Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion be/src/vec/aggregate_functions/aggregate_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class IDataType;

struct AggregateFunctionAttr {
bool enable_decimal256 {false};
std::vector<std::pair<std::string, bool>> column_infos;
std::vector<std::string> column_names;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change this basic data structure?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optimize the redundant data structures added in approx_top_k.

};

template <bool nullable, typename ColVecType>
Expand Down
80 changes: 80 additions & 0 deletions be/src/vec/aggregate_functions/aggregate_function_approx_top.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,92 @@
#pragma once

#include "vec/core/types.h"
#include "vec/data_types/data_type.h"
#include "vec/data_types/data_type_nullable.h"

namespace doris::vectorized {

class AggregateFunctionApproxTop {
public:
AggregateFunctionApproxTop(const std::vector<std::string>& column_names)
: _column_names(column_names) {}

static int32_t is_valid_const_columns(const std::vector<bool>& is_const_columns) {
int32_t true_count = 0;
bool found_false_after_true = false;
for (int32_t i = is_const_columns.size() - 1; i >= 0; --i) {
if (is_const_columns[i]) {
true_count++;
if (found_false_after_true) {
return false;
}
} else {
if (true_count > 2) {
return false;
}
found_false_after_true = true;
}
}
if (true_count > 2) {
throw Exception(ErrorCode::INVALID_ARGUMENT, "Invalid is_const_columns configuration");
}
return true_count;
}

protected:
void lazy_init(const IColumn** columns, ssize_t row_num,
const DataTypes& argument_types) const {
auto get_param = [](size_t idx, const DataTypes& data_types,
const IColumn** columns) -> uint64_t {
const auto& data_type = data_types.at(idx);
const IColumn* column = columns[idx];

const auto* type = data_type.get();
if (type->is_nullable()) {
type = assert_cast<const DataTypeNullable*, TypeCheckOnRelease::DISABLE>(type)
->get_nested_type()
.get();
}
int64_t value = 0;
WhichDataType which(type);
if (which.idx == TypeIndex::Int8) {
value = assert_cast<const ColumnInt8*, TypeCheckOnRelease::DISABLE>(column)
->get_element(0);
} else if (which.idx == TypeIndex::Int16) {
value = assert_cast<const ColumnInt16*, TypeCheckOnRelease::DISABLE>(column)
->get_element(0);
} else if (which.idx == TypeIndex::Int32) {
value = assert_cast<const ColumnInt32*, TypeCheckOnRelease::DISABLE>(column)
->get_element(0);
}
if (value <= 0) {
throw Exception(ErrorCode::INVALID_ARGUMENT,
"The parameter cannot be less than or equal to 0.");
}
return value;
};

_threshold =
std::min(get_param(_column_names.size(), argument_types, columns), (uint64_t)4096);
_reserved = std::min(
std::max(get_param(_column_names.size() + 1, argument_types, columns), _threshold),
(uint64_t)4096);

if (_threshold == 0 || _reserved == 0 || _threshold > 4096 || _reserved > 4096) {
throw Exception(ErrorCode::INTERNAL_ERROR,
"approx_top_sum param error, _threshold: {}, _reserved: {}", _threshold,
_reserved);
}

_init_flag = true;
}

static inline constexpr UInt64 TOP_K_MAX_SIZE = 0xFFFFFF;

mutable std::vector<std::string> _column_names;
mutable bool _init_flag = false;
mutable uint64_t _threshold = 10;
mutable uint64_t _reserved = 30;
};

} // namespace doris::vectorized
48 changes: 3 additions & 45 deletions be/src/vec/aggregate_functions/aggregate_function_approx_top_k.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,58 +24,16 @@

namespace doris::vectorized {

int32_t is_valid_const_columns(const std::vector<bool>& is_const_columns) {
int32_t true_count = 0;
bool found_false_after_true = false;
for (int32_t i = is_const_columns.size() - 1; i >= 0; --i) {
if (is_const_columns[i]) {
true_count++;
if (found_false_after_true) {
return false;
}
} else {
if (true_count > 2) {
return false;
}
found_false_after_true = true;
}
}
if (true_count > 2) {
throw Exception(ErrorCode::INVALID_ARGUMENT, "Invalid is_const_columns configuration");
}
return true_count;
}

AggregateFunctionPtr create_aggregate_function_approx_top_k(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
if (argument_types.empty()) {
if (argument_types.size() < 3) {
return nullptr;
}

std::vector<bool> is_const_columns;
std::vector<std::string> column_names;
for (const auto& [name, is_const] : attr.column_infos) {
is_const_columns.push_back(is_const);
if (!is_const) {
column_names.push_back(name);
}
}

int32_t true_count = is_valid_const_columns(is_const_columns);
if (true_count == 0) {
return creator_without_type::create<AggregateFunctionApproxTopK<0>>(
argument_types, result_is_nullable, column_names);
} else if (true_count == 1) {
return creator_without_type::create<AggregateFunctionApproxTopK<1>>(
argument_types, result_is_nullable, column_names);
} else if (true_count == 2) {
return creator_without_type::create<AggregateFunctionApproxTopK<2>>(
argument_types, result_is_nullable, column_names);
} else {
return nullptr;
}
return creator_without_type::create<AggregateFunctionApproxTopK>(
argument_types, result_is_nullable, attr.column_names);
}

void register_aggregate_function_approx_top_k(AggregateFunctionSimpleFactory& factory) {
Expand Down
71 changes: 5 additions & 66 deletions be/src/vec/aggregate_functions/aggregate_function_approx_top_k.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,28 +45,25 @@

namespace doris::vectorized {

inline constexpr UInt64 TOP_K_MAX_SIZE = 0xFFFFFF;

struct AggregateFunctionTopKGenericData {
using Set = SpaceSaving<StringRef, StringRefHash>;

Set value;
};

template <int32_t ArgsSize>
class AggregateFunctionApproxTopK final
: public IAggregateFunctionDataHelper<AggregateFunctionTopKGenericData,
AggregateFunctionApproxTopK<ArgsSize>>,
AggregateFunctionApproxTopK>,
AggregateFunctionApproxTop {
private:
using State = AggregateFunctionTopKGenericData;

public:
AggregateFunctionApproxTopK(std::vector<std::string> column_names,
AggregateFunctionApproxTopK(const std::vector<std::string>& column_names,
const DataTypes& argument_types_)
: IAggregateFunctionDataHelper<AggregateFunctionTopKGenericData,
AggregateFunctionApproxTopK<ArgsSize>>(argument_types_),
_column_names(std::move(column_names)) {}
AggregateFunctionApproxTopK>(argument_types_),
AggregateFunctionApproxTop(column_names) {}

String get_name() const override { return "approx_top_k"; }

Expand Down Expand Up @@ -141,7 +138,7 @@ class AggregateFunctionApproxTopK final
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena* arena) const override {
if (!_init_flag) {
lazy_init(columns, row_num);
lazy_init(columns, row_num, this->get_argument_types());
}

auto& set = this->data(place).value;
Expand Down Expand Up @@ -227,64 +224,6 @@ class AggregateFunctionApproxTopK final
std::string res = buffer.GetString();
data_to.insert_data(res.data(), res.size());
}

private:
void lazy_init(const IColumn** columns, ssize_t row_num) const {
auto get_param = [](size_t idx, const DataTypes& data_types,
const IColumn** columns) -> uint64_t {
const auto& data_type = data_types.at(idx);
const IColumn* column = columns[idx];

const auto* type = data_type.get();
if (type->is_nullable()) {
type = assert_cast<const DataTypeNullable*, TypeCheckOnRelease::DISABLE>(type)
->get_nested_type()
.get();
}
int64_t value = 0;
WhichDataType which(type);
if (which.idx == TypeIndex::Int8) {
value = assert_cast<const ColumnInt8*, TypeCheckOnRelease::DISABLE>(column)
->get_element(0);
} else if (which.idx == TypeIndex::Int16) {
value = assert_cast<const ColumnInt16*, TypeCheckOnRelease::DISABLE>(column)
->get_element(0);
} else if (which.idx == TypeIndex::Int32) {
value = assert_cast<const ColumnInt32*, TypeCheckOnRelease::DISABLE>(column)
->get_element(0);
}
if (value <= 0) {
throw Exception(ErrorCode::INVALID_ARGUMENT,
"The parameter cannot be less than or equal to 0.");
}
return value;
};

const auto& data_types = this->get_argument_types();
if (ArgsSize == 1) {
_threshold =
std::min(get_param(_column_names.size(), data_types, columns), (uint64_t)1000);
} else if (ArgsSize == 2) {
_threshold =
std::min(get_param(_column_names.size(), data_types, columns), (uint64_t)1000);
_reserved = std::min(
std::max(get_param(_column_names.size() + 1, data_types, columns), _threshold),
(uint64_t)1000);
}

if (_threshold == 0 || _reserved == 0 || _threshold > 1000 || _reserved > 1000) {
throw Exception(ErrorCode::INTERNAL_ERROR,
"approx_top_k param error, _threshold: {}, _reserved: {}", _threshold,
_reserved);
}

_init_flag = true;
}

mutable std::vector<std::string> _column_names;
mutable bool _init_flag = false;
mutable uint64_t _threshold = 10;
mutable uint64_t _reserved = 300;
};

} // namespace doris::vectorized
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include "vec/aggregate_functions/aggregate_function_approx_top_sum.h"

#include "common/exception.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
#include "vec/aggregate_functions/helpers.h"
#include "vec/data_types/data_type.h"

namespace doris::vectorized {

template <size_t N>
AggregateFunctionPtr create_aggregate_function_multi_top_sum_impl(
const DataTypes& argument_types, const bool result_is_nullable,
const std::vector<std::string>& column_names) {
if (N == argument_types.size() - 3) {
return creator_with_type_base<true, false, false, N>::template create<
AggregateFunctionApproxTopSumSimple>(argument_types, result_is_nullable,
column_names);
} else {
return create_aggregate_function_multi_top_sum_impl<N - 1>(
argument_types, result_is_nullable, column_names);
}
}

template <>
AggregateFunctionPtr create_aggregate_function_multi_top_sum_impl<0>(
const DataTypes& argument_types, const bool result_is_nullable,
const std::vector<std::string>& column_names) {
return creator_with_type_base<true, false, false, 0>::template create<
AggregateFunctionApproxTopSumSimple>(argument_types, result_is_nullable, column_names);
}

AggregateFunctionPtr create_aggregate_function_approx_top_sum(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable,
const AggregateFunctionAttr& attr) {
if (argument_types.size() < 3) {
return nullptr;
}

constexpr size_t max_param_value = 10;
if (argument_types.size() > max_param_value) {
throw Exception(ErrorCode::INTERNAL_ERROR,
"Argument types size exceeds the supported limit.");
}

return create_aggregate_function_multi_top_sum_impl<max_param_value>(
argument_types, result_is_nullable, attr.column_names);
}

void register_aggregate_function_approx_top_sum(AggregateFunctionSimpleFactory& factory) {
factory.register_function_both("approx_top_sum", create_aggregate_function_approx_top_sum);
}

} // namespace doris::vectorized
Loading