From c8b74c802f861715c46d87412c147e3b66772ae6 Mon Sep 17 00:00:00 2001 From: zhiqiang-hhhh Date: Wed, 18 Sep 2024 22:10:19 +0800 Subject: [PATCH 1/8] X --- .../exec/aggregation_source_operator.cpp | 5 + .../aggregate_function_kurtosis.cpp | 59 +++++++ .../aggregate_function_simple_factory.cpp | 5 + .../aggregate_function_simple_factory.h | 8 +- .../aggregate_function_skew.cpp | 58 +++++++ .../aggregate_function_statistic.h | 156 ++++++++++++++++++ be/src/vec/aggregate_functions/moments.h | 112 +++++++++++++ .../catalog/BuiltinAggregateFunctions.java | 6 +- .../org/apache/doris/catalog/FunctionSet.java | 36 ++++ .../trees/expressions/functions/agg/Kurt.java | 89 ++++++++++ .../trees/expressions/functions/agg/Skew.java | 89 ++++++++++ .../visitor/AggregateFunctionVisitor.java | 10 ++ .../aggregate/aggregate_function_kurt.out | 52 ++++++ .../aggregate/aggregate_function_skew.out | 52 ++++++ .../aggregate/aggregate_function_kurt.groovy | 78 +++++++++ .../aggregate/aggregate_function_skew.groovy | 78 +++++++++ 16 files changed, 891 insertions(+), 2 deletions(-) create mode 100644 be/src/vec/aggregate_functions/aggregate_function_kurtosis.cpp create mode 100644 be/src/vec/aggregate_functions/aggregate_function_skew.cpp create mode 100644 be/src/vec/aggregate_functions/aggregate_function_statistic.h create mode 100644 be/src/vec/aggregate_functions/moments.h create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Kurt.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Skew.java create mode 100644 regression-test/data/query_p0/aggregate/aggregate_function_kurt.out create mode 100644 regression-test/data/query_p0/aggregate/aggregate_function_skew.out create mode 100644 regression-test/suites/query_p0/aggregate/aggregate_function_kurt.groovy create mode 100644 regression-test/suites/query_p0/aggregate/aggregate_function_skew.groovy diff --git a/be/src/pipeline/exec/aggregation_source_operator.cpp b/be/src/pipeline/exec/aggregation_source_operator.cpp index a5f40a431c5ee6..fe03eba4102955 100644 --- a/be/src/pipeline/exec/aggregation_source_operator.cpp +++ b/be/src/pipeline/exec/aggregation_source_operator.cpp @@ -416,6 +416,11 @@ Status AggLocalState::_get_without_key_result(RuntimeState* state, vectorized::B } } + // Result of operator is nullable, but aggregate function result is not nullable + // this happens when: + // 1. no group by + // 2. input of aggregate function is empty + // 3. all of input columns are not nullable if (column_type->is_nullable() && !data_types[i]->is_nullable()) { vectorized::ColumnPtr ptr = std::move(columns[i]); // unless `count`, other aggregate function dispose empty set should be null diff --git a/be/src/vec/aggregate_functions/aggregate_function_kurtosis.cpp b/be/src/vec/aggregate_functions/aggregate_function_kurtosis.cpp new file mode 100644 index 00000000000000..e76507594e7b9d --- /dev/null +++ b/be/src/vec/aggregate_functions/aggregate_function_kurtosis.cpp @@ -0,0 +1,59 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "common/status.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/aggregate_functions/aggregate_function_simple_factory.h" +#include "vec/aggregate_functions/aggregate_function_statistic.h" +#include "vec/aggregate_functions/helpers.h" +#include "vec/core/types.h" +#include "vec/data_types/data_type.h" + +namespace doris::vectorized { + +template +AggregateFunctionPtr type_dispatch_for_aggregate_function_kurtosis(const DataTypes& argument_types, + const bool result_is_nullable) { + using StatFunctionTemplate = StatFuncOneArg; + return creator_without_type::create>( + argument_types, result_is_nullable, StatisticsFunctionKind::kurtPop); +}; + +AggregateFunctionPtr create_aggregate_function_kurtosis(const std::string& name, + const DataTypes& argument_types, + const bool result_is_nullable) { + WhichDataType type(remove_nullable(argument_types[0])); + +#define DISPATCH(TYPE) \ + if (type.idx == TypeIndex::TYPE) \ + return type_dispatch_for_aggregate_function_kurtosis(argument_types, \ + result_is_nullable); + FOR_NUMERIC_TYPES(DISPATCH) +#undef DISPATCH + + LOG(WARNING) << "unsupported input type " << argument_types[0]->get_name() + << " for aggregate function " << name; + return nullptr; +} + +void register_aggregate_function_kurtosis_pop(AggregateFunctionSimpleFactory& factory) { + factory.register_function_both("kurt", create_aggregate_function_kurtosis); + factory.register_alias("kurt", "kurt_pop"); + factory.register_alias("kurt", "kurtosis"); +} + +} // namespace doris::vectorized \ No newline at end of file diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp index 436691c6ef2aad..19a0eacfe8d2fc 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp @@ -68,6 +68,8 @@ void register_aggregate_function_bitmap_agg(AggregateFunctionSimpleFactory& fact void register_aggregate_functions_corr(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_covar_pop(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_covar_samp(AggregateFunctionSimpleFactory& factory); +void register_aggregate_function_skew_pop(AggregateFunctionSimpleFactory& factory); +void register_aggregate_function_kurtosis_pop(AggregateFunctionSimpleFactory& factory); AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() { static std::once_flag oc; @@ -119,6 +121,9 @@ AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() { register_aggregate_function_covar_samp(instance); register_aggregate_function_combinator_foreach(instance); + + register_aggregate_function_skew_pop(instance); + register_aggregate_function_kurtosis_pop(instance); }); return instance; } diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h index cc504b9f99609d..22f2e0f8bae1d6 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h +++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h @@ -143,13 +143,19 @@ class AggregateFunctionSimpleFactory { if (function_alias.contains(name)) { name_str = function_alias[name]; } - + LOG_INFO("Find function name: {}", name_str); if (nullable) { + LOG_INFO("Nullable input function {} found: {}", name, + nullable_aggregate_functions.find(name_str) != + nullable_aggregate_functions.end()); return nullable_aggregate_functions.find(name_str) == nullable_aggregate_functions.end() ? nullptr : nullable_aggregate_functions[name_str](name_str, argument_types, result_is_nullable); } else { + LOG_INFO("Not input function {} found: {}", name, + nullable_aggregate_functions.find(name_str) != + nullable_aggregate_functions.end()); return aggregate_functions.find(name_str) == aggregate_functions.end() ? nullptr : aggregate_functions[name_str](name_str, argument_types, diff --git a/be/src/vec/aggregate_functions/aggregate_function_skew.cpp b/be/src/vec/aggregate_functions/aggregate_function_skew.cpp new file mode 100644 index 00000000000000..6acc12faa7dd00 --- /dev/null +++ b/be/src/vec/aggregate_functions/aggregate_function_skew.cpp @@ -0,0 +1,58 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "common/status.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/aggregate_functions/aggregate_function_simple_factory.h" +#include "vec/aggregate_functions/aggregate_function_statistic.h" +#include "vec/aggregate_functions/helpers.h" +#include "vec/core/types.h" +#include "vec/data_types/data_type.h" + +namespace doris::vectorized { + +template +AggregateFunctionPtr type_dispatch_for_aggregate_function_skew(const DataTypes& argument_types, + const bool result_is_nullable) { + using StatFunctionTemplate = StatFuncOneArg; + return creator_without_type::create>( + argument_types, result_is_nullable, StatisticsFunctionKind::skewPop); +}; + +AggregateFunctionPtr create_aggregate_function_skew(const std::string& name, + const DataTypes& argument_types, + const bool result_is_nullable) { + WhichDataType type(remove_nullable(argument_types[0])); + +#define DISPATCH(TYPE) \ + if (type.idx == TypeIndex::TYPE) \ + return type_dispatch_for_aggregate_function_skew(argument_types, result_is_nullable); + FOR_NUMERIC_TYPES(DISPATCH) +#undef DISPATCH + + LOG(WARNING) << "unsupported input type " << argument_types[0]->get_name() + << " for aggregate function " << name; + return nullptr; +} + +void register_aggregate_function_skew_pop(AggregateFunctionSimpleFactory& factory) { + factory.register_function_both("skew", create_aggregate_function_skew); + factory.register_alias("skew", "skew_pop"); + factory.register_alias("skew", "skewness"); +} + +} // namespace doris::vectorized \ No newline at end of file diff --git a/be/src/vec/aggregate_functions/aggregate_function_statistic.h b/be/src/vec/aggregate_functions/aggregate_function_statistic.h new file mode 100644 index 00000000000000..5fab9812e0a53d --- /dev/null +++ b/be/src/vec/aggregate_functions/aggregate_function_statistic.h @@ -0,0 +1,156 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once +#include +#include +#include +#include + +#include "common/exception.h" +#include "common/status.h" +#include "moments.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/aggregate_functions/moments.h" +#include "vec/columns/column_nullable.h" +#include "vec/columns/column_vector.h" +#include "vec/core/types.h" +#include "vec/data_types/data_type.h" +#include "vec/data_types/data_type_nullable.h" +#include "vec/data_types/data_type_number.h" + +namespace doris::vectorized { + +enum class StatisticsFunctionKind : uint8_t { skewPop, kurtPop }; + +inline std::string to_string(StatisticsFunctionKind kind) { + switch (kind) { + case StatisticsFunctionKind::skewPop: + return "skewness"; + case StatisticsFunctionKind::kurtPop: + return "kurtosis"; + default: + return "Unknown"; + } +} + +template +struct StatFuncOneArg { + using Type1 = T; + using Type2 = T; + using ResultType = Float64; + using Data = VarMoments; + + static constexpr UInt32 num_args = 1; +}; + +template +class AggregateFunctionVarianceSimple + : public IAggregateFunctionDataHelper> { +public: + using T1 = typename StatFunc::Type1; + using T2 = typename StatFunc::Type2; + using ColVecT1 = ColumnVectorOrDecimal; + using ColVecT2 = ColumnVectorOrDecimal; + using ResultType = typename StatFunc::ResultType; + using ColVecResult = ColumnVector; + + explicit AggregateFunctionVarianceSimple(StatisticsFunctionKind kind_, + const DataTypes& argument_types_) + : IAggregateFunctionDataHelper>( + argument_types_), + kind(kind_) { + DCHECK(!argument_types_.empty()); + } + + String get_name() const override { return to_string(kind); } + + DataTypePtr get_return_type() const override { return std::make_shared(); } + + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { + this->data(place).add(assert_cast(*columns[0]).get_data()[row_num]); + } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, + Arena*) const override { + this->data(place).merge(this->data(rhs)); + } + + void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { + this->data(place).write(buf); + } + + void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, + Arena*) const override { + this->data(place).read(buf); + } + + void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { + const auto& data = this->data(place); + ColVecResult* dst_column = assert_cast(&to); + + switch (kind) { + case StatisticsFunctionKind::skewPop: { + // If input is empty set, we will get NAN from getPopulation() + ResultType var_value = data.getPopulation(); + if (!std::isnan(var_value) && var_value > 0) { + ResultType moments3 = data.getMoment3(); + if (!std::isnan(moments3)) [[likely]] { + dst_column->get_data().push_back( + static_cast(moments3 / pow(var_value, 1.5))); + } else { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, + "skewness calculation error, result is NAN"); + } + } else { + // Empty input, result column will be: + // Nullable if without group by + // Nullable if with group by, and input column is nullable + // Non-Nullable if with group by, and input column is non-nullable + dst_column->insert_default(); + } + break; + } + case StatisticsFunctionKind::kurtPop: { + ResultType var_value = data.getPopulation(); + if (!std::isnan(var_value) && var_value > 0) { + ResultType moments4 = data.getMoment4(); + if (!std::isnan(moments4)) [[likely]] { + dst_column->get_data().push_back( + static_cast(moments4 / pow(var_value, 2))); + } else { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, + "skewness calculation error, result is NAN"); + } + } else { + dst_column->insert_default(); + } + break; + } + default: + throw doris::Exception(ErrorCode::INTERNAL_ERROR, "Unknown statistics function kind"); + } + } + +private: + StatisticsFunctionKind kind; +}; + +} // namespace doris::vectorized \ No newline at end of file diff --git a/be/src/vec/aggregate_functions/moments.h b/be/src/vec/aggregate_functions/moments.h new file mode 100644 index 00000000000000..fe189e9f74ca52 --- /dev/null +++ b/be/src/vec/aggregate_functions/moments.h @@ -0,0 +1,112 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "common/exception.h" +#include "common/status.h" +#include "vec/io/io_helper.h" + +namespace doris::vectorized { + +class BufferReadable; +class BufferWritable; + +template +struct VarMoments { + // m[1] = sum(x) + // m[2] = sum(x^2) + // m[3] = sum(x^3) + // m[4] = sum(x^4) + T m[_level + 1] {}; + + void add(T x) { + ++m[0]; + m[1] += x; + m[2] += x * x; + if constexpr (_level >= 3) m[3] += x * x * x; + if constexpr (_level >= 4) m[4] += x * x * x * x; + } + + void merge(const VarMoments& rhs) { + m[0] += rhs.m[0]; + m[1] += rhs.m[1]; + m[2] += rhs.m[2]; + if constexpr (_level >= 3) m[3] += rhs.m[3]; + if constexpr (_level >= 4) m[4] += rhs.m[4]; + } + + void write(BufferWritable& buf) const { write_binary(*this, buf); } + + void read(BufferReadable& buf) { read_binary(*this, buf); } + + T get() const { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, + "Variation moments should be obtained by either 'getSample' or " + "'getPopulation' method"); + } + + T getPopulation() const { + if (m[0] == 0) return std::numeric_limits::quiet_NaN(); + + /// Due to numerical errors, the result can be slightly less than zero, + /// but it should be impossible. Trim to zero. + + return std::max(T {}, (m[2] - m[1] * m[1] / m[0]) / m[0]); + } + + T getSample() const { + if (m[0] <= 1) return std::numeric_limits::quiet_NaN(); + return std::max(T {}, (m[2] - m[1] * m[1] / m[0]) / (m[0] - 1)); + } + + T getMoment3() const { + if constexpr (_level < 3) { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, + "Variation moments should be obtained by either 'getSample' or " + "'getPopulation' method"); + } else { + if (m[0] == 0) return std::numeric_limits::quiet_NaN(); + // to avoid accuracy problem + if (m[0] == 1) return 0; + /// \[ \frac{1}{m_0} (m_3 - (3 * m_2 - \frac{2 * {m_1}^2}{m_0}) * \frac{m_1}{m_0});\] + return (m[3] - (3 * m[2] - 2 * m[1] * m[1] / m[0]) * m[1] / m[0]) / m[0]; + } + } + + T getMoment4() const { + if constexpr (_level < 4) { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, + "Variation moments should be obtained by either 'getSample' or " + "'getPopulation' method"); + } else { + if (m[0] == 0) return std::numeric_limits::quiet_NaN(); + // to avoid accuracy problem + if (m[0] == 1) return 0; + /// \[ \frac{1}{m_0}(m_4 - (4 * m_3 - (6 * m_2 - \frac{3 * m_1^2}{m_0} ) \frac{m_1}{m_0})\frac{m_1}{m_0})\] + return (m[4] - + (4 * m[3] - (6 * m[2] - 3 * m[1] * m[1] / m[0]) * m[1] / m[0]) * m[1] / m[0]) / + m[0]; + } + } + + void reset() { return; } +}; + +} // namespace doris::vectorized \ No newline at end of file diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java index 28b1352eaf4551..6889adc6b6c9ae 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java @@ -43,6 +43,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnion; import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnionAgg; import org.apache.doris.nereids.trees.expressions.functions.agg.IntersectCount; +import org.apache.doris.nereids.trees.expressions.functions.agg.Kurt; import org.apache.doris.nereids.trees.expressions.functions.agg.MapAgg; import org.apache.doris.nereids.trees.expressions.functions.agg.Max; import org.apache.doris.nereids.trees.expressions.functions.agg.MaxBy; @@ -64,6 +65,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Retention; import org.apache.doris.nereids.trees.expressions.functions.agg.SequenceCount; import org.apache.doris.nereids.trees.expressions.functions.agg.SequenceMatch; +import org.apache.doris.nereids.trees.expressions.functions.agg.Skew; import org.apache.doris.nereids.trees.expressions.functions.agg.Stddev; import org.apache.doris.nereids.trees.expressions.functions.agg.StddevSamp; import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; @@ -145,7 +147,9 @@ public class BuiltinAggregateFunctions implements FunctionHelper { agg(TopNWeighted.class, "topn_weighted"), agg(Variance.class, "var_pop", "variance_pop", "variance"), agg(VarianceSamp.class, "var_samp", "variance_samp"), - agg(WindowFunnel.class, "window_funnel") + agg(WindowFunnel.class, "window_funnel"), + agg(Skew.class, "skew", "skew_pop", "skewness"), + agg(Kurt.class, "kurt", "kurt_pop", "kurtosis") ); public final Set aggFuncNames = aggregateFunctions.stream() diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java index 4165f9362214ed..94501f8caccb6d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java @@ -1940,6 +1940,42 @@ private void initAggregateBuiltins() { Lists.newArrayList(Type.DOUBLE, Type.DOUBLE), Type.DOUBLE, Type.DOUBLE, "", "", "", "", "", "", "", false, false, false, true)); + + + List skewnessAndKurtosis = Lists.newArrayList("skew", "skew_pop", "skewness", "kurt", + "kurt_pop", "kurtosis"); + skewnessAndKurtosis.addAll(skewnessAndKurtosis); + + for (String name : skewnessAndKurtosis) { + addBuiltin(AggregateFunction.createBuiltin(name, + Lists.newArrayList(Type.TINYINT), Type.DOUBLE, Type.DOUBLE, + "", "", "", "", "", "", "", + false, false, false, true)); + addBuiltin(AggregateFunction.createBuiltin(name, + Lists.newArrayList(Type.SMALLINT), Type.DOUBLE, Type.DOUBLE, + "", "", "", "", "", "", "", + false, false, false, true)); + addBuiltin(AggregateFunction.createBuiltin(name, + Lists.newArrayList(Type.INT), Type.DOUBLE, Type.DOUBLE, + "", "", "", "", "", "", "", + false, false, false, true)); + addBuiltin(AggregateFunction.createBuiltin(name, + Lists.newArrayList(Type.BIGINT), Type.DOUBLE, Type.DOUBLE, + "", "", "", "", "", "", "", + false, false, false, true)); + addBuiltin(AggregateFunction.createBuiltin(name, + Lists.newArrayList(Type.LARGEINT), Type.DOUBLE, Type.DOUBLE, + "", "", "", "", "", "", "", + false, false, false, true)); + addBuiltin(AggregateFunction.createBuiltin(name, + Lists.newArrayList(Type.FLOAT), Type.DOUBLE, Type.DOUBLE, + "", "", "", "", "", "", "", + false, false, false, true)); + addBuiltin(AggregateFunction.createBuiltin(name, + Lists.newArrayList(Type.DOUBLE), Type.DOUBLE, Type.DOUBLE, + "", "", "", "", "", "", "", + false, false, false, true)); + } } public Map> getVectorizedFunctions() { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Kurt.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Kurt.java new file mode 100644 index 00000000000000..206618505f436f --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Kurt.java @@ -0,0 +1,89 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions.agg; + +import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; +import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.types.DoubleType; +import org.apache.doris.nereids.types.FloatType; +import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.types.LargeIntType; +import org.apache.doris.nereids.types.SmallIntType; +import org.apache.doris.nereids.types.TinyIntType; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + +import java.util.List; + +/** + * AggregateFunction 'Kurt'. + */ +public class Kurt extends NullableAggregateFunction + implements UnaryExpression, ExplicitlyCastableSignature { + + public static final List SIGNATURES = ImmutableList.of( + FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(LargeIntType.INSTANCE) + ); + + public Kurt(Expression arg1) { + this(false, arg1); + } + + public Kurt(boolean distinct, Expression arg1) { + this(distinct, false, arg1); + } + + public Kurt(boolean distinct, boolean alwaysNullable, Expression arg1) { + super("Kurt", distinct, alwaysNullable, arg1); + } + + /** + * withDistinctAndChildren. + */ + @Override + public Kurt withDistinctAndChildren(boolean distinct, List children) { + Preconditions.checkArgument(children.size() == 1); + return new Kurt(distinct, alwaysNullable, children.get(0)); + } + + @Override + public Kurt withAlwaysNullable(boolean alwaysNullable) { + return new Kurt(distinct, alwaysNullable, children.get(0)); + } + + @Override + public R accept(ExpressionVisitor visitor, C context) { + return visitor.visitKurt(this, context); + } + + @Override + public List getSignatures() { + return SIGNATURES; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Skew.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Skew.java new file mode 100644 index 00000000000000..b8b6a7976b1813 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Skew.java @@ -0,0 +1,89 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions.agg; + +import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; +import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.types.DoubleType; +import org.apache.doris.nereids.types.FloatType; +import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.types.LargeIntType; +import org.apache.doris.nereids.types.SmallIntType; +import org.apache.doris.nereids.types.TinyIntType; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + +import java.util.List; + +/** + * AggregateFunction 'Skew'. + */ +public class Skew extends NullableAggregateFunction + implements UnaryExpression, ExplicitlyCastableSignature { + + public static final List SIGNATURES = ImmutableList.of( + FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(LargeIntType.INSTANCE) + ); + + public Skew(Expression arg1) { + this(false, arg1); + } + + public Skew(boolean distinct, Expression arg1) { + this(distinct, false, arg1); + } + + public Skew(boolean distinct, boolean alwaysNullable, Expression arg1) { + super("Skew", distinct, alwaysNullable, arg1); + } + + /** + * withDistinctAndChildren. + */ + @Override + public Skew withDistinctAndChildren(boolean distinct, List children) { + Preconditions.checkArgument(children.size() == 1); + return new Skew(distinct, alwaysNullable, children.get(0)); + } + + @Override + public Skew withAlwaysNullable(boolean alwaysNullable) { + return new Skew(distinct, alwaysNullable, children.get(0)); + } + + @Override + public R accept(ExpressionVisitor visitor, C context) { + return visitor.visitSkew(this, context); + } + + @Override + public List getSignatures() { + return SIGNATURES; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java index abe8044c28c342..5d3885abef046a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java @@ -44,6 +44,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnion; import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnionAgg; import org.apache.doris.nereids.trees.expressions.functions.agg.IntersectCount; +import org.apache.doris.nereids.trees.expressions.functions.agg.Kurt; import org.apache.doris.nereids.trees.expressions.functions.agg.MapAgg; import org.apache.doris.nereids.trees.expressions.functions.agg.Max; import org.apache.doris.nereids.trees.expressions.functions.agg.MaxBy; @@ -66,6 +67,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Retention; import org.apache.doris.nereids.trees.expressions.functions.agg.SequenceCount; import org.apache.doris.nereids.trees.expressions.functions.agg.SequenceMatch; +import org.apache.doris.nereids.trees.expressions.functions.agg.Skew; import org.apache.doris.nereids.trees.expressions.functions.agg.Stddev; import org.apache.doris.nereids.trees.expressions.functions.agg.StddevSamp; import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; @@ -210,6 +212,10 @@ default R visitIntersectCount(IntersectCount intersectCount, C context) { return visitAggregateFunction(intersectCount, context); } + default R visitKurt(Kurt kurt, C context) { + return visitNullableAggregateFunction(kurt, context); + } + default R visitMapAgg(MapAgg mapAgg, C context) { return visitAggregateFunction(mapAgg, context); } @@ -278,6 +284,10 @@ default R visitSequenceMatch(SequenceMatch sequenceMatch, C context) { return visitNullableAggregateFunction(sequenceMatch, context); } + default R visitSkew(Skew skew, C context) { + return visitNullableAggregateFunction(skew, context); + } + default R visitStddev(Stddev stddev, C context) { return visitNullableAggregateFunction(stddev, context); } diff --git a/regression-test/data/query_p0/aggregate/aggregate_function_kurt.out b/regression-test/data/query_p0/aggregate/aggregate_function_kurt.out new file mode 100644 index 00000000000000..15571714e3cbf1 --- /dev/null +++ b/regression-test/data/query_p0/aggregate/aggregate_function_kurt.out @@ -0,0 +1,52 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !sql_empty_1 -- +\N \N + +-- !sql_empty_2 -- + +-- !sql_1 -- +0.0 0.0 + +-- !sql_2 -- +0.0 0.0 + +-- !sql_3 -- +3.162124583734851 1.5000000000000007 + +-- !sql_4 -- +0.0 0.0 +0.0 \N +0.0 0.0 +0.0 \N +0.0 0.0 + +-- !sql_distinct_1 -- +2.2985631952470373 + +-- !sql_distinct_2 -- +1.5000000000000007 + +-- !sql_distinct_3 -- +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_distinct_4 -- +0.0 +\N +0.0 +\N +0.0 + +-- !sql_5 -- +3.162124583734851 1.5000000000000007 + +-- !sql_6 -- +0.0 0.0 +0.0 \N +0.0 0.0 +0.0 \N +0.0 0.0 + diff --git a/regression-test/data/query_p0/aggregate/aggregate_function_skew.out b/regression-test/data/query_p0/aggregate/aggregate_function_skew.out new file mode 100644 index 00000000000000..5ed27c441fdbfc --- /dev/null +++ b/regression-test/data/query_p0/aggregate/aggregate_function_skew.out @@ -0,0 +1,52 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !sql_empty_1 -- +\N \N + +-- !sql_empty_2 -- + +-- !sql_1 -- +0.0 0.0 + +-- !sql_2 -- +0.0 0.0 + +-- !sql_3 -- +1.4337199628825619 0.675885787569108 + +-- !sql_4 -- +0.0 \N +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 \N + +-- !sql_distinct_1 -- +1.1135657469022011 + +-- !sql_distinct_2 -- +0.675885787569108 + +-- !sql_distinct_3 -- +0.0 +0.0 +0.0 +0.0 +0.0 + +-- !sql_distinct_4 -- +\N +0.0 +0.0 +0.0 +\N + +-- !sql_5 -- +1.4337199628825619 0.675885787569108 + +-- !sql_6 -- +0.0 \N +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 \N + diff --git a/regression-test/suites/query_p0/aggregate/aggregate_function_kurt.groovy b/regression-test/suites/query_p0/aggregate/aggregate_function_kurt.groovy new file mode 100644 index 00000000000000..0e475467a16f87 --- /dev/null +++ b/regression-test/suites/query_p0/aggregate/aggregate_function_kurt.groovy @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("aggregate_function_kurt") { + sql """ + drop table if exists aggregate_function_kurt; + """ + sql""" + create table aggregate_function_kurt (tag int, val1 double not null, val2 double null) distributed by hash(tag) buckets 10 properties('replication_num' = '1'); + """ + + qt_sql_empty_1 """ + select kurtosis(val1),kurtosis(val2) from aggregate_function_kurt; + """ + qt_sql_empty_2 """ + select kurtosis(val1),kurtosis(val2) from aggregate_function_kurt group by tag; + """ + + sql """ + insert into aggregate_function_kurt values (1, -10.0, -10.0); + """ + + qt_sql_1 """ + select kurtosis(val1),kurtosis(val2) from aggregate_function_kurt; + """ + qt_sql_2 """ + select kurtosis(val1),kurtosis(val2) from aggregate_function_kurt group by tag; + """ + + sql """ + insert into aggregate_function_kurt values (2, -20.0, NULL), (3, 100, NULL), (4, 100, 100), (5,1000, 1000); + """ + qt_sql_3 """ + select kurtosis(val1),kurtosis(val2) from aggregate_function_kurt; + """ + qt_sql_4 """ + select kurtosis(val1),kurtosis(val2) from aggregate_function_kurt group by tag; + """ + + qt_sql_distinct_1 """ + select kurtosis(distinct val1) from aggregate_function_kurt; + """ + qt_sql_distinct_2 """ + select kurtosis(distinct val2) from aggregate_function_kurt; + """ + + qt_sql_distinct_3 """ + select kurtosis(distinct val1) from aggregate_function_kurt group by tag; + """ + qt_sql_distinct_4 """ + select kurtosis(distinct val2) from aggregate_function_kurt group by tag; + """ + + sql """ + insert into aggregate_function_kurt select * from aggregate_function_kurt; + """ + + qt_sql_5 """ + select kurt(val1),kurt_pop(val2) from aggregate_function_kurt; + """ + qt_sql_6 """ + select kurt(val1),kurt_pop(val2) from aggregate_function_kurt group by tag; + """ +} \ No newline at end of file diff --git a/regression-test/suites/query_p0/aggregate/aggregate_function_skew.groovy b/regression-test/suites/query_p0/aggregate/aggregate_function_skew.groovy new file mode 100644 index 00000000000000..b36e354cc481e3 --- /dev/null +++ b/regression-test/suites/query_p0/aggregate/aggregate_function_skew.groovy @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("aggregate_function_skew") { + sql """ + drop table if exists aggregate_function_skew; + """ + sql""" + create table aggregate_function_skew (tag int, val1 double not null, val2 double null) distributed by hash(tag) buckets 10 properties('replication_num' = '1'); + """ + + qt_sql_empty_1 """ + select skewness(val1),skewness(val2) from aggregate_function_skew; + """ + qt_sql_empty_2 """ + select skewness(val1),skewness(val2) from aggregate_function_skew group by tag; + """ + + sql """ + insert into aggregate_function_skew values (1, -10.0, -10.0); + """ + + qt_sql_1 """ + select skewness(val1),skewness(val2) from aggregate_function_skew; + """ + qt_sql_2 """ + select skewness(val1),skewness(val2) from aggregate_function_skew group by tag; + """ + + sql """ + insert into aggregate_function_skew values (2, -20.0, NULL), (3, 100, NULL), (4, 100, 100), (5,1000, 1000); + """ + qt_sql_3 """ + select skewness(val1),skewness(val2) from aggregate_function_skew; + """ + qt_sql_4 """ + select skewness(val1),skewness(val2) from aggregate_function_skew group by tag; + """ + + qt_sql_distinct_1 """ + select skewness(distinct val1) from aggregate_function_skew; + """ + qt_sql_distinct_2 """ + select skewness(distinct val2) from aggregate_function_skew; + """ + + qt_sql_distinct_3 """ + select skewness(distinct val1) from aggregate_function_skew group by tag; + """ + qt_sql_distinct_4 """ + select skewness(distinct val2) from aggregate_function_skew group by tag; + """ + + sql """ + insert into aggregate_function_skew select * from aggregate_function_skew; + """ + + qt_sql_5 """ + select skew(val1),skew_pop(val2) from aggregate_function_skew; + """ + qt_sql_6 """ + select skew(val1),skew_pop(val2) from aggregate_function_skew group by tag; + """ +} \ No newline at end of file From 094cfc419a36193595625a0058b9352ac0fdde4b Mon Sep 17 00:00:00 2001 From: zhiqiang-hhhh Date: Wed, 18 Sep 2024 22:24:41 +0800 Subject: [PATCH 2/8] X --- .../aggregate_function_simple_factory.h | 7 --- .../org/apache/doris/catalog/FunctionSet.java | 58 +++++++++---------- 2 files changed, 29 insertions(+), 36 deletions(-) diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h index 22f2e0f8bae1d6..7d83bcbbf04945 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h +++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h @@ -143,19 +143,12 @@ class AggregateFunctionSimpleFactory { if (function_alias.contains(name)) { name_str = function_alias[name]; } - LOG_INFO("Find function name: {}", name_str); if (nullable) { - LOG_INFO("Nullable input function {} found: {}", name, - nullable_aggregate_functions.find(name_str) != - nullable_aggregate_functions.end()); return nullable_aggregate_functions.find(name_str) == nullable_aggregate_functions.end() ? nullptr : nullable_aggregate_functions[name_str](name_str, argument_types, result_is_nullable); } else { - LOG_INFO("Not input function {} found: {}", name, - nullable_aggregate_functions.find(name_str) != - nullable_aggregate_functions.end()); return aggregate_functions.find(name_str) == aggregate_functions.end() ? nullptr : aggregate_functions[name_str](name_str, argument_types, diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java index 94501f8caccb6d..2e7063a0aec17a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java @@ -1943,38 +1943,38 @@ private void initAggregateBuiltins() { List skewnessAndKurtosis = Lists.newArrayList("skew", "skew_pop", "skewness", "kurt", - "kurt_pop", "kurtosis"); + "kurt_pop", "kurtosis"); skewnessAndKurtosis.addAll(skewnessAndKurtosis); for (String name : skewnessAndKurtosis) { - addBuiltin(AggregateFunction.createBuiltin(name, - Lists.newArrayList(Type.TINYINT), Type.DOUBLE, Type.DOUBLE, - "", "", "", "", "", "", "", - false, false, false, true)); - addBuiltin(AggregateFunction.createBuiltin(name, - Lists.newArrayList(Type.SMALLINT), Type.DOUBLE, Type.DOUBLE, - "", "", "", "", "", "", "", - false, false, false, true)); - addBuiltin(AggregateFunction.createBuiltin(name, - Lists.newArrayList(Type.INT), Type.DOUBLE, Type.DOUBLE, - "", "", "", "", "", "", "", - false, false, false, true)); - addBuiltin(AggregateFunction.createBuiltin(name, - Lists.newArrayList(Type.BIGINT), Type.DOUBLE, Type.DOUBLE, - "", "", "", "", "", "", "", - false, false, false, true)); - addBuiltin(AggregateFunction.createBuiltin(name, - Lists.newArrayList(Type.LARGEINT), Type.DOUBLE, Type.DOUBLE, - "", "", "", "", "", "", "", - false, false, false, true)); - addBuiltin(AggregateFunction.createBuiltin(name, - Lists.newArrayList(Type.FLOAT), Type.DOUBLE, Type.DOUBLE, - "", "", "", "", "", "", "", - false, false, false, true)); - addBuiltin(AggregateFunction.createBuiltin(name, - Lists.newArrayList(Type.DOUBLE), Type.DOUBLE, Type.DOUBLE, - "", "", "", "", "", "", "", - false, false, false, true)); + addBuiltin(AggregateFunction.createBuiltin(name, + Lists.newArrayList(Type.TINYINT), Type.DOUBLE, Type.DOUBLE, + "", "", "", "", "", "", "", + false, false, false, true)); + addBuiltin(AggregateFunction.createBuiltin(name, + Lists.newArrayList(Type.SMALLINT), Type.DOUBLE, Type.DOUBLE, + "", "", "", "", "", "", "", + false, false, false, true)); + addBuiltin(AggregateFunction.createBuiltin(name, + Lists.newArrayList(Type.INT), Type.DOUBLE, Type.DOUBLE, + "", "", "", "", "", "", "", + false, false, false, true)); + addBuiltin(AggregateFunction.createBuiltin(name, + Lists.newArrayList(Type.BIGINT), Type.DOUBLE, Type.DOUBLE, + "", "", "", "", "", "", "", + false, false, false, true)); + addBuiltin(AggregateFunction.createBuiltin(name, + Lists.newArrayList(Type.LARGEINT), Type.DOUBLE, Type.DOUBLE, + "", "", "", "", "", "", "", + false, false, false, true)); + addBuiltin(AggregateFunction.createBuiltin(name, + Lists.newArrayList(Type.FLOAT), Type.DOUBLE, Type.DOUBLE, + "", "", "", "", "", "", "", + false, false, false, true)); + addBuiltin(AggregateFunction.createBuiltin(name, + Lists.newArrayList(Type.DOUBLE), Type.DOUBLE, Type.DOUBLE, + "", "", "", "", "", "", "", + false, false, false, true)); } } From c5fa11c0855a5362bca5e665b81d05b564c1aee6 Mon Sep 17 00:00:00 2001 From: zhiqiang-hhhh Date: Thu, 19 Sep 2024 14:11:37 +0800 Subject: [PATCH 3/8] FIX NULLABLE --- .../aggregate_function_kurtosis.cpp | 47 ++++++++---- .../aggregate_function_simple_factory.cpp | 8 +- .../aggregate_function_skew.cpp | 36 +++++++-- .../aggregate_function_statistic.h | 74 ++++++++++--------- .../trees/expressions/functions/agg/Kurt.java | 18 ++--- .../trees/expressions/functions/agg/Skew.java | 19 ++--- .../visitor/AggregateFunctionVisitor.java | 4 +- .../aggregate/aggregate_function_kurt.out | 46 ++++++------ .../aggregate/aggregate_function_skew.out | 46 ++++++------ 9 files changed, 166 insertions(+), 132 deletions(-) diff --git a/be/src/vec/aggregate_functions/aggregate_function_kurtosis.cpp b/be/src/vec/aggregate_functions/aggregate_function_kurtosis.cpp index e76507594e7b9d..53a7da582d66c7 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_kurtosis.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_kurtosis.cpp @@ -22,26 +22,47 @@ #include "vec/aggregate_functions/helpers.h" #include "vec/core/types.h" #include "vec/data_types/data_type.h" +#include "vec/data_types/data_type_nullable.h" namespace doris::vectorized { template -AggregateFunctionPtr type_dispatch_for_aggregate_function_kurtosis(const DataTypes& argument_types, - const bool result_is_nullable) { +AggregateFunctionPtr type_dispatch_for_aggregate_function_kurt(const DataTypes& argument_types, + const bool result_is_nullable, + bool nullable_input) { using StatFunctionTemplate = StatFuncOneArg; - return creator_without_type::create>( - argument_types, result_is_nullable, StatisticsFunctionKind::kurtPop); + + if (nullable_input) { + return creator_without_type::create_ignore_nullable< + AggregateFunctionVarianceSimple>( + argument_types, result_is_nullable, StatisticsFunctionKind::kurtPop); + } else { + return creator_without_type::create_ignore_nullable< + AggregateFunctionVarianceSimple>( + argument_types, result_is_nullable, StatisticsFunctionKind::kurtPop); + } }; -AggregateFunctionPtr create_aggregate_function_kurtosis(const std::string& name, - const DataTypes& argument_types, - const bool result_is_nullable) { +AggregateFunctionPtr create_aggregate_function_kurt(const std::string& name, + const DataTypes& argument_types, + const bool result_is_nullable) { + if (argument_types.size() != 1) { + LOG(WARNING) << "aggregate function " << name << " requires exactly 1 argument"; + return nullptr; + } + + if (!result_is_nullable) { + LOG(WARNING) << "aggregate function " << name << " requires nullable result type"; + return nullptr; + } + + const bool nullable_input = argument_types[0]->is_nullable(); WhichDataType type(remove_nullable(argument_types[0])); -#define DISPATCH(TYPE) \ - if (type.idx == TypeIndex::TYPE) \ - return type_dispatch_for_aggregate_function_kurtosis(argument_types, \ - result_is_nullable); +#define DISPATCH(TYPE) \ + if (type.idx == TypeIndex::TYPE) \ + return type_dispatch_for_aggregate_function_kurt(argument_types, result_is_nullable, \ + nullable_input); FOR_NUMERIC_TYPES(DISPATCH) #undef DISPATCH @@ -50,8 +71,8 @@ AggregateFunctionPtr create_aggregate_function_kurtosis(const std::string& name, return nullptr; } -void register_aggregate_function_kurtosis_pop(AggregateFunctionSimpleFactory& factory) { - factory.register_function_both("kurt", create_aggregate_function_kurtosis); +void register_aggregate_function_kurtosis(AggregateFunctionSimpleFactory& factory) { + factory.register_function_both("kurt", create_aggregate_function_kurt); factory.register_alias("kurt", "kurt_pop"); factory.register_alias("kurt", "kurtosis"); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp index 19a0eacfe8d2fc..d11ec714889be9 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp @@ -68,8 +68,8 @@ void register_aggregate_function_bitmap_agg(AggregateFunctionSimpleFactory& fact void register_aggregate_functions_corr(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_covar_pop(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_covar_samp(AggregateFunctionSimpleFactory& factory); -void register_aggregate_function_skew_pop(AggregateFunctionSimpleFactory& factory); -void register_aggregate_function_kurtosis_pop(AggregateFunctionSimpleFactory& factory); +void register_aggregate_function_skewness(AggregateFunctionSimpleFactory& factory); +void register_aggregate_function_kurtosis(AggregateFunctionSimpleFactory& factory); AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() { static std::once_flag oc; @@ -122,8 +122,8 @@ AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() { register_aggregate_function_combinator_foreach(instance); - register_aggregate_function_skew_pop(instance); - register_aggregate_function_kurtosis_pop(instance); + register_aggregate_function_skewness(instance); + register_aggregate_function_kurtosis(instance); }); return instance; } diff --git a/be/src/vec/aggregate_functions/aggregate_function_skew.cpp b/be/src/vec/aggregate_functions/aggregate_function_skew.cpp index 6acc12faa7dd00..95e0715fcd7c1e 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_skew.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_skew.cpp @@ -22,25 +22,47 @@ #include "vec/aggregate_functions/helpers.h" #include "vec/core/types.h" #include "vec/data_types/data_type.h" +#include "vec/data_types/data_type_nullable.h" namespace doris::vectorized { template AggregateFunctionPtr type_dispatch_for_aggregate_function_skew(const DataTypes& argument_types, - const bool result_is_nullable) { + const bool result_is_nullable, + bool nullable_input) { using StatFunctionTemplate = StatFuncOneArg; - return creator_without_type::create>( - argument_types, result_is_nullable, StatisticsFunctionKind::skewPop); + + if (nullable_input) { + return creator_without_type::create_ignore_nullable< + AggregateFunctionVarianceSimple>( + argument_types, result_is_nullable, StatisticsFunctionKind::skewPop); + } else { + return creator_without_type::create_ignore_nullable< + AggregateFunctionVarianceSimple>( + argument_types, result_is_nullable, StatisticsFunctionKind::skewPop); + } }; AggregateFunctionPtr create_aggregate_function_skew(const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { + if (argument_types.size() != 1) { + LOG(WARNING) << "aggregate function " << name << " requires exactly 1 argument"; + return nullptr; + } + + if (!result_is_nullable) { + LOG(WARNING) << "aggregate function " << name << " requires nullable result type"; + return nullptr; + } + + const bool nullable_input = argument_types[0]->is_nullable(); WhichDataType type(remove_nullable(argument_types[0])); -#define DISPATCH(TYPE) \ - if (type.idx == TypeIndex::TYPE) \ - return type_dispatch_for_aggregate_function_skew(argument_types, result_is_nullable); +#define DISPATCH(TYPE) \ + if (type.idx == TypeIndex::TYPE) \ + return type_dispatch_for_aggregate_function_skew(argument_types, result_is_nullable, \ + nullable_input); FOR_NUMERIC_TYPES(DISPATCH) #undef DISPATCH @@ -49,7 +71,7 @@ AggregateFunctionPtr create_aggregate_function_skew(const std::string& name, return nullptr; } -void register_aggregate_function_skew_pop(AggregateFunctionSimpleFactory& factory) { +void register_aggregate_function_skewness(AggregateFunctionSimpleFactory& factory) { factory.register_function_both("skew", create_aggregate_function_skew); factory.register_alias("skew", "skew_pop"); factory.register_alias("skew", "skewness"); diff --git a/be/src/vec/aggregate_functions/aggregate_function_statistic.h b/be/src/vec/aggregate_functions/aggregate_function_statistic.h index 5fab9812e0a53d..f86ce017e0f97b 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_statistic.h +++ b/be/src/vec/aggregate_functions/aggregate_function_statistic.h @@ -28,6 +28,7 @@ #include "vec/aggregate_functions/moments.h" #include "vec/columns/column_nullable.h" #include "vec/columns/column_vector.h" +#include "vec/common/assert_cast.h" #include "vec/core/types.h" #include "vec/data_types/data_type.h" #include "vec/data_types/data_type_nullable.h" @@ -58,10 +59,11 @@ struct StatFuncOneArg { static constexpr UInt32 num_args = 1; }; -template +template class AggregateFunctionVarianceSimple - : public IAggregateFunctionDataHelper> { + : public IAggregateFunctionDataHelper< + typename StatFunc::Data, + AggregateFunctionVarianceSimple> { public: using T1 = typename StatFunc::Type1; using T2 = typename StatFunc::Type2; @@ -72,20 +74,30 @@ class AggregateFunctionVarianceSimple explicit AggregateFunctionVarianceSimple(StatisticsFunctionKind kind_, const DataTypes& argument_types_) - : IAggregateFunctionDataHelper>( - argument_types_), + : IAggregateFunctionDataHelper< + typename StatFunc::Data, + AggregateFunctionVarianceSimple>(argument_types_), kind(kind_) { DCHECK(!argument_types_.empty()); } String get_name() const override { return to_string(kind); } - DataTypePtr get_return_type() const override { return std::make_shared(); } + DataTypePtr get_return_type() const override { + return make_nullable(std::make_shared()); + } void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena*) const override { - this->data(place).add(assert_cast(*columns[0]).get_data()[row_num]); + if constexpr (NullableInput) { + const ColumnNullable& column_with_nullable = + assert_cast(*columns[0]); + this->data(place).add( + assert_cast(column_with_nullable.get_nested_column()) + .get_data()[row_num]); + } else { + this->data(place).add(assert_cast(*columns[0]).get_data()[row_num]); + } } void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, @@ -104,43 +116,37 @@ class AggregateFunctionVarianceSimple void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { const auto& data = this->data(place); - ColVecResult* dst_column = assert_cast(&to); - + ColumnNullable& dst_column_with_nullable = assert_cast(to); + ColVecResult* dst_column = + assert_cast(&(dst_column_with_nullable.get_nested_column())); + ; switch (kind) { case StatisticsFunctionKind::skewPop: { // If input is empty set, we will get NAN from getPopulation() ResultType var_value = data.getPopulation(); - if (!std::isnan(var_value) && var_value > 0) { - ResultType moments3 = data.getMoment3(); - if (!std::isnan(moments3)) [[likely]] { - dst_column->get_data().push_back( - static_cast(moments3 / pow(var_value, 1.5))); - } else { - throw doris::Exception(ErrorCode::INTERNAL_ERROR, - "skewness calculation error, result is NAN"); - } - } else { - // Empty input, result column will be: - // Nullable if without group by - // Nullable if with group by, and input column is nullable - // Non-Nullable if with group by, and input column is non-nullable + ResultType moments_3 = data.getMoment3(); + + if (std::isnan(var_value) || std::isnan(moments_3) || var_value <= 0) { + dst_column_with_nullable.get_null_map_data().push_back(1); dst_column->insert_default(); + } else { + dst_column_with_nullable.get_null_map_data().push_back(0); + dst_column->get_data().push_back( + static_cast(moments_3 / pow(var_value, 1.5))); } break; } case StatisticsFunctionKind::kurtPop: { ResultType var_value = data.getPopulation(); - if (!std::isnan(var_value) && var_value > 0) { - ResultType moments4 = data.getMoment4(); - if (!std::isnan(moments4)) [[likely]] { - dst_column->get_data().push_back( - static_cast(moments4 / pow(var_value, 2))); - } else { - throw doris::Exception(ErrorCode::INTERNAL_ERROR, - "skewness calculation error, result is NAN"); - } - } else { + ResultType moments_4 = data.getMoment4(); + + if (std::isnan(var_value) || std::isnan(moments_4) || var_value <= 0) { + dst_column_with_nullable.get_null_map_data().push_back(1); dst_column->insert_default(); + } else { + dst_column_with_nullable.get_null_map_data().push_back(0); + dst_column->get_data().push_back( + static_cast(moments_4 / pow(var_value, 2))); } break; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Kurt.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Kurt.java index 206618505f436f..3480f1c53c38a6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Kurt.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Kurt.java @@ -19,6 +19,7 @@ import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable; import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; @@ -38,8 +39,8 @@ /** * AggregateFunction 'Kurt'. */ -public class Kurt extends NullableAggregateFunction - implements UnaryExpression, ExplicitlyCastableSignature { +public class Kurt extends AggregateFunction + implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNullable { public static final List SIGNATURES = ImmutableList.of( FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE), @@ -56,11 +57,7 @@ public Kurt(Expression arg1) { } public Kurt(boolean distinct, Expression arg1) { - this(distinct, false, arg1); - } - - public Kurt(boolean distinct, boolean alwaysNullable, Expression arg1) { - super("Kurt", distinct, alwaysNullable, arg1); + super("kurt", distinct, arg1); } /** @@ -69,12 +66,7 @@ public Kurt(boolean distinct, boolean alwaysNullable, Expression arg1) { @Override public Kurt withDistinctAndChildren(boolean distinct, List children) { Preconditions.checkArgument(children.size() == 1); - return new Kurt(distinct, alwaysNullable, children.get(0)); - } - - @Override - public Kurt withAlwaysNullable(boolean alwaysNullable) { - return new Kurt(distinct, alwaysNullable, children.get(0)); + return new Kurt(distinct, children.get(0)); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Skew.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Skew.java index b8b6a7976b1813..24b14d08ef12da 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Skew.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Skew.java @@ -19,6 +19,7 @@ import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable; import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; @@ -38,8 +39,9 @@ /** * AggregateFunction 'Skew'. */ -public class Skew extends NullableAggregateFunction - implements UnaryExpression, ExplicitlyCastableSignature { + +public class Skew extends AggregateFunction + implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNullable { public static final List SIGNATURES = ImmutableList.of( FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE), @@ -56,11 +58,7 @@ public Skew(Expression arg1) { } public Skew(boolean distinct, Expression arg1) { - this(distinct, false, arg1); - } - - public Skew(boolean distinct, boolean alwaysNullable, Expression arg1) { - super("Skew", distinct, alwaysNullable, arg1); + super("skew", distinct, arg1); } /** @@ -69,12 +67,7 @@ public Skew(boolean distinct, boolean alwaysNullable, Expression arg1) { @Override public Skew withDistinctAndChildren(boolean distinct, List children) { Preconditions.checkArgument(children.size() == 1); - return new Skew(distinct, alwaysNullable, children.get(0)); - } - - @Override - public Skew withAlwaysNullable(boolean alwaysNullable) { - return new Skew(distinct, alwaysNullable, children.get(0)); + return new Skew(distinct, children.get(0)); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java index 5d3885abef046a..b0f39ca6f7ef74 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java @@ -213,7 +213,7 @@ default R visitIntersectCount(IntersectCount intersectCount, C context) { } default R visitKurt(Kurt kurt, C context) { - return visitNullableAggregateFunction(kurt, context); + return visitAggregateFunction(kurt, context); } default R visitMapAgg(MapAgg mapAgg, C context) { @@ -285,7 +285,7 @@ default R visitSequenceMatch(SequenceMatch sequenceMatch, C context) { } default R visitSkew(Skew skew, C context) { - return visitNullableAggregateFunction(skew, context); + return visitAggregateFunction(skew, context); } default R visitStddev(Stddev stddev, C context) { diff --git a/regression-test/data/query_p0/aggregate/aggregate_function_kurt.out b/regression-test/data/query_p0/aggregate/aggregate_function_kurt.out index 15571714e3cbf1..4c5d85db37c2d1 100644 --- a/regression-test/data/query_p0/aggregate/aggregate_function_kurt.out +++ b/regression-test/data/query_p0/aggregate/aggregate_function_kurt.out @@ -5,48 +5,48 @@ -- !sql_empty_2 -- -- !sql_1 -- -0.0 0.0 +\N \N -- !sql_2 -- -0.0 0.0 +\N \N -- !sql_3 -- -3.162124583734851 1.5000000000000007 +3.162124583734851 3.195564571744395 -- !sql_4 -- -0.0 0.0 -0.0 \N -0.0 0.0 -0.0 \N -0.0 0.0 +\N \N +\N \N +\N \N +\N \N +\N \N -- !sql_distinct_1 -- 2.2985631952470373 -- !sql_distinct_2 -- -1.5000000000000007 +2.3039881240500497 -- !sql_distinct_3 -- -0.0 -0.0 -0.0 -0.0 -0.0 +\N +\N +\N +\N +\N -- !sql_distinct_4 -- -0.0 \N -0.0 \N -0.0 +\N +\N +\N -- !sql_5 -- -3.162124583734851 1.5000000000000007 +3.162124583734851 3.195564571744395 -- !sql_6 -- -0.0 0.0 -0.0 \N -0.0 0.0 -0.0 \N -0.0 0.0 +\N \N +\N \N +\N \N +\N \N +\N \N diff --git a/regression-test/data/query_p0/aggregate/aggregate_function_skew.out b/regression-test/data/query_p0/aggregate/aggregate_function_skew.out index 5ed27c441fdbfc..8c271a7eb05124 100644 --- a/regression-test/data/query_p0/aggregate/aggregate_function_skew.out +++ b/regression-test/data/query_p0/aggregate/aggregate_function_skew.out @@ -5,48 +5,48 @@ -- !sql_empty_2 -- -- !sql_1 -- -0.0 0.0 +\N \N -- !sql_2 -- -0.0 0.0 +\N \N -- !sql_3 -- -1.4337199628825619 0.675885787569108 +1.4337199628825619 1.4622886709763663 -- !sql_4 -- -0.0 \N -0.0 0.0 -0.0 0.0 -0.0 0.0 -0.0 \N +\N \N +\N \N +\N \N +\N \N +\N \N -- !sql_distinct_1 -- 1.1135657469022011 -- !sql_distinct_2 -- -0.675885787569108 +1.1197287397085194 -- !sql_distinct_3 -- -0.0 -0.0 -0.0 -0.0 -0.0 +\N +\N +\N +\N +\N -- !sql_distinct_4 -- \N -0.0 -0.0 -0.0 +\N +\N +\N \N -- !sql_5 -- -1.4337199628825619 0.675885787569108 +1.4337199628825619 1.4622886709763663 -- !sql_6 -- -0.0 \N -0.0 0.0 -0.0 0.0 -0.0 0.0 -0.0 \N +\N \N +\N \N +\N \N +\N \N +\N \N From 43091a90ca6375a84b1c46ecb737e15e81ae3c84 Mon Sep 17 00:00:00 2001 From: zhiqiang-hhhh Date: Fri, 20 Sep 2024 14:50:37 +0800 Subject: [PATCH 4/8] X --- .../aggregate_function_kurtosis.cpp | 4 +- .../aggregate_function_skew.cpp | 4 +- .../aggregate_function_statistic.h | 70 +++++++++---------- be/src/vec/aggregate_functions/moments.h | 28 ++++---- .../trees/expressions/functions/agg/Kurt.java | 10 ++- .../trees/expressions/functions/agg/Skew.java | 10 ++- .../aggregate/aggregate_function_kurt.out | 4 +- .../aggregate/aggregate_function_skew.out | 4 +- 8 files changed, 65 insertions(+), 69 deletions(-) diff --git a/be/src/vec/aggregate_functions/aggregate_function_kurtosis.cpp b/be/src/vec/aggregate_functions/aggregate_function_kurtosis.cpp index 53a7da582d66c7..00ad1893eafcf6 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_kurtosis.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_kurtosis.cpp @@ -35,11 +35,11 @@ AggregateFunctionPtr type_dispatch_for_aggregate_function_kurt(const DataTypes& if (nullable_input) { return creator_without_type::create_ignore_nullable< AggregateFunctionVarianceSimple>( - argument_types, result_is_nullable, StatisticsFunctionKind::kurtPop); + argument_types, result_is_nullable, STATISTICS_FUNCTION_KIND::KURT_POP); } else { return creator_without_type::create_ignore_nullable< AggregateFunctionVarianceSimple>( - argument_types, result_is_nullable, StatisticsFunctionKind::kurtPop); + argument_types, result_is_nullable, STATISTICS_FUNCTION_KIND::KURT_POP); } }; diff --git a/be/src/vec/aggregate_functions/aggregate_function_skew.cpp b/be/src/vec/aggregate_functions/aggregate_function_skew.cpp index 95e0715fcd7c1e..144e482ad239ed 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_skew.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_skew.cpp @@ -35,11 +35,11 @@ AggregateFunctionPtr type_dispatch_for_aggregate_function_skew(const DataTypes& if (nullable_input) { return creator_without_type::create_ignore_nullable< AggregateFunctionVarianceSimple>( - argument_types, result_is_nullable, StatisticsFunctionKind::skewPop); + argument_types, result_is_nullable, STATISTICS_FUNCTION_KIND::SKEW_POP); } else { return creator_without_type::create_ignore_nullable< AggregateFunctionVarianceSimple>( - argument_types, result_is_nullable, StatisticsFunctionKind::skewPop); + argument_types, result_is_nullable, STATISTICS_FUNCTION_KIND::SKEW_POP); } }; diff --git a/be/src/vec/aggregate_functions/aggregate_function_statistic.h b/be/src/vec/aggregate_functions/aggregate_function_statistic.h index f86ce017e0f97b..1d1d508971d5d8 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_statistic.h +++ b/be/src/vec/aggregate_functions/aggregate_function_statistic.h @@ -36,13 +36,13 @@ namespace doris::vectorized { -enum class StatisticsFunctionKind : uint8_t { skewPop, kurtPop }; +enum class STATISTICS_FUNCTION_KIND : uint8_t { SKEW_POP, KURT_POP }; -inline std::string to_string(StatisticsFunctionKind kind) { +inline std::string to_string(STATISTICS_FUNCTION_KIND kind) { switch (kind) { - case StatisticsFunctionKind::skewPop: + case STATISTICS_FUNCTION_KIND::SKEW_POP: return "skewness"; - case StatisticsFunctionKind::kurtPop: + case STATISTICS_FUNCTION_KIND::KURT_POP: return "kurtosis"; default: return "Unknown"; @@ -51,12 +51,8 @@ inline std::string to_string(StatisticsFunctionKind kind) { template struct StatFuncOneArg { - using Type1 = T; - using Type2 = T; - using ResultType = Float64; - using Data = VarMoments; - - static constexpr UInt32 num_args = 1; + using Type = T; + using Data = VarMoments; }; template @@ -65,14 +61,10 @@ class AggregateFunctionVarianceSimple typename StatFunc::Data, AggregateFunctionVarianceSimple> { public: - using T1 = typename StatFunc::Type1; - using T2 = typename StatFunc::Type2; - using ColVecT1 = ColumnVectorOrDecimal; - using ColVecT2 = ColumnVectorOrDecimal; - using ResultType = typename StatFunc::ResultType; - using ColVecResult = ColumnVector; - - explicit AggregateFunctionVarianceSimple(StatisticsFunctionKind kind_, + using InputCol = ColumnVector; + using ResultCol = ColumnVector; + + explicit AggregateFunctionVarianceSimple(STATISTICS_FUNCTION_KIND kind_, const DataTypes& argument_types_) : IAggregateFunctionDataHelper< typename StatFunc::Data, @@ -91,12 +83,18 @@ class AggregateFunctionVarianceSimple Arena*) const override { if constexpr (NullableInput) { const ColumnNullable& column_with_nullable = - assert_cast(*columns[0]); - this->data(place).add( - assert_cast(column_with_nullable.get_nested_column()) - .get_data()[row_num]); + assert_cast(*columns[0]); + + if (column_with_nullable.is_null_at(row_num)) { + return; + } else { + this->data(place).add(assert_cast( + column_with_nullable.get_nested_column()) + .get_data()[row_num]); + } + } else { - this->data(place).add(assert_cast(*columns[0]).get_data()[row_num]); + this->data(place).add(assert_cast(*columns[0]).get_data()[row_num]); } } @@ -117,14 +115,14 @@ class AggregateFunctionVarianceSimple void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { const auto& data = this->data(place); ColumnNullable& dst_column_with_nullable = assert_cast(to); - ColVecResult* dst_column = - assert_cast(&(dst_column_with_nullable.get_nested_column())); - ; + ResultCol* dst_column = + assert_cast(&(dst_column_with_nullable.get_nested_column())); + switch (kind) { - case StatisticsFunctionKind::skewPop: { - // If input is empty set, we will get NAN from getPopulation() - ResultType var_value = data.getPopulation(); - ResultType moments_3 = data.getMoment3(); + case STATISTICS_FUNCTION_KIND::SKEW_POP: { + // If input is empty set, we will get NAN from get_population() + Float64 var_value = data.get_population(); + Float64 moments_3 = data.get_moment_3(); if (std::isnan(var_value) || std::isnan(moments_3) || var_value <= 0) { dst_column_with_nullable.get_null_map_data().push_back(1); @@ -132,13 +130,13 @@ class AggregateFunctionVarianceSimple } else { dst_column_with_nullable.get_null_map_data().push_back(0); dst_column->get_data().push_back( - static_cast(moments_3 / pow(var_value, 1.5))); + static_cast(moments_3 / pow(var_value, 1.5))); } break; } - case StatisticsFunctionKind::kurtPop: { - ResultType var_value = data.getPopulation(); - ResultType moments_4 = data.getMoment4(); + case STATISTICS_FUNCTION_KIND::KURT_POP: { + Float64 var_value = data.get_population(); + Float64 moments_4 = data.get_moment_4(); if (std::isnan(var_value) || std::isnan(moments_4) || var_value <= 0) { dst_column_with_nullable.get_null_map_data().push_back(1); @@ -146,7 +144,7 @@ class AggregateFunctionVarianceSimple } else { dst_column_with_nullable.get_null_map_data().push_back(0); dst_column->get_data().push_back( - static_cast(moments_4 / pow(var_value, 2))); + static_cast(moments_4 / pow(var_value, 2))); } break; } @@ -156,7 +154,7 @@ class AggregateFunctionVarianceSimple } private: - StatisticsFunctionKind kind; + STATISTICS_FUNCTION_KIND kind; }; } // namespace doris::vectorized \ No newline at end of file diff --git a/be/src/vec/aggregate_functions/moments.h b/be/src/vec/aggregate_functions/moments.h index fe189e9f74ca52..d9db12774709bc 100644 --- a/be/src/vec/aggregate_functions/moments.h +++ b/be/src/vec/aggregate_functions/moments.h @@ -58,11 +58,10 @@ struct VarMoments { T get() const { throw doris::Exception(ErrorCode::INTERNAL_ERROR, - "Variation moments should be obtained by either 'getSample' or " - "'getPopulation' method"); + "Variation moments should be obtained by 'get_population' method"); } - T getPopulation() const { + T get_population() const { if (m[0] == 0) return std::numeric_limits::quiet_NaN(); /// Due to numerical errors, the result can be slightly less than zero, @@ -71,16 +70,16 @@ struct VarMoments { return std::max(T {}, (m[2] - m[1] * m[1] / m[0]) / m[0]); } - T getSample() const { + T get_sample() const { if (m[0] <= 1) return std::numeric_limits::quiet_NaN(); return std::max(T {}, (m[2] - m[1] * m[1] / m[0]) / (m[0] - 1)); } - T getMoment3() const { + T get_moment_3() const { if constexpr (_level < 3) { - throw doris::Exception(ErrorCode::INTERNAL_ERROR, - "Variation moments should be obtained by either 'getSample' or " - "'getPopulation' method"); + throw doris::Exception( + ErrorCode::INTERNAL_ERROR, + "Variation moments should be obtained by 'get_population' method"); } else { if (m[0] == 0) return std::numeric_limits::quiet_NaN(); // to avoid accuracy problem @@ -90,11 +89,11 @@ struct VarMoments { } } - T getMoment4() const { + T get_moment_4() const { if constexpr (_level < 4) { - throw doris::Exception(ErrorCode::INTERNAL_ERROR, - "Variation moments should be obtained by either 'getSample' or " - "'getPopulation' method"); + throw doris::Exception( + ErrorCode::INTERNAL_ERROR, + "Variation moments should be obtained by 'get_population' method"); } else { if (m[0] == 0) return std::numeric_limits::quiet_NaN(); // to avoid accuracy problem @@ -106,7 +105,10 @@ struct VarMoments { } } - void reset() { return; } + void reset() { + m = {}; + return; + } }; } // namespace doris::vectorized \ No newline at end of file diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Kurt.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Kurt.java index 3480f1c53c38a6..13b24838e2edf3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Kurt.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Kurt.java @@ -27,7 +27,6 @@ import org.apache.doris.nereids.types.DoubleType; import org.apache.doris.nereids.types.FloatType; import org.apache.doris.nereids.types.IntegerType; -import org.apache.doris.nereids.types.LargeIntType; import org.apache.doris.nereids.types.SmallIntType; import org.apache.doris.nereids.types.TinyIntType; @@ -43,13 +42,12 @@ public class Kurt extends AggregateFunction implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNullable { public static final List SIGNATURES = ImmutableList.of( - FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE), FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE), - FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE), - FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE), - FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE), FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE), - FunctionSignature.ret(DoubleType.INSTANCE).args(LargeIntType.INSTANCE) + FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE) ); public Kurt(Expression arg1) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Skew.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Skew.java index 24b14d08ef12da..4041b7a386339f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Skew.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Skew.java @@ -27,7 +27,6 @@ import org.apache.doris.nereids.types.DoubleType; import org.apache.doris.nereids.types.FloatType; import org.apache.doris.nereids.types.IntegerType; -import org.apache.doris.nereids.types.LargeIntType; import org.apache.doris.nereids.types.SmallIntType; import org.apache.doris.nereids.types.TinyIntType; @@ -44,13 +43,12 @@ public class Skew extends AggregateFunction implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNullable { public static final List SIGNATURES = ImmutableList.of( - FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE), FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE), - FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE), - FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE), - FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE), FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE), - FunctionSignature.ret(DoubleType.INSTANCE).args(LargeIntType.INSTANCE) + FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE) ); public Skew(Expression arg1) { diff --git a/regression-test/data/query_p0/aggregate/aggregate_function_kurt.out b/regression-test/data/query_p0/aggregate/aggregate_function_kurt.out index 4c5d85db37c2d1..c5b0fe88d5b85b 100644 --- a/regression-test/data/query_p0/aggregate/aggregate_function_kurt.out +++ b/regression-test/data/query_p0/aggregate/aggregate_function_kurt.out @@ -11,7 +11,7 @@ \N \N -- !sql_3 -- -3.162124583734851 3.195564571744395 +3.162124583734851 2.3039881240500497 -- !sql_4 -- \N \N @@ -41,7 +41,7 @@ \N -- !sql_5 -- -3.162124583734851 3.195564571744395 +3.162124583734851 2.3039881240500497 -- !sql_6 -- \N \N diff --git a/regression-test/data/query_p0/aggregate/aggregate_function_skew.out b/regression-test/data/query_p0/aggregate/aggregate_function_skew.out index 8c271a7eb05124..171e96da057a9c 100644 --- a/regression-test/data/query_p0/aggregate/aggregate_function_skew.out +++ b/regression-test/data/query_p0/aggregate/aggregate_function_skew.out @@ -11,7 +11,7 @@ \N \N -- !sql_3 -- -1.4337199628825619 1.4622886709763663 +1.4337199628825619 1.1197287397085194 -- !sql_4 -- \N \N @@ -41,7 +41,7 @@ \N -- !sql_5 -- -1.4337199628825619 1.4622886709763663 +1.4337199628825619 1.1197287397085194 -- !sql_6 -- \N \N From 2392c7df980f8e1f1fa46bc595fe25a67fb5b23b Mon Sep 17 00:00:00 2001 From: zhiqiang-hhhh Date: Fri, 20 Sep 2024 14:54:48 +0800 Subject: [PATCH 5/8] ASSERTCAT --- be/src/vec/aggregate_functions/aggregate_function_statistic.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/be/src/vec/aggregate_functions/aggregate_function_statistic.h b/be/src/vec/aggregate_functions/aggregate_function_statistic.h index 1d1d508971d5d8..09947c20c7b246 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_statistic.h +++ b/be/src/vec/aggregate_functions/aggregate_function_statistic.h @@ -94,7 +94,9 @@ class AggregateFunctionVarianceSimple } } else { - this->data(place).add(assert_cast(*columns[0]).get_data()[row_num]); + this->data(place).add( + assert_cast(*columns[0]) + .get_data()[row_num]); } } From 28bf4ec0980d5c57c25ad402ca84b2a54c741267 Mon Sep 17 00:00:00 2001 From: zhiqiang-hhhh Date: Fri, 20 Sep 2024 19:08:15 +0800 Subject: [PATCH 6/8] X --- .../data/query_p0/aggregate/aggregate_function_kurt.out | 6 +++--- .../data/query_p0/aggregate/aggregate_function_skew.out | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/regression-test/data/query_p0/aggregate/aggregate_function_kurt.out b/regression-test/data/query_p0/aggregate/aggregate_function_kurt.out index c5b0fe88d5b85b..0c078b9e7ef5c5 100644 --- a/regression-test/data/query_p0/aggregate/aggregate_function_kurt.out +++ b/regression-test/data/query_p0/aggregate/aggregate_function_kurt.out @@ -11,7 +11,7 @@ \N \N -- !sql_3 -- -3.162124583734851 2.3039881240500497 +3.162124583734851 1.5000000000000007 -- !sql_4 -- \N \N @@ -24,7 +24,7 @@ 2.2985631952470373 -- !sql_distinct_2 -- -2.3039881240500497 +1.5000000000000007 -- !sql_distinct_3 -- \N @@ -41,7 +41,7 @@ \N -- !sql_5 -- -3.162124583734851 2.3039881240500497 +3.162124583734851 1.5000000000000007 -- !sql_6 -- \N \N diff --git a/regression-test/data/query_p0/aggregate/aggregate_function_skew.out b/regression-test/data/query_p0/aggregate/aggregate_function_skew.out index 171e96da057a9c..3320371dfbb37c 100644 --- a/regression-test/data/query_p0/aggregate/aggregate_function_skew.out +++ b/regression-test/data/query_p0/aggregate/aggregate_function_skew.out @@ -11,7 +11,7 @@ \N \N -- !sql_3 -- -1.4337199628825619 1.1197287397085194 +1.4337199628825619 0.675885787569108 -- !sql_4 -- \N \N @@ -24,7 +24,7 @@ 1.1135657469022011 -- !sql_distinct_2 -- -1.1197287397085194 +0.675885787569108 -- !sql_distinct_3 -- \N @@ -41,7 +41,7 @@ \N -- !sql_5 -- -1.4337199628825619 1.1197287397085194 +1.4337199628825619 0.675885787569108 -- !sql_6 -- \N \N From f1f9be265eca0669c4f26e9df77fa97560c89fd1 Mon Sep 17 00:00:00 2001 From: zhiqiang-hhhh Date: Mon, 23 Sep 2024 20:41:16 +0800 Subject: [PATCH 7/8] F --- be/src/vec/aggregate_functions/aggregate_function_statistic.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/be/src/vec/aggregate_functions/aggregate_function_statistic.h b/be/src/vec/aggregate_functions/aggregate_function_statistic.h index 09947c20c7b246..a1fd4395eb848a 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_statistic.h +++ b/be/src/vec/aggregate_functions/aggregate_function_statistic.h @@ -145,8 +145,9 @@ class AggregateFunctionVarianceSimple dst_column->insert_default(); } else { dst_column_with_nullable.get_null_map_data().push_back(0); + // kurtosis = E(X^4) / E(X^2)^2 - 3 dst_column->get_data().push_back( - static_cast(moments_4 / pow(var_value, 2))); + static_cast(moments_4 / pow(var_value, 2)) - 3); } break; } From 42228ba95fa800ceaab3214f0af03b22de18f686 Mon Sep 17 00:00:00 2001 From: zhiqiang-hhhh Date: Mon, 23 Sep 2024 20:50:07 +0800 Subject: [PATCH 8/8] X --- .../data/query_p0/aggregate/aggregate_function_kurt.out | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/regression-test/data/query_p0/aggregate/aggregate_function_kurt.out b/regression-test/data/query_p0/aggregate/aggregate_function_kurt.out index 0c078b9e7ef5c5..362bd25d078c5e 100644 --- a/regression-test/data/query_p0/aggregate/aggregate_function_kurt.out +++ b/regression-test/data/query_p0/aggregate/aggregate_function_kurt.out @@ -11,7 +11,7 @@ \N \N -- !sql_3 -- -3.162124583734851 1.5000000000000007 +0.16212458373485106 -1.4999999999999993 -- !sql_4 -- \N \N @@ -21,10 +21,10 @@ \N \N -- !sql_distinct_1 -- -2.2985631952470373 +-0.7014368047529627 -- !sql_distinct_2 -- -1.5000000000000007 +-1.4999999999999993 -- !sql_distinct_3 -- \N @@ -41,7 +41,7 @@ \N -- !sql_5 -- -3.162124583734851 1.5000000000000007 +0.16212458373485106 -1.4999999999999993 -- !sql_6 -- \N \N