From 96d4a20ab7d50b1a36e85622d827ecf4a670f6d4 Mon Sep 17 00:00:00 2001 From: yoruet <1559650411@qq.com> Date: Tue, 24 Sep 2024 20:49:25 +0800 Subject: [PATCH 01/17] add regr_intercept and regr_slope aggregate functions --- .../aggregate_function_regr_intercept.cpp | 37 ++++ .../aggregate_function_regr_intercept.h | 184 ++++++++++++++++++ .../aggregate_function_regr_intercept_impl.h | 1 + .../aggregate_function_regr_slope.cpp | 37 ++++ .../aggregate_function_regr_slope.h | 183 +++++++++++++++++ .../aggregate_function_simple_factory.cpp | 7 +- .../doris/catalog/AggregateFunction.java | 2 +- .../catalog/BuiltinAggregateFunctions.java | 18 +- .../org/apache/doris/catalog/FunctionSet.java | 23 +++ .../functions/agg/RegrIntercept.java | 107 ++++++++++ .../expressions/functions/agg/RegrSlope.java | 108 ++++++++++ .../visitor/AggregateFunctionVisitor.java | 10 + .../agg_function/test_regr_intercept.groovy | 141 ++++++++++++++ .../agg_function/test_regr_slope.groovy | 141 ++++++++++++++ 14 files changed, 989 insertions(+), 10 deletions(-) create mode 100644 be/src/vec/aggregate_functions/aggregate_function_regr_intercept.cpp create mode 100644 be/src/vec/aggregate_functions/aggregate_function_regr_intercept.h create mode 100644 be/src/vec/aggregate_functions/aggregate_function_regr_intercept_impl.h create mode 100644 be/src/vec/aggregate_functions/aggregate_function_regr_slope.cpp create mode 100644 be/src/vec/aggregate_functions/aggregate_function_regr_slope.h create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/RegrIntercept.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/RegrSlope.java create mode 100644 regression-test/suites/nereids_function_p0/agg_function/test_regr_intercept.groovy create mode 100644 regression-test/suites/nereids_function_p0/agg_function/test_regr_slope.groovy diff --git a/be/src/vec/aggregate_functions/aggregate_function_regr_intercept.cpp b/be/src/vec/aggregate_functions/aggregate_function_regr_intercept.cpp new file mode 100644 index 00000000000000..0a93b7113c39b4 --- /dev/null +++ b/be/src/vec/aggregate_functions/aggregate_function_regr_intercept.cpp @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "vec/aggregate_functions/aggregate_function_regr_intercept.h" + +#include + +#include "vec/aggregate_functions/aggregate_function_simple_factory.h" +#include "vec/data_types/data_type.h" + +namespace doris::vectorized { + +AggregateFunctionPtr create_aggregate_function_regr_intercept(const std::string& name, + const DataTypes& argument_types, + const bool result_is_nullable) { + return std::make_shared(argument_types); +} + +void register_aggregate_function_regr_intercept(AggregateFunctionSimpleFactory& factory) { + factory.register_function("regr_intercept", create_aggregate_function_regr_intercept, false); +} + +} // namespace doris::vectorized \ No newline at end of file diff --git a/be/src/vec/aggregate_functions/aggregate_function_regr_intercept.h b/be/src/vec/aggregate_functions/aggregate_function_regr_intercept.h new file mode 100644 index 00000000000000..38b20431bd312a --- /dev/null +++ b/be/src/vec/aggregate_functions/aggregate_function_regr_intercept.h @@ -0,0 +1,184 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include "olap/olap_common.h" +#include "runtime/decimalv2_value.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/columns/column.h" +#include "vec/columns/column_nullable.h" +#include "vec/common/assert_cast.h" +#include "vec/core/field.h" +#include "vec/core/types.h" +#include "vec/data_types/data_type_decimal.h" +#include "vec/data_types/data_type_number.h" +#include "vec/io/io_helper.h" + +namespace doris::vectorized { + +struct AggregateFunctionRegrInterceptData { + UInt64 count = 0; + double sum_x = 0.0; + double sum_y = 0.0; + double sum_xx = 0.0; + double sum_xy = 0.0; +}; + +class AggregateFunctionRegrIntercept final + : public IAggregateFunctionDataHelper { +public: + AggregateFunctionRegrIntercept(const DataTypes& argument_types_) + : IAggregateFunctionDataHelper(argument_types_) { + } + + String get_name() const override { return "regr_intercept"; } + + DataTypePtr get_return_type() const override { return std::make_shared(); } + + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { + bool y_valid = true; + bool x_valid = true; + const IColumn* y_column = columns[0]; + const IColumn* x_column = columns[1]; + + if (y_column->is_nullable()) { + const auto& nullable_col = assert_cast(*y_column); + y_valid = !nullable_col.is_null_at(row_num); + y_column = &nullable_col.get_nested_column(); + } + + if (x_column->is_nullable()) { + const auto& nullable_col = assert_cast(*x_column); + x_valid = !nullable_col.is_null_at(row_num); + x_column = &nullable_col.get_nested_column(); + } + + + if (y_valid && x_valid) { + double y_value = 0.0; + double x_value = 0.0; + + // Handle different numeric types for y + if (const auto* float64_col = check_and_get_column(y_column)) { + y_value = float64_col->get_data()[row_num]; + } else if (const auto* float32_col = check_and_get_column(y_column)) { + y_value = static_cast(float32_col->get_data()[row_num]); + } else if (const auto* int64_col = check_and_get_column(y_column)) { + y_value = static_cast(int64_col->get_data()[row_num]); + } else if (const auto* int32_col = check_and_get_column(y_column)) { + y_value = static_cast(int32_col->get_data()[row_num]); + } else { + throw Exception(ErrorCode::INVALID_ARGUMENT, + "Unexpected column type for y in regr_intercept"); + } + + // Handle different numeric types for x + if (const auto* float64_col = check_and_get_column(x_column)) { + x_value = float64_col->get_data()[row_num]; + } else if (const auto* float32_col = check_and_get_column(x_column)) { + x_value = static_cast(float32_col->get_data()[row_num]); + } else if (const auto* int64_col = check_and_get_column(x_column)) { + x_value = static_cast(int64_col->get_data()[row_num]); + } else if (const auto* int32_col = check_and_get_column(x_column)) { + x_value = static_cast(int32_col->get_data()[row_num]); + } else { + throw Exception(ErrorCode::INVALID_ARGUMENT, + "Unexpected column type for x in regr_intercept"); + } + + data(place).count += 1; + data(place).sum_x += x_value; + data(place).sum_y += y_value; + data(place).sum_xx += x_value * x_value; + data(place).sum_xy += x_value * y_value; + } + } + + void reset(AggregateDataPtr __restrict place) const override { + data(place).count = 0; + data(place).sum_x = 0.0; + data(place).sum_y = 0.0; + data(place).sum_xx = 0.0; + data(place).sum_xy = 0.0; + } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, + Arena*) const override { + data(place).count += data(rhs).count; + data(place).sum_x += data(rhs).sum_x; + data(place).sum_y += data(rhs).sum_y; + data(place).sum_xx += data(rhs).sum_xx; + data(place).sum_xy += data(rhs).sum_xy; + } + + void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { + write_var_uint(data(place).count, buf); + write_float_binary(data(place).sum_x, buf); + write_float_binary(data(place).sum_y, buf); + write_float_binary(data(place).sum_xx, buf); + write_float_binary(data(place).sum_xy, buf); + } + + void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, + Arena*) const override { + read_var_uint(data(place).count, buf); + read_float_binary(data(place).sum_x, buf); + read_float_binary(data(place).sum_y, buf); + read_float_binary(data(place).sum_xx, buf); + read_float_binary(data(place).sum_xy, buf); + } + + void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { + UInt64 n = data(place).count; + double sum_x = data(place).sum_x; + double sum_y = data(place).sum_y; + double sum_xx = data(place).sum_xx; + double sum_xy = data(place).sum_xy; + + double denominator = n * sum_xx - sum_x * sum_x; + bool result_is_null = (n == 0 || denominator == 0.0); + + double intercept = 0.0; + if (!result_is_null) { + double slope = (n * sum_xy - sum_x * sum_y) / denominator; + intercept = (sum_y - slope * sum_x) / n; + } + + if (to.is_nullable()) { + auto& null_column = assert_cast(to); + null_column.get_null_map_data().push_back(result_is_null); + auto& nested_column = assert_cast(null_column.get_nested_column()); + nested_column.get_data().push_back(intercept); + } else { + if (result_is_null) { + assert_cast(to).get_data().push_back( + std::numeric_limits::quiet_NaN()); + } else { + assert_cast(to).get_data().push_back(intercept); + } + } + } + +}; + +} // namespace doris::vectorized \ No newline at end of file diff --git a/be/src/vec/aggregate_functions/aggregate_function_regr_intercept_impl.h b/be/src/vec/aggregate_functions/aggregate_function_regr_intercept_impl.h new file mode 100644 index 00000000000000..0519ecba6ea913 --- /dev/null +++ b/be/src/vec/aggregate_functions/aggregate_function_regr_intercept_impl.h @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/be/src/vec/aggregate_functions/aggregate_function_regr_slope.cpp b/be/src/vec/aggregate_functions/aggregate_function_regr_slope.cpp new file mode 100644 index 00000000000000..b3a672390bdac2 --- /dev/null +++ b/be/src/vec/aggregate_functions/aggregate_function_regr_slope.cpp @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "vec/aggregate_functions/aggregate_function_regr_slope.h" + +#include + +#include "vec/aggregate_functions/aggregate_function_simple_factory.h" +#include "vec/data_types/data_type.h" + +namespace doris::vectorized { + +AggregateFunctionPtr create_aggregate_function_regr_slope(const std::string& name, + const DataTypes& argument_types, + const bool result_is_nullable) { + return std::make_shared(argument_types); +} + +void register_aggregate_function_regr_slope(AggregateFunctionSimpleFactory& factory) { + factory.register_function("regr_slope", create_aggregate_function_regr_slope, false); +} + +} // namespace doris::vectorized \ No newline at end of file diff --git a/be/src/vec/aggregate_functions/aggregate_function_regr_slope.h b/be/src/vec/aggregate_functions/aggregate_function_regr_slope.h new file mode 100644 index 00000000000000..8cbb5d60800509 --- /dev/null +++ b/be/src/vec/aggregate_functions/aggregate_function_regr_slope.h @@ -0,0 +1,183 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +#include "olap/olap_common.h" +#include "runtime/decimalv2_value.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/columns/column.h" +#include "vec/columns/column_nullable.h" +#include "vec/common/assert_cast.h" +#include "vec/core/field.h" +#include "vec/core/types.h" +#include "vec/data_types/data_type_decimal.h" +#include "vec/data_types/data_type_number.h" +#include "vec/io/io_helper.h" + +namespace doris::vectorized { + +struct AggregateFunctionRegrSlopeData { + UInt64 count = 0; + double sum_x = 0.0; + double sum_y = 0.0; + double sum_xx = 0.0; + double sum_xy = 0.0; +}; + +class AggregateFunctionRegrSlope final + : public IAggregateFunctionDataHelper { +public: + AggregateFunctionRegrSlope(const DataTypes& argument_types_) + : IAggregateFunctionDataHelper(argument_types_) { + } + + String get_name() const override { return "regr_slope"; } + + DataTypePtr get_return_type() const override { return std::make_shared(); } + + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { + bool y_valid = true; + bool x_valid = true; + const IColumn* y_column = columns[0]; + const IColumn* x_column = columns[1]; + + if (y_column->is_nullable()) { + const auto& nullable_col = assert_cast(*y_column); + y_valid = !nullable_col.is_null_at(row_num); + y_column = &nullable_col.get_nested_column(); + } + + if (x_column->is_nullable()) { + const auto& nullable_col = assert_cast(*x_column); + x_valid = !nullable_col.is_null_at(row_num); + x_column = &nullable_col.get_nested_column(); + } + + if (y_valid && x_valid) { + double y_value = 0.0; + double x_value = 0.0; + + // Handle different numeric types for y + if (const auto* float64_col = check_and_get_column(y_column)) { + y_value = float64_col->get_data()[row_num]; + } else if (const auto* float32_col = check_and_get_column(y_column)) { + y_value = static_cast(float32_col->get_data()[row_num]); + } else if (const auto* int64_col = check_and_get_column(y_column)) { + y_value = static_cast(int64_col->get_data()[row_num]); + } else if (const auto* int32_col = check_and_get_column(y_column)) { + y_value = static_cast(int32_col->get_data()[row_num]); + } else { + throw Exception(ErrorCode::INVALID_ARGUMENT, + "Unexpected column type for y in regr_slope"); + } + + // Handle different numeric types for x + if (const auto* float64_col = check_and_get_column(x_column)) { + x_value = float64_col->get_data()[row_num]; + } else if (const auto* float32_col = check_and_get_column(x_column)) { + x_value = static_cast(float32_col->get_data()[row_num]); + } else if (const auto* int64_col = check_and_get_column(x_column)) { + x_value = static_cast(int64_col->get_data()[row_num]); + } else if (const auto* int32_col = check_and_get_column(x_column)) { + x_value = static_cast(int32_col->get_data()[row_num]); + } else { + throw Exception(ErrorCode::INVALID_ARGUMENT, + "Unexpected column type for x in regr_slope"); + } + + data(place).count += 1; + data(place).sum_x += x_value; + data(place).sum_y += y_value; + data(place).sum_xx += x_value * x_value; + data(place).sum_xy += x_value * y_value; + } + } + + void reset(AggregateDataPtr __restrict place) const override { + data(place).count = 0; + data(place).sum_x = 0.0; + data(place).sum_y = 0.0; + data(place).sum_xx = 0.0; + data(place).sum_xy = 0.0; + } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, + Arena*) const override { + data(place).count += data(rhs).count; + data(place).sum_x += data(rhs).sum_x; + data(place).sum_y += data(rhs).sum_y; + data(place).sum_xx += data(rhs).sum_xx; + data(place).sum_xy += data(rhs).sum_xy; + } + + void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { + write_var_uint(data(place).count, buf); + write_float_binary(data(place).sum_x, buf); + write_float_binary(data(place).sum_y, buf); + write_float_binary(data(place).sum_xx, buf); + write_float_binary(data(place).sum_xy, buf); + } + + void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, + Arena*) const override { + read_var_uint(data(place).count, buf); + read_float_binary(data(place).sum_x, buf); + read_float_binary(data(place).sum_y, buf); + read_float_binary(data(place).sum_xx, buf); + read_float_binary(data(place).sum_xy, buf); + } + + void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { + UInt64 n = data(place).count; + double sum_x = data(place).sum_x; + double sum_y = data(place).sum_y; + double sum_xx = data(place).sum_xx; + double sum_xy = data(place).sum_xy; + + double denominator = n * sum_xx - sum_x * sum_x; + bool result_is_null = (n == 0 || denominator == 0.0); + + double slope = 0.0; + if (!result_is_null) { + slope = (n * sum_xy - sum_x * sum_y) / denominator; + } + + if (to.is_nullable()) { + auto& null_column = assert_cast(to); + null_column.get_null_map_data().push_back(result_is_null); + auto& nested_column = assert_cast(null_column.get_nested_column()); + nested_column.get_data().push_back(slope); + } else { + if (result_is_null) { + assert_cast(to).get_data().push_back( + std::numeric_limits::quiet_NaN()); + } else { + assert_cast(to).get_data().push_back(slope); + } + } + } + +}; + +} // namespace doris::vectorized \ No newline at end of file diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp index 436691c6ef2aad..11a926f71e4194 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp @@ -56,6 +56,8 @@ void register_aggregate_function_percentile(AggregateFunctionSimpleFactory& fact void register_aggregate_function_percentile_old(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_window_funnel(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_window_funnel_old(AggregateFunctionSimpleFactory& factory); +void register_aggregate_function_regr_intercept(AggregateFunctionSimpleFactory& factory); +void register_aggregate_function_regr_slope(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_retention(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_percentile_approx(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_orthogonal_bitmap(AggregateFunctionSimpleFactory& factory); @@ -89,8 +91,7 @@ AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() { register_aggregate_function_group_concat(instance); register_aggregate_function_quantile_state(instance); register_aggregate_function_combinator_distinct(instance); - register_aggregate_function_reader_load( - instance); // register aggregate function for agg reader + register_aggregate_function_reader_load(instance); // register aggregate function for agg reader register_aggregate_function_window_rank(instance); register_aggregate_function_stddev_variance_pop(instance); register_aggregate_function_topn(instance); @@ -100,6 +101,8 @@ AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() { register_aggregate_function_percentile_approx(instance); register_aggregate_function_window_funnel(instance); register_aggregate_function_window_funnel_old(instance); + register_aggregate_function_regr_intercept(instance); + register_aggregate_function_regr_slope(instance); register_aggregate_function_retention(instance); register_aggregate_function_orthogonal_bitmap(instance); register_aggregate_function_collect_list(instance); diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java index 2786dc6470e39d..df9c792aec6c41 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java @@ -56,7 +56,7 @@ public class AggregateFunction extends Function { "ndv_no_finalize", "percentile_array", "histogram", FunctionSet.SEQUENCE_COUNT, FunctionSet.MAP_AGG, FunctionSet.BITMAP_AGG, FunctionSet.ARRAY_AGG, FunctionSet.COLLECT_LIST, FunctionSet.COLLECT_SET, FunctionSet.GROUP_ARRAY_INTERSECT, - FunctionSet.SUM0, FunctionSet.MULTI_DISTINCT_SUM0); + FunctionSet.SUM0, FunctionSet.MULTI_DISTINCT_SUM0, FunctionSet.REGR_INTERCEPT, FunctionSet.REGR_SLOPE); public static ImmutableSet ALWAYS_NULLABLE_AGGREGATE_FUNCTION_NAME_SET = ImmutableSet.of("stddev_samp", "variance_samp", "var_samp", "percentile_approx", "first_value", diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java index 28b1352eaf4551..f517770400b5dd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java @@ -61,6 +61,8 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.PercentileApproxWeighted; import org.apache.doris.nereids.trees.expressions.functions.agg.PercentileArray; import org.apache.doris.nereids.trees.expressions.functions.agg.QuantileUnion; +import org.apache.doris.nereids.trees.expressions.functions.agg.RegrIntercept; +import org.apache.doris.nereids.trees.expressions.functions.agg.RegrSlope; import org.apache.doris.nereids.trees.expressions.functions.agg.Retention; import org.apache.doris.nereids.trees.expressions.functions.agg.SequenceCount; import org.apache.doris.nereids.trees.expressions.functions.agg.SequenceMatch; @@ -127,13 +129,15 @@ public class BuiltinAggregateFunctions implements FunctionHelper { agg(Ndv.class, "approx_count_distinct", "ndv"), agg(OrthogonalBitmapIntersect.class, "orthogonal_bitmap_intersect"), agg(OrthogonalBitmapIntersectCount.class, "orthogonal_bitmap_intersect_count"), - agg(OrthogonalBitmapUnionCount.class, "orthogonal_bitmap_union_count"), - agg(Percentile.class, "percentile"), - agg(PercentileApprox.class, "percentile_approx"), - agg(PercentileApproxWeighted.class, "percentile_approx_weighted"), - agg(PercentileArray.class, "percentile_array"), - agg(QuantileUnion.class, "quantile_union"), - agg(Retention.class, "retention"), + agg(OrthogonalBitmapUnionCount.class, "orthogonal_bitmap_union_count"), + agg(Percentile.class, "percentile"), + agg(PercentileApprox.class, "percentile_approx"), + agg(PercentileApproxWeighted.class, "percentile_approx_weighted"), + agg(PercentileArray.class, "percentile_array"), + agg(QuantileUnion.class, "quantile_union"), + agg(RegrIntercept.class, "regr_intercept"), + agg(RegrSlope.class, "regr_slope"), + agg(Retention.class, "retention"), agg(SequenceCount.class, "sequence_count"), agg(SequenceMatch.class, "sequence_match"), agg(Stddev.class, "stddev_pop", "stddev"), diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java index 4165f9362214ed..bf2a623f86db6d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java @@ -631,6 +631,10 @@ public void addBuiltinBothScalaAndVectorized(Function fn) { public static final String RETENTION = "retention"; + public static final String REGR_INTERCEPT = "regr_intercept"; + + public static final String REGR_SLOPE = "regr_slope"; + public static final String SEQUENCE_MATCH = "sequence_match"; public static final String SEQUENCE_COUNT = "sequence_count"; @@ -664,6 +668,25 @@ private void initAggregateBuiltins() { "", null, false, true, true, true)); + // regr_intercept + addBuiltin(AggregateFunction.createBuiltin(FunctionSet.REGR_INTERCEPT, + Lists.newArrayList(Type.DOUBLE, Type.DOUBLE), Type.DOUBLE, Type.DOUBLE, + "", + "", + "", + null, null, + "", + null, false, true, true, true)); + // regr_slope + addBuiltin(AggregateFunction.createBuiltin(FunctionSet.REGR_SLOPE, + Lists.newArrayList(Type.DOUBLE, Type.DOUBLE), Type.DOUBLE, Type.DOUBLE, + "", + "", + "", + null, null, + "", + null, false, true, true, true)); + // count(array/map/struct) for (Type complexType : Lists.newArrayList(Type.ARRAY, Type.MAP, Type.GENERIC_STRUCT)) { addBuiltin(AggregateFunction.createBuiltin(FunctionSet.COUNT, diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/RegrIntercept.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/RegrIntercept.java new file mode 100644 index 00000000000000..d6855023faac0d --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/RegrIntercept.java @@ -0,0 +1,107 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions.agg; + +import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.exceptions.AnalysisException; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; +import org.apache.doris.nereids.trees.expressions.functions.window.SupportWindowAnalytic; +import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.DecimalV3Type; +import org.apache.doris.nereids.types.DoubleType; +import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.types.LargeIntType; +import org.apache.doris.nereids.types.SmallIntType; +import org.apache.doris.nereids.types.TinyIntType; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + +import java.util.List; + +/** + * AggregateFunction 'regr_intercept'. + */ +public class RegrIntercept extends NullableAggregateFunction + implements BinaryExpression, ExplicitlyCastableSignature, SupportWindowAnalytic { + + public static final List SIGNATURES = ImmutableList.of( + FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE, TinyIntType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE, SmallIntType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE, IntegerType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE, BigIntType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(LargeIntType.INSTANCE, LargeIntType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, DoubleType.INSTANCE), + FunctionSignature.ret(DecimalV3Type.WILDCARD).args(DecimalV3Type.WILDCARD, DecimalV3Type.WILDCARD)); + + /** + * Constructor with 2 arguments. + */ + public RegrIntercept(Expression arg1, Expression arg2) { + this(false, arg1, arg2); + } + + /** + * Constructor with distinct flag and 2 arguments. + */ + public RegrIntercept(boolean distinct, Expression arg1, Expression arg2) { + this(distinct, false, arg1, arg2); + } + + private RegrIntercept(boolean distinct, boolean alwaysNullable, Expression arg1, Expression arg2) { + super("regr_intercept", distinct, alwaysNullable, arg1, arg2); + } + + @Override + public RegrIntercept withDistinctAndChildren(boolean distinct, List children) { + Preconditions.checkArgument(children.size() == 2); + return new RegrIntercept(distinct, alwaysNullable, children.get(0), children.get(1)); + } + + @Override + public NullableAggregateFunction withAlwaysNullable(boolean alwaysNullable) { + return new RegrIntercept(distinct, alwaysNullable, child(0), child(1)); + } + + @Override + public R accept(ExpressionVisitor visitor, C context) { + return visitor.visitNullableAggregateFunction(this, context); + } + + @Override + public List getSignatures() { + return SIGNATURES; + } + + @Override + public void checkLegalityBeforeTypeCoercion() throws AnalysisException { + DataType arg0Type = left().getDataType(); + DataType arg1Type = right().getDataType(); + if ((!arg0Type.isNumericType() && !arg0Type.isNullType()) + || arg0Type.isOnlyMetricType()) { + throw new AnalysisException("regr_intercept requires numeric for first parameter: " + toSql()); + } else if ((!arg1Type.isNumericType() && !arg1Type.isNullType()) + || arg1Type.isOnlyMetricType()) { + throw new AnalysisException("regr_intercept requires numeric for second parameter: " + toSql()); + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/RegrSlope.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/RegrSlope.java new file mode 100644 index 00000000000000..e396a10cc035f7 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/RegrSlope.java @@ -0,0 +1,108 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions.agg; + +import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.exceptions.AnalysisException; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; +import org.apache.doris.nereids.trees.expressions.functions.window.SupportWindowAnalytic; +import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.DecimalV3Type; +import org.apache.doris.nereids.types.DoubleType; +import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.types.LargeIntType; +import org.apache.doris.nereids.types.SmallIntType; +import org.apache.doris.nereids.types.TinyIntType; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + +import java.util.List; + +/** + * AggregateFunction 'regr_slope'. + */ + +public class RegrSlope extends NullableAggregateFunction + implements BinaryExpression, ExplicitlyCastableSignature, SupportWindowAnalytic { + + public static final List SIGNATURES = ImmutableList.of( + FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE, TinyIntType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE, SmallIntType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE, IntegerType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE, BigIntType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(LargeIntType.INSTANCE, LargeIntType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, DoubleType.INSTANCE), + FunctionSignature.ret(DecimalV3Type.WILDCARD).args(DecimalV3Type.WILDCARD, DecimalV3Type.WILDCARD)); + + /** + * Constructor with 2 arguments. + */ + public RegrSlope(Expression arg1, Expression arg2) { + this(false, arg1, arg2); + } + + /** + * Constructor with distinct flag and 2 arguments. + */ + public RegrSlope(boolean distinct, Expression arg1, Expression arg2) { + this(distinct, false, arg1, arg2); + } + + private RegrSlope(boolean distinct, boolean alwaysNullable, Expression arg1, Expression arg2) { + super("regr_slope", distinct, alwaysNullable, arg1, arg2); + } + + @Override + public RegrSlope withDistinctAndChildren(boolean distinct, List children) { + Preconditions.checkArgument(children.size() == 2); + return new RegrSlope(distinct, alwaysNullable, children.get(0), children.get(1)); + } + + @Override + public NullableAggregateFunction withAlwaysNullable(boolean alwaysNullable) { + return new RegrSlope(distinct, alwaysNullable, child(0), child(1)); + } + + @Override + public void checkLegalityBeforeTypeCoercion() throws AnalysisException { + DataType arg0Type = left().getDataType(); + DataType arg1Type = right().getDataType(); + if ((!arg0Type.isNumericType() && !arg0Type.isNullType()) + || arg0Type.isOnlyMetricType()) { + throw new AnalysisException("regr_slope requires numeric for first parameter: " + toSql()); + } else if ((!arg1Type.isNumericType() && !arg1Type.isNullType()) + || arg1Type.isOnlyMetricType()) { + throw new AnalysisException("regr_slope requires numeric for second parameter: " + toSql()); + } + } + + @Override + public R accept(ExpressionVisitor visitor, C context) { + return visitor.visitNullableAggregateFunction(this, context); + } + + @Override + public List getSignatures() { + return SIGNATURES; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java index abe8044c28c342..3f8c157b376bd6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java @@ -63,6 +63,8 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.PercentileApproxWeighted; import org.apache.doris.nereids.trees.expressions.functions.agg.PercentileArray; import org.apache.doris.nereids.trees.expressions.functions.agg.QuantileUnion; +import org.apache.doris.nereids.trees.expressions.functions.agg.RegrIntercept; +import org.apache.doris.nereids.trees.expressions.functions.agg.RegrSlope; import org.apache.doris.nereids.trees.expressions.functions.agg.Retention; import org.apache.doris.nereids.trees.expressions.functions.agg.SequenceCount; import org.apache.doris.nereids.trees.expressions.functions.agg.SequenceMatch; @@ -266,6 +268,14 @@ default R visitQuantileUnion(QuantileUnion quantileUnion, C context) { return visitAggregateFunction(quantileUnion, context); } + default R visitRegrIntercept(RegrIntercept regrIntercept, C context) { + return visitAggregateFunction(regrIntercept, context); + } + + default R visitRegrSlope(RegrSlope regrSlope, C context) { + return visitAggregateFunction(regrSlope, context); + } + default R visitRetention(Retention retention, C context) { return visitNullableAggregateFunction(retention, context); } diff --git a/regression-test/suites/nereids_function_p0/agg_function/test_regr_intercept.groovy b/regression-test/suites/nereids_function_p0/agg_function/test_regr_intercept.groovy new file mode 100644 index 00000000000000..f2402f51d4b3dc --- /dev/null +++ b/regression-test/suites/nereids_function_p0/agg_function/test_regr_intercept.groovy @@ -0,0 +1,141 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_regr_intercept") { + sql """ DROP TABLE IF EXISTS test_regr_intercept_int """ + sql """ DROP TABLE IF EXISTS test_regr_intercept_double """ + sql """ DROP TABLE IF EXISTS test_regr_intercept_nullable_col """ + + + sql """ SET enable_nereids_planner=true """ + sql """ SET enable_fallback_to_original_planner=false """ + + sql """ + CREATE TABLE test_regr_intercept_int ( + `id` int, + `x` int, + `y` int + ) ENGINE=OLAP + DUPLICATE KEY (`id`) + DISTRIBUTED BY HASH(`id`) BUCKETS 4 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + ); + """ + sql """ + CREATE TABLE test_regr_intercept_double ( + `id` int, + `x` double, + `y` double + ) ENGINE=OLAP + DUPLICATE KEY (`id`) + DISTRIBUTED BY HASH(`id`) BUCKETS 4 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + ); + """ + sql """ + CREATE TABLE test_regr_intercept_nullable_col ( + `id` int, + `x` int, + `y` int + ) ENGINE=OLAP + DUPLICATE KEY (`id`) + DISTRIBUTED BY HASH(`id`) BUCKETS 4 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + ); + """ + // no value + qt_sql "select regr_intercept(y,x) from test_regr_intercept_int" + sql """ TRUNCATE TABLE test_regr_intercept_int """ + + sql """ + INSERT INTO test_regr_intercept_int VALUES + (1, 18, 13), + (2, 14, 27), + (3, 12, 2), + (4, 5, 6), + (5, 10, 20); + """ + + sql """ + INSERT INTO test_regr_intercept_double VALUES + (1, 18.27123456, 13.27123456), + (2, 14.65890846, 27.65890846), + (3, 12.25345846, 2.253458468), + (4, 5.890846835, 6.890846835), + (5, 10.14345678, 20.14345678); + """ + + sql """ + INSERT INTO test_regr_intercept_nullable_col VALUES + (1, 18, 13), + (2, 14, 27), + (3, 5, 7), + (4, 10, 20); + """ + + // value is null + sql """SELECT regr_intercept(NULL, NULL);""" + + // parameter is literal and columns + qt_sql "select regr_intercept(10,x) from test_regr_intercept_int" + + // literal and column + qt_sql "select regr_intercept(4,x) from test_regr_intercept_int" + + // int value + qt_sql "select regr_intercept(y,x) from test_regr_intercept_int" + sql """ TRUNCATE TABLE test_regr_intercept_int """ + + // double value + qt_sql "select regr_intercept(y,x) from test_regr_intercept_double" + sql """ TRUNCATE TABLE test_regr_intercept_double """ + + // nullable and non_nullable + qt_sql "select regr_intercept(y,non_nullable(x)) from test_regr_intercept_nullable_col" + + // non_nullable and nullable + qt_sql "select regr_intercept(non_nullable(y),x) from test_regr_intercept_nullable_col" + + // non_nullable and non_nullable + qt_sql "select regr_intercept(non_nullable(y),non_nullable(x)) from test_regr_intercept_nullable_col" + sql """ TRUNCATE TABLE test_regr_intercept_nullable_col """ + + // exception test + test{ + sql """select regr_intercept('range', 1);""" + exception "regr_intercept requires numeric for first parameter" + } + + test{ + sql """select regr_intercept(1, 'hello');""" + exception "regr_intercept requires numeric for second parameter" + } + + test{ + sql """select regr_intercept(y, 'hello') from test_regr_intercept_int;""" + exception "regr_intercept requires numeric for second parameter" + } + + test{ + sql """select regr_intercept(1, true);""" + exception "regr_intercept requires numeric for second parameter" + } + +} diff --git a/regression-test/suites/nereids_function_p0/agg_function/test_regr_slope.groovy b/regression-test/suites/nereids_function_p0/agg_function/test_regr_slope.groovy new file mode 100644 index 00000000000000..783a822c5da113 --- /dev/null +++ b/regression-test/suites/nereids_function_p0/agg_function/test_regr_slope.groovy @@ -0,0 +1,141 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_regr_slope") { + sql """ DROP TABLE IF EXISTS test_regr_slope_int """ + sql """ DROP TABLE IF EXISTS test_regr_slope_double """ + sql """ DROP TABLE IF EXISTS test_regr_slope_nullable_col """ + + + sql """ SET enable_nereids_planner=true """ + sql """ SET enable_fallback_to_original_planner=false """ + + sql """ + CREATE TABLE test_regr_slope_int ( + `id` int, + `x` int, + `y` int + ) ENGINE=OLAP + DUPLICATE KEY (`id`) + DISTRIBUTED BY HASH(`id`) BUCKETS 4 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + ); + """ + sql """ + CREATE TABLE test_regr_slope_double ( + `id` int, + `x` double, + `y` double + ) ENGINE=OLAP + DUPLICATE KEY (`id`) + DISTRIBUTED BY HASH(`id`) BUCKETS 4 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + ); + """ + sql """ + CREATE TABLE test_regr_slope_nullable_col ( + `id` int, + `x` int, + `y` int + ) ENGINE=OLAP + DUPLICATE KEY (`id`) + DISTRIBUTED BY HASH(`id`) BUCKETS 4 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + ); + """ + // no value + qt_sql "select regr_slope(y,x) from test_regr_slope_int" + sql """ TRUNCATE TABLE test_regr_slope_int """ + + sql """ + INSERT INTO test_regr_slope_int VALUES + (1, 18, 13), + (2, 14, 27), + (3, 12, 2), + (4, 5, 6), + (5, 10, 20); + """ + + sql """ + INSERT INTO test_regr_slope_double VALUES + (1, 18.27123456, 13.27123456), + (2, 14.65890846, 27.65890846), + (3, 12.25345846, 2.253458468), + (4, 5.890846835, 6.890846835), + (5, 10.14345678, 20.14345678); + """ + + sql """ + INSERT INTO test_regr_slope_nullable_col VALUES + (1, 18, 13), + (2, 14, 27), + (3, 5, 7), + (4, 10, 20); + """ + + // value is null + sql """SELECT regr_slope(NULL, NULL);""" + + // parameter is literal and columns + qt_sql "select regr_slope(10,x) from test_regr_slope_int" + + // literal and column + qt_sql "select regr_slope(4,x) from test_regr_slope_int" + + // int value + qt_sql "select regr_slope(y,x) from test_regr_slope_int" + sql """ TRUNCATE TABLE test_regr_slope_int """ + + // double value + qt_sql "select regr_slope(y,x) from test_regr_slope_double" + sql """ TRUNCATE TABLE test_regr_slope_double """ + + // nullable and non_nullable + qt_sql "select regr_slope(y,non_nullable(x)) from test_regr_slope_nullable_col" + + // non_nullable and nullable + qt_sql "select regr_slope(non_nullable(y),x) from test_regr_slope_nullable_col" + + // non_nullable and non_nullable + qt_sql "select regr_slope(non_nullable(y),non_nullable(x)) from test_regr_slope_nullable_col" + sql """ TRUNCATE TABLE test_regr_slope_nullable_col """ + + // exception test + test{ + sql """select regr_slope('range', 1);""" + exception "regr_slope requires numeric for first parameter" + } + + test{ + sql """select regr_slope(1, 'hello');""" + exception "regr_slope requires numeric for second parameter" + } + + test{ + sql """select regr_slope(y, 'hello') from test_regr_slope_int;""" + exception "regr_slope requires numeric for second parameter" + } + + test{ + sql """select regr_slope(1, true);""" + exception "regr_slope requires numeric for second parameter" + } + +} From 5882639cb88b656af00e76de2377afaf27058e5b Mon Sep 17 00:00:00 2001 From: yoruet <1559650411@qq.com> Date: Tue, 24 Sep 2024 21:30:25 +0800 Subject: [PATCH 02/17] delete unused files --- .../aggregate_functions/aggregate_function_regr_intercept_impl.h | 1 - 1 file changed, 1 deletion(-) delete mode 100644 be/src/vec/aggregate_functions/aggregate_function_regr_intercept_impl.h diff --git a/be/src/vec/aggregate_functions/aggregate_function_regr_intercept_impl.h b/be/src/vec/aggregate_functions/aggregate_function_regr_intercept_impl.h deleted file mode 100644 index 0519ecba6ea913..00000000000000 --- a/be/src/vec/aggregate_functions/aggregate_function_regr_intercept_impl.h +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file From 7d3d24ea4ac8af48c3d7d19fecd5606700d1d840 Mon Sep 17 00:00:00 2001 From: yoruet <1559650411@qq.com> Date: Tue, 24 Sep 2024 23:19:20 +0800 Subject: [PATCH 03/17] fix bug in regr_slope and regr_intercept --- .../aggregate_function_regr_intercept.cpp | 34 ++- .../aggregate_function_regr_intercept.h | 217 ++++++++---------- .../aggregate_function_regr_slope.cpp | 34 ++- .../aggregate_function_regr_slope.h | 214 ++++++++--------- .../functions/agg/RegrIntercept.java | 30 +-- .../expressions/functions/agg/RegrSlope.java | 26 +-- .../agg_function/test_regr_intercept.out | 25 ++ .../agg_function/test_regr_slope.out | 25 ++ 8 files changed, 332 insertions(+), 273 deletions(-) create mode 100644 regression-test/data/nereids_function_p0/agg_function/test_regr_intercept.out create mode 100644 regression-test/data/nereids_function_p0/agg_function/test_regr_slope.out diff --git a/be/src/vec/aggregate_functions/aggregate_function_regr_intercept.cpp b/be/src/vec/aggregate_functions/aggregate_function_regr_intercept.cpp index 0a93b7113c39b4..c785ee8bb2421f 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_regr_intercept.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_regr_intercept.cpp @@ -14,24 +14,44 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. - -#include "vec/aggregate_functions/aggregate_function_regr_intercept.h" +#include #include +#include "common/logging.h" +#include "vec/aggregate_functions/aggregate_function_regr_intercept.h" #include "vec/aggregate_functions/aggregate_function_simple_factory.h" +#include "vec/aggregate_functions/helpers.h" #include "vec/data_types/data_type.h" +#include "vec/data_types/data_type_nullable.h" namespace doris::vectorized { +template