From 8848a7334a289c823e6d4008e2d98d011eef0d3a Mon Sep 17 00:00:00 2001 From: chesterxu Date: Thu, 14 Mar 2024 20:58:28 +0800 Subject: [PATCH 01/21] init --- ...gregate_function_group_array_intersect.cpp | 40 ++ ...aggregate_function_group_array_intersect.h | 398 ++++++++++++++++++ .../aggregate_function_simple_factory.cpp | 2 + be/src/vec/common/hash_table/hash_table.h | 2 + docs/sidebars.json | 1 + .../doris/analysis/FunctionCallExpr.java | 3 +- .../catalog/BuiltinAggregateFunctions.java | 2 + .../org/apache/doris/catalog/FunctionSet.java | 6 +- .../functions/agg/GroupArrayIntersect.java | 92 ++++ .../visitor/AggregateFunctionVisitor.java | 5 + 10 files changed, 549 insertions(+), 2 deletions(-) create mode 100644 be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.cpp create mode 100644 be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupArrayIntersect.java diff --git a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.cpp b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.cpp new file mode 100644 index 00000000000000..b051e2d3950bc4 --- /dev/null +++ b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.cpp @@ -0,0 +1,40 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// This file is copied from +// https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/AggregateFunctionGroupArrayIntersect.cpp +// and modified by Doris + +#include "vec/aggregate_functions/aggregate_function_group_array_intersect.h" + +namespace doris::vectorized { +AggregateFunctionPtr create_aggregate_function_group_array_intersect( + const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { + assert_unary(name, argument_types); + + if (!WhichDataType(argument_types.at(0)).is_array()) + throw Exception(ErrorCode::INVALID_ARGUMENT, + "Aggregate function groupArrayIntersect accepts only array type argument."); + + return create_aggregate_function_group_array_intersect_impl(name, argument_types, + result_is_nullable); +} + +void register_aggregate_function_group_array_intersect(AggregateFunctionSimpleFactory& factory) { + factory.register_function_both("group_array_intersect", + create_aggregate_function_group_array_intersect); +} +} // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h new file mode 100644 index 00000000000000..faf58943d588f4 --- /dev/null +++ b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h @@ -0,0 +1,398 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// This file is copied from +// https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/AggregateFunctionGroupArrayIntersect.cpp +// and modified by Doris + +#include +#include +#include +#include + +#include +#include + +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/aggregate_functions/aggregate_function_simple_factory.h" +#include "vec/aggregate_functions/helpers.h" +#include "vec/columns/column_array.h" +#include "vec/data_types/data_type_array.h" +#include "vec/data_types/data_type_number.h" +#include "vec/data_types/data_type_string.h" +#include "vec/data_types/data_type_time_v2.h" +#include "vec/io/io_helper.h" +#include "vec/io/var_int.h" + +namespace doris { +namespace vectorized { +class Arena; +class BufferReadable; +class BufferWritable; +} // namespace vectorized +} // namespace doris + +namespace doris::vectorized { + +template +struct AggregateFunctionGroupArrayIntersectData { + using Set = HashSet; + + Set value; + UInt64 version = 0; +}; + +/// Puts all values to the hash set. Returns an array of unique values. Implemented for numeric types. +template +class AggregateFunctionGroupArrayIntersect + : public IAggregateFunctionDataHelper, + AggregateFunctionGroupArrayIntersect> { +private: + using State = AggregateFunctionGroupArrayIntersectData; + DataTypePtr argument_type; + +public: + AggregateFunctionGroupArrayIntersect(const DataTypes& argument_types_) + : IAggregateFunctionDataHelper, + AggregateFunctionGroupArrayIntersect>( + {argument_types_}), + argument_type(this->argument_types[0]) {} + + AggregateFunctionGroupArrayIntersect(const DataTypes& argument_types_, + const bool result_is_nullable) + : IAggregateFunctionDataHelper, + AggregateFunctionGroupArrayIntersect>( + {argument_types_}), + argument_type(this->argument_types[0]) {} + + String get_name() const override { return "group_array_intersect"; } + + // DataTypePtr get_return_type() const override { return argument_type; } + DataTypePtr get_return_type() const override { + return std::make_shared(argument_type); + } + + bool allocates_memory_in_arena() const override { return false; } + + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { + auto& version = this->data(place).version; + auto& set = this->data(place).value; + + const auto data_column = assert_cast(*columns[0]).get_data_ptr(); + const auto& offsets = assert_cast(*columns[0]).get_offsets(); + const size_t offset = offsets[row_num - 1]; + const auto arr_size = offsets[row_num] - offset; + + ++version; + if (version == 1) { + for (size_t i = 0; i < arr_size; ++i) + set.insert(static_cast((*data_column)[offset + i].get())); + } else if (!set.empty()) { + typename State::Set new_set; + for (size_t i = 0; i < arr_size; ++i) { + typename State::Set::LookupResult set_value = + set.find(static_cast((*data_column)[offset + i].get())); + if (set_value != nullptr) + new_set.insert(static_cast((*data_column)[offset + i].get())); + } + set = std::move(new_set); + } + } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, + Arena*) const override { + auto& set = this->data(place).value; + const auto& rhs_set = this->data(rhs).value; + + if (this->data(rhs).version == 0) return; + + UInt64 version = this->data(place).version++; + if (version == 0) { + for (auto& rhs_elem : rhs_set) set.insert(rhs_elem.get_value()); + return; + } + + if (!set.empty()) { + auto create_new_set = [](auto& lhs_val, auto& rhs_val) { + typename State::Set new_set; + for (auto& lhs_elem : lhs_val) { + auto res = rhs_val.find(lhs_elem.get_value()); + if (res != nullptr) new_set.insert(lhs_elem.get_value()); + } + return new_set; + }; + auto new_set = rhs_set.size() < set.size() ? create_new_set(rhs_set, set) + : create_new_set(set, rhs_set); + set = std::move(new_set); + } + } + + void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { + auto& set = this->data(place).value; + auto version = this->data(place).version; + + write_var_uint(version, buf); + write_var_uint(set.size(), buf); + + for (const auto& elem : set) write_int_binary(elem.get_value(), buf); + } + + void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, + Arena*) const override { + read_var_uint(this->data(place).version, buf); + this->data(place).value.read(buf); + } + + void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { + ColumnArray& arr_to = assert_cast(to); + auto& offsets_to = arr_to.get_offsets(); + + const auto& set = this->data(place).value; + offsets_to.push_back(offsets_to.back() + set.size()); + + typename ColumnVector::Container& data_to = + assert_cast&>(arr_to.get_data()).get_data(); + size_t old_size = data_to.size(); + data_to.resize(old_size + set.size()); + + size_t i = 0; + for (auto it = set.begin(); it != set.end(); ++it, ++i) + data_to[old_size + i] = it->get_value(); + } +}; + +/// Generic implementation, it uses serialized representation as object descriptor. +struct AggregateFunctionGroupArrayIntersectGenericData { + using Set = HashSet; + + Set value; + UInt64 version = 0; +}; + +/** Template parameter with true value should be used for columns that store their elements in memory continuously. + * For such columns GroupArrayIntersect() can be implemented more efficiently (especially for small numeric arrays). + */ +template +class AggregateFunctionGroupArrayIntersectGeneric + : public IAggregateFunctionDataHelper< + AggregateFunctionGroupArrayIntersectGenericData, + AggregateFunctionGroupArrayIntersectGeneric> { + const DataTypePtr& input_data_type; + + using State = AggregateFunctionGroupArrayIntersectGenericData; + +public: + AggregateFunctionGroupArrayIntersectGeneric(const DataTypes& input_data_type_) + : IAggregateFunctionDataHelper< + AggregateFunctionGroupArrayIntersectGenericData, + AggregateFunctionGroupArrayIntersectGeneric>( + input_data_type_), + input_data_type(this->argument_types[0]) {} + + String get_name() const override { return "group_array_intersect"; } + + DataTypePtr get_return_type() const override { + return std::make_shared(input_data_type); + } + + bool allocates_memory_in_arena() const override { return true; } + + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena* arena) const override { + auto& set = this->data(place).value; + auto& version = this->data(place).version; + bool inserted; + State::Set::LookupResult it; + + const auto data_column = assert_cast(*columns[0]).get_data_ptr(); + const auto& offsets = assert_cast(*columns[0]).get_offsets(); + const size_t offset = offsets[row_num - 1]; + const auto arr_size = offsets[row_num] - offset; + + ++version; + if (version == 1) { + for (size_t i = 0; i < arr_size; ++i) { + if constexpr (is_plain_column) { + StringRef key = data_column->get_data_at(offset + i); + key.data = arena->insert(key.data, key.size); + set.emplace(key, it, inserted); + } else { + const char* begin = nullptr; + StringRef serialized = + data_column->serialize_value_into_arena(offset + i, *arena, begin); + assert(serialized.data != nullptr); + serialized.data = arena->insert(serialized.data, serialized.size); + set.emplace(serialized, it, inserted); + } + } + } else if (!set.empty()) { + typename State::Set new_set; + for (size_t i = 0; i < arr_size; ++i) { + if constexpr (is_plain_column) { + it = set.find(data_column->get_data_at(offset + i)); + if (it != nullptr) { + StringRef key = data_column->get_data_at(offset + i); + key.data = arena->insert(key.data, key.size); + new_set.emplace(key, it, inserted); + } + } else { + const char* begin = nullptr; + StringRef serialized = + data_column->serialize_value_into_arena(offset + i, *arena, begin); + assert(serialized.data != nullptr); + it = set.find(serialized); + + if (it != nullptr) { + serialized.data = arena->insert(serialized.data, serialized.size); + new_set.emplace(serialized, it, inserted); + } + } + } + set = std::move(new_set); + } + } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, + Arena* arena) const override { + auto& set = this->data(place).value; + const auto& rhs_value = this->data(rhs).value; + + if (this->data(rhs).version == 0) return; + + UInt64 version = this->data(place).version++; + if (version == 0) { + bool inserted; + State::Set::LookupResult it; + for (auto& rhs_elem : rhs_value) { + StringRef key = rhs_elem.get_value(); + key.data = arena->insert(key.data, key.size); + set.emplace(key, it, inserted); + } + } else if (!set.empty()) { + auto create_new_map = [](auto& lhs_val, auto& rhs_val) { + typename State::Set new_map; + for (auto& lhs_elem : lhs_val) { + auto val = rhs_val.find(lhs_elem.get_value()); + if (val != nullptr) new_map.insert(lhs_elem.get_value()); + } + return new_map; + }; + auto new_map = rhs_value.size() < set.size() ? create_new_map(rhs_value, set) + : create_new_map(set, rhs_value); + set = std::move(new_map); + } + } + + void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { + auto& set = this->data(place).value; + auto& version = this->data(place).version; + write_var_uint(version, buf); + write_var_uint(set.size(), buf); + + for (const auto& elem : set) write_string_binary(elem.get_value(), buf); + } + + void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, + Arena* arena) const override { + auto& set = this->data(place).value; + auto& version = this->data(place).version; + size_t size; + read_var_uint(version, buf); + read_var_uint(size, buf); + set.reserve(size); + UInt64 elem_version; + for (size_t i = 0; i < size; ++i) { + auto key = read_string_binary_into(*arena, buf); + read_var_uint(elem_version, buf); + set.insert(key); + } + } + + void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { + ColumnArray& arr_to = assert_cast(to); + auto& offsets_to = arr_to.get_offsets(); + IColumn& data_to = arr_to.get_data(); + + auto& set = this->data(place).value; + + offsets_to.push_back(offsets_to.back() + set.size()); + + for (auto& elem : set) { + if constexpr (is_plain_column) + data_to.insert_data(elem.get_value().data, elem.get_value().size); + else + std::ignore = data_to.deserialize_and_insert_from_arena(elem.get_value().data); + } + } +}; + +namespace { + +/// Substitute return type for DateV2 and DateTimeV2 +class AggregateFunctionGroupArrayIntersectDateV2 + : public AggregateFunctionGroupArrayIntersect { +public: + explicit AggregateFunctionGroupArrayIntersectDateV2(const DataTypes& argument_types_) + : AggregateFunctionGroupArrayIntersect( + DataTypes(argument_types_.begin(), argument_types_.end())) {} +}; + +class AggregateFunctionGroupArrayIntersectDateTimeV2 + : public AggregateFunctionGroupArrayIntersect { +public: + explicit AggregateFunctionGroupArrayIntersectDateTimeV2(const DataTypes& argument_types_) + : AggregateFunctionGroupArrayIntersect( + DataTypes(argument_types_.begin(), argument_types_.end())) {} +}; + +IAggregateFunction* create_with_extra_types(const DataTypes& argument_types) { + WhichDataType which(argument_types[0]); + if (which.idx == TypeIndex::DateV2) + return new AggregateFunctionGroupArrayIntersectDateV2(argument_types); + else if (which.idx == TypeIndex::DateTimeV2) + return new AggregateFunctionGroupArrayIntersectDateTimeV2(argument_types); + else { + /// Check that we can use plain version of AggregateFunctionGroupArrayIntersectGeneric + if (argument_types[0]->is_value_unambiguously_represented_in_contiguous_memory_region()) + return new AggregateFunctionGroupArrayIntersectGeneric(argument_types); + else + return new AggregateFunctionGroupArrayIntersectGeneric(argument_types); + } +} + +inline AggregateFunctionPtr create_aggregate_function_group_array_intersect_impl( + const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { + const auto& nested_type = + dynamic_cast(*(argument_types[0])).get_nested_type(); + DataTypes new_argument_types = {nested_type}; + AggregateFunctionPtr res( + creator_with_numeric_type::creator( + "", new_argument_types, result_is_nullable)); + if (!res) { + res = AggregateFunctionPtr(create_with_extra_types(argument_types)); + } + + if (!res) + throw Exception(ErrorCode::INVALID_ARGUMENT, + "Illegal type {} of argument for aggregate function {}", + argument_types[0]->get_name(), name); + + return res; +} +} // namespace + +} // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp index c33b8b50609635..056dddf175175e 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp @@ -49,6 +49,7 @@ void register_aggregate_function_stddev_variance_pop(AggregateFunctionSimpleFact void register_aggregate_function_stddev_variance_samp(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_topn(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_approx_count_distinct(AggregateFunctionSimpleFactory& factory); +void register_aggregate_function_group_array_intersect(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_group_concat(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_percentile(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_window_funnel(AggregateFunctionSimpleFactory& factory); @@ -79,6 +80,7 @@ AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() { register_aggregate_function_uniq(instance); register_aggregate_function_bit(instance); register_aggregate_function_bitmap(instance); + register_aggregate_function_group_array_intersect(instance); register_aggregate_function_group_concat(instance); register_aggregate_function_quantile_state(instance); register_aggregate_function_combinator_distinct(instance); diff --git a/be/src/vec/common/hash_table/hash_table.h b/be/src/vec/common/hash_table/hash_table.h index 04a5ff8f0e4c7b..ede7897ecdee74 100644 --- a/be/src/vec/common/hash_table/hash_table.h +++ b/be/src/vec/common/hash_table/hash_table.h @@ -806,6 +806,8 @@ class HashTable : private boost::noncopyable, } } + void reserve(size_t num_elements) { resize(num_elements); } + /// Insert a value. In the case of any more complex values, it is better to use the `emplace` function. std::pair ALWAYS_INLINE insert(const value_type& x) { std::pair res; diff --git a/docs/sidebars.json b/docs/sidebars.json index 0c130b20f0fbc4..a30e808a141d79 100644 --- a/docs/sidebars.json +++ b/docs/sidebars.json @@ -609,6 +609,7 @@ "sql-manual/sql-functions/aggregate-functions/percentile-array", "sql-manual/sql-functions/aggregate-functions/percentile-approx", "sql-manual/sql-functions/aggregate-functions/histogram", + "sql-manual/sql-functions/aggregate-functions/group-array-intersect", "sql-manual/sql-functions/aggregate-functions/group-bitmap-xor", "sql-manual/sql-functions/aggregate-functions/group-bit-and", "sql-manual/sql-functions/aggregate-functions/group-bit-or", diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java index b5184c33fcd546..0c214c826f7fea 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java @@ -1713,7 +1713,8 @@ && collectChildReturnTypes()[0].isDecimalV3()) { } if (fnName.getFunction().equalsIgnoreCase("group_uniq_array") - || fnName.getFunction().equalsIgnoreCase("group_array")) { + || fnName.getFunction().equalsIgnoreCase("group_array") + || fnName.getFunction().equalsIgnoreCase("group_array_intersect")) { fn.setReturnType(new ArrayType(getChild(0).type)); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java index e6134cfc31f9e5..6c7dd1abae0d15 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java @@ -33,6 +33,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.CountByEnum; import org.apache.doris.nereids.trees.expressions.functions.agg.Covar; import org.apache.doris.nereids.trees.expressions.functions.agg.CovarSamp; +import org.apache.doris.nereids.trees.expressions.functions.agg.GroupArrayIntersect; import org.apache.doris.nereids.trees.expressions.functions.agg.GroupBitAnd; import org.apache.doris.nereids.trees.expressions.functions.agg.GroupBitOr; import org.apache.doris.nereids.trees.expressions.functions.agg.GroupBitXor; @@ -101,6 +102,7 @@ public class BuiltinAggregateFunctions implements FunctionHelper { agg(CountByEnum.class, "count_by_enum"), agg(Covar.class, "covar", "covar_pop"), agg(CovarSamp.class, "covar_samp"), + agg(GroupArrayIntersect.class, "group_array_intersect"), agg(GroupBitAnd.class, "group_bit_and"), agg(GroupBitOr.class, "group_bit_or"), agg(GroupBitXor.class, "group_bit_xor"), diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java index b69d52b2c11c1b..73f622c9a82dfa 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java @@ -611,6 +611,8 @@ public void addBuiltinBothScalaAndVectorized(Function fn) { public static final String GROUP_ARRAY = "group_array"; + public static final String GROUP_ARRAY_INTERSECT = "group_array_intersect"; + public static final String ARRAY_AGG = "array_agg"; // Populate all the aggregate builtins in the catalog. @@ -1442,7 +1444,9 @@ private void initAggregateBuiltins() { addBuiltin( AggregateFunction.createBuiltin(GROUP_ARRAY, Lists.newArrayList(t, Type.INT), new ArrayType(t), t, "", "", "", "", "", true, false, true, true)); - + addBuiltin( + AggregateFunction.createBuiltin(GROUP_ARRAY_INTERSECT, Lists.newArrayList(new ArrayType(t)), + new ArrayType(t), Type.VARCHAR, "", "", "", "", "", true, false, true, true)); addBuiltin(AggregateFunction.createBuiltin(ARRAY_AGG, Lists.newArrayList(t), new ArrayType(t), t, "", "", "", "", "", true, false, true, true)); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupArrayIntersect.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupArrayIntersect.java new file mode 100644 index 00000000000000..3ff557140aad05 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupArrayIntersect.java @@ -0,0 +1,92 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions.agg; + +import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.ArrayType; +import org.apache.doris.nereids.types.coercion.AnyDataType; +import org.apache.doris.nereids.types.coercion.FollowToAnyDataType; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + +import java.util.List; + +/** + * AggregateFunction 'group_array_intersect'. + */ +public class GroupArrayIntersect extends NullableAggregateFunction + implements ExplicitlyCastableSignature { + + public static final List SIGNATURES = ImmutableList.of( + FunctionSignature.ret(ArrayType.of(new FollowToAnyDataType(0))) + .args(ArrayType.of(new AnyDataType(0))) + ); + + /** + * constructor with 1 argument. + */ + public GroupArrayIntersect(Expression arg) { + super("group_array_intersect", arg); + } + + /** + * constructor with 1 argument. + */ + public GroupArrayIntersect(boolean distinct, Expression arg) { + super("group_array_intersect", distinct, arg); + } + + public FunctionSignature computeSignature(FunctionSignature signature) { + signature = signature.withReturnType(ArrayType.of(getArgumentType(0))); + return super.computeSignature(signature); + } + + /** + * withChildren. + */ + @Override + public GroupArrayIntersect withChildren(List children) { + Preconditions.checkArgument(children.size() == 1); + return new GroupArrayIntersect(children.get(0)); + } + + @Override + public AggregateFunction withDistinctAndChildren(boolean distinct, List children) { + Preconditions.checkArgument(children.size() == 1); + return new CollectSet(distinct, children.get(0)); + } + + @Override + public R accept(ExpressionVisitor visitor, C context) { + return visitor.visitGroupArrayIntersect(this, context); + } + + @Override + public List getSignatures() { + return SIGNATURES; + } + + @Override + public NullableAggregateFunction withAlwaysNullable(boolean alwaysNullable) { + return null; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java index 594f9c754335aa..3431cf1203eb18 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java @@ -34,6 +34,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.CountByEnum; import org.apache.doris.nereids.trees.expressions.functions.agg.Covar; import org.apache.doris.nereids.trees.expressions.functions.agg.CovarSamp; +import org.apache.doris.nereids.trees.expressions.functions.agg.GroupArrayIntersect; import org.apache.doris.nereids.trees.expressions.functions.agg.GroupBitAnd; import org.apache.doris.nereids.trees.expressions.functions.agg.GroupBitOr; import org.apache.doris.nereids.trees.expressions.functions.agg.GroupBitXor; @@ -162,6 +163,10 @@ default R visitMultiDistinctSum(MultiDistinctSum multiDistinctSum, C context) { return visitAggregateFunction(multiDistinctSum, context); } + default R visitGroupArrayIntersect(GroupArrayIntersect groupArrayIntersect, C context) { + return visitAggregateFunction(groupArrayIntersect, context); + } + default R visitGroupBitAnd(GroupBitAnd groupBitAnd, C context) { return visitNullableAggregateFunction(groupBitAnd, context); } From e43c88dde8d08c5d460d58456063c675f990f09c Mon Sep 17 00:00:00 2001 From: chesterxu Date: Wed, 20 Mar 2024 11:40:35 +0800 Subject: [PATCH 02/21] fix and add logs --- ...gregate_function_group_array_intersect.cpp | 16 ++- ...aggregate_function_group_array_intersect.h | 130 ++++++++++++++++-- .../doris/analysis/FunctionCallExpr.java | 3 +- .../doris/catalog/AggregateFunction.java | 3 +- .../org/apache/doris/catalog/FunctionSet.java | 2 +- .../functions/agg/GroupArrayIntersect.java | 24 +--- 6 files changed, 136 insertions(+), 42 deletions(-) diff --git a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.cpp b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.cpp index b051e2d3950bc4..a812d8d9b71508 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.cpp @@ -24,12 +24,20 @@ namespace doris::vectorized { AggregateFunctionPtr create_aggregate_function_group_array_intersect( const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { assert_unary(name, argument_types); + std::string demangled_name_before = boost::core::demangle(typeid((argument_types[0])).name()); + LOG(INFO) << "in the cpp, before remove, name of argument_types[0]: " << demangled_name_before; + const DataTypePtr& argument_type = remove_nullable(argument_types[0]); - if (!WhichDataType(argument_types.at(0)).is_array()) - throw Exception(ErrorCode::INVALID_ARGUMENT, - "Aggregate function groupArrayIntersect accepts only array type argument."); + std::string demangled_name_argument_type = boost::core::demangle(typeid(argument_type).name()); + LOG(INFO) << "in the cpp, after remove, name of argument_type: " + << demangled_name_argument_type; - return create_aggregate_function_group_array_intersect_impl(name, argument_types, + if (!WhichDataType(argument_type).is_array()) + throw Exception(ErrorCode::INVALID_ARGUMENT, + "Aggregate function groupArrayIntersect accepts only array type argument. " + "Provided argument type: " + + argument_type->get_name()); + return create_aggregate_function_group_array_intersect_impl(name, {argument_type}, result_is_nullable); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h index faf58943d588f4..662116755d01f2 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h +++ b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h @@ -23,6 +23,7 @@ #include #include +#include #include #include @@ -68,21 +69,30 @@ class AggregateFunctionGroupArrayIntersect AggregateFunctionGroupArrayIntersect(const DataTypes& argument_types_) : IAggregateFunctionDataHelper, AggregateFunctionGroupArrayIntersect>( - {argument_types_}), - argument_type(this->argument_types[0]) {} + argument_types_), + argument_type(argument_types_[0]) {} AggregateFunctionGroupArrayIntersect(const DataTypes& argument_types_, const bool result_is_nullable) : IAggregateFunctionDataHelper, AggregateFunctionGroupArrayIntersect>( - {argument_types_}), - argument_type(this->argument_types[0]) {} + argument_types_), + argument_type(argument_types_[0]) {} String get_name() const override { return "group_array_intersect"; } // DataTypePtr get_return_type() const override { return argument_type; } DataTypePtr get_return_type() const override { - return std::make_shared(argument_type); + std::string demangled_name = boost::core::demangle(typeid(argument_type).name()); + LOG(INFO) << "in the get_return_type, name of argument_type: " << demangled_name; + std::string demangled_name_T = boost::core::demangle(typeid(T).name()); + LOG(INFO) << "in the get_return_type, name of T: " << demangled_name_T; + // return std::make_shared(make_nullable(argument_type)); + // return std::make_shared(argument_type); + using ReturnDataType = DataTypeNumber; + std::string demangled_name_re = boost::core::demangle(typeid(ReturnDataType).name()); + LOG(INFO) << "in the get_return_type, name of ReturnDataType: " << demangled_name_re; + return std::make_shared(std::make_shared()); } bool allocates_memory_in_arena() const override { return false; } @@ -92,10 +102,22 @@ class AggregateFunctionGroupArrayIntersect auto& version = this->data(place).version; auto& set = this->data(place).value; + LOG(INFO) << "Input row_num: " << row_num; // 输出输入的 row_num + const auto data_column = assert_cast(*columns[0]).get_data_ptr(); const auto& offsets = assert_cast(*columns[0]).get_offsets(); const size_t offset = offsets[row_num - 1]; const auto arr_size = offsets[row_num] - offset; + LOG(INFO) << "In add func, the name of *columns[0] is: " << (*columns[0]).get_name(); + LOG(INFO) << "In add func, offsets size: " << offsets.size(); + LOG(INFO) << "In add func, offsets capacity: " << offsets.capacity(); + + LOG(INFO) << "Before update: version = " << version + << ", set = {"; // 输出更新前的 version 和 set + for (const auto& elem : set) { + LOG(INFO) << elem.get_value() << " "; // 逐个输出 set 中的元素 + } + LOG(INFO) << "}"; ++version; if (version == 1) { @@ -111,6 +133,12 @@ class AggregateFunctionGroupArrayIntersect } set = std::move(new_set); } + + LOG(INFO) << "After update: set = {"; // 输出更新后的 set + for (const auto& elem : set) { + LOG(INFO) << elem.get_value() << " "; + } + LOG(INFO) << "}"; } void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, @@ -118,15 +146,25 @@ class AggregateFunctionGroupArrayIntersect auto& set = this->data(place).value; const auto& rhs_set = this->data(rhs).value; - if (this->data(rhs).version == 0) return; + LOG(INFO) << "merge: place set size = " << set.size() + << ", rhs set size = " << rhs_set.size(); + + if (this->data(rhs).version == 0) { + LOG(INFO) << "rhs version is 0, skipping merge"; + return; + } UInt64 version = this->data(place).version++; + LOG(INFO) << "merge: version = " << version; + if (version == 0) { + LOG(INFO) << "Copying rhs set to place set"; for (auto& rhs_elem : rhs_set) set.insert(rhs_elem.get_value()); return; } if (!set.empty()) { + LOG(INFO) << "Merging place set and rhs set"; auto create_new_set = [](auto& lhs_val, auto& rhs_val) { typename State::Set new_set; for (auto& lhs_elem : lhs_val) { @@ -139,39 +177,80 @@ class AggregateFunctionGroupArrayIntersect : create_new_set(set, rhs_set); set = std::move(new_set); } + + LOG(INFO) << "After merge: set = {"; + for (const auto& elem : set) { + LOG(INFO) << elem.get_value() << " "; + } + LOG(INFO) << "}"; } void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { auto& set = this->data(place).value; auto version = this->data(place).version; + LOG(INFO) << "serialize: version = " << version << ", set size = " << set.size(); + write_var_uint(version, buf); write_var_uint(set.size(), buf); - for (const auto& elem : set) write_int_binary(elem.get_value(), buf); + for (const auto& elem : set) { + LOG(INFO) << "Serializing element: " << elem.get_value(); + write_int_binary(elem.get_value(), buf); + } } void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, Arena*) const override { + LOG(INFO) << "deserialize"; read_var_uint(this->data(place).version, buf); + LOG(INFO) << "Deserialized version: " << this->data(place).version; this->data(place).value.read(buf); + LOG(INFO) << "Deserialized set: {"; + for (const auto& elem : this->data(place).value) { + LOG(INFO) << elem.get_value() << " "; + } + LOG(INFO) << "}"; } void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { + LOG(INFO) << "in the insert_result_into, name of T: " + << boost::core::demangle(typeid(T).name()); + LOG(INFO) << "in the start of insert, Type name: " << typeid(T).name(); + LOG(INFO) << "the name of to is: " << to.get_name(); + ColumnArray& arr_to = assert_cast(to); - auto& offsets_to = arr_to.get_offsets(); + LOG(INFO) << "the name of arr_to is: " << arr_to.get_name(); + + ColumnArray::Offsets64& offsets_to = arr_to.get_offsets(); const auto& set = this->data(place).value; - offsets_to.push_back(offsets_to.back() + set.size()); + LOG(INFO) << "insert_result_into: set size = " << set.size(); + + if (offsets_to.size() == 0) { + offsets_to.push_back(set.size()); + } else { + offsets_to.push_back(offsets_to.back() + set.size()); + } typename ColumnVector::Container& data_to = assert_cast&>(arr_to.get_data()).get_data(); + std::string demangled_name = boost::core::demangle(typeid(data_to).name()); + LOG(INFO) << "name of data_to: " << demangled_name << std::endl; size_t old_size = data_to.size(); + LOG(INFO) << "old_size of data_to: " << old_size; + data_to.resize(old_size + set.size()); + LOG(INFO) << "after resize, size of data_to: " << data_to.size(); size_t i = 0; - for (auto it = set.begin(); it != set.end(); ++it, ++i) - data_to[old_size + i] = it->get_value(); + for (auto it = set.begin(); it != set.end(); ++it, ++i) { + T value = it->get_value(); + LOG(INFO) << "Inserting value: " << value << " at index " << (old_size + i); + data_to[old_size + i] = value; + } + LOG(INFO) << "after for loop, size of data_to: " << data_to.size(); + LOG(INFO) << "After making value to data_to, leaving..."; } }; @@ -201,11 +280,13 @@ class AggregateFunctionGroupArrayIntersectGeneric AggregateFunctionGroupArrayIntersectGenericData, AggregateFunctionGroupArrayIntersectGeneric>( input_data_type_), - input_data_type(this->argument_types[0]) {} + input_data_type(input_data_type_[0]) {} String get_name() const override { return "group_array_intersect"; } + // DataTypePtr get_return_type() const override { return input_data_type; } DataTypePtr get_return_type() const override { + // return std::make_shared(make_nullable(input_data_type)); return std::make_shared(input_data_type); } @@ -377,13 +458,32 @@ IAggregateFunction* create_with_extra_types(const DataTypes& argument_types) { inline AggregateFunctionPtr create_aggregate_function_group_array_intersect_impl( const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { const auto& nested_type = - dynamic_cast(*(argument_types[0])).get_nested_type(); - DataTypes new_argument_types = {nested_type}; + dynamic_cast(*(remove_nullable(argument_types[0]))) + .get_nested_type(); + WhichDataType which_type(remove_nullable(nested_type)); + if (which_type.is_int()) { + LOG(INFO) << "nested_type is int"; + } else { + LOG(INFO) << "nested_type is not int"; + } + DataTypes new_argument_types = {remove_nullable(nested_type)}; + std::string demangled_name_nested_type = + boost::core::demangle(typeid(remove_nullable(nested_type)).name()); + LOG(INFO) << "in the create_aggregate_function_group_array_intersect_impl, name of T: " + << demangled_name_nested_type; AggregateFunctionPtr res( creator_with_numeric_type::creator( "", new_argument_types, result_is_nullable)); +// AggregateFunctionPtr res = nullptr; +// WhichDataType which(*remove_nullable(nested_type)); +// #define DISPATCH(TYPE) \ +// if (which.idx == TypeIndex::TYPE) \ +// res = creator_without_type::create>( \ +// argument_types, result_is_nullable); +// FOR_NUMERIC_TYPES(DISPATCH) +// #undef DISPATCH if (!res) { - res = AggregateFunctionPtr(create_with_extra_types(argument_types)); + res = AggregateFunctionPtr(create_with_extra_types(remove_nullable(argument_types))); } if (!res) diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java index 0c214c826f7fea..b5184c33fcd546 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java @@ -1713,8 +1713,7 @@ && collectChildReturnTypes()[0].isDecimalV3()) { } if (fnName.getFunction().equalsIgnoreCase("group_uniq_array") - || fnName.getFunction().equalsIgnoreCase("group_array") - || fnName.getFunction().equalsIgnoreCase("group_array_intersect")) { + || fnName.getFunction().equalsIgnoreCase("group_array")) { fn.setReturnType(new ArrayType(getChild(0).type)); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java index d954ab41a95f94..55df2233197241 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java @@ -55,7 +55,8 @@ public class AggregateFunction extends Function { FunctionSet.COUNT, "approx_count_distinct", "ndv", FunctionSet.BITMAP_UNION_INT, FunctionSet.BITMAP_UNION_COUNT, "ndv_no_finalize", FunctionSet.WINDOW_FUNNEL, FunctionSet.RETENTION, FunctionSet.SEQUENCE_MATCH, FunctionSet.SEQUENCE_COUNT, FunctionSet.MAP_AGG, FunctionSet.BITMAP_AGG, - FunctionSet.ARRAY_AGG, FunctionSet.COLLECT_LIST, FunctionSet.COLLECT_SET); + FunctionSet.ARRAY_AGG, FunctionSet.COLLECT_LIST, FunctionSet.COLLECT_SET, + FunctionSet.GROUP_ARRAY_INTERSECT); public static ImmutableSet ALWAYS_NULLABLE_AGGREGATE_FUNCTION_NAME_SET = ImmutableSet.of("stddev_samp", "variance_samp", "var_samp", "percentile_approx", "first_value", diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java index 73f622c9a82dfa..61088bdf974e7e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java @@ -1446,7 +1446,7 @@ private void initAggregateBuiltins() { t, "", "", "", "", "", true, false, true, true)); addBuiltin( AggregateFunction.createBuiltin(GROUP_ARRAY_INTERSECT, Lists.newArrayList(new ArrayType(t)), - new ArrayType(t), Type.VARCHAR, "", "", "", "", "", true, false, true, true)); + new ArrayType(t), t, "", "", "", "", "", true, false, true, true)); addBuiltin(AggregateFunction.createBuiltin(ARRAY_AGG, Lists.newArrayList(t), new ArrayType(t), t, "", "", "", "", "", true, false, true, true)); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupArrayIntersect.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupArrayIntersect.java index 3ff557140aad05..c6b1bc96a6d0c7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupArrayIntersect.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupArrayIntersect.java @@ -19,7 +19,9 @@ import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable; import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; +import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.ArrayType; import org.apache.doris.nereids.types.coercion.AnyDataType; @@ -33,8 +35,8 @@ /** * AggregateFunction 'group_array_intersect'. */ -public class GroupArrayIntersect extends NullableAggregateFunction - implements ExplicitlyCastableSignature { +public class GroupArrayIntersect extends AggregateFunction + implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNotNullable { public static final List SIGNATURES = ImmutableList.of( FunctionSignature.ret(ArrayType.of(new FollowToAnyDataType(0))) @@ -55,24 +57,13 @@ public GroupArrayIntersect(boolean distinct, Expression arg) { super("group_array_intersect", distinct, arg); } - public FunctionSignature computeSignature(FunctionSignature signature) { - signature = signature.withReturnType(ArrayType.of(getArgumentType(0))); - return super.computeSignature(signature); - } - /** * withChildren. */ - @Override - public GroupArrayIntersect withChildren(List children) { - Preconditions.checkArgument(children.size() == 1); - return new GroupArrayIntersect(children.get(0)); - } - @Override public AggregateFunction withDistinctAndChildren(boolean distinct, List children) { Preconditions.checkArgument(children.size() == 1); - return new CollectSet(distinct, children.get(0)); + return new GroupArrayIntersect(distinct, children.get(0)); } @Override @@ -84,9 +75,4 @@ public R accept(ExpressionVisitor visitor, C context) { public List getSignatures() { return SIGNATURES; } - - @Override - public NullableAggregateFunction withAlwaysNullable(boolean alwaysNullable) { - return null; - } } From e831cd7317cc22859e8746dcda967536360a6418 Mon Sep 17 00:00:00 2001 From: chesterxu Date: Fri, 22 Mar 2024 11:35:41 +0800 Subject: [PATCH 03/21] fix column in add, fix impl --- ...aggregate_function_group_array_intersect.h | 66 +++++++++++-------- 1 file changed, 40 insertions(+), 26 deletions(-) diff --git a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h index 662116755d01f2..2b10e6bc219b7f 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h +++ b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h @@ -88,11 +88,12 @@ class AggregateFunctionGroupArrayIntersect std::string demangled_name_T = boost::core::demangle(typeid(T).name()); LOG(INFO) << "in the get_return_type, name of T: " << demangled_name_T; // return std::make_shared(make_nullable(argument_type)); - // return std::make_shared(argument_type); - using ReturnDataType = DataTypeNumber; - std::string demangled_name_re = boost::core::demangle(typeid(ReturnDataType).name()); - LOG(INFO) << "in the get_return_type, name of ReturnDataType: " << demangled_name_re; - return std::make_shared(std::make_shared()); + return std::make_shared(argument_type); + + // using ReturnDataType = DataTypeNumber; + // std::string demangled_name_re = boost::core::demangle(typeid(ReturnDataType).name()); + // LOG(INFO) << "in the get_return_type, name of ReturnDataType: " << demangled_name_re; + // return std::make_shared(std::make_shared()); } bool allocates_memory_in_arena() const override { return false; } @@ -103,9 +104,18 @@ class AggregateFunctionGroupArrayIntersect auto& set = this->data(place).value; LOG(INFO) << "Input row_num: " << row_num; // 输出输入的 row_num - - const auto data_column = assert_cast(*columns[0]).get_data_ptr(); - const auto& offsets = assert_cast(*columns[0]).get_offsets(); + const bool col_is_nullable = (*columns[0]).is_nullable(); + const ColumnArray& column = + col_is_nullable ? assert_cast( + assert_cast(*columns[0]) + .get_nested_column()) + : assert_cast(*columns[0]); + + std::string demangled_name_column = boost::core::demangle(typeid(column).name()); + LOG(INFO) << "In add func, the name of column: " << demangled_name_column; + + const auto data_column = column.get_data_ptr(); + const auto& offsets = column.get_offsets(); const size_t offset = offsets[row_num - 1]; const auto arr_size = offsets[row_num] - offset; LOG(INFO) << "In add func, the name of *columns[0] is: " << (*columns[0]).get_name(); @@ -458,30 +468,34 @@ IAggregateFunction* create_with_extra_types(const DataTypes& argument_types) { inline AggregateFunctionPtr create_aggregate_function_group_array_intersect_impl( const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { const auto& nested_type = - dynamic_cast(*(remove_nullable(argument_types[0]))) - .get_nested_type(); + dynamic_cast(*(argument_types[0])).get_nested_type(); WhichDataType which_type(remove_nullable(nested_type)); if (which_type.is_int()) { LOG(INFO) << "nested_type is int"; } else { LOG(INFO) << "nested_type is not int"; } - DataTypes new_argument_types = {remove_nullable(nested_type)}; - std::string demangled_name_nested_type = - boost::core::demangle(typeid(remove_nullable(nested_type)).name()); - LOG(INFO) << "in the create_aggregate_function_group_array_intersect_impl, name of T: " - << demangled_name_nested_type; - AggregateFunctionPtr res( - creator_with_numeric_type::creator( - "", new_argument_types, result_is_nullable)); -// AggregateFunctionPtr res = nullptr; -// WhichDataType which(*remove_nullable(nested_type)); -// #define DISPATCH(TYPE) \ -// if (which.idx == TypeIndex::TYPE) \ -// res = creator_without_type::create>( \ -// argument_types, result_is_nullable); -// FOR_NUMERIC_TYPES(DISPATCH) -// #undef DISPATCH + DataTypes new_argument_types = {nested_type}; + LOG(INFO) << "in the create_aggregate_function_group_array_intersect_impl, name of " + "remove_nullable(nested_type): " + << boost::core::demangle(typeid(remove_nullable(nested_type)).name()); + + const auto& argument_type = dynamic_cast(*(argument_types[0])); + LOG(INFO) << "in the create_aggregate_function_group_array_intersect_impl, name of " + "argument_type: " + << boost::core::demangle(typeid(argument_type).name()); + + // AggregateFunctionPtr res( + // creator_with_numeric_type::creator( + // "", new_argument_types, result_is_nullable)); + AggregateFunctionPtr res = nullptr; + WhichDataType which(remove_nullable(nested_type)); +#define DISPATCH(TYPE) \ + if (which.idx == TypeIndex::TYPE) \ + res = creator_without_type::create>( \ + new_argument_types, result_is_nullable); + FOR_NUMERIC_TYPES(DISPATCH) +#undef DISPATCH if (!res) { res = AggregateFunctionPtr(create_with_extra_types(remove_nullable(argument_types))); } From 56860f3e92c202e68f518bfbb93b39a2c0bfee27 Mon Sep 17 00:00:00 2001 From: chesterxu Date: Fri, 22 Mar 2024 19:08:38 +0800 Subject: [PATCH 04/21] fix cast to data_to --- .../aggregate_function_group_array_intersect.h | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h index 2b10e6bc219b7f..4bce94165f7534 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h +++ b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h @@ -88,7 +88,7 @@ class AggregateFunctionGroupArrayIntersect std::string demangled_name_T = boost::core::demangle(typeid(T).name()); LOG(INFO) << "in the get_return_type, name of T: " << demangled_name_T; // return std::make_shared(make_nullable(argument_type)); - return std::make_shared(argument_type); + return argument_type; // using ReturnDataType = DataTypeNumber; // std::string demangled_name_re = boost::core::demangle(typeid(ReturnDataType).name()); @@ -243,8 +243,13 @@ class AggregateFunctionGroupArrayIntersect offsets_to.push_back(offsets_to.back() + set.size()); } - typename ColumnVector::Container& data_to = - assert_cast&>(arr_to.get_data()).get_data(); + auto& to_nested_col = arr_to.get_data(); + using ElementType = T; + using ColVecType = ColumnVector; + auto col_null = reinterpret_cast(&to_nested_col); + auto& data_to = assert_cast(col_null->get_nested_column()).get_data(); + // typename ColumnVector::Container& data_to = + // assert_cast&>(arr_to.get_data()).get_data(); std::string demangled_name = boost::core::demangle(typeid(data_to).name()); LOG(INFO) << "name of data_to: " << demangled_name << std::endl; size_t old_size = data_to.size(); @@ -493,7 +498,7 @@ inline AggregateFunctionPtr create_aggregate_function_group_array_intersect_impl #define DISPATCH(TYPE) \ if (which.idx == TypeIndex::TYPE) \ res = creator_without_type::create>( \ - new_argument_types, result_is_nullable); + argument_types, result_is_nullable); FOR_NUMERIC_TYPES(DISPATCH) #undef DISPATCH if (!res) { From 95f7026a83cee2a1b0042c10e8d38f4ee2e62fe6 Mon Sep 17 00:00:00 2001 From: chesterxu Date: Sat, 23 Mar 2024 17:37:22 +0800 Subject: [PATCH 05/21] add to_nested_col.is_nullable() to judge --- ...aggregate_function_group_array_intersect.h | 73 ++++++++++++++----- 1 file changed, 55 insertions(+), 18 deletions(-) diff --git a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h index 4bce94165f7534..c0602b3d6f066f 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h +++ b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h @@ -246,26 +246,63 @@ class AggregateFunctionGroupArrayIntersect auto& to_nested_col = arr_to.get_data(); using ElementType = T; using ColVecType = ColumnVector; - auto col_null = reinterpret_cast(&to_nested_col); - auto& data_to = assert_cast(col_null->get_nested_column()).get_data(); - // typename ColumnVector::Container& data_to = - // assert_cast&>(arr_to.get_data()).get_data(); - std::string demangled_name = boost::core::demangle(typeid(data_to).name()); - LOG(INFO) << "name of data_to: " << demangled_name << std::endl; - size_t old_size = data_to.size(); - LOG(INFO) << "old_size of data_to: " << old_size; - - data_to.resize(old_size + set.size()); - LOG(INFO) << "after resize, size of data_to: " << data_to.size(); - - size_t i = 0; - for (auto it = set.begin(); it != set.end(); ++it, ++i) { - T value = it->get_value(); - LOG(INFO) << "Inserting value: " << value << " at index " << (old_size + i); - data_to[old_size + i] = value; + + bool is_nullable = to_nested_col.is_nullable(); + + if (is_nullable) { + LOG(INFO) << "nested_col is nullable. "; + auto col_null = reinterpret_cast(&to_nested_col); + auto& nested_col = assert_cast(col_null->get_nested_column()); + + size_t old_size = nested_col.get_data().size(); + LOG(INFO) << "old_size of data_to: " << old_size; + nested_col.get_data().resize(old_size + set.size()); + + size_t i = 0; + for (auto it = set.begin(); it != set.end(); ++it, ++i) { + T value = it->get_value(); + LOG(INFO) << "Inserting value: " << value << " at index " << (old_size + i); + nested_col.get_data()[old_size + i] = value; + } + + auto& null_map_data = col_null->get_null_map_data(); + null_map_data.resize_fill(nested_col.size(), 0); + } else { + LOG(INFO) << "nested_col is not nullable. "; + auto& nested_col = static_cast(to_nested_col); + size_t old_size = nested_col.get_data().size(); + LOG(INFO) << "old_size of data_to: " << old_size; + nested_col.get_data().resize(old_size + set.size()); + + size_t i = 0; + for (auto it = set.begin(); it != set.end(); ++it, ++i) { + T value = it->get_value(); + LOG(INFO) << "Inserting value: " << value << " at index " << (old_size + i); + nested_col.get_data()[old_size + i] = value; + } } - LOG(INFO) << "after for loop, size of data_to: " << data_to.size(); LOG(INFO) << "After making value to data_to, leaving..."; + + // auto col_null = reinterpret_cast(&to_nested_col); + // auto& data_to = assert_cast(col_null->get_nested_column()).get_data(); + // // typename ColumnVector::Container& data_to = + // // assert_cast&>(arr_to.get_data()).get_data(); + // std::string demangled_name = boost::core::demangle(typeid(data_to).name()); + // LOG(INFO) << "name of data_to: " << demangled_name << std::endl; + // size_t old_size = data_to.size(); + // LOG(INFO) << "old_size of data_to: " << old_size; + + // data_to.resize(old_size + set.size()); + // LOG(INFO) << "after resize, size of data_to: " << data_to.size(); + + // size_t i = 0; + // for (auto it = set.begin(); it != set.end(); ++it, ++i) { + // T value = it->get_value(); + // LOG(INFO) << "Inserting value: " << value << " at index " << (old_size + i); + // data_to[old_size + i] = value; + // } + // LOG(INFO) << "after for loop, size of data_to: " << data_to.size(); + // LOG(INFO) << "After making value to data_to, leaving..."; } }; From cb771dbfa0e3a2b42890a34a36251b5c5fa73010 Mon Sep 17 00:00:00 2001 From: chesterxu Date: Sat, 23 Mar 2024 19:29:53 +0800 Subject: [PATCH 06/21] fix float type --- ...aggregate_function_group_array_intersect.h | 37 ++++++++++++++----- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h index c0602b3d6f066f..0c3e5d1d5f6ae0 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h +++ b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h @@ -110,11 +110,16 @@ class AggregateFunctionGroupArrayIntersect assert_cast(*columns[0]) .get_nested_column()) : assert_cast(*columns[0]); + const auto& nested_column = + assert_cast(column.get_data()).get_nested_column(); + using ColVecType = ColumnVector; + const auto& nested_column_data = assert_cast(nested_column); - std::string demangled_name_column = boost::core::demangle(typeid(column).name()); - LOG(INFO) << "In add func, the name of column: " << demangled_name_column; + std::string demangled_name_column = + boost::core::demangle(typeid(nested_column_data).name()); + LOG(INFO) << "In add func, the name of nested_column_data: " << demangled_name_column; - const auto data_column = column.get_data_ptr(); + // const auto data_column = nested_column_data.get_data(); const auto& offsets = column.get_offsets(); const size_t offset = offsets[row_num - 1]; const auto arr_size = offsets[row_num] - offset; @@ -131,20 +136,32 @@ class AggregateFunctionGroupArrayIntersect ++version; if (version == 1) { - for (size_t i = 0; i < arr_size; ++i) - set.insert(static_cast((*data_column)[offset + i].get())); + LOG(INFO) << "Inserting elements into an empty set..."; + for (size_t i = 0; i < arr_size; ++i) { + // T value = (*data_column)[offset + i].get(); + T value = nested_column_data.get_element(offset + i); + LOG(INFO) << "Inserting value: " << value; + set.insert(value); + } } else if (!set.empty()) { + LOG(INFO) << "Updating an existing set..."; typename State::Set new_set; for (size_t i = 0; i < arr_size; ++i) { - typename State::Set::LookupResult set_value = - set.find(static_cast((*data_column)[offset + i].get())); - if (set_value != nullptr) - new_set.insert(static_cast((*data_column)[offset + i].get())); + // T value = (*data_column)[offset + i].get(); + T value = nested_column_data.get_element(offset + i); + LOG(INFO) << "Checking value: " << value; + typename State::Set::LookupResult set_value = set.find(value); + if (set_value != nullptr) { + LOG(INFO) << "Value found in the set, inserting into new_set"; + new_set.insert(value); + } else { + LOG(INFO) << "Value not found in the set, skipping"; + } } set = std::move(new_set); } - LOG(INFO) << "After update: set = {"; // 输出更新后的 set + LOG(INFO) << "After update: set = {"; for (const auto& elem : set) { LOG(INFO) << elem.get_value() << " "; } From db38507547efebf5067ac6d3b24e2dd8f85bc457 Mon Sep 17 00:00:00 2001 From: chesterxu Date: Sun, 24 Mar 2024 19:15:46 +0800 Subject: [PATCH 07/21] fix return 0 by adding has_null in data --- ...aggregate_function_group_array_intersect.h | 205 ++++++++++++++---- 1 file changed, 160 insertions(+), 45 deletions(-) diff --git a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h index 0c3e5d1d5f6ae0..76d362197d706b 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h +++ b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h @@ -54,6 +54,7 @@ struct AggregateFunctionGroupArrayIntersectData { Set value; UInt64 version = 0; + UInt64 has_null = 0; // 添加一个私有变量has_null }; /// Puts all values to the hash set. Returns an array of unique values. Implemented for numeric types. @@ -102,6 +103,7 @@ class AggregateFunctionGroupArrayIntersect Arena*) const override { auto& version = this->data(place).version; auto& set = this->data(place).value; + auto& has_null = this->data(place).has_null; LOG(INFO) << "Input row_num: " << row_num; // 输出输入的 row_num const bool col_is_nullable = (*columns[0]).is_nullable(); @@ -110,16 +112,30 @@ class AggregateFunctionGroupArrayIntersect assert_cast(*columns[0]) .get_nested_column()) : assert_cast(*columns[0]); - const auto& nested_column = - assert_cast(column.get_data()).get_nested_column(); + using ColVecType = ColumnVector; - const auto& nested_column_data = assert_cast(nested_column); + LOG(INFO) << "the name of column is: " << column.get_name(); + + const auto& column_data = column.get_data(); + + bool is_column_data_nullable = column_data.is_nullable(); + ColumnNullable* col_null = nullptr; + const ColVecType* nested_column_data = nullptr; + + if (is_column_data_nullable) { + LOG(INFO) << "nested_col is nullable. "; + auto const_col_data = const_cast(&column_data); + col_null = static_cast(const_col_data); + nested_column_data = &assert_cast(col_null->get_nested_column()); + } else { + LOG(INFO) << "nested_col is not nullable. "; + nested_column_data = &static_cast(column_data); + } std::string demangled_name_column = - boost::core::demangle(typeid(nested_column_data).name()); + boost::core::demangle(typeid(*nested_column_data).name()); LOG(INFO) << "In add func, the name of nested_column_data: " << demangled_name_column; - // const auto data_column = nested_column_data.get_data(); const auto& offsets = column.get_offsets(); const size_t offset = offsets[row_num - 1]; const auto arr_size = offsets[row_num] - offset; @@ -138,17 +154,34 @@ class AggregateFunctionGroupArrayIntersect if (version == 1) { LOG(INFO) << "Inserting elements into an empty set..."; for (size_t i = 0; i < arr_size; ++i) { - // T value = (*data_column)[offset + i].get(); - T value = nested_column_data.get_element(offset + i); - LOG(INFO) << "Inserting value: " << value; - set.insert(value); + if (col_null->is_null_at(offset + i)) { + LOG(INFO) << "Encounting null: "; + // this->data(place).null_set.insert(offset + i); + has_null = 1; + break; + } else { + T value = nested_column_data->get_element(offset + i); + LOG(INFO) << "Inserting value: " << value; + set.insert(value); + } } } else if (!set.empty()) { LOG(INFO) << "Updating an existing set..."; typename State::Set new_set; + bool found_null = false; for (size_t i = 0; i < arr_size; ++i) { - // T value = (*data_column)[offset + i].get(); - T value = nested_column_data.get_element(offset + i); + T value; // 将value的声明移动到循环体的开始 + if (col_null && col_null->is_null_at(offset + i)) { + LOG(INFO) << "Encounting null: "; + if (!found_null) { // 如果还没有找到null值 + // this->data(place).null_set.insert(offset + i); + has_null = 1; + found_null = true; // 更新标志 + } + break; // 遇到null值,跳出循环 + } else { + value = nested_column_data->get_element(offset + i); + } LOG(INFO) << "Checking value: " << value; typename State::Set::LookupResult set_value = set.find(value); if (set_value != nullptr) { @@ -166,12 +199,25 @@ class AggregateFunctionGroupArrayIntersect LOG(INFO) << elem.get_value() << " "; } LOG(INFO) << "}"; + + if (is_column_data_nullable) { + auto& null_map_data = col_null->get_null_map_data(); + null_map_data.resize_fill(nested_column_data->size(), 0); + } } void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena*) const override { + auto place_has_null = this->data(place).has_null; + auto rhs_has_null = this->data(rhs).has_null; + LOG(INFO) << "before, merge: place_has_null = " << place_has_null + << ", rhs_has_null = " << rhs_has_null; + this->data(place).has_null = rhs_has_null; + auto& set = this->data(place).value; const auto& rhs_set = this->data(rhs).value; + // auto& null_set = this->data(place).null_set; + // const auto& rhs_null_set = this->data(rhs).null_set; LOG(INFO) << "merge: place set size = " << set.size() << ", rhs set size = " << rhs_set.size(); @@ -210,37 +256,81 @@ class AggregateFunctionGroupArrayIntersect LOG(INFO) << elem.get_value() << " "; } LOG(INFO) << "}"; + // null_set.insert(rhs_null_set.begin(), rhs_null_set.end()); + LOG(INFO) << "after, merge: has_null = " << this->data(place).has_null; } void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { auto& set = this->data(place).value; auto version = this->data(place).version; + auto has_null = this->data(place).has_null; + + LOG(INFO) << "before, serialize: has_null = " << has_null; LOG(INFO) << "serialize: version = " << version << ", set size = " << set.size(); - write_var_uint(version, buf); - write_var_uint(set.size(), buf); + write_var_uint(has_null, buf); - for (const auto& elem : set) { - LOG(INFO) << "Serializing element: " << elem.get_value(); - write_int_binary(elem.get_value(), buf); - } + // if (has_null == 0) { + write_var_uint(version, buf); + write_var_uint(set.size(), buf); + for (const auto& elem : set) { + LOG(INFO) << "Serializing element: " << elem.get_value(); + write_int_binary(elem.get_value(), buf); + } + // } + // } else { + // } + LOG(INFO) << "after, serialize: has_null = " << has_null; + // // 序列化null_set + // write_var_uint(null_set.size(), buf); + // for (const auto& null_index : null_set) { + // write_var_uint(null_index, buf); + // } } void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, Arena*) const override { LOG(INFO) << "deserialize"; - read_var_uint(this->data(place).version, buf); - LOG(INFO) << "Deserialized version: " << this->data(place).version; - this->data(place).value.read(buf); - LOG(INFO) << "Deserialized set: {"; - for (const auto& elem : this->data(place).value) { - LOG(INFO) << elem.get_value() << " "; - } - LOG(INFO) << "}"; + // auto& has_null = this->data(place).has_null; + + // LOG(INFO) << "before, deserialize: has_null = " << has_null; + + read_var_uint(this->data(place).has_null, buf); + auto has_null = this->data(place).has_null; + LOG(INFO) << "another way to read the has_null: " << this->data(place).has_null; + LOG(INFO) << "another has_null value: " << has_null; + + // if (has_null == 0) { + read_var_uint(this->data(place).version, buf); + LOG(INFO) << "Deserialized version: " << this->data(place).version; + this->data(place).value.read(buf); + LOG(INFO) << "Deserialized set: {"; + for (const auto& elem : this->data(place).value) { + LOG(INFO) << elem.get_value() << " "; + } + LOG(INFO) << "}"; + // } + // 读取has_null的值 + // } else { + // read_var_uint(has_null, buf); + // LOG(INFO) << "Deserialized has_null: " << this->data(place).has_null; + // } + + LOG(INFO) << "after, deserialize: has_null = " << this->data(place).has_null; + + // size_t null_set_size; + // read_var_uint(null_set_size, buf); + // for (size_t i = 0; i < null_set_size; ++i) { + // size_t null_index; + // read_var_uint(null_index, buf); + // this->data(place).null_set.insert(null_index); + // } } void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { + LOG(INFO) << "before, insert: has_null = " << this->data(place).has_null; + LOG(INFO) << "in the insert_result_into, name of T: " << boost::core::demangle(typeid(T).name()); LOG(INFO) << "in the start of insert, Type name: " << typeid(T).name(); @@ -251,15 +341,6 @@ class AggregateFunctionGroupArrayIntersect ColumnArray::Offsets64& offsets_to = arr_to.get_offsets(); - const auto& set = this->data(place).value; - LOG(INFO) << "insert_result_into: set size = " << set.size(); - - if (offsets_to.size() == 0) { - offsets_to.push_back(set.size()); - } else { - offsets_to.push_back(offsets_to.back() + set.size()); - } - auto& to_nested_col = arr_to.get_data(); using ElementType = T; using ColVecType = ColumnVector; @@ -270,25 +351,59 @@ class AggregateFunctionGroupArrayIntersect LOG(INFO) << "nested_col is nullable. "; auto col_null = reinterpret_cast(&to_nested_col); auto& nested_col = assert_cast(col_null->get_nested_column()); + auto& null_map_data = col_null->get_null_map_data(); - size_t old_size = nested_col.get_data().size(); - LOG(INFO) << "old_size of data_to: " << old_size; - nested_col.get_data().resize(old_size + set.size()); + if (this->data(place).has_null == 1) { + LOG(INFO) << "We have null, pass!"; + // nested_col.insert_default(); + // if (offsets_to.size() == 0) { + // offsets_to.push_back(0); + // } else { + // offsets_to.push_back(offsets_to.back()); + // } + // nested_col.get_data().resize(1); + // nested_col.insert_default(); + col_null->insert_data(nullptr, 0); + offsets_to.push_back(to_nested_col.size()); + // null_map_data.push_back(1); + } else { + size_t old_size = nested_col.get_data().size(); + LOG(INFO) << "old_size of data_to: " << old_size; + + const auto& set = this->data(place).value; + LOG(INFO) << "insert_result_into: set size = " << set.size(); + + if (offsets_to.size() == 0) { + offsets_to.push_back(set.size()); + } else { + offsets_to.push_back(offsets_to.back() + set.size()); + } - size_t i = 0; - for (auto it = set.begin(); it != set.end(); ++it, ++i) { - T value = it->get_value(); - LOG(INFO) << "Inserting value: " << value << " at index " << (old_size + i); - nested_col.get_data()[old_size + i] = value; - } + nested_col.get_data().resize(old_size + set.size()); - auto& null_map_data = col_null->get_null_map_data(); - null_map_data.resize_fill(nested_col.size(), 0); + size_t i = 0; + for (auto it = set.begin(); it != set.end(); ++it, ++i) { + T value = it->get_value(); + LOG(INFO) << "Inserting value: " << value << " at index " << (old_size + i); + nested_col.get_data()[old_size + i] = value; + null_map_data.push_back(0); + } + } } else { LOG(INFO) << "nested_col is not nullable. "; auto& nested_col = static_cast(to_nested_col); size_t old_size = nested_col.get_data().size(); LOG(INFO) << "old_size of data_to: " << old_size; + + const auto& set = this->data(place).value; + LOG(INFO) << "insert_result_into: set size = " << set.size(); + + if (offsets_to.size() == 0) { + offsets_to.push_back(set.size()); + } else { + offsets_to.push_back(offsets_to.back() + set.size()); + } + nested_col.get_data().resize(old_size + set.size()); size_t i = 0; From 67c6ebdcdde87074d4061ea19bd09662412d42ee Mon Sep 17 00:00:00 2001 From: chesterxu Date: Wed, 27 Mar 2024 11:39:05 +0800 Subject: [PATCH 08/21] opt by using HybridSet --- be/src/exprs/hybrid_set.h | 23 +- ...aggregate_function_group_array_intersect.h | 544 +++++++++++++----- 2 files changed, 408 insertions(+), 159 deletions(-) diff --git a/be/src/exprs/hybrid_set.h b/be/src/exprs/hybrid_set.h index ba5fabe509be62..7c6d15d19b6593 100644 --- a/be/src/exprs/hybrid_set.h +++ b/be/src/exprs/hybrid_set.h @@ -170,6 +170,8 @@ class DynamicContainer { size_t size() const { return _set.size(); } + void clear() { _set.clear(); } + private: vectorized::flat_hash_set _set; }; @@ -233,7 +235,18 @@ class HybridSetBase : public RuntimeFilterFuncBase { virtual IteratorBase* begin() = 0; - bool contain_null() const { return _contains_null && _null_aware; } + bool contain_null() const { + LOG(INFO) << "Entering the func contain_null.. "; + if (_contains_null && _null_aware) { + LOG(INFO) + << "The func finds out containing null, _contains_null && _null_aware is true."; + } else if (_null_aware) { + LOG(INFO) << "The func finds out not containing null, _contains_null is not true"; + } else { + LOG(INFO) << "The func finds out not containing null, both are false "; + } + return _contains_null && _null_aware; + } bool _contains_null = false; }; @@ -271,6 +284,12 @@ class HybridSet : public HybridSetBase { ~HybridSet() override = default; void insert(const void* data) override { + LOG(INFO) << "Entering the func insert.. "; + if (data == nullptr) { + LOG(INFO) << "The func finds out data is nullptr."; + } else { + LOG(INFO) << "The func finds out data is not nullptr. "; + } if (data == nullptr) { _contains_null = true; return; @@ -308,6 +327,8 @@ class HybridSet : public HybridSetBase { int size() override { return _set.size(); } + void clear() { _set.clear(); } + bool find(const void* data) const override { if (data == nullptr) { return false; diff --git a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h index 76d362197d706b..27fc2b83d77d07 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h +++ b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h @@ -27,6 +27,7 @@ #include #include +#include "exprs/hybrid_set.h" #include "vec/aggregate_functions/aggregate_function.h" #include "vec/aggregate_functions/aggregate_function_simple_factory.h" #include "vec/aggregate_functions/helpers.h" @@ -48,13 +49,48 @@ class BufferWritable; namespace doris::vectorized { +template +constexpr PrimitiveType TypeToPrimitiveType() { + if constexpr (std::is_same_v || std::is_same_v) { + return TYPE_TINYINT; + } else if constexpr (std::is_same_v) { + return TYPE_SMALLINT; + } else if constexpr (std::is_same_v) { + return TYPE_INT; + } else if constexpr (std::is_same_v) { + return TYPE_BIGINT; + } else if constexpr (std::is_same_v) { + return TYPE_LARGEINT; + } else if constexpr (std::is_same_v) { + return TYPE_FLOAT; + // } else if constexpr (std::is_same_v) { + // return TYPE_DOUBLE; + // } else { + // return TYPE_STRING; + } else { + return TYPE_DOUBLE; + } +} + +template +class NullableKeySet + : public HybridSet(), DynamicContainer()>::CppType>> { +public: + NullableKeySet() { this->_null_aware = true; } + + void change_contains_null_value(bool target_value) { this->_contains_null = target_value; } +}; + template struct AggregateFunctionGroupArrayIntersectData { - using Set = HashSet; + using NullableKeySetType = NullableKeySet; + using Set = std::shared_ptr; + + AggregateFunctionGroupArrayIntersectData() : value(std::make_shared()) {} Set value; UInt64 version = 0; - UInt64 has_null = 0; // 添加一个私有变量has_null }; /// Puts all values to the hash set. Returns an array of unique values. Implemented for numeric types. @@ -101,9 +137,10 @@ class AggregateFunctionGroupArrayIntersect void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena*) const override { - auto& version = this->data(place).version; - auto& set = this->data(place).value; - auto& has_null = this->data(place).has_null; + auto& data = this->data(place); + auto& version = data.version; + auto& set = data.value; + CHECK(set != nullptr); LOG(INFO) << "Input row_num: " << row_num; // 输出输入的 row_num const bool col_is_nullable = (*columns[0]).is_nullable(); @@ -113,6 +150,11 @@ class AggregateFunctionGroupArrayIntersect .get_nested_column()) : assert_cast(*columns[0]); + const auto data_column = column.get_data_ptr(); + const auto& offsets = column.get_offsets(); + const size_t offset = offsets[static_cast(row_num) - 1]; + const auto arr_size = offsets[row_num] - offset; + using ColVecType = ColumnVector; LOG(INFO) << "the name of column is: " << column.get_name(); @@ -132,156 +174,317 @@ class AggregateFunctionGroupArrayIntersect nested_column_data = &static_cast(column_data); } - std::string demangled_name_column = - boost::core::demangle(typeid(*nested_column_data).name()); - LOG(INFO) << "In add func, the name of nested_column_data: " << demangled_name_column; - - const auto& offsets = column.get_offsets(); - const size_t offset = offsets[row_num - 1]; - const auto arr_size = offsets[row_num] - offset; - LOG(INFO) << "In add func, the name of *columns[0] is: " << (*columns[0]).get_name(); - LOG(INFO) << "In add func, offsets size: " << offsets.size(); - LOG(INFO) << "In add func, offsets capacity: " << offsets.capacity(); - - LOG(INFO) << "Before update: version = " << version - << ", set = {"; // 输出更新前的 version 和 set - for (const auto& elem : set) { - LOG(INFO) << elem.get_value() << " "; // 逐个输出 set 中的元素 - } - LOG(INFO) << "}"; - ++version; if (version == 1) { - LOG(INFO) << "Inserting elements into an empty set..."; + LOG(INFO) << "version is 1."; for (size_t i = 0; i < arr_size; ++i) { - if (col_null->is_null_at(offset + i)) { - LOG(INFO) << "Encounting null: "; - // this->data(place).null_set.insert(offset + i); - has_null = 1; - break; + const bool is_null_element = + is_column_data_nullable && col_null->is_null_at(offset + i); + const T* src_data = + is_null_element ? nullptr : &(nested_column_data->get_element(offset + i)); + // const char* src_data = reinterpret_cast( + // &(nested_column_data->get_element(offset + i))); + // auto src_data = nested_column_data->get_data_at(offset + i); + // set->insert((void*)&value); + if (is_null_element) { + LOG(INFO) << "src_data is null!!!!"; + // src_data.data = nullptr; } else { - T value = nested_column_data->get_element(offset + i); - LOG(INFO) << "Inserting value: " << value; - set.insert(value); + LOG(INFO) << "src_data is not null"; + LOG(INFO) << "Inserting value: " << *(src_data); } + + if (src_data == nullptr) { + LOG(INFO) << "src_data==nullptr is true. "; + } + + // set->insert(src_data.data); + set->insert(src_data); + LOG(INFO) << "After inserting value."; } - } else if (!set.empty()) { - LOG(INFO) << "Updating an existing set..."; - typename State::Set new_set; - bool found_null = false; + + if (set->contain_null()) { + LOG(INFO) << "in the last of version==1, the set contains null."; + } + + // LOG(INFO) << "before insert "; + // set->insert_fixed_len(data_column, offset); + // LOG(INFO) << "after insert "; + } else if (set->size() != 0 || set->contain_null()) { + // typename State::Set new_set; + typename State::Set new_set = std::make_shared(); + + CHECK(new_set != nullptr); for (size_t i = 0; i < arr_size; ++i) { - T value; // 将value的声明移动到循环体的开始 - if (col_null && col_null->is_null_at(offset + i)) { - LOG(INFO) << "Encounting null: "; - if (!found_null) { // 如果还没有找到null值 - // this->data(place).null_set.insert(offset + i); - has_null = 1; - found_null = true; // 更新标志 - } - break; // 遇到null值,跳出循环 + const bool is_null_element = + is_column_data_nullable && col_null->is_null_at(offset + i); + const T* src_data = + is_null_element ? nullptr : &(nested_column_data->get_element(offset + i)); + // const char* src_data = reinterpret_cast( + // &(nested_column_data->get_element(offset + i))); + // auto src_data = nested_column_data->get_data_at(offset + i); + if (is_null_element) { + LOG(INFO) << "src_data is null again ~"; + // src_data.data = nullptr; } else { - value = nested_column_data->get_element(offset + i); + LOG(INFO) << "Inserting value: " << *src_data; } - LOG(INFO) << "Checking value: " << value; - typename State::Set::LookupResult set_value = set.find(value); - if (set_value != nullptr) { - LOG(INFO) << "Value found in the set, inserting into new_set"; - new_set.insert(value); - } else { - LOG(INFO) << "Value not found in the set, skipping"; + + if (src_data == nullptr) { + LOG(INFO) << "src_data==nullptr is true. "; } + // if (set->find((void*)(src_data.data))) { + // LOG(INFO) << "we can find set's value: " << src_data.to_string(); + // // new_set->insert((void*)&value); + // new_set->insert((src_data.data)); + // LOG(INFO) << "After inserting value."; + // } + if (set->find(src_data) || src_data == nullptr) { + if (set->find(src_data)) { + LOG(INFO) << "we can find set's value: " << *src_data; + } else { + LOG(INFO) << "src_data is null again "; + } + // new_set->insert((void*)&value); + new_set->insert(src_data); + LOG(INFO) << "After inserting value."; + } + } + + if (set->contain_null()) { + LOG(INFO) << "before swap between set and new_set, the set contains null."; } + if (new_set->contain_null()) { + LOG(INFO) << "before swap between set and new_set, the new_set contains null."; + } + + // set->clear(); + // set->change_contains_null_value(new_set->contain_null()); + // HybridSetBase::IteratorBase* it = new_set->begin(); + // while (it->has_next()) { + // const T* value = reinterpret_cast(it->get_value()); + // LOG(INFO) << "In the while loop, we insert the value: " << *value; + + // set->insert(value); + // it->next(); + // } + // new_set->clear(); set = std::move(new_set); - } - LOG(INFO) << "After update: set = {"; - for (const auto& elem : set) { - LOG(INFO) << elem.get_value() << " "; + if (set->contain_null()) { + LOG(INFO) << "after swap between set and new_set, the set contains null."; + } + // if (new_set->contain_null()) { + // LOG(INFO) << "after swap between set and new_set, the new_set contains null."; + // } } - LOG(INFO) << "}"; - if (is_column_data_nullable) { - auto& null_map_data = col_null->get_null_map_data(); - null_map_data.resize_fill(nested_column_data->size(), 0); - } + // using ColVecType = ColumnVector; + // LOG(INFO) << "the name of column is: " << column.get_name(); + + // const auto& column_data = column.get_data(); + + // bool is_column_data_nullable = column_data.is_nullable(); + // ColumnNullable* col_null = nullptr; + // const ColVecType* nested_column_data = nullptr; + + // if (is_column_data_nullable) { + // LOG(INFO) << "nested_col is nullable. "; + // auto const_col_data = const_cast(&column_data); + // col_null = static_cast(const_col_data); + // nested_column_data = &assert_cast(col_null->get_nested_column()); + // } else { + // LOG(INFO) << "nested_col is not nullable. "; + // nested_column_data = &static_cast(column_data); + // } + + // std::string demangled_name_column = + // boost::core::demangle(typeid(*nested_column_data).name()); + // LOG(INFO) << "In add func, the name of nested_column_data: " << demangled_name_column; + + // const auto& offsets = column.get_offsets(); + // const size_t offset = offsets[row_num - 1]; + // const auto arr_size = offsets[row_num] - offset; + // LOG(INFO) << "In add func, the name of *columns[0] is: " << (*columns[0]).get_name(); + // LOG(INFO) << "In add func, offsets size: " << offsets.size(); + // LOG(INFO) << "In add func, offsets capacity: " << offsets.capacity(); + + // LOG(INFO) << "Before update: version = " << version + // << ", set = {"; // 输出更新前的 version 和 set + // for (const auto& elem : set) { + // LOG(INFO) << elem.get_value() << " "; // 逐个输出 set 中的元素 + // } + // LOG(INFO) << "}"; + + // ++version; + // if (version == 1) { + // LOG(INFO) << "Inserting elements into an empty set..."; + // for (size_t i = 0; i < arr_size; ++i) { + // if (col_null->is_null_at(offset + i)) { + // LOG(INFO) << "Encounting null: "; + // // this->data(place).null_set.insert(offset + i); + // has_null = 1; + // break; + // } else { + // T value = nested_column_data->get_element(offset + i); + // LOG(INFO) << "Inserting value: " << value; + // set.insert(value); + // } + // } + // } else if (!set.empty()) { + // LOG(INFO) << "Updating an existing set..."; + // typename State::Set new_set; + // bool found_null = false; + // for (size_t i = 0; i < arr_size; ++i) { + // T value; // 将value的声明移动到循环体的开始 + // if (col_null && col_null->is_null_at(offset + i)) { + // LOG(INFO) << "Encounting null: "; + // if (!found_null) { // 如果还没有找到null值 + // // this->data(place).null_set.insert(offset + i); + // has_null = 1; + // found_null = true; // 更新标志 + // } + // break; // 遇到null值,跳出循环 + // } else { + // value = nested_column_data->get_element(offset + i); + // } + // LOG(INFO) << "Checking value: " << value; + // typename State::Set::LookupResult set_value = set.find(value); + // if (set_value != nullptr) { + // LOG(INFO) << "Value found in the set, inserting into new_set"; + // new_set.insert(value); + // } else { + // LOG(INFO) << "Value not found in the set, skipping"; + // } + // } + // set = std::move(new_set); + // } + + // LOG(INFO) << "After update: set = {"; + // for (const auto& elem : set) { + // LOG(INFO) << elem.get_value() << " "; + // } + // LOG(INFO) << "}"; + + // if (is_column_data_nullable) { + // auto& null_map_data = col_null->get_null_map_data(); + // null_map_data.resize_fill(nested_column_data->size(), 0); + // } } void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena*) const override { - auto place_has_null = this->data(place).has_null; - auto rhs_has_null = this->data(rhs).has_null; - LOG(INFO) << "before, merge: place_has_null = " << place_has_null - << ", rhs_has_null = " << rhs_has_null; - this->data(place).has_null = rhs_has_null; + auto& data = this->data(place); - auto& set = this->data(place).value; - const auto& rhs_set = this->data(rhs).value; + auto& set = data.value; + auto& rhs_set = this->data(rhs).value; // auto& null_set = this->data(place).null_set; // const auto& rhs_null_set = this->data(rhs).null_set; + set->change_contains_null_value(rhs_set->contain_null()); - LOG(INFO) << "merge: place set size = " << set.size() - << ", rhs set size = " << rhs_set.size(); + LOG(INFO) << "merge: place set size = " << set->size(); if (this->data(rhs).version == 0) { LOG(INFO) << "rhs version is 0, skipping merge"; return; } - UInt64 version = this->data(place).version++; + UInt64 version = data.version++; LOG(INFO) << "merge: version = " << version; if (version == 0) { LOG(INFO) << "Copying rhs set to place set"; - for (auto& rhs_elem : rhs_set) set.insert(rhs_elem.get_value()); + const auto& rhs_set = this->data(rhs).value; + HybridSetBase::IteratorBase* it = rhs_set->begin(); + while (it->has_next()) { + const void* value = it->get_value(); + set->insert(value); + it->next(); + } return; } - if (!set.empty()) { + if (set->size() != 0) { LOG(INFO) << "Merging place set and rhs set"; auto create_new_set = [](auto& lhs_val, auto& rhs_val) { typename State::Set new_set; - for (auto& lhs_elem : lhs_val) { - auto res = rhs_val.find(lhs_elem.get_value()); - if (res != nullptr) new_set.insert(lhs_elem.get_value()); + HybridSetBase::IteratorBase* it = lhs_val->begin(); + while (it->has_next()) { + const void* value = it->get_value(); + bool found = rhs_val->find(value); + if (found) { + new_set->insert(value); + } + it->next(); } return new_set; }; - auto new_set = rhs_set.size() < set.size() ? create_new_set(rhs_set, set) - : create_new_set(set, rhs_set); + auto new_set = rhs_set->size() < set->size() ? create_new_set(rhs_set, set) + : create_new_set(set, rhs_set); + // set->clear(); + // set->change_contains_null_value(new_set->contain_null()); + // HybridSetBase::IteratorBase* it = new_set->begin(); + // while (it->has_next()) { + // const void* value = it->get_value(); + // set->insert(value); + // it->next(); + // } + // new_set->clear(); set = std::move(new_set); } LOG(INFO) << "After merge: set = {"; - for (const auto& elem : set) { - LOG(INFO) << elem.get_value() << " "; + HybridSetBase::IteratorBase* it = set->begin(); + while (it->has_next()) { + T value = *reinterpret_cast(it->get_value()); + LOG(INFO) << value << " "; + it->next(); } LOG(INFO) << "}"; // null_set.insert(rhs_null_set.begin(), rhs_null_set.end()); - LOG(INFO) << "after, merge: has_null = " << this->data(place).has_null; } void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { - auto& set = this->data(place).value; - auto version = this->data(place).version; - auto has_null = this->data(place).has_null; + auto& data = this->data(place); + auto& set = data.value; + auto version = data.version; + CHECK(set != nullptr); - LOG(INFO) << "before, serialize: has_null = " << has_null; + LOG(INFO) << "serialize: version = " << version << ", set size = " << set->size(); - LOG(INFO) << "serialize: version = " << version << ", set size = " << set.size(); + bool is_set_contains_null = set->contain_null(); + if (is_set_contains_null) { + LOG(INFO) << "Before writing of serialize, the set contains null."; + } - write_var_uint(has_null, buf); + write_pod_binary(is_set_contains_null, buf); - // if (has_null == 0) { - write_var_uint(version, buf); - write_var_uint(set.size(), buf); - for (const auto& elem : set) { - LOG(INFO) << "Serializing element: " << elem.get_value(); - write_int_binary(elem.get_value(), buf); + write_var_uint(version, buf); + write_var_uint(set->size(), buf); + HybridSetBase::IteratorBase* it = set->begin(); + + if (it == nullptr) { + LOG(INFO) << "Before writing of serialize, the set->begin() is nullptr."; + } + + if (it->has_next()) { + LOG(INFO) << "Before writing of serialize, the it->has_next() is true."; + } + + while (it->has_next()) { + if (it->get_value() == nullptr) { + LOG(INFO) << "during writing of serialize, the it->get_value() is nullptr."; } + const T* value_ptr = static_cast(it->get_value()); + // const StringRef* str_ref = reinterpret_cast(value_ptr); + LOG(INFO) << "Serializing element: " << (*value_ptr); + LOG(INFO) << "after writing element... "; + write_int_binary((*value_ptr), buf); + it->next(); + } // } // } else { // } - LOG(INFO) << "after, serialize: has_null = " << has_null; // // 序列化null_set // write_var_uint(null_set.size(), buf); // for (const auto& null_index : null_set) { @@ -291,33 +494,45 @@ class AggregateFunctionGroupArrayIntersect void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, Arena*) const override { + auto& data = this->data(place); + bool is_set_contains_null; + LOG(INFO) << "deserialize"; - // auto& has_null = this->data(place).has_null; - - // LOG(INFO) << "before, deserialize: has_null = " << has_null; - - read_var_uint(this->data(place).has_null, buf); - auto has_null = this->data(place).has_null; - LOG(INFO) << "another way to read the has_null: " << this->data(place).has_null; - LOG(INFO) << "another has_null value: " << has_null; - - // if (has_null == 0) { - read_var_uint(this->data(place).version, buf); - LOG(INFO) << "Deserialized version: " << this->data(place).version; - this->data(place).value.read(buf); - LOG(INFO) << "Deserialized set: {"; - for (const auto& elem : this->data(place).value) { - LOG(INFO) << elem.get_value() << " "; - } - LOG(INFO) << "}"; + + read_pod_binary(is_set_contains_null, buf); + data.value->change_contains_null_value(is_set_contains_null); + + read_var_uint(data.version, buf); + LOG(INFO) << "Deserialized version: " << data.version; + // this->data(place).value.read(buf); + size_t size; + read_var_uint(size, buf); + LOG(INFO) << "Deserialized size: " << size; + + // read_binary(data.value, buf); + LOG(INFO) << "Deserialized set: {"; + T element; + // for (const auto& elem : this->data(place).value) { + // LOG(INFO) << elem.get_value() << " "; // } - // 读取has_null的值 - // } else { - // read_var_uint(has_null, buf); - // LOG(INFO) << "Deserialized has_null: " << this->data(place).has_null; + // HybridSetBase::IteratorBase* it = data.value->begin(); + // while (it->has_next()) { + // const T* value_ptr = reinterpret_cast(it->get_value()); + // // const StringRef* str_ref = reinterpret_cast(value_ptr); + // LOG(INFO) << "derializing element: " << (*value_ptr); + // it->next(); // } + for (size_t i = 0; i < size; ++i) { + read_int_binary(element, buf); + LOG(INFO) << "derializing element: " << element; + data.value->insert(static_cast(&element)); + } + LOG(INFO) << "}"; - LOG(INFO) << "after, deserialize: has_null = " << this->data(place).has_null; + if (data.value->contain_null()) { + LOG(INFO) << "After reading of deserialize, the set contains null."; + } + // } // size_t null_set_size; // read_var_uint(null_set_size, buf); @@ -329,8 +544,6 @@ class AggregateFunctionGroupArrayIntersect } void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { - LOG(INFO) << "before, insert: has_null = " << this->data(place).has_null; - LOG(INFO) << "in the insert_result_into, name of T: " << boost::core::demangle(typeid(T).name()); LOG(INFO) << "in the start of insert, Type name: " << typeid(T).name(); @@ -353,42 +566,54 @@ class AggregateFunctionGroupArrayIntersect auto& nested_col = assert_cast(col_null->get_nested_column()); auto& null_map_data = col_null->get_null_map_data(); - if (this->data(place).has_null == 1) { - LOG(INFO) << "We have null, pass!"; - // nested_col.insert_default(); - // if (offsets_to.size() == 0) { - // offsets_to.push_back(0); - // } else { - // offsets_to.push_back(offsets_to.back()); - // } - // nested_col.get_data().resize(1); - // nested_col.insert_default(); - col_null->insert_data(nullptr, 0); - offsets_to.push_back(to_nested_col.size()); - // null_map_data.push_back(1); - } else { - size_t old_size = nested_col.get_data().size(); - LOG(INFO) << "old_size of data_to: " << old_size; + // if (this->data(place).has_null == 1) { + // LOG(INFO) << "We have null, pass!"; + // // nested_col.insert_default(); + // // if (offsets_to.size() == 0) { + // // offsets_to.push_back(0); + // // } else { + // // offsets_to.push_back(offsets_to.back()); + // // } + // // nested_col.get_data().resize(1); + // // nested_col.insert_default(); + // col_null->insert_data(nullptr, 0); + // offsets_to.push_back(to_nested_col.size()); + // // null_map_data.push_back(1); + // } else { + size_t old_size = nested_col.get_data().size(); + LOG(INFO) << "old_size of data_to: " << old_size; - const auto& set = this->data(place).value; - LOG(INFO) << "insert_result_into: set size = " << set.size(); + const auto& set = this->data(place).value; + LOG(INFO) << "insert_result_into: set size = " << set->size(); - if (offsets_to.size() == 0) { - offsets_to.push_back(set.size()); - } else { - offsets_to.push_back(offsets_to.back() + set.size()); - } + auto res_size = set->size(); + size_t i = 0; - nested_col.get_data().resize(old_size + set.size()); + if (set->contain_null()) { + LOG(INFO) << "We have null, insert it!"; + col_null->insert_data(nullptr, 0); + res_size += 1; + i = 1; + } - size_t i = 0; - for (auto it = set.begin(); it != set.end(); ++it, ++i) { - T value = it->get_value(); - LOG(INFO) << "Inserting value: " << value << " at index " << (old_size + i); - nested_col.get_data()[old_size + i] = value; - null_map_data.push_back(0); - } + if (offsets_to.size() == 0) { + offsets_to.push_back(res_size); + } else { + offsets_to.push_back(offsets_to.back() + res_size); + } + + nested_col.get_data().resize(old_size + res_size); + + HybridSetBase::IteratorBase* it = set->begin(); + while (it->has_next()) { + T value = *reinterpret_cast(it->get_value()); + LOG(INFO) << "Inserting value: " << value << " at index " << (old_size + i); + nested_col.get_data()[old_size + i] = value; + null_map_data.push_back(0); + it->next(); + ++i; } + // } } else { LOG(INFO) << "nested_col is not nullable. "; auto& nested_col = static_cast(to_nested_col); @@ -396,21 +621,24 @@ class AggregateFunctionGroupArrayIntersect LOG(INFO) << "old_size of data_to: " << old_size; const auto& set = this->data(place).value; - LOG(INFO) << "insert_result_into: set size = " << set.size(); + LOG(INFO) << "insert_result_into: set size = " << set->size(); if (offsets_to.size() == 0) { - offsets_to.push_back(set.size()); + offsets_to.push_back(set->size()); } else { - offsets_to.push_back(offsets_to.back() + set.size()); + offsets_to.push_back(offsets_to.back() + set->size()); } - nested_col.get_data().resize(old_size + set.size()); + nested_col.get_data().resize(old_size + set->size()); size_t i = 0; - for (auto it = set.begin(); it != set.end(); ++it, ++i) { - T value = it->get_value(); + HybridSetBase::IteratorBase* it = set->begin(); + while (it->has_next()) { + T value = *reinterpret_cast(it->get_value()); LOG(INFO) << "Inserting value: " << value << " at index " << (old_size + i); nested_col.get_data()[old_size + i] = value; + it->next(); + ++i; } } LOG(INFO) << "After making value to data_to, leaving..."; @@ -429,7 +657,7 @@ class AggregateFunctionGroupArrayIntersect // size_t i = 0; // for (auto it = set.begin(); it != set.end(); ++it, ++i) { - // T value = it->get_value(); + // T value = it.get_value(); // LOG(INFO) << "Inserting value: " << value << " at index " << (old_size + i); // data_to[old_size + i] = value; // } From 576ddb13443eec0d1ea9061dcec1b749efc053cb Mon Sep 17 00:00:00 2001 From: chesterxu Date: Sun, 31 Mar 2024 19:38:59 +0800 Subject: [PATCH 09/21] remove useless code --- ...aggregate_function_group_array_intersect.h | 208 +----------------- 1 file changed, 2 insertions(+), 206 deletions(-) diff --git a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h index 27fc2b83d77d07..9bdfa7640af8a8 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h +++ b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h @@ -118,19 +118,12 @@ class AggregateFunctionGroupArrayIntersect String get_name() const override { return "group_array_intersect"; } - // DataTypePtr get_return_type() const override { return argument_type; } DataTypePtr get_return_type() const override { std::string demangled_name = boost::core::demangle(typeid(argument_type).name()); LOG(INFO) << "in the get_return_type, name of argument_type: " << demangled_name; std::string demangled_name_T = boost::core::demangle(typeid(T).name()); LOG(INFO) << "in the get_return_type, name of T: " << demangled_name_T; - // return std::make_shared(make_nullable(argument_type)); return argument_type; - - // using ReturnDataType = DataTypeNumber; - // std::string demangled_name_re = boost::core::demangle(typeid(ReturnDataType).name()); - // LOG(INFO) << "in the get_return_type, name of ReturnDataType: " << demangled_name_re; - // return std::make_shared(std::make_shared()); } bool allocates_memory_in_arena() const override { return false; } @@ -182,10 +175,7 @@ class AggregateFunctionGroupArrayIntersect is_column_data_nullable && col_null->is_null_at(offset + i); const T* src_data = is_null_element ? nullptr : &(nested_column_data->get_element(offset + i)); - // const char* src_data = reinterpret_cast( - // &(nested_column_data->get_element(offset + i))); - // auto src_data = nested_column_data->get_data_at(offset + i); - // set->insert((void*)&value); + if (is_null_element) { LOG(INFO) << "src_data is null!!!!"; // src_data.data = nullptr; @@ -207,9 +197,6 @@ class AggregateFunctionGroupArrayIntersect LOG(INFO) << "in the last of version==1, the set contains null."; } - // LOG(INFO) << "before insert "; - // set->insert_fixed_len(data_column, offset); - // LOG(INFO) << "after insert "; } else if (set->size() != 0 || set->contain_null()) { // typename State::Set new_set; typename State::Set new_set = std::make_shared(); @@ -220,9 +207,7 @@ class AggregateFunctionGroupArrayIntersect is_column_data_nullable && col_null->is_null_at(offset + i); const T* src_data = is_null_element ? nullptr : &(nested_column_data->get_element(offset + i)); - // const char* src_data = reinterpret_cast( - // &(nested_column_data->get_element(offset + i))); - // auto src_data = nested_column_data->get_data_at(offset + i); + if (is_null_element) { LOG(INFO) << "src_data is null again ~"; // src_data.data = nullptr; @@ -233,19 +218,12 @@ class AggregateFunctionGroupArrayIntersect if (src_data == nullptr) { LOG(INFO) << "src_data==nullptr is true. "; } - // if (set->find((void*)(src_data.data))) { - // LOG(INFO) << "we can find set's value: " << src_data.to_string(); - // // new_set->insert((void*)&value); - // new_set->insert((src_data.data)); - // LOG(INFO) << "After inserting value."; - // } if (set->find(src_data) || src_data == nullptr) { if (set->find(src_data)) { LOG(INFO) << "we can find set's value: " << *src_data; } else { LOG(INFO) << "src_data is null again "; } - // new_set->insert((void*)&value); new_set->insert(src_data); LOG(INFO) << "After inserting value."; } @@ -257,129 +235,19 @@ class AggregateFunctionGroupArrayIntersect if (new_set->contain_null()) { LOG(INFO) << "before swap between set and new_set, the new_set contains null."; } - - // set->clear(); - // set->change_contains_null_value(new_set->contain_null()); - // HybridSetBase::IteratorBase* it = new_set->begin(); - // while (it->has_next()) { - // const T* value = reinterpret_cast(it->get_value()); - // LOG(INFO) << "In the while loop, we insert the value: " << *value; - - // set->insert(value); - // it->next(); - // } - // new_set->clear(); set = std::move(new_set); if (set->contain_null()) { LOG(INFO) << "after swap between set and new_set, the set contains null."; } - // if (new_set->contain_null()) { - // LOG(INFO) << "after swap between set and new_set, the new_set contains null."; - // } } - - // using ColVecType = ColumnVector; - // LOG(INFO) << "the name of column is: " << column.get_name(); - - // const auto& column_data = column.get_data(); - - // bool is_column_data_nullable = column_data.is_nullable(); - // ColumnNullable* col_null = nullptr; - // const ColVecType* nested_column_data = nullptr; - - // if (is_column_data_nullable) { - // LOG(INFO) << "nested_col is nullable. "; - // auto const_col_data = const_cast(&column_data); - // col_null = static_cast(const_col_data); - // nested_column_data = &assert_cast(col_null->get_nested_column()); - // } else { - // LOG(INFO) << "nested_col is not nullable. "; - // nested_column_data = &static_cast(column_data); - // } - - // std::string demangled_name_column = - // boost::core::demangle(typeid(*nested_column_data).name()); - // LOG(INFO) << "In add func, the name of nested_column_data: " << demangled_name_column; - - // const auto& offsets = column.get_offsets(); - // const size_t offset = offsets[row_num - 1]; - // const auto arr_size = offsets[row_num] - offset; - // LOG(INFO) << "In add func, the name of *columns[0] is: " << (*columns[0]).get_name(); - // LOG(INFO) << "In add func, offsets size: " << offsets.size(); - // LOG(INFO) << "In add func, offsets capacity: " << offsets.capacity(); - - // LOG(INFO) << "Before update: version = " << version - // << ", set = {"; // 输出更新前的 version 和 set - // for (const auto& elem : set) { - // LOG(INFO) << elem.get_value() << " "; // 逐个输出 set 中的元素 - // } - // LOG(INFO) << "}"; - - // ++version; - // if (version == 1) { - // LOG(INFO) << "Inserting elements into an empty set..."; - // for (size_t i = 0; i < arr_size; ++i) { - // if (col_null->is_null_at(offset + i)) { - // LOG(INFO) << "Encounting null: "; - // // this->data(place).null_set.insert(offset + i); - // has_null = 1; - // break; - // } else { - // T value = nested_column_data->get_element(offset + i); - // LOG(INFO) << "Inserting value: " << value; - // set.insert(value); - // } - // } - // } else if (!set.empty()) { - // LOG(INFO) << "Updating an existing set..."; - // typename State::Set new_set; - // bool found_null = false; - // for (size_t i = 0; i < arr_size; ++i) { - // T value; // 将value的声明移动到循环体的开始 - // if (col_null && col_null->is_null_at(offset + i)) { - // LOG(INFO) << "Encounting null: "; - // if (!found_null) { // 如果还没有找到null值 - // // this->data(place).null_set.insert(offset + i); - // has_null = 1; - // found_null = true; // 更新标志 - // } - // break; // 遇到null值,跳出循环 - // } else { - // value = nested_column_data->get_element(offset + i); - // } - // LOG(INFO) << "Checking value: " << value; - // typename State::Set::LookupResult set_value = set.find(value); - // if (set_value != nullptr) { - // LOG(INFO) << "Value found in the set, inserting into new_set"; - // new_set.insert(value); - // } else { - // LOG(INFO) << "Value not found in the set, skipping"; - // } - // } - // set = std::move(new_set); - // } - - // LOG(INFO) << "After update: set = {"; - // for (const auto& elem : set) { - // LOG(INFO) << elem.get_value() << " "; - // } - // LOG(INFO) << "}"; - - // if (is_column_data_nullable) { - // auto& null_map_data = col_null->get_null_map_data(); - // null_map_data.resize_fill(nested_column_data->size(), 0); - // } } void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena*) const override { auto& data = this->data(place); - auto& set = data.value; auto& rhs_set = this->data(rhs).value; - // auto& null_set = this->data(place).null_set; - // const auto& rhs_null_set = this->data(rhs).null_set; set->change_contains_null_value(rhs_set->contain_null()); LOG(INFO) << "merge: place set size = " << set->size(); @@ -421,15 +289,6 @@ class AggregateFunctionGroupArrayIntersect }; auto new_set = rhs_set->size() < set->size() ? create_new_set(rhs_set, set) : create_new_set(set, rhs_set); - // set->clear(); - // set->change_contains_null_value(new_set->contain_null()); - // HybridSetBase::IteratorBase* it = new_set->begin(); - // while (it->has_next()) { - // const void* value = it->get_value(); - // set->insert(value); - // it->next(); - // } - // new_set->clear(); set = std::move(new_set); } @@ -482,14 +341,6 @@ class AggregateFunctionGroupArrayIntersect write_int_binary((*value_ptr), buf); it->next(); } - // } - // } else { - // } - // // 序列化null_set - // write_var_uint(null_set.size(), buf); - // for (const auto& null_index : null_set) { - // write_var_uint(null_index, buf); - // } } void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, @@ -509,19 +360,8 @@ class AggregateFunctionGroupArrayIntersect read_var_uint(size, buf); LOG(INFO) << "Deserialized size: " << size; - // read_binary(data.value, buf); LOG(INFO) << "Deserialized set: {"; T element; - // for (const auto& elem : this->data(place).value) { - // LOG(INFO) << elem.get_value() << " "; - // } - // HybridSetBase::IteratorBase* it = data.value->begin(); - // while (it->has_next()) { - // const T* value_ptr = reinterpret_cast(it->get_value()); - // // const StringRef* str_ref = reinterpret_cast(value_ptr); - // LOG(INFO) << "derializing element: " << (*value_ptr); - // it->next(); - // } for (size_t i = 0; i < size; ++i) { read_int_binary(element, buf); LOG(INFO) << "derializing element: " << element; @@ -532,15 +372,6 @@ class AggregateFunctionGroupArrayIntersect if (data.value->contain_null()) { LOG(INFO) << "After reading of deserialize, the set contains null."; } - // } - - // size_t null_set_size; - // read_var_uint(null_set_size, buf); - // for (size_t i = 0; i < null_set_size; ++i) { - // size_t null_index; - // read_var_uint(null_index, buf); - // this->data(place).null_set.insert(null_index); - // } } void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { @@ -566,20 +397,6 @@ class AggregateFunctionGroupArrayIntersect auto& nested_col = assert_cast(col_null->get_nested_column()); auto& null_map_data = col_null->get_null_map_data(); - // if (this->data(place).has_null == 1) { - // LOG(INFO) << "We have null, pass!"; - // // nested_col.insert_default(); - // // if (offsets_to.size() == 0) { - // // offsets_to.push_back(0); - // // } else { - // // offsets_to.push_back(offsets_to.back()); - // // } - // // nested_col.get_data().resize(1); - // // nested_col.insert_default(); - // col_null->insert_data(nullptr, 0); - // offsets_to.push_back(to_nested_col.size()); - // // null_map_data.push_back(1); - // } else { size_t old_size = nested_col.get_data().size(); LOG(INFO) << "old_size of data_to: " << old_size; @@ -642,27 +459,6 @@ class AggregateFunctionGroupArrayIntersect } } LOG(INFO) << "After making value to data_to, leaving..."; - - // auto col_null = reinterpret_cast(&to_nested_col); - // auto& data_to = assert_cast(col_null->get_nested_column()).get_data(); - // // typename ColumnVector::Container& data_to = - // // assert_cast&>(arr_to.get_data()).get_data(); - // std::string demangled_name = boost::core::demangle(typeid(data_to).name()); - // LOG(INFO) << "name of data_to: " << demangled_name << std::endl; - // size_t old_size = data_to.size(); - // LOG(INFO) << "old_size of data_to: " << old_size; - - // data_to.resize(old_size + set.size()); - // LOG(INFO) << "after resize, size of data_to: " << data_to.size(); - - // size_t i = 0; - // for (auto it = set.begin(); it != set.end(); ++it, ++i) { - // T value = it.get_value(); - // LOG(INFO) << "Inserting value: " << value << " at index " << (old_size + i); - // data_to[old_size + i] = value; - // } - // LOG(INFO) << "after for loop, size of data_to: " << data_to.size(); - // LOG(INFO) << "After making value to data_to, leaving..."; } }; From 299824767d4b5093d9ad92ba328a260f70228dc8 Mon Sep 17 00:00:00 2001 From: chesterxu Date: Mon, 1 Apr 2024 16:54:32 +0800 Subject: [PATCH 10/21] fix string and date part --- ...aggregate_function_group_array_intersect.h | 405 +++++++++++++----- 1 file changed, 302 insertions(+), 103 deletions(-) diff --git a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h index 9bdfa7640af8a8..36aafb72361698 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h +++ b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h @@ -49,6 +49,7 @@ class BufferWritable; namespace doris::vectorized { +/// Only for changing Numeric type or Date(DateTime)V2 type to PrimitiveType so that to inherit HybridSet template constexpr PrimitiveType TypeToPrimitiveType() { if constexpr (std::is_same_v || std::is_same_v) { @@ -63,31 +64,36 @@ constexpr PrimitiveType TypeToPrimitiveType() { return TYPE_LARGEINT; } else if constexpr (std::is_same_v) { return TYPE_FLOAT; - // } else if constexpr (std::is_same_v) { - // return TYPE_DOUBLE; - // } else { - // return TYPE_STRING; - } else { + } else if constexpr (std::is_same_v) { return TYPE_DOUBLE; + } else if constexpr (std::is_same_v) { + return TYPE_DATEV2; + } else if constexpr (std::is_same_v) { + return TYPE_DATETIMEV2; + } else { + throw Exception( + ErrorCode::INVALID_ARGUMENT, + "Only for changing Numeric type or Date(DateTime)V2 type to PrimitiveType"); } } template -class NullableKeySet +class NullableNumericOrDateSet : public HybridSet(), DynamicContainer()>::CppType>> { public: - NullableKeySet() { this->_null_aware = true; } + NullableNumericOrDateSet() { this->_null_aware = true; } void change_contains_null_value(bool target_value) { this->_contains_null = target_value; } }; template struct AggregateFunctionGroupArrayIntersectData { - using NullableKeySetType = NullableKeySet; - using Set = std::shared_ptr; + using NullableNumericOrDateSetType = NullableNumericOrDateSet; + using Set = std::shared_ptr; - AggregateFunctionGroupArrayIntersectData() : value(std::make_shared()) {} + AggregateFunctionGroupArrayIntersectData() + : value(std::make_shared()) {} Set value; UInt64 version = 0; @@ -199,7 +205,8 @@ class AggregateFunctionGroupArrayIntersect } else if (set->size() != 0 || set->contain_null()) { // typename State::Set new_set; - typename State::Set new_set = std::make_shared(); + typename State::Set new_set = + std::make_shared(); CHECK(new_set != nullptr); for (size_t i = 0; i < arr_size; ++i) { @@ -218,7 +225,7 @@ class AggregateFunctionGroupArrayIntersect if (src_data == nullptr) { LOG(INFO) << "src_data==nullptr is true. "; } - if (set->find(src_data) || src_data == nullptr) { + if (set->find(src_data) || (set->contain_null() && src_data == nullptr)) { if (set->find(src_data)) { LOG(INFO) << "we can find set's value: " << *src_data; } else { @@ -279,7 +286,7 @@ class AggregateFunctionGroupArrayIntersect HybridSetBase::IteratorBase* it = lhs_val->begin(); while (it->has_next()) { const void* value = it->get_value(); - bool found = rhs_val->find(value); + bool found = (rhs_val->find(value) || (rhs_val->contain_null() && value == nullptr)); if (found) { new_set->insert(value); } @@ -300,7 +307,6 @@ class AggregateFunctionGroupArrayIntersect it->next(); } LOG(INFO) << "}"; - // null_set.insert(rhs_null_set.begin(), rhs_null_set.end()); } void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { @@ -463,9 +469,18 @@ class AggregateFunctionGroupArrayIntersect }; /// Generic implementation, it uses serialized representation as object descriptor. +class NullableStringSet : public StringValueSet> { +public: + NullableStringSet() { this->_null_aware = true; } + + void change_contains_null_value(bool target_value) { this->_contains_null = target_value; } +}; + struct AggregateFunctionGroupArrayIntersectGenericData { - using Set = HashSet; + using Set = std::shared_ptr; + AggregateFunctionGroupArrayIntersectGenericData() + : value(std::make_shared()) {} Set value; UInt64 version = 0; }; @@ -495,137 +510,320 @@ class AggregateFunctionGroupArrayIntersectGeneric // DataTypePtr get_return_type() const override { return input_data_type; } DataTypePtr get_return_type() const override { // return std::make_shared(make_nullable(input_data_type)); - return std::make_shared(input_data_type); + return input_data_type; } bool allocates_memory_in_arena() const override { return true; } void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena* arena) const override { - auto& set = this->data(place).value; - auto& version = this->data(place).version; - bool inserted; - State::Set::LookupResult it; - - const auto data_column = assert_cast(*columns[0]).get_data_ptr(); - const auto& offsets = assert_cast(*columns[0]).get_offsets(); - const size_t offset = offsets[row_num - 1]; + auto& data = this->data(place); + auto& version = data.version; + auto& set = data.value; + CHECK(set != nullptr); + + LOG(INFO) << "Input row_num: " << row_num; // 输出输入的 row_num + const bool col_is_nullable = (*columns[0]).is_nullable(); + const ColumnArray& column = + col_is_nullable ? assert_cast( + assert_cast(*columns[0]) + .get_nested_column()) + : assert_cast(*columns[0]); + + const auto data_column = column.get_data_ptr(); + const auto& offsets = column.get_offsets(); + const size_t offset = offsets[static_cast(row_num) - 1]; const auto arr_size = offsets[row_num] - offset; + LOG(INFO) << "the name of column is: " << column.get_name(); + + const auto& column_data = column.get_data(); + + bool is_column_data_nullable = column_data.is_nullable(); + ColumnNullable* col_null = nullptr; + const ColumnArray* nested_column_data = nullptr; + + if (is_column_data_nullable) { + LOG(INFO) << "nested_col is nullable. "; + auto const_col_data = const_cast(&column_data); + col_null = static_cast(const_col_data); + nested_column_data = &assert_cast(col_null->get_nested_column()); + } else { + LOG(INFO) << "nested_col is not nullable. "; + nested_column_data = &static_cast(column_data); + } + ++version; if (version == 1) { + LOG(INFO) << "version is 1."; for (size_t i = 0; i < arr_size; ++i) { + const bool is_null_element = + is_column_data_nullable && col_null->is_null_at(offset + i); + + StringRef src = StringRef(); if constexpr (is_plain_column) { - StringRef key = data_column->get_data_at(offset + i); - key.data = arena->insert(key.data, key.size); - set.emplace(key, it, inserted); + src = nested_column_data->get_data_at(offset + i); + src.data = arena->insert(src.data, src.size); } else { const char* begin = nullptr; - StringRef serialized = - data_column->serialize_value_into_arena(offset + i, *arena, begin); - assert(serialized.data != nullptr); - serialized.data = arena->insert(serialized.data, serialized.size); - set.emplace(serialized, it, inserted); + src = nested_column_data->serialize_value_into_arena(offset + i, *arena, begin); + src.data = arena->insert(src.data, src.size); + } + + const auto* src_data = is_null_element ? nullptr : src.data; + + if (is_null_element) { + LOG(INFO) << "src_data is null!!!!"; + // src_data.data = nullptr; + } else { + LOG(INFO) << "src_data is not null"; + LOG(INFO) << "Inserting value: " << *(src_data); + } + + if (src_data == nullptr) { + LOG(INFO) << "src_data==nullptr is true. "; } + + // set->insert(src_data.data); + set->insert(src_data); + LOG(INFO) << "After inserting value."; + } + + if (set->contain_null()) { + LOG(INFO) << "in the last of version==1, the set contains null."; } - } else if (!set.empty()) { - typename State::Set new_set; + + } else if (set->size() != 0 || set->contain_null()) { + // typename State::Set new_set; + typename State::Set new_set = std::make_shared(); + + CHECK(new_set != nullptr); for (size_t i = 0; i < arr_size; ++i) { + const bool is_null_element = + is_column_data_nullable && col_null->is_null_at(offset + i); + StringRef src = StringRef(); if constexpr (is_plain_column) { - it = set.find(data_column->get_data_at(offset + i)); - if (it != nullptr) { - StringRef key = data_column->get_data_at(offset + i); - key.data = arena->insert(key.data, key.size); - new_set.emplace(key, it, inserted); - } + src = nested_column_data->get_data_at(offset + i); + src.data = arena->insert(src.data, src.size); } else { const char* begin = nullptr; - StringRef serialized = - data_column->serialize_value_into_arena(offset + i, *arena, begin); - assert(serialized.data != nullptr); - it = set.find(serialized); - - if (it != nullptr) { - serialized.data = arena->insert(serialized.data, serialized.size); - new_set.emplace(serialized, it, inserted); + src = nested_column_data->serialize_value_into_arena(offset + i, *arena, begin); + src.data = arena->insert(src.data, src.size); + } + + const auto* src_data = is_null_element ? nullptr : src.data; + + if (is_null_element) { + LOG(INFO) << "src_data is null again ~"; + // src_data.data = nullptr; + } else { + LOG(INFO) << "Inserting value: " << *src_data; + } + + if (src_data == nullptr) { + LOG(INFO) << "src_data==nullptr is true. "; + } + if (set->find(src_data) || (set->contain_null() && src_data == nullptr)) { + if (set->find(src_data)) { + LOG(INFO) << "we can find set's value: " << *src_data; + } else { + LOG(INFO) << "src_data is null again "; } + new_set->insert(src_data); + LOG(INFO) << "After inserting value."; } } + + if (set->contain_null()) { + LOG(INFO) << "before swap between set and new_set, the set contains null."; + } + if (new_set->contain_null()) { + LOG(INFO) << "before swap between set and new_set, the new_set contains null."; + } set = std::move(new_set); + + if (set->contain_null()) { + LOG(INFO) << "after swap between set and new_set, the set contains null."; + } } } void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena* arena) const override { - auto& set = this->data(place).value; - const auto& rhs_value = this->data(rhs).value; + auto& data = this->data(place); + auto& set = data.value; + auto& rhs_set = this->data(rhs).value; + set->change_contains_null_value(rhs_set->contain_null()); + + LOG(INFO) << "merge: place set size = " << set->size(); + + if (this->data(rhs).version == 0) { + LOG(INFO) << "rhs version is 0, skipping merge"; + return; + } - if (this->data(rhs).version == 0) return; + UInt64 version = data.version++; + LOG(INFO) << "merge: version = " << version; - UInt64 version = this->data(place).version++; if (version == 0) { - bool inserted; - State::Set::LookupResult it; - for (auto& rhs_elem : rhs_value) { - StringRef key = rhs_elem.get_value(); - key.data = arena->insert(key.data, key.size); - set.emplace(key, it, inserted); + LOG(INFO) << "Copying rhs set to place set"; + const auto& rhs_set = this->data(rhs).value; + HybridSetBase::IteratorBase* it = rhs_set->begin(); + while (it->has_next()) { + const void* value = it->get_value(); + set->insert(value); + it->next(); } - } else if (!set.empty()) { - auto create_new_map = [](auto& lhs_val, auto& rhs_val) { - typename State::Set new_map; - for (auto& lhs_elem : lhs_val) { - auto val = rhs_val.find(lhs_elem.get_value()); - if (val != nullptr) new_map.insert(lhs_elem.get_value()); + return; + } + + if (set->size() != 0) { + LOG(INFO) << "Merging place set and rhs set"; + auto create_new_set = [](auto& lhs_val, auto& rhs_val) { + typename State::Set new_set; + HybridSetBase::IteratorBase* it = lhs_val->begin(); + while (it->has_next()) { + const void* value = it->get_value(); + bool found = + (rhs_val->find(value) || (rhs_val->contain_null() && value == nullptr)); + if (found) { + new_set->insert(value); + } + it->next(); } - return new_map; + return new_set; }; - auto new_map = rhs_value.size() < set.size() ? create_new_map(rhs_value, set) - : create_new_map(set, rhs_value); - set = std::move(new_map); + auto new_set = rhs_set->size() < set->size() ? create_new_set(rhs_set, set) + : create_new_set(set, rhs_set); + set = std::move(new_set); + } + + LOG(INFO) << "After merge: set = {"; + HybridSetBase::IteratorBase* it = set->begin(); + while (it->has_next()) { + auto* value = it->get_value(); + LOG(INFO) << value << " "; + it->next(); } + LOG(INFO) << "}"; } void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { - auto& set = this->data(place).value; - auto& version = this->data(place).version; + auto& data = this->data(place); + auto& set = data.value; + auto version = data.version; + CHECK(set != nullptr); + + LOG(INFO) << "serialize: version = " << version << ", set size = " << set->size(); + + bool is_set_contains_null = set->contain_null(); + if (is_set_contains_null) { + LOG(INFO) << "Before writing of serialize, the set contains null."; + } + + write_pod_binary(is_set_contains_null, buf); + write_var_uint(version, buf); - write_var_uint(set.size(), buf); + write_var_uint(set->size(), buf); + HybridSetBase::IteratorBase* it = set->begin(); + + if (it == nullptr) { + LOG(INFO) << "Before writing of serialize, the set->begin() is nullptr."; + } + + if (it->has_next()) { + LOG(INFO) << "Before writing of serialize, the it->has_next() is true."; + } - for (const auto& elem : set) write_string_binary(elem.get_value(), buf); + while (it->has_next()) { + if (it->get_value() == nullptr) { + LOG(INFO) << "during writing of serialize, the it->get_value() is nullptr."; + } + const auto* value_ptr = reinterpret_cast(it->get_value()); + // const StringRef* str_ref = reinterpret_cast(value_ptr); + LOG(INFO) << "Serializing element: " << *(value_ptr->data); + LOG(INFO) << "after writing element... "; + write_string_binary(*(value_ptr), buf); + it->next(); + } } void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, Arena* arena) const override { - auto& set = this->data(place).value; - auto& version = this->data(place).version; + auto& data = this->data(place); + bool is_set_contains_null; + + LOG(INFO) << "deserialize"; + + read_pod_binary(is_set_contains_null, buf); + data.value->change_contains_null_value(is_set_contains_null); + + read_var_uint(data.version, buf); + LOG(INFO) << "Deserialized version: " << data.version; + // this->data(place).value.read(buf); size_t size; - read_var_uint(version, buf); read_var_uint(size, buf); - set.reserve(size); - UInt64 elem_version; + LOG(INFO) << "Deserialized size: " << size; + + LOG(INFO) << "Deserialized set: {"; + StringRef element; for (size_t i = 0; i < size; ++i) { - auto key = read_string_binary_into(*arena, buf); - read_var_uint(elem_version, buf); - set.insert(key); + element = read_string_binary_into(*arena, buf); + LOG(INFO) << "derializing element: " << element.to_string(); + data.value->insert(element.data); + } + LOG(INFO) << "}"; + + if (data.value->contain_null()) { + LOG(INFO) << "After reading of deserialize, the set contains null."; } } void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { + LOG(INFO) << "the name of to is: " << to.get_name(); + ColumnArray& arr_to = assert_cast(to); - auto& offsets_to = arr_to.get_offsets(); - IColumn& data_to = arr_to.get_data(); + LOG(INFO) << "the name of arr_to is: " << arr_to.get_name(); + + ColumnArray::Offsets64& offsets_to = arr_to.get_offsets(); - auto& set = this->data(place).value; + LOG(INFO) << "nested_col is nullable. "; + auto& data_to = arr_to.get_data(); + auto col_null = reinterpret_cast(&data_to); + auto& null_map_data = col_null->get_null_map_data(); - offsets_to.push_back(offsets_to.back() + set.size()); + const auto& set = this->data(place).value; + LOG(INFO) << "insert_result_into: set size = " << set->size(); - for (auto& elem : set) { - if constexpr (is_plain_column) - data_to.insert_data(elem.get_value().data, elem.get_value().size); - else - std::ignore = data_to.deserialize_and_insert_from_arena(elem.get_value().data); + auto res_size = set->size(); + + if (set->contain_null()) { + LOG(INFO) << "We have null, insert it!"; + col_null->insert_data(nullptr, 0); + res_size += 1; + } + + if (offsets_to.size() == 0) { + offsets_to.push_back(res_size); + } else { + offsets_to.push_back(offsets_to.back() + res_size); } + + data_to.resize(res_size); + + HybridSetBase::IteratorBase* it = set->begin(); + while (it->has_next()) { + const auto* value = + reinterpret_cast(it->get_value()); + if constexpr (is_plain_column) { + data_to.insert_data(value->data, value->size); + } else { + std::ignore = data_to.deserialize_and_insert_from_arena(value->data); + } + null_map_data.push_back(0); + it->next(); + } + LOG(INFO) << "After making value to data_to, leaving..."; } }; @@ -633,26 +831,27 @@ namespace { /// Substitute return type for DateV2 and DateTimeV2 class AggregateFunctionGroupArrayIntersectDateV2 - : public AggregateFunctionGroupArrayIntersect { + : public AggregateFunctionGroupArrayIntersect { public: explicit AggregateFunctionGroupArrayIntersectDateV2(const DataTypes& argument_types_) - : AggregateFunctionGroupArrayIntersect( + : AggregateFunctionGroupArrayIntersect( DataTypes(argument_types_.begin(), argument_types_.end())) {} }; class AggregateFunctionGroupArrayIntersectDateTimeV2 - : public AggregateFunctionGroupArrayIntersect { + : public AggregateFunctionGroupArrayIntersect { public: explicit AggregateFunctionGroupArrayIntersectDateTimeV2(const DataTypes& argument_types_) - : AggregateFunctionGroupArrayIntersect( + : AggregateFunctionGroupArrayIntersect( DataTypes(argument_types_.begin(), argument_types_.end())) {} }; -IAggregateFunction* create_with_extra_types(const DataTypes& argument_types) { - WhichDataType which(argument_types[0]); - if (which.idx == TypeIndex::DateV2) +IAggregateFunction* create_with_extra_types(const DataTypePtr& nested_type, + const DataTypes& argument_types) { + WhichDataType which(nested_type); + if (which.idx == TypeIndex::Date || which.idx == TypeIndex::DateV2) return new AggregateFunctionGroupArrayIntersectDateV2(argument_types); - else if (which.idx == TypeIndex::DateTimeV2) + else if (which.idx == TypeIndex::DateTime || which.idx == TypeIndex::DateTimeV2) return new AggregateFunctionGroupArrayIntersectDateTimeV2(argument_types); else { /// Check that we can use plain version of AggregateFunctionGroupArrayIntersectGeneric @@ -665,9 +864,9 @@ IAggregateFunction* create_with_extra_types(const DataTypes& argument_types) { inline AggregateFunctionPtr create_aggregate_function_group_array_intersect_impl( const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { - const auto& nested_type = - dynamic_cast(*(argument_types[0])).get_nested_type(); - WhichDataType which_type(remove_nullable(nested_type)); + const auto& nested_type = remove_nullable( + dynamic_cast(*(argument_types[0])).get_nested_type()); + WhichDataType which_type(nested_type); if (which_type.is_int()) { LOG(INFO) << "nested_type is int"; } else { @@ -676,7 +875,7 @@ inline AggregateFunctionPtr create_aggregate_function_group_array_intersect_impl DataTypes new_argument_types = {nested_type}; LOG(INFO) << "in the create_aggregate_function_group_array_intersect_impl, name of " "remove_nullable(nested_type): " - << boost::core::demangle(typeid(remove_nullable(nested_type)).name()); + << boost::core::demangle(typeid(nested_type).name()); const auto& argument_type = dynamic_cast(*(argument_types[0])); LOG(INFO) << "in the create_aggregate_function_group_array_intersect_impl, name of " @@ -687,7 +886,7 @@ inline AggregateFunctionPtr create_aggregate_function_group_array_intersect_impl // creator_with_numeric_type::creator( // "", new_argument_types, result_is_nullable)); AggregateFunctionPtr res = nullptr; - WhichDataType which(remove_nullable(nested_type)); + WhichDataType which(nested_type); #define DISPATCH(TYPE) \ if (which.idx == TypeIndex::TYPE) \ res = creator_without_type::create>( \ @@ -695,7 +894,7 @@ inline AggregateFunctionPtr create_aggregate_function_group_array_intersect_impl FOR_NUMERIC_TYPES(DISPATCH) #undef DISPATCH if (!res) { - res = AggregateFunctionPtr(create_with_extra_types(remove_nullable(argument_types))); + res = AggregateFunctionPtr(create_with_extra_types(nested_type, argument_types)); } if (!res) From 036c57679d56a20c58524a03b80b5cb984670e9e Mon Sep 17 00:00:00 2001 From: chesterxu Date: Tue, 2 Apr 2024 19:22:17 +0800 Subject: [PATCH 11/21] fix --- ...aggregate_function_group_array_intersect.h | 139 ++++++++++-------- 1 file changed, 77 insertions(+), 62 deletions(-) diff --git a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h index 36aafb72361698..a53a90f81229a0 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h +++ b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h @@ -71,9 +71,8 @@ constexpr PrimitiveType TypeToPrimitiveType() { } else if constexpr (std::is_same_v) { return TYPE_DATETIMEV2; } else { - throw Exception( - ErrorCode::INVALID_ARGUMENT, - "Only for changing Numeric type or Date(DateTime)V2 type to PrimitiveType"); + throw Exception(ErrorCode::INVALID_ARGUMENT, + "Only for changing Numeric type or Date(DateTime)V2 type to PrimitiveType"); } } @@ -255,7 +254,6 @@ class AggregateFunctionGroupArrayIntersect auto& data = this->data(place); auto& set = data.value; auto& rhs_set = this->data(rhs).value; - set->change_contains_null_value(rhs_set->contain_null()); LOG(INFO) << "merge: place set size = " << set->size(); @@ -268,8 +266,8 @@ class AggregateFunctionGroupArrayIntersect LOG(INFO) << "merge: version = " << version; if (version == 0) { + set->change_contains_null_value(rhs_set->contain_null()); LOG(INFO) << "Copying rhs set to place set"; - const auto& rhs_set = this->data(rhs).value; HybridSetBase::IteratorBase* it = rhs_set->begin(); while (it->has_next()) { const void* value = it->get_value(); @@ -286,12 +284,14 @@ class AggregateFunctionGroupArrayIntersect HybridSetBase::IteratorBase* it = lhs_val->begin(); while (it->has_next()) { const void* value = it->get_value(); - bool found = (rhs_val->find(value) || (rhs_val->contain_null() && value == nullptr)); + bool found = (rhs_val->find(value)); if (found) { new_set->insert(value); } it->next(); } + new_set->change_contains_null_value(lhs_val->contain_null() && + rhs_val->contain_null()); return new_set; }; auto new_set = rhs_set->size() < set->size() ? create_new_set(rhs_set, set) @@ -493,9 +493,9 @@ class AggregateFunctionGroupArrayIntersectGeneric : public IAggregateFunctionDataHelper< AggregateFunctionGroupArrayIntersectGenericData, AggregateFunctionGroupArrayIntersectGeneric> { - const DataTypePtr& input_data_type; - +private: using State = AggregateFunctionGroupArrayIntersectGenericData; + DataTypePtr input_data_type; public: AggregateFunctionGroupArrayIntersectGeneric(const DataTypes& input_data_type_) @@ -510,6 +510,8 @@ class AggregateFunctionGroupArrayIntersectGeneric // DataTypePtr get_return_type() const override { return input_data_type; } DataTypePtr get_return_type() const override { // return std::make_shared(make_nullable(input_data_type)); + LOG(INFO) << "in the get_return_type, name of input_data_type: " + << boost::core::demangle(typeid(input_data_type).name()); return input_data_type; } @@ -541,17 +543,18 @@ class AggregateFunctionGroupArrayIntersectGeneric bool is_column_data_nullable = column_data.is_nullable(); ColumnNullable* col_null = nullptr; - const ColumnArray* nested_column_data = nullptr; + // const ColumnArray* nested_column_data = nullptr; if (is_column_data_nullable) { LOG(INFO) << "nested_col is nullable. "; auto const_col_data = const_cast(&column_data); col_null = static_cast(const_col_data); - nested_column_data = &assert_cast(col_null->get_nested_column()); + // nested_column_data = &assert_cast(col_null->get_nested_column()); } else { LOG(INFO) << "nested_col is not nullable. "; - nested_column_data = &static_cast(column_data); + // nested_column_data = &static_cast(column_data); } + auto nested_column_data = data_column; ++version; if (version == 1) { @@ -562,30 +565,32 @@ class AggregateFunctionGroupArrayIntersectGeneric StringRef src = StringRef(); if constexpr (is_plain_column) { + LOG(INFO) << "is_plain_column is true"; src = nested_column_data->get_data_at(offset + i); - src.data = arena->insert(src.data, src.size); + LOG(INFO) << "we get src"; } else { + LOG(INFO) << "is_plain_column is false"; const char* begin = nullptr; src = nested_column_data->serialize_value_into_arena(offset + i, *arena, begin); - src.data = arena->insert(src.data, src.size); + LOG(INFO) << "we get src"; } - const auto* src_data = is_null_element ? nullptr : src.data; + src.data = is_null_element ? nullptr : arena->insert(src.data, src.size); if (is_null_element) { - LOG(INFO) << "src_data is null!!!!"; + LOG(INFO) << "src.data is null!!!!"; // src_data.data = nullptr; } else { - LOG(INFO) << "src_data is not null"; - LOG(INFO) << "Inserting value: " << *(src_data); + LOG(INFO) << "src.data is not null"; + LOG(INFO) << "Inserting value: " << *(src.data); } - if (src_data == nullptr) { - LOG(INFO) << "src_data==nullptr is true. "; + if (src.data == nullptr) { + LOG(INFO) << "src.data==nullptr is true. "; } // set->insert(src_data.data); - set->insert(src_data); + set->insert((void*)src.data, src.size); LOG(INFO) << "After inserting value."; } @@ -604,32 +609,30 @@ class AggregateFunctionGroupArrayIntersectGeneric StringRef src = StringRef(); if constexpr (is_plain_column) { src = nested_column_data->get_data_at(offset + i); - src.data = arena->insert(src.data, src.size); } else { const char* begin = nullptr; src = nested_column_data->serialize_value_into_arena(offset + i, *arena, begin); - src.data = arena->insert(src.data, src.size); } - const auto* src_data = is_null_element ? nullptr : src.data; + src.data = is_null_element ? nullptr : arena->insert(src.data, src.size); if (is_null_element) { - LOG(INFO) << "src_data is null again ~"; + LOG(INFO) << "src.data is null again ~"; // src_data.data = nullptr; } else { - LOG(INFO) << "Inserting value: " << *src_data; + LOG(INFO) << "Inserting value: " << *(src.data); } - if (src_data == nullptr) { + if (src.data == nullptr) { LOG(INFO) << "src_data==nullptr is true. "; } - if (set->find(src_data) || (set->contain_null() && src_data == nullptr)) { - if (set->find(src_data)) { - LOG(INFO) << "we can find set's value: " << *src_data; + if (set->find(src.data, src.size) || (set->contain_null() && src.data == nullptr)) { + if (set->find(src.data, src.size)) { + LOG(INFO) << "we can find set's value: " << *(src.data); } else { LOG(INFO) << "src_data is null again "; } - new_set->insert(src_data); + new_set->insert((void*)src.data, src.size); LOG(INFO) << "After inserting value."; } } @@ -649,48 +652,53 @@ class AggregateFunctionGroupArrayIntersectGeneric } void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, - Arena* arena) const override { + Arena*) const override { auto& data = this->data(place); auto& set = data.value; auto& rhs_set = this->data(rhs).value; - set->change_contains_null_value(rhs_set->contain_null()); LOG(INFO) << "merge: place set size = " << set->size(); if (this->data(rhs).version == 0) { LOG(INFO) << "rhs version is 0, skipping merge"; return; + } else { + LOG(INFO) << "rhs version is: " << this->data(rhs).version; } UInt64 version = data.version++; LOG(INFO) << "merge: version = " << version; if (version == 0) { + set->change_contains_null_value(rhs_set->contain_null()); LOG(INFO) << "Copying rhs set to place set"; - const auto& rhs_set = this->data(rhs).value; HybridSetBase::IteratorBase* it = rhs_set->begin(); while (it->has_next()) { - const void* value = it->get_value(); - set->insert(value); + const StringRef* value = reinterpret_cast(it->get_value()); + LOG(INFO) << "before inserting..."; + LOG(INFO) << "inserting: " << *(char*)(value->data); + set->insert((void*)(value->data), value->size); + LOG(INFO) << "after inserting..."; it->next(); + LOG(INFO) << "after next..."; } - return; - } - - if (set->size() != 0) { + } else if (set->size() != 0) { LOG(INFO) << "Merging place set and rhs set"; auto create_new_set = [](auto& lhs_val, auto& rhs_val) { typename State::Set new_set; HybridSetBase::IteratorBase* it = lhs_val->begin(); while (it->has_next()) { - const void* value = it->get_value(); - bool found = - (rhs_val->find(value) || (rhs_val->contain_null() && value == nullptr)); + const auto* value = reinterpret_cast(it->get_value()); + bool found = (rhs_val->find(value)); + LOG(INFO) << "before inserting..."; if (found) { - new_set->insert(value); + new_set->insert((void*)value->data, value->size); } + LOG(INFO) << "after inserting..."; it->next(); } + new_set->change_contains_null_value(lhs_val->contain_null() && + rhs_val->contain_null()); return new_set; }; auto new_set = rhs_set->size() < set->size() ? create_new_set(rhs_set, set) @@ -698,14 +706,14 @@ class AggregateFunctionGroupArrayIntersectGeneric set = std::move(new_set); } - LOG(INFO) << "After merge: set = {"; - HybridSetBase::IteratorBase* it = set->begin(); - while (it->has_next()) { - auto* value = it->get_value(); - LOG(INFO) << value << " "; - it->next(); - } - LOG(INFO) << "}"; + // LOG(INFO) << "After merge: set = {"; + // HybridSetBase::IteratorBase* it = set->begin(); + // while (it->has_next()) { + // auto* value = reinterpret_cast(it->get_value()); + // LOG(INFO) << (*value).to_string() << " "; + // it->next(); + // } + // LOG(INFO) << "}"; } void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { @@ -738,12 +746,14 @@ class AggregateFunctionGroupArrayIntersectGeneric while (it->has_next()) { if (it->get_value() == nullptr) { LOG(INFO) << "during writing of serialize, the it->get_value() is nullptr."; + } else { + LOG(INFO) << "the it->get_value() is not null "; } - const auto* value_ptr = reinterpret_cast(it->get_value()); + const auto* value = reinterpret_cast(it->get_value()); // const StringRef* str_ref = reinterpret_cast(value_ptr); - LOG(INFO) << "Serializing element: " << *(value_ptr->data); + LOG(INFO) << "Serializing element: " << (*(value)).to_string(); LOG(INFO) << "after writing element... "; - write_string_binary(*(value_ptr), buf); + write_string_binary(*value, buf); it->next(); } } @@ -769,8 +779,8 @@ class AggregateFunctionGroupArrayIntersectGeneric StringRef element; for (size_t i = 0; i < size; ++i) { element = read_string_binary_into(*arena, buf); - LOG(INFO) << "derializing element: " << element.to_string(); - data.value->insert(element.data); + LOG(INFO) << "derializing element: " << *(element.data); + data.value->insert((void*)element.data, element.size); } LOG(INFO) << "}"; @@ -790,7 +800,7 @@ class AggregateFunctionGroupArrayIntersectGeneric LOG(INFO) << "nested_col is nullable. "; auto& data_to = arr_to.get_data(); auto col_null = reinterpret_cast(&data_to); - auto& null_map_data = col_null->get_null_map_data(); + // auto& null_map_data = col_null->get_null_map_data(); const auto& set = this->data(place).value; LOG(INFO) << "insert_result_into: set size = " << set->size(); @@ -809,20 +819,25 @@ class AggregateFunctionGroupArrayIntersectGeneric offsets_to.push_back(offsets_to.back() + res_size); } - data_to.resize(res_size); + // data_to.resize(res_size); HybridSetBase::IteratorBase* it = set->begin(); while (it->has_next()) { - const auto* value = - reinterpret_cast(it->get_value()); + const auto* value = reinterpret_cast(it->get_value()); if constexpr (is_plain_column) { data_to.insert_data(value->data, value->size); } else { std::ignore = data_to.deserialize_and_insert_from_arena(value->data); } - null_map_data.push_back(0); + // null_map_data.push_back(0); it->next(); } + + LOG(INFO) << "res_size size = " << res_size; + LOG(INFO) << "offsets_to size = " << offsets_to.size(); + LOG(INFO) << "set size = " << set->size(); + LOG(INFO) << "data_to size = " << data_to.size(); + LOG(INFO) << "After making value to data_to, leaving..."; } }; @@ -855,7 +870,7 @@ IAggregateFunction* create_with_extra_types(const DataTypePtr& nested_type, return new AggregateFunctionGroupArrayIntersectDateTimeV2(argument_types); else { /// Check that we can use plain version of AggregateFunctionGroupArrayIntersectGeneric - if (argument_types[0]->is_value_unambiguously_represented_in_contiguous_memory_region()) + if (nested_type->is_value_unambiguously_represented_in_contiguous_memory_region()) return new AggregateFunctionGroupArrayIntersectGeneric(argument_types); else return new AggregateFunctionGroupArrayIntersectGeneric(argument_types); From 924868152126736cd36618c085dffe2b70215d21 Mon Sep 17 00:00:00 2001 From: chesterxu Date: Tue, 2 Apr 2024 20:49:50 +0800 Subject: [PATCH 12/21] opt --- be/src/exprs/hybrid_set.h | 23 +- ...gregate_function_group_array_intersect.cpp | 49 ++- ...aggregate_function_group_array_intersect.h | 382 +----------------- be/src/vec/common/hash_table/hash_table.h | 2 - 4 files changed, 53 insertions(+), 403 deletions(-) diff --git a/be/src/exprs/hybrid_set.h b/be/src/exprs/hybrid_set.h index 7c6d15d19b6593..ba5fabe509be62 100644 --- a/be/src/exprs/hybrid_set.h +++ b/be/src/exprs/hybrid_set.h @@ -170,8 +170,6 @@ class DynamicContainer { size_t size() const { return _set.size(); } - void clear() { _set.clear(); } - private: vectorized::flat_hash_set _set; }; @@ -235,18 +233,7 @@ class HybridSetBase : public RuntimeFilterFuncBase { virtual IteratorBase* begin() = 0; - bool contain_null() const { - LOG(INFO) << "Entering the func contain_null.. "; - if (_contains_null && _null_aware) { - LOG(INFO) - << "The func finds out containing null, _contains_null && _null_aware is true."; - } else if (_null_aware) { - LOG(INFO) << "The func finds out not containing null, _contains_null is not true"; - } else { - LOG(INFO) << "The func finds out not containing null, both are false "; - } - return _contains_null && _null_aware; - } + bool contain_null() const { return _contains_null && _null_aware; } bool _contains_null = false; }; @@ -284,12 +271,6 @@ class HybridSet : public HybridSetBase { ~HybridSet() override = default; void insert(const void* data) override { - LOG(INFO) << "Entering the func insert.. "; - if (data == nullptr) { - LOG(INFO) << "The func finds out data is nullptr."; - } else { - LOG(INFO) << "The func finds out data is not nullptr. "; - } if (data == nullptr) { _contains_null = true; return; @@ -327,8 +308,6 @@ class HybridSet : public HybridSetBase { int size() override { return _set.size(); } - void clear() { _set.clear(); } - bool find(const void* data) const override { if (data == nullptr) { return false; diff --git a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.cpp b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.cpp index a812d8d9b71508..c2c9e0fb0742e2 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.cpp @@ -21,17 +21,54 @@ #include "vec/aggregate_functions/aggregate_function_group_array_intersect.h" namespace doris::vectorized { + +IAggregateFunction* create_with_extra_types(const DataTypePtr& nested_type, + const DataTypes& argument_types) { + WhichDataType which(nested_type); + if (which.idx == TypeIndex::Date || which.idx == TypeIndex::DateV2) + return new AggregateFunctionGroupArrayIntersect(argument_types); + else if (which.idx == TypeIndex::DateTime || which.idx == TypeIndex::DateTimeV2) + return new AggregateFunctionGroupArrayIntersect(argument_types); + else { + /// Check that we can use plain version of AggregateFunctionGroupArrayIntersectGeneric + if (nested_type->is_value_unambiguously_represented_in_contiguous_memory_region()) + return new AggregateFunctionGroupArrayIntersectGeneric(argument_types); + else + return new AggregateFunctionGroupArrayIntersectGeneric(argument_types); + } +} + +inline AggregateFunctionPtr create_aggregate_function_group_array_intersect_impl( + const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { + const auto& nested_type = remove_nullable( + dynamic_cast(*(argument_types[0])).get_nested_type()); + AggregateFunctionPtr res = nullptr; + + WhichDataType which(nested_type); +#define DISPATCH(TYPE) \ + if (which.idx == TypeIndex::TYPE) \ + res = creator_without_type::create>( \ + argument_types, result_is_nullable); + FOR_NUMERIC_TYPES(DISPATCH) +#undef DISPATCH + + if (!res) { + res = AggregateFunctionPtr(create_with_extra_types(nested_type, argument_types)); + } + + if (!res) + throw Exception(ErrorCode::INVALID_ARGUMENT, + "Illegal type {} of argument for aggregate function {}", + argument_types[0]->get_name(), name); + + return res; +} + AggregateFunctionPtr create_aggregate_function_group_array_intersect( const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { assert_unary(name, argument_types); - std::string demangled_name_before = boost::core::demangle(typeid((argument_types[0])).name()); - LOG(INFO) << "in the cpp, before remove, name of argument_types[0]: " << demangled_name_before; const DataTypePtr& argument_type = remove_nullable(argument_types[0]); - std::string demangled_name_argument_type = boost::core::demangle(typeid(argument_type).name()); - LOG(INFO) << "in the cpp, after remove, name of argument_type: " - << demangled_name_argument_type; - if (!WhichDataType(argument_type).is_array()) throw Exception(ErrorCode::INVALID_ARGUMENT, "Aggregate function groupArrayIntersect accepts only array type argument. " diff --git a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h index a53a90f81229a0..3f3a33f5fbbffd 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h +++ b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h @@ -18,20 +18,17 @@ // https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/AggregateFunctionGroupArrayIntersect.cpp // and modified by Doris -#include -#include -#include -#include - -#include #include #include #include "exprs/hybrid_set.h" #include "vec/aggregate_functions/aggregate_function.h" #include "vec/aggregate_functions/aggregate_function_simple_factory.h" +#include "vec/aggregate_functions/factory_helpers.h" #include "vec/aggregate_functions/helpers.h" #include "vec/columns/column_array.h" +#include "vec/common/assert_cast.h" +#include "vec/core/field.h" #include "vec/data_types/data_type_array.h" #include "vec/data_types/data_type_number.h" #include "vec/data_types/data_type_string.h" @@ -123,13 +120,7 @@ class AggregateFunctionGroupArrayIntersect String get_name() const override { return "group_array_intersect"; } - DataTypePtr get_return_type() const override { - std::string demangled_name = boost::core::demangle(typeid(argument_type).name()); - LOG(INFO) << "in the get_return_type, name of argument_type: " << demangled_name; - std::string demangled_name_T = boost::core::demangle(typeid(T).name()); - LOG(INFO) << "in the get_return_type, name of T: " << demangled_name_T; - return argument_type; - } + DataTypePtr get_return_type() const override { return argument_type; } bool allocates_memory_in_arena() const override { return false; } @@ -138,9 +129,7 @@ class AggregateFunctionGroupArrayIntersect auto& data = this->data(place); auto& version = data.version; auto& set = data.value; - CHECK(set != nullptr); - LOG(INFO) << "Input row_num: " << row_num; // 输出输入的 row_num const bool col_is_nullable = (*columns[0]).is_nullable(); const ColumnArray& column = col_is_nullable ? assert_cast( @@ -154,8 +143,6 @@ class AggregateFunctionGroupArrayIntersect const auto arr_size = offsets[row_num] - offset; using ColVecType = ColumnVector; - LOG(INFO) << "the name of column is: " << column.get_name(); - const auto& column_data = column.get_data(); bool is_column_data_nullable = column_data.is_nullable(); @@ -163,89 +150,38 @@ class AggregateFunctionGroupArrayIntersect const ColVecType* nested_column_data = nullptr; if (is_column_data_nullable) { - LOG(INFO) << "nested_col is nullable. "; auto const_col_data = const_cast(&column_data); col_null = static_cast(const_col_data); nested_column_data = &assert_cast(col_null->get_nested_column()); } else { - LOG(INFO) << "nested_col is not nullable. "; nested_column_data = &static_cast(column_data); } ++version; if (version == 1) { - LOG(INFO) << "version is 1."; for (size_t i = 0; i < arr_size; ++i) { const bool is_null_element = is_column_data_nullable && col_null->is_null_at(offset + i); const T* src_data = is_null_element ? nullptr : &(nested_column_data->get_element(offset + i)); - if (is_null_element) { - LOG(INFO) << "src_data is null!!!!"; - // src_data.data = nullptr; - } else { - LOG(INFO) << "src_data is not null"; - LOG(INFO) << "Inserting value: " << *(src_data); - } - - if (src_data == nullptr) { - LOG(INFO) << "src_data==nullptr is true. "; - } - - // set->insert(src_data.data); set->insert(src_data); - LOG(INFO) << "After inserting value."; } - - if (set->contain_null()) { - LOG(INFO) << "in the last of version==1, the set contains null."; - } - } else if (set->size() != 0 || set->contain_null()) { - // typename State::Set new_set; typename State::Set new_set = std::make_shared(); - CHECK(new_set != nullptr); for (size_t i = 0; i < arr_size; ++i) { const bool is_null_element = is_column_data_nullable && col_null->is_null_at(offset + i); const T* src_data = is_null_element ? nullptr : &(nested_column_data->get_element(offset + i)); - if (is_null_element) { - LOG(INFO) << "src_data is null again ~"; - // src_data.data = nullptr; - } else { - LOG(INFO) << "Inserting value: " << *src_data; - } - - if (src_data == nullptr) { - LOG(INFO) << "src_data==nullptr is true. "; - } if (set->find(src_data) || (set->contain_null() && src_data == nullptr)) { - if (set->find(src_data)) { - LOG(INFO) << "we can find set's value: " << *src_data; - } else { - LOG(INFO) << "src_data is null again "; - } new_set->insert(src_data); - LOG(INFO) << "After inserting value."; } } - - if (set->contain_null()) { - LOG(INFO) << "before swap between set and new_set, the set contains null."; - } - if (new_set->contain_null()) { - LOG(INFO) << "before swap between set and new_set, the new_set contains null."; - } set = std::move(new_set); - - if (set->contain_null()) { - LOG(INFO) << "after swap between set and new_set, the set contains null."; - } } } @@ -255,19 +191,13 @@ class AggregateFunctionGroupArrayIntersect auto& set = data.value; auto& rhs_set = this->data(rhs).value; - LOG(INFO) << "merge: place set size = " << set->size(); - if (this->data(rhs).version == 0) { - LOG(INFO) << "rhs version is 0, skipping merge"; return; } UInt64 version = data.version++; - LOG(INFO) << "merge: version = " << version; - if (version == 0) { set->change_contains_null_value(rhs_set->contain_null()); - LOG(INFO) << "Copying rhs set to place set"; HybridSetBase::IteratorBase* it = rhs_set->begin(); while (it->has_next()) { const void* value = it->get_value(); @@ -278,14 +208,12 @@ class AggregateFunctionGroupArrayIntersect } if (set->size() != 0) { - LOG(INFO) << "Merging place set and rhs set"; auto create_new_set = [](auto& lhs_val, auto& rhs_val) { typename State::Set new_set; HybridSetBase::IteratorBase* it = lhs_val->begin(); while (it->has_next()) { const void* value = it->get_value(); - bool found = (rhs_val->find(value)); - if (found) { + if ((rhs_val->find(value))) { new_set->insert(value); } it->next(); @@ -298,29 +226,14 @@ class AggregateFunctionGroupArrayIntersect : create_new_set(set, rhs_set); set = std::move(new_set); } - - LOG(INFO) << "After merge: set = {"; - HybridSetBase::IteratorBase* it = set->begin(); - while (it->has_next()) { - T value = *reinterpret_cast(it->get_value()); - LOG(INFO) << value << " "; - it->next(); - } - LOG(INFO) << "}"; } void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { auto& data = this->data(place); auto& set = data.value; auto version = data.version; - CHECK(set != nullptr); - - LOG(INFO) << "serialize: version = " << version << ", set size = " << set->size(); bool is_set_contains_null = set->contain_null(); - if (is_set_contains_null) { - LOG(INFO) << "Before writing of serialize, the set contains null."; - } write_pod_binary(is_set_contains_null, buf); @@ -328,22 +241,8 @@ class AggregateFunctionGroupArrayIntersect write_var_uint(set->size(), buf); HybridSetBase::IteratorBase* it = set->begin(); - if (it == nullptr) { - LOG(INFO) << "Before writing of serialize, the set->begin() is nullptr."; - } - - if (it->has_next()) { - LOG(INFO) << "Before writing of serialize, the it->has_next() is true."; - } - while (it->has_next()) { - if (it->get_value() == nullptr) { - LOG(INFO) << "during writing of serialize, the it->get_value() is nullptr."; - } const T* value_ptr = static_cast(it->get_value()); - // const StringRef* str_ref = reinterpret_cast(value_ptr); - LOG(INFO) << "Serializing element: " << (*value_ptr); - LOG(INFO) << "after writing element... "; write_int_binary((*value_ptr), buf); it->next(); } @@ -354,40 +253,21 @@ class AggregateFunctionGroupArrayIntersect auto& data = this->data(place); bool is_set_contains_null; - LOG(INFO) << "deserialize"; - read_pod_binary(is_set_contains_null, buf); data.value->change_contains_null_value(is_set_contains_null); - read_var_uint(data.version, buf); - LOG(INFO) << "Deserialized version: " << data.version; - // this->data(place).value.read(buf); size_t size; read_var_uint(size, buf); - LOG(INFO) << "Deserialized size: " << size; - LOG(INFO) << "Deserialized set: {"; T element; for (size_t i = 0; i < size; ++i) { read_int_binary(element, buf); - LOG(INFO) << "derializing element: " << element; data.value->insert(static_cast(&element)); } - LOG(INFO) << "}"; - - if (data.value->contain_null()) { - LOG(INFO) << "After reading of deserialize, the set contains null."; - } } void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { - LOG(INFO) << "in the insert_result_into, name of T: " - << boost::core::demangle(typeid(T).name()); - LOG(INFO) << "in the start of insert, Type name: " << typeid(T).name(); - LOG(INFO) << "the name of to is: " << to.get_name(); - ColumnArray& arr_to = assert_cast(to); - LOG(INFO) << "the name of arr_to is: " << arr_to.get_name(); ColumnArray::Offsets64& offsets_to = arr_to.get_offsets(); @@ -398,22 +278,18 @@ class AggregateFunctionGroupArrayIntersect bool is_nullable = to_nested_col.is_nullable(); if (is_nullable) { - LOG(INFO) << "nested_col is nullable. "; auto col_null = reinterpret_cast(&to_nested_col); auto& nested_col = assert_cast(col_null->get_nested_column()); auto& null_map_data = col_null->get_null_map_data(); size_t old_size = nested_col.get_data().size(); - LOG(INFO) << "old_size of data_to: " << old_size; const auto& set = this->data(place).value; - LOG(INFO) << "insert_result_into: set size = " << set->size(); auto res_size = set->size(); size_t i = 0; if (set->contain_null()) { - LOG(INFO) << "We have null, insert it!"; col_null->insert_data(nullptr, 0); res_size += 1; i = 1; @@ -430,21 +306,16 @@ class AggregateFunctionGroupArrayIntersect HybridSetBase::IteratorBase* it = set->begin(); while (it->has_next()) { T value = *reinterpret_cast(it->get_value()); - LOG(INFO) << "Inserting value: " << value << " at index " << (old_size + i); nested_col.get_data()[old_size + i] = value; null_map_data.push_back(0); it->next(); ++i; } - // } } else { - LOG(INFO) << "nested_col is not nullable. "; auto& nested_col = static_cast(to_nested_col); size_t old_size = nested_col.get_data().size(); - LOG(INFO) << "old_size of data_to: " << old_size; const auto& set = this->data(place).value; - LOG(INFO) << "insert_result_into: set size = " << set->size(); if (offsets_to.size() == 0) { offsets_to.push_back(set->size()); @@ -458,13 +329,11 @@ class AggregateFunctionGroupArrayIntersect HybridSetBase::IteratorBase* it = set->begin(); while (it->has_next()) { T value = *reinterpret_cast(it->get_value()); - LOG(INFO) << "Inserting value: " << value << " at index " << (old_size + i); nested_col.get_data()[old_size + i] = value; it->next(); ++i; } } - LOG(INFO) << "After making value to data_to, leaving..."; } }; @@ -507,13 +376,7 @@ class AggregateFunctionGroupArrayIntersectGeneric String get_name() const override { return "group_array_intersect"; } - // DataTypePtr get_return_type() const override { return input_data_type; } - DataTypePtr get_return_type() const override { - // return std::make_shared(make_nullable(input_data_type)); - LOG(INFO) << "in the get_return_type, name of input_data_type: " - << boost::core::demangle(typeid(input_data_type).name()); - return input_data_type; - } + DataTypePtr get_return_type() const override { return input_data_type; } bool allocates_memory_in_arena() const override { return true; } @@ -522,9 +385,7 @@ class AggregateFunctionGroupArrayIntersectGeneric auto& data = this->data(place); auto& version = data.version; auto& set = data.value; - CHECK(set != nullptr); - LOG(INFO) << "Input row_num: " << row_num; // 输出输入的 row_num const bool col_is_nullable = (*columns[0]).is_nullable(); const ColumnArray& column = col_is_nullable ? assert_cast( @@ -532,77 +393,39 @@ class AggregateFunctionGroupArrayIntersectGeneric .get_nested_column()) : assert_cast(*columns[0]); - const auto data_column = column.get_data_ptr(); + const auto nested_column_data = column.get_data_ptr(); const auto& offsets = column.get_offsets(); const size_t offset = offsets[static_cast(row_num) - 1]; const auto arr_size = offsets[row_num] - offset; - - LOG(INFO) << "the name of column is: " << column.get_name(); - const auto& column_data = column.get_data(); - bool is_column_data_nullable = column_data.is_nullable(); ColumnNullable* col_null = nullptr; - // const ColumnArray* nested_column_data = nullptr; if (is_column_data_nullable) { - LOG(INFO) << "nested_col is nullable. "; auto const_col_data = const_cast(&column_data); col_null = static_cast(const_col_data); - // nested_column_data = &assert_cast(col_null->get_nested_column()); - } else { - LOG(INFO) << "nested_col is not nullable. "; - // nested_column_data = &static_cast(column_data); } - auto nested_column_data = data_column; ++version; if (version == 1) { - LOG(INFO) << "version is 1."; for (size_t i = 0; i < arr_size; ++i) { const bool is_null_element = is_column_data_nullable && col_null->is_null_at(offset + i); StringRef src = StringRef(); if constexpr (is_plain_column) { - LOG(INFO) << "is_plain_column is true"; src = nested_column_data->get_data_at(offset + i); - LOG(INFO) << "we get src"; } else { - LOG(INFO) << "is_plain_column is false"; const char* begin = nullptr; src = nested_column_data->serialize_value_into_arena(offset + i, *arena, begin); - LOG(INFO) << "we get src"; } src.data = is_null_element ? nullptr : arena->insert(src.data, src.size); - - if (is_null_element) { - LOG(INFO) << "src.data is null!!!!"; - // src_data.data = nullptr; - } else { - LOG(INFO) << "src.data is not null"; - LOG(INFO) << "Inserting value: " << *(src.data); - } - - if (src.data == nullptr) { - LOG(INFO) << "src.data==nullptr is true. "; - } - - // set->insert(src_data.data); set->insert((void*)src.data, src.size); - LOG(INFO) << "After inserting value."; - } - - if (set->contain_null()) { - LOG(INFO) << "in the last of version==1, the set contains null."; } - } else if (set->size() != 0 || set->contain_null()) { - // typename State::Set new_set; typename State::Set new_set = std::make_shared(); - CHECK(new_set != nullptr); for (size_t i = 0; i < arr_size; ++i) { const bool is_null_element = is_column_data_nullable && col_null->is_null_at(offset + i); @@ -615,39 +438,11 @@ class AggregateFunctionGroupArrayIntersectGeneric } src.data = is_null_element ? nullptr : arena->insert(src.data, src.size); - - if (is_null_element) { - LOG(INFO) << "src.data is null again ~"; - // src_data.data = nullptr; - } else { - LOG(INFO) << "Inserting value: " << *(src.data); - } - - if (src.data == nullptr) { - LOG(INFO) << "src_data==nullptr is true. "; - } if (set->find(src.data, src.size) || (set->contain_null() && src.data == nullptr)) { - if (set->find(src.data, src.size)) { - LOG(INFO) << "we can find set's value: " << *(src.data); - } else { - LOG(INFO) << "src_data is null again "; - } new_set->insert((void*)src.data, src.size); - LOG(INFO) << "After inserting value."; } } - - if (set->contain_null()) { - LOG(INFO) << "before swap between set and new_set, the set contains null."; - } - if (new_set->contain_null()) { - LOG(INFO) << "before swap between set and new_set, the new_set contains null."; - } set = std::move(new_set); - - if (set->contain_null()) { - LOG(INFO) << "after swap between set and new_set, the set contains null."; - } } } @@ -657,44 +452,28 @@ class AggregateFunctionGroupArrayIntersectGeneric auto& set = data.value; auto& rhs_set = this->data(rhs).value; - LOG(INFO) << "merge: place set size = " << set->size(); - if (this->data(rhs).version == 0) { - LOG(INFO) << "rhs version is 0, skipping merge"; return; - } else { - LOG(INFO) << "rhs version is: " << this->data(rhs).version; } UInt64 version = data.version++; - LOG(INFO) << "merge: version = " << version; - if (version == 0) { set->change_contains_null_value(rhs_set->contain_null()); - LOG(INFO) << "Copying rhs set to place set"; HybridSetBase::IteratorBase* it = rhs_set->begin(); while (it->has_next()) { const StringRef* value = reinterpret_cast(it->get_value()); - LOG(INFO) << "before inserting..."; - LOG(INFO) << "inserting: " << *(char*)(value->data); set->insert((void*)(value->data), value->size); - LOG(INFO) << "after inserting..."; it->next(); - LOG(INFO) << "after next..."; } } else if (set->size() != 0) { - LOG(INFO) << "Merging place set and rhs set"; auto create_new_set = [](auto& lhs_val, auto& rhs_val) { typename State::Set new_set; HybridSetBase::IteratorBase* it = lhs_val->begin(); while (it->has_next()) { const auto* value = reinterpret_cast(it->get_value()); - bool found = (rhs_val->find(value)); - LOG(INFO) << "before inserting..."; - if (found) { + if (rhs_val->find(value)) { new_set->insert((void*)value->data, value->size); } - LOG(INFO) << "after inserting..."; it->next(); } new_set->change_contains_null_value(lhs_val->contain_null() && @@ -705,54 +484,22 @@ class AggregateFunctionGroupArrayIntersectGeneric : create_new_set(set, rhs_set); set = std::move(new_set); } - - // LOG(INFO) << "After merge: set = {"; - // HybridSetBase::IteratorBase* it = set->begin(); - // while (it->has_next()) { - // auto* value = reinterpret_cast(it->get_value()); - // LOG(INFO) << (*value).to_string() << " "; - // it->next(); - // } - // LOG(INFO) << "}"; } void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { auto& data = this->data(place); auto& set = data.value; auto version = data.version; - CHECK(set != nullptr); - - LOG(INFO) << "serialize: version = " << version << ", set size = " << set->size(); bool is_set_contains_null = set->contain_null(); - if (is_set_contains_null) { - LOG(INFO) << "Before writing of serialize, the set contains null."; - } write_pod_binary(is_set_contains_null, buf); - write_var_uint(version, buf); write_var_uint(set->size(), buf); - HybridSetBase::IteratorBase* it = set->begin(); - - if (it == nullptr) { - LOG(INFO) << "Before writing of serialize, the set->begin() is nullptr."; - } - - if (it->has_next()) { - LOG(INFO) << "Before writing of serialize, the it->has_next() is true."; - } + HybridSetBase::IteratorBase* it = set->begin(); while (it->has_next()) { - if (it->get_value() == nullptr) { - LOG(INFO) << "during writing of serialize, the it->get_value() is nullptr."; - } else { - LOG(INFO) << "the it->get_value() is not null "; - } const auto* value = reinterpret_cast(it->get_value()); - // const StringRef* str_ref = reinterpret_cast(value_ptr); - LOG(INFO) << "Serializing element: " << (*(value)).to_string(); - LOG(INFO) << "after writing element... "; write_string_binary(*value, buf); it->next(); } @@ -763,52 +510,30 @@ class AggregateFunctionGroupArrayIntersectGeneric auto& data = this->data(place); bool is_set_contains_null; - LOG(INFO) << "deserialize"; - read_pod_binary(is_set_contains_null, buf); data.value->change_contains_null_value(is_set_contains_null); read_var_uint(data.version, buf); - LOG(INFO) << "Deserialized version: " << data.version; - // this->data(place).value.read(buf); size_t size; read_var_uint(size, buf); - LOG(INFO) << "Deserialized size: " << size; - LOG(INFO) << "Deserialized set: {"; StringRef element; for (size_t i = 0; i < size; ++i) { element = read_string_binary_into(*arena, buf); - LOG(INFO) << "derializing element: " << *(element.data); data.value->insert((void*)element.data, element.size); } - LOG(INFO) << "}"; - - if (data.value->contain_null()) { - LOG(INFO) << "After reading of deserialize, the set contains null."; - } } void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { - LOG(INFO) << "the name of to is: " << to.get_name(); - ColumnArray& arr_to = assert_cast(to); - LOG(INFO) << "the name of arr_to is: " << arr_to.get_name(); - ColumnArray::Offsets64& offsets_to = arr_to.get_offsets(); - - LOG(INFO) << "nested_col is nullable. "; auto& data_to = arr_to.get_data(); auto col_null = reinterpret_cast(&data_to); - // auto& null_map_data = col_null->get_null_map_data(); const auto& set = this->data(place).value; - LOG(INFO) << "insert_result_into: set size = " << set->size(); - auto res_size = set->size(); if (set->contain_null()) { - LOG(INFO) << "We have null, insert it!"; col_null->insert_data(nullptr, 0); res_size += 1; } @@ -819,8 +544,6 @@ class AggregateFunctionGroupArrayIntersectGeneric offsets_to.push_back(offsets_to.back() + res_size); } - // data_to.resize(res_size); - HybridSetBase::IteratorBase* it = set->begin(); while (it->has_next()) { const auto* value = reinterpret_cast(it->get_value()); @@ -829,96 +552,9 @@ class AggregateFunctionGroupArrayIntersectGeneric } else { std::ignore = data_to.deserialize_and_insert_from_arena(value->data); } - // null_map_data.push_back(0); it->next(); } - - LOG(INFO) << "res_size size = " << res_size; - LOG(INFO) << "offsets_to size = " << offsets_to.size(); - LOG(INFO) << "set size = " << set->size(); - LOG(INFO) << "data_to size = " << data_to.size(); - - LOG(INFO) << "After making value to data_to, leaving..."; } }; -namespace { - -/// Substitute return type for DateV2 and DateTimeV2 -class AggregateFunctionGroupArrayIntersectDateV2 - : public AggregateFunctionGroupArrayIntersect { -public: - explicit AggregateFunctionGroupArrayIntersectDateV2(const DataTypes& argument_types_) - : AggregateFunctionGroupArrayIntersect( - DataTypes(argument_types_.begin(), argument_types_.end())) {} -}; - -class AggregateFunctionGroupArrayIntersectDateTimeV2 - : public AggregateFunctionGroupArrayIntersect { -public: - explicit AggregateFunctionGroupArrayIntersectDateTimeV2(const DataTypes& argument_types_) - : AggregateFunctionGroupArrayIntersect( - DataTypes(argument_types_.begin(), argument_types_.end())) {} -}; - -IAggregateFunction* create_with_extra_types(const DataTypePtr& nested_type, - const DataTypes& argument_types) { - WhichDataType which(nested_type); - if (which.idx == TypeIndex::Date || which.idx == TypeIndex::DateV2) - return new AggregateFunctionGroupArrayIntersectDateV2(argument_types); - else if (which.idx == TypeIndex::DateTime || which.idx == TypeIndex::DateTimeV2) - return new AggregateFunctionGroupArrayIntersectDateTimeV2(argument_types); - else { - /// Check that we can use plain version of AggregateFunctionGroupArrayIntersectGeneric - if (nested_type->is_value_unambiguously_represented_in_contiguous_memory_region()) - return new AggregateFunctionGroupArrayIntersectGeneric(argument_types); - else - return new AggregateFunctionGroupArrayIntersectGeneric(argument_types); - } -} - -inline AggregateFunctionPtr create_aggregate_function_group_array_intersect_impl( - const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { - const auto& nested_type = remove_nullable( - dynamic_cast(*(argument_types[0])).get_nested_type()); - WhichDataType which_type(nested_type); - if (which_type.is_int()) { - LOG(INFO) << "nested_type is int"; - } else { - LOG(INFO) << "nested_type is not int"; - } - DataTypes new_argument_types = {nested_type}; - LOG(INFO) << "in the create_aggregate_function_group_array_intersect_impl, name of " - "remove_nullable(nested_type): " - << boost::core::demangle(typeid(nested_type).name()); - - const auto& argument_type = dynamic_cast(*(argument_types[0])); - LOG(INFO) << "in the create_aggregate_function_group_array_intersect_impl, name of " - "argument_type: " - << boost::core::demangle(typeid(argument_type).name()); - - // AggregateFunctionPtr res( - // creator_with_numeric_type::creator( - // "", new_argument_types, result_is_nullable)); - AggregateFunctionPtr res = nullptr; - WhichDataType which(nested_type); -#define DISPATCH(TYPE) \ - if (which.idx == TypeIndex::TYPE) \ - res = creator_without_type::create>( \ - argument_types, result_is_nullable); - FOR_NUMERIC_TYPES(DISPATCH) -#undef DISPATCH - if (!res) { - res = AggregateFunctionPtr(create_with_extra_types(nested_type, argument_types)); - } - - if (!res) - throw Exception(ErrorCode::INVALID_ARGUMENT, - "Illegal type {} of argument for aggregate function {}", - argument_types[0]->get_name(), name); - - return res; -} -} // namespace - } // namespace doris::vectorized diff --git a/be/src/vec/common/hash_table/hash_table.h b/be/src/vec/common/hash_table/hash_table.h index ede7897ecdee74..04a5ff8f0e4c7b 100644 --- a/be/src/vec/common/hash_table/hash_table.h +++ b/be/src/vec/common/hash_table/hash_table.h @@ -806,8 +806,6 @@ class HashTable : private boost::noncopyable, } } - void reserve(size_t num_elements) { resize(num_elements); } - /// Insert a value. In the case of any more complex values, it is better to use the `emplace` function. std::pair ALWAYS_INLINE insert(const value_type& x) { std::pair res; From b90b7f96d487e5ec64143cae9e64b82c4526a3b0 Mon Sep 17 00:00:00 2001 From: chesterxu Date: Thu, 4 Apr 2024 00:21:09 +0800 Subject: [PATCH 13/21] opt2 --- ...aggregate_function_group_array_intersect.h | 18 +-- .../aggregate/group_array_intersect.groovy | 106 ++++++++++++++++++ 2 files changed, 109 insertions(+), 15 deletions(-) create mode 100644 regression-test/suites/query_p0/aggregate/group_array_intersect.groovy diff --git a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h index 3f3a33f5fbbffd..f41b8f9301a99f 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h +++ b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h @@ -295,11 +295,7 @@ class AggregateFunctionGroupArrayIntersect i = 1; } - if (offsets_to.size() == 0) { - offsets_to.push_back(res_size); - } else { - offsets_to.push_back(offsets_to.back() + res_size); - } + offsets_to.push_back(offsets_to.back() + res_size); nested_col.get_data().resize(old_size + res_size); @@ -317,11 +313,7 @@ class AggregateFunctionGroupArrayIntersect const auto& set = this->data(place).value; - if (offsets_to.size() == 0) { - offsets_to.push_back(set->size()); - } else { - offsets_to.push_back(offsets_to.back() + set->size()); - } + offsets_to.push_back(offsets_to.back() + set->size()); nested_col.get_data().resize(old_size + set->size()); @@ -538,11 +530,7 @@ class AggregateFunctionGroupArrayIntersectGeneric res_size += 1; } - if (offsets_to.size() == 0) { - offsets_to.push_back(res_size); - } else { - offsets_to.push_back(offsets_to.back() + res_size); - } + offsets_to.push_back(offsets_to.back() + res_size); HybridSetBase::IteratorBase* it = set->begin(); while (it->has_next()) { diff --git a/regression-test/suites/query_p0/aggregate/group_array_intersect.groovy b/regression-test/suites/query_p0/aggregate/group_array_intersect.groovy new file mode 100644 index 00000000000000..a802d42a66aca3 --- /dev/null +++ b/regression-test/suites/query_p0/aggregate/group_array_intersect.groovy @@ -0,0 +1,106 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("group_array_intersect") { + sql "DROP TABLE IF EXISTS `group_array_intersect_test`;" + sql """ + CREATE TABLE `group_array_intersect_test` ( + `id` int(11) NULL COMMENT "" + , `c_array_int` ARRAY NULL COMMENT "" + , `c_array_datetimev2` ARRAY NULL COMMENT "" + , `c_array_float` ARRAY NULL COMMENT "" + , `c_array_datev2` ARRAY NULL COMMENT "" + , `c_array_string` ARRAY NULL COMMENT "" + , `c_array_bigint` ARRAY NULL COMMENT "" + , `c_array_decimal` ARRAY NULL COMMENT "" + ) ENGINE=OLAP + DUPLICATE KEY(`id`) + COMMENT "OLAP" + DISTRIBUTED BY HASH(`id`) BUCKETS 1 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1", + "in_memory" = "false", + "storage_format" = "V2" + ); + """ + + sql """INSERT INTO `group_array_intersect_test`(id, c_array_int) VALUES (0, [0]),(1, [1,2,3,4,5]), (2, [6,7,8]), (3, []), (4, null), (5, [6, 7]), (6, [NULL]);""" + sql """INSERT INTO `group_array_intersect_test`(id, c_array_int) VALUES (12, [12, null, 13]), (13, [null, null]), (14, [12, 13]);""" + sql """INSERT INTO `group_array_intersect_test`(id, c_array_float) VALUES (7, [6.3, 7.3]), (8, [7.3, 8.3]), (9, [7.3, 9.3, 8.3]);""" + sql """INSERT INTO `group_array_intersect_test`(id, c_array_datetimev2) VALUES (10, ['2024-03-23 00:00:00', '2024-03-24 00:00:00']), (11, ['2024-03-24 00:00:00', '2024-03-25 00:00:00']);""" + sql """INSERT INTO `group_array_intersect_test`(id, c_array_datev2) VALUES (15, ['2024-05-23', '2024-03-29']), (16, ['2024-03-29', '2024-03-25']), (17, ['2024-05-23', null]);""" + sql """INSERT INTO `group_array_intersect_test`(id, c_array_string) VALUES (18, ['a', 'b', 'c', 'd', 'e', 'f']), (19, ['a', 'aa', 'b', 'bb', 'c', 'cc', 'd', 'dd', 'f', 'ff']);""" + sql """INSERT INTO `group_array_intersect_test`(id, c_array_string) VALUES (20, ['a', null]), (21, [null, null]), (22, ['x', 'y']);""" + sql """INSERT INTO `group_array_intersect_test`(id, c_array_bigint) VALUES (23, [1234567890123456]), (24, [1234567890123456, 2333333333333333]);""" + sql """INSERT INTO `group_array_intersect_test`(id, c_array_decimal) VALUES (25, [1.34,2.00188888888888888]), (26, [1.34,2.00123344444455555]);""" + + qt_int_1 """select group_array_intersect(c_array_int) from group_array_intersect_test where id in (6, 12);""" + qt_int_2 """select group_array_intersect(c_array_int) from group_array_intersect_test where id in (14, 12);""" + qt_int_3 """select group_array_intersect(c_array_int) from group_array_intersect_test where id in (0, 6);""" + qt_int_4 """select group_array_intersect(c_array_int) from group_array_intersect_test where id in (13);""" + qt_int_5 """select group_array_intersect(c_array_int) from group_array_intersect_test where id in (2, 5);""" + qt_int_6 """select group_array_intersect(c_array_int) from group_array_intersect_test where id in (6, 13);""" + qt_int_7 """select group_array_intersect(c_array_int) from group_array_intersect_test where id in (12);""" + qt_int_8 """select group_array_intersect(c_array_int) from group_array_intersect_test where id in (6, 7);""" + qt_int_9 """select group_array_intersect(c_array_int) from group_array_intersect_test where id in (9, 12);""" + qt_float_1 """select group_array_intersect(c_array_float) from group_array_intersect_test where id = 7;""" + qt_float_2 """select group_array_intersect(c_array_float) from group_array_intersect_test where id between 7 and 8;""" + qt_float_3 """select group_array_intersect(c_array_float) from group_array_intersect_test where id in (7, 9);""" + qt_datetimev2_1 """select group_array_intersect(c_array_datetimev2) from group_array_intersect_test;""" + qt_datetimev2_2 """select group_array_intersect(c_array_datetimev2) from group_array_intersect_test where id in (10, 11);""" + qt_datev2_1 """select group_array_intersect(c_array_datev2) from group_array_intersect_test where id in (15, 16);""" + qt_datev2_2 """select group_array_intersect(c_array_datev2) from group_array_intersect_test where id in (15, 17);""" + qt_string_1 """select group_array_intersect(c_array_string) from group_array_intersect_test where id in (17, 20);""" + qt_string_2 """select group_array_intersect(c_array_string) from group_array_intersect_test where id in (18, 20);""" + qt_bigint """select group_array_intersect(c_array_bigint) from group_array_intersect_test where id in (23, 24);""" + qt_decimal """select group_array_intersect(c_array_decimal) from group_array_intersect_test where id in (25, 26);""" + qt_groupby_1 """select id, group_array_intersect(c_array_int) from group_array_intersect_test where id <= 1 group by id order by id;""" + qt_groupby_2 """select id, group_array_intersect(c_array_string) from group_array_intersect_test where c_array_string is not null group by id order by id;""" + qt_groupby_3 """select id, group_array_intersect(c_array_string) from group_array_intersect_test where id = 18 group by id order by id;""" + + + sql "DROP TABLE IF EXISTS `group_array_intersect_test_not_null`;" + sql """ + CREATE TABLE `group_array_intersect_test_not_null` ( + `id` int(11) NULL COMMENT "" + , `c_array_int` ARRAY NOT NULL COMMENT "" + , `c_array_float` ARRAY NOT NULL COMMENT "" + , `c_array_string` ARRAY NOT NULL COMMENT "" + ) ENGINE=OLAP + DUPLICATE KEY(`id`) + COMMENT "OLAP" + DISTRIBUTED BY HASH(`id`) BUCKETS 1 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1", + "in_memory" = "false", + "storage_format" = "V2" + ); + """ + + sql """INSERT INTO `group_array_intersect_test_not_null`(id, c_array_int, c_array_float, c_array_string) VALUES (1, [1, 2, 3, 4, 5], [1.1, 2.2, 3.3, 4.4, 5.5], ['a', 'b', 'c', 'd', 'e', 'f']);""" + sql """INSERT INTO `group_array_intersect_test_not_null`(id, c_array_int, c_array_float, c_array_string) VALUES (2, [6, 7, 8], [6.6, 7.7, 8.8], ['a', 'aa', 'b', 'bb', 'c', 'cc', 'd', 'dd', 'f', 'ff'])""" + sql """INSERT INTO `group_array_intersect_test_not_null`(id, c_array_int, c_array_float, c_array_string) VALUES (3, [], [], []);""" + sql """INSERT INTO `group_array_intersect_test_not_null`(id, c_array_int, c_array_float, c_array_string) VALUES (4, [6, 7], [6.6, 7.7], ['a']);""" + sql """INSERT INTO `group_array_intersect_test_not_null`(id, c_array_int, c_array_float, c_array_string) VALUES (5, [null], [null], ['x', 'y']);""" + + qt_notnull_1 """select group_array_intersect(c_array_float) from group_array_intersect_test_not_null where array_size(c_array_float) between 1 and 2;""" + qt_notnull_2 """select group_array_intersect(c_array_int), group_array_intersect(c_array_float) from group_array_intersect_test_not_null where id between 2 and 3;""" + qt_notnull_3 """select group_array_intersect(c_array_float) from group_array_intersect_test_not_null where array_size(c_array_float) between 2 and 3;""" + qt_notnull_4 """select group_array_intersect(c_array_string) from group_array_intersect_test_not_null where id between 1 and 2;""" + qt_notnull_5 """select group_array_intersect(c_array_int), group_array_intersect(c_array_float), group_array_intersect(c_array_string) from group_array_intersect_test_not_null where id between 3 and 4;""" + qt_notnull_6 """select group_array_intersect(c_array_string) from group_array_intersect_test_not_null where id between 1 and 5;""" +} From 7f52c916f1aa2c834e16369b93d68ece327abf5a Mon Sep 17 00:00:00 2001 From: chesterxu Date: Thu, 4 Apr 2024 16:18:26 +0800 Subject: [PATCH 14/21] opt3 --- ...aggregate_function_group_array_intersect.h | 5 +- .../aggregate/group_array_intersect.out | 93 +++++++++++++++++++ 2 files changed, 96 insertions(+), 2 deletions(-) create mode 100644 regression-test/data/query_p0/aggregate/group_array_intersect.out diff --git a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h index f41b8f9301a99f..2e1fae4cbf9305 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h +++ b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h @@ -209,7 +209,8 @@ class AggregateFunctionGroupArrayIntersect if (set->size() != 0) { auto create_new_set = [](auto& lhs_val, auto& rhs_val) { - typename State::Set new_set; + typename State::Set new_set = + std::make_shared(); HybridSetBase::IteratorBase* it = lhs_val->begin(); while (it->has_next()) { const void* value = it->get_value(); @@ -459,7 +460,7 @@ class AggregateFunctionGroupArrayIntersectGeneric } } else if (set->size() != 0) { auto create_new_set = [](auto& lhs_val, auto& rhs_val) { - typename State::Set new_set; + typename State::Set new_set = std::make_shared(); HybridSetBase::IteratorBase* it = lhs_val->begin(); while (it->has_next()) { const auto* value = reinterpret_cast(it->get_value()); diff --git a/regression-test/data/query_p0/aggregate/group_array_intersect.out b/regression-test/data/query_p0/aggregate/group_array_intersect.out new file mode 100644 index 00000000000000..e9e0efec5dce21 --- /dev/null +++ b/regression-test/data/query_p0/aggregate/group_array_intersect.out @@ -0,0 +1,93 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !int_1 -- +[null] + +-- !int_2 -- +[13, 12] + +-- !int_3 -- +[] + +-- !int_4 -- +[null] + +-- !int_5 -- +[7, 6] + +-- !int_6 -- +[null] + +-- !int_7 -- +[null, 13, 12] + +-- !int_8 -- +[] + +-- !int_9 -- +[] + +-- !float_1 -- +[6.3, 7.3] + +-- !float_2 -- +[7.3] + +-- !float_3 -- +[7.3] + +-- !datetimev2_1 -- +[] + +-- !datetimev2_2 -- +["2024-03-24 00:00:00.000"] + +-- !datev2_1 -- +["2024-03-29"] + +-- !datev2_2 -- +["2024-05-23"] + +-- !string_1 -- +[] + +-- !string_2 -- +["a"] + +-- !bigint -- +[1234567890123456] + +-- !decimal -- +[1.34000] + +-- !groupby_1 -- +0 [0] +1 [4, 1, 5, 2, 3] + +-- !groupby_2 -- +18 ["c", "e", "b", "d", "a", "f"] +19 ["c", "ff", "cc", "bb", "f", "aa", "dd", "b", "d", "a"] +20 [null, "a"] +21 [null] +22 ["x", "y"] + +-- !groupby_3 -- +18 ["c", "e", "b", "d", "a", "f"] + +-- !notnull_1 -- +[] + +-- !notnull_2 -- +[] [] + +-- !notnull_3 -- +[7.7, 6.6] + +-- !notnull_4 -- +["c", "b", "d", "a", "f"] + +-- !notnull_5 -- +[] [] [] + +-- !notnull_6 -- +[] + From a50f50d27c6ba1793d72fff38f74e2eaa43fa644 Mon Sep 17 00:00:00 2001 From: chesterxu Date: Thu, 4 Apr 2024 17:03:07 +0800 Subject: [PATCH 15/21] add tests for nereids --- .../nereids_function_p0/agg_function/agg.out | 142 ++++++++++++++++++ .../agg_function/agg.groovy | 66 ++++++++ 2 files changed, 208 insertions(+) diff --git a/regression-test/data/nereids_function_p0/agg_function/agg.out b/regression-test/data/nereids_function_p0/agg_function/agg.out index bfd27bf07d95ca..10337f9f420ab9 100644 --- a/regression-test/data/nereids_function_p0/agg_function/agg.out +++ b/regression-test/data/nereids_function_p0/agg_function/agg.out @@ -1410,6 +1410,148 @@ -- !sql_count_AnyData_agg_phase_4_notnull -- 12 12 +-- !sql_group_array_intersect_array_bool -- +[0] + +-- !sql_group_array_intersect_array_tinyint -- +[1] + +-- !sql_group_array_intersect_array_smallint -- +[] + +-- !sql_group_array_intersect_array_int -- +[] +[1] +[2] +[3] +[4] +[5] +[6] +[7] +[8] +[9] +[10] +[11] +[12] + +-- !sql_group_array_intersect_array_bigint -- +[] + +-- !sql_group_array_intersect_array_largeint -- +[8] + +-- !sql_group_array_intersect_array_float -- +[5] + +-- !sql_group_array_intersect_array_double -- +[0.2] + +-- !sql_group_array_intersect_array_date -- +["2012-03-03"] + +-- !sql_group_array_intersect_array_datetime -- +["2012-03-04 04:03:04"] + +-- !sql_group_array_intersect_array_datev2 -- +["2012-03-06"] + +-- !sql_group_array_intersect_array_datetimev2 -- +["2012-03-09 09:08:09.000000"] + +-- !sql_group_array_intersect_array_char -- +["char21", "char11", "char31"] + +-- !sql_group_array_intersect_array_varchar -- +["varchar11", "char11", "varchar31", "char31", "varchar21", "char21"] + +-- !sql_group_array_intersect_array_string -- +["varchar11", "string1", "varchar31", "char31", "varchar21", "char21"] + +-- !sql_group_array_intersect_array_decimal -- +[] +[0.100000000] +[0.200000000] +[0.300000000] +[0.400000000] +[0.500000000] +[0.600000000] +[0.700000000] +[0.800000000] +[0.900000000] +[1.000000000] +[1.100000000] +[1.200000000] + +-- !sql_group_array_intersect_array_bool_notnull -- +[0] + +-- !sql_group_array_intersect_array_tinyint_notnull -- +[1] + +-- !sql_group_array_intersect_array_smallint_notnull -- +[] + +-- !sql_group_array_intersect_array_int_notnull -- +[1] +[2] +[3] +[4] +[5] +[6] +[7] +[8] +[9] +[10] +[11] +[12] + +-- !sql_group_array_intersect_array_bigint_notnull -- +[] + +-- !sql_group_array_intersect_array_largeint_notnull -- +[8] + +-- !sql_group_array_intersect_array_float_notnull -- +[5] + +-- !sql_group_array_intersect_array_double_notnull -- +[0.2] + +-- !sql_group_array_intersect_array_date_notnull -- +["2012-03-03"] + +-- !sql_group_array_intersect_array_datetime_notnull -- +["2012-03-04 04:03:04"] + +-- !sql_group_array_intersect_array_datev2_notnull -- +["2012-03-06"] + +-- !sql_group_array_intersect_array_datetimev2_notnull -- +["2012-03-09 09:08:09.000000"] + +-- !sql_group_array_intersect_array_char_notnull -- +["char21", "char11", "char31"] + +-- !sql_group_array_intersect_array_varchar_notnull -- +["varchar11", "char11", "varchar31", "char31", "varchar21", "char21"] + +-- !sql_group_array_intersect_array_string_notnull -- +["varchar11", "string1", "varchar31", "char31", "varchar21", "char21"] + +-- !sql_group_array_intersect_array_decimal_notnull -- +[0.100000000] +[0.200000000] +[0.300000000] +[0.400000000] +[0.500000000] +[0.600000000] +[0.700000000] +[0.800000000] +[0.900000000] +[1.000000000] +[1.100000000] +[1.200000000] + -- !sql_group_bit_and_TinyInt_gb -- \N 0 diff --git a/regression-test/suites/nereids_function_p0/agg_function/agg.groovy b/regression-test/suites/nereids_function_p0/agg_function/agg.groovy index 81c84ad32d8f7e..1caf47cdacdb93 100644 --- a/regression-test/suites/nereids_function_p0/agg_function/agg.groovy +++ b/regression-test/suites/nereids_function_p0/agg_function/agg.groovy @@ -589,6 +589,72 @@ suite("nereids_agg_fn") { select /*+SET_VAR(disable_nereids_rules='THREE_PHASE_AGGREGATE_WITH_DISTINCT, TWO_PHASE_AGGREGATE_WITH_DISTINCT')*/ count(distinct id, kint), count(kint) from fn_test group by kbool order by kbool''' qt_sql_count_AnyData_agg_phase_4_notnull ''' select /*+SET_VAR(disable_nereids_rules='THREE_PHASE_AGGREGATE_WITH_DISTINCT, TWO_PHASE_AGGREGATE_WITH_DISTINCT')*/ count(distinct id), count(kint) from fn_test''' + + qt_sql_group_array_intersect_array_bool ''' + select group_array_intersect(kabool) from fn_test where id<= 2;''' + qt_sql_group_array_intersect_array_tinyint ''' + select group_array_intersect(katint) from fn_test where id between 7 and 10;''' + qt_sql_group_array_intersect_array_smallint ''' + select group_array_intersect(kasint) from fn_test;''' + qt_sql_group_array_intersect_array_int ''' + select group_array_intersect(kaint) from fn_test group by id order by id;''' + qt_sql_group_array_intersect_array_bigint ''' + select group_array_intersect(kabint) from fn_test where id between 7 and 10;''' + qt_sql_group_array_intersect_array_largeint ''' + select group_array_intersect(kalint) from fn_test where id = 7;''' + qt_sql_group_array_intersect_array_float ''' + select group_array_intersect(kafloat) from fn_test where id = 4;''' + qt_sql_group_array_intersect_array_double ''' + select group_array_intersect(kadbl) from fn_test where id = 1;''' + qt_sql_group_array_intersect_array_date ''' + select group_array_intersect(kadt) from fn_test where id = 2;''' + qt_sql_group_array_intersect_array_datetime ''' + select group_array_intersect(kadtm) from fn_test where id = 3;''' + qt_sql_group_array_intersect_array_datev2 ''' + select group_array_intersect(kadtv2) from fn_test where id = 5;''' + qt_sql_group_array_intersect_array_datetimev2 ''' + select group_array_intersect(kadtmv2) from fn_test where id = 8;''' + qt_sql_group_array_intersect_array_char ''' + select group_array_intersect(kachr) from fn_test where id in (0, 3);''' + qt_sql_group_array_intersect_array_varchar ''' + select group_array_intersect(kavchr) from fn_test where id = 6;''' + qt_sql_group_array_intersect_array_string ''' + select group_array_intersect(kastr) from fn_test where id in (6, 9);''' + qt_sql_group_array_intersect_array_decimal ''' + select group_array_intersect(kadcml) from fn_test group by id order by id;''' + qt_sql_group_array_intersect_array_bool_notnull ''' + select group_array_intersect(kabool) from fn_test_not_nullable where id<= 2;''' + qt_sql_group_array_intersect_array_tinyint_notnull ''' + select group_array_intersect(katint) from fn_test_not_nullable where id between 7 and 10;''' + qt_sql_group_array_intersect_array_smallint_notnull ''' + select group_array_intersect(kasint) from fn_test_not_nullable;''' + qt_sql_group_array_intersect_array_int_notnull ''' + select group_array_intersect(kaint) from fn_test_not_nullable group by id order by id;''' + qt_sql_group_array_intersect_array_bigint_notnull ''' + select group_array_intersect(kabint) from fn_test_not_nullable where id between 7 and 10;''' + qt_sql_group_array_intersect_array_largeint_notnull ''' + select group_array_intersect(kalint) from fn_test_not_nullable where id = 7;''' + qt_sql_group_array_intersect_array_float_notnull ''' + select group_array_intersect(kafloat) from fn_test_not_nullable where id = 4;''' + qt_sql_group_array_intersect_array_double_notnull ''' + select group_array_intersect(kadbl) from fn_test_not_nullable where id = 1;''' + qt_sql_group_array_intersect_array_date_notnull ''' + select group_array_intersect(kadt) from fn_test_not_nullable where id = 2;''' + qt_sql_group_array_intersect_array_datetime_notnull ''' + select group_array_intersect(kadtm) from fn_test_not_nullable where id = 3;''' + qt_sql_group_array_intersect_array_datev2_notnull ''' + select group_array_intersect(kadtv2) from fn_test_not_nullable where id = 5;''' + qt_sql_group_array_intersect_array_datetimev2_notnull ''' + select group_array_intersect(kadtmv2) from fn_test_not_nullable where id = 8;''' + qt_sql_group_array_intersect_array_char_notnull ''' + select group_array_intersect(kachr) from fn_test_not_nullable where id in (0, 3);''' + qt_sql_group_array_intersect_array_varchar_notnull ''' + select group_array_intersect(kavchr) from fn_test_not_nullable where id = 6;''' + qt_sql_group_array_intersect_array_string_notnull ''' + select group_array_intersect(kastr) from fn_test_not_nullable where id in (6, 9);''' + qt_sql_group_array_intersect_array_decimal_notnull ''' + select group_array_intersect(kadcml) from fn_test_not_nullable group by id order by id;''' + qt_sql_group_bit_and_TinyInt_gb ''' select group_bit_and(ktint) from fn_test group by kbool order by kbool''' qt_sql_group_bit_and_TinyInt ''' From 96f8c80076cf2bc78ce08ae0ba8f2fa2299d2cab Mon Sep 17 00:00:00 2001 From: chesterxu Date: Thu, 4 Apr 2024 18:41:26 +0800 Subject: [PATCH 16/21] opt4 --- ...aggregate_function_group_array_intersect.h | 90 ++++++++----------- 1 file changed, 36 insertions(+), 54 deletions(-) diff --git a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h index 2e1fae4cbf9305..c322eba6d269d4 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h +++ b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h @@ -269,7 +269,6 @@ class AggregateFunctionGroupArrayIntersect void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { ColumnArray& arr_to = assert_cast(to); - ColumnArray::Offsets64& offsets_to = arr_to.get_offsets(); auto& to_nested_col = arr_to.get_data(); @@ -278,54 +277,42 @@ class AggregateFunctionGroupArrayIntersect bool is_nullable = to_nested_col.is_nullable(); - if (is_nullable) { - auto col_null = reinterpret_cast(&to_nested_col); - auto& nested_col = assert_cast(col_null->get_nested_column()); - auto& null_map_data = col_null->get_null_map_data(); - + auto insert_values = [](ColVecType& nested_col, auto& set, bool is_nullable = false, + ColumnNullable* col_null = nullptr) { size_t old_size = nested_col.get_data().size(); - - const auto& set = this->data(place).value; - - auto res_size = set->size(); + size_t res_size = set->size(); size_t i = 0; - if (set->contain_null()) { + if (is_nullable && set->contain_null()) { col_null->insert_data(nullptr, 0); res_size += 1; i = 1; } - offsets_to.push_back(offsets_to.back() + res_size); - nested_col.get_data().resize(old_size + res_size); HybridSetBase::IteratorBase* it = set->begin(); while (it->has_next()) { - T value = *reinterpret_cast(it->get_value()); + ElementType value = *reinterpret_cast(it->get_value()); nested_col.get_data()[old_size + i] = value; - null_map_data.push_back(0); + if (is_nullable) { + col_null->get_null_map_data().push_back(0); + } it->next(); ++i; } + }; + + const auto& set = this->data(place).value; + if (is_nullable) { + auto col_null = reinterpret_cast(&to_nested_col); + auto& nested_col = assert_cast(col_null->get_nested_column()); + offsets_to.push_back(offsets_to.back() + set->size() + (set->contain_null() ? 1 : 0)); + insert_values(nested_col, set, true, col_null); } else { auto& nested_col = static_cast(to_nested_col); - size_t old_size = nested_col.get_data().size(); - - const auto& set = this->data(place).value; - offsets_to.push_back(offsets_to.back() + set->size()); - - nested_col.get_data().resize(old_size + set->size()); - - size_t i = 0; - HybridSetBase::IteratorBase* it = set->begin(); - while (it->has_next()) { - T value = *reinterpret_cast(it->get_value()); - nested_col.get_data()[old_size + i] = value; - it->next(); - ++i; - } + insert_values(nested_col, set); } } }; @@ -348,7 +335,7 @@ struct AggregateFunctionGroupArrayIntersectGenericData { }; /** Template parameter with true value should be used for columns that store their elements in memory continuously. - * For such columns GroupArrayIntersect() can be implemented more efficiently (especially for small numeric arrays). + * For such columns group_array_intersect() can be implemented more efficiently (especially for small numeric arrays). */ template class AggregateFunctionGroupArrayIntersectGeneric @@ -399,38 +386,33 @@ class AggregateFunctionGroupArrayIntersectGeneric col_null = static_cast(const_col_data); } + auto process_element = [&](size_t i) { + const bool is_null_element = + is_column_data_nullable && col_null->is_null_at(offset + i); + + StringRef src = StringRef(); + if constexpr (is_plain_column) { + src = nested_column_data->get_data_at(offset + i); + } else { + const char* begin = nullptr; + src = nested_column_data->serialize_value_into_arena(offset + i, *arena, begin); + } + + src.data = is_null_element ? nullptr : arena->insert(src.data, src.size); + return src; + }; + ++version; if (version == 1) { for (size_t i = 0; i < arr_size; ++i) { - const bool is_null_element = - is_column_data_nullable && col_null->is_null_at(offset + i); - - StringRef src = StringRef(); - if constexpr (is_plain_column) { - src = nested_column_data->get_data_at(offset + i); - } else { - const char* begin = nullptr; - src = nested_column_data->serialize_value_into_arena(offset + i, *arena, begin); - } - - src.data = is_null_element ? nullptr : arena->insert(src.data, src.size); + StringRef src = process_element(i); set->insert((void*)src.data, src.size); } } else if (set->size() != 0 || set->contain_null()) { typename State::Set new_set = std::make_shared(); for (size_t i = 0; i < arr_size; ++i) { - const bool is_null_element = - is_column_data_nullable && col_null->is_null_at(offset + i); - StringRef src = StringRef(); - if constexpr (is_plain_column) { - src = nested_column_data->get_data_at(offset + i); - } else { - const char* begin = nullptr; - src = nested_column_data->serialize_value_into_arena(offset + i, *arena, begin); - } - - src.data = is_null_element ? nullptr : arena->insert(src.data, src.size); + StringRef src = process_element(i); if (set->find(src.data, src.size) || (set->contain_null() && src.data == nullptr)) { new_set->insert((void*)src.data, src.size); } From 6aa4266164ab4c81109e43d77e2bf2504a377ca8 Mon Sep 17 00:00:00 2001 From: chesterxu Date: Fri, 5 Apr 2024 11:50:30 +0800 Subject: [PATCH 17/21] fmt --- .../aggregate_function_group_array_intersect.cpp | 9 +++++---- .../aggregate_function_group_array_intersect.h | 8 +++----- docs/sidebars.json | 0 3 files changed, 8 insertions(+), 9 deletions(-) delete mode 100644 docs/sidebars.json diff --git a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.cpp b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.cpp index c2c9e0fb0742e2..2d5ac8a6a8409e 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.cpp @@ -25,11 +25,11 @@ namespace doris::vectorized { IAggregateFunction* create_with_extra_types(const DataTypePtr& nested_type, const DataTypes& argument_types) { WhichDataType which(nested_type); - if (which.idx == TypeIndex::Date || which.idx == TypeIndex::DateV2) + if (which.idx == TypeIndex::Date || which.idx == TypeIndex::DateV2) { return new AggregateFunctionGroupArrayIntersect(argument_types); - else if (which.idx == TypeIndex::DateTime || which.idx == TypeIndex::DateTimeV2) + } else if (which.idx == TypeIndex::DateTime || which.idx == TypeIndex::DateTimeV2) { return new AggregateFunctionGroupArrayIntersect(argument_types); - else { + } else { /// Check that we can use plain version of AggregateFunctionGroupArrayIntersectGeneric if (nested_type->is_value_unambiguously_represented_in_contiguous_memory_region()) return new AggregateFunctionGroupArrayIntersectGeneric(argument_types); @@ -56,10 +56,11 @@ inline AggregateFunctionPtr create_aggregate_function_group_array_intersect_impl res = AggregateFunctionPtr(create_with_extra_types(nested_type, argument_types)); } - if (!res) + if (!res) { throw Exception(ErrorCode::INVALID_ARGUMENT, "Illegal type {} of argument for aggregate function {}", argument_types[0]->get_name(), name); + } return res; } diff --git a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h index c322eba6d269d4..2721276ae84986 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h +++ b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h @@ -36,13 +36,11 @@ #include "vec/io/io_helper.h" #include "vec/io/var_int.h" -namespace doris { -namespace vectorized { +namespace doris::vectorized { class Arena; class BufferReadable; class BufferWritable; -} // namespace vectorized -} // namespace doris +} // namespace doris::vectorized namespace doris::vectorized { @@ -436,7 +434,7 @@ class AggregateFunctionGroupArrayIntersectGeneric set->change_contains_null_value(rhs_set->contain_null()); HybridSetBase::IteratorBase* it = rhs_set->begin(); while (it->has_next()) { - const StringRef* value = reinterpret_cast(it->get_value()); + const auto* value = reinterpret_cast(it->get_value()); set->insert((void*)(value->data), value->size); it->next(); } diff --git a/docs/sidebars.json b/docs/sidebars.json deleted file mode 100644 index e69de29bb2d1d6..00000000000000 From efa5d25102705565ed81d0f5695a53fdb05c7ff8 Mon Sep 17 00:00:00 2001 From: chesterxu Date: Fri, 5 Apr 2024 21:40:00 +0800 Subject: [PATCH 18/21] fix distinct --- .../trees/expressions/functions/agg/GroupArrayIntersect.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupArrayIntersect.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupArrayIntersect.java index c6b1bc96a6d0c7..695a23d0a01396 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupArrayIntersect.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupArrayIntersect.java @@ -54,7 +54,7 @@ public GroupArrayIntersect(Expression arg) { * constructor with 1 argument. */ public GroupArrayIntersect(boolean distinct, Expression arg) { - super("group_array_intersect", distinct, arg); + super("group_array_intersect", false, arg); } /** From 7a7f6d2181c51d838df4758d9477ebddd16e31ed Mon Sep 17 00:00:00 2001 From: chesterxu Date: Mon, 8 Apr 2024 19:10:53 +0800 Subject: [PATCH 19/21] opt, fix --- ...aggregate_function_group_array_intersect.h | 83 +++++++++---------- .../functions/agg/GroupArrayIntersect.java | 6 +- 2 files changed, 42 insertions(+), 47 deletions(-) diff --git a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h index 2721276ae84986..fa12b01f132c82 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h +++ b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h @@ -84,13 +84,13 @@ class NullableNumericOrDateSet template struct AggregateFunctionGroupArrayIntersectData { using NullableNumericOrDateSetType = NullableNumericOrDateSet; - using Set = std::shared_ptr; + using Set = std::unique_ptr; AggregateFunctionGroupArrayIntersectData() - : value(std::make_shared()) {} + : value(std::make_unique()) {} Set value; - UInt64 version = 0; + bool init = false; }; /// Puts all values to the hash set. Returns an array of unique values. Implemented for numeric types. @@ -125,7 +125,7 @@ class AggregateFunctionGroupArrayIntersect void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena*) const override { auto& data = this->data(place); - auto& version = data.version; + auto& init = data.init; auto& set = data.value; const bool col_is_nullable = (*columns[0]).is_nullable(); @@ -143,8 +143,8 @@ class AggregateFunctionGroupArrayIntersect using ColVecType = ColumnVector; const auto& column_data = column.get_data(); - bool is_column_data_nullable = column_data.is_nullable(); - ColumnNullable* col_null = nullptr; + const bool is_column_data_nullable = column_data.is_nullable(); + const ColumnNullable* col_null = nullptr; const ColVecType* nested_column_data = nullptr; if (is_column_data_nullable) { @@ -155,8 +155,7 @@ class AggregateFunctionGroupArrayIntersect nested_column_data = &static_cast(column_data); } - ++version; - if (version == 1) { + if (!init) { for (size_t i = 0; i < arr_size; ++i) { const bool is_null_element = is_column_data_nullable && col_null->is_null_at(offset + i); @@ -165,9 +164,10 @@ class AggregateFunctionGroupArrayIntersect set->insert(src_data); } + init = true; } else if (set->size() != 0 || set->contain_null()) { typename State::Set new_set = - std::make_shared(); + std::make_unique(); for (size_t i = 0; i < arr_size; ++i) { const bool is_null_element = @@ -189,12 +189,12 @@ class AggregateFunctionGroupArrayIntersect auto& set = data.value; auto& rhs_set = this->data(rhs).value; - if (this->data(rhs).version == 0) { + if (!this->data(rhs).init) { return; } - UInt64 version = data.version++; - if (version == 0) { + auto& init = data.init; + if (!init) { set->change_contains_null_value(rhs_set->contain_null()); HybridSetBase::IteratorBase* it = rhs_set->begin(); while (it->has_next()) { @@ -202,13 +202,14 @@ class AggregateFunctionGroupArrayIntersect set->insert(value); it->next(); } + init = true; return; } if (set->size() != 0) { auto create_new_set = [](auto& lhs_val, auto& rhs_val) { typename State::Set new_set = - std::make_shared(); + std::make_unique(); HybridSetBase::IteratorBase* it = lhs_val->begin(); while (it->has_next()) { const void* value = it->get_value(); @@ -230,13 +231,11 @@ class AggregateFunctionGroupArrayIntersect void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { auto& data = this->data(place); auto& set = data.value; - auto version = data.version; - - bool is_set_contains_null = set->contain_null(); + auto& init = data.init; + const bool is_set_contains_null = set->contain_null(); write_pod_binary(is_set_contains_null, buf); - - write_var_uint(version, buf); + write_pod_binary(init, buf); write_var_uint(set->size(), buf); HybridSetBase::IteratorBase* it = set->begin(); @@ -254,7 +253,7 @@ class AggregateFunctionGroupArrayIntersect read_pod_binary(is_set_contains_null, buf); data.value->change_contains_null_value(is_set_contains_null); - read_var_uint(data.version, buf); + read_pod_binary(data.init, buf); size_t size; read_var_uint(size, buf); @@ -270,10 +269,9 @@ class AggregateFunctionGroupArrayIntersect ColumnArray::Offsets64& offsets_to = arr_to.get_offsets(); auto& to_nested_col = arr_to.get_data(); - using ElementType = T; - using ColVecType = ColumnVector; + using ColVecType = ColumnVector; - bool is_nullable = to_nested_col.is_nullable(); + const bool is_nullable = to_nested_col.is_nullable(); auto insert_values = [](ColVecType& nested_col, auto& set, bool is_nullable = false, ColumnNullable* col_null = nullptr) { @@ -291,7 +289,7 @@ class AggregateFunctionGroupArrayIntersect HybridSetBase::IteratorBase* it = set->begin(); while (it->has_next()) { - ElementType value = *reinterpret_cast(it->get_value()); + const auto value = *reinterpret_cast(it->get_value()); nested_col.get_data()[old_size + i] = value; if (is_nullable) { col_null->get_null_map_data().push_back(0); @@ -324,12 +322,12 @@ class NullableStringSet : public StringValueSet> { }; struct AggregateFunctionGroupArrayIntersectGenericData { - using Set = std::shared_ptr; + using Set = std::unique_ptr; AggregateFunctionGroupArrayIntersectGenericData() - : value(std::make_shared()) {} + : value(std::make_unique()) {} Set value; - UInt64 version = 0; + bool init = false; }; /** Template parameter with true value should be used for columns that store their elements in memory continuously. @@ -361,7 +359,7 @@ class AggregateFunctionGroupArrayIntersectGeneric void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena* arena) const override { auto& data = this->data(place); - auto& version = data.version; + auto& init = data.init; auto& set = data.value; const bool col_is_nullable = (*columns[0]).is_nullable(); @@ -373,10 +371,10 @@ class AggregateFunctionGroupArrayIntersectGeneric const auto nested_column_data = column.get_data_ptr(); const auto& offsets = column.get_offsets(); - const size_t offset = offsets[static_cast(row_num) - 1]; + const auto offset = offsets[row_num - 1]; const auto arr_size = offsets[row_num] - offset; const auto& column_data = column.get_data(); - bool is_column_data_nullable = column_data.is_nullable(); + const bool is_column_data_nullable = column_data.is_nullable(); ColumnNullable* col_null = nullptr; if (is_column_data_nullable) { @@ -400,14 +398,14 @@ class AggregateFunctionGroupArrayIntersectGeneric return src; }; - ++version; - if (version == 1) { + if (!init) { for (size_t i = 0; i < arr_size; ++i) { StringRef src = process_element(i); set->insert((void*)src.data, src.size); } + init = true; } else if (set->size() != 0 || set->contain_null()) { - typename State::Set new_set = std::make_shared(); + typename State::Set new_set = std::make_unique(); for (size_t i = 0; i < arr_size; ++i) { StringRef src = process_element(i); @@ -425,12 +423,12 @@ class AggregateFunctionGroupArrayIntersectGeneric auto& set = data.value; auto& rhs_set = this->data(rhs).value; - if (this->data(rhs).version == 0) { + if (!this->data(rhs).init) { return; } - UInt64 version = data.version++; - if (version == 0) { + auto& init = data.init; + if (!init) { set->change_contains_null_value(rhs_set->contain_null()); HybridSetBase::IteratorBase* it = rhs_set->begin(); while (it->has_next()) { @@ -438,9 +436,10 @@ class AggregateFunctionGroupArrayIntersectGeneric set->insert((void*)(value->data), value->size); it->next(); } + init = true; } else if (set->size() != 0) { auto create_new_set = [](auto& lhs_val, auto& rhs_val) { - typename State::Set new_set = std::make_shared(); + typename State::Set new_set = std::make_unique(); HybridSetBase::IteratorBase* it = lhs_val->begin(); while (it->has_next()) { const auto* value = reinterpret_cast(it->get_value()); @@ -462,12 +461,11 @@ class AggregateFunctionGroupArrayIntersectGeneric void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { auto& data = this->data(place); auto& set = data.value; - auto version = data.version; - - bool is_set_contains_null = set->contain_null(); + auto& init = data.init; + const bool is_set_contains_null = set->contain_null(); write_pod_binary(is_set_contains_null, buf); - write_var_uint(version, buf); + write_pod_binary(init, buf); write_var_uint(set->size(), buf); HybridSetBase::IteratorBase* it = set->begin(); @@ -485,8 +483,7 @@ class AggregateFunctionGroupArrayIntersectGeneric read_pod_binary(is_set_contains_null, buf); data.value->change_contains_null_value(is_set_contains_null); - - read_var_uint(data.version, buf); + read_pod_binary(data.init, buf); size_t size; read_var_uint(size, buf); @@ -498,7 +495,7 @@ class AggregateFunctionGroupArrayIntersectGeneric } void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { - ColumnArray& arr_to = assert_cast(to); + auto& arr_to = assert_cast(to); ColumnArray::Offsets64& offsets_to = arr_to.get_offsets(); auto& data_to = arr_to.get_data(); auto col_null = reinterpret_cast(&data_to); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupArrayIntersect.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupArrayIntersect.java index 695a23d0a01396..3d6216d0d09161 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupArrayIntersect.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/GroupArrayIntersect.java @@ -25,7 +25,6 @@ import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.ArrayType; import org.apache.doris.nereids.types.coercion.AnyDataType; -import org.apache.doris.nereids.types.coercion.FollowToAnyDataType; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -39,9 +38,8 @@ public class GroupArrayIntersect extends AggregateFunction implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNotNullable { public static final List SIGNATURES = ImmutableList.of( - FunctionSignature.ret(ArrayType.of(new FollowToAnyDataType(0))) - .args(ArrayType.of(new AnyDataType(0))) - ); + FunctionSignature.retArgType(0) + .args(ArrayType.of(new AnyDataType(0)))); /** * constructor with 1 argument. From a25a16111b6bb29674465a6e355c16c9703e7d3e Mon Sep 17 00:00:00 2001 From: chesterxu Date: Fri, 12 Apr 2024 00:19:09 +0800 Subject: [PATCH 20/21] opt --- ...gregate_function_group_array_intersect.cpp | 8 +- ...aggregate_function_group_array_intersect.h | 91 +++++++++---------- 2 files changed, 51 insertions(+), 48 deletions(-) diff --git a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.cpp b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.cpp index 2d5ac8a6a8409e..b3b9a8b9af47c6 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.cpp @@ -25,9 +25,13 @@ namespace doris::vectorized { IAggregateFunction* create_with_extra_types(const DataTypePtr& nested_type, const DataTypes& argument_types) { WhichDataType which(nested_type); - if (which.idx == TypeIndex::Date || which.idx == TypeIndex::DateV2) { + if (which.idx == TypeIndex::Date || which.idx == TypeIndex::DateTime) { + throw Exception(ErrorCode::INVALID_ARGUMENT, + "We don't support array or array for " + "group_array_intersect(), please use array or array."); + } else if (which.idx == TypeIndex::DateV2) { return new AggregateFunctionGroupArrayIntersect(argument_types); - } else if (which.idx == TypeIndex::DateTime || which.idx == TypeIndex::DateTimeV2) { + } else if (which.idx == TypeIndex::DateTimeV2) { return new AggregateFunctionGroupArrayIntersect(argument_types); } else { /// Check that we can use plain version of AggregateFunctionGroupArrayIntersectGeneric diff --git a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h index fa12b01f132c82..a45640c50e2804 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h +++ b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h @@ -46,7 +46,7 @@ namespace doris::vectorized { /// Only for changing Numeric type or Date(DateTime)V2 type to PrimitiveType so that to inherit HybridSet template -constexpr PrimitiveType TypeToPrimitiveType() { +constexpr PrimitiveType type_to_primitive_type() { if constexpr (std::is_same_v || std::is_same_v) { return TYPE_TINYINT; } else if constexpr (std::is_same_v) { @@ -72,9 +72,9 @@ constexpr PrimitiveType TypeToPrimitiveType() { } template -class NullableNumericOrDateSet - : public HybridSet(), DynamicContainer()>::CppType>> { +class NullableNumericOrDateSet : public HybridSet(), + DynamicContainer()>::CppType>> { public: NullableNumericOrDateSet() { this->_null_aware = true; } @@ -83,6 +83,7 @@ class NullableNumericOrDateSet template struct AggregateFunctionGroupArrayIntersectData { + using ColVecType = ColumnVector; using NullableNumericOrDateSetType = NullableNumericOrDateSet; using Set = std::unique_ptr; @@ -91,9 +92,36 @@ struct AggregateFunctionGroupArrayIntersectData { Set value; bool init = false; + + void process_col_data(const ColumnNullable* col_null, const ColVecType* nested_column_data, + size_t offset, size_t arr_size, bool init, Set& set) { + if (!init) { + for (size_t i = 0; i < arr_size; ++i) { + const bool is_null_element = col_null && col_null->is_null_at(offset + i); + const T* src_data = + is_null_element ? nullptr : &(nested_column_data->get_element(offset + i)); + + set->insert(src_data); + } + init = true; + } else if (set->size() != 0 || set->contain_null()) { + Set new_set = std::make_unique(); + + for (size_t i = 0; i < arr_size; ++i) { + const bool is_null_element = col_null && col_null->is_null_at(offset + i); + const T* src_data = + is_null_element ? nullptr : &(nested_column_data->get_element(offset + i)); + + if (set->find(src_data) || (set->contain_null() && src_data == nullptr)) { + new_set->insert(src_data); + } + } + set = std::move(new_set); + } + } }; -/// Puts all values to the hash set. Returns an array of unique values. Implemented for numeric types. +/// Puts all values to the hybrid set. Returns an array of unique values. Implemented for numeric/date types. template class AggregateFunctionGroupArrayIntersect : public IAggregateFunctionDataHelper, @@ -137,50 +165,23 @@ class AggregateFunctionGroupArrayIntersect const auto data_column = column.get_data_ptr(); const auto& offsets = column.get_offsets(); - const size_t offset = offsets[static_cast(row_num) - 1]; + const auto offset = offsets[row_num - 1]; const auto arr_size = offsets[row_num] - offset; - - using ColVecType = ColumnVector; const auto& column_data = column.get_data(); - const bool is_column_data_nullable = column_data.is_nullable(); const ColumnNullable* col_null = nullptr; - const ColVecType* nested_column_data = nullptr; + const typename State::ColVecType* nested_column_data = nullptr; - if (is_column_data_nullable) { + if (column_data.is_nullable()) { auto const_col_data = const_cast(&column_data); col_null = static_cast(const_col_data); - nested_column_data = &assert_cast(col_null->get_nested_column()); + nested_column_data = + &assert_cast(col_null->get_nested_column()); } else { - nested_column_data = &static_cast(column_data); + nested_column_data = &static_cast(column_data); } - if (!init) { - for (size_t i = 0; i < arr_size; ++i) { - const bool is_null_element = - is_column_data_nullable && col_null->is_null_at(offset + i); - const T* src_data = - is_null_element ? nullptr : &(nested_column_data->get_element(offset + i)); - - set->insert(src_data); - } - init = true; - } else if (set->size() != 0 || set->contain_null()) { - typename State::Set new_set = - std::make_unique(); - - for (size_t i = 0; i < arr_size; ++i) { - const bool is_null_element = - is_column_data_nullable && col_null->is_null_at(offset + i); - const T* src_data = - is_null_element ? nullptr : &(nested_column_data->get_element(offset + i)); - - if (set->find(src_data) || (set->contain_null() && src_data == nullptr)) { - new_set->insert(src_data); - } - } - set = std::move(new_set); - } + data.process_col_data(col_null, nested_column_data, offset, arr_size, init, set); } void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, @@ -267,14 +268,11 @@ class AggregateFunctionGroupArrayIntersect void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { ColumnArray& arr_to = assert_cast(to); ColumnArray::Offsets64& offsets_to = arr_to.get_offsets(); - auto& to_nested_col = arr_to.get_data(); - using ColVecType = ColumnVector; - const bool is_nullable = to_nested_col.is_nullable(); - auto insert_values = [](ColVecType& nested_col, auto& set, bool is_nullable = false, - ColumnNullable* col_null = nullptr) { + auto insert_values = [](typename State::ColVecType& nested_col, auto& set, + bool is_nullable = false, ColumnNullable* col_null = nullptr) { size_t old_size = nested_col.get_data().size(); size_t res_size = set->size(); size_t i = 0; @@ -302,11 +300,12 @@ class AggregateFunctionGroupArrayIntersect const auto& set = this->data(place).value; if (is_nullable) { auto col_null = reinterpret_cast(&to_nested_col); - auto& nested_col = assert_cast(col_null->get_nested_column()); + auto& nested_col = + assert_cast(col_null->get_nested_column()); offsets_to.push_back(offsets_to.back() + set->size() + (set->contain_null() ? 1 : 0)); insert_values(nested_col, set, true, col_null); } else { - auto& nested_col = static_cast(to_nested_col); + auto& nested_col = static_cast(to_nested_col); offsets_to.push_back(offsets_to.back() + set->size()); insert_values(nested_col, set); } From 78e6ecd485ea94025feb7c77edb97dfc80a883e8 Mon Sep 17 00:00:00 2001 From: chesterxu Date: Fri, 12 Apr 2024 19:36:00 +0800 Subject: [PATCH 21/21] fix --- ...aggregate_function_group_array_intersect.h | 37 ++++++++++--------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h index a45640c50e2804..03c1639c45aa09 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h +++ b/be/src/vec/aggregate_functions/aggregate_function_group_array_intersect.h @@ -93,11 +93,24 @@ struct AggregateFunctionGroupArrayIntersectData { Set value; bool init = false; - void process_col_data(const ColumnNullable* col_null, const ColVecType* nested_column_data, - size_t offset, size_t arr_size, bool init, Set& set) { + void process_col_data(auto& column_data, size_t offset, size_t arr_size, bool& init, Set& set) { + const bool is_column_data_nullable = column_data.is_nullable(); + + const ColumnNullable* col_null = nullptr; + const ColVecType* nested_column_data = nullptr; + + if (is_column_data_nullable) { + auto* const_col_data = const_cast(&column_data); + col_null = static_cast(const_col_data); + nested_column_data = &assert_cast(col_null->get_nested_column()); + } else { + nested_column_data = &static_cast(column_data); + } + if (!init) { for (size_t i = 0; i < arr_size; ++i) { - const bool is_null_element = col_null && col_null->is_null_at(offset + i); + const bool is_null_element = + is_column_data_nullable && col_null->is_null_at(offset + i); const T* src_data = is_null_element ? nullptr : &(nested_column_data->get_element(offset + i)); @@ -108,7 +121,8 @@ struct AggregateFunctionGroupArrayIntersectData { Set new_set = std::make_unique(); for (size_t i = 0; i < arr_size; ++i) { - const bool is_null_element = col_null && col_null->is_null_at(offset + i); + const bool is_null_element = + is_column_data_nullable && col_null->is_null_at(offset + i); const T* src_data = is_null_element ? nullptr : &(nested_column_data->get_element(offset + i)); @@ -163,25 +177,12 @@ class AggregateFunctionGroupArrayIntersect .get_nested_column()) : assert_cast(*columns[0]); - const auto data_column = column.get_data_ptr(); const auto& offsets = column.get_offsets(); const auto offset = offsets[row_num - 1]; const auto arr_size = offsets[row_num] - offset; const auto& column_data = column.get_data(); - const ColumnNullable* col_null = nullptr; - const typename State::ColVecType* nested_column_data = nullptr; - - if (column_data.is_nullable()) { - auto const_col_data = const_cast(&column_data); - col_null = static_cast(const_col_data); - nested_column_data = - &assert_cast(col_null->get_nested_column()); - } else { - nested_column_data = &static_cast(column_data); - } - - data.process_col_data(col_null, nested_column_data, offset, arr_size, init, set); + data.process_col_data(column_data, offset, arr_size, init, set); } void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,