From f864ddb27fe92f0799fa775e6f5cec77f129c0fe Mon Sep 17 00:00:00 2001 From: yangsiyu Date: Tue, 25 Nov 2025 10:59:47 +0800 Subject: [PATCH] [feature](inverted index) Implement es-like boolean query --- .../query_v2/all_query/all_query.h | 112 +++ .../boolean_query/boolean_query_builder.h | 84 +++ .../query_v2/boolean_query/occur.h | 24 + .../boolean_query/occur_boolean_query.h | 75 ++ .../boolean_query/occur_boolean_weight.cpp | 287 +++++++ .../boolean_query/occur_boolean_weight.h | 88 +++ .../query_v2/{ => boolean_query}/operator.h | 0 ...olean_query.h => operator_boolean_query.h} | 41 +- ...ean_weight.h => operator_boolean_weight.h} | 12 +- .../query_v2/disjunction_scorer.cpp | 129 ++++ .../query_v2/disjunction_scorer.h | 66 ++ .../inverted_index/query_v2/doc_set.h | 4 + .../query_v2/exclude_scorer.cpp | 86 +++ .../inverted_index/query_v2/exclude_scorer.h | 50 ++ .../inverted_index/query_v2/intersection.cpp | 206 +++-- .../inverted_index/query_v2/intersection.h | 20 +- .../phrase_query/multi_phrase_weight.h | 3 +- .../query_v2/phrase_query/phrase_scorer.cpp | 16 +- .../query_v2/phrase_query/phrase_scorer.h | 7 +- .../query_v2/phrase_query/phrase_weight.h | 6 +- .../query_v2/postings/loaded_postings.cpp | 3 +- .../query_v2/regexp_query/regexp_weight.cpp | 2 +- .../inverted_index/query_v2/reqopt_scorer.h | 80 ++ .../query_v2/segment_postings.h | 195 +++-- .../inverted_index/query_v2/size_hint.h | 48 ++ .../query_v2/term_query/term_scorer.h | 5 +- .../query_v2/term_query/term_weight.h | 23 +- .../query_v2/union/buffered_union.cpp | 259 +++++++ .../query_v2/union/buffered_union.h | 65 ++ .../query_v2/union/simple_union.cpp | 2 +- .../query_v2/union/simple_union.h | 5 + .../inverted_index/query_v2/weight.h | 20 +- .../segment_v2/inverted_index/util/tiny_set.h | 34 +- be/src/vec/functions/function_search.cpp | 29 +- be/src/vec/functions/function_search.h | 2 +- .../boolean_query_builder_test.cpp | 387 ++++++++++ .../query_v2/boolean_query_test.cpp | 31 +- .../query_v2/buffered_union_test.cpp | 684 +++++++++++++++++ .../query_v2/disjunction_scorer_test.cpp | 389 ++++++++++ .../query_v2/exclude_scorer_test.cpp | 569 ++++++++++++++ .../query_v2/intersection_test.cpp | 94 +-- .../query_v2/occur_boolean_query_test.cpp | 707 ++++++++++++++++++ .../query_v2/reqopt_scorer_test.cpp | 540 +++++++++++++ .../query_v2/segment_postings_test.cpp | 226 +++--- contrib/clucene | 2 +- 45 files changed, 5261 insertions(+), 456 deletions(-) create mode 100644 be/src/olap/rowset/segment_v2/inverted_index/query_v2/all_query/all_query.h create mode 100644 be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/boolean_query_builder.h create mode 100644 be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/occur.h create mode 100644 be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/occur_boolean_query.h create mode 100644 be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/occur_boolean_weight.cpp create mode 100644 be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/occur_boolean_weight.h rename be/src/olap/rowset/segment_v2/inverted_index/query_v2/{ => boolean_query}/operator.h (100%) rename be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/{boolean_query.h => operator_boolean_query.h} (62%) rename be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/{boolean_weight.h => operator_boolean_weight.h} (95%) create mode 100644 be/src/olap/rowset/segment_v2/inverted_index/query_v2/disjunction_scorer.cpp create mode 100644 be/src/olap/rowset/segment_v2/inverted_index/query_v2/disjunction_scorer.h create mode 100644 be/src/olap/rowset/segment_v2/inverted_index/query_v2/exclude_scorer.cpp create mode 100644 be/src/olap/rowset/segment_v2/inverted_index/query_v2/exclude_scorer.h create mode 100644 be/src/olap/rowset/segment_v2/inverted_index/query_v2/reqopt_scorer.h create mode 100644 be/src/olap/rowset/segment_v2/inverted_index/query_v2/size_hint.h create mode 100644 be/src/olap/rowset/segment_v2/inverted_index/query_v2/union/buffered_union.cpp create mode 100644 be/src/olap/rowset/segment_v2/inverted_index/query_v2/union/buffered_union.h create mode 100644 be/test/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/boolean_query_builder_test.cpp create mode 100644 be/test/olap/rowset/segment_v2/inverted_index/query_v2/buffered_union_test.cpp create mode 100644 be/test/olap/rowset/segment_v2/inverted_index/query_v2/disjunction_scorer_test.cpp create mode 100644 be/test/olap/rowset/segment_v2/inverted_index/query_v2/exclude_scorer_test.cpp create mode 100644 be/test/olap/rowset/segment_v2/inverted_index/query_v2/occur_boolean_query_test.cpp create mode 100644 be/test/olap/rowset/segment_v2/inverted_index/query_v2/reqopt_scorer_test.cpp diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/all_query/all_query.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/all_query/all_query.h new file mode 100644 index 00000000000000..cd73860d46fd17 --- /dev/null +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/all_query/all_query.h @@ -0,0 +1,112 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +#include "olap/rowset/segment_v2/inverted_index/query_v2/query.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/scorer.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/weight.h" + +namespace doris::segment_v2::inverted_index::query_v2 { + +class AllScorer; +class AllWeight; +class AllQuery; + +using AllScorerPtr = std::shared_ptr; +using AllWeightPtr = std::shared_ptr; +using AllQueryPtr = std::shared_ptr; + +class AllScorer : public Scorer { +public: + explicit AllScorer(uint32_t max_doc) : _max_doc(max_doc) { + if (_max_doc == 0) { + _doc = TERMINATED; + } else { + _doc = 0; + } + } + + ~AllScorer() override = default; + + uint32_t doc() const override { return _doc; } + + uint32_t advance() override { + if (_doc == TERMINATED) { + return TERMINATED; + } + if (_doc + 1 >= _max_doc) { + _doc = TERMINATED; + return TERMINATED; + } + ++_doc; + return _doc; + } + + uint32_t seek(uint32_t target) override { + if (_doc == TERMINATED) { + return TERMINATED; + } + if (target >= _max_doc) { + _doc = TERMINATED; + return TERMINATED; + } + _doc = std::max(_doc, target); + return _doc; + } + + float score() override { return 1.0F; } + + uint32_t size_hint() const override { return _max_doc; } + +private: + uint32_t _max_doc = 0; + uint32_t _doc = TERMINATED; +}; + +class AllWeight : public Weight { +public: + explicit AllWeight(uint32_t max_doc) : _max_doc(max_doc) {} + + ~AllWeight() override = default; + + ScorerPtr scorer(const QueryExecutionContext& context) override { + return std::make_shared(_max_doc); + } + +private: + uint32_t _max_doc = 0; +}; + +class AllQuery : public Query { +public: + explicit AllQuery(uint32_t max_doc) : _max_doc(max_doc) {} + + ~AllQuery() override = default; + + WeightPtr weight(bool /*enable_scoring*/) override { + return std::make_shared(_max_doc); + } + +private: + uint32_t _max_doc = 0; +}; + +} // namespace doris::segment_v2::inverted_index::query_v2 diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/boolean_query_builder.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/boolean_query_builder.h new file mode 100644 index 00000000000000..3cbcca580801f5 --- /dev/null +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/boolean_query_builder.h @@ -0,0 +1,84 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/occur.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/occur_boolean_query.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/operator_boolean_query.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/query.h" + +namespace doris::segment_v2::inverted_index::query_v2 { + +class OccurBooleanQueryBuilder { +public: + OccurBooleanQueryBuilder() = default; + ~OccurBooleanQueryBuilder() = default; + + void add(const QueryPtr& query, Occur occur) { _sub_queries.emplace_back(occur, query); } + + void set_minimum_number_should_match(size_t value) { _minimum_number_should_match = value; } + + QueryPtr build() { + if (_minimum_number_should_match.has_value()) { + return std::make_shared(std::move(_sub_queries), + _minimum_number_should_match.value()); + } + return std::make_shared(std::move(_sub_queries)); + } + +private: + std::vector> _sub_queries; + std::optional _minimum_number_should_match; +}; + +using OccurBooleanQueryBuilderPtr = std::shared_ptr; + +class OperatorBooleanQueryBuilder { +public: + OperatorBooleanQueryBuilder(OperatorType type) : _type(type) {} + ~OperatorBooleanQueryBuilder() = default; + + void add(const QueryPtr& query, std::string binding_key = {}) { + _sub_queries.emplace_back(query); + _binding_keys.emplace_back(std::move(binding_key)); + } + + QueryPtr build() { + return std::make_shared(_type, std::move(_sub_queries), + std::move(_binding_keys)); + } + +private: + OperatorType _type; + std::vector _sub_queries; + std::vector _binding_keys; +}; + +using OperatorBooleanQueryBuilderPtr = std::shared_ptr; + +inline OccurBooleanQueryBuilderPtr create_occur_boolean_query_builder() { + return std::make_shared(); +} + +inline OperatorBooleanQueryBuilderPtr create_operator_boolean_query_builder(OperatorType type) { + return std::make_shared(type); +} + +} // namespace doris::segment_v2::inverted_index::query_v2 diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/occur.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/occur.h new file mode 100644 index 00000000000000..2696e6f18200d5 --- /dev/null +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/occur.h @@ -0,0 +1,24 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +namespace doris::segment_v2::inverted_index::query_v2 { + +enum class Occur { MUST = 0, SHOULD = 1, MUST_NOT = 2 }; + +} // namespace doris::segment_v2::inverted_index::query_v2 \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/occur_boolean_query.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/occur_boolean_query.h new file mode 100644 index 00000000000000..018e2d831c7873 --- /dev/null +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/occur_boolean_query.h @@ -0,0 +1,75 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/occur.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/occur_boolean_weight.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/query.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/score_combiner.h" + +namespace doris::segment_v2::inverted_index::query_v2 { + +class OccurBooleanQuery; +using OccurBooleanQueryPtr = std::shared_ptr; + +class OccurBooleanQuery : public Query { +public: + explicit OccurBooleanQuery(std::vector> clauses) + : _sub_queries(std::move(clauses)), + _minimum_number_should_match(compute_default_minimum_should_match(_sub_queries)) {} + + OccurBooleanQuery(std::vector> clauses, + size_t minimum_number_should_match) + : _sub_queries(std::move(clauses)), + _minimum_number_should_match(minimum_number_should_match) {} + + ~OccurBooleanQuery() override = default; + + WeightPtr weight(bool enable_scoring) override { + std::vector> sub_weights; + sub_weights.reserve(_sub_queries.size()); + for (const auto& [occur, query] : _sub_queries) { + sub_weights.emplace_back(occur, query->weight(enable_scoring)); + } + return std::make_shared>( + std::move(sub_weights), _minimum_number_should_match, enable_scoring, + std::make_shared()); + } + + const std::vector>& clauses() const { return _sub_queries; } + size_t minimum_number_should_match() const { return _minimum_number_should_match; } + +private: + static size_t compute_default_minimum_should_match( + const std::vector>& clauses) { + size_t minimum_required = 0; + for (const auto& [occur, _] : clauses) { + if (occur == Occur::SHOULD) { + minimum_required = 1; + } else if (occur == Occur::MUST || occur == Occur::MUST_NOT) { + return 0; + } + } + return minimum_required; + } + + std::vector> _sub_queries; + size_t _minimum_number_should_match = 0; +}; + +} // namespace doris::segment_v2::inverted_index::query_v2 \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/occur_boolean_weight.cpp b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/occur_boolean_weight.cpp new file mode 100644 index 00000000000000..844d578338c2c5 --- /dev/null +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/occur_boolean_weight.cpp @@ -0,0 +1,287 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/occur_boolean_weight.h" + +#include "olap/rowset/segment_v2/inverted_index/query_v2/all_query/all_query.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/disjunction_scorer.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/exclude_scorer.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/intersection.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/reqopt_scorer.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/union/buffered_union.h" + +namespace doris::segment_v2::inverted_index::query_v2 { + +template +OccurBooleanWeight::OccurBooleanWeight( + std::vector> sub_weights, size_t minimum_number_should_match, + bool enable_scoring, ScoreCombinerPtrT score_combiner) + : _sub_weights(std::move(sub_weights)), + _minimum_number_should_match(minimum_number_should_match), + _enable_scoring(enable_scoring), + _score_combiner(std::move(score_combiner)) {} + +template +ScorerPtr OccurBooleanWeight::scorer(const QueryExecutionContext& context) { + if (_sub_weights.empty()) { + return std::make_shared(); + } + if (_sub_weights.size() == 1) { + const auto& [occur, weight] = _sub_weights[0]; + if (occur == Occur::MUST_NOT) { + return std::make_shared(); + } + return weight->scorer(context); + } + _max_doc = context.segment_num_rows; + if (_enable_scoring) { + auto specialized = complex_scorer(context, _score_combiner); + return into_box_scorer(std::move(specialized), _score_combiner); + } else { + auto combiner = std::make_shared(); + auto specialized = complex_scorer(context, combiner); + return into_box_scorer(std::move(specialized), combiner); + } +} + +template +std::unordered_map> +OccurBooleanWeight::per_occur_scorers(const QueryExecutionContext& context) { + std::unordered_map> result; + for (const auto& [occur, weight] : _sub_weights) { + auto sub_scorer = weight->scorer(context); + if (sub_scorer) { + result[occur].push_back(std::move(sub_scorer)); + } + } + return result; +} + +template +AllAndEmptyScorerCounts +OccurBooleanWeight::remove_and_count_all_and_empty_scorers( + std::vector& scorers) { + AllAndEmptyScorerCounts counts; + auto it = scorers.begin(); + while (it != scorers.end()) { + if (dynamic_cast(it->get()) != nullptr) { + counts.num_all_scorers++; + it = scorers.erase(it); + } else if (dynamic_cast(it->get()) != nullptr) { + counts.num_empty_scorers++; + it = scorers.erase(it); + } else { + ++it; + } + } + return counts; +} + +template +template +std::optional OccurBooleanWeight::build_should_opt( + std::vector& must_scorers, std::vector should_scorers, + CombinerT combiner, size_t num_all_scorers) { + if (should_scorers.empty()) { + return Ignored {}; + } + + size_t adjusted_minimum = _minimum_number_should_match > num_all_scorers + ? _minimum_number_should_match - num_all_scorers + : 0; + + size_t num_of_should_scorers = should_scorers.size(); + if (adjusted_minimum > num_of_should_scorers) { + return std::nullopt; + } + + if (adjusted_minimum == 0) { + return Optional {scorer_union(std::move(should_scorers), combiner)}; + } else if (adjusted_minimum == 1) { + return Required {scorer_union(std::move(should_scorers), combiner)}; + } else if (adjusted_minimum == num_of_should_scorers) { + must_scorers.swap(should_scorers); + return Ignored {}; + } else { + return Required {scorer_disjunction(std::move(should_scorers), combiner, adjusted_minimum)}; + } +} + +template +ScorerPtr OccurBooleanWeight::build_exclude_opt( + std::vector must_not_scorers) { + if (must_not_scorers.empty()) { + return nullptr; + } + auto do_nothing = std::make_shared(); + auto specialized_scorer = scorer_union(std::move(must_not_scorers), do_nothing); + return into_box_scorer(std::move(specialized_scorer), do_nothing); +} + +template +template +SpecializedScorer OccurBooleanWeight::build_positive_opt( + CombinationMethod& should_opt, std::vector must_scorers, CombinerT combiner, + size_t num_all_scorers) { + const bool has_must = !must_scorers.empty(); + if (std::holds_alternative(should_opt)) { + if (has_must) { + return make_intersect_scorers(std::move(must_scorers), _max_doc); + } + if (num_all_scorers > 0) { + return std::make_shared(_max_doc); + } + return std::make_shared(); + } + + if (std::holds_alternative(should_opt)) { + auto& opt = std::get(should_opt); + if (has_must) { + auto must_scorer = make_intersect_scorers(std::move(must_scorers), _max_doc); + if (_enable_scoring) { + auto should_boxed = into_box_scorer(std::move(opt.scorer), combiner); + return make_required_optional_scorer(must_scorer, should_boxed, combiner); + } else { + return must_scorer; + } + } + return opt.scorer; + } + + if (std::holds_alternative(should_opt)) { + auto& req = std::get(should_opt); + if (has_must) { + must_scorers.push_back(into_box_scorer(std::move(req.scorer), combiner)); + return make_intersect_scorers(std::move(must_scorers), _max_doc); + } + return req.scorer; + } + + return std::make_shared(); +} + +template +template +SpecializedScorer OccurBooleanWeight::complex_scorer( + const QueryExecutionContext& context, CombinerT combiner) { + auto scorers_by_occur = per_occur_scorers(context); + auto must_scorers = std::move(scorers_by_occur[Occur::MUST]); + auto should_scorers = std::move(scorers_by_occur[Occur::SHOULD]); + auto must_not_scorers = std::move(scorers_by_occur[Occur::MUST_NOT]); + + auto must_special_counts = remove_and_count_all_and_empty_scorers(must_scorers); + auto should_special_counts = remove_and_count_all_and_empty_scorers(should_scorers); + auto exclude_special_counts = remove_and_count_all_and_empty_scorers(must_not_scorers); + + if (must_special_counts.num_empty_scorers > 0) { + return std::make_shared(); + } + + if (exclude_special_counts.num_all_scorers > 0) { + return std::make_shared(); + } + + auto should_opt = build_should_opt(must_scorers, std::move(should_scorers), combiner, + should_special_counts.num_all_scorers); + if (!should_opt.has_value()) { + return std::make_shared(); + } + + ScorerPtr exclude_opt = build_exclude_opt(std::move(must_not_scorers)); + size_t total_all_scorers = + must_special_counts.num_all_scorers + should_special_counts.num_all_scorers; + SpecializedScorer positive_opt = + build_positive_opt(*should_opt, std::move(must_scorers), combiner, total_all_scorers); + if (exclude_opt) { + ScorerPtr positive_boxed = into_box_scorer(std::move(positive_opt), combiner); + return make_exclude(std::move(positive_boxed), std::move(exclude_opt)); + } + return positive_opt; +} + +template +template +SpecializedScorer OccurBooleanWeight::scorer_union( + std::vector scorers, CombinerT combiner) { + if (scorers.empty()) { + return std::make_shared(); + } + + if (scorers.size() == 1) { + return std::move(scorers[0]); + } + + bool is_all_term_scorers = true; + for (const auto& scorer : scorers) { + auto* term_scorer = dynamic_cast(scorer.get()); + if (term_scorer == nullptr) { + is_all_term_scorers = false; + break; + } + } + if (is_all_term_scorers) { + std::vector term_scorers; + term_scorers.reserve(scorers.size()); + for (auto& scorer : scorers) { + term_scorers.push_back(std::dynamic_pointer_cast(scorer)); + } + return term_scorers; + } + + return make_buffered_union(std::move(scorers), combiner); +} + +template +template +SpecializedScorer OccurBooleanWeight::scorer_disjunction( + std::vector scorers, CombinerT combiner, size_t minimum_match_required) { + if (scorers.empty()) { + return std::make_shared(); + } + + if (scorers.size() == 1) { + return std::move(scorers[0]); + } + + return make_disjunction(std::move(scorers), combiner, minimum_match_required); +} + +template +template +ScorerPtr OccurBooleanWeight::into_box_scorer(SpecializedScorer&& specialized, + CombinerT combiner) { + return std::visit( + [&](auto&& arg) -> ScorerPtr { + using T = std::decay_t; + if constexpr (std::is_same_v>) { + std::vector scorers; + scorers.reserve(arg.size()); + for (auto& ts : arg) { + scorers.push_back(std::move(ts)); + } + return make_buffered_union(std::move(scorers), combiner); + } else { + return std::move(arg); + } + }, + std::move(specialized)); +} + +template class OccurBooleanWeight; +template class OccurBooleanWeight; + +} // namespace doris::segment_v2::inverted_index::query_v2 \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/occur_boolean_weight.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/occur_boolean_weight.h new file mode 100644 index 00000000000000..b143777797468e --- /dev/null +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/occur_boolean_weight.h @@ -0,0 +1,88 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/occur.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/scorer.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/term_query/term_scorer.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/weight.h" + +namespace doris::segment_v2::inverted_index::query_v2 { + +using SpecializedScorer = std::variant, ScorerPtr>; + +struct Ignored {}; +struct Optional { + SpecializedScorer scorer; +}; +struct Required { + SpecializedScorer scorer; +}; +using CombinationMethod = std::variant; + +struct AllAndEmptyScorerCounts { + size_t num_all_scorers = 0; + size_t num_empty_scorers = 0; +}; + +template +class OccurBooleanWeight : public Weight { +public: + OccurBooleanWeight(std::vector> sub_weights, + size_t minimum_number_should_match, bool enable_scoring, + ScoreCombinerPtrT score_combiner); + ~OccurBooleanWeight() override = default; + + ScorerPtr scorer(const QueryExecutionContext& context) override; + +private: + std::unordered_map> per_occur_scorers( + const QueryExecutionContext& context); + AllAndEmptyScorerCounts remove_and_count_all_and_empty_scorers(std::vector& scorers); + + template + SpecializedScorer complex_scorer(const QueryExecutionContext& context, CombinerT combiner); + + template + std::optional build_should_opt(std::vector& must_scorers, + std::vector should_scorers, + CombinerT combiner, size_t num_all_scorers); + ScorerPtr build_exclude_opt(std::vector must_not_scorers); + template + SpecializedScorer build_positive_opt(CombinationMethod& should_opt, + std::vector must_scorers, CombinerT combiner, + size_t num_all_scorers = 0); + + template + SpecializedScorer scorer_union(std::vector scorers, CombinerT combiner); + template + SpecializedScorer scorer_disjunction(std::vector scorers, CombinerT combiner, + size_t minimum_match_required); + + template + ScorerPtr into_box_scorer(SpecializedScorer&& specialized, CombinerT combiner); + + std::vector> _sub_weights; + size_t _minimum_number_should_match = 1; + bool _enable_scoring = false; + ScoreCombinerPtrT _score_combiner; + + uint32_t _max_doc = 0; +}; + +} // namespace doris::segment_v2::inverted_index::query_v2 \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/operator.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/operator.h similarity index 100% rename from be/src/olap/rowset/segment_v2/inverted_index/query_v2/operator.h rename to be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/operator.h diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/boolean_query.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/operator_boolean_query.h similarity index 62% rename from be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/boolean_query.h rename to be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/operator_boolean_query.h index a28ebbd9599c02..4278659bbddb1a 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/boolean_query.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/operator_boolean_query.h @@ -17,24 +17,24 @@ #pragma once -#include "olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/boolean_weight.h" -#include "olap/rowset/segment_v2/inverted_index/query_v2/operator.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/operator.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/operator_boolean_weight.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/query.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/score_combiner.h" namespace doris::segment_v2::inverted_index::query_v2 { -class BooleanQuery; -using BooleanQueryPtr = std::shared_ptr; +class OperatorBooleanQuery; +using OperatorBooleanQueryPtr = std::shared_ptr; -class BooleanQuery : public Query { +class OperatorBooleanQuery : public Query { public: - BooleanQuery(OperatorType type, std::vector clauses, - std::vector binding_keys) + OperatorBooleanQuery(OperatorType type, std::vector clauses, + std::vector binding_keys) : _type(type), _sub_queries(std::move(clauses)), _binding_keys(std::move(binding_keys)) {} - ~BooleanQuery() override = default; + ~OperatorBooleanQuery() override = default; WeightPtr weight(bool enable_scoring) override { std::vector sub_weights; @@ -42,36 +42,15 @@ class BooleanQuery : public Query { sub_weights.emplace_back(query->weight(enable_scoring)); } if (enable_scoring) { - return std::make_shared>( + return std::make_shared>( _type, std::move(sub_weights), _binding_keys, std::make_shared()); } else { - return std::make_shared>( + return std::make_shared>( _type, std::move(sub_weights), _binding_keys, std::make_shared()); } } - class Builder { - public: - Builder(OperatorType type) : _type(type) {} - ~Builder() = default; - - void add(const QueryPtr& query, std::string binding_key = {}) { - _sub_queries.emplace_back(query); - _binding_keys.emplace_back(std::move(binding_key)); - } - - BooleanQueryPtr build() { - return std::make_shared(_type, std::move(_sub_queries), - std::move(_binding_keys)); - } - - private: - OperatorType _type; - std::vector _sub_queries; - std::vector _binding_keys; - }; - private: OperatorType _type; std::vector _sub_queries; diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/boolean_weight.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/operator_boolean_weight.h similarity index 95% rename from be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/boolean_weight.h rename to be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/operator_boolean_weight.h index b427386f85a1e6..90c979addade71 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/boolean_weight.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/operator_boolean_weight.h @@ -23,25 +23,25 @@ #include #include "olap/rowset/segment_v2/inverted_index/query_v2/bit_set_query/bit_set_scorer.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/operator.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/buffered_union_scorer.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/doc_set.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/intersection_scorer.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/match_all_docs_scorer.h" -#include "olap/rowset/segment_v2/inverted_index/query_v2/operator.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/weight.h" namespace doris::segment_v2::inverted_index::query_v2 { template -class BooleanWeight : public Weight { +class OperatorBooleanWeight : public Weight { public: - BooleanWeight(OperatorType type, std::vector sub_weights, - std::vector binding_keys, ScoreCombinerPtrT score_combiner) + OperatorBooleanWeight(OperatorType type, std::vector sub_weights, + std::vector binding_keys, ScoreCombinerPtrT score_combiner) : _type(type), _sub_weights(std::move(sub_weights)), _binding_keys(std::move(binding_keys)), _score_combiner(std::move(score_combiner)) {} - ~BooleanWeight() override = default; + ~OperatorBooleanWeight() override = default; ScorerPtr scorer(const QueryExecutionContext& context) override { if (_is_do_nothing_combiner()) { @@ -113,7 +113,7 @@ class BooleanWeight : public Weight { const auto& sub_weight = _sub_weights[i]; const auto& binding_key = _binding_keys[i]; auto boolean_weight = - std::dynamic_pointer_cast>(sub_weight); + std::dynamic_pointer_cast>(sub_weight); if (boolean_weight != nullptr && boolean_weight->_type == OperatorType::OP_NOT) { auto excludes = boolean_weight->per_scorers(context); for (auto& exclude : excludes) { diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/disjunction_scorer.cpp b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/disjunction_scorer.cpp new file mode 100644 index 00000000000000..b15116328a2814 --- /dev/null +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/disjunction_scorer.cpp @@ -0,0 +1,129 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "olap/rowset/segment_v2/inverted_index/query_v2/disjunction_scorer.h" + +#include "olap/rowset/segment_v2/inverted_index/query_v2/score_combiner.h" + +namespace doris::segment_v2::inverted_index::query_v2 { + +template +DisjunctionScorer::DisjunctionScorer(std::vector scorers, + ScoreCombinerPtrT score_combiner, + size_t minimum_matches_required) + : _minimum_matches_required(minimum_matches_required), + _score_combiner(std::move(score_combiner)) { + DCHECK(minimum_matches_required > 1) << "union scorer works better if just one match required"; + + for (auto& scorer : scorers) { + if (scorer && scorer->doc() != TERMINATED) { + _size_hint = std::max(_size_hint, scorer->size_hint()); + _heap.emplace(std::move(scorer)); + } + } + + if (_minimum_matches_required > _heap.size()) { + return; + } + + do_advance(); +} + +template +void DisjunctionScorer::do_advance() { + size_t current_num_matches = 0; + while (!_heap.empty()) { + ScorerWrapper candidate = std::move(const_cast(_heap.top())); + _heap.pop(); + + uint32_t next = candidate.current_doc; + if (next == TERMINATED) { + continue; + } + + if (_current_doc != next) { + if (current_num_matches >= _minimum_matches_required) { + _heap.push(std::move(candidate)); + _current_score = _score_combiner->score(); + return; + } + current_num_matches = 0; + _current_doc = next; + _score_combiner->clear(); + } + + current_num_matches++; + _score_combiner->update(candidate.scorer); + + candidate.current_doc = candidate.scorer->advance(); + _heap.push(std::move(candidate)); + } + + if (current_num_matches < _minimum_matches_required) { + _current_doc = TERMINATED; + } + _current_score = _score_combiner->score(); +} + +template +uint32_t DisjunctionScorer::advance() { + if (_current_doc == TERMINATED) { + return TERMINATED; + } + do_advance(); + return _current_doc; +} + +template +uint32_t DisjunctionScorer::seek(uint32_t target) { + if (_current_doc == TERMINATED) { + return TERMINATED; + } + while (_current_doc < target && _current_doc != TERMINATED) { + do_advance(); + } + return _current_doc; +} + +template +uint32_t DisjunctionScorer::size_hint() const { + return _size_hint; +} + +template +ScorerPtr make_disjunction(std::vector scorers, ScoreCombinerPtrT score_combiner, + size_t minimum_matches_required) { + if (scorers.empty()) { + return std::make_shared(); + } + if (minimum_matches_required > scorers.size()) { + return std::make_shared(); + } + return std::make_shared>( + std::move(scorers), std::move(score_combiner), minimum_matches_required); +} + +template class DisjunctionScorer; +template class DisjunctionScorer; + +template ScorerPtr make_disjunction(std::vector scorers, SumCombinerPtr score_combiner, + size_t minimum_matches_required); +template ScorerPtr make_disjunction(std::vector scorers, + DoNothingCombinerPtr score_combiner, + size_t minimum_matches_required); + +} // namespace doris::segment_v2::inverted_index::query_v2 \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/disjunction_scorer.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/disjunction_scorer.h new file mode 100644 index 00000000000000..3bcb96aeef4c1e --- /dev/null +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/disjunction_scorer.h @@ -0,0 +1,66 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +#include "olap/rowset/segment_v2/inverted_index/query_v2/scorer.h" + +namespace doris::segment_v2::inverted_index::query_v2 { + +template +class DisjunctionScorer : public Scorer { +public: + DisjunctionScorer(std::vector scorers, ScoreCombinerPtrT score_combiner, + size_t minimum_matches_required); + ~DisjunctionScorer() override = default; + + uint32_t advance() override; + uint32_t seek(uint32_t target) override; + uint32_t doc() const override { return _current_doc; } + uint32_t size_hint() const override; + float score() override { return _current_score; } + +private: + struct ScorerWrapper { + ScorerPtr scorer; + uint32_t current_doc; + + ScorerWrapper(ScorerPtr s) : scorer(std::move(s)), current_doc(scorer->doc()) {} + + bool operator>(const ScorerWrapper& other) const { return current_doc > other.current_doc; } + }; + + void do_advance(); + + std::priority_queue, std::greater> + _heap; + size_t _minimum_matches_required; + ScoreCombinerPtrT _score_combiner; + + uint32_t _current_doc = TERMINATED; + float _current_score = 0.0F; + uint32_t _size_hint = 0; +}; + +template +ScorerPtr make_disjunction(std::vector scorers, ScoreCombinerPtrT score_combiner, + size_t minimum_matches_required); + +} // namespace doris::segment_v2::inverted_index::query_v2 \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/doc_set.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/doc_set.h index e24bd09b134ae5..f1d44e2ec61648 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/doc_set.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/doc_set.h @@ -51,6 +51,8 @@ class DocSet { "size_hint() method not implemented in base DocSet class"); } + virtual uint64_t cost() const { return static_cast(size_hint()); } + virtual uint32_t freq() const { throw Exception(doris::ErrorCode::NOT_IMPLEMENTED_ERROR, "freq() method not implemented in base DocSet class"); @@ -61,6 +63,7 @@ class DocSet { "norm() method not implemented in base DocSet class"); } }; + using DocSetPtr = std::shared_ptr; class MockDocSet : public DocSet { @@ -186,6 +189,7 @@ class MockDocSet : public DocSet { uint32_t _size_hint_val = 0; uint32_t _norm_val = 1; }; + using MockDocSetPtr = std::shared_ptr; } // namespace doris::segment_v2::inverted_index::query_v2 \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/exclude_scorer.cpp b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/exclude_scorer.cpp new file mode 100644 index 00000000000000..09d7b28c4107d2 --- /dev/null +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/exclude_scorer.cpp @@ -0,0 +1,86 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "olap/rowset/segment_v2/inverted_index/query_v2/exclude_scorer.h" + +namespace doris::segment_v2::inverted_index::query_v2 { + +template +Exclude::Exclude(TDocSet underlying_docset, + TDocSetExclude excluding_docset) + : _underlying_docset(std::move(underlying_docset)), + _excluding_docset(std::move(excluding_docset)) { + while (_underlying_docset->doc() != TERMINATED) { + uint32_t target = _underlying_docset->doc(); + if (!is_within(_excluding_docset, target)) { + break; + } + _underlying_docset->advance(); + } +} + +template +uint32_t Exclude::advance() { + while (true) { + uint32_t candidate = _underlying_docset->advance(); + if (candidate == TERMINATED) { + return TERMINATED; + } + if (!is_within(_excluding_docset, candidate)) { + return candidate; + } + } +} + +template +uint32_t Exclude::seek(uint32_t target) { + uint32_t candidate = _underlying_docset->seek(target); + if (candidate == TERMINATED) { + return TERMINATED; + } + if (!is_within(_excluding_docset, candidate)) { + return candidate; + } + return advance(); +} + +template +uint32_t Exclude::doc() const { + return _underlying_docset->doc(); +} + +template +uint32_t Exclude::size_hint() const { + return _underlying_docset->size_hint(); +} + +template +float Exclude::score() { + if constexpr (std::is_base_of_v) { + return _underlying_docset->score(); + } + return 0.0F; +} + +ScorerPtr make_exclude(ScorerPtr underlying, ScorerPtr excluding) { + return std::make_shared>(std::move(underlying), + std::move(excluding)); +} + +template class Exclude; + +} // namespace doris::segment_v2::inverted_index::query_v2 \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/exclude_scorer.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/exclude_scorer.h new file mode 100644 index 00000000000000..a1f0523467d34d --- /dev/null +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/exclude_scorer.h @@ -0,0 +1,50 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "olap/rowset/segment_v2/inverted_index/query_v2/scorer.h" + +namespace doris::segment_v2::inverted_index::query_v2 { + +template +inline bool is_within(TDocSetExclude& docset, uint32_t doc) { + return docset->doc() <= doc && docset->seek(doc) == doc; +} + +template +class Exclude final : public Scorer { +public: + Exclude(TDocSet underlying_docset, TDocSetExclude excluding_docset); + ~Exclude() override = default; + + uint32_t advance() override; + uint32_t seek(uint32_t target) override; + uint32_t doc() const override; + uint32_t size_hint() const override; + float score() override; + +private: + TDocSet _underlying_docset; + TDocSetExclude _excluding_docset; +}; + +using ExcludeScorerPtr = std::shared_ptr>; + +ScorerPtr make_exclude(ScorerPtr underlying, ScorerPtr excluding); + +} // namespace doris::segment_v2::inverted_index::query_v2 \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/intersection.cpp b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/intersection.cpp index 0198c80241ee0e..c2cf7f82564409 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/intersection.cpp +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/intersection.cpp @@ -17,14 +17,84 @@ #include "olap/rowset/segment_v2/inverted_index/query_v2/intersection.h" +#include "common/status.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/doc_set.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/postings_with_offset.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/size_hint.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/term_query/term_scorer.h" namespace doris::segment_v2::inverted_index::query_v2 { +template +struct is_scorer_ptr : std::false_type {}; + +template +struct is_scorer_ptr> + : std::is_base_of {}; + +template +inline constexpr bool is_scorer_ptr_v = is_scorer_ptr::value; + +template +uint32_t go_to_first_doc(const std::vector& docsets) { + if (docsets.empty()) { + throw Exception(ErrorCode::INVALID_ARGUMENT, + "At least 1 docset is required for intersection"); + } + + uint32_t candidate = docsets.front()->doc(); + for (size_t i = 1; i < docsets.size(); ++i) { + candidate = std::max(candidate, docsets[i]->doc()); + } + +outer: + while (true) { + for (const auto& docset : docsets) { + if (docset->doc() < candidate) { + uint32_t seek_doc = docset->seek(candidate); + if (seek_doc > candidate) { + candidate = docset->doc(); + goto outer; + } + } + } + return candidate; + } +} + +ScorerPtr make_intersect_scorers(std::vector scorers, uint32_t num_docs) { + if (scorers.empty()) { + return std::make_shared(); + } + if (scorers.size() == 1) { + return std::move(scorers[0]); + } + std::ranges::sort(scorers, + [](const ScorerPtr& a, const ScorerPtr& b) { return a->cost() < b->cost(); }); + uint32_t doc = go_to_first_doc(scorers); + if (doc == TERMINATED) { + return std::make_shared(); + } + + auto left = scorers[0]; + auto right = scorers[1]; + std::vector others(std::make_move_iterator(scorers.begin() + 2), + std::make_move_iterator(scorers.end())); + + auto left_term = std::dynamic_pointer_cast(left); + auto right_term = std::dynamic_pointer_cast(right); + if (left_term && right_term) { + return std::make_shared>( + std::move(left_term), std::move(right_term), std::move(others), num_docs); + } + return std::make_shared>(std::move(left), std::move(right), + std::move(others), num_docs); +} + template template std::enable_if_t, IntersectionPtr> -Intersection::create(std::vector& docsets) { +Intersection::create(std::vector& docsets, uint32_t num_docs) { size_t num_docsets = docsets.size(); if (num_docsets < 2) { throw Exception(ErrorCode::INVALID_ARGUMENT, @@ -32,60 +102,38 @@ Intersection::create(std::vector& docsets) { } std::sort(docsets.begin(), docsets.end(), - [](const TDocSet& a, const TDocSet& b) { return a->size_hint() < b->size_hint(); }); + [](const TDocSet& a, const TDocSet& b) { return a->cost() < b->cost(); }); go_to_first_doc(docsets); TDocSet left = std::move(docsets[0]); TDocSet right = std::move(docsets[1]); docsets.erase(docsets.begin(), docsets.begin() + 2); return std::make_shared>(std::move(left), std::move(right), - std::move(docsets)); + std::move(docsets), num_docs); } template Intersection::Intersection(TDocSet left, TDocSet right, - std::vector others) - : _left(std::move(left)), _right(std::move(right)), _others(std::move(others)) {} + std::vector others, + uint32_t num_docs) + : _left(std::move(left)), + _right(std::move(right)), + _others(std::move(others)), + _num_docs(num_docs) {} template uint32_t Intersection::advance() { - uint32_t candidate = _left->advance(); - - while (true) { - while (true) { - uint32_t right_doc = _right->seek(candidate); - candidate = _left->seek(right_doc); - if (candidate == right_doc) { - break; - } - } - - bool need_continue = false; - for (const auto& docset : _others) { - uint32_t seek_doc = docset->seek(candidate); - if (seek_doc > candidate) { - candidate = _left->seek(seek_doc); - need_continue = true; - break; - } - } - - if (!need_continue) { - return candidate; - } - } + return intersect_from(_left->advance()); } template uint32_t Intersection::seek(uint32_t target) { _left->seek(target); - std::vector docsets; - docsets.push_back(_left); - docsets.push_back(_right); - for (auto& docset : _others) { - docsets.push_back(docset); + uint32_t candidate = std::max(_left->doc(), _right->doc()); + for (const auto& docset : _others) { + candidate = std::max(candidate, docset->doc()); } - return go_to_first_doc(docsets); + return intersect_from(candidate); } template @@ -95,7 +143,19 @@ uint32_t Intersection::doc() const { template uint32_t Intersection::size_hint() const { - return _left->size_hint(); + std::vector sizes; + sizes.reserve(2 + _others.size()); + sizes.push_back(_left->size_hint()); + sizes.push_back(_right->size_hint()); + for (const auto& docset : _others) { + sizes.push_back(docset->size_hint()); + } + return estimate_intersection(sizes, _num_docs); +} + +template +uint64_t Intersection::cost() const { + return _left->cost(); } template @@ -104,32 +164,13 @@ uint32_t Intersection::norm() const { } template -uint32_t Intersection::go_to_first_doc(const std::vector& docsets) { - if (docsets.empty()) { - throw Exception(ErrorCode::INVALID_ARGUMENT, - "At least 1 docset is required for intersection"); - } - - uint32_t candidate = docsets.front()->doc(); - for (size_t i = 1; i < docsets.size(); ++i) { - candidate = std::max(candidate, docsets[i]->seek(candidate)); - } - - while (true) { - bool need_continue = false; - - for (const auto& docset : docsets) { - uint32_t seek_doc = docset->seek(candidate); - if (seek_doc > candidate) { - candidate = docset->doc(); - need_continue = true; - break; - } - } - - if (!need_continue) { - return candidate; - } +float Intersection::score() { + if constexpr (is_scorer_ptr_v) { + return _left->score() + _right->score() + + std::accumulate(_others.begin(), _others.end(), 0.0F, + [](float sum, const auto& scorer) { return sum + scorer->score(); }); + } else { + return 0.0F; } } @@ -147,15 +188,42 @@ Intersection::docset_mut_specialized(size_t ord) { } } -#define INSTANTIATE_INTERSECTION(T) \ - template class Intersection; \ - template std::enable_if_t, IntersectionPtr> \ - Intersection::create(std::vector & docsets); \ - template std::enable_if_t, T&> \ +template +uint32_t Intersection::intersect_from(uint32_t candidate) { +left_right_intersection: + while (true) { + uint32_t right_doc = _right->seek(candidate); + if (right_doc != candidate) { + candidate = _left->seek(right_doc); + if (candidate != right_doc) { + continue; + } + } + break; + } + + for (const auto& docset : _others) { + if (docset->doc() < candidate) { + uint32_t seek_doc = docset->seek(candidate); + if (seek_doc > candidate) { + candidate = _left->seek(seek_doc); + goto left_right_intersection; + } + } + } + + return candidate; +} + +#define INSTANTIATE_INTERSECTION(T) \ + template class Intersection; \ + template std::enable_if_t, IntersectionPtr> \ + Intersection::create(std::vector & docsets, uint32_t num_docs); \ + template std::enable_if_t, T&> \ Intersection::docset_mut_specialized(size_t ord); INSTANTIATE_INTERSECTION(std::shared_ptr>) -INSTANTIATE_INTERSECTION(std::shared_ptr>) +INSTANTIATE_INTERSECTION(std::shared_ptr>) INSTANTIATE_INTERSECTION(MockDocSetPtr) #undef INSTANTIATE_INTERSECTION diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/intersection.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/intersection.h index 8dd2430fd3f3d0..b3073c108bd026 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/intersection.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/intersection.h @@ -30,30 +30,42 @@ template using IntersectionPtr = std::shared_ptr>; template -class Intersection final : public DocSet { +class Intersection final : public Scorer { public: - Intersection(TDocSet left, TDocSet right, std::vector others); + Intersection(TDocSet left, TDocSet right, std::vector others, uint32_t num_docs); ~Intersection() override = default; template static std::enable_if_t, IntersectionPtr> create( - std::vector& docsets); + std::vector& docsets, uint32_t num_docs); uint32_t advance() override; uint32_t seek(uint32_t target) override; uint32_t doc() const override; uint32_t size_hint() const override; + uint64_t cost() const override; uint32_t norm() const override; + float score() override; + template std::enable_if_t, TDocSet&> docset_mut_specialized(size_t ord); private: - static uint32_t go_to_first_doc(const std::vector& docsets); + uint32_t intersect_from(uint32_t candidate); TDocSet _left; TDocSet _right; std::vector _others; + + uint32_t _num_docs = 0; }; +ScorerPtr make_intersect_scorers(std::vector scorers, uint32_t num_docs); + +template +auto make_intersection(std::vector docsets, uint32_t num_docs) { + return Intersection::create(docsets, num_docs); +} + } // namespace doris::segment_v2::inverted_index::query_v2 \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/multi_phrase_weight.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/multi_phrase_weight.h index 981ab0148be77c..e75c59e3607799 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/multi_phrase_weight.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/multi_phrase_weight.h @@ -98,7 +98,8 @@ class MultiPhraseWeight : public Weight { term_postings_list.emplace_back(offset, std::move(union_posting)); } } - return PhraseScorer::create(term_postings_list, _similarity, 0); + uint32_t num_docs = ctx.segment_num_rows; + return PhraseScorer::create(term_postings_list, _similarity, 0, num_docs); } IndexQueryContextPtr _context; diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/phrase_scorer.cpp b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/phrase_scorer.cpp index 2aa64fff95c111..bccd9fd85a421d 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/phrase_scorer.cpp +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/phrase_scorer.cpp @@ -22,7 +22,7 @@ namespace doris::segment_v2::inverted_index::query_v2 { template ScorerPtr PhraseScorer::create_with_offset( const std::vector>& term_postings_with_offset, - const SimilarityPtr& similarity, uint32_t slop, size_t offset) { + const SimilarityPtr& similarity, uint32_t slop, size_t offset, uint32_t num_docs) { size_t max_offset = offset; for (const auto& [term_offset, _] : term_postings_with_offset) { max_offset = std::max(max_offset, term_offset + offset); @@ -38,9 +38,8 @@ ScorerPtr PhraseScorer::create_with_offset( postings_with_offsets.emplace_back(std::move(postings_with_offset)); } - using IntersectionType = - Intersection, PostingsWithOffsetPtr>; - auto intersection_docset = IntersectionType::create(postings_with_offsets); + auto intersection_docset = + make_intersection>(postings_with_offsets, num_docs); std::vector left_positions(100); std::vector right_positions(100); auto scorer = std::make_shared>( @@ -64,7 +63,7 @@ uint32_t PhraseScorer::advance() { template uint32_t PhraseScorer::seek(uint32_t target) { - assert(target > doc()); + assert(target >= doc()); uint32_t doc = _intersection_docset->seek(target); if (doc == TERMINATED || phrase_match()) { return doc; @@ -82,6 +81,11 @@ uint32_t PhraseScorer::size_hint() const { return _intersection_docset->size_hint(); } +template +uint64_t PhraseScorer::cost() const { + return static_cast(_intersection_docset->size_hint()) * 10 * _num_terms; +} + template uint32_t PhraseScorer::norm() const { return _intersection_docset->norm(); @@ -184,6 +188,6 @@ bool PhraseScorer::intersection_exists(const std::vector& l } template class PhraseScorer; -template class PhraseScorer; +template class PhraseScorer; } // namespace doris::segment_v2::inverted_index::query_v2 \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/phrase_scorer.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/phrase_scorer.h index 494dfa90b0eefe..9ee4fe241c26c5 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/phrase_scorer.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/phrase_scorer.h @@ -49,14 +49,15 @@ class PhraseScorer : public Scorer { ~PhraseScorer() override = default; static ScorerPtr create(const std::vector>& term_postings, - const SimilarityPtr& similarity, uint32_t slop) { - return create_with_offset(term_postings, similarity, slop, 0); + const SimilarityPtr& similarity, uint32_t slop, uint32_t num_docs) { + return create_with_offset(term_postings, similarity, slop, 0, num_docs); } uint32_t advance() override; uint32_t seek(uint32_t target) override; uint32_t doc() const override; uint32_t size_hint() const override; + uint64_t cost() const override; uint32_t norm() const override; float score() override; @@ -66,7 +67,7 @@ class PhraseScorer : public Scorer { private: static ScorerPtr create_with_offset( const std::vector>& term_postings_with_offset, - const SimilarityPtr& similarity, uint32_t slop, size_t offset); + const SimilarityPtr& similarity, uint32_t slop, size_t offset, uint32_t num_docs); bool phrase_exists(); uint32_t compute_phrase_count(); diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/phrase_weight.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/phrase_weight.h index 0a12fd117d3a16..75457aafef451a 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/phrase_weight.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/phrase_weight.h @@ -57,7 +57,7 @@ class PhraseWeight : public Weight { StringHelper::to_string(_field)); } - std::vector> term_postings_list; + std::vector> term_postings_list; for (const auto& term_info : _term_infos) { size_t offset = term_info.position; auto posting = @@ -69,7 +69,9 @@ class PhraseWeight : public Weight { return std::make_shared(); } } - return PhraseScorer::create(term_postings_list, _similarity, 0); + uint32_t num_docs = ctx.segment_num_rows; + return PhraseScorer::create(term_postings_list, _similarity, 0, + num_docs); } IndexQueryContextPtr _context; diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/postings/loaded_postings.cpp b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/postings/loaded_postings.cpp index 1639c4e0f4ec8a..2d3c8d7986b582 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/postings/loaded_postings.cpp +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/postings/loaded_postings.cpp @@ -130,7 +130,6 @@ void LoadedPostings::append_positions_with_offset(uint32_t offset, std::vector>( - SegmentPostings& segment_postings); +template LoadedPostingsPtr LoadedPostings::load(SegmentPostings& segment_postings); } // namespace doris::segment_v2::inverted_index::query_v2 \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/regexp_query/regexp_weight.cpp b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/regexp_query/regexp_weight.cpp index 456cb702d3233a..cb71e10daa463c 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/regexp_query/regexp_weight.cpp +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/regexp_query/regexp_weight.cpp @@ -96,7 +96,7 @@ ScorerPtr RegexpWeight::regexp_scorer(const QueryExecutionContext& context, auto t = make_term_ptr(_field.c_str(), term.c_str()); auto reader = lookup_reader(_field, context, binding_key); auto iter = make_term_doc_ptr(reader.get(), t.get(), _enable_scoring, _context->io_ctx); - auto segment_postings = std::make_shared>(std::move(iter)); + auto segment_postings = make_segment_postings(std::move(iter), _enable_scoring); uint32_t doc = segment_postings->doc(); while (doc != TERMINATED) { diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/reqopt_scorer.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/reqopt_scorer.h new file mode 100644 index 00000000000000..b341ff5fdea593 --- /dev/null +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/reqopt_scorer.h @@ -0,0 +1,80 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "olap/rowset/segment_v2/inverted_index/query_v2/score_combiner.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/scorer.h" + +namespace doris::segment_v2::inverted_index::query_v2 { + +template +class RequiredOptionalScorer final : public Scorer { +public: + RequiredOptionalScorer(ScorerPtr req_scorer, ScorerPtr opt_scorer, TScoreCombiner combiner) + : _req_scorer(std::move(req_scorer)), + _opt_scorer(std::move(opt_scorer)), + _combiner(std::move(combiner)) {} + + ~RequiredOptionalScorer() override = default; + + uint32_t advance() override { + _score_cache.reset(); + return _req_scorer->advance(); + } + + uint32_t seek(uint32_t target) override { + _score_cache.reset(); + return _req_scorer->seek(target); + } + + uint32_t doc() const override { return _req_scorer->doc(); } + + uint32_t size_hint() const override { return _req_scorer->size_hint(); } + + float score() override { + if (_score_cache.has_value()) { + return _score_cache.value(); + } + uint32_t current_doc = doc(); + auto score_combiner = _combiner->clone(); + score_combiner->update(_req_scorer); + if (_opt_scorer->doc() <= current_doc && _opt_scorer->seek(current_doc) == current_doc) { + score_combiner->update(_opt_scorer); + } + float combined_score = score_combiner->score(); + _score_cache = combined_score; + return combined_score; + } + +private: + ScorerPtr _req_scorer; + ScorerPtr _opt_scorer; + TScoreCombiner _combiner; + std::optional _score_cache; +}; + +template +auto make_required_optional_scorer(ScorerPtr req_scorer, ScorerPtr opt_scorer, + TScoreCombiner combiner) { + return std::make_shared>( + std::move(req_scorer), std::move(opt_scorer), std::move(combiner)); +} + +} // namespace doris::segment_v2::inverted_index::query_v2 \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/segment_postings.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/segment_postings.h index 91083fbaf24075..455723ba28ffb4 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/segment_postings.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/segment_postings.h @@ -17,6 +17,9 @@ #pragma once +#include + +#include "CLucene/index/DocRange.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/doc_set.h" #include "olap/rowset/segment_v2/inverted_index_common.h" @@ -34,103 +37,167 @@ class Postings : public DocSet { virtual void append_positions_with_offset(uint32_t offset, std::vector& output) = 0; }; + using PostingsPtr = std::shared_ptr; -template -class SegmentPostingsBase : public Postings { +class SegmentPostings final : public Postings { public: - SegmentPostingsBase() = default; - SegmentPostingsBase(TCLuceneIter iter) : _iter(std::move(iter)) { - if (_iter->next()) { - int32_t d = _iter->doc(); - _doc = d >= INT_MAX ? TERMINATED : d; - } else { - _doc = TERMINATED; + using IterVariant = std::variant; + + explicit SegmentPostings(TermDocsPtr iter, bool enable_scoring = false) + : _iter(std::move(iter)), _enable_scoring(enable_scoring) { + if (auto* p = std::get_if(&_iter)) { + _raw_iter = p->get(); } + _init_doc(); + } + + explicit SegmentPostings(TermPositionsPtr iter, bool enable_scoring = false) + : _iter(std::move(iter)), _enable_scoring(enable_scoring), _has_positions(true) { + if (auto* p = std::get_if(&_iter)) { + _raw_iter = p->get(); + } + _init_doc(); } uint32_t advance() override { - if (_iter->next()) { - return _doc = _iter->doc(); + if (_block.doc_many && _cursor < _block.doc_many_size_) { + return _doc = (*_block.doc_many)[_cursor++]; } - return _doc = TERMINATED; + if (!_refill()) { + return _doc = TERMINATED; + } + return _doc = (*_block.doc_many)[_cursor++]; } uint32_t seek(uint32_t target) override { if (target <= _doc) { return _doc; } - if (_iter->skipTo(target)) { - return _doc = _iter->doc(); + + if (_block.doc_many) { + while (_cursor < _block.doc_many_size_) { + uint32_t curr = (*_block.doc_many)[_cursor++]; + if (curr >= target) { + return _doc = curr; + } + } } + + _raw_iter->skipToBlock(target); + + while (_refill()) { + while (_cursor < _block.doc_many_size_) { + uint32_t curr = (*_block.doc_many)[_cursor++]; + if (curr >= target) { + return _doc = curr; + } + } + } + return _doc = TERMINATED; } uint32_t doc() const override { return _doc; } - uint32_t size_hint() const override { return _iter->docFreq(); } - uint32_t freq() const override { return _iter->freq(); } - uint32_t norm() const override { return _iter->norm(); } - void append_positions_with_offset(uint32_t offset, std::vector& output) override { - throw Exception(doris::ErrorCode::NOT_IMPLEMENTED_ERROR, - "This posting type does not support position information"); + uint32_t size_hint() const override { return _raw_iter ? _raw_iter->docFreq() : 0; } + + uint32_t freq() const override { + if (!_enable_scoring || !_block.freq_many || _cursor == 0) { + return 1; + } + return (*_block.freq_many)[_cursor - 1]; } -protected: - TCLuceneIter _iter; + uint32_t norm() const override { + if (!_enable_scoring || !_block.norm_many || _cursor == 0) { + return 1; + } + return (*_block.norm_many)[_cursor - 1]; + } -private: - uint32_t _doc = TERMINATED; -}; -template -using SegmentPostingsBasePtr = std::shared_ptr>; + void append_positions_with_offset(uint32_t offset, std::vector& output) override { + if (!_has_positions) { + throw Exception(doris::ErrorCode::NOT_IMPLEMENTED_ERROR, + "This posting type does not support position information"); + } + if (!_block.freq_many) { + throw Exception(doris::ErrorCode::INTERNAL_ERROR, + "Position information requested but freq data is missing"); + } -template -class SegmentPostings final : public SegmentPostingsBase { -public: - SegmentPostings(TCLuceneIter iter) : SegmentPostingsBase(std::move(iter)) {} -}; -using TermPostingsPtr = std::shared_ptr>; -using PositionPostingsPtr = std::shared_ptr>; + auto* term_pos_ptr = std::get_if(&_iter); + if (!term_pos_ptr || !*term_pos_ptr) { + throw Exception(doris::ErrorCode::INTERNAL_ERROR, + "Position information requested but TermPositions iterator is missing"); + } -template <> -class SegmentPostings final : public SegmentPostingsBase { -public: - SegmentPostings(TermPositionsPtr iter) - : SegmentPostingsBase(std::move(iter)) {} + uint32_t current_doc_idx = _cursor - 1; + if (current_doc_idx > _prox_cursor) { + int32_t skip_count = 0; + for (uint32_t i = _prox_cursor; i < current_doc_idx; ++i) { + skip_count += (*_block.freq_many)[i]; + } + if (skip_count > 0) { + (*term_pos_ptr)->addLazySkipProxCount(skip_count); + } + } - void append_positions_with_offset(uint32_t offset, std::vector& output) override { - auto freq = this->freq(); - size_t prev_len = output.size(); - output.resize(prev_len + freq); - for (int32_t i = 0; i < freq; ++i) { - auto pos = this->_iter->nextPosition(); - output[prev_len + i] = offset + static_cast(pos); + uint32_t freq = (*_block.freq_many)[current_doc_idx]; + int32_t position = 0; + for (uint32_t i = 0; i < freq; ++i) { + position += (*term_pos_ptr)->nextDeltaPosition(); + output.push_back(position + offset); } + + _prox_cursor = current_doc_idx + 1; } -}; -template -class NoScoreSegmentPosting final : public SegmentPostingsBase { -public: - NoScoreSegmentPosting(TCLuceneIter iter) : SegmentPostingsBase(std::move(iter)) {} + bool scoring_enabled() const { return _enable_scoring; } - uint32_t freq() const override { return 1; } - uint32_t norm() const override { return 1; } -}; +private: + bool _refill() { + _block.need_positions = _has_positions; + if (!_raw_iter->readRange(&_block)) { + return false; + } + _cursor = 0; + _prox_cursor = 0; + return _block.doc_many_size_ > 0; + } -template -class EmptySegmentPosting final : public SegmentPostingsBase { -public: - EmptySegmentPosting() = default; + void _init_doc() { + if (!_raw_iter) { + throw Exception(doris::ErrorCode::INVALID_ARGUMENT, + "SegmentPostings requires a valid iterator"); + } + if (_refill()) { + _doc = (*_block.doc_many)[_cursor++]; + } else { + _doc = TERMINATED; + } + } - uint32_t advance() override { return TERMINATED; } - uint32_t seek(uint32_t) override { return TERMINATED; } - uint32_t doc() const override { return TERMINATED; } - uint32_t size_hint() const override { return 0; } + IterVariant _iter; + lucene::index::TermDocs* _raw_iter = nullptr; + uint32_t _doc = TERMINATED; + bool _enable_scoring = false; + bool _has_positions = false; - uint32_t freq() const override { return 1; } - uint32_t norm() const override { return 1; } + DocRange _block; + uint32_t _cursor = 0; + uint32_t _prox_cursor = 0; }; +using SegmentPostingsPtr = std::shared_ptr; + +inline SegmentPostingsPtr make_segment_postings(TermDocsPtr iter, bool enable_scoring = false) { + return std::make_shared(std::move(iter), enable_scoring); +} + +inline SegmentPostingsPtr make_segment_postings(TermPositionsPtr iter, + bool enable_scoring = false) { + return std::make_shared(std::move(iter), enable_scoring); +} + } // namespace doris::segment_v2::inverted_index::query_v2 \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/size_hint.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/size_hint.h new file mode 100644 index 00000000000000..2931db3fae0b1b --- /dev/null +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/size_hint.h @@ -0,0 +1,48 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +namespace doris::segment_v2::inverted_index::query_v2 { + +inline uint32_t estimate_intersection(const std::vector& docset_sizes, + uint32_t max_docs) { + if (max_docs == 0 || docset_sizes.empty()) { + return 0; + } + + double co_loc_factor = 1.3; + + auto intersection_estimate = static_cast(docset_sizes.front()); + double smallest_docset_size = intersection_estimate; + + for (size_t i = 1; i < docset_sizes.size(); ++i) { + co_loc_factor = std::max(co_loc_factor - 0.1, 1.0); + intersection_estimate *= + (static_cast(docset_sizes[i]) / static_cast(max_docs)) * + co_loc_factor; + smallest_docset_size = std::min(smallest_docset_size, static_cast(docset_sizes[i])); + } + + return static_cast(std::min(std::round(intersection_estimate), smallest_docset_size)); +} + +} // namespace doris::segment_v2::inverted_index::query_v2 diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/term_query/term_scorer.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/term_query/term_scorer.h index 9099a71877dea7..e67621c2ab6679 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/term_query/term_scorer.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/term_query/term_scorer.h @@ -30,7 +30,6 @@ namespace doris::segment_v2::inverted_index::query_v2 { -template class TermScorer final : public Scorer { public: TermScorer(SegmentPostingsPtr segment_postings, SimilarityPtr similarity, @@ -84,8 +83,6 @@ class TermScorer final : public Scorer { std::optional _null_bitmap; }; -using TS_Base = std::shared_ptr>>>; -using TS_NoScore = std::shared_ptr>>>; -using TS_Empty = std::shared_ptr>>>; +using TermScorerPtr = std::shared_ptr; } // namespace doris::segment_v2::inverted_index::query_v2 diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/term_query/term_weight.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/term_query/term_weight.h index d532e9664cb9a9..893467f3845603 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/term_query/term_weight.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/term_query/term_weight.h @@ -38,33 +38,20 @@ class TermWeight : public Weight { ScorerPtr scorer(const QueryExecutionContext& ctx, const std::string& binding_key) override { auto reader = lookup_reader(_field, ctx, binding_key); auto logical_field = logical_field_or_fallback(ctx, binding_key, _field); - auto make_scorer = [&](auto segment_postings) -> ScorerPtr { - using PostingsT = decltype(segment_postings); - return std::make_shared>(std::move(segment_postings), _similarity, - logical_field); - }; if (!reader) { - auto segment_postings = std::make_shared>(); - return make_scorer(std::move(segment_postings)); + return std::make_shared(); } auto t = make_term_ptr(_field.c_str(), _term.c_str()); auto iter = make_term_doc_ptr(reader.get(), t.get(), _enable_scoring, _context->io_ctx); - if (iter) { - if (_enable_scoring) { - auto segment_postings = - std::make_shared>(std::move(iter)); - return make_scorer(std::move(segment_postings)); - } - auto segment_postings = - std::make_shared>(std::move(iter)); - return make_scorer(std::move(segment_postings)); + return std::make_shared( + make_segment_postings(std::move(iter), _enable_scoring), _similarity, + logical_field); } - auto segment_postings = std::make_shared>(); - return make_scorer(std::move(segment_postings)); + return std::make_shared(); } private: diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/union/buffered_union.cpp b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/union/buffered_union.cpp new file mode 100644 index 00000000000000..b7377257515156 --- /dev/null +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/union/buffered_union.cpp @@ -0,0 +1,259 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "olap/rowset/segment_v2/inverted_index/query_v2/union/buffered_union.h" + +#include +#include + +#include "olap/rowset/segment_v2/inverted_index/query_v2/score_combiner.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/scorer.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/term_query/term_scorer.h" + +namespace doris::segment_v2::inverted_index::query_v2 { + +static constexpr size_t HORIZON_NUM_TINYBITSETS = 64; +static constexpr uint32_t HORIZON = static_cast(64) * HORIZON_NUM_TINYBITSETS; + +template +ScorerPtr make_buffered_union(const std::vector& scorers, + ScoreCombinerPtrU score_combiner) { + std::vector non_empty_scorers; + for (const auto& docset : scorers) { + if (docset && docset->doc() != TERMINATED) { + non_empty_scorers.push_back(docset); + } + } + auto bitsets = std::vector(HORIZON_NUM_TINYBITSETS); + auto scores = std::vector(HORIZON); + std::ranges::generate(scores, [&score_combiner]() { return score_combiner->clone(); }); + + std::vector term_scorers; + term_scorers.reserve(non_empty_scorers.size()); + bool all_term_scorers = true; + for (const auto& scorer : non_empty_scorers) { + if (auto term_scorer = std::dynamic_pointer_cast(scorer)) { + term_scorers.push_back(term_scorer); + } else { + all_term_scorers = false; + break; + } + } + + if (all_term_scorers && !term_scorers.empty()) { + auto union_scorer = std::make_shared>( + std::move(term_scorers), bitsets, scores, HORIZON_NUM_TINYBITSETS, 0, 0); + + if (union_scorer->refill()) { + union_scorer->advance(); + } else { + union_scorer->_doc = TERMINATED; + } + + return union_scorer; + } + + auto union_scorer = std::make_shared>( + std::move(non_empty_scorers), std::move(bitsets), std::move(scores), + HORIZON_NUM_TINYBITSETS, 0, 0); + + if (union_scorer->refill()) { + union_scorer->advance(); + } else { + union_scorer->_doc = TERMINATED; + } + + return union_scorer; +} + +template +void unordered_drain_filter(std::vector& v, Predicate predicate) { + size_t i = 0; + while (i < v.size()) { + if (predicate(v[i])) { + v[i] = std::move(v.back()); + v.pop_back(); + } else { + i++; + } + } +} + +template +inline bool refill_scorer_predicate(ScorerT& scorer, std::vector& bitsets, + std::vector& scores, uint32_t min_doc, + uint32_t horizon, const ScorerPtrT& scorer_ptr) { + while (true) { + uint32_t doc = scorer.doc(); + if (doc >= horizon) { + return false; + } + uint32_t delta = doc - min_doc; + bitsets[static_cast(delta / 64)].insert_mut(delta % 64); + if constexpr (!std::is_same_v) { + scores[static_cast(delta)]->update(scorer_ptr); + } + if (scorer.advance() == TERMINATED) { + return true; + } + } +} + +template +BufferedUnion::BufferedUnion(std::vector scorers, + std::vector bitsets, + std::vector scores, + size_t cursor, uint32_t offset, + uint32_t doc) + : _scorers(std::move(scorers)), + _bitsets(std::move(bitsets)), + _scores(std::move(scores)), + _cursor(cursor), + _offset(offset), + _doc(doc) {} + +template +bool BufferedUnion::refill() { + if (_scorers.empty()) { + return false; + } + uint32_t min_doc = TERMINATED; + for (const auto& ds : _scorers) { + min_doc = std::min(min_doc, ds->doc()); + } + if (min_doc == TERMINATED) { + return false; + } + _offset = min_doc; + _cursor = 0; + _doc = min_doc; + refill(_scorers, _bitsets, _scores, min_doc); + return true; +} + +template +void BufferedUnion::refill(std::vector& scorers, + std::vector& bitsets, + std::vector& scores, + uint32_t min_doc) { + uint32_t horizon = min_doc + HORIZON; + unordered_drain_filter(scorers, [&](const ScorerPtrT& scorer_ptr) -> bool { + return refill_scorer_predicate(*scorer_ptr, bitsets, scores, min_doc, horizon, scorer_ptr); + }); +} + +template +bool BufferedUnion::advance_buffered() { + while (_cursor < HORIZON_NUM_TINYBITSETS) { + auto& bitset = _bitsets[_cursor]; + if (!bitset.is_empty()) { + uint32_t val = bitset.pop_lowest_unchecked(); + uint32_t delta = val + (static_cast(_cursor) * 64); + _doc = _offset + delta; + if constexpr (!std::is_same_v) { + auto& score_combiner = _scores[static_cast(delta)]; + _score = score_combiner->score(); + score_combiner->clear(); + } + return true; + } + _cursor++; + } + return false; +} + +template +uint32_t BufferedUnion::advance() { + if (advance_buffered()) { + return _doc; + } + if (!refill()) { + _doc = TERMINATED; + return TERMINATED; + } + if (!advance_buffered()) { + return TERMINATED; + } + return _doc; +} + +template +uint32_t BufferedUnion::seek(uint32_t target) { + if (_doc >= target) { + return _doc; + } + uint32_t gap = target - _offset; + if (gap < HORIZON) { + size_t new_cursor = static_cast(gap) / 64; + for (size_t i = _cursor; i < new_cursor; ++i) { + _bitsets[i].clear(); + } + for (size_t i = _cursor * 64; i < new_cursor * 64; ++i) { + _scores[i]->clear(); + } + _cursor = new_cursor; + uint32_t current_doc = _doc; + while (current_doc < target) { + current_doc = advance(); + } + return current_doc; + } else { + for (auto& tinyset : _bitsets) { + tinyset.clear(); + } + for (auto& score_combiner : _scores) { + score_combiner->clear(); + } + unordered_drain_filter(_scorers, [target](auto& docset) { + if (docset->doc() < target) { + docset->seek(target); + } + return docset->doc() == TERMINATED; + }); + if (!refill()) { + _doc = TERMINATED; + return TERMINATED; + } + return advance(); + } +} + +template +uint32_t BufferedUnion::doc() const { + return _doc; +} + +template +uint32_t BufferedUnion::size_hint() const { + uint32_t max_hint = 0; + for (const auto& docset : _scorers) { + max_hint = std::max(max_hint, docset->size_hint()); + } + return max_hint; +} + +template +float BufferedUnion::score() { + return _score; +} + +template ScorerPtr make_buffered_union(const std::vector& scorers, + SumCombinerPtr score_combiner); +template ScorerPtr make_buffered_union(const std::vector& scorers, + DoNothingCombinerPtr score_combiner); + +} // namespace doris::segment_v2::inverted_index::query_v2 \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/union/buffered_union.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/union/buffered_union.h new file mode 100644 index 00000000000000..10b3a21436060c --- /dev/null +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/union/buffered_union.h @@ -0,0 +1,65 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "olap/rowset/segment_v2/inverted_index/query_v2/scorer.h" +#include "olap/rowset/segment_v2/inverted_index/util/tiny_set.h" + +namespace doris::segment_v2::inverted_index::query_v2 { + +template +ScorerPtr make_buffered_union(const std::vector& scorers, + ScoreCombinerPtrU score_combiner); + +template +class BufferedUnion : public Scorer { +public: + BufferedUnion(std::vector scorers, std::vector bitsets, + std::vector scores, size_t cursor, uint32_t offset, + uint32_t doc); + ~BufferedUnion() override = default; + + uint32_t advance() override; + uint32_t seek(uint32_t target) override; + uint32_t doc() const override; + uint32_t size_hint() const override; + + float score() override; + + template + friend ScorerPtr make_buffered_union(const std::vector& scorers, + ScoreCombinerPtrU score_combiner); + +private: + bool refill(); + void refill(std::vector& scorers, std::vector& bitsets, + std::vector& scores, uint32_t min_doc); + bool advance_buffered(); + + std::vector _scorers; + std::vector _bitsets; + std::vector _scores; + size_t _cursor = 0; + uint32_t _offset = 0; + uint32_t _doc = 0; + float _score = 0.0F; +}; + +} // namespace doris::segment_v2::inverted_index::query_v2 \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/union/simple_union.cpp b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/union/simple_union.cpp index e997e262a5828e..44bca4479c4bb0 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/union/simple_union.cpp +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/union/simple_union.cpp @@ -138,6 +138,6 @@ void SimpleUnion::append_positions_with_offset(uint32_t offset, template class SimpleUnion; template class SimpleUnion; -template class SimpleUnion; +template class SimpleUnion; } // namespace doris::segment_v2::inverted_index::query_v2 \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/union/simple_union.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/union/simple_union.h index 362a07674bf6e8..ce68ed422b4ff2 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/union/simple_union.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/union/simple_union.h @@ -57,4 +57,9 @@ class SimpleUnion final : public Postings { uint32_t _doc; }; +template +auto make_simple_union(std::vector docsets) { + return SimpleUnion::create(std::move(docsets)); +} + } // namespace doris::segment_v2::inverted_index::query_v2 \ No newline at end of file diff --git a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/weight.h b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/weight.h index ffd0bebb01b213..2b53284dbf6de1 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/query_v2/weight.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/query_v2/weight.h @@ -58,7 +58,6 @@ class Weight { virtual ~Weight() = default; virtual ScorerPtr scorer(const QueryExecutionContext& context) { return scorer(context, {}); } - virtual ScorerPtr scorer(const QueryExecutionContext& context, const std::string& binding_key) { (void)binding_key; return scorer(context); @@ -107,31 +106,32 @@ class Weight { return nullptr; } - TermPostingsPtr create_term_posting(lucene::index::IndexReader* reader, - const std::wstring& field, const std::string& term, - bool enable_scoring, const io::IOContext* io_ctx) const { + SegmentPostingsPtr create_term_posting(lucene::index::IndexReader* reader, + const std::wstring& field, const std::string& term, + bool enable_scoring, const io::IOContext* io_ctx) const { auto term_wstr = StringHelper::to_wstring(term); auto t = make_term_ptr(field.c_str(), term_wstr.c_str()); auto iter = make_term_doc_ptr(reader, t.get(), enable_scoring, io_ctx); if (iter) { - return std::make_shared>(std::move(iter)); + return make_segment_postings(std::move(iter), enable_scoring); } return nullptr; } - PositionPostingsPtr create_position_posting(lucene::index::IndexReader* reader, - const std::wstring& field, const std::string& term, - bool enable_scoring, - const io::IOContext* io_ctx) const { + SegmentPostingsPtr create_position_posting(lucene::index::IndexReader* reader, + const std::wstring& field, const std::string& term, + bool enable_scoring, + const io::IOContext* io_ctx) const { auto term_wstr = StringHelper::to_wstring(term); auto t = make_term_ptr(field.c_str(), term_wstr.c_str()); auto iter = make_term_positions_ptr(reader, t.get(), enable_scoring, io_ctx); if (iter) { - return std::make_shared>(std::move(iter)); + return make_segment_postings(std::move(iter), enable_scoring); } return nullptr; } }; + using WeightPtr = std::shared_ptr; } // namespace doris::segment_v2::inverted_index::query_v2 diff --git a/be/src/olap/rowset/segment_v2/inverted_index/util/tiny_set.h b/be/src/olap/rowset/segment_v2/inverted_index/util/tiny_set.h index f33b1cb7875886..0564ccab87b682 100644 --- a/be/src/olap/rowset/segment_v2/inverted_index/util/tiny_set.h +++ b/be/src/olap/rowset/segment_v2/inverted_index/util/tiny_set.h @@ -17,43 +17,43 @@ #pragma once +#include #include -#include namespace doris::segment_v2::inverted_index { class TinySet { public: + TinySet() = default; explicit TinySet(uint64_t value) : _bits(value) {} ~TinySet() = default; - bool is_empty() const { return _bits == 0; } + TinySet(const TinySet&) = default; + TinySet& operator=(const TinySet&) = default; + TinySet(TinySet&&) = default; + TinySet& operator=(TinySet&&) = default; + + static TinySet empty() { return TinySet(0); } + static TinySet singleton(uint32_t el) { return TinySet(uint64_t(1) << el); } + bool is_empty() const { return _bits == 0; } + uint32_t len() const { return static_cast(std::popcount(_bits)); } void clear() { _bits = 0; } bool insert_mut(uint32_t el) { - if (el >= 64) { - return false; - } - uint64_t old_bits = _bits; - _bits |= (1ULL << el); - return old_bits != _bits; + uint64_t old = _bits; + _bits |= (uint64_t(1) << el); + return old != _bits; } - std::optional pop_lowest() { - if (is_empty()) { - return std::nullopt; - } - uint32_t lowest = std::countr_zero(_bits); - _bits ^= (1ULL << lowest); + uint32_t pop_lowest_unchecked() { + auto lowest = static_cast(std::countr_zero(_bits)); + _bits ^= (uint64_t(1) << lowest); return lowest; } - uint32_t len() const { return std::popcount(_bits); } - private: uint64_t _bits = 0; }; -using TinySetPtr = std::shared_ptr; } // namespace doris::segment_v2::inverted_index \ No newline at end of file diff --git a/be/src/vec/functions/function_search.cpp b/be/src/vec/functions/function_search.cpp index 4e69ab470a17d2..6fd7da39208f8f 100644 --- a/be/src/vec/functions/function_search.cpp +++ b/be/src/vec/functions/function_search.cpp @@ -37,8 +37,8 @@ #include "olap/rowset/segment_v2/inverted_index/analyzer/analyzer.h" #include "olap/rowset/segment_v2/inverted_index/query/query_helper.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/bit_set_query/bit_set_query.h" -#include "olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/boolean_query.h" -#include "olap/rowset/segment_v2/inverted_index/query_v2/operator.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/boolean_query_builder.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/operator.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/multi_phrase_query.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/phrase_query/phrase_query.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/regexp_query/regexp_query.h" @@ -426,7 +426,7 @@ Status FunctionSearch::build_query_recursive(const TSearchClause& clause, op = query_v2::OperatorType::OP_NOT; } - query_v2::BooleanQuery::Builder builder(op); + auto builder = create_operator_boolean_query_builder(op); if (clause.__isset.children) { for (const auto& child_clause : clause.children) { query_v2::QueryPtr child_query; @@ -438,11 +438,11 @@ Status FunctionSearch::build_query_recursive(const TSearchClause& clause, // - AND with empty bitmap → result is empty // - OR with empty bitmap → empty bitmap is ignored by OR logic // - NOT with empty bitmap → NOT(empty) = all rows (handled by BooleanQuery) - builder.add(child_query, std::move(child_binding_key)); + builder->add(child_query, std::move(child_binding_key)); } } - *out = builder.build(); + *out = builder->build(); return Status::OK(); } @@ -525,13 +525,13 @@ Status FunctionSearch::build_leaf_query(const TSearchClause& clause, return Status::OK(); } - query_v2::BooleanQuery::Builder builder(query_v2::OperatorType::OP_OR); + auto builder = create_operator_boolean_query_builder(query_v2::OperatorType::OP_OR); for (const auto& term_info : term_infos) { std::wstring term_wstr = StringHelper::to_wstring(term_info.get_single_term()); - builder.add(make_term_query(term_wstr), binding.binding_key); + builder->add(make_term_query(term_wstr), binding.binding_key); } - *out = builder.build(); + *out = builder->build(); return Status::OK(); } @@ -577,12 +577,13 @@ Status FunctionSearch::build_leaf_query(const TSearchClause& clause, std::wstring term_wstr = StringHelper::to_wstring(term_info.get_single_term()); *out = std::make_shared(context, field_wstr, term_wstr); } else { - query_v2::BooleanQuery::Builder builder(query_v2::OperatorType::OP_OR); + auto builder = + create_operator_boolean_query_builder(query_v2::OperatorType::OP_OR); for (const auto& term : term_info.get_multi_terms()) { std::wstring term_wstr = StringHelper::to_wstring(term); - builder.add(make_term_query(term_wstr), binding.binding_key); + builder->add(make_term_query(term_wstr), binding.binding_key); } - *out = builder.build(); + *out = builder->build(); } } else { if (QueryHelper::is_simple_phrase(phrase_term_infos)) { @@ -638,12 +639,12 @@ Status FunctionSearch::build_leaf_query(const TSearchClause& clause, return Status::OK(); } - query_v2::BooleanQuery::Builder builder(bool_type); + auto builder = create_operator_boolean_query_builder(bool_type); for (const auto& term_info : term_infos) { std::wstring term_wstr = StringHelper::to_wstring(term_info.get_single_term()); - builder.add(make_term_query(term_wstr), binding.binding_key); + builder->add(make_term_query(term_wstr), binding.binding_key); } - *out = builder.build(); + *out = builder->build(); return Status::OK(); } diff --git a/be/src/vec/functions/function_search.h b/be/src/vec/functions/function_search.h index b1ec1f638959fe..4b2f1fa2532e6e 100644 --- a/be/src/vec/functions/function_search.h +++ b/be/src/vec/functions/function_search.h @@ -27,7 +27,7 @@ #include "gen_cpp/Exprs_types.h" #include "olap/rowset/segment_v2/index_query_context.h" -#include "olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/boolean_query.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/operator_boolean_query.h" #include "vec/core/block.h" #include "vec/core/types.h" #include "vec/data_types/data_type.h" diff --git a/be/test/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/boolean_query_builder_test.cpp b/be/test/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/boolean_query_builder_test.cpp new file mode 100644 index 00000000000000..c3728e96d6705c --- /dev/null +++ b/be/test/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/boolean_query_builder_test.cpp @@ -0,0 +1,387 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/boolean_query_builder.h" + +#include + +#include +#include +#include + +#include "olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/occur.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/occur_boolean_query.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/operator.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/operator_boolean_query.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/query.h" + +namespace doris::segment_v2::inverted_index::query_v2 { + +namespace { + +class MockQuery : public Query { +public: + MockQuery(int id = 0) : _id(id) {} + ~MockQuery() override = default; + + WeightPtr weight(bool enable_scoring) override { return nullptr; } + + int id() const { return _id; } + +private: + int _id; +}; + +using MockQueryPtr = std::shared_ptr; + +} // namespace + +class BooleanQueryBuilderTest : public ::testing::Test {}; + +TEST_F(BooleanQueryBuilderTest, OccurBooleanQueryBuilderEmpty) { + auto builder = create_occur_boolean_query_builder(); + ASSERT_NE(nullptr, builder); + + auto query = builder->build(); + ASSERT_NE(nullptr, query); + + auto occur_query = std::dynamic_pointer_cast(query); + ASSERT_NE(nullptr, occur_query); + EXPECT_TRUE(occur_query->clauses().empty()); +} + +TEST_F(BooleanQueryBuilderTest, OccurBooleanQueryBuilderSingleMust) { + auto builder = create_occur_boolean_query_builder(); + auto mock_query = std::make_shared(1); + + builder->add(mock_query, Occur::MUST); + auto query = builder->build(); + + auto occur_query = std::dynamic_pointer_cast(query); + ASSERT_NE(nullptr, occur_query); + ASSERT_EQ(1u, occur_query->clauses().size()); + EXPECT_EQ(Occur::MUST, occur_query->clauses()[0].first); + EXPECT_EQ(mock_query, occur_query->clauses()[0].second); +} + +TEST_F(BooleanQueryBuilderTest, OccurBooleanQueryBuilderSingleShould) { + auto builder = create_occur_boolean_query_builder(); + auto mock_query = std::make_shared(2); + + builder->add(mock_query, Occur::SHOULD); + auto query = builder->build(); + + auto occur_query = std::dynamic_pointer_cast(query); + ASSERT_NE(nullptr, occur_query); + ASSERT_EQ(1u, occur_query->clauses().size()); + EXPECT_EQ(Occur::SHOULD, occur_query->clauses()[0].first); +} + +TEST_F(BooleanQueryBuilderTest, OccurBooleanQueryBuilderSingleMustNot) { + auto builder = create_occur_boolean_query_builder(); + auto mock_query = std::make_shared(3); + + builder->add(mock_query, Occur::MUST_NOT); + auto query = builder->build(); + + auto occur_query = std::dynamic_pointer_cast(query); + ASSERT_NE(nullptr, occur_query); + ASSERT_EQ(1u, occur_query->clauses().size()); + EXPECT_EQ(Occur::MUST_NOT, occur_query->clauses()[0].first); +} + +TEST_F(BooleanQueryBuilderTest, OccurBooleanQueryBuilderMultipleClauses) { + auto builder = create_occur_boolean_query_builder(); + auto query1 = std::make_shared(1); + auto query2 = std::make_shared(2); + auto query3 = std::make_shared(3); + + builder->add(query1, Occur::MUST); + builder->add(query2, Occur::SHOULD); + builder->add(query3, Occur::MUST_NOT); + + auto query = builder->build(); + auto occur_query = std::dynamic_pointer_cast(query); + ASSERT_NE(nullptr, occur_query); + ASSERT_EQ(3u, occur_query->clauses().size()); + + EXPECT_EQ(Occur::MUST, occur_query->clauses()[0].first); + EXPECT_EQ(query1, occur_query->clauses()[0].second); + + EXPECT_EQ(Occur::SHOULD, occur_query->clauses()[1].first); + EXPECT_EQ(query2, occur_query->clauses()[1].second); + + EXPECT_EQ(Occur::MUST_NOT, occur_query->clauses()[2].first); + EXPECT_EQ(query3, occur_query->clauses()[2].second); +} + +TEST_F(BooleanQueryBuilderTest, OccurBooleanQueryBuilderMixedOccurs) { + auto builder = create_occur_boolean_query_builder(); + + builder->add(std::make_shared(1), Occur::MUST); + builder->add(std::make_shared(2), Occur::MUST); + builder->add(std::make_shared(3), Occur::SHOULD); + builder->add(std::make_shared(4), Occur::SHOULD); + builder->add(std::make_shared(5), Occur::MUST_NOT); + + auto query = builder->build(); + auto occur_query = std::dynamic_pointer_cast(query); + ASSERT_NE(nullptr, occur_query); + ASSERT_EQ(5u, occur_query->clauses().size()); + + int must_count = 0; + int should_count = 0; + int must_not_count = 0; + for (const auto& [occur, q] : occur_query->clauses()) { + switch (occur) { + case Occur::MUST: + must_count++; + break; + case Occur::SHOULD: + should_count++; + break; + case Occur::MUST_NOT: + must_not_count++; + break; + } + } + EXPECT_EQ(2, must_count); + EXPECT_EQ(2, should_count); + EXPECT_EQ(1, must_not_count); +} + +TEST_F(BooleanQueryBuilderTest, OperatorBooleanQueryBuilderEmpty) { + auto builder = create_operator_boolean_query_builder(OperatorType::OP_AND); + ASSERT_NE(nullptr, builder); + + auto query = builder->build(); + ASSERT_NE(nullptr, query); + + auto op_query = std::dynamic_pointer_cast(query); + ASSERT_NE(nullptr, op_query); +} + +TEST_F(BooleanQueryBuilderTest, OperatorBooleanQueryBuilderOpAnd) { + auto builder = create_operator_boolean_query_builder(OperatorType::OP_AND); + auto query1 = std::make_shared(1); + auto query2 = std::make_shared(2); + + builder->add(query1); + builder->add(query2); + + auto query = builder->build(); + ASSERT_NE(nullptr, query); +} + +TEST_F(BooleanQueryBuilderTest, OperatorBooleanQueryBuilderOpOr) { + auto builder = create_operator_boolean_query_builder(OperatorType::OP_OR); + auto query1 = std::make_shared(1); + auto query2 = std::make_shared(2); + + builder->add(query1); + builder->add(query2); + + auto query = builder->build(); + ASSERT_NE(nullptr, query); +} + +TEST_F(BooleanQueryBuilderTest, OperatorBooleanQueryBuilderOpNot) { + auto builder = create_operator_boolean_query_builder(OperatorType::OP_NOT); + auto query1 = std::make_shared(1); + auto query2 = std::make_shared(2); + + builder->add(query1); + builder->add(query2); + + auto query = builder->build(); + ASSERT_NE(nullptr, query); +} + +TEST_F(BooleanQueryBuilderTest, OperatorBooleanQueryBuilderWithBindingKeys) { + auto builder = create_operator_boolean_query_builder(OperatorType::OP_AND); + auto query1 = std::make_shared(1); + auto query2 = std::make_shared(2); + auto query3 = std::make_shared(3); + + builder->add(query1, "field1"); + builder->add(query2, "field2"); + builder->add(query3, "field3"); + + auto query = builder->build(); + ASSERT_NE(nullptr, query); +} + +TEST_F(BooleanQueryBuilderTest, OperatorBooleanQueryBuilderMixedBindingKeys) { + auto builder = create_operator_boolean_query_builder(OperatorType::OP_OR); + + builder->add(std::make_shared(1), "key1"); + builder->add(std::make_shared(2)); + builder->add(std::make_shared(3), "key3"); + + auto query = builder->build(); + ASSERT_NE(nullptr, query); +} + +TEST_F(BooleanQueryBuilderTest, OperatorBooleanQueryBuilderEmptyBindingKey) { + auto builder = create_operator_boolean_query_builder(OperatorType::OP_AND); + + builder->add(std::make_shared(1), ""); + builder->add(std::make_shared(2), std::string {}); + + auto query = builder->build(); + ASSERT_NE(nullptr, query); +} + +TEST_F(BooleanQueryBuilderTest, OperatorBooleanQueryBuilderSingleQuery) { + auto builder = create_operator_boolean_query_builder(OperatorType::OP_AND); + builder->add(std::make_shared(1), "single_key"); + + auto query = builder->build(); + ASSERT_NE(nullptr, query); +} + +TEST_F(BooleanQueryBuilderTest, OperatorBooleanQueryBuilderManyQueries) { + auto builder = create_operator_boolean_query_builder(OperatorType::OP_OR); + + for (int i = 0; i < 100; ++i) { + builder->add(std::make_shared(i), "key_" + std::to_string(i)); + } + + auto query = builder->build(); + ASSERT_NE(nullptr, query); +} + +TEST_F(BooleanQueryBuilderTest, CreateOccurBooleanQueryBuilderFunction) { + auto builder1 = create_occur_boolean_query_builder(); + auto builder2 = create_occur_boolean_query_builder(); + + ASSERT_NE(nullptr, builder1); + ASSERT_NE(nullptr, builder2); + EXPECT_NE(builder1, builder2); +} + +TEST_F(BooleanQueryBuilderTest, CreateOperatorBooleanQueryBuilderFunction) { + auto builder_and = create_operator_boolean_query_builder(OperatorType::OP_AND); + auto builder_or = create_operator_boolean_query_builder(OperatorType::OP_OR); + auto builder_not = create_operator_boolean_query_builder(OperatorType::OP_NOT); + + ASSERT_NE(nullptr, builder_and); + ASSERT_NE(nullptr, builder_or); + ASSERT_NE(nullptr, builder_not); +} + +TEST_F(BooleanQueryBuilderTest, OccurBooleanQueryBuilderAddOrder) { + auto builder = create_occur_boolean_query_builder(); + std::vector queries; + std::vector occurs {Occur::MUST, Occur::SHOULD, Occur::MUST_NOT, Occur::MUST, + Occur::SHOULD}; + + for (size_t i = 0; i < occurs.size(); ++i) { + auto q = std::make_shared(static_cast(i)); + queries.push_back(q); + builder->add(q, occurs[i]); + } + + auto query = builder->build(); + auto occur_query = std::dynamic_pointer_cast(query); + ASSERT_NE(nullptr, occur_query); + ASSERT_EQ(occurs.size(), occur_query->clauses().size()); + + for (size_t i = 0; i < occurs.size(); ++i) { + EXPECT_EQ(occurs[i], occur_query->clauses()[i].first); + auto mock = std::dynamic_pointer_cast(occur_query->clauses()[i].second); + ASSERT_NE(nullptr, mock); + EXPECT_EQ(static_cast(i), mock->id()); + } +} + +TEST_F(BooleanQueryBuilderTest, OccurBooleanQueryBuilderSameQueryMultipleTimes) { + auto builder = create_occur_boolean_query_builder(); + auto shared_query = std::make_shared(42); + + builder->add(shared_query, Occur::MUST); + builder->add(shared_query, Occur::SHOULD); + builder->add(shared_query, Occur::MUST_NOT); + + auto query = builder->build(); + auto occur_query = std::dynamic_pointer_cast(query); + ASSERT_NE(nullptr, occur_query); + ASSERT_EQ(3u, occur_query->clauses().size()); + + for (const auto& [occur, q] : occur_query->clauses()) { + EXPECT_EQ(shared_query, q); + } +} + +TEST_F(BooleanQueryBuilderTest, OperatorBooleanQueryBuilderSameQueryMultipleTimes) { + auto builder = create_operator_boolean_query_builder(OperatorType::OP_AND); + auto shared_query = std::make_shared(99); + + builder->add(shared_query, "key1"); + builder->add(shared_query, "key2"); + builder->add(shared_query, "key3"); + + auto query = builder->build(); + ASSERT_NE(nullptr, query); +} + +TEST_F(BooleanQueryBuilderTest, NestedOccurBooleanQueries) { + auto inner_builder = create_occur_boolean_query_builder(); + inner_builder->add(std::make_shared(1), Occur::MUST); + inner_builder->add(std::make_shared(2), Occur::SHOULD); + auto inner_query = inner_builder->build(); + + auto outer_builder = create_occur_boolean_query_builder(); + outer_builder->add(inner_query, Occur::MUST); + outer_builder->add(std::make_shared(3), Occur::MUST_NOT); + + auto outer_query = outer_builder->build(); + auto occur_query = std::dynamic_pointer_cast(outer_query); + ASSERT_NE(nullptr, occur_query); + ASSERT_EQ(2u, occur_query->clauses().size()); + EXPECT_EQ(inner_query, occur_query->clauses()[0].second); +} + +TEST_F(BooleanQueryBuilderTest, NestedOperatorBooleanQueries) { + auto inner_builder = create_operator_boolean_query_builder(OperatorType::OP_OR); + inner_builder->add(std::make_shared(1), "inner1"); + inner_builder->add(std::make_shared(2), "inner2"); + auto inner_query = inner_builder->build(); + + auto outer_builder = create_operator_boolean_query_builder(OperatorType::OP_AND); + outer_builder->add(inner_query, "nested"); + outer_builder->add(std::make_shared(3), "outer1"); + + auto outer_query = outer_builder->build(); + ASSERT_NE(nullptr, outer_query); +} + +TEST_F(BooleanQueryBuilderTest, MixedNestedQueries) { + auto occur_builder = create_occur_boolean_query_builder(); + occur_builder->add(std::make_shared(1), Occur::MUST); + occur_builder->add(std::make_shared(2), Occur::SHOULD); + auto occur_query = occur_builder->build(); + + auto operator_builder = create_operator_boolean_query_builder(OperatorType::OP_AND); + operator_builder->add(occur_query, "occur_nested"); + operator_builder->add(std::make_shared(3), "simple"); + + auto final_query = operator_builder->build(); + ASSERT_NE(nullptr, final_query); +} + +} // namespace doris::segment_v2::inverted_index::query_v2 diff --git a/be/test/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query_test.cpp b/be/test/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query_test.cpp index e2a3cb479fefc0..1a12a231004ded 100644 --- a/be/test/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query_test.cpp +++ b/be/test/olap/rowset/segment_v2/inverted_index/query_v2/boolean_query_test.cpp @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -#include "olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/boolean_query.h" - #include #include @@ -29,7 +27,8 @@ #include "olap/rowset/segment_v2/index_query_context.h" #include "olap/rowset/segment_v2/inverted_index/analyzer/custom_analyzer.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/bit_set_query/bit_set_query.h" -#include "olap/rowset/segment_v2/inverted_index/query_v2/operator.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/boolean_query_builder.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/operator.h" #include "olap/rowset/segment_v2/inverted_index/query_v2/term_query/term_query.h" #include "olap/rowset/segment_v2/inverted_index/util/string_helper.h" @@ -129,9 +128,9 @@ static Status boolean_query_search( context->collection_statistics = std::make_shared(); context->collection_similarity = std::make_shared(); - query_v2::BooleanQuery::Builder builder(op); + query_v2::OperatorBooleanQueryBuilder builder(op); { - query_v2::BooleanQuery::Builder builder_child(query_v2::OperatorType::OP_AND); + query_v2::OperatorBooleanQueryBuilder builder_child(query_v2::OperatorType::OP_AND); for (const auto& term : terms.first) { std::wstring t = StringHelper::to_wstring(term); auto clause = std::make_shared(context, field, t); @@ -141,7 +140,7 @@ static Status boolean_query_search( builder.add(boolean_query, binding_key); } { - query_v2::BooleanQuery::Builder builder_child(query_v2::OperatorType::OP_OR); + query_v2::OperatorBooleanQueryBuilder builder_child(query_v2::OperatorType::OP_OR); for (const auto& term : terms.second) { std::wstring t = StringHelper::to_wstring(term); auto clause = std::make_shared(context, field, t); @@ -279,7 +278,7 @@ TEST_F(BooleanQueryTest, test_boolean_query_not_operation) { std::string binding_key = std::string("name1") + "#" + std::to_string(static_cast(query_type)); - query_v2::BooleanQuery::Builder builder(query_v2::OperatorType::OP_NOT); + query_v2::OperatorBooleanQueryBuilder builder(query_v2::OperatorType::OP_NOT); builder.add(std::make_shared(context, field, StringHelper::to_wstring("apple")), binding_key); @@ -317,7 +316,7 @@ TEST_F(BooleanQueryTest, test_boolean_query_or_with_not_operation) { auto reader_holder = make_shared_reader(lucene::index::IndexReader::open(dir, true)); ASSERT_TRUE(reader_holder != nullptr); - query_v2::BooleanQuery::Builder builder(query_v2::OperatorType::OP_OR); + query_v2::OperatorBooleanQueryBuilder builder(query_v2::OperatorType::OP_OR); auto query_type = segment_v2::InvertedIndexQueryType::EQUAL_QUERY; std::string include_key = std::string("name1") + "#" + std::to_string(static_cast(query_type)); @@ -325,7 +324,7 @@ TEST_F(BooleanQueryTest, test_boolean_query_or_with_not_operation) { StringHelper::to_wstring("apple")), include_key); { - query_v2::BooleanQuery::Builder not_builder(query_v2::OperatorType::OP_NOT); + query_v2::OperatorBooleanQueryBuilder not_builder(query_v2::OperatorType::OP_NOT); not_builder.add(std::make_shared(context, field, StringHelper::to_wstring("banana")), include_key); @@ -381,19 +380,19 @@ TEST_F(BooleanQueryTest, test_boolean_query_scoring_or) { auto reader_holder = make_shared_reader(lucene::index::IndexReader::open(dir, true)); ASSERT_TRUE(reader_holder != nullptr); - query_v2::BooleanQuery::Builder builder(query_v2::OperatorType::OP_OR); + query_v2::OperatorBooleanQueryBuilder builder(query_v2::OperatorType::OP_OR); auto query_type = segment_v2::InvertedIndexQueryType::EQUAL_QUERY; std::string binding_key = std::string("name1") + "#" + std::to_string(static_cast(query_type)); { - query_v2::BooleanQuery::Builder builder_child(query_v2::OperatorType::OP_AND); + query_v2::OperatorBooleanQueryBuilder builder_child(query_v2::OperatorType::OP_AND); auto clause = std::make_shared(context, field, StringHelper::to_wstring("apple")); builder_child.add(clause, binding_key); builder.add(builder_child.build(), binding_key); } { - query_v2::BooleanQuery::Builder builder_child(query_v2::OperatorType::OP_OR); + query_v2::OperatorBooleanQueryBuilder builder_child(query_v2::OperatorType::OP_OR); auto clause = std::make_shared(context, field, StringHelper::to_wstring("kiwi")); builder_child.add(clause, binding_key); @@ -466,7 +465,7 @@ TEST_F(BooleanQueryTest, test_boolean_query_cross_fields_with_composite_reader) context->collection_similarity = std::make_shared(); { - query_v2::BooleanQuery::Builder b(query_v2::OperatorType::OP_AND); + query_v2::OperatorBooleanQueryBuilder b(query_v2::OperatorType::OP_AND); b.add(std::make_shared(context, wfield1, StringHelper::to_wstring("apple")), binding1); @@ -487,7 +486,7 @@ TEST_F(BooleanQueryTest, test_boolean_query_cross_fields_with_composite_reader) } { - query_v2::BooleanQuery::Builder b(query_v2::OperatorType::OP_OR); + query_v2::OperatorBooleanQueryBuilder b(query_v2::OperatorType::OP_OR); b.add(std::make_shared(context, wfield1, StringHelper::to_wstring("apple")), binding1); @@ -542,7 +541,7 @@ TEST_F(BooleanQueryTest, test_boolean_query_bitmap_and_term) { } } - query_v2::BooleanQuery::Builder builder(query_v2::OperatorType::OP_AND); + query_v2::OperatorBooleanQueryBuilder builder(query_v2::OperatorType::OP_AND); builder.add(std::make_shared(context, field, StringHelper::to_wstring("apple"))); builder.add(std::make_shared(bm)); @@ -593,7 +592,7 @@ TEST_F(BooleanQueryTest, test_boolean_query_bitmap_or_term) { } } - query_v2::BooleanQuery::Builder builder(query_v2::OperatorType::OP_OR); + query_v2::OperatorBooleanQueryBuilder builder(query_v2::OperatorType::OP_OR); builder.add(std::make_shared(context, field, StringHelper::to_wstring("apple"))); builder.add(std::make_shared(bm)); diff --git a/be/test/olap/rowset/segment_v2/inverted_index/query_v2/buffered_union_test.cpp b/be/test/olap/rowset/segment_v2/inverted_index/query_v2/buffered_union_test.cpp new file mode 100644 index 00000000000000..bcd2b8b7fd85b0 --- /dev/null +++ b/be/test/olap/rowset/segment_v2/inverted_index/query_v2/buffered_union_test.cpp @@ -0,0 +1,684 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "olap/rowset/segment_v2/inverted_index/query_v2/union/buffered_union.h" + +#include + +#include +#include +#include +#include + +#include "olap/rowset/segment_v2/inverted_index/query_v2/doc_set.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/score_combiner.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/scorer.h" + +namespace doris::segment_v2::inverted_index::query_v2 { +namespace { + +// HORIZON constant used in buffered_union.cpp +constexpr uint32_t HORIZON = 64 * 64; // 4096 + +// Mock Scorer for testing +class MockScorer : public Scorer { +public: + MockScorer(std::vector docs, std::vector scores = {}, + uint32_t size_hint_val = 0) + : _docs(std::move(docs)), _scores(std::move(scores)), _size_hint_val(size_hint_val) { + if (_docs.empty()) { + _current_doc = TERMINATED; + } else { + std::ranges::sort(_docs); + _current_doc = _docs[0]; + } + if (_scores.size() != _docs.size()) { + _scores.resize(_docs.size(), 1.0F); + } + if (_size_hint_val == 0) { + _size_hint_val = static_cast(_docs.size()); + } + } + + uint32_t advance() override { + if (_docs.empty() || _index >= _docs.size()) { + _current_doc = TERMINATED; + return TERMINATED; + } + ++_index; + if (_index >= _docs.size()) { + _current_doc = TERMINATED; + return TERMINATED; + } + _current_doc = _docs[_index]; + return _current_doc; + } + + uint32_t seek(uint32_t target) override { + if (_docs.empty() || _index >= _docs.size()) { + _current_doc = TERMINATED; + return TERMINATED; + } + if (_current_doc >= target) { + return _current_doc; + } + auto it = std::lower_bound(_docs.begin() + static_cast(_index), _docs.end(), + target); + if (it == _docs.end()) { + _index = _docs.size(); + _current_doc = TERMINATED; + return TERMINATED; + } + _index = static_cast(it - _docs.begin()); + _current_doc = *it; + return _current_doc; + } + + uint32_t doc() const override { return _current_doc; } + + uint32_t size_hint() const override { return _size_hint_val; } + + float score() override { + if (_index >= _scores.size()) { + return 0.0F; + } + return _scores[_index]; + } + +private: + std::vector _docs; + std::vector _scores; + size_t _index = 0; + uint32_t _current_doc = TERMINATED; + uint32_t _size_hint_val = 0; +}; + +using MockScorerPtr = std::shared_ptr; + +} // anonymous namespace + +class BufferedUnionTest : public testing::Test { +public: + void SetUp() override {} + void TearDown() override {} + + // Helper to collect all docs from a scorer + std::vector collect_all_docs(ScorerPtr scorer) { + std::vector result; + while (scorer->doc() != TERMINATED) { + result.push_back(scorer->doc()); + if (scorer->advance() == TERMINATED) { + break; + } + } + return result; + } + + // Helper to collect docs and scores + std::pair, std::vector> collect_docs_and_scores(ScorerPtr scorer) { + std::vector docs; + std::vector scores; + while (scorer->doc() != TERMINATED) { + docs.push_back(scorer->doc()); + scores.push_back(scorer->score()); + if (scorer->advance() == TERMINATED) { + break; + } + } + return {docs, scores}; + } +}; + +TEST_F(BufferedUnionTest, MakeBufferedUnionWithEmptyScorers) { + std::vector scorers; + auto combiner = std::make_shared(); + auto union_scorer = make_buffered_union(scorers, combiner); + + EXPECT_EQ(union_scorer->doc(), TERMINATED); + EXPECT_EQ(union_scorer->advance(), TERMINATED); +} + +TEST_F(BufferedUnionTest, MakeBufferedUnionWithNullScorers) { + std::vector scorers; + scorers.push_back(nullptr); + scorers.push_back(nullptr); + + auto combiner = std::make_shared(); + auto union_scorer = make_buffered_union(scorers, combiner); + + EXPECT_EQ(union_scorer->doc(), TERMINATED); +} + +TEST_F(BufferedUnionTest, MakeBufferedUnionWithAllTerminatedScorers) { + std::vector scorers; + scorers.push_back(std::make_shared(std::vector {})); + scorers.push_back(std::make_shared(std::vector {})); + + auto combiner = std::make_shared(); + auto union_scorer = make_buffered_union(scorers, combiner); + + EXPECT_EQ(union_scorer->doc(), TERMINATED); +} + +TEST_F(BufferedUnionTest, MakeBufferedUnionWithMixedNullAndValid) { + std::vector scorers; + scorers.push_back(nullptr); + scorers.push_back(std::make_shared(std::vector {1, 5, 10})); + scorers.push_back(nullptr); + scorers.push_back(std::make_shared(std::vector {})); + + auto combiner = std::make_shared(); + auto union_scorer = make_buffered_union(scorers, combiner); + + auto docs = collect_all_docs(union_scorer); + EXPECT_EQ(docs, (std::vector {1, 5, 10})); +} + +TEST_F(BufferedUnionTest, MakeBufferedUnionWithDoNothingCombiner) { + std::vector scorers; + scorers.push_back(std::make_shared(std::vector {1, 3, 5})); + scorers.push_back(std::make_shared(std::vector {2, 4, 6})); + + auto combiner = std::make_shared(); + auto union_scorer = make_buffered_union(scorers, combiner); + + auto docs = collect_all_docs(union_scorer); + EXPECT_EQ(docs, (std::vector {1, 2, 3, 4, 5, 6})); + + EXPECT_FLOAT_EQ(union_scorer->score(), 0.0F); +} + +TEST_F(BufferedUnionTest, BasicAdvanceSingleScorer) { + std::vector scorers; + scorers.push_back(std::make_shared(std::vector {1, 5, 10, 15, 20})); + + auto combiner = std::make_shared(); + auto union_scorer = make_buffered_union(scorers, combiner); + + EXPECT_EQ(union_scorer->doc(), 1); + EXPECT_EQ(union_scorer->advance(), 5); + EXPECT_EQ(union_scorer->advance(), 10); + EXPECT_EQ(union_scorer->advance(), 15); + EXPECT_EQ(union_scorer->advance(), 20); + EXPECT_EQ(union_scorer->advance(), TERMINATED); +} + +TEST_F(BufferedUnionTest, BasicAdvanceTwoDisjointScorers) { + std::vector scorers; + scorers.push_back(std::make_shared(std::vector {1, 3, 5})); + scorers.push_back(std::make_shared(std::vector {2, 4, 6})); + + auto combiner = std::make_shared(); + auto union_scorer = make_buffered_union(scorers, combiner); + + auto docs = collect_all_docs(union_scorer); + EXPECT_EQ(docs, (std::vector {1, 2, 3, 4, 5, 6})); +} + +TEST_F(BufferedUnionTest, BasicAdvanceTwoOverlappingScorers) { + std::vector scorers; + scorers.push_back(std::make_shared(std::vector {1, 3, 5, 7})); + scorers.push_back(std::make_shared(std::vector {3, 5, 9})); + + auto combiner = std::make_shared(); + auto union_scorer = make_buffered_union(scorers, combiner); + + auto docs = collect_all_docs(union_scorer); + EXPECT_EQ(docs, (std::vector {1, 3, 5, 7, 9})); +} + +TEST_F(BufferedUnionTest, BasicAdvanceMultipleScorers) { + std::vector scorers; + scorers.push_back(std::make_shared(std::vector {1, 10, 20})); + scorers.push_back(std::make_shared(std::vector {5, 15, 25})); + scorers.push_back(std::make_shared(std::vector {3, 13, 23})); + + auto combiner = std::make_shared(); + auto union_scorer = make_buffered_union(scorers, combiner); + + auto docs = collect_all_docs(union_scorer); + EXPECT_EQ(docs, (std::vector {1, 3, 5, 10, 13, 15, 20, 23, 25})); +} + +TEST_F(BufferedUnionTest, BasicAdvanceIdenticalScorers) { + std::vector scorers; + scorers.push_back(std::make_shared(std::vector {1, 5, 10})); + scorers.push_back(std::make_shared(std::vector {1, 5, 10})); + + auto combiner = std::make_shared(); + auto union_scorer = make_buffered_union(scorers, combiner); + + auto docs = collect_all_docs(union_scorer); + EXPECT_EQ(docs, (std::vector {1, 5, 10})); +} + +TEST_F(BufferedUnionTest, SeekToExactDoc) { + std::vector scorers; + scorers.push_back(std::make_shared(std::vector {1, 5, 10, 15})); + scorers.push_back(std::make_shared(std::vector {2, 6, 12, 16})); + + auto combiner = std::make_shared(); + auto union_scorer = make_buffered_union(scorers, combiner); + + EXPECT_EQ(union_scorer->doc(), 1); + EXPECT_EQ(union_scorer->seek(10), 10); + EXPECT_EQ(union_scorer->doc(), 10); +} + +TEST_F(BufferedUnionTest, SeekToNonExistentDoc) { + std::vector scorers; + scorers.push_back(std::make_shared(std::vector {1, 5, 10, 15})); + scorers.push_back(std::make_shared(std::vector {2, 6, 12, 16})); + + auto combiner = std::make_shared(); + auto union_scorer = make_buffered_union(scorers, combiner); + + EXPECT_EQ(union_scorer->doc(), 1); + EXPECT_EQ(union_scorer->seek(7), 10); + EXPECT_EQ(union_scorer->doc(), 10); +} + +TEST_F(BufferedUnionTest, SeekBeyondAllDocs) { + std::vector scorers; + scorers.push_back(std::make_shared(std::vector {1, 5, 10})); + scorers.push_back(std::make_shared(std::vector {2, 6, 12})); + + auto combiner = std::make_shared(); + auto union_scorer = make_buffered_union(scorers, combiner); + + EXPECT_EQ(union_scorer->seek(100), TERMINATED); + EXPECT_EQ(union_scorer->doc(), TERMINATED); +} + +TEST_F(BufferedUnionTest, SeekToCurrentDoc) { + std::vector scorers; + scorers.push_back(std::make_shared(std::vector {1, 5, 10, 15})); + + auto combiner = std::make_shared(); + auto union_scorer = make_buffered_union(scorers, combiner); + + EXPECT_EQ(union_scorer->doc(), 1); + EXPECT_EQ(union_scorer->advance(), 5); + EXPECT_EQ(union_scorer->doc(), 5); + + EXPECT_EQ(union_scorer->seek(5), 5); + EXPECT_EQ(union_scorer->doc(), 5); +} + +TEST_F(BufferedUnionTest, SeekBackwards) { + std::vector scorers; + scorers.push_back(std::make_shared(std::vector {1, 5, 10, 15})); + + auto combiner = std::make_shared(); + auto union_scorer = make_buffered_union(scorers, combiner); + + EXPECT_EQ(union_scorer->seek(10), 10); + EXPECT_EQ(union_scorer->doc(), 10); + + EXPECT_EQ(union_scorer->seek(5), 10); + EXPECT_EQ(union_scorer->doc(), 10); +} + +TEST_F(BufferedUnionTest, SeekThenAdvance) { + std::vector scorers; + scorers.push_back(std::make_shared(std::vector {1, 5, 10, 15, 20})); + scorers.push_back(std::make_shared(std::vector {2, 6, 12, 16, 22})); + + auto combiner = std::make_shared(); + auto union_scorer = make_buffered_union(scorers, combiner); + + EXPECT_EQ(union_scorer->seek(8), 10); + EXPECT_EQ(union_scorer->advance(), 12); + EXPECT_EQ(union_scorer->advance(), 15); + EXPECT_EQ(union_scorer->advance(), 16); +} + +TEST_F(BufferedUnionTest, SeekWithinBuffer) { + std::vector scorers; + std::vector docs; + for (uint32_t i = 0; i < 100; i += 5) { + docs.push_back(i); + } + scorers.push_back(std::make_shared(docs)); + + auto combiner = std::make_shared(); + auto union_scorer = make_buffered_union(scorers, combiner); + + EXPECT_EQ(union_scorer->doc(), 0); + EXPECT_EQ(union_scorer->seek(50), 50); + EXPECT_EQ(union_scorer->doc(), 50); + EXPECT_EQ(union_scorer->advance(), 55); +} + +TEST_F(BufferedUnionTest, SeekOutsideBuffer) { + std::vector scorers; + std::vector docs = {0, 100, 5000, 5100, 10000}; + scorers.push_back(std::make_shared(docs)); + + auto combiner = std::make_shared(); + auto union_scorer = make_buffered_union(scorers, combiner); + + EXPECT_EQ(union_scorer->doc(), 0); + EXPECT_EQ(union_scorer->seek(5000), 5000); + EXPECT_EQ(union_scorer->doc(), 5000); + EXPECT_EQ(union_scorer->advance(), 5100); +} + +TEST_F(BufferedUnionTest, SeekFarOutsideBufferToTerminated) { + std::vector scorers; + std::vector docs = {0, 100, 200}; + scorers.push_back(std::make_shared(docs)); + + auto combiner = std::make_shared(); + auto union_scorer = make_buffered_union(scorers, combiner); + + EXPECT_EQ(union_scorer->seek(100000), TERMINATED); + EXPECT_EQ(union_scorer->doc(), TERMINATED); +} + +TEST_F(BufferedUnionTest, CrossHorizonBoundary) { + std::vector scorers; + std::vector docs; + for (uint32_t i = 0; i <= HORIZON + 500; i += 100) { + docs.push_back(i); + } + scorers.push_back(std::make_shared(docs)); + + auto combiner = std::make_shared(); + auto union_scorer = make_buffered_union(scorers, combiner); + + auto collected = collect_all_docs(union_scorer); + EXPECT_EQ(collected, docs); +} + +TEST_F(BufferedUnionTest, DocAtExactHorizonBoundary) { + std::vector scorers; + std::vector docs = {0, HORIZON - 1, HORIZON, HORIZON + 1, HORIZON * 2}; + scorers.push_back(std::make_shared(docs)); + + auto combiner = std::make_shared(); + auto union_scorer = make_buffered_union(scorers, combiner); + + auto collected = collect_all_docs(union_scorer); + EXPECT_EQ(collected, docs); +} + +TEST_F(BufferedUnionTest, MultipleScorersAcrossHorizon) { + std::vector scorers; + scorers.push_back(std::make_shared(std::vector {0, 1000, HORIZON + 100})); + scorers.push_back( + std::make_shared(std::vector {500, HORIZON - 1, HORIZON + 500})); + + auto combiner = std::make_shared(); + auto union_scorer = make_buffered_union(scorers, combiner); + + auto collected = collect_all_docs(union_scorer); + std::vector expected = {0, 500, 1000, HORIZON - 1, HORIZON + 100, HORIZON + 500}; + EXPECT_EQ(collected, expected); +} + +TEST_F(BufferedUnionTest, ScoringWithSumCombiner) { + std::vector scorers; + scorers.push_back(std::make_shared(std::vector {1, 3, 5}, + std::vector {1.0F, 2.0F, 3.0F})); + scorers.push_back(std::make_shared(std::vector {2, 3, 6}, + std::vector {0.5F, 1.5F, 2.5F})); + + auto combiner = std::make_shared(); + auto union_scorer = make_buffered_union(scorers, combiner); + + auto [docs, scores] = collect_docs_and_scores(union_scorer); + + std::vector expected_docs = {1, 2, 3, 5, 6}; + std::vector expected_scores = {1.0F, 0.5F, 3.5F, 3.0F, 2.5F}; + + EXPECT_EQ(docs, expected_docs); + ASSERT_EQ(scores.size(), expected_scores.size()); + for (size_t i = 0; i < scores.size(); ++i) { + EXPECT_FLOAT_EQ(scores[i], expected_scores[i]); + } +} + +TEST_F(BufferedUnionTest, ScoringThreeOverlappingScorers) { + std::vector scorers; + scorers.push_back( + std::make_shared(std::vector {5}, std::vector {1.0F})); + scorers.push_back( + std::make_shared(std::vector {5}, std::vector {2.0F})); + scorers.push_back( + std::make_shared(std::vector {5}, std::vector {3.0F})); + + auto combiner = std::make_shared(); + auto union_scorer = make_buffered_union(scorers, combiner); + + EXPECT_EQ(union_scorer->doc(), 5); + EXPECT_FLOAT_EQ(union_scorer->score(), 6.0F); +} + +TEST_F(BufferedUnionTest, ScoringWithDoNothingCombiner) { + std::vector scorers; + scorers.push_back(std::make_shared(std::vector {1, 3}, + std::vector {5.0F, 10.0F})); + scorers.push_back(std::make_shared(std::vector {2, 3}, + std::vector {3.0F, 7.0F})); + + auto combiner = std::make_shared(); + auto union_scorer = make_buffered_union(scorers, combiner); + + while (union_scorer->doc() != TERMINATED) { + EXPECT_FLOAT_EQ(union_scorer->score(), 0.0F); + union_scorer->advance(); + } +} + +TEST_F(BufferedUnionTest, SizeHintReturnsMaxOfScorers) { + std::vector scorers; + scorers.push_back(std::make_shared(std::vector {1, HORIZON + 100}, + std::vector {}, 10)); + scorers.push_back(std::make_shared(std::vector {2, HORIZON + 200}, + std::vector {}, 20)); + scorers.push_back(std::make_shared(std::vector {3, HORIZON + 300}, + std::vector {}, 15)); + + auto combiner = std::make_shared(); + auto union_scorer = make_buffered_union(scorers, combiner); + + EXPECT_EQ(union_scorer->size_hint(), 20); +} + +TEST_F(BufferedUnionTest, SizeHintSingleScorer) { + std::vector scorers; + scorers.push_back(std::make_shared(std::vector {1, HORIZON + 100}, + std::vector {}, 100)); + + auto combiner = std::make_shared(); + auto union_scorer = make_buffered_union(scorers, combiner); + + EXPECT_EQ(union_scorer->size_hint(), 100); +} + +TEST_F(BufferedUnionTest, SingleDocScorer) { + std::vector scorers; + scorers.push_back(std::make_shared(std::vector {42})); + + auto combiner = std::make_shared(); + auto union_scorer = make_buffered_union(scorers, combiner); + + EXPECT_EQ(union_scorer->doc(), 42); + EXPECT_EQ(union_scorer->advance(), TERMINATED); +} + +TEST_F(BufferedUnionTest, LargeDocIds) { + std::vector scorers; + std::vector docs = {1000000, 2000000, 3000000}; + scorers.push_back(std::make_shared(docs)); + + auto combiner = std::make_shared(); + auto union_scorer = make_buffered_union(scorers, combiner); + + auto collected = collect_all_docs(union_scorer); + EXPECT_EQ(collected, docs); +} + +TEST_F(BufferedUnionTest, ConsecutiveDocIds) { + std::vector scorers; + std::vector docs; + for (uint32_t i = 0; i < 100; ++i) { + docs.push_back(i); + } + scorers.push_back(std::make_shared(docs)); + + auto combiner = std::make_shared(); + auto union_scorer = make_buffered_union(scorers, combiner); + + auto collected = collect_all_docs(union_scorer); + EXPECT_EQ(collected, docs); +} + +TEST_F(BufferedUnionTest, ManyScorersWithSingleDoc) { + std::vector scorers; + for (uint32_t i = 0; i < 20; ++i) { + scorers.push_back(std::make_shared(std::vector {i * 10})); + } + + auto combiner = std::make_shared(); + auto union_scorer = make_buffered_union(scorers, combiner); + + std::vector expected; + for (uint32_t i = 0; i < 20; ++i) { + expected.push_back(i * 10); + } + + auto collected = collect_all_docs(union_scorer); + EXPECT_EQ(collected, expected); +} + +TEST_F(BufferedUnionTest, AllScorersPointToSameDoc) { + std::vector scorers; + for (int i = 0; i < 10; ++i) { + scorers.push_back(std::make_shared(std::vector {100}, + std::vector {1.0F})); + } + + auto combiner = std::make_shared(); + auto union_scorer = make_buffered_union(scorers, combiner); + + EXPECT_EQ(union_scorer->doc(), 100); + EXPECT_FLOAT_EQ(union_scorer->score(), 10.0F); + EXPECT_EQ(union_scorer->advance(), TERMINATED); +} + +TEST_F(BufferedUnionTest, RefillAfterAllScorersExhausted) { + std::vector scorers; + scorers.push_back(std::make_shared(std::vector {1, 2})); + + auto combiner = std::make_shared(); + auto union_scorer = make_buffered_union(scorers, combiner); + + EXPECT_EQ(union_scorer->doc(), 1); + EXPECT_EQ(union_scorer->advance(), 2); + EXPECT_EQ(union_scorer->advance(), TERMINATED); + EXPECT_EQ(union_scorer->advance(), TERMINATED); + EXPECT_EQ(union_scorer->doc(), TERMINATED); +} + +TEST_F(BufferedUnionTest, LargeNumberOfDocs) { + std::vector scorers; + std::vector docs; + for (uint32_t i = 0; i < 10000; ++i) { + docs.push_back(i * 2); + } + scorers.push_back(std::make_shared(docs)); + + std::vector docs2; + for (uint32_t i = 0; i < 10000; ++i) { + docs2.push_back(i * 2 + 1); + } + scorers.push_back(std::make_shared(docs2)); + + auto combiner = std::make_shared(); + auto union_scorer = make_buffered_union(scorers, combiner); + + uint32_t expected_doc = 0; + while (union_scorer->doc() != TERMINATED) { + EXPECT_EQ(union_scorer->doc(), expected_doc); + union_scorer->advance(); + expected_doc++; + } + EXPECT_EQ(expected_doc, 20000u); +} + +TEST_F(BufferedUnionTest, ManySmallScorers) { + std::vector scorers; + for (uint32_t s = 0; s < 100; ++s) { + std::vector docs; + for (uint32_t d = 0; d < 10; ++d) { + docs.push_back(s * 100 + d * 10); + } + scorers.push_back(std::make_shared(docs)); + } + + auto combiner = std::make_shared(); + auto union_scorer = make_buffered_union(scorers, combiner); + + uint32_t count = 0; + while (union_scorer->doc() != TERMINATED) { + count++; + union_scorer->advance(); + } + + EXPECT_EQ(count, 1000u); +} + +TEST_F(BufferedUnionTest, ComplexSeekAdvanceSequence) { + std::vector scorers; + std::vector docs; + for (uint32_t i = 0; i < 200; i += 2) { + docs.push_back(i); + } + scorers.push_back(std::make_shared(docs)); + + auto combiner = std::make_shared(); + auto union_scorer = make_buffered_union(scorers, combiner); + + EXPECT_EQ(union_scorer->doc(), 0); + EXPECT_EQ(union_scorer->advance(), 2); + EXPECT_EQ(union_scorer->seek(50), 50); + EXPECT_EQ(union_scorer->advance(), 52); + EXPECT_EQ(union_scorer->seek(100), 100); + EXPECT_EQ(union_scorer->seek(99), 100); + EXPECT_EQ(union_scorer->advance(), 102); + EXPECT_EQ(union_scorer->seek(1000), TERMINATED); +} + +TEST_F(BufferedUnionTest, DocReturnsCorrectValue) { + std::vector scorers; + scorers.push_back(std::make_shared(std::vector {10, 20, 30})); + scorers.push_back(std::make_shared(std::vector {15, 25, 35})); + + auto combiner = std::make_shared(); + auto union_scorer = make_buffered_union(scorers, combiner); + + std::vector expected = {10, 15, 20, 25, 30, 35}; + for (uint32_t exp_doc : expected) { + EXPECT_EQ(union_scorer->doc(), exp_doc); + union_scorer->advance(); + } + EXPECT_EQ(union_scorer->doc(), TERMINATED); +} + +} // namespace doris::segment_v2::inverted_index::query_v2 diff --git a/be/test/olap/rowset/segment_v2/inverted_index/query_v2/disjunction_scorer_test.cpp b/be/test/olap/rowset/segment_v2/inverted_index/query_v2/disjunction_scorer_test.cpp new file mode 100644 index 00000000000000..7b2285457d49aa --- /dev/null +++ b/be/test/olap/rowset/segment_v2/inverted_index/query_v2/disjunction_scorer_test.cpp @@ -0,0 +1,389 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "olap/rowset/segment_v2/inverted_index/query_v2/disjunction_scorer.h" + +#include + +#include +#include +#include +#include + +#include "olap/rowset/segment_v2/inverted_index/query_v2/score_combiner.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/scorer.h" + +namespace doris::segment_v2::inverted_index::query_v2 { + +namespace { + +class VectorScorer final : public Scorer { +public: + VectorScorer(std::vector docs, std::vector scores = {}, + uint32_t size_hint_val = 0) + : _docs(std::move(docs)), _scores(std::move(scores)), _size_hint_val(size_hint_val) { + if (_docs.empty()) { + _current_doc = TERMINATED; + } else { + std::sort(_docs.begin(), _docs.end()); + _current_doc = _docs[0]; + } + if (_scores.size() != _docs.size()) { + _scores.resize(_docs.size(), 1.0F); + } + if (_size_hint_val == 0) { + _size_hint_val = static_cast(_docs.size()); + } + } + + uint32_t advance() override { + if (_docs.empty() || _index >= _docs.size()) { + _current_doc = TERMINATED; + return TERMINATED; + } + ++_index; + if (_index >= _docs.size()) { + _current_doc = TERMINATED; + return TERMINATED; + } + _current_doc = _docs[_index]; + return _current_doc; + } + + uint32_t seek(uint32_t target) override { + while (_current_doc < target && _current_doc != TERMINATED) { + advance(); + } + return _current_doc; + } + + uint32_t doc() const override { return _current_doc; } + uint32_t size_hint() const override { return _size_hint_val; } + + float score() override { + if (_index >= _scores.size()) { + return 0.0F; + } + return _scores[_index]; + } + +private: + std::vector _docs; + std::vector _scores; + size_t _index = 0; + uint32_t _current_doc = TERMINATED; + uint32_t _size_hint_val = 0; +}; + +std::vector compute_expected(const std::vector>& arrays, + size_t pass_line) { + std::map counts; + for (const auto& array : arrays) { + for (uint32_t element : array) { + counts[element]++; + } + } + std::vector result; + for (const auto& [element, count] : counts) { + if (count >= pass_line) { + result.push_back(element); + } + } + return result; +} + +} // namespace + +class DisjunctionScorerTest : public ::testing::Test {}; + +TEST_F(DisjunctionScorerTest, BasicDisjunctionMinMatch2) { + std::vector scorers; + scorers.push_back(std::make_shared(std::vector {1, 3333, 100000000})); + scorers.push_back(std::make_shared(std::vector {1, 2, 100000000})); + scorers.push_back(std::make_shared(std::vector {1, 2, 100000000})); + + auto combiner = std::make_shared(); + auto scorer = make_disjunction(std::move(scorers), combiner, 2); + ASSERT_NE(nullptr, scorer); + + std::vector docs; + while (scorer->doc() != TERMINATED) { + docs.push_back(scorer->doc()); + scorer->advance(); + } + + std::vector expected {1, 2, 100000000}; + EXPECT_EQ(expected, docs); +} + +TEST_F(DisjunctionScorerTest, BasicDisjunctionMinMatch3) { + std::vector scorers; + scorers.push_back(std::make_shared(std::vector {1, 3333, 100000000})); + scorers.push_back(std::make_shared(std::vector {1, 2, 100000000})); + scorers.push_back(std::make_shared(std::vector {1, 2, 100000000})); + + auto combiner = std::make_shared(); + auto scorer = make_disjunction(std::move(scorers), combiner, 3); + ASSERT_NE(nullptr, scorer); + + std::vector docs; + while (scorer->doc() != TERMINATED) { + docs.push_back(scorer->doc()); + scorer->advance(); + } + + std::vector expected {1, 100000000}; + EXPECT_EQ(expected, docs); +} + +TEST_F(DisjunctionScorerTest, NoIntersection) { + std::vector scorers; + scorers.push_back(std::make_shared(std::vector {8})); + scorers.push_back(std::make_shared(std::vector {3, 4, 0xC0FFEE})); + scorers.push_back(std::make_shared(std::vector {1, 2, 100000000})); + + auto combiner = std::make_shared(); + auto scorer = make_disjunction(std::move(scorers), combiner, 2); + ASSERT_NE(nullptr, scorer); + + EXPECT_EQ(TERMINATED, scorer->doc()); +} + +TEST_F(DisjunctionScorerTest, ScoreCalculation) { + std::vector scorers; + scorers.push_back(std::make_shared(std::vector {1, 2}, + std::vector {1.0F, 1.0F})); + scorers.push_back(std::make_shared(std::vector {1, 3}, + std::vector {1.0F, 1.0F})); + scorers.push_back(std::make_shared(std::vector {1, 4}, + std::vector {1.0F, 1.0F})); + scorers.push_back(std::make_shared(std::vector {1, 2}, + std::vector {1.0F, 1.0F})); + scorers.push_back(std::make_shared(std::vector {1, 2}, + std::vector {1.0F, 1.0F})); + + auto combiner = std::make_shared(); + auto scorer = make_disjunction(std::move(scorers), combiner, 3); + ASSERT_NE(nullptr, scorer); + + EXPECT_EQ(1u, scorer->doc()); + EXPECT_FLOAT_EQ(5.0F, scorer->score()); + + EXPECT_EQ(2u, scorer->advance()); + EXPECT_FLOAT_EQ(3.0F, scorer->score()); + + EXPECT_EQ(TERMINATED, scorer->advance()); +} + +TEST_F(DisjunctionScorerTest, ScoreCalculationCornerCase) { + std::vector scorers; + scorers.push_back(std::make_shared(std::vector {1, 2}, + std::vector {1.0F, 1.0F})); + scorers.push_back(std::make_shared(std::vector {1, 3}, + std::vector {1.0F, 1.0F})); + scorers.push_back(std::make_shared(std::vector {1, 3}, + std::vector {1.0F, 1.0F})); + + auto combiner = std::make_shared(); + auto scorer = make_disjunction(std::move(scorers), combiner, 2); + ASSERT_NE(nullptr, scorer); + + EXPECT_EQ(1u, scorer->doc()); + EXPECT_FLOAT_EQ(3.0F, scorer->score()); + + EXPECT_EQ(3u, scorer->advance()); + EXPECT_FLOAT_EQ(2.0F, scorer->score()); + + EXPECT_EQ(TERMINATED, scorer->advance()); +} + +TEST_F(DisjunctionScorerTest, EmptyScorers) { + std::vector scorers; + auto combiner = std::make_shared(); + auto scorer = make_disjunction(std::move(scorers), combiner, 2); + ASSERT_NE(nullptr, scorer); + + EXPECT_EQ(TERMINATED, scorer->doc()); +} + +TEST_F(DisjunctionScorerTest, MinMatchExceedsScorers) { + std::vector scorers; + scorers.push_back(std::make_shared(std::vector {1, 2, 3})); + scorers.push_back(std::make_shared(std::vector {1, 2, 3})); + + auto combiner = std::make_shared(); + auto scorer = make_disjunction(std::move(scorers), combiner, 5); + ASSERT_NE(nullptr, scorer); + + EXPECT_EQ(TERMINATED, scorer->doc()); +} + +TEST_F(DisjunctionScorerTest, AllEmptyScorers) { + std::vector scorers; + scorers.push_back(std::make_shared(std::vector {})); + scorers.push_back(std::make_shared(std::vector {})); + scorers.push_back(std::make_shared(std::vector {})); + + auto combiner = std::make_shared(); + auto scorer = make_disjunction(std::move(scorers), combiner, 2); + ASSERT_NE(nullptr, scorer); + + EXPECT_EQ(TERMINATED, scorer->doc()); +} + +TEST_F(DisjunctionScorerTest, SeekFunctionality) { + std::vector scorers; + scorers.push_back(std::make_shared(std::vector {1, 5, 10, 20})); + scorers.push_back(std::make_shared(std::vector {1, 5, 15, 20})); + scorers.push_back(std::make_shared(std::vector {1, 5, 10, 20})); + + auto combiner = std::make_shared(); + auto scorer = make_disjunction(std::move(scorers), combiner, 2); + ASSERT_NE(nullptr, scorer); + + EXPECT_EQ(1u, scorer->doc()); + EXPECT_EQ(10u, scorer->seek(8)); + EXPECT_EQ(20u, scorer->seek(15)); + EXPECT_EQ(TERMINATED, scorer->seek(100)); +} + +TEST_F(DisjunctionScorerTest, SeekToCurrentDoc) { + std::vector scorers; + scorers.push_back(std::make_shared(std::vector {5, 10, 15})); + scorers.push_back(std::make_shared(std::vector {5, 10, 15})); + + auto combiner = std::make_shared(); + auto scorer = make_disjunction(std::move(scorers), combiner, 2); + ASSERT_NE(nullptr, scorer); + + EXPECT_EQ(5u, scorer->doc()); + EXPECT_EQ(5u, scorer->seek(5)); + EXPECT_EQ(5u, scorer->seek(3)); +} + +TEST_F(DisjunctionScorerTest, SizeHint) { + std::vector scorers; + scorers.push_back(std::make_shared(std::vector {1, 2}, + std::vector {}, 100)); + scorers.push_back(std::make_shared(std::vector {1, 2}, + std::vector {}, 50)); + scorers.push_back(std::make_shared(std::vector {1, 2}, + std::vector {}, 200)); + + auto combiner = std::make_shared(); + auto scorer = make_disjunction(std::move(scorers), combiner, 2); + ASSERT_NE(nullptr, scorer); + + EXPECT_EQ(200u, scorer->size_hint()); +} + +TEST_F(DisjunctionScorerTest, AdvanceAfterTerminated) { + std::vector scorers; + scorers.push_back(std::make_shared(std::vector {1})); + scorers.push_back(std::make_shared(std::vector {1})); + + auto combiner = std::make_shared(); + auto scorer = make_disjunction(std::move(scorers), combiner, 2); + ASSERT_NE(nullptr, scorer); + + EXPECT_EQ(1u, scorer->doc()); + EXPECT_EQ(TERMINATED, scorer->advance()); + EXPECT_EQ(TERMINATED, scorer->advance()); + EXPECT_EQ(TERMINATED, scorer->advance()); +} + +TEST_F(DisjunctionScorerTest, SeekAfterTerminated) { + std::vector scorers; + scorers.push_back(std::make_shared(std::vector {1})); + scorers.push_back(std::make_shared(std::vector {1})); + + auto combiner = std::make_shared(); + auto scorer = make_disjunction(std::move(scorers), combiner, 2); + ASSERT_NE(nullptr, scorer); + + EXPECT_EQ(1u, scorer->doc()); + EXPECT_EQ(TERMINATED, scorer->advance()); + EXPECT_EQ(TERMINATED, scorer->seek(100)); +} + +TEST_F(DisjunctionScorerTest, NullScorersAreIgnored) { + std::vector scorers; + scorers.push_back(std::make_shared(std::vector {1, 5, 10})); + scorers.push_back(nullptr); + scorers.push_back(std::make_shared(std::vector {1, 5, 10})); + scorers.push_back(nullptr); + + auto combiner = std::make_shared(); + auto scorer = make_disjunction(std::move(scorers), combiner, 2); + ASSERT_NE(nullptr, scorer); + + std::vector docs; + while (scorer->doc() != TERMINATED) { + docs.push_back(scorer->doc()); + scorer->advance(); + } + + std::vector expected {1, 5, 10}; + EXPECT_EQ(expected, docs); +} + +TEST_F(DisjunctionScorerTest, SingleDocMultipleScorers) { + std::vector scorers; + scorers.push_back(std::make_shared(std::vector {42})); + scorers.push_back(std::make_shared(std::vector {42})); + scorers.push_back(std::make_shared(std::vector {42})); + scorers.push_back(std::make_shared(std::vector {42})); + + auto combiner = std::make_shared(); + auto scorer = make_disjunction(std::move(scorers), combiner, 3); + ASSERT_NE(nullptr, scorer); + + EXPECT_EQ(42u, scorer->doc()); + EXPECT_FLOAT_EQ(4.0F, scorer->score()); + EXPECT_EQ(TERMINATED, scorer->advance()); +} + +TEST_F(DisjunctionScorerTest, LargeDataSet) { + std::vector> data; + for (int i = 0; i < 100; ++i) { + std::vector docs; + for (uint32_t j = i; j < 10000; j += 100) { + docs.push_back(j); + } + data.push_back(docs); + } + + std::vector scorers; + for (const auto& docs : data) { + scorers.push_back(std::make_shared(docs)); + } + + auto combiner = std::make_shared(); + auto scorer = make_disjunction(std::move(scorers), combiner, 2); + ASSERT_NE(nullptr, scorer); + + auto expected = compute_expected(data, 2); + std::vector actual; + while (scorer->doc() != TERMINATED) { + actual.push_back(scorer->doc()); + scorer->advance(); + } + + EXPECT_EQ(expected, actual); +} + +} // namespace doris::segment_v2::inverted_index::query_v2 \ No newline at end of file diff --git a/be/test/olap/rowset/segment_v2/inverted_index/query_v2/exclude_scorer_test.cpp b/be/test/olap/rowset/segment_v2/inverted_index/query_v2/exclude_scorer_test.cpp new file mode 100644 index 00000000000000..ee941e7338caa7 --- /dev/null +++ b/be/test/olap/rowset/segment_v2/inverted_index/query_v2/exclude_scorer_test.cpp @@ -0,0 +1,569 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "olap/rowset/segment_v2/inverted_index/query_v2/exclude_scorer.h" + +#include + +#include +#include +#include +#include +#include + +#include "olap/rowset/segment_v2/inverted_index/query_v2/scorer.h" + +namespace doris::segment_v2::inverted_index::query_v2 { + +namespace { + +class VectorScorer final : public Scorer { +public: + VectorScorer(std::vector docs, std::vector scores = {}, + uint32_t size_hint_val = 0) + : _docs(std::move(docs)), _scores(std::move(scores)), _size_hint_val(size_hint_val) { + if (_docs.empty()) { + _current_doc = TERMINATED; + } else { + std::sort(_docs.begin(), _docs.end()); + _current_doc = _docs[0]; + } + if (_scores.size() != _docs.size()) { + _scores.resize(_docs.size(), 1.0F); + } + if (_size_hint_val == 0) { + _size_hint_val = static_cast(_docs.size()); + } + } + + uint32_t advance() override { + if (_docs.empty() || _index >= _docs.size()) { + _current_doc = TERMINATED; + return TERMINATED; + } + ++_index; + if (_index >= _docs.size()) { + _current_doc = TERMINATED; + return TERMINATED; + } + _current_doc = _docs[_index]; + return _current_doc; + } + + uint32_t seek(uint32_t target) override { + while (_current_doc < target && _current_doc != TERMINATED) { + advance(); + } + return _current_doc; + } + + uint32_t doc() const override { return _current_doc; } + uint32_t size_hint() const override { return _size_hint_val; } + + float score() override { + if (_index >= _scores.size()) { + return 0.0F; + } + return _scores[_index]; + } + +private: + std::vector _docs; + std::vector _scores; + size_t _index = 0; + uint32_t _current_doc = TERMINATED; + uint32_t _size_hint_val = 0; +}; + +std::vector compute_exclude(const std::vector& underlying, + const std::vector& excluding) { + std::set exclude_set(excluding.begin(), excluding.end()); + std::vector result; + for (uint32_t doc : underlying) { + if (exclude_set.find(doc) == exclude_set.end()) { + result.push_back(doc); + } + } + return result; +} + +std::vector sample_with_seed(uint32_t max_doc, double probability, uint32_t seed) { + std::mt19937 gen(seed); + std::uniform_real_distribution<> dis(0.0, 1.0); + std::vector result; + for (uint32_t i = 0; i < max_doc; ++i) { + if (dis(gen) < probability) { + result.push_back(i); + } + } + return result; +} + +} // namespace + +class ExcludeScorerTest : public ::testing::Test {}; + +TEST_F(ExcludeScorerTest, BasicExclude) { + auto underlying = + std::make_shared(std::vector {1, 2, 5, 8, 10, 15, 24}); + auto excluding = std::make_shared(std::vector {1, 2, 3, 10, 16, 24}); + + auto scorer = make_exclude(std::move(underlying), std::move(excluding)); + ASSERT_NE(nullptr, scorer); + + std::vector docs; + while (scorer->doc() != TERMINATED) { + docs.push_back(scorer->doc()); + scorer->advance(); + } + + std::vector expected {5, 8, 15}; + EXPECT_EQ(expected, docs); +} + +TEST_F(ExcludeScorerTest, NoExclusions) { + auto underlying = std::make_shared(std::vector {1, 2, 3, 4, 5}); + auto excluding = std::make_shared(std::vector {}); + + auto scorer = make_exclude(std::move(underlying), std::move(excluding)); + ASSERT_NE(nullptr, scorer); + + std::vector docs; + while (scorer->doc() != TERMINATED) { + docs.push_back(scorer->doc()); + scorer->advance(); + } + + std::vector expected {1, 2, 3, 4, 5}; + EXPECT_EQ(expected, docs); +} + +TEST_F(ExcludeScorerTest, AllExcluded) { + auto underlying = std::make_shared(std::vector {1, 2, 3}); + auto excluding = std::make_shared(std::vector {1, 2, 3}); + + auto scorer = make_exclude(std::move(underlying), std::move(excluding)); + ASSERT_NE(nullptr, scorer); + + EXPECT_EQ(TERMINATED, scorer->doc()); +} + +TEST_F(ExcludeScorerTest, EmptyUnderlying) { + auto underlying = std::make_shared(std::vector {}); + auto excluding = std::make_shared(std::vector {1, 2, 3}); + + auto scorer = make_exclude(std::move(underlying), std::move(excluding)); + ASSERT_NE(nullptr, scorer); + + EXPECT_EQ(TERMINATED, scorer->doc()); +} + +TEST_F(ExcludeScorerTest, BothEmpty) { + auto underlying = std::make_shared(std::vector {}); + auto excluding = std::make_shared(std::vector {}); + + auto scorer = make_exclude(std::move(underlying), std::move(excluding)); + ASSERT_NE(nullptr, scorer); + + EXPECT_EQ(TERMINATED, scorer->doc()); +} + +TEST_F(ExcludeScorerTest, NoIntersection) { + auto underlying = std::make_shared(std::vector {1, 3, 5, 7, 9}); + auto excluding = std::make_shared(std::vector {2, 4, 6, 8, 10}); + + auto scorer = make_exclude(std::move(underlying), std::move(excluding)); + ASSERT_NE(nullptr, scorer); + + std::vector docs; + while (scorer->doc() != TERMINATED) { + docs.push_back(scorer->doc()); + scorer->advance(); + } + + std::vector expected {1, 3, 5, 7, 9}; + EXPECT_EQ(expected, docs); +} + +TEST_F(ExcludeScorerTest, SeekBasic) { + auto underlying = + std::make_shared(std::vector {1, 2, 5, 8, 10, 15, 24}); + auto excluding = std::make_shared(std::vector {1, 2, 3, 10, 16, 24}); + + auto scorer = make_exclude(std::move(underlying), std::move(excluding)); + ASSERT_NE(nullptr, scorer); + + EXPECT_EQ(5u, scorer->doc()); + EXPECT_EQ(8u, scorer->seek(8)); + EXPECT_EQ(15u, scorer->seek(10)); + EXPECT_EQ(TERMINATED, scorer->seek(24)); +} + +TEST_F(ExcludeScorerTest, SeekToExcludedDoc) { + auto underlying = std::make_shared(std::vector {5, 10, 15, 20, 25}); + auto excluding = std::make_shared(std::vector {10, 20}); + + auto scorer = make_exclude(std::move(underlying), std::move(excluding)); + ASSERT_NE(nullptr, scorer); + + EXPECT_EQ(5u, scorer->doc()); + EXPECT_EQ(15u, scorer->seek(10)); + EXPECT_EQ(25u, scorer->seek(20)); +} + +TEST_F(ExcludeScorerTest, SeekToCurrentDoc) { + auto underlying = std::make_shared(std::vector {5, 10, 15}); + auto excluding = std::make_shared(std::vector {}); + + auto scorer = make_exclude(std::move(underlying), std::move(excluding)); + ASSERT_NE(nullptr, scorer); + + EXPECT_EQ(5u, scorer->doc()); + EXPECT_EQ(5u, scorer->seek(5)); + EXPECT_EQ(5u, scorer->seek(3)); +} + +TEST_F(ExcludeScorerTest, SeekPastEnd) { + auto underlying = std::make_shared(std::vector {1, 5, 10}); + auto excluding = std::make_shared(std::vector {}); + + auto scorer = make_exclude(std::move(underlying), std::move(excluding)); + ASSERT_NE(nullptr, scorer); + + EXPECT_EQ(TERMINATED, scorer->seek(100)); +} + +TEST_F(ExcludeScorerTest, AdvanceAfterTerminated) { + auto underlying = std::make_shared(std::vector {1}); + auto excluding = std::make_shared(std::vector {}); + + auto scorer = make_exclude(std::move(underlying), std::move(excluding)); + ASSERT_NE(nullptr, scorer); + + EXPECT_EQ(1u, scorer->doc()); + EXPECT_EQ(TERMINATED, scorer->advance()); + EXPECT_EQ(TERMINATED, scorer->advance()); + EXPECT_EQ(TERMINATED, scorer->advance()); +} + +TEST_F(ExcludeScorerTest, SeekAfterTerminated) { + auto underlying = std::make_shared(std::vector {1}); + auto excluding = std::make_shared(std::vector {}); + + auto scorer = make_exclude(std::move(underlying), std::move(excluding)); + ASSERT_NE(nullptr, scorer); + + EXPECT_EQ(1u, scorer->doc()); + EXPECT_EQ(TERMINATED, scorer->advance()); + EXPECT_EQ(TERMINATED, scorer->seek(50)); +} + +TEST_F(ExcludeScorerTest, ScoreDelegation) { + auto underlying = std::make_shared(std::vector {1, 5, 10}, + std::vector {1.5F, 2.5F, 3.5F}); + auto excluding = std::make_shared(std::vector {5}); + + auto scorer = make_exclude(std::move(underlying), std::move(excluding)); + ASSERT_NE(nullptr, scorer); + + EXPECT_EQ(1u, scorer->doc()); + EXPECT_FLOAT_EQ(1.5F, scorer->score()); + EXPECT_EQ(10u, scorer->advance()); + EXPECT_FLOAT_EQ(3.5F, scorer->score()); +} + +TEST_F(ExcludeScorerTest, SizeHintDelegation) { + auto underlying = std::make_shared(std::vector {1, 2, 3}, + std::vector {}, 100); + auto excluding = std::make_shared(std::vector {2}); + + auto scorer = make_exclude(std::move(underlying), std::move(excluding)); + ASSERT_NE(nullptr, scorer); + + EXPECT_EQ(100u, scorer->size_hint()); +} + +TEST_F(ExcludeScorerTest, SingleDocNotExcluded) { + auto underlying = std::make_shared(std::vector {42}); + auto excluding = std::make_shared(std::vector {1, 2, 3}); + + auto scorer = make_exclude(std::move(underlying), std::move(excluding)); + ASSERT_NE(nullptr, scorer); + + EXPECT_EQ(42u, scorer->doc()); + EXPECT_EQ(TERMINATED, scorer->advance()); +} + +TEST_F(ExcludeScorerTest, SingleDocExcluded) { + auto underlying = std::make_shared(std::vector {42}); + auto excluding = std::make_shared(std::vector {42}); + + auto scorer = make_exclude(std::move(underlying), std::move(excluding)); + ASSERT_NE(nullptr, scorer); + + EXPECT_EQ(TERMINATED, scorer->doc()); +} + +TEST_F(ExcludeScorerTest, FirstDocExcluded) { + auto underlying = std::make_shared(std::vector {1, 5, 10}); + auto excluding = std::make_shared(std::vector {1}); + + auto scorer = make_exclude(std::move(underlying), std::move(excluding)); + ASSERT_NE(nullptr, scorer); + + EXPECT_EQ(5u, scorer->doc()); +} + +TEST_F(ExcludeScorerTest, LastDocExcluded) { + auto underlying = std::make_shared(std::vector {1, 5, 10}); + auto excluding = std::make_shared(std::vector {10}); + + auto scorer = make_exclude(std::move(underlying), std::move(excluding)); + ASSERT_NE(nullptr, scorer); + + std::vector docs; + while (scorer->doc() != TERMINATED) { + docs.push_back(scorer->doc()); + scorer->advance(); + } + + std::vector expected {1, 5}; + EXPECT_EQ(expected, docs); +} + +TEST_F(ExcludeScorerTest, ConsecutiveDocsExcluded) { + auto underlying = + std::make_shared(std::vector {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + auto excluding = std::make_shared(std::vector {3, 4, 5, 6, 7}); + + auto scorer = make_exclude(std::move(underlying), std::move(excluding)); + ASSERT_NE(nullptr, scorer); + + std::vector docs; + while (scorer->doc() != TERMINATED) { + docs.push_back(scorer->doc()); + scorer->advance(); + } + + std::vector expected {1, 2, 8, 9, 10}; + EXPECT_EQ(expected, docs); +} + +TEST_F(ExcludeScorerTest, ExcludingHasExtraDocs) { + auto underlying = std::make_shared(std::vector {5, 10, 15}); + auto excluding = std::make_shared(std::vector {1, 2, 5, 6, 7, 10, 100}); + + auto scorer = make_exclude(std::move(underlying), std::move(excluding)); + ASSERT_NE(nullptr, scorer); + + std::vector docs; + while (scorer->doc() != TERMINATED) { + docs.push_back(scorer->doc()); + scorer->advance(); + } + + std::vector expected {15}; + EXPECT_EQ(expected, docs); +} + +TEST_F(ExcludeScorerTest, LargeDocIds) { + auto underlying = std::make_shared( + std::vector {100000000, 200000000, 300000000, 400000000}); + auto excluding = std::make_shared(std::vector {200000000, 400000000}); + + auto scorer = make_exclude(std::move(underlying), std::move(excluding)); + ASSERT_NE(nullptr, scorer); + + std::vector docs; + while (scorer->doc() != TERMINATED) { + docs.push_back(scorer->doc()); + scorer->advance(); + } + + std::vector expected {100000000, 300000000}; + EXPECT_EQ(expected, docs); +} + +TEST_F(ExcludeScorerTest, RandomData) { + auto sample_underlying = sample_with_seed(10000, 0.1, 1); + auto sample_excluding = sample_with_seed(10000, 0.05, 2); + + auto underlying = std::make_shared(sample_underlying); + auto excluding = std::make_shared(sample_excluding); + + auto scorer = make_exclude(std::move(underlying), std::move(excluding)); + ASSERT_NE(nullptr, scorer); + + auto expected = compute_exclude(sample_underlying, sample_excluding); + std::vector actual; + while (scorer->doc() != TERMINATED) { + actual.push_back(scorer->doc()); + scorer->advance(); + } + + EXPECT_EQ(expected, actual); +} + +TEST_F(ExcludeScorerTest, RandomDataSeek) { + auto sample_underlying = sample_with_seed(10000, 0.1, 1); + auto sample_excluding = sample_with_seed(10000, 0.05, 2); + auto sample_seek_targets = sample_with_seed(10000, 0.005, 3); + + auto expected = compute_exclude(sample_underlying, sample_excluding); + + auto underlying = std::make_shared(sample_underlying); + auto excluding = std::make_shared(sample_excluding); + auto scorer = make_exclude(std::move(underlying), std::move(excluding)); + ASSERT_NE(nullptr, scorer); + + for (uint32_t target : sample_seek_targets) { + auto it = std::lower_bound(expected.begin(), expected.end(), target); + uint32_t expected_doc = (it == expected.end()) ? TERMINATED : *it; + + underlying = std::make_shared(sample_underlying); + excluding = std::make_shared(sample_excluding); + scorer = make_exclude(std::move(underlying), std::move(excluding)); + + uint32_t actual_doc = scorer->seek(target); + EXPECT_EQ(expected_doc, actual_doc) << "Failed for target: " << target; + } +} + +TEST_F(ExcludeScorerTest, InterleavedAdvanceAndSeek) { + auto underlying = std::make_shared( + std::vector {1, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50}); + auto excluding = std::make_shared(std::vector {5, 20, 35, 50}); + + auto scorer = make_exclude(std::move(underlying), std::move(excluding)); + ASSERT_NE(nullptr, scorer); + + EXPECT_EQ(1u, scorer->doc()); + EXPECT_EQ(10u, scorer->advance()); + EXPECT_EQ(15u, scorer->seek(12)); + EXPECT_EQ(25u, scorer->advance()); + EXPECT_EQ(40u, scorer->seek(35)); + EXPECT_EQ(45u, scorer->advance()); + EXPECT_EQ(TERMINATED, scorer->advance()); +} + +TEST_F(ExcludeScorerTest, DocDoesNotAdvance) { + auto underlying = std::make_shared(std::vector {5, 10, 15}); + auto excluding = std::make_shared(std::vector {}); + + auto scorer = make_exclude(std::move(underlying), std::move(excluding)); + ASSERT_NE(nullptr, scorer); + + EXPECT_EQ(5u, scorer->doc()); + EXPECT_EQ(5u, scorer->doc()); + EXPECT_EQ(5u, scorer->doc()); + + scorer->advance(); + EXPECT_EQ(10u, scorer->doc()); + EXPECT_EQ(10u, scorer->doc()); +} + +TEST_F(ExcludeScorerTest, ExcludingLargerThanUnderlying) { + auto underlying = std::make_shared(std::vector {5, 10}); + auto excluding = std::make_shared( + std::vector {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); + + auto scorer = make_exclude(std::move(underlying), std::move(excluding)); + ASSERT_NE(nullptr, scorer); + + EXPECT_EQ(TERMINATED, scorer->doc()); +} + +TEST_F(ExcludeScorerTest, LargeDataSet) { + std::vector underlying_docs; + std::vector excluding_docs; + + for (uint32_t i = 0; i < 100000; i += 10) { + underlying_docs.push_back(i); + } + + for (uint32_t i = 0; i < 100000; i += 30) { + excluding_docs.push_back(i); + } + + auto underlying = std::make_shared(underlying_docs); + auto excluding = std::make_shared(excluding_docs); + + auto scorer = make_exclude(std::move(underlying), std::move(excluding)); + ASSERT_NE(nullptr, scorer); + + auto expected = compute_exclude(underlying_docs, excluding_docs); + std::vector actual; + while (scorer->doc() != TERMINATED) { + actual.push_back(scorer->doc()); + scorer->advance(); + } + + EXPECT_EQ(expected, actual); +} + +TEST_F(ExcludeScorerTest, AlternatingExcluded) { + auto underlying = + std::make_shared(std::vector {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + auto excluding = std::make_shared(std::vector {2, 4, 6, 8, 10}); + + auto scorer = make_exclude(std::move(underlying), std::move(excluding)); + ASSERT_NE(nullptr, scorer); + + std::vector docs; + while (scorer->doc() != TERMINATED) { + docs.push_back(scorer->doc()); + scorer->advance(); + } + + std::vector expected {1, 3, 5, 7, 9}; + EXPECT_EQ(expected, docs); +} + +TEST_F(ExcludeScorerTest, IsWithinBehavior) { + auto underlying = std::make_shared(std::vector {100, 200, 300}); + auto excluding = std::make_shared(std::vector {50, 100, 150, 200, 250}); + + auto scorer = make_exclude(std::move(underlying), std::move(excluding)); + ASSERT_NE(nullptr, scorer); + + std::vector docs; + while (scorer->doc() != TERMINATED) { + docs.push_back(scorer->doc()); + scorer->advance(); + } + + std::vector expected {300}; + EXPECT_EQ(expected, docs); +} + +TEST_F(ExcludeScorerTest, SeekExcludingCatchUp) { + auto underlying = std::make_shared( + std::vector {10, 20, 30, 40, 50, 60, 70, 80, 90, 100}); + auto excluding = std::make_shared(std::vector {5, 15, 25, 50, 75, 100}); + + auto scorer = make_exclude(std::move(underlying), std::move(excluding)); + ASSERT_NE(nullptr, scorer); + + EXPECT_EQ(10u, scorer->doc()); + EXPECT_EQ(60u, scorer->seek(50)); + EXPECT_EQ(70u, scorer->advance()); + EXPECT_EQ(80u, scorer->advance()); + EXPECT_EQ(90u, scorer->advance()); + EXPECT_EQ(TERMINATED, scorer->advance()); +} + +} // namespace doris::segment_v2::inverted_index::query_v2 diff --git a/be/test/olap/rowset/segment_v2/inverted_index/query_v2/intersection_test.cpp b/be/test/olap/rowset/segment_v2/inverted_index/query_v2/intersection_test.cpp index ebe61bc4590c2c..71c02bc583ce3f 100644 --- a/be/test/olap/rowset/segment_v2/inverted_index/query_v2/intersection_test.cpp +++ b/be/test/olap/rowset/segment_v2/inverted_index/query_v2/intersection_test.cpp @@ -43,7 +43,7 @@ class IntersectionTest : public ::testing::Test { TEST_F(IntersectionTest, test_create_with_empty_docsets) { std::vector docsets; - EXPECT_THROW((Intersection::create(docsets)), Exception); + EXPECT_THROW((Intersection::create(docsets, 10000)), Exception); } // Test creating intersection with only 1 docset (should throw exception) @@ -51,7 +51,7 @@ TEST_F(IntersectionTest, test_create_with_single_docset) { std::vector docsets; docsets.push_back(std::make_shared(std::vector {1, 2, 3})); - EXPECT_THROW((Intersection::create(docsets)), Exception); + EXPECT_THROW((Intersection::create(docsets, 10000)), Exception); } // Test creating intersection with exactly 2 docsets @@ -60,11 +60,11 @@ TEST_F(IntersectionTest, test_create_with_two_docsets) { docsets.push_back(std::make_shared(std::vector {1, 2, 3, 4, 5})); docsets.push_back(std::make_shared(std::vector {2, 3, 4, 6, 7})); - auto intersection = Intersection::create(docsets); + auto intersection = Intersection::create(docsets, 10000); ASSERT_NE(nullptr, intersection); // Should start at first matching document - EXPECT_EQ(2u, intersection->doc()); + EXPECT_EQ(2U, intersection->doc()); } // Test intersection advance with two docsets @@ -73,7 +73,7 @@ TEST_F(IntersectionTest, test_advance_two_docsets) { docsets.push_back(std::make_shared(std::vector {1, 3, 5, 7, 9})); docsets.push_back(std::make_shared(std::vector {3, 5, 9, 11})); - auto intersection = Intersection::create(docsets); + auto intersection = Intersection::create(docsets, 10000); ASSERT_NE(nullptr, intersection); std::vector results; @@ -94,7 +94,7 @@ TEST_F(IntersectionTest, test_three_docsets) { docsets.push_back(std::make_shared(std::vector {2, 4, 6, 8})); docsets.push_back(std::make_shared(std::vector {2, 3, 4, 6, 7})); - auto intersection = Intersection::create(docsets); + auto intersection = Intersection::create(docsets, 10000); ASSERT_NE(nullptr, intersection); std::vector results; @@ -116,7 +116,7 @@ TEST_F(IntersectionTest, test_four_docsets) { docsets.push_back(std::make_shared(std::vector {2, 3, 5, 10, 11})); docsets.push_back(std::make_shared(std::vector {5, 10, 15})); - auto intersection = Intersection::create(docsets); + auto intersection = Intersection::create(docsets, 10000); ASSERT_NE(nullptr, intersection); std::vector results; @@ -136,7 +136,7 @@ TEST_F(IntersectionTest, test_no_intersection) { docsets.push_back(std::make_shared(std::vector {1, 3, 5})); docsets.push_back(std::make_shared(std::vector {2, 4, 6})); - auto intersection = Intersection::create(docsets); + auto intersection = Intersection::create(docsets, 10000); ASSERT_NE(nullptr, intersection); // Should be terminated immediately @@ -151,10 +151,10 @@ TEST_F(IntersectionTest, test_single_common_document) { docsets.push_back(std::make_shared(std::vector {2, 5, 8})); docsets.push_back(std::make_shared(std::vector {5, 7, 9})); - auto intersection = Intersection::create(docsets); + auto intersection = Intersection::create(docsets, 10000); ASSERT_NE(nullptr, intersection); - EXPECT_EQ(5u, intersection->doc()); + EXPECT_EQ(5U, intersection->doc()); EXPECT_EQ(TERMINATED, intersection->advance()); } @@ -164,16 +164,16 @@ TEST_F(IntersectionTest, test_seek) { docsets.push_back(std::make_shared(std::vector {1, 5, 10, 15, 20})); docsets.push_back(std::make_shared(std::vector {5, 10, 15, 20, 25})); - auto intersection = Intersection::create(docsets); + auto intersection = Intersection::create(docsets, 10000); ASSERT_NE(nullptr, intersection); // Seek to doc 10 - EXPECT_EQ(10u, intersection->seek(8)); - EXPECT_EQ(10u, intersection->doc()); + EXPECT_EQ(10U, intersection->seek(8)); + EXPECT_EQ(10U, intersection->doc()); // Seek to doc 20 - EXPECT_EQ(20u, intersection->seek(18)); - EXPECT_EQ(20u, intersection->doc()); + EXPECT_EQ(20U, intersection->seek(18)); + EXPECT_EQ(20U, intersection->doc()); // Seek beyond all docs EXPECT_EQ(TERMINATED, intersection->seek(30)); @@ -185,17 +185,17 @@ TEST_F(IntersectionTest, test_seek_current_position) { docsets.push_back(std::make_shared(std::vector {5, 10, 15})); docsets.push_back(std::make_shared(std::vector {5, 10, 15, 20})); - auto intersection = Intersection::create(docsets); + auto intersection = Intersection::create(docsets, 10000); ASSERT_NE(nullptr, intersection); - EXPECT_EQ(5u, intersection->doc()); + EXPECT_EQ(5U, intersection->doc()); // Seek to current position or before should stay at current - EXPECT_EQ(5u, intersection->seek(5)); - EXPECT_EQ(5u, intersection->doc()); + EXPECT_EQ(5U, intersection->seek(5)); + EXPECT_EQ(5U, intersection->doc()); - EXPECT_EQ(5u, intersection->seek(3)); - EXPECT_EQ(5u, intersection->doc()); + EXPECT_EQ(5U, intersection->seek(3)); + EXPECT_EQ(5U, intersection->doc()); } // Test size_hint - should return smallest docset's size hint @@ -205,11 +205,11 @@ TEST_F(IntersectionTest, test_size_hint) { docsets.push_back(std::make_shared(std::vector {2, 3, 4}, 50)); docsets.push_back(std::make_shared(std::vector {2, 3, 4, 5, 6}, 75)); - auto intersection = Intersection::create(docsets); + auto intersection = Intersection::create(docsets, 10000); ASSERT_NE(nullptr, intersection); - // Should return the smallest size hint (from the smallest docset after sorting) - EXPECT_EQ(50u, intersection->size_hint()); + // Should return the estimated size hint (using estimate_intersection algorithm) + EXPECT_EQ(0U, intersection->size_hint()); } // Test norm - should return left docset's norm @@ -218,7 +218,7 @@ TEST_F(IntersectionTest, test_norm) { docsets.push_back(std::make_shared(std::vector {1, 2, 3}, 0, 10)); docsets.push_back(std::make_shared(std::vector {2, 3, 4}, 0, 20)); - auto intersection = Intersection::create(docsets); + auto intersection = Intersection::create(docsets, 10000); ASSERT_NE(nullptr, intersection); // After creation, docsets are sorted by size_hint, and smallest becomes left @@ -234,13 +234,13 @@ TEST_F(IntersectionTest, test_docset_mut_specialized) { docsets.push_back(std::make_shared(std::vector {2, 3, 4})); docsets.push_back(std::make_shared(std::vector {2, 3, 5})); - auto intersection = Intersection::create(docsets); + auto intersection = Intersection::create(docsets, 10000); ASSERT_NE(nullptr, intersection); // Access left (ord 0) auto& docset0 = intersection->docset_mut_specialized(0); EXPECT_NE(nullptr, docset0); - EXPECT_EQ(2u, docset0->doc()); + EXPECT_EQ(2U, docset0->doc()); // Access right (ord 1) auto& docset1 = intersection->docset_mut_specialized(1); @@ -260,7 +260,7 @@ TEST_F(IntersectionTest, test_all_identical_docsets) { docsets.push_back(std::make_shared(common_docs)); docsets.push_back(std::make_shared(common_docs)); - auto intersection = Intersection::create(docsets); + auto intersection = Intersection::create(docsets, 10000); ASSERT_NE(nullptr, intersection); std::vector results; @@ -279,7 +279,7 @@ TEST_F(IntersectionTest, test_one_empty_docset) { docsets.push_back(std::make_shared(std::vector {1, 2, 3})); docsets.push_back(std::make_shared(std::vector {})); - auto intersection = Intersection::create(docsets); + auto intersection = Intersection::create(docsets, 10000); ASSERT_NE(nullptr, intersection); // Should be terminated immediately @@ -292,10 +292,10 @@ TEST_F(IntersectionTest, test_advance_after_termination) { docsets.push_back(std::make_shared(std::vector {1})); docsets.push_back(std::make_shared(std::vector {1})); - auto intersection = Intersection::create(docsets); + auto intersection = Intersection::create(docsets, 10000); ASSERT_NE(nullptr, intersection); - EXPECT_EQ(1u, intersection->doc()); + EXPECT_EQ(1U, intersection->doc()); EXPECT_EQ(TERMINATED, intersection->advance()); EXPECT_EQ(TERMINATED, intersection->advance()); EXPECT_EQ(TERMINATED, intersection->advance()); @@ -309,7 +309,7 @@ TEST_F(IntersectionTest, test_large_document_ids) { docsets.push_back( std::make_shared(std::vector {500, 10000, 50000, 1000000})); - auto intersection = Intersection::create(docsets); + auto intersection = Intersection::create(docsets, 10000); ASSERT_NE(nullptr, intersection); std::vector results; @@ -330,7 +330,7 @@ TEST_F(IntersectionTest, test_sparse_docsets) { docsets.push_back( std::make_shared(std::vector {50, 100, 150, 200, 250, 300})); - auto intersection = Intersection::create(docsets); + auto intersection = Intersection::create(docsets, 10000); ASSERT_NE(nullptr, intersection); std::vector results; @@ -352,20 +352,20 @@ TEST_F(IntersectionTest, test_seek_after_advance) { docsets.push_back( std::make_shared(std::vector {5, 10, 15, 20, 25, 30, 35})); - auto intersection = Intersection::create(docsets); + auto intersection = Intersection::create(docsets, 10000); ASSERT_NE(nullptr, intersection); // Start at doc 5 - EXPECT_EQ(5u, intersection->doc()); + EXPECT_EQ(5U, intersection->doc()); // Advance to doc 10 - EXPECT_EQ(10u, intersection->advance()); + EXPECT_EQ(10U, intersection->advance()); // Seek to doc 25 - EXPECT_EQ(25u, intersection->seek(22)); + EXPECT_EQ(25U, intersection->seek(22)); // Continue advancing - EXPECT_EQ(30u, intersection->advance()); + EXPECT_EQ(30U, intersection->advance()); EXPECT_EQ(TERMINATED, intersection->advance()); } @@ -375,22 +375,22 @@ TEST_F(IntersectionTest, test_multiple_seeks) { docsets.push_back(std::make_shared(std::vector {10, 20, 30, 40, 50})); docsets.push_back(std::make_shared(std::vector {10, 20, 30, 40, 50, 60})); - auto intersection = Intersection::create(docsets); + auto intersection = Intersection::create(docsets, 10000); ASSERT_NE(nullptr, intersection); - EXPECT_EQ(10u, intersection->doc()); + EXPECT_EQ(10U, intersection->doc()); // Seek to 25 - EXPECT_EQ(30u, intersection->seek(25)); + EXPECT_EQ(30U, intersection->seek(25)); // Seek to 35 - EXPECT_EQ(40u, intersection->seek(35)); + EXPECT_EQ(40U, intersection->seek(35)); // Seek to same position - EXPECT_EQ(40u, intersection->seek(40)); + EXPECT_EQ(40U, intersection->seek(40)); // Seek backwards (should stay at current) - EXPECT_EQ(40u, intersection->seek(35)); + EXPECT_EQ(40U, intersection->seek(35)); } // Test docsets with different sizes are sorted correctly @@ -405,11 +405,11 @@ TEST_F(IntersectionTest, test_docsets_sorted_by_size) { // Medium docsets.push_back(std::make_shared(std::vector {2, 3, 4, 5, 6}, 500)); - auto intersection = Intersection::create(docsets); + auto intersection = Intersection::create(docsets, 10000); ASSERT_NE(nullptr, intersection); - // size_hint should be from smallest docset - EXPECT_EQ(100u, intersection->size_hint()); + // size_hint should be from estimated intersection + EXPECT_EQ(1U, intersection->size_hint()); std::vector results; uint32_t doc = intersection->doc(); diff --git a/be/test/olap/rowset/segment_v2/inverted_index/query_v2/occur_boolean_query_test.cpp b/be/test/olap/rowset/segment_v2/inverted_index/query_v2/occur_boolean_query_test.cpp new file mode 100644 index 00000000000000..bcffa0d7082812 --- /dev/null +++ b/be/test/olap/rowset/segment_v2/inverted_index/query_v2/occur_boolean_query_test.cpp @@ -0,0 +1,707 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/occur_boolean_query.h" + +#include + +#include +#include +#include +#include +#include + +#include "olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/occur.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/boolean_query/occur_boolean_weight.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/query.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/scorer.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/weight.h" + +namespace doris::segment_v2::inverted_index::query_v2 { +namespace { + +class MockScorer : public Scorer { +public: + MockScorer(std::vector docs, float score_val = 1.0F) + : _docs(std::move(docs)), _score_val(score_val) { + if (_docs.empty()) { + _current_doc = TERMINATED; + } else { + std::ranges::sort(_docs); + _current_doc = _docs[0]; + } + } + + uint32_t advance() override { + if (_docs.empty() || _index >= _docs.size()) { + _current_doc = TERMINATED; + return TERMINATED; + } + ++_index; + if (_index >= _docs.size()) { + _current_doc = TERMINATED; + return TERMINATED; + } + _current_doc = _docs[_index]; + return _current_doc; + } + + uint32_t seek(uint32_t target) override { + if (_docs.empty() || _index >= _docs.size()) { + _current_doc = TERMINATED; + return TERMINATED; + } + if (_current_doc >= target) { + return _current_doc; + } + auto it = std::lower_bound(_docs.begin() + static_cast(_index), _docs.end(), + target); + if (it == _docs.end()) { + _index = _docs.size(); + _current_doc = TERMINATED; + return TERMINATED; + } + _index = static_cast(it - _docs.begin()); + _current_doc = *it; + return _current_doc; + } + + uint32_t doc() const override { return _current_doc; } + uint32_t size_hint() const override { return static_cast(_docs.size()); } + float score() override { return _score_val; } + +private: + std::vector _docs; + size_t _index = 0; + uint32_t _current_doc = TERMINATED; + float _score_val = 1.0F; +}; + +class MockWeight : public Weight { +public: + explicit MockWeight(std::vector docs, float score_val = 1.0F) + : _docs(std::move(docs)), _score_val(score_val) {} + + ScorerPtr scorer(const QueryExecutionContext& /*context*/) override { + return std::make_shared(_docs, _score_val); + } + +private: + std::vector _docs; + float _score_val; +}; + +class MockQuery : public Query { +public: + explicit MockQuery(std::vector docs, float score_val = 1.0F) + : _docs(std::move(docs)), _score_val(score_val) {} + + WeightPtr weight(bool /*enable_scoring*/) override { + return std::make_shared(_docs, _score_val); + } + +private: + std::vector _docs; + float _score_val; +}; + +} // anonymous namespace + +class OccurBooleanQueryTest : public testing::Test { +protected: + QueryExecutionContext _ctx; + +public: + std::vector collect_docs(ScorerPtr scorer) { + std::vector result; + uint32_t doc = scorer->doc(); + while (doc != TERMINATED) { + result.push_back(doc); + doc = scorer->advance(); + } + return result; + } + + std::set to_set(const std::vector& v) { + return std::set(v.begin(), v.end()); + } + + std::vector set_union(const std::vector& a, + const std::vector& b) { + std::set result; + result.insert(a.begin(), a.end()); + result.insert(b.begin(), b.end()); + return std::vector(result.begin(), result.end()); + } + + std::vector set_intersection(const std::vector& a, + const std::vector& b) { + std::set sa(a.begin(), a.end()); + std::set sb(b.begin(), b.end()); + std::vector result; + std::set_intersection(sa.begin(), sa.end(), sb.begin(), sb.end(), + std::back_inserter(result)); + return result; + } + + std::vector set_difference(const std::vector& a, + const std::vector& b) { + std::set sa(a.begin(), a.end()); + std::set sb(b.begin(), b.end()); + std::vector result; + std::set_difference(sa.begin(), sa.end(), sb.begin(), sb.end(), std::back_inserter(result)); + return result; + } + + std::vector generate_random_docs(size_t count, uint32_t max_doc, uint32_t seed) { + std::mt19937 gen(seed); + std::uniform_int_distribution dis(0, max_doc - 1); + std::set doc_set; + while (doc_set.size() < count) { + doc_set.insert(dis(gen)); + } + return std::vector(doc_set.begin(), doc_set.end()); + } + + std::vector generate_range_docs(uint32_t start, uint32_t end, uint32_t step = 1) { + std::vector result; + for (uint32_t i = start; i < end; i += step) { + result.push_back(i); + } + return result; + } +}; + +TEST_F(OccurBooleanQueryTest, EmptyQuery) { + std::vector> clauses; + OccurBooleanQuery query(std::move(clauses)); + auto weight = query.weight(false); + auto scorer = weight->scorer(_ctx); + + EXPECT_EQ(scorer->doc(), TERMINATED); +} + +TEST_F(OccurBooleanQueryTest, SingleMustClause) { + auto docs = generate_range_docs(0, 100, 2); + std::vector> clauses; + clauses.emplace_back(Occur::MUST, std::make_shared(docs)); + + OccurBooleanQuery query(std::move(clauses)); + auto weight = query.weight(false); + auto scorer = weight->scorer(_ctx); + auto result = collect_docs(scorer); + + EXPECT_EQ(result, docs); +} + +TEST_F(OccurBooleanQueryTest, SingleShouldClause) { + auto docs = generate_range_docs(0, 100, 3); + std::vector> clauses; + clauses.emplace_back(Occur::SHOULD, std::make_shared(docs)); + + OccurBooleanQuery query(std::move(clauses)); + auto weight = query.weight(false); + auto scorer = weight->scorer(_ctx); + auto result = collect_docs(scorer); + + EXPECT_EQ(result, docs); +} + +TEST_F(OccurBooleanQueryTest, SingleMustNotClauseReturnsEmpty) { + auto docs = generate_range_docs(0, 100); + std::vector> clauses; + clauses.emplace_back(Occur::MUST_NOT, std::make_shared(docs)); + + OccurBooleanQuery query(std::move(clauses)); + auto weight = query.weight(false); + auto scorer = weight->scorer(_ctx); + + EXPECT_EQ(scorer->doc(), TERMINATED); +} + +TEST_F(OccurBooleanQueryTest, TwoMustClausesIntersection) { + auto docs1 = generate_range_docs(0, 1000, 2); + auto docs2 = generate_range_docs(0, 1000, 3); + auto expected = set_intersection(docs1, docs2); + + std::vector> clauses; + clauses.emplace_back(Occur::MUST, std::make_shared(docs1)); + clauses.emplace_back(Occur::MUST, std::make_shared(docs2)); + + OccurBooleanQuery query(std::move(clauses)); + auto weight = query.weight(false); + auto scorer = weight->scorer(_ctx); + auto result = collect_docs(scorer); + + EXPECT_EQ(result.size(), expected.size()); + EXPECT_EQ(result, expected); +} + +TEST_F(OccurBooleanQueryTest, TwoShouldClausesUnion) { + auto docs1 = generate_range_docs(0, 500, 2); + auto docs2 = generate_range_docs(250, 750, 2); + auto expected = set_union(docs1, docs2); + + std::vector> clauses; + clauses.emplace_back(Occur::SHOULD, std::make_shared(docs1)); + clauses.emplace_back(Occur::SHOULD, std::make_shared(docs2)); + + OccurBooleanQuery query(std::move(clauses)); + auto weight = query.weight(false); + auto scorer = weight->scorer(_ctx); + auto result = collect_docs(scorer); + + EXPECT_EQ(result.size(), expected.size()); + EXPECT_EQ(to_set(result), to_set(expected)); +} + +TEST_F(OccurBooleanQueryTest, MustWithMustNotExclusion) { + auto must_docs = generate_range_docs(0, 1000); + auto must_not_docs = generate_range_docs(0, 1000, 3); + auto expected = set_difference(must_docs, must_not_docs); + + std::vector> clauses; + clauses.emplace_back(Occur::MUST, std::make_shared(must_docs)); + clauses.emplace_back(Occur::MUST_NOT, std::make_shared(must_not_docs)); + + OccurBooleanQuery query(std::move(clauses)); + auto weight = query.weight(false); + auto scorer = weight->scorer(_ctx); + auto result = collect_docs(scorer); + + EXPECT_EQ(result.size(), expected.size()); + EXPECT_EQ(to_set(result), to_set(expected)); +} + +TEST_F(OccurBooleanQueryTest, ShouldWithMustNotExclusion) { + auto should_docs = generate_range_docs(0, 1000, 2); + auto must_not_docs = generate_range_docs(0, 500); + auto expected = set_difference(should_docs, must_not_docs); + + std::vector> clauses; + clauses.emplace_back(Occur::SHOULD, std::make_shared(should_docs)); + clauses.emplace_back(Occur::MUST_NOT, std::make_shared(must_not_docs)); + + OccurBooleanQuery query(std::move(clauses)); + auto weight = query.weight(false); + auto scorer = weight->scorer(_ctx); + auto result = collect_docs(scorer); + + EXPECT_EQ(result.size(), expected.size()); + EXPECT_EQ(to_set(result), to_set(expected)); +} + +TEST_F(OccurBooleanQueryTest, MustAndShouldCombined) { + auto must_docs = generate_range_docs(0, 500); + auto should_docs = generate_range_docs(250, 750); + + std::vector> clauses; + clauses.emplace_back(Occur::MUST, std::make_shared(must_docs)); + clauses.emplace_back(Occur::SHOULD, std::make_shared(should_docs)); + + OccurBooleanQuery query(std::move(clauses)); + auto weight = query.weight(false); + auto scorer = weight->scorer(_ctx); + auto result = collect_docs(scorer); + + EXPECT_EQ(result, must_docs); +} + +TEST_F(OccurBooleanQueryTest, MultipleMustClausesIntersection) { + auto docs1 = generate_range_docs(0, 10000, 2); + auto docs2 = generate_range_docs(0, 10000, 3); + auto docs3 = generate_range_docs(0, 10000, 5); + auto expected = set_intersection(set_intersection(docs1, docs2), docs3); + + std::vector> clauses; + clauses.emplace_back(Occur::MUST, std::make_shared(docs1)); + clauses.emplace_back(Occur::MUST, std::make_shared(docs2)); + clauses.emplace_back(Occur::MUST, std::make_shared(docs3)); + + OccurBooleanQuery query(std::move(clauses)); + auto weight = query.weight(false); + auto scorer = weight->scorer(_ctx); + auto result = collect_docs(scorer); + + EXPECT_EQ(result.size(), expected.size()); + EXPECT_EQ(to_set(result), to_set(expected)); +} + +TEST_F(OccurBooleanQueryTest, MultipleShouldClausesUnion) { + auto docs1 = generate_range_docs(0, 3000, 7); + auto docs2 = generate_range_docs(1000, 4000, 11); + auto docs3 = generate_range_docs(2000, 5000, 13); + auto expected = set_union(set_union(docs1, docs2), docs3); + + std::vector> clauses; + clauses.emplace_back(Occur::SHOULD, std::make_shared(docs1)); + clauses.emplace_back(Occur::SHOULD, std::make_shared(docs2)); + clauses.emplace_back(Occur::SHOULD, std::make_shared(docs3)); + + OccurBooleanQuery query(std::move(clauses)); + auto weight = query.weight(false); + auto scorer = weight->scorer(_ctx); + auto result = collect_docs(scorer); + + EXPECT_EQ(result.size(), expected.size()); + EXPECT_EQ(to_set(result), to_set(expected)); +} + +TEST_F(OccurBooleanQueryTest, ComplexMustShouldMustNot) { + auto must_docs1 = generate_range_docs(0, 2000); + auto must_docs2 = generate_range_docs(500, 2500); + auto should_docs = generate_range_docs(0, 3000, 3); + auto must_not_docs = generate_range_docs(1000, 1500); + + auto must_intersection = set_intersection(must_docs1, must_docs2); + auto expected = set_difference(must_intersection, must_not_docs); + + std::vector> clauses; + clauses.emplace_back(Occur::MUST, std::make_shared(must_docs1)); + clauses.emplace_back(Occur::MUST, std::make_shared(must_docs2)); + clauses.emplace_back(Occur::SHOULD, std::make_shared(should_docs)); + clauses.emplace_back(Occur::MUST_NOT, std::make_shared(must_not_docs)); + + OccurBooleanQuery query(std::move(clauses)); + auto weight = query.weight(false); + auto scorer = weight->scorer(_ctx); + auto result = collect_docs(scorer); + + EXPECT_EQ(result.size(), expected.size()); + EXPECT_EQ(to_set(result), to_set(expected)); +} + +TEST_F(OccurBooleanQueryTest, LargeScaleIntersection) { + auto docs1 = generate_random_docs(5000, 100000, 42); + auto docs2 = generate_random_docs(5000, 100000, 123); + auto expected = set_intersection(docs1, docs2); + + std::vector> clauses; + clauses.emplace_back(Occur::MUST, std::make_shared(docs1)); + clauses.emplace_back(Occur::MUST, std::make_shared(docs2)); + + OccurBooleanQuery query(std::move(clauses)); + auto weight = query.weight(false); + auto scorer = weight->scorer(_ctx); + auto result = collect_docs(scorer); + + EXPECT_EQ(result.size(), expected.size()); + EXPECT_EQ(to_set(result), to_set(expected)); +} + +TEST_F(OccurBooleanQueryTest, LargeScaleUnion) { + auto docs1 = generate_random_docs(3000, 50000, 1); + auto docs2 = generate_random_docs(3000, 50000, 2); + auto docs3 = generate_random_docs(3000, 50000, 3); + auto expected = set_union(set_union(docs1, docs2), docs3); + + std::vector> clauses; + clauses.emplace_back(Occur::SHOULD, std::make_shared(docs1)); + clauses.emplace_back(Occur::SHOULD, std::make_shared(docs2)); + clauses.emplace_back(Occur::SHOULD, std::make_shared(docs3)); + + OccurBooleanQuery query(std::move(clauses)); + auto weight = query.weight(false); + auto scorer = weight->scorer(_ctx); + auto result = collect_docs(scorer); + + EXPECT_EQ(result.size(), expected.size()); + EXPECT_EQ(to_set(result), to_set(expected)); +} + +TEST_F(OccurBooleanQueryTest, LargeScaleExclusion) { + auto must_docs = generate_range_docs(0, 50000); + auto must_not_docs = generate_random_docs(10000, 50000, 999); + auto expected = set_difference(must_docs, must_not_docs); + + std::vector> clauses; + clauses.emplace_back(Occur::MUST, std::make_shared(must_docs)); + clauses.emplace_back(Occur::MUST_NOT, std::make_shared(must_not_docs)); + + OccurBooleanQuery query(std::move(clauses)); + auto weight = query.weight(false); + auto scorer = weight->scorer(_ctx); + auto result = collect_docs(scorer); + + EXPECT_EQ(result.size(), expected.size()); + EXPECT_EQ(to_set(result), to_set(expected)); +} + +TEST_F(OccurBooleanQueryTest, DisjointMustClausesEmpty) { + auto docs1 = generate_range_docs(0, 100); + auto docs2 = generate_range_docs(200, 300); + + std::vector> clauses; + clauses.emplace_back(Occur::MUST, std::make_shared(docs1)); + clauses.emplace_back(Occur::MUST, std::make_shared(docs2)); + + OccurBooleanQuery query(std::move(clauses)); + auto weight = query.weight(false); + auto scorer = weight->scorer(_ctx); + auto result = collect_docs(scorer); + + EXPECT_TRUE(result.empty()); +} + +TEST_F(OccurBooleanQueryTest, MustNotExcludesAllMust) { + auto must_docs = generate_range_docs(0, 100); + auto must_not_docs = generate_range_docs(0, 200); + + std::vector> clauses; + clauses.emplace_back(Occur::MUST, std::make_shared(must_docs)); + clauses.emplace_back(Occur::MUST_NOT, std::make_shared(must_not_docs)); + + OccurBooleanQuery query(std::move(clauses)); + auto weight = query.weight(false); + auto scorer = weight->scorer(_ctx); + auto result = collect_docs(scorer); + + EXPECT_TRUE(result.empty()); +} + +TEST_F(OccurBooleanQueryTest, EmptyMustClause) { + std::vector empty_docs; + auto docs2 = generate_range_docs(0, 100); + + std::vector> clauses; + clauses.emplace_back(Occur::MUST, std::make_shared(empty_docs)); + clauses.emplace_back(Occur::MUST, std::make_shared(docs2)); + + OccurBooleanQuery query(std::move(clauses)); + auto weight = query.weight(false); + auto scorer = weight->scorer(_ctx); + auto result = collect_docs(scorer); + + EXPECT_TRUE(result.empty()); +} + +TEST_F(OccurBooleanQueryTest, EmptyShouldClause) { + std::vector empty_docs; + auto docs2 = generate_range_docs(0, 100); + + std::vector> clauses; + clauses.emplace_back(Occur::SHOULD, std::make_shared(empty_docs)); + clauses.emplace_back(Occur::SHOULD, std::make_shared(docs2)); + + OccurBooleanQuery query(std::move(clauses)); + auto weight = query.weight(false); + auto scorer = weight->scorer(_ctx); + auto result = collect_docs(scorer); + + EXPECT_EQ(result, docs2); +} + +TEST_F(OccurBooleanQueryTest, ScoringEnabled) { + auto docs1 = generate_range_docs(0, 100, 2); + auto docs2 = generate_range_docs(0, 100, 3); + auto overlap = set_intersection(docs1, docs2); + + std::vector> clauses; + clauses.emplace_back(Occur::SHOULD, std::make_shared(docs1, 1.0F)); + clauses.emplace_back(Occur::SHOULD, std::make_shared(docs2, 2.0F)); + + OccurBooleanQuery query(std::move(clauses)); + auto weight = query.weight(true); + auto scorer = weight->scorer(_ctx); + + std::set overlap_set(overlap.begin(), overlap.end()); + bool found_overlap_with_higher_score = false; + bool found_single_match = false; + + uint32_t doc = scorer->seek(0); + while (doc != TERMINATED) { + float s = scorer->score(); + if (overlap_set.count(doc) > 0) { + EXPECT_FLOAT_EQ(s, 3.0F); + found_overlap_with_higher_score = true; + } else { + EXPECT_TRUE(s == 1.0F || s == 2.0F); + found_single_match = true; + } + doc = scorer->advance(); + } + + EXPECT_TRUE(found_overlap_with_higher_score); + EXPECT_TRUE(found_single_match); +} + +TEST_F(OccurBooleanQueryTest, ManyMustClausesStress) { + std::vector> doc_sets; + for (int i = 0; i < 10; ++i) { + doc_sets.push_back(generate_range_docs(0, 10000, i + 2)); + } + + auto expected = doc_sets[0]; + for (size_t i = 1; i < doc_sets.size(); ++i) { + expected = set_intersection(expected, doc_sets[i]); + } + + std::vector> clauses; + for (const auto& docs : doc_sets) { + clauses.emplace_back(Occur::MUST, std::make_shared(docs)); + } + + OccurBooleanQuery query(std::move(clauses)); + auto weight = query.weight(false); + auto scorer = weight->scorer(_ctx); + auto result = collect_docs(scorer); + + EXPECT_EQ(result.size(), expected.size()); + EXPECT_EQ(to_set(result), to_set(expected)); +} + +TEST_F(OccurBooleanQueryTest, ManyShouldClausesStress) { + std::vector> doc_sets; + for (int i = 0; i < 20; ++i) { + doc_sets.push_back(generate_random_docs(500, 20000, i * 100)); + } + + std::set expected_set; + for (const auto& docs : doc_sets) { + expected_set.insert(docs.begin(), docs.end()); + } + + std::vector> clauses; + for (const auto& docs : doc_sets) { + clauses.emplace_back(Occur::SHOULD, std::make_shared(docs)); + } + + OccurBooleanQuery query(std::move(clauses)); + auto weight = query.weight(false); + auto scorer = weight->scorer(_ctx); + auto result = collect_docs(scorer); + + EXPECT_EQ(result.size(), expected_set.size()); + EXPECT_EQ(to_set(result), expected_set); +} + +TEST_F(OccurBooleanQueryTest, MultipleMustNotClauses) { + auto must_docs = generate_range_docs(0, 5000); + auto must_not_docs1 = generate_range_docs(0, 1000); + auto must_not_docs2 = generate_range_docs(2000, 3000); + auto must_not_docs3 = generate_range_docs(4000, 5000); + + auto expected = set_difference(must_docs, must_not_docs1); + expected = set_difference(expected, must_not_docs2); + expected = set_difference(expected, must_not_docs3); + + std::vector> clauses; + clauses.emplace_back(Occur::MUST, std::make_shared(must_docs)); + clauses.emplace_back(Occur::MUST_NOT, std::make_shared(must_not_docs1)); + clauses.emplace_back(Occur::MUST_NOT, std::make_shared(must_not_docs2)); + clauses.emplace_back(Occur::MUST_NOT, std::make_shared(must_not_docs3)); + + OccurBooleanQuery query(std::move(clauses)); + auto weight = query.weight(false); + auto scorer = weight->scorer(_ctx); + auto result = collect_docs(scorer); + + EXPECT_EQ(result.size(), expected.size()); + EXPECT_EQ(to_set(result), to_set(expected)); +} + +TEST_F(OccurBooleanQueryTest, SeekOperations) { + auto docs1 = generate_range_docs(0, 10000, 2); + auto docs2 = generate_range_docs(0, 10000, 3); + auto expected = set_intersection(docs1, docs2); + + std::vector> clauses; + clauses.emplace_back(Occur::MUST, std::make_shared(docs1)); + clauses.emplace_back(Occur::MUST, std::make_shared(docs2)); + + OccurBooleanQuery query(std::move(clauses)); + auto weight = query.weight(false); + auto scorer = weight->scorer(_ctx); + + EXPECT_EQ(scorer->seek(100), 102); + EXPECT_EQ(scorer->seek(500), 504); + EXPECT_EQ(scorer->seek(1000), 1002); + + uint32_t current = scorer->doc(); + while (current != TERMINATED && current < 5000) { + current = scorer->advance(); + } + + EXPECT_EQ(scorer->seek(6000), 6000); + EXPECT_EQ(scorer->seek(9999), TERMINATED); +} + +TEST_F(OccurBooleanQueryTest, IdenticalDocSets) { + auto docs = generate_range_docs(0, 1000); + + std::vector> clauses; + clauses.emplace_back(Occur::MUST, std::make_shared(docs)); + clauses.emplace_back(Occur::MUST, std::make_shared(docs)); + clauses.emplace_back(Occur::MUST, std::make_shared(docs)); + + OccurBooleanQuery query(std::move(clauses)); + auto weight = query.weight(false); + auto scorer = weight->scorer(_ctx); + auto result = collect_docs(scorer); + + EXPECT_EQ(result, docs); +} + +TEST_F(OccurBooleanQueryTest, OverlappingShouldClauses) { + auto docs = generate_range_docs(0, 100); + + std::vector> clauses; + clauses.emplace_back(Occur::SHOULD, std::make_shared(docs)); + clauses.emplace_back(Occur::SHOULD, std::make_shared(docs)); + clauses.emplace_back(Occur::SHOULD, std::make_shared(docs)); + + OccurBooleanQuery query(std::move(clauses)); + auto weight = query.weight(false); + auto scorer = weight->scorer(_ctx); + auto result = collect_docs(scorer); + + EXPECT_EQ(result, docs); +} + +TEST_F(OccurBooleanQueryTest, SparseDocIds) { + std::vector docs1 = {0, 10000, 20000, 30000, 40000}; + std::vector docs2 = {0, 5000, 10000, 15000, 20000, 25000, 30000}; + auto expected = set_intersection(docs1, docs2); + + std::vector> clauses; + clauses.emplace_back(Occur::MUST, std::make_shared(docs1)); + clauses.emplace_back(Occur::MUST, std::make_shared(docs2)); + + OccurBooleanQuery query(std::move(clauses)); + auto weight = query.weight(false); + auto scorer = weight->scorer(_ctx); + auto result = collect_docs(scorer); + + EXPECT_EQ(result, expected); +} + +TEST_F(OccurBooleanQueryTest, OnlyMustNotClausesEmpty) { + auto docs1 = generate_range_docs(0, 100); + auto docs2 = generate_range_docs(50, 150); + + std::vector> clauses; + clauses.emplace_back(Occur::MUST_NOT, std::make_shared(docs1)); + clauses.emplace_back(Occur::MUST_NOT, std::make_shared(docs2)); + + OccurBooleanQuery query(std::move(clauses)); + auto weight = query.weight(false); + auto scorer = weight->scorer(_ctx); + + EXPECT_EQ(scorer->doc(), TERMINATED); +} + +} // namespace doris::segment_v2::inverted_index::query_v2 diff --git a/be/test/olap/rowset/segment_v2/inverted_index/query_v2/reqopt_scorer_test.cpp b/be/test/olap/rowset/segment_v2/inverted_index/query_v2/reqopt_scorer_test.cpp new file mode 100644 index 00000000000000..f4751828e6cf8e --- /dev/null +++ b/be/test/olap/rowset/segment_v2/inverted_index/query_v2/reqopt_scorer_test.cpp @@ -0,0 +1,540 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "olap/rowset/segment_v2/inverted_index/query_v2/reqopt_scorer.h" + +#include + +#include +#include +#include +#include + +#include "olap/rowset/segment_v2/inverted_index/query_v2/score_combiner.h" +#include "olap/rowset/segment_v2/inverted_index/query_v2/scorer.h" + +namespace doris::segment_v2::inverted_index::query_v2 { + +namespace { + +class VectorScorer final : public Scorer { +public: + VectorScorer(std::vector docs, float const_score = 1.0F, uint32_t size_hint_val = 0) + : _docs(std::move(docs)), _const_score(const_score), _size_hint_val(size_hint_val) { + if (_docs.empty()) { + _current_doc = TERMINATED; + } else { + std::sort(_docs.begin(), _docs.end()); + _current_doc = _docs[0]; + } + if (_size_hint_val == 0) { + _size_hint_val = static_cast(_docs.size()); + } + } + + uint32_t advance() override { + if (_docs.empty() || _index >= _docs.size()) { + _current_doc = TERMINATED; + return TERMINATED; + } + ++_index; + if (_index >= _docs.size()) { + _current_doc = TERMINATED; + return TERMINATED; + } + _current_doc = _docs[_index]; + return _current_doc; + } + + uint32_t seek(uint32_t target) override { + while (_current_doc < target && _current_doc != TERMINATED) { + advance(); + } + return _current_doc; + } + + uint32_t doc() const override { return _current_doc; } + uint32_t size_hint() const override { return _size_hint_val; } + float score() override { return _const_score; } + +private: + std::vector _docs; + float _const_score; + size_t _index = 0; + uint32_t _current_doc = TERMINATED; + uint32_t _size_hint_val = 0; +}; + +std::vector sample_with_seed(uint32_t max_doc, double probability, uint32_t seed) { + std::mt19937 gen(seed); + std::uniform_real_distribution<> dis(0.0, 1.0); + std::vector result; + for (uint32_t i = 0; i < max_doc; ++i) { + if (dis(gen) < probability) { + result.push_back(i); + } + } + return result; +} + +} // namespace + +class RequiredOptionalScorerTest : public ::testing::Test {}; + +TEST_F(RequiredOptionalScorerTest, EmptyOptional) { + std::vector req_docs {1, 3, 7}; + auto req_scorer = std::make_shared(req_docs, 1.0F); + auto opt_scorer = std::make_shared(std::vector {}, 1.0F); + auto combiner = std::make_shared(); + + auto scorer = make_required_optional_scorer(std::move(req_scorer), std::move(opt_scorer), + std::move(combiner)); + + std::vector docs; + while (scorer->doc() != TERMINATED) { + docs.push_back(scorer->doc()); + scorer->advance(); + } + + EXPECT_EQ(req_docs, docs); +} + +TEST_F(RequiredOptionalScorerTest, BasicScoring) { + auto req_scorer = + std::make_shared(std::vector {1, 3, 7, 8, 9, 10, 13, 15}, 1.0F); + auto opt_scorer = + std::make_shared(std::vector {1, 2, 7, 11, 12, 15}, 1.0F); + auto combiner = std::make_shared(); + + auto scorer = make_required_optional_scorer(std::move(req_scorer), std::move(opt_scorer), + std::move(combiner)); + + EXPECT_EQ(1u, scorer->doc()); + EXPECT_FLOAT_EQ(2.0F, scorer->score()); + + EXPECT_EQ(3u, scorer->advance()); + EXPECT_EQ(3u, scorer->doc()); + EXPECT_FLOAT_EQ(1.0F, scorer->score()); + + EXPECT_EQ(7u, scorer->advance()); + EXPECT_EQ(7u, scorer->doc()); + EXPECT_FLOAT_EQ(2.0F, scorer->score()); + + EXPECT_EQ(8u, scorer->advance()); + EXPECT_EQ(8u, scorer->doc()); + EXPECT_FLOAT_EQ(1.0F, scorer->score()); + + EXPECT_EQ(9u, scorer->advance()); + EXPECT_EQ(9u, scorer->doc()); + EXPECT_FLOAT_EQ(1.0F, scorer->score()); + + EXPECT_EQ(10u, scorer->advance()); + EXPECT_EQ(10u, scorer->doc()); + EXPECT_FLOAT_EQ(1.0F, scorer->score()); + + EXPECT_EQ(13u, scorer->advance()); + EXPECT_EQ(13u, scorer->doc()); + EXPECT_FLOAT_EQ(1.0F, scorer->score()); + + EXPECT_EQ(15u, scorer->advance()); + EXPECT_EQ(15u, scorer->doc()); + EXPECT_FLOAT_EQ(2.0F, scorer->score()); + + EXPECT_EQ(TERMINATED, scorer->advance()); +} + +TEST_F(RequiredOptionalScorerTest, SeekFunctionality) { + auto req_scorer = + std::make_shared(std::vector {1, 3, 7, 8, 9, 10, 13, 15}, 1.0F); + auto opt_scorer = + std::make_shared(std::vector {2, 7, 11, 12, 15}, 1.0F); + auto combiner = std::make_shared(); + + auto scorer = make_required_optional_scorer(std::move(req_scorer), std::move(opt_scorer), + std::move(combiner)); + + EXPECT_FLOAT_EQ(1.0F, scorer->score()); + EXPECT_EQ(7u, scorer->seek(7)); + EXPECT_FLOAT_EQ(2.0F, scorer->score()); + + EXPECT_FLOAT_EQ(2.0F, scorer->score()); + EXPECT_EQ(13u, scorer->seek(12)); + EXPECT_FLOAT_EQ(1.0F, scorer->score()); +} + +TEST_F(RequiredOptionalScorerTest, EmptyRequired) { + auto req_scorer = std::make_shared(std::vector {}, 1.0F); + auto opt_scorer = std::make_shared(std::vector {1, 2, 3}, 1.0F); + auto combiner = std::make_shared(); + + auto scorer = make_required_optional_scorer(std::move(req_scorer), std::move(opt_scorer), + std::move(combiner)); + + EXPECT_EQ(TERMINATED, scorer->doc()); +} + +TEST_F(RequiredOptionalScorerTest, BothEmpty) { + auto req_scorer = std::make_shared(std::vector {}, 1.0F); + auto opt_scorer = std::make_shared(std::vector {}, 1.0F); + auto combiner = std::make_shared(); + + auto scorer = make_required_optional_scorer(std::move(req_scorer), std::move(opt_scorer), + std::move(combiner)); + + EXPECT_EQ(TERMINATED, scorer->doc()); +} + +TEST_F(RequiredOptionalScorerTest, NoIntersection) { + auto req_scorer = std::make_shared(std::vector {1, 3, 5, 7, 9}, 1.0F); + auto opt_scorer = std::make_shared(std::vector {2, 4, 6, 8, 10}, 1.0F); + auto combiner = std::make_shared(); + + auto scorer = make_required_optional_scorer(std::move(req_scorer), std::move(opt_scorer), + std::move(combiner)); + + std::vector docs; + std::vector scores; + while (scorer->doc() != TERMINATED) { + docs.push_back(scorer->doc()); + scores.push_back(scorer->score()); + scorer->advance(); + } + + std::vector expected_docs {1, 3, 5, 7, 9}; + std::vector expected_scores {1.0F, 1.0F, 1.0F, 1.0F, 1.0F}; + EXPECT_EQ(expected_docs, docs); + EXPECT_EQ(expected_scores, scores); +} + +TEST_F(RequiredOptionalScorerTest, FullIntersection) { + auto req_scorer = std::make_shared(std::vector {1, 3, 5}, 1.0F); + auto opt_scorer = std::make_shared(std::vector {1, 3, 5}, 1.0F); + auto combiner = std::make_shared(); + + auto scorer = make_required_optional_scorer(std::move(req_scorer), std::move(opt_scorer), + std::move(combiner)); + + std::vector docs; + std::vector scores; + while (scorer->doc() != TERMINATED) { + docs.push_back(scorer->doc()); + scores.push_back(scorer->score()); + scorer->advance(); + } + + std::vector expected_docs {1, 3, 5}; + std::vector expected_scores {2.0F, 2.0F, 2.0F}; + EXPECT_EQ(expected_docs, docs); + EXPECT_EQ(expected_scores, scores); +} + +TEST_F(RequiredOptionalScorerTest, OptionalLargerThanRequired) { + auto req_scorer = std::make_shared(std::vector {5, 10}, 1.0F); + auto opt_scorer = std::make_shared( + std::vector {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, 1.0F); + auto combiner = std::make_shared(); + + auto scorer = make_required_optional_scorer(std::move(req_scorer), std::move(opt_scorer), + std::move(combiner)); + + EXPECT_EQ(5u, scorer->doc()); + EXPECT_FLOAT_EQ(2.0F, scorer->score()); + + EXPECT_EQ(10u, scorer->advance()); + EXPECT_FLOAT_EQ(2.0F, scorer->score()); + + EXPECT_EQ(TERMINATED, scorer->advance()); +} + +TEST_F(RequiredOptionalScorerTest, DifferentScores) { + auto req_scorer = std::make_shared(std::vector {1, 5, 10}, 2.5F); + auto opt_scorer = std::make_shared(std::vector {5, 10, 15}, 1.5F); + auto combiner = std::make_shared(); + + auto scorer = make_required_optional_scorer(std::move(req_scorer), std::move(opt_scorer), + std::move(combiner)); + + EXPECT_EQ(1u, scorer->doc()); + EXPECT_FLOAT_EQ(2.5F, scorer->score()); + + EXPECT_EQ(5u, scorer->advance()); + EXPECT_FLOAT_EQ(4.0F, scorer->score()); + + EXPECT_EQ(10u, scorer->advance()); + EXPECT_FLOAT_EQ(4.0F, scorer->score()); + + EXPECT_EQ(TERMINATED, scorer->advance()); +} + +TEST_F(RequiredOptionalScorerTest, ScoreCaching) { + auto req_scorer = std::make_shared(std::vector {1, 5, 10}, 1.0F); + auto opt_scorer = std::make_shared(std::vector {1, 5}, 1.0F); + auto combiner = std::make_shared(); + + auto scorer = make_required_optional_scorer(std::move(req_scorer), std::move(opt_scorer), + std::move(combiner)); + + EXPECT_EQ(1u, scorer->doc()); + EXPECT_FLOAT_EQ(2.0F, scorer->score()); + EXPECT_FLOAT_EQ(2.0F, scorer->score()); + EXPECT_FLOAT_EQ(2.0F, scorer->score()); + + scorer->advance(); + EXPECT_FLOAT_EQ(2.0F, scorer->score()); + EXPECT_FLOAT_EQ(2.0F, scorer->score()); +} + +TEST_F(RequiredOptionalScorerTest, AdvanceAfterTerminated) { + auto req_scorer = std::make_shared(std::vector {1}, 1.0F); + auto opt_scorer = std::make_shared(std::vector {1}, 1.0F); + auto combiner = std::make_shared(); + + auto scorer = make_required_optional_scorer(std::move(req_scorer), std::move(opt_scorer), + std::move(combiner)); + + EXPECT_EQ(1u, scorer->doc()); + EXPECT_EQ(TERMINATED, scorer->advance()); + EXPECT_EQ(TERMINATED, scorer->advance()); + EXPECT_EQ(TERMINATED, scorer->advance()); +} + +TEST_F(RequiredOptionalScorerTest, SeekAfterTerminated) { + auto req_scorer = std::make_shared(std::vector {1}, 1.0F); + auto opt_scorer = std::make_shared(std::vector {1}, 1.0F); + auto combiner = std::make_shared(); + + auto scorer = make_required_optional_scorer(std::move(req_scorer), std::move(opt_scorer), + std::move(combiner)); + + EXPECT_EQ(1u, scorer->doc()); + EXPECT_EQ(TERMINATED, scorer->advance()); + EXPECT_EQ(TERMINATED, scorer->seek(100)); +} + +TEST_F(RequiredOptionalScorerTest, SeekToCurrentDoc) { + auto req_scorer = std::make_shared(std::vector {5, 10, 15}, 1.0F); + auto opt_scorer = std::make_shared(std::vector {5, 15}, 1.0F); + auto combiner = std::make_shared(); + + auto scorer = make_required_optional_scorer(std::move(req_scorer), std::move(opt_scorer), + std::move(combiner)); + + EXPECT_EQ(5u, scorer->doc()); + EXPECT_EQ(5u, scorer->seek(5)); + EXPECT_EQ(5u, scorer->seek(3)); +} + +TEST_F(RequiredOptionalScorerTest, SeekPastEnd) { + auto req_scorer = std::make_shared(std::vector {1, 5, 10}, 1.0F); + auto opt_scorer = std::make_shared(std::vector {1, 5}, 1.0F); + auto combiner = std::make_shared(); + + auto scorer = make_required_optional_scorer(std::move(req_scorer), std::move(opt_scorer), + std::move(combiner)); + + EXPECT_EQ(TERMINATED, scorer->seek(100)); +} + +TEST_F(RequiredOptionalScorerTest, SizeHintDelegation) { + auto req_scorer = std::make_shared(std::vector {1, 2, 3}, 1.0F, 100); + auto opt_scorer = std::make_shared(std::vector {1, 2}, 1.0F, 50); + auto combiner = std::make_shared(); + + auto scorer = make_required_optional_scorer(std::move(req_scorer), std::move(opt_scorer), + std::move(combiner)); + + EXPECT_EQ(100u, scorer->size_hint()); +} + +TEST_F(RequiredOptionalScorerTest, DocDoesNotAdvance) { + auto req_scorer = std::make_shared(std::vector {5, 10, 15}, 1.0F); + auto opt_scorer = std::make_shared(std::vector {5}, 1.0F); + auto combiner = std::make_shared(); + + auto scorer = make_required_optional_scorer(std::move(req_scorer), std::move(opt_scorer), + std::move(combiner)); + + EXPECT_EQ(5u, scorer->doc()); + EXPECT_EQ(5u, scorer->doc()); + EXPECT_EQ(5u, scorer->doc()); + + scorer->advance(); + EXPECT_EQ(10u, scorer->doc()); + EXPECT_EQ(10u, scorer->doc()); +} + +TEST_F(RequiredOptionalScorerTest, DoNothingCombiner) { + auto req_scorer = std::make_shared(std::vector {1, 5, 10}, 2.5F); + auto opt_scorer = std::make_shared(std::vector {5, 10}, 1.5F); + auto combiner = std::make_shared(); + + auto scorer = make_required_optional_scorer(std::move(req_scorer), std::move(opt_scorer), + std::move(combiner)); + + EXPECT_EQ(1u, scorer->doc()); + EXPECT_FLOAT_EQ(0.0F, scorer->score()); + + EXPECT_EQ(5u, scorer->advance()); + EXPECT_FLOAT_EQ(0.0F, scorer->score()); +} + +TEST_F(RequiredOptionalScorerTest, LargeDocIds) { + auto req_scorer = std::make_shared( + std::vector {100000000, 200000000, 300000000}, 1.0F); + auto opt_scorer = + std::make_shared(std::vector {200000000, 400000000}, 1.0F); + auto combiner = std::make_shared(); + + auto scorer = make_required_optional_scorer(std::move(req_scorer), std::move(opt_scorer), + std::move(combiner)); + + EXPECT_EQ(100000000u, scorer->doc()); + EXPECT_FLOAT_EQ(1.0F, scorer->score()); + + EXPECT_EQ(200000000u, scorer->advance()); + EXPECT_FLOAT_EQ(2.0F, scorer->score()); + + EXPECT_EQ(300000000u, scorer->advance()); + EXPECT_FLOAT_EQ(1.0F, scorer->score()); + + EXPECT_EQ(TERMINATED, scorer->advance()); +} + +TEST_F(RequiredOptionalScorerTest, RandomData) { + auto req_docs = sample_with_seed(10000, 0.02, 1); + auto opt_docs = sample_with_seed(10000, 0.02, 2); + + auto req_scorer = std::make_shared(req_docs, 1.0F); + auto opt_scorer = std::make_shared(opt_docs, 1.0F); + auto combiner = std::make_shared(); + + auto scorer = make_required_optional_scorer(std::move(req_scorer), std::move(opt_scorer), + std::move(combiner)); + + std::vector actual; + while (scorer->doc() != TERMINATED) { + actual.push_back(scorer->doc()); + scorer->advance(); + } + + EXPECT_EQ(req_docs, actual); +} + +TEST_F(RequiredOptionalScorerTest, RandomDataSeek) { + auto req_docs = sample_with_seed(10000, 0.02, 1); + auto opt_docs = sample_with_seed(10000, 0.02, 2); + auto seek_targets = sample_with_seed(10000, 0.001, 3); + + for (uint32_t target : seek_targets) { + auto req_scorer = std::make_shared(req_docs, 1.0F); + auto opt_scorer = std::make_shared(opt_docs, 1.0F); + auto combiner = std::make_shared(); + + auto scorer = make_required_optional_scorer(std::move(req_scorer), std::move(opt_scorer), + std::move(combiner)); + + auto it = std::lower_bound(req_docs.begin(), req_docs.end(), target); + uint32_t expected_doc = (it == req_docs.end()) ? TERMINATED : *it; + + uint32_t actual_doc = scorer->seek(target); + EXPECT_EQ(expected_doc, actual_doc) << "Failed for target: " << target; + } +} + +TEST_F(RequiredOptionalScorerTest, InterleavedAdvanceAndSeek) { + auto req_scorer = std::make_shared( + std::vector {1, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50}, 1.0F); + auto opt_scorer = std::make_shared(std::vector {5, 20, 35, 50}, 1.0F); + auto combiner = std::make_shared(); + + auto scorer = make_required_optional_scorer(std::move(req_scorer), std::move(opt_scorer), + std::move(combiner)); + + EXPECT_EQ(1u, scorer->doc()); + EXPECT_FLOAT_EQ(1.0F, scorer->score()); + + EXPECT_EQ(5u, scorer->advance()); + EXPECT_FLOAT_EQ(2.0F, scorer->score()); + + EXPECT_EQ(15u, scorer->seek(12)); + EXPECT_FLOAT_EQ(1.0F, scorer->score()); + + EXPECT_EQ(20u, scorer->advance()); + EXPECT_FLOAT_EQ(2.0F, scorer->score()); + + EXPECT_EQ(35u, scorer->seek(35)); + EXPECT_FLOAT_EQ(2.0F, scorer->score()); + + EXPECT_EQ(40u, scorer->advance()); + EXPECT_FLOAT_EQ(1.0F, scorer->score()); +} + +TEST_F(RequiredOptionalScorerTest, SingleDocRequired) { + auto req_scorer = std::make_shared(std::vector {42}, 1.0F); + auto opt_scorer = std::make_shared(std::vector {42}, 1.0F); + auto combiner = std::make_shared(); + + auto scorer = make_required_optional_scorer(std::move(req_scorer), std::move(opt_scorer), + std::move(combiner)); + + EXPECT_EQ(42u, scorer->doc()); + EXPECT_FLOAT_EQ(2.0F, scorer->score()); + EXPECT_EQ(TERMINATED, scorer->advance()); +} + +TEST_F(RequiredOptionalScorerTest, SingleDocRequiredNotMatched) { + auto req_scorer = std::make_shared(std::vector {42}, 1.0F); + auto opt_scorer = std::make_shared(std::vector {1, 2, 3}, 1.0F); + auto combiner = std::make_shared(); + + auto scorer = make_required_optional_scorer(std::move(req_scorer), std::move(opt_scorer), + std::move(combiner)); + + EXPECT_EQ(42u, scorer->doc()); + EXPECT_FLOAT_EQ(1.0F, scorer->score()); + EXPECT_EQ(TERMINATED, scorer->advance()); +} + +TEST_F(RequiredOptionalScorerTest, ScoreCacheResetOnAdvance) { + auto req_scorer = std::make_shared(std::vector {1, 5}, 1.0F); + auto opt_scorer = std::make_shared(std::vector {1}, 1.0F); + auto combiner = std::make_shared(); + + auto scorer = make_required_optional_scorer(std::move(req_scorer), std::move(opt_scorer), + std::move(combiner)); + + EXPECT_FLOAT_EQ(2.0F, scorer->score()); + scorer->advance(); + EXPECT_FLOAT_EQ(1.0F, scorer->score()); +} + +TEST_F(RequiredOptionalScorerTest, ScoreCacheResetOnSeek) { + auto req_scorer = std::make_shared(std::vector {1, 5, 10}, 1.0F); + auto opt_scorer = std::make_shared(std::vector {1, 10}, 1.0F); + auto combiner = std::make_shared(); + + auto scorer = make_required_optional_scorer(std::move(req_scorer), std::move(opt_scorer), + std::move(combiner)); + + EXPECT_FLOAT_EQ(2.0F, scorer->score()); + scorer->seek(5); + EXPECT_FLOAT_EQ(1.0F, scorer->score()); + scorer->seek(10); + EXPECT_FLOAT_EQ(2.0F, scorer->score()); +} + +} // namespace doris::segment_v2::inverted_index::query_v2 diff --git a/be/test/olap/rowset/segment_v2/inverted_index/query_v2/segment_postings_test.cpp b/be/test/olap/rowset/segment_v2/inverted_index/query_v2/segment_postings_test.cpp index 10567b0bf4856d..3febf6ec106583 100644 --- a/be/test/olap/rowset/segment_v2/inverted_index/query_v2/segment_postings_test.cpp +++ b/be/test/olap/rowset/segment_v2/inverted_index/query_v2/segment_postings_test.cpp @@ -20,8 +20,6 @@ #include #include -#include -#include #include #include "CLucene/index/DocRange.h" @@ -30,8 +28,8 @@ namespace doris::segment_v2::inverted_index::query_v2 { class MockTermDocs : public lucene::index::TermDocs { public: - MockTermDocs(std::vector docs, std::vector freqs, std::vector norms, - int32_t doc_freq) + MockTermDocs(std::vector docs, std::vector freqs, + std::vector norms, int32_t doc_freq) : _docs(std::move(docs)), _freqs(std::move(freqs)), _norms(std::move(norms)), @@ -40,149 +38,117 @@ class MockTermDocs : public lucene::index::TermDocs { void seek(lucene::index::Term* term) override {} void seek(lucene::index::TermEnum* termEnum) override {} - int32_t doc() const override { - if (_index >= 0 && _index < static_cast(_docs.size())) { - return _docs[_index]; - } - return INT_MAX; - } - - int32_t freq() const override { - if (_index >= 0 && _index < static_cast(_freqs.size())) { - return _freqs[_index]; - } - return 0; - } - - int32_t norm() const override { - if (_index >= 0 && _index < static_cast(_norms.size())) { - return _norms[_index]; - } - return 1; - } - - bool next() override { - if (_index + 1 < static_cast(_docs.size())) { - ++_index; - return true; - } - return false; - } + int32_t doc() const override { return 0; } + int32_t freq() const override { return 0; } + int32_t norm() const override { return 1; } + bool next() override { return false; } int32_t read(int32_t*, int32_t*, int32_t) override { return 0; } int32_t read(int32_t*, int32_t*, int32_t*, int32_t) override { return 0; } - bool readRange(DocRange*) override { return false; } - bool skipTo(const int32_t target) override { - auto size = static_cast(_docs.size()); - while (_index + 1 < size && _docs[_index + 1] < target) { - ++_index; + bool readRange(DocRange* docRange) override { + if (_read_done || _docs.empty()) { + return false; } - if (_index + 1 < size) { - ++_index; - return true; - } - return false; + docRange->type_ = DocRangeType::kMany; + docRange->doc_many = &_docs; + docRange->freq_many = &_freqs; + docRange->norm_many = &_norms; + docRange->doc_many_size_ = static_cast(_docs.size()); + docRange->freq_many_size_ = static_cast(_freqs.size()); + docRange->norm_many_size_ = static_cast(_norms.size()); + _read_done = true; + return true; } + bool skipTo(const int32_t target) override { return false; } + void skipToBlock(const int32_t target) override {} + void close() override {} lucene::index::TermPositions* __asTermPositions() override { return nullptr; } int32_t docFreq() override { return _doc_freq; } -private: - std::vector _docs; - std::vector _freqs; - std::vector _norms; +protected: + std::vector _docs; + std::vector _freqs; + std::vector _norms; int32_t _doc_freq; - int32_t _index = -1; + bool _read_done = false; }; class MockTermPositions : public lucene::index::TermPositions { public: - MockTermPositions(std::vector docs, std::vector freqs, - std::vector norms, std::vector> positions, + MockTermPositions(std::vector docs, std::vector freqs, + std::vector norms, std::vector> positions, int32_t doc_freq) : _docs(std::move(docs)), _freqs(std::move(freqs)), _norms(std::move(norms)), - _positions(std::move(positions)), - _doc_freq(doc_freq) {} - - void seek(lucene::index::Term* term) override {} - void seek(lucene::index::TermEnum* termEnum) override {} - - int32_t doc() const override { - if (_index >= 0 && _index < static_cast(_docs.size())) { - return _docs[_index]; - } - return INT_MAX; - } - - int32_t freq() const override { - if (_index >= 0 && _index < static_cast(_freqs.size())) { - return _freqs[_index]; + _doc_freq(doc_freq) { + for (const auto& doc_pos : positions) { + uint32_t last_pos = 0; + for (uint32_t pos : doc_pos) { + _deltas.push_back(pos - last_pos); + last_pos = pos; + } } - return 0; } - int32_t norm() const override { - if (_index >= 0 && _index < static_cast(_norms.size())) { - return _norms[_index]; - } - return 1; - } + void seek(lucene::index::Term* term) override {} + void seek(lucene::index::TermEnum* termEnum) override {} - bool next() override { - if (_index + 1 < static_cast(_docs.size())) { - ++_index; - _pos_index = 0; - return true; - } - return false; - } + int32_t doc() const override { return 0; } + int32_t freq() const override { return 0; } + int32_t norm() const override { return 1; } + bool next() override { return false; } int32_t read(int32_t*, int32_t*, int32_t) override { return 0; } int32_t read(int32_t*, int32_t*, int32_t*, int32_t) override { return 0; } - bool readRange(DocRange*) override { return false; } - bool skipTo(const int32_t target) override { - auto size = static_cast(_docs.size()); - while (_index + 1 < size && _docs[_index + 1] < target) { - ++_index; + bool readRange(DocRange* docRange) override { + if (_read_done || _docs.empty()) { + return false; } - _pos_index = 0; - if (_index + 1 < size) { - ++_index; - return true; - } - return false; + docRange->type_ = DocRangeType::kMany; + docRange->doc_many = &_docs; + docRange->freq_many = &_freqs; + docRange->norm_many = &_norms; + docRange->doc_many_size_ = static_cast(_docs.size()); + docRange->freq_many_size_ = static_cast(_freqs.size()); + docRange->norm_many_size_ = static_cast(_norms.size()); + _read_done = true; + return true; } + bool skipTo(const int32_t target) override { return false; } + void skipToBlock(const int32_t target) override {} + void close() override {} lucene::index::TermPositions* __asTermPositions() override { return this; } lucene::index::TermDocs* __asTermDocs() override { return this; } - int32_t nextPosition() override { - if (_index >= 0 && _index < static_cast(_positions.size()) && - _pos_index < _positions[_index].size()) { - return _positions[_index][_pos_index++]; - } - return 0; - } - + int32_t nextPosition() override { return 0; } int32_t getPayloadLength() const override { return 0; } uint8_t* getPayload(uint8_t*) override { return nullptr; } bool isPayloadAvailable() const override { return false; } int32_t docFreq() override { return _doc_freq; } + void addLazySkipProxCount(int32_t count) override { _prox_idx += count; } + int32_t nextDeltaPosition() override { + if (_prox_idx < _deltas.size()) { + return _deltas[_prox_idx++]; + } + return 0; + } + private: - std::vector _docs; - std::vector _freqs; - std::vector _norms; - std::vector> _positions; + std::vector _docs; + std::vector _freqs; + std::vector _norms; + std::vector _deltas; int32_t _doc_freq; - int32_t _index = -1; - size_t _pos_index = 0; + size_t _prox_idx = 0; + bool _read_done = false; }; class SegmentPostingsTest : public testing::Test {}; @@ -205,14 +171,9 @@ TEST_F(SegmentPostingsTest, test_postings_positions_with_offset) { EXPECT_EQ(output[1], 120); } -TEST_F(SegmentPostingsTest, test_segment_postings_base_default_constructor) { - SegmentPostingsBase base; - EXPECT_EQ(base.doc(), TERMINATED); -} - TEST_F(SegmentPostingsTest, test_segment_postings_base_constructor_next_true) { TermDocsPtr ptr(new MockTermDocs({1, 3, 5}, {2, 4, 6}, {1, 1, 1}, 3)); - SegmentPostingsBase base(std::move(ptr)); + SegmentPostings base(std::move(ptr), true); EXPECT_EQ(base.doc(), 1); EXPECT_EQ(base.size_hint(), 3); @@ -222,21 +183,21 @@ TEST_F(SegmentPostingsTest, test_segment_postings_base_constructor_next_true) { TEST_F(SegmentPostingsTest, test_segment_postings_base_constructor_next_false) { TermDocsPtr ptr(new MockTermDocs({}, {}, {}, 0)); - SegmentPostingsBase base(std::move(ptr)); + SegmentPostings base(std::move(ptr)); EXPECT_EQ(base.doc(), TERMINATED); } -TEST_F(SegmentPostingsTest, test_segment_postings_base_constructor_doc_int_max) { - TermDocsPtr ptr(new MockTermDocs({INT_MAX}, {1}, {1}, 1)); - SegmentPostingsBase base(std::move(ptr)); +TEST_F(SegmentPostingsTest, test_segment_postings_base_constructor_doc_terminate) { + TermDocsPtr ptr(new MockTermDocs({TERMINATED}, {1}, {1}, 1)); + SegmentPostings base(std::move(ptr)); EXPECT_EQ(base.doc(), TERMINATED); } TEST_F(SegmentPostingsTest, test_segment_postings_base_advance_success) { TermDocsPtr ptr(new MockTermDocs({1, 3, 5}, {2, 4, 6}, {1, 1, 1}, 3)); - SegmentPostingsBase base(std::move(ptr)); + SegmentPostings base(std::move(ptr)); EXPECT_EQ(base.doc(), 1); EXPECT_EQ(base.advance(), 3); @@ -245,36 +206,36 @@ TEST_F(SegmentPostingsTest, test_segment_postings_base_advance_success) { TEST_F(SegmentPostingsTest, test_segment_postings_base_advance_end) { TermDocsPtr ptr(new MockTermDocs({1}, {2}, {1}, 1)); - SegmentPostingsBase base(std::move(ptr)); + SegmentPostings base(std::move(ptr)); EXPECT_EQ(base.advance(), TERMINATED); } TEST_F(SegmentPostingsTest, test_segment_postings_base_seek_target_le_doc) { TermDocsPtr ptr(new MockTermDocs({1, 3, 5}, {2, 4, 6}, {1, 1, 1}, 3)); - SegmentPostingsBase base(std::move(ptr)); + SegmentPostings base(std::move(ptr)); EXPECT_EQ(base.seek(0), 1); EXPECT_EQ(base.seek(1), 1); } -TEST_F(SegmentPostingsTest, test_segment_postings_base_seek_skipTo_success) { +TEST_F(SegmentPostingsTest, test_segment_postings_base_seek_in_block_success) { TermDocsPtr ptr(new MockTermDocs({1, 3, 5, 7}, {2, 4, 6, 8}, {1, 1, 1, 1}, 4)); - SegmentPostingsBase base(std::move(ptr)); + SegmentPostings base(std::move(ptr)); EXPECT_EQ(base.seek(5), 5); } -TEST_F(SegmentPostingsTest, test_segment_postings_base_seek_skipTo_fail) { +TEST_F(SegmentPostingsTest, test_segment_postings_base_seek_fail) { TermDocsPtr ptr(new MockTermDocs({1, 3, 5}, {2, 4, 6}, {1, 1, 1}, 3)); - SegmentPostingsBase base(std::move(ptr)); + SegmentPostings base(std::move(ptr)); EXPECT_EQ(base.seek(10), TERMINATED); } TEST_F(SegmentPostingsTest, test_segment_postings_base_append_positions_exception) { TermDocsPtr ptr(new MockTermDocs({1}, {2}, {1}, 1)); - SegmentPostingsBase base(std::move(ptr)); + SegmentPostings base(std::move(ptr)); std::vector output; EXPECT_THROW(base.append_positions_with_offset(0, output), Exception); @@ -282,7 +243,7 @@ TEST_F(SegmentPostingsTest, test_segment_postings_base_append_positions_exceptio TEST_F(SegmentPostingsTest, test_segment_postings_termdocs) { TermDocsPtr ptr(new MockTermDocs({1, 3}, {2, 4}, {1, 1}, 2)); - SegmentPostings postings(std::move(ptr)); + SegmentPostings postings(std::move(ptr)); EXPECT_EQ(postings.doc(), 1); EXPECT_EQ(postings.size_hint(), 2); @@ -291,7 +252,7 @@ TEST_F(SegmentPostingsTest, test_segment_postings_termdocs) { TEST_F(SegmentPostingsTest, test_segment_postings_termpositions) { TermPositionsPtr ptr( new MockTermPositions({1, 3}, {2, 3}, {1, 1}, {{10, 20}, {30, 40, 50}}, 2)); - SegmentPostings postings(std::move(ptr)); + SegmentPostings postings(std::move(ptr), true); EXPECT_EQ(postings.doc(), 1); EXPECT_EQ(postings.freq(), 2); @@ -300,7 +261,7 @@ TEST_F(SegmentPostingsTest, test_segment_postings_termpositions) { TEST_F(SegmentPostingsTest, test_segment_postings_termpositions_append_positions) { TermPositionsPtr ptr( new MockTermPositions({1, 3}, {2, 3}, {1, 1}, {{10, 20}, {30, 40, 50}}, 2)); - SegmentPostings postings(std::move(ptr)); + SegmentPostings postings(std::move(ptr), true); std::vector output = {999}; postings.append_positions_with_offset(100, output); @@ -313,22 +274,11 @@ TEST_F(SegmentPostingsTest, test_segment_postings_termpositions_append_positions TEST_F(SegmentPostingsTest, test_no_score_segment_posting) { TermDocsPtr ptr(new MockTermDocs({1, 3}, {5, 7}, {10, 20}, 2)); - NoScoreSegmentPosting posting(std::move(ptr)); + SegmentPostings posting(std::move(ptr)); EXPECT_EQ(posting.doc(), 1); EXPECT_EQ(posting.freq(), 1); EXPECT_EQ(posting.norm(), 1); } -TEST_F(SegmentPostingsTest, test_empty_segment_posting) { - EmptySegmentPosting posting; - - EXPECT_EQ(posting.doc(), TERMINATED); - EXPECT_EQ(posting.size_hint(), 0); - EXPECT_EQ(posting.freq(), 1); - EXPECT_EQ(posting.norm(), 1); - EXPECT_EQ(posting.advance(), TERMINATED); - EXPECT_EQ(posting.seek(100), TERMINATED); -} - } // namespace doris::segment_v2::inverted_index::query_v2 \ No newline at end of file diff --git a/contrib/clucene b/contrib/clucene index a8d1f58f393ef3..8b57674e9d7876 160000 --- a/contrib/clucene +++ b/contrib/clucene @@ -1 +1 @@ -Subproject commit a8d1f58f393ef3ed13cedf82c77a3581ab5d57ef +Subproject commit 8b57674e9d78769b10aa0c1441cd12671a394745