diff --git a/be/src/vec/aggregate_functions/aggregate_function_regr_union.cpp b/be/src/vec/aggregate_functions/aggregate_function_regr_union.cpp index 738d777441c360..ed3d32cd5261b9 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_regr_union.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_regr_union.cpp @@ -89,5 +89,8 @@ void register_aggregate_function_regr_union(AggregateFunctionSimpleFactory& fact factory.register_function_both("regr_slope", create_aggregate_function_regr); factory.register_function_both("regr_intercept", create_aggregate_function_regr); + factory.register_function_both("regr_sxx", create_aggregate_function_regr); + factory.register_function_both("regr_sxy", create_aggregate_function_regr); + factory.register_function_both("regr_syy", create_aggregate_function_regr); } } // namespace doris::vectorized \ No newline at end of file diff --git a/be/src/vec/aggregate_functions/aggregate_function_regr_union.h b/be/src/vec/aggregate_functions/aggregate_function_regr_union.h index a95daaf0d840cd..bae0d6152b9ffe 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_regr_union.h +++ b/be/src/vec/aggregate_functions/aggregate_function_regr_union.h @@ -98,6 +98,173 @@ struct AggregateFunctionRegrData { return slope; } }; +template +struct AggregateFunctionRegrSxxData { + using Type = T; + UInt64 count = 0; + Float64 sum_x {}; + Float64 sum_of_x_squared {}; + + void write(BufferWritable& buf) const { + write_binary(sum_x, buf); + write_binary(sum_of_x_squared, buf); + write_binary(count, buf); + } + + void read(BufferReadable& buf) { + read_binary(sum_x, buf); + read_binary(sum_of_x_squared, buf); + read_binary(count, buf); + } + + void reset() { + sum_x = {}; + sum_of_x_squared = {}; + count = 0; + } + + void merge(const AggregateFunctionRegrSxxData& rhs) { + if (rhs.count == 0) { + return; + } + sum_x += rhs.sum_x; + sum_of_x_squared += rhs.sum_of_x_squared; + count += rhs.count; + } + + void add(T value_y, T value_x) { + sum_x += value_x; + sum_of_x_squared += value_x * value_x; + count += 1; + } + + Float64 get_regr_sxx_result() const { + // count == 0 + // The result of a query for an empty table is a null value + Float64 result = sum_of_x_squared - (sum_x * sum_x / count); + return result; + } +}; +template +struct AggregateFunctionRegrSxyData { + using Type = T; + UInt64 count = 0; + Float64 sum_x {}; + Float64 sum_y {}; + Float64 sum_of_x_mul_y {}; + + void write(BufferWritable& buf) const { + write_binary(sum_x, buf); + write_binary(sum_y, buf); + write_binary(sum_of_x_mul_y, buf); + write_binary(count, buf); + } + + void read(BufferReadable& buf) { + read_binary(sum_x, buf); + read_binary(sum_y, buf); + read_binary(sum_of_x_mul_y, buf); + read_binary(count, buf); + } + + void reset() { + sum_x = {}; + sum_y = {}; + sum_of_x_mul_y = {}; + count = 0; + } + + void merge(const AggregateFunctionRegrSxyData& rhs) { + if (rhs.count == 0) { + return; + } + sum_x += rhs.sum_x; + sum_y += rhs.sum_y; + sum_of_x_mul_y += rhs.sum_of_x_mul_y; + count += rhs.count; + } + + void add(T value_y, T value_x) { + sum_x += value_x; + sum_y += value_y; + sum_of_x_mul_y += value_x * value_y; + count += 1; + } + + Float64 get_regr_sxy_result() const { + // count == 0 + // The result of a query for an empty table is a null value + Float64 result = sum_of_x_mul_y - (sum_x * sum_y / count); + return result; + } +}; +template +struct AggregateFunctionRegrSyyData { + using Type = T; + UInt64 count = 0; + Float64 sum_y {}; + Float64 sum_of_y_squared {}; + + void write(BufferWritable& buf) const { + write_binary(sum_y, buf); + write_binary(sum_of_y_squared, buf); + write_binary(count, buf); + } + + void read(BufferReadable& buf) { + read_binary(sum_y, buf); + read_binary(sum_of_y_squared, buf); + read_binary(count, buf); + } + + void reset() { + sum_y = {}; + sum_of_y_squared = {}; + count = 0; + } + + void merge(const AggregateFunctionRegrSyyData& rhs) { + if (rhs.count == 0) { + return; + } + sum_y += rhs.sum_y; + sum_of_y_squared += rhs.sum_of_y_squared; + count += rhs.count; + } + + void add(T value_y, T value_x) { + sum_y += value_y; + sum_of_y_squared += value_y * value_y; + count += 1; + } + + Float64 get_regr_syy_result() const { + // count == 0 + // The result of a query for an empty table is a null value + Float64 result = sum_of_y_squared - (sum_y * sum_y / count); + return result; + } +}; +template +struct RegrSxxFunc : AggregateFunctionRegrSxxData { + static constexpr const char* name = "regr_sxx"; + + Float64 get_result() const { return this->get_regr_sxx_result(); } +}; + +template +struct RegrSxyFunc : AggregateFunctionRegrSxyData { + static constexpr const char* name = "regr_sxy"; + + Float64 get_result() const { return this->get_regr_sxy_result(); } +}; + +template +struct RegrSyyFunc : AggregateFunctionRegrSyyData { + static constexpr const char* name = "regr_syy"; + + Float64 get_result() const { return this->get_regr_syy_result(); } +}; template struct RegrSlopeFunc : AggregateFunctionRegrData { diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp index 0a31fd6c5763ac..58bfce1b5790f6 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp @@ -58,6 +58,7 @@ void register_aggregate_function_window_funnel(AggregateFunctionSimpleFactory& f void register_aggregate_function_window_funnel_old(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_regr_union(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_retention(AggregateFunctionSimpleFactory& factory); +void register_aggregate_function_regr_mixed(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_percentile_approx(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_orthogonal_bitmap(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_collect_list(AggregateFunctionSimpleFactory& factory); diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java index df9c792aec6c41..e71cfee52819dd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java @@ -55,6 +55,7 @@ public class AggregateFunction extends Function { "approx_count_distinct", "ndv", FunctionSet.BITMAP_UNION_INT, FunctionSet.BITMAP_UNION_COUNT, "ndv_no_finalize", "percentile_array", "histogram", FunctionSet.SEQUENCE_COUNT, FunctionSet.MAP_AGG, FunctionSet.BITMAP_AGG, FunctionSet.ARRAY_AGG, + FunctionSet.REGR_SXX, FunctionSet.REGR_SYY, FunctionSet.REGR_SXY, FunctionSet.COLLECT_LIST, FunctionSet.COLLECT_SET, FunctionSet.GROUP_ARRAY_INTERSECT, FunctionSet.SUM0, FunctionSet.MULTI_DISTINCT_SUM0, FunctionSet.REGR_INTERCEPT, FunctionSet.REGR_SLOPE); diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java index 3611764886b48a..718cf92ebba95a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java @@ -64,6 +64,9 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.QuantileUnion; import org.apache.doris.nereids.trees.expressions.functions.agg.RegrIntercept; import org.apache.doris.nereids.trees.expressions.functions.agg.RegrSlope; +import org.apache.doris.nereids.trees.expressions.functions.agg.RegrSxx; +import org.apache.doris.nereids.trees.expressions.functions.agg.RegrSxy; +import org.apache.doris.nereids.trees.expressions.functions.agg.RegrSyy; import org.apache.doris.nereids.trees.expressions.functions.agg.Retention; import org.apache.doris.nereids.trees.expressions.functions.agg.SequenceCount; import org.apache.doris.nereids.trees.expressions.functions.agg.SequenceMatch; @@ -139,6 +142,9 @@ public class BuiltinAggregateFunctions implements FunctionHelper { agg(QuantileUnion.class, "quantile_union"), agg(RegrIntercept.class, "regr_intercept"), agg(RegrSlope.class, "regr_slope"), + agg(RegrSxx.class, "regr_sxx"), + agg(RegrSxy.class, "regr_sxy"), + agg(RegrSyy.class, "regr_syy"), agg(Retention.class, "retention"), agg(SequenceCount.class, "sequence_count"), agg(SequenceMatch.class, "sequence_match"), diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java index c9110bb0e13a0c..2089d489bcece5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java @@ -635,6 +635,12 @@ public void addBuiltinBothScalaAndVectorized(Function fn) { public static final String REGR_SLOPE = "regr_slope"; + public static final String REGR_SXX = "regr_sxx"; + + public static final String REGR_SXY = "regr_sxy"; + + public static final String REGR_SYY = "regr_syy"; + public static final String SEQUENCE_MATCH = "sequence_match"; public static final String SEQUENCE_COUNT = "sequence_count"; @@ -708,6 +714,33 @@ private void initAggregateBuiltins() { null, false, true, true, true)); } + addBuiltin(AggregateFunction.createBuiltin(FunctionSet.REGR_SXX, + Lists.newArrayList(Type.DOUBLE, Type.DOUBLE), Type.DOUBLE, Type.DOUBLE, + "", + "", + "", + null, null, + "", + null, false, false, false, true)); + + addBuiltin(AggregateFunction.createBuiltin(FunctionSet.REGR_SXY, + Lists.newArrayList(Type.DOUBLE, Type.DOUBLE), Type.DOUBLE, Type.DOUBLE, + "", + "", + "", + null, null, + "", + null, false, false, false, true)); + + addBuiltin(AggregateFunction.createBuiltin(FunctionSet.REGR_SYY, + Lists.newArrayList(Type.DOUBLE, Type.DOUBLE), Type.DOUBLE, Type.DOUBLE, + "", + "", + "", + null, null, + "", + null, false, false, false, true)); + // Vectorization does not need symbol any more, we should clean it in the future. addBuiltin(AggregateFunction.createBuiltin(FunctionSet.WINDOW_FUNNEL, Lists.newArrayList(Type.BIGINT, Type.STRING, Type.DATETIME, Type.BOOLEAN), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/RegrSxx.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/RegrSxx.java new file mode 100644 index 00000000000000..d9dc1f81f03919 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/RegrSxx.java @@ -0,0 +1,94 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions.agg; + +import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.exceptions.AnalysisException; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable; +import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; +import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.DoubleType; +import org.apache.doris.nereids.types.FloatType; +import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.types.SmallIntType; +import org.apache.doris.nereids.types.TinyIntType; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + +import java.util.List; + +/** regr_sxx agg function. */ +public class RegrSxx extends AggregateFunction + implements BinaryExpression, ExplicitlyCastableSignature, AlwaysNullable { + + public static final List SIGNATURES = ImmutableList.of( + FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, DoubleType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE, BigIntType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE, IntegerType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE, SmallIntType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE, TinyIntType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE, FloatType.INSTANCE)); + + /** + * Constructor with 2 arguments. + */ + public RegrSxx(Expression arg1, Expression arg2) { + this(false, arg1, arg2); + } + + /** + * Constructor with distinct flag and 2 arguments. + */ + public RegrSxx(boolean distinct, Expression arg1, Expression arg2) { + super("regr_sxx", distinct, arg1, arg2); + } + + @Override + public void checkLegalityBeforeTypeCoercion() throws AnalysisException { + DataType arg0Type = left().getDataType(); + DataType arg1Type = right().getDataType(); + if ((!arg0Type.isNumericType() && !arg0Type.isNullType()) + || arg0Type.isOnlyMetricType()) { + throw new AnalysisException("regr_sxx requires numeric for first parameter: " + toSql()); + } else if ((!arg1Type.isNumericType() && !arg1Type.isNullType()) + || arg1Type.isOnlyMetricType()) { + throw new AnalysisException("regr_sxx requires numeric for second parameter: " + toSql()); + } + } + + @Override + public RegrSxx withDistinctAndChildren(boolean distinct, List children) { + Preconditions.checkArgument(children.size() == 2); + return new RegrSxx(distinct, children.get(0), children.get(1)); + } + + @Override + public R accept(ExpressionVisitor visitor, C context) { + return visitor.visitRegrSxx(this, context); + } + + @Override + public List getSignatures() { + return SIGNATURES; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/RegrSxy.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/RegrSxy.java new file mode 100644 index 00000000000000..768279ea25b2ef --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/RegrSxy.java @@ -0,0 +1,94 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions.agg; + +import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.exceptions.AnalysisException; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable; +import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; +import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.DoubleType; +import org.apache.doris.nereids.types.FloatType; +import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.types.SmallIntType; +import org.apache.doris.nereids.types.TinyIntType; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + +import java.util.List; + +/** regr_sxy agg function. */ +public class RegrSxy extends AggregateFunction + implements BinaryExpression, ExplicitlyCastableSignature, AlwaysNullable { + + public static final List SIGNATURES = ImmutableList.of( + FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, DoubleType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE, BigIntType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE, IntegerType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE, SmallIntType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE, TinyIntType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE, FloatType.INSTANCE)); + + /** + * Constructor with 2 arguments. + */ + public RegrSxy(Expression arg1, Expression arg2) { + this(false, arg1, arg2); + } + + /** + * Constructor with distinct flag and 2 arguments. + */ + public RegrSxy(boolean distinct, Expression arg1, Expression arg2) { + super("regr_sxy", distinct, arg1, arg2); + } + + @Override + public void checkLegalityBeforeTypeCoercion() throws AnalysisException { + DataType arg0Type = left().getDataType(); + DataType arg1Type = right().getDataType(); + if ((!arg0Type.isNumericType() && !arg0Type.isNullType()) + || arg0Type.isOnlyMetricType()) { + throw new AnalysisException("regr_sxy requires numeric for first parameter: " + toSql()); + } else if ((!arg1Type.isNumericType() && !arg1Type.isNullType()) + || arg1Type.isOnlyMetricType()) { + throw new AnalysisException("regr_sxy requires numeric for second parameter: " + toSql()); + } + } + + @Override + public RegrSxy withDistinctAndChildren(boolean distinct, List children) { + Preconditions.checkArgument(children.size() == 2); + return new RegrSxy(distinct, children.get(0), children.get(1)); + } + + @Override + public R accept(ExpressionVisitor visitor, C context) { + return visitor.visitRegrSxy(this, context); + } + + @Override + public List getSignatures() { + return SIGNATURES; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/RegrSyy.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/RegrSyy.java new file mode 100644 index 00000000000000..8bb4860bf10039 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/RegrSyy.java @@ -0,0 +1,94 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions.agg; + +import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.exceptions.AnalysisException; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable; +import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; +import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.DoubleType; +import org.apache.doris.nereids.types.FloatType; +import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.types.SmallIntType; +import org.apache.doris.nereids.types.TinyIntType; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + +import java.util.List; + +/** regr_syy agg function. */ +public class RegrSyy extends AggregateFunction + implements BinaryExpression, ExplicitlyCastableSignature, AlwaysNullable { + + public static final List SIGNATURES = ImmutableList.of( + FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, DoubleType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE, BigIntType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE, IntegerType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE, SmallIntType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE, TinyIntType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE, FloatType.INSTANCE)); + + /** + * Constructor with 2 arguments. + */ + public RegrSyy(Expression arg1, Expression arg2) { + this(false, arg1, arg2); + } + + /** + * Constructor with distinct flag and 2 arguments. + */ + public RegrSyy(boolean distinct, Expression arg1, Expression arg2) { + super("regr_syy", distinct, arg1, arg2); + } + + @Override + public void checkLegalityBeforeTypeCoercion() throws AnalysisException { + DataType arg0Type = left().getDataType(); + DataType arg1Type = right().getDataType(); + if ((!arg0Type.isNumericType() && !arg0Type.isNullType()) + || arg0Type.isOnlyMetricType()) { + throw new AnalysisException("regr_syy requires numeric for first parameter: " + toSql()); + } else if ((!arg1Type.isNumericType() && !arg1Type.isNullType()) + || arg1Type.isOnlyMetricType()) { + throw new AnalysisException("regr_syy requires numeric for second parameter: " + toSql()); + } + } + + @Override + public RegrSyy withDistinctAndChildren(boolean distinct, List children) { + Preconditions.checkArgument(children.size() == 2); + return new RegrSyy(distinct, children.get(0), children.get(1)); + } + + @Override + public R accept(ExpressionVisitor visitor, C context) { + return visitor.visitRegrSyy(this, context); + } + + @Override + public List getSignatures() { + return SIGNATURES; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java index 4f9f9e5e3643f5..a95278537fa6d2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java @@ -66,6 +66,9 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.QuantileUnion; import org.apache.doris.nereids.trees.expressions.functions.agg.RegrIntercept; import org.apache.doris.nereids.trees.expressions.functions.agg.RegrSlope; +import org.apache.doris.nereids.trees.expressions.functions.agg.RegrSxx; +import org.apache.doris.nereids.trees.expressions.functions.agg.RegrSxy; +import org.apache.doris.nereids.trees.expressions.functions.agg.RegrSyy; import org.apache.doris.nereids.trees.expressions.functions.agg.Retention; import org.apache.doris.nereids.trees.expressions.functions.agg.SequenceCount; import org.apache.doris.nereids.trees.expressions.functions.agg.SequenceMatch; @@ -282,6 +285,18 @@ default R visitRegrSlope(RegrSlope regrSlope, C context) { return visitAggregateFunction(regrSlope, context); } + default R visitRegrSxx(RegrSxx regrSxx, C context) { + return visitAggregateFunction(regrSxx, context); + } + + default R visitRegrSxy(RegrSxy regrSxy, C context) { + return visitAggregateFunction(regrSxy, context); + } + + default R visitRegrSyy(RegrSyy regrSyy, C context) { + return visitAggregateFunction(regrSyy, context); + } + default R visitRetention(Retention retention, C context) { return visitNullableAggregateFunction(retention, context); } diff --git a/regression-test/data/nereids_function_p0/agg_function/test_regr_sxx.out b/regression-test/data/nereids_function_p0/agg_function/test_regr_sxx.out new file mode 100644 index 00000000000000..0f4b68c6f1b1c8 --- /dev/null +++ b/regression-test/data/nereids_function_p0/agg_function/test_regr_sxx.out @@ -0,0 +1,24 @@ +-- !sql -- +\N + +-- !sql -- +92.79999999999995 + +-- !sql -- +92.79999999999995 + +-- !sql -- +92.79999999999995 + +-- !sql -- +86.93426485321584 + +-- !sql -- +92.75 + +-- !sql -- +92.75 + +-- !sql -- +92.75 + diff --git a/regression-test/data/nereids_function_p0/agg_function/test_regr_sxy.out b/regression-test/data/nereids_function_p0/agg_function/test_regr_sxy.out new file mode 100644 index 00000000000000..6b33d373bbd15b --- /dev/null +++ b/regression-test/data/nereids_function_p0/agg_function/test_regr_sxy.out @@ -0,0 +1,30 @@ +-- !sql -- +\N + +-- !sql -- +0.0 + +-- !sql -- +0.0 + +-- !sql -- +0.0 + +-- !sql -- +\N + +-- !sql -- +\N + +-- !sql -- +60.742502897294685 + +-- !sql -- +59.75 + +-- !sql -- +59.75 + +-- !sql -- +59.75 + diff --git a/regression-test/data/nereids_function_p0/agg_function/test_regr_syy.out b/regression-test/data/nereids_function_p0/agg_function/test_regr_syy.out new file mode 100644 index 00000000000000..cecaa04286a3b8 --- /dev/null +++ b/regression-test/data/nereids_function_p0/agg_function/test_regr_syy.out @@ -0,0 +1,24 @@ +-- !sql -- +\N + +-- !sql -- +413.20000000000005 + +-- !sql -- +413.20000000000005 + +-- !sql -- +\N + +-- !sql -- +413.35074075257376 + +-- !sql -- +224.75 + +-- !sql -- +224.75 + +-- !sql -- +224.75 + diff --git a/regression-test/suites/nereids_function_p0/agg_function/test_regr_sxx.groovy b/regression-test/suites/nereids_function_p0/agg_function/test_regr_sxx.groovy new file mode 100644 index 00000000000000..6881c1d3931780 --- /dev/null +++ b/regression-test/suites/nereids_function_p0/agg_function/test_regr_sxx.groovy @@ -0,0 +1,143 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_regr_sxx") { + sql """ DROP TABLE IF EXISTS test_regr_sxx_int """ + sql """ DROP TABLE IF EXISTS test_regr_sxx_double """ + sql """ DROP TABLE IF EXISTS test_regr_sxx_nullable_col """ + + + sql """ SET enable_nereids_planner=true """ + sql """ SET enable_fallback_to_original_planner=false """ + + sql """ + CREATE TABLE test_regr_sxx_int ( + `id` int, + `x` int, + `y` int, + ) ENGINE=OLAP + Duplicate KEY (`id`) + DISTRIBUTED BY HASH(`id`) BUCKETS 4 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + ); + """ + sql """ + CREATE TABLE test_regr_sxx_double ( + `id` int, + `x` double, + `y` double, + ) ENGINE=OLAP + Duplicate KEY (`id`) + DISTRIBUTED BY HASH(`id`) BUCKETS 4 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + ); + """ + sql """ + CREATE TABLE test_regr_sxx_nullable_col ( + `id` int, + `x` int, + `y` int, + ) ENGINE=OLAP + Duplicate KEY (`id`) + DISTRIBUTED BY HASH(`id`) BUCKETS 4 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + ); + """ + // no value + qt_sql "select regr_sxx(y,x) from test_regr_sxx_int" + sql """ truncate table test_regr_sxx_int """ + + sql """ + insert into test_regr_sxx_int values + (1, 18, 13), + (2, 14, 27), + (3, 12, 2), + (4, 5, 6), + (5, 10, 20); + """ + + sql """ + insert into test_regr_sxx_double values + (1, 18.27123456, 13.27123456), + (2, 14.65890846, 27.65890846), + (3, 12.25345846, 2.253458468), + (4, 5.890846835, 6.890846835), + (5, 10.14345678, 20.14345678); + """ + + sql """ + insert into test_regr_sxx_nullable_col values + (1, 18, 13), + (2, 14, 27), + (3, 5, 7), + (4, 10, 20); + """ + + // value is null + sql """select regr_sxx(NULL, NULL);""" + + // parameter is literal and columns + qt_sql "select regr_sxx(10,x) from test_regr_sxx_int" + + // literal and column + qt_sql "select regr_sxx(4,x) from test_regr_sxx_int" + + // int value + qt_sql "select regr_sxx(y,x) from test_regr_sxx_int" + sql """ truncate table test_regr_sxx_int """ + + // double value + qt_sql "select regr_sxx(y,x) from test_regr_sxx_double" + sql """ truncate table test_regr_sxx_double """ + + // nullable and non_nullable + qt_sql "select regr_sxx(y,non_nullable(x)) from test_regr_sxx_nullable_col" + + // non_nullable and nullable + qt_sql "select regr_sxx(non_nullable(y),x) from test_regr_sxx_nullable_col" + + // non_nullable and non_nullable + qt_sql "select regr_sxx(non_nullable(y),non_nullable(x)) from test_regr_sxx_nullable_col" + sql """ truncate table test_regr_sxx_nullable_col """ + + + + // exception test + test{ + sql """select regr_sxx('range', 1);""" + exception "regr_sxx requires numeric for first parameter" + } + + test{ + sql """select regr_sxx(1, 'hello');""" + exception "regr_sxx requires numeric for second parameter" + } + + test{ + sql """select regr_sxx(y, 'hello') from test_regr_sxx_int;""" + exception "regr_sxx requires numeric for second parameter" + } + + test{ + sql """select regr_sxx(1, true);""" + exception "regr_sxx requires numeric for second parameter" + } + +} diff --git a/regression-test/suites/nereids_function_p0/agg_function/test_regr_sxy.groovy b/regression-test/suites/nereids_function_p0/agg_function/test_regr_sxy.groovy new file mode 100644 index 00000000000000..5d5b5274572f43 --- /dev/null +++ b/regression-test/suites/nereids_function_p0/agg_function/test_regr_sxy.groovy @@ -0,0 +1,137 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_regr_sxy") { + sql """ DROP TABLE IF EXISTS test_regr_sxy_int """ + sql """ DROP TABLE IF EXISTS test_regr_sxy_double """ + sql """ DROP TABLE IF EXISTS test_regr_sxy_nullable_col """ + + + sql """ SET enable_nereids_planner=true """ + sql """ SET enable_fallback_to_original_planner=false """ + + sql """ + CREATE TABLE test_regr_sxy_int ( + `id` int, + `x` int, + `y` int, + ) ENGINE=OLAP + Duplicate KEY (`id`) + DISTRIBUTED BY HASH(`id`) BUCKETS 4 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + ); + """ + sql """ + CREATE TABLE test_regr_sxy_double ( + `id` int, + `x` double, + `y` double, + ) ENGINE=OLAP + Duplicate KEY (`id`) + DISTRIBUTED BY HASH(`id`) BUCKETS 4 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + ); + """ + sql """ + CREATE TABLE test_regr_sxy_nullable_col ( + `id` int, + `x` int, + `y` int, + ) ENGINE=OLAP + Duplicate KEY (`id`) + DISTRIBUTED BY HASH(`id`) BUCKETS 4 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + ); + """ + // no value + qt_sql "select regr_sxy(y,x) from test_regr_sxy_int" + sql """ truncate table test_regr_sxy_int """ + + sql """ + insert into test_regr_sxy_int values + (1, 18, 13), + (2, 14, 27), + (3, 12, 2), + (4, 5, 6), + (5, 10, 20) + """ + + sql """ + insert into test_regr_sxy_double values + (1, 18.27123456, 13.27123456), + (2, 14.65890846, 27.65890846), + (3, 12.25345846, 2.253458468), + (4, 5.890846835, 6.890846835), + (5, 10.14345678, 20.14345678) + """ + + sql """ + insert into test_regr_sxy_nullable_col values + (1, 18, 13), + (2, 14, 27), + (3, 5, 7), + (4, 10, 20); + """ + + // value is null + sql """select regr_sxy(NULL, NULL);""" + + // literal and column + qt_sql "select regr_sxy(4,x) from test_regr_sxy_int" + + qt_sql "select regr_sxy(y,10) from test_regr_sxy_int" + + // value is literal and columns + qt_sql "select regr_sxy(y,20) from test_regr_sxy_int" + sql """ truncate table test_regr_sxy_int """ + + // value is literal and columns + qt_sql "select regr_sxy(10,x) from test_regr_sxy_int" + sql """ truncate table test_regr_sxy_int """ + + // int value + qt_sql "select regr_sxy(y,x) from test_regr_sxy_int" + sql """ truncate table test_regr_sxy_int """ + + // double value + qt_sql "select regr_sxy(y,x) from test_regr_sxy_double" + sql """ truncate table test_regr_sxy_double """ + + // nullable and non_nullable + qt_sql "select regr_sxy(y,non_nullable(x)) from test_regr_sxy_nullable_col" + + // non_nullable and nullable + qt_sql "select regr_sxy(non_nullable(y),x) from test_regr_sxy_nullable_col" + + // non_nullable and non_nullable + qt_sql "select regr_sxy(non_nullable(y),non_nullable(x)) from test_regr_sxy_nullable_col" + sql """ truncate table test_regr_sxy_nullable_col """ + + // exception test + test{ + sql """select regr_sxy('range', 1);""" + exception "regr_sxy requires numeric for first parameter" + } + test{ + sql """select regr_sxy(1, 'hello');""" + exception "regr_sxy requires numeric for second parameter" + } + +} diff --git a/regression-test/suites/nereids_function_p0/agg_function/test_regr_syy.groovy b/regression-test/suites/nereids_function_p0/agg_function/test_regr_syy.groovy new file mode 100644 index 00000000000000..cc3cf934c108bd --- /dev/null +++ b/regression-test/suites/nereids_function_p0/agg_function/test_regr_syy.groovy @@ -0,0 +1,131 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_regr_syy") { + sql """ DROP TABLE IF EXISTS test_regr_syy_int """ + sql """ DROP TABLE IF EXISTS test_regr_syy_double """ + sql """ DROP TABLE IF EXISTS test_regr_syy_nullable_col """ + + + sql """ SET enable_nereids_planner=true """ + sql """ SET enable_fallback_to_original_planner=false """ + + sql """ + CREATE TABLE test_regr_syy_int ( + `id` int, + `x` int, + `y` int, + ) ENGINE=OLAP + Duplicate KEY (`id`) + DISTRIBUTED BY HASH(`id`) BUCKETS 4 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + ); + """ + sql """ + CREATE TABLE test_regr_syy_double ( + `id` int, + `x` double, + `y` double, + ) ENGINE=OLAP + Duplicate KEY (`id`) + DISTRIBUTED BY HASH(`id`) BUCKETS 4 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + ); + """ + sql """ + CREATE TABLE test_regr_syy_nullable_col ( + `id` int NULL, + `x` int NULL, + `y` int NULL, + ) ENGINE=OLAP + Duplicate KEY (`id`) + DISTRIBUTED BY HASH(`id`) BUCKETS 4 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + ); + """ + // no value + qt_sql "select regr_syy(y,x) from test_regr_syy_int" + sql """ truncate table test_regr_syy_int """ + + sql """ + insert into test_regr_syy_int values + (1, 18, 13), + (2, 14, 27), + (3, 12, 2), + (4, 5, 6), + (5, 10, 20) + """ + + sql """ + insert into test_regr_syy_double values + (1, 18.27123456, 13.27123456), + (2, 14.65890846, 27.65890846), + (3, 12.25345846, 2.253458468), + (4, 5.890846835, 6.890846835), + (5, 10.14345678, 20.14345678) + """ + + sql """ + insert into test_regr_syy_nullable_col values + (1, 18, 13), + (2, 14, 27), + (3, 5, 7), + (4, 10, 20); + """ + + // value is null + sql """select regr_syy(NULL, NULL);""" + + // literal and column + qt_sql "select regr_syy(y,4) from test_regr_syy_int" + + // value is literal and columns + qt_sql "select regr_syy(y,20) from test_regr_syy_int" + sql """ truncate table test_regr_syy_int """ + + // int value + qt_sql "select regr_syy(y,x) from test_regr_syy_int" + sql """ truncate table test_regr_syy_int """ + + // double value + qt_sql "select regr_syy(y,x) from test_regr_syy_double" + sql """ truncate table test_regr_syy_double """ + + // nullable and non_nullable + qt_sql "select regr_syy(y,non_nullable(x)) from test_regr_syy_nullable_col" + + // non_nullable and nullable + qt_sql "select regr_syy(non_nullable(y),x) from test_regr_syy_nullable_col" + + // non_nullable and non_nullable + qt_sql "select regr_syy(non_nullable(y),non_nullable(x)) from test_regr_syy_nullable_col" + sql """ truncate table test_regr_syy_nullable_col """ + + // exception test + test{ + sql """select regr_syy('range', 1);""" + exception "regr_syy requires numeric for first parameter" + } + test{ + sql """select regr_syy(1, 'hello');""" + exception "regr_syy requires numeric for second parameter" + } + +}