From ca1bd8a4b81f98f5ed4e8804192d071163358b74 Mon Sep 17 00:00:00 2001 From: Hao Tan Date: Thu, 5 Aug 2021 18:40:47 +0800 Subject: [PATCH 1/7] Support percentile aggregate function --- be/src/exprs/aggregate_functions.cpp | 79 +++++++++++ be/src/exprs/aggregate_functions.h | 13 ++ be/src/util/counts.h | 130 ++++++++++++++++++ be/test/exprs/CMakeLists.txt | 1 + be/test/exprs/percentile_test.cpp | 119 ++++++++++++++++ be/test/util/CMakeLists.txt | 1 + be/test/util/counts_test.cpp | 70 ++++++++++ .../doris/analysis/FunctionCallExpr.java | 10 ++ .../org/apache/doris/catalog/FunctionSet.java | 10 ++ 9 files changed, 433 insertions(+) create mode 100644 be/src/util/counts.h create mode 100644 be/test/exprs/percentile_test.cpp create mode 100644 be/test/util/counts_test.cpp diff --git a/be/src/exprs/aggregate_functions.cpp b/be/src/exprs/aggregate_functions.cpp index b5a80771c27801..3cb763cd4c0b7c 100644 --- a/be/src/exprs/aggregate_functions.cpp +++ b/be/src/exprs/aggregate_functions.cpp @@ -31,6 +31,7 @@ #include "runtime/string_value.h" #include "util/debug_util.h" #include "util/tdigest.h" +#include "util/counts.h" // TODO: this file should be cross compiled and then all of the builtin // aggregate functions will have a codegen enabled path. Then we can remove @@ -147,6 +148,81 @@ void AggregateFunctions::count_remove(FunctionContext*, const AnyVal& src, BigIn } } +struct PercentileState { + PercentileState() : counts(new Counts()) {} + ~PercentileState() { delete counts; } + + Counts* counts = nullptr; + double quantile = -1.0; +}; + +void AggregateFunctions::percentile_init(FunctionContext* ctx, StringVal* dst) { + dst->is_null = false; + dst->len = sizeof(PercentileState); + dst->ptr = (uint8_t*) new PercentileState(); +} + +template +void AggregateFunctions::percentile_update(FunctionContext* ctx, const T& src, + const DoubleVal& quantile, StringVal* dst) { + if (src.is_null) { + return; + } + + DCHECK_GE(quantile.val, 0); + DCHECK_LE(quantile.val, 1); + DCHECK(dst->ptr != nullptr); + DCHECK_EQ(sizeof(PercentileState), dst->len); + + PercentileState* percentile = reinterpret_cast(dst->ptr); + percentile->counts->increment(src.val, 1); + percentile->quantile = quantile.val; +} + +void AggregateFunctions::percentile_merge(FunctionContext* ctx, const StringVal& src, StringVal* dst) { + DCHECK(dst->ptr != nullptr); + DCHECK_EQ(sizeof(PercentileState), dst->len); + + double quantile; + memcpy(&quantile, src.ptr, sizeof(double)); + + PercentileState* src_percentile = new PercentileState(); + src_percentile->quantile = quantile; + src_percentile->counts->unserialize(src.ptr + sizeof(double)); + + PercentileState* dst_percentile = reinterpret_cast(dst->ptr); + dst_percentile->counts->merge(src_percentile->counts); + if (dst_percentile->quantile == -1.0) { + dst_percentile->quantile = quantile; + } + + delete src_percentile; +} + +StringVal AggregateFunctions::percentile_serialize(FunctionContext* ctx, const StringVal& src) { + DCHECK(!src.is_null); + + PercentileState* percentile = reinterpret_cast(src.ptr); + uint32_t serialize_size = percentile->counts->serialized_size(); + StringVal result(ctx, sizeof(double) + serialize_size); + memcpy(result.ptr, &percentile->quantile, sizeof(double)); + percentile->counts->serialize(result.ptr + sizeof(double)); + + delete percentile; + return result; +} + +DoubleVal AggregateFunctions::percentile_finalize(FunctionContext* ctx, const StringVal& src) { + DCHECK(!src.is_null); + + PercentileState* percentile = reinterpret_cast(src.ptr); + double quantile = percentile->quantile; + auto result = percentile->counts->terminate(quantile); + + delete percentile; + return result; +} + struct PercentileApproxState { public: PercentileApproxState() : digest(new TDigest()) {} @@ -2628,6 +2704,9 @@ template void AggregateFunctions::offset_fn_update(FunctionContext const DecimalV2Val&, DecimalV2Val* dst); +template void AggregateFunctions::percentile_update( + FunctionContext* ctx, const BigIntVal&, const DoubleVal&, StringVal*); + template void AggregateFunctions::percentile_approx_update( FunctionContext* ctx, const doris_udf::DoubleVal&, const doris_udf::DoubleVal&, doris_udf::StringVal*); diff --git a/be/src/exprs/aggregate_functions.h b/be/src/exprs/aggregate_functions.h index 39e9aa891b6bfa..71bd4db204c4d9 100644 --- a/be/src/exprs/aggregate_functions.h +++ b/be/src/exprs/aggregate_functions.h @@ -65,6 +65,19 @@ class AggregateFunctions { static void count_star_remove(FunctionContext*, BigIntVal* dst); + // Implementation of percentile + static void percentile_init(FunctionContext* ctx, StringVal* dst); + + template + static void percentile_update(FunctionContext* ctx, const T& src, + const DoubleVal& quantile, StringVal* dst); + + static void percentile_merge(FunctionContext* ctx, const StringVal& src, StringVal* dst); + + static StringVal percentile_serialize(FunctionContext* ctx, const StringVal& state_sv); + + static DoubleVal percentile_finalize(FunctionContext* ctx, const StringVal& src); + // Implementation of percentile_approx static void percentile_approx_init(doris_udf::FunctionContext* ctx, doris_udf::StringVal* dst); diff --git a/be/src/util/counts.h b/be/src/util/counts.h new file mode 100644 index 00000000000000..825b6bea9d2077 --- /dev/null +++ b/be/src/util/counts.h @@ -0,0 +1,130 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef DORIS_BE_SRC_UTIL_COUNTS_H_ +#define DORIS_BE_SRC_UTIL_COUNTS_H_ + +#include +#include +#include "udf/udf.h" + +namespace doris { + +class Counts { +public: + Counts() = default; + + inline void merge(const Counts* other) { + if (other == nullptr || other->_counts.empty()) { + return; + } + + for (auto& cell : other->_counts) { + increment(cell.first, cell.second); + } + } + + void increment(int64_t key, int i) { + auto item = _counts.find(key); + if (item != _counts.end()) { + item->second += i; + } else { + _counts.emplace(std::make_pair(key, i)); + } + } + + uint32_t serialized_size() { + return sizeof(uint32_t) + sizeof(int64_t) * _counts.size() + sizeof(int) * _counts.size(); + } + + void serialize(uint8_t* writer) { + uint32_t size = _counts.size(); + memcpy(writer, &size, sizeof(uint32_t)); + writer += sizeof(uint32_t); + for (auto& cell : _counts) { + memcpy(writer, &cell.first, sizeof(int64_t)); + writer += sizeof(int64_t); + memcpy(writer, &cell.second, sizeof(int)); + writer += sizeof(int); + } + } + + void unserialize(const uint8_t* type_reader) { + uint32_t size; + memcpy(&size, type_reader, sizeof(uint32_t)); + type_reader += sizeof(uint32_t); + for (uint32_t i = 0; i < size; ++i) { + int64_t key; + int count; + memcpy(&key, type_reader, sizeof(int64_t)); + type_reader += sizeof(int64_t); + memcpy(&count, type_reader, sizeof(int)); + type_reader += sizeof(int); + _counts.emplace(std::make_pair(key, count)); + } + } + + double get_percentile(std::map& copy_counts, double position) { + long lower = std::floor(position); + long higher = std::ceil(position); + + auto iter = copy_counts.begin(); + for (; iter != copy_counts.end() && iter->second < lower + 1; ++iter); + + int64_t lower_key = iter->first; + if (higher == lower) { + return lower_key; + } + + if (iter->second < higher + 1) { + iter++; + } + + int64_t higher_key = iter->first; + if (lower_key == higher_key) { + return lower_key; + } + + return (higher - position) * lower_key + (position - lower) * higher_key; + } + + doris_udf::DoubleVal terminate(double quantile) { + if (_counts.empty()) { + return doris_udf::DoubleVal(); + } + + std::map copy_counts(_counts); + long total = 0; + for (auto& cell : copy_counts) { + total += cell.second; + cell.second = total; + } + + long max_position = total - 1; + double position = max_position * quantile; + return doris_udf::DoubleVal(get_percentile(copy_counts, position)); + + } + + +private: + std::map _counts; +}; + +} // namespace doris + +#endif // DORIS_BE_SRC_UTIL_COUNTS_H_ diff --git a/be/test/exprs/CMakeLists.txt b/be/test/exprs/CMakeLists.txt index 26e73f01cef923..9f4702bb8f6395 100644 --- a/be/test/exprs/CMakeLists.txt +++ b/be/test/exprs/CMakeLists.txt @@ -29,6 +29,7 @@ ADD_BE_TEST(json_function_test) ADD_BE_TEST(string_functions_test) ADD_BE_TEST(timestamp_functions_test) ADD_BE_TEST(percentile_approx_test) +ADD_BE_TEST(percentile_test) ADD_BE_TEST(bitmap_function_test) ADD_BE_TEST(hll_function_test) ADD_BE_TEST(encryption_functions_test) diff --git a/be/test/exprs/percentile_test.cpp b/be/test/exprs/percentile_test.cpp new file mode 100644 index 00000000000000..8054d086241103 --- /dev/null +++ b/be/test/exprs/percentile_test.cpp @@ -0,0 +1,119 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "exprs/aggregate_functions.h" +#include "testutil/function_utils.h" + +namespace doris { + +class PercentileTest : public testing::Test { +public: + PercentileTest() {} +}; + +TEST_F(PercentileTest, testSample) { + FunctionUtils* futil = new FunctionUtils(); + doris_udf::FunctionContext* context = futil->get_fn_ctx(); + + DoubleVal doubleQ(0.9); + + StringVal stringVal1; + BigIntVal int1(1); + AggregateFunctions::percentile_init(context, &stringVal1); + AggregateFunctions::percentile_update(context, int1, doubleQ, &stringVal1); + BigIntVal int2(2); + AggregateFunctions::percentile_update(context, int2, doubleQ, &stringVal1); + + StringVal s = AggregateFunctions::percentile_serialize(context, stringVal1); + + StringVal stringVal2; + AggregateFunctions::percentile_init(context, &stringVal2); + AggregateFunctions::percentile_merge(context, s, &stringVal2); + DoubleVal v = AggregateFunctions::percentile_finalize(context, stringVal2); + ASSERT_EQ(v.val, 1.9); + delete futil; +} + +TEST_F(PercentileTest, testNoMerge) { + FunctionUtils* futil = new FunctionUtils(); + doris_udf::FunctionContext* context = futil->get_fn_ctx(); + + DoubleVal doubleQ(0.9); + + StringVal stringVal1; + BigIntVal val(1); + AggregateFunctions::percentile_init(context, &stringVal1); + AggregateFunctions::percentile_update(context, val, doubleQ, &stringVal1); + BigIntVal val2(2); + AggregateFunctions::percentile_update(context, val2, doubleQ, &stringVal1); + + DoubleVal v = AggregateFunctions::percentile_finalize(context, stringVal1); + ASSERT_EQ(v.val, 1.9); + delete futil; +} + +TEST_F(PercentileTest, testSerialize) { + FunctionUtils* futil = new FunctionUtils(); + doris_udf::FunctionContext* context = futil->get_fn_ctx(); + + DoubleVal doubleQ(0.999); + StringVal stringVal; + AggregateFunctions::percentile_init(context, &stringVal); + + for (int i = 1; i <= 100000; i++) { + BigIntVal val(i); + AggregateFunctions::percentile_update(context, val, doubleQ, &stringVal); + } + StringVal serialized = AggregateFunctions::percentile_serialize(context, stringVal); + + // mock serialize + StringVal stringVal2; + AggregateFunctions::percentile_init(context, &stringVal2); + AggregateFunctions::percentile_merge(context, serialized, &stringVal2); + DoubleVal v = AggregateFunctions::percentile_finalize(context, stringVal2); + ASSERT_DOUBLE_EQ(v.val, 99900.001); + + // merge init percentile stringVal3 should not change the correct result + AggregateFunctions::percentile_init(context, &stringVal); + + for (int i = 1; i <= 100000; i++) { + BigIntVal val(i); + AggregateFunctions::percentile_update(context, val, doubleQ, &stringVal); + } + serialized = AggregateFunctions::percentile_serialize(context, stringVal); + + StringVal stringVal3; + AggregateFunctions::percentile_init(context, &stringVal2); + AggregateFunctions::percentile_init(context, &stringVal3); + StringVal serialized2 = AggregateFunctions::percentile_serialize(context, stringVal3); + + AggregateFunctions::percentile_merge(context, serialized, &stringVal2); + AggregateFunctions::percentile_merge(context, serialized2, &stringVal2); + v = AggregateFunctions::percentile_finalize(context, stringVal2); + ASSERT_DOUBLE_EQ(v.val, 99900.001); + + delete futil; +} + +} // namespace doris + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/be/test/util/CMakeLists.txt b/be/test/util/CMakeLists.txt index 74374cdafefaa4..634250b566cb22 100644 --- a/be/test/util/CMakeLists.txt +++ b/be/test/util/CMakeLists.txt @@ -72,3 +72,4 @@ ADD_BE_TEST(s3_uri_test) ADD_BE_TEST(s3_storage_backend_test) ADD_BE_TEST(broker_storage_backend_test) ADD_BE_TEST(sort_heap_test) +ADD_BE_TEST(counts_test) diff --git a/be/test/util/counts_test.cpp b/be/test/util/counts_test.cpp new file mode 100644 index 00000000000000..2d683efa0ac6c8 --- /dev/null +++ b/be/test/util/counts_test.cpp @@ -0,0 +1,70 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "util/counts.h" + +#include +#include "test_util/test_util.h" + +namespace doris { + +class TCountsTest : public testing::Test {}; + +TEST_F(TCountsTest, TotalTest) { + Counts counts; + // 1 1 1 2 5 7 7 9 9 19 + // >>> import numpy as np + // >>> a = np.array([1,1,1,2,5,7,7,9,9,19]) + // >>> p = np.percentile(a, 20) + counts.increment(1, 3); + counts.increment(5, 1); + counts.increment(2, 1); + counts.increment(9, 1); + counts.increment(9, 1); + counts.increment(19, 1); + counts.increment(7, 2); + + doris_udf::DoubleVal result = counts.terminate(0.2); + EXPECT_EQ(1, result.val); + uint8_t* writer = new uint8_t[counts.serialized_size()]; + uint8_t* type_reader = writer; + counts.serialize(writer); + + Counts other; + other.unserialize(type_reader); + doris_udf::DoubleVal result1 = other.terminate(0.2); + EXPECT_EQ(result.val, result1.val); + + Counts other1; + other1.increment(1, 1); + other1.increment(100, 3); + other1.increment(50, 3); + other1.increment(10, 1); + other1.increment(99, 2); + + counts.merge(&other1); + // 1 1 1 1 2 5 7 7 9 9 10 19 50 50 50 99 99 100 100 100 + EXPECT_EQ(counts.terminate(0.3).val, 6.4); + +} + +} // namespace doris + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java index bda3af926b1588..a35e6399d60ec9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java @@ -459,6 +459,16 @@ private void analyzeBuiltinAggFunction(Analyzer analyzer) throws AnalysisExcepti fnParams.setIsDistinct(false); } + if (fnName.getFunction().equalsIgnoreCase("percentile")) { + if (children.size() != 2) { + throw new AnalysisException("percentile(expr, DOUBLE) requires two parameters"); + } + if (!getChild(1).isConstant()) { + throw new AnalysisException("percentile requires second parameter must be a constant : " + + this.toSql()); + } + } + if (fnName.getFunction().equalsIgnoreCase("percentile_approx")) { if (children.size() != 2 && children.size() != 3) { throw new AnalysisException("percentile_approx(expr, DOUBLE [, B]) requires two or three parameters"); diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java index 4dacd6a286cb29..a9fff4497c7614 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java @@ -1372,6 +1372,16 @@ private void initAggregateBuiltins() { "_ZN5doris15BitmapFunctions16bitmap_serializeEPN9doris_udf15FunctionContextERKNS1_9StringValE", true, false, true)); + //Percentile + addBuiltin(AggregateFunction.createBuiltin("percentile", + Lists.newArrayList(Type.BIGINT, Type.DOUBLE), Type.DOUBLE, Type.VARCHAR, + prefix + "15percentile_initEPN9doris_udf15FunctionContextEPNS1_9StringValE", + prefix + "17percentile_updateIN9doris_udf9BigIntValEEEvPNS2_15FunctionContextERKT_RKNS2_9DoubleValEPNS2_9StringValE", + prefix + "16percentile_mergeEPN9doris_udf15FunctionContextERKNS1_9StringValEPS4_", + prefix + "20percentile_serializeEPN9doris_udf15FunctionContextERKNS1_9StringValE", + prefix + "19percentile_finalizeEPN9doris_udf15FunctionContextERKNS1_9StringValE", + false, false, false)); + //PercentileApprox addBuiltin(AggregateFunction.createBuiltin("percentile_approx", Lists.newArrayList(Type.DOUBLE, Type.DOUBLE), Type.DOUBLE, Type.VARCHAR, From 45980d3ee22371e4aae08e14a89567edfb42bf5f Mon Sep 17 00:00:00 2001 From: Tan Hao Date: Thu, 12 Aug 2021 10:39:32 +0800 Subject: [PATCH 2/7] clean code --- be/src/exprs/aggregate_functions.cpp | 2 -- be/src/util/counts.h | 2 -- 2 files changed, 4 deletions(-) diff --git a/be/src/exprs/aggregate_functions.cpp b/be/src/exprs/aggregate_functions.cpp index 23a1ff3b75653a..3b7a8722782127 100644 --- a/be/src/exprs/aggregate_functions.cpp +++ b/be/src/exprs/aggregate_functions.cpp @@ -194,8 +194,6 @@ void AggregateFunctions::percentile_update(FunctionContext* ctx, const T& src, return; } - DCHECK_GE(quantile.val, 0); - DCHECK_LE(quantile.val, 1); DCHECK(dst->ptr != nullptr); DCHECK_EQ(sizeof(PercentileState), dst->len); diff --git a/be/src/util/counts.h b/be/src/util/counts.h index 825b6bea9d2077..65acc542915e8d 100644 --- a/be/src/util/counts.h +++ b/be/src/util/counts.h @@ -117,10 +117,8 @@ class Counts { long max_position = total - 1; double position = max_position * quantile; return doris_udf::DoubleVal(get_percentile(copy_counts, position)); - } - private: std::map _counts; }; From ad9d34146387452a3b2c37b7519e8c841f684d94 Mon Sep 17 00:00:00 2001 From: Hao Tan Date: Thu, 12 Aug 2021 16:22:56 +0800 Subject: [PATCH 3/7] use unordered_map to optimize performance --- be/src/util/counts.h | 40 +++++++++++++++++++++++----------------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/be/src/util/counts.h b/be/src/util/counts.h index 65acc542915e8d..39a766df845ce8 100644 --- a/be/src/util/counts.h +++ b/be/src/util/counts.h @@ -18,15 +18,18 @@ #ifndef DORIS_BE_SRC_UTIL_COUNTS_H_ #define DORIS_BE_SRC_UTIL_COUNTS_H_ -#include #include +#include +#include +#include + #include "udf/udf.h" namespace doris { class Counts { public: - Counts() = default; + Counts() = default; inline void merge(const Counts* other) { if (other == nullptr || other->_counts.empty()) { @@ -38,7 +41,7 @@ class Counts { } } - void increment(int64_t key, int i) { + void increment(int64_t key, uint32_t i) { auto item = _counts.find(key); if (item != _counts.end()) { item->second += i; @@ -48,7 +51,8 @@ class Counts { } uint32_t serialized_size() { - return sizeof(uint32_t) + sizeof(int64_t) * _counts.size() + sizeof(int) * _counts.size(); + return sizeof(uint32_t) + sizeof(int64_t) * _counts.size() + + sizeof(uint32_t) * _counts.size(); } void serialize(uint8_t* writer) { @@ -58,8 +62,8 @@ class Counts { for (auto& cell : _counts) { memcpy(writer, &cell.first, sizeof(int64_t)); writer += sizeof(int64_t); - memcpy(writer, &cell.second, sizeof(int)); - writer += sizeof(int); + memcpy(writer, &cell.second, sizeof(uint32_t)); + writer += sizeof(uint32_t); } } @@ -69,21 +73,22 @@ class Counts { type_reader += sizeof(uint32_t); for (uint32_t i = 0; i < size; ++i) { int64_t key; - int count; + uint32_t count; memcpy(&key, type_reader, sizeof(int64_t)); type_reader += sizeof(int64_t); - memcpy(&count, type_reader, sizeof(int)); - type_reader += sizeof(int); + memcpy(&count, type_reader, sizeof(uint32_t)); + type_reader += sizeof(uint32_t); _counts.emplace(std::make_pair(key, count)); } } - double get_percentile(std::map& copy_counts, double position) { + double get_percentile(std::vector>& counts, double position) { long lower = std::floor(position); long higher = std::ceil(position); - auto iter = copy_counts.begin(); - for (; iter != copy_counts.end() && iter->second < lower + 1; ++iter); + auto iter = counts.begin(); + for (; iter != counts.end() && iter->second < lower + 1; ++iter) + ; int64_t lower_key = iter->first; if (higher == lower) { @@ -94,7 +99,7 @@ class Counts { iter++; } - int64_t higher_key = iter->first; + int64_t higher_key = iter->first; if (lower_key == higher_key) { return lower_key; } @@ -107,20 +112,21 @@ class Counts { return doris_udf::DoubleVal(); } - std::map copy_counts(_counts); + std::vector> elems(_counts.begin(), _counts.end()); + std::sort(elems.begin(), elems.end()); long total = 0; - for (auto& cell : copy_counts) { + for (auto& cell : elems) { total += cell.second; cell.second = total; } long max_position = total - 1; double position = max_position * quantile; - return doris_udf::DoubleVal(get_percentile(copy_counts, position)); + return doris_udf::DoubleVal(get_percentile(elems, position)); } private: - std::map _counts; + std::unordered_map _counts; }; } // namespace doris From ea5741747ded24258ec51257fa98b00a9890ff7f Mon Sep 17 00:00:00 2001 From: Tan Hao Date: Fri, 13 Aug 2021 10:33:53 +0800 Subject: [PATCH 4/7] fix ut --- be/src/util/counts.h | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/be/src/util/counts.h b/be/src/util/counts.h index 39a766df845ce8..fc882658f6e0b8 100644 --- a/be/src/util/counts.h +++ b/be/src/util/counts.h @@ -18,8 +18,8 @@ #ifndef DORIS_BE_SRC_UTIL_COUNTS_H_ #define DORIS_BE_SRC_UTIL_COUNTS_H_ +#include #include -#include #include #include @@ -113,7 +113,11 @@ class Counts { } std::vector> elems(_counts.begin(), _counts.end()); - std::sort(elems.begin(), elems.end()); + sort(elems.begin(), elems.end(), + [](const std::pair l, const std::pair r) { + return l.first < r.first; + }); + long total = 0; for (auto& cell : elems) { total += cell.second; From daa1b7f37b4fb7586f8e91417bd0f6d0de26ac22 Mon Sep 17 00:00:00 2001 From: Tan Hao Date: Fri, 13 Aug 2021 13:16:02 +0000 Subject: [PATCH 5/7] fix --- be/src/exprs/aggregate_functions.cpp | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/be/src/exprs/aggregate_functions.cpp b/be/src/exprs/aggregate_functions.cpp index 3b7a8722782127..9d1d140df86d67 100644 --- a/be/src/exprs/aggregate_functions.cpp +++ b/be/src/exprs/aggregate_functions.cpp @@ -174,10 +174,7 @@ void AggregateFunctions::count_remove(FunctionContext*, const AnyVal& src, BigIn } struct PercentileState { - PercentileState() : counts(new Counts()) {} - ~PercentileState() { delete counts; } - - Counts* counts = nullptr; + Counts counts; double quantile = -1.0; }; @@ -198,7 +195,7 @@ void AggregateFunctions::percentile_update(FunctionContext* ctx, const T& src, DCHECK_EQ(sizeof(PercentileState), dst->len); PercentileState* percentile = reinterpret_cast(dst->ptr); - percentile->counts->increment(src.val, 1); + percentile->counts.increment(src.val, 1); percentile->quantile = quantile.val; } @@ -211,10 +208,10 @@ void AggregateFunctions::percentile_merge(FunctionContext* ctx, const StringVal& PercentileState* src_percentile = new PercentileState(); src_percentile->quantile = quantile; - src_percentile->counts->unserialize(src.ptr + sizeof(double)); + src_percentile->counts.unserialize(src.ptr + sizeof(double)); PercentileState* dst_percentile = reinterpret_cast(dst->ptr); - dst_percentile->counts->merge(src_percentile->counts); + dst_percentile->counts.merge(&src_percentile->counts); if (dst_percentile->quantile == -1.0) { dst_percentile->quantile = quantile; } @@ -226,10 +223,10 @@ StringVal AggregateFunctions::percentile_serialize(FunctionContext* ctx, const S DCHECK(!src.is_null); PercentileState* percentile = reinterpret_cast(src.ptr); - uint32_t serialize_size = percentile->counts->serialized_size(); + uint32_t serialize_size = percentile->counts.serialized_size(); StringVal result(ctx, sizeof(double) + serialize_size); memcpy(result.ptr, &percentile->quantile, sizeof(double)); - percentile->counts->serialize(result.ptr + sizeof(double)); + percentile->counts.serialize(result.ptr + sizeof(double)); delete percentile; return result; @@ -240,7 +237,7 @@ DoubleVal AggregateFunctions::percentile_finalize(FunctionContext* ctx, const St PercentileState* percentile = reinterpret_cast(src.ptr); double quantile = percentile->quantile; - auto result = percentile->counts->terminate(quantile); + auto result = percentile->counts.terminate(quantile); delete percentile; return result; From a1552333d9cfdf009520bfeb6f38f23bffff5444 Mon Sep 17 00:00:00 2001 From: Tan Hao Date: Sat, 14 Aug 2021 11:34:40 +0800 Subject: [PATCH 6/7] add doc for percentile --- .../aggregate-functions/percentile.md | 49 ++++++++++++++++++ .../aggregate-functions/percentile.md | 50 +++++++++++++++++++ 2 files changed, 99 insertions(+) create mode 100644 docs/en/sql-reference/sql-functions/aggregate-functions/percentile.md create mode 100755 docs/zh-CN/sql-reference/sql-functions/aggregate-functions/percentile.md diff --git a/docs/en/sql-reference/sql-functions/aggregate-functions/percentile.md b/docs/en/sql-reference/sql-functions/aggregate-functions/percentile.md new file mode 100644 index 00000000000000..30c347ad116dcd --- /dev/null +++ b/docs/en/sql-reference/sql-functions/aggregate-functions/percentile.md @@ -0,0 +1,49 @@ +--- +{ + "title": "PERCENTILE", + "language": "en" +} +--- + + + +# PERCENTILE +## Description +### Syntax + +`PERCENTILE(expr, DOUBLE p)` + +Calculate the exact percentile, suitable for small data volumes. Sort the specified column in descending order first, and then take the exact p-th percentile. The value of p is between 0 and 1 + +Parameter Description: +expr: required. The value is an integer (bigint at most). +p: The exact percentile is required. The value is [0.0,1.0] + +## example +``` +MySQL > select `table`, percentile(cost_time,0.99) from log_statis group by `table`; ++---------------------+---------------------------+ +| table | percentile(`cost_time`, 0.99)| ++----------+--------------------------------------+ +| test | 54.22 | ++----------+--------------------------------------+ + +## keyword +PERCENTILE diff --git a/docs/zh-CN/sql-reference/sql-functions/aggregate-functions/percentile.md b/docs/zh-CN/sql-reference/sql-functions/aggregate-functions/percentile.md new file mode 100755 index 00000000000000..19a49482557793 --- /dev/null +++ b/docs/zh-CN/sql-reference/sql-functions/aggregate-functions/percentile.md @@ -0,0 +1,50 @@ +--- +{ + "title": "PERCENTILE", + "language": "zh-CN" +} +--- + + + +# PERCENTILE +## description +### Syntax + +`PERCENTILE(expr, DOUBLE p)` + +计算精确的百分位数,适用于小数据量。先对指定列降序排列,然后取精确的第 p 位百分数。p的值介于0到1之间 + +参数说明 +expr:必填。值为整数(最大为bigint) 类型的列。 +p:必填。需要精确的百分位数。取值为 [0.0,1.0]。 + + +## example +``` +MySQL > select `table`, percentile(cost_time,0.99) from log_statis group by `table`; ++---------------------+---------------------------+ +| table | percentile(`cost_time`, 0.99) | ++----------+--------------------------------------+ +| test | 54.22 | ++----------+--------------------------------------+ + +## keyword +PERCENTILE From 3916142456402ea5b50e5dc88a8b7f77e42dd393 Mon Sep 17 00:00:00 2001 From: Tan Hao Date: Sat, 14 Aug 2021 11:35:08 +0800 Subject: [PATCH 7/7] add doc for percentile --- .../sql-functions/aggregate-functions/percentile.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/zh-CN/sql-reference/sql-functions/aggregate-functions/percentile.md b/docs/zh-CN/sql-reference/sql-functions/aggregate-functions/percentile.md index 19a49482557793..077cafa285b3e8 100755 --- a/docs/zh-CN/sql-reference/sql-functions/aggregate-functions/percentile.md +++ b/docs/zh-CN/sql-reference/sql-functions/aggregate-functions/percentile.md @@ -36,7 +36,6 @@ under the License. expr:必填。值为整数(最大为bigint) 类型的列。 p:必填。需要精确的百分位数。取值为 [0.0,1.0]。 - ## example ``` MySQL > select `table`, percentile(cost_time,0.99) from log_statis group by `table`;