diff --git a/apps/search_memory_index.cpp b/apps/search_memory_index.cpp index 12fc918b5..ea1144969 100644 --- a/apps/search_memory_index.cpp +++ b/apps/search_memory_index.cpp @@ -232,11 +232,19 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, if (filtered_search && !tags) { std::string raw_filter = query_filters.size() == 1 ? query_filters[0] : query_filters[i]; - - auto retval = index->search_with_filters(query + i * query_aligned_dim, raw_filter, recall_at, L, + if (diverse_search){ + auto retval = index->diverse_search_with_filters(query + i * query_aligned_dim, raw_filter, recall_at, L, + query_result_ids[test_id].data() + i * recall_at, + query_result_dists[test_id].data() + i * recall_at, maxLperSeller); + cmp_stats[i] = retval.second; + } + else + { + auto retval = index->search_with_filters(query + i * query_aligned_dim, raw_filter, recall_at, L, query_result_ids[test_id].data() + i * recall_at, query_result_dists[test_id].data() + i * recall_at); - cmp_stats[i] = retval.second; + cmp_stats[i] = retval.second; + } } else if (metric == diskann::FAST_L2) { diff --git a/include/abstract_index.h b/include/abstract_index.h index c8b01105c..397fb24d5 100644 --- a/include/abstract_index.h +++ b/include/abstract_index.h @@ -82,6 +82,12 @@ class AbstractIndex std::pair search_with_filters(const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, IndexType *indices, float *distances); + // Filter support + diverse search + // IndexType is either uint32_t or uint64_t + template + std::pair diverse_search_with_filters(const DataType &query, const std::string &raw_label, + const size_t K, const uint32_t L, IndexType *indices, + float *distances, const uint32_t maxLperSeller); // insert points with labels, labels should be present for filtered index template @@ -119,6 +125,9 @@ class AbstractIndex virtual std::pair _search_with_filters(const DataType &query, const std::string &filter_label, const size_t K, const uint32_t L, std::any &indices, float *distances) = 0; + virtual std::pair _diverse_search_with_filters(const DataType &query, const std::string &filter_label, + const size_t K, const uint32_t L, std::any &indices, + float *distances, const uint32_t maxLperSeller) = 0; virtual int _insert_point(const DataType &data_point, const TagType tag, Labelvector &labels) = 0; virtual int _insert_point(const DataType &data_point, const TagType tag) = 0; virtual int _lazy_delete(const TagType &tag) = 0; diff --git a/include/index.h b/include/index.h index 71a1c37f2..0c5dd3a97 100644 --- a/include/index.h +++ b/include/index.h @@ -147,7 +147,12 @@ template clas template DISKANN_DLLEXPORT std::pair search_with_filters(const T *query, const LabelT &filter_label, const size_t K, const uint32_t L, - IndexType *indices, float *distances); + IndexType *indices, float *distances, const uint32_t maxLperSeller = 0); + // Filter support + diverse search + template + DISKANN_DLLEXPORT std::pair diverse_search_with_filters(const T *query, const LabelT &filter_label, + const size_t K, const uint32_t L, + IndexType *indices, float *distances, const uint32_t maxLperSeller = 0); // Will fail if tag already in the index or if tag=0. DISKANN_DLLEXPORT int insert_point(const T *point, const TagT tag); @@ -216,6 +221,11 @@ template clas float *distances) override; virtual std::pair _diverse_search(const DataType &query, const size_t K, const uint32_t L, const uint32_t maxLperSeller, std::any &indices, float *distances = nullptr) override; + virtual std::pair _diverse_search_with_filters(const DataType &query, + const std::string &filter_label_raw, const size_t K, + const uint32_t L, std::any &indices, + float *distances, const uint32_t maxLperSeller) override; + virtual int _insert_point(const DataType &data_point, const TagType tag) override; virtual int _insert_point(const DataType &data_point, const TagType tag, Labelvector &labels) override; diff --git a/src/abstract_index.cpp b/src/abstract_index.cpp index fd2aafa20..e0d8a6227 100644 --- a/src/abstract_index.cpp +++ b/src/abstract_index.cpp @@ -51,6 +51,15 @@ std::pair AbstractIndex::search_with_filters(const DataType return _search_with_filters(query, raw_label, K, L, any_indices, distances); } +template +std::pair AbstractIndex::diverse_search_with_filters(const DataType &query, const std::string &raw_label, + const size_t K, const uint32_t L, IndexType *indices, + float *distances, const uint32_t maxLperSeller) +{ + auto any_indices = std::any(indices); + return _diverse_search_with_filters(query, raw_label, K, L, any_indices, distances, maxLperSeller); +} + template void AbstractIndex::search_with_optimized_layout(const data_type *query, size_t K, size_t L, uint32_t *indices) { @@ -178,7 +187,13 @@ template DISKANN_DLLEXPORT std::pair AbstractIndex::diverse_ template DISKANN_DLLEXPORT std::pair AbstractIndex::diverse_search( const int8_t *query, const size_t K, const uint32_t L, const uint32_t maxL,uint64_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair AbstractIndex::diverse_search_with_filters( + const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, uint32_t *indices, + float *distances, const uint32_t maxLperSeller); +template DISKANN_DLLEXPORT std::pair AbstractIndex::diverse_search_with_filters( + const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, uint64_t *indices, + float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair AbstractIndex::search_with_filters( const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, uint32_t *indices, diff --git a/src/index.cpp b/src/index.cpp index f12415dcf..5781dd319 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2318,11 +2318,34 @@ std::pair Index::_search_with_filters(const } } +template +std::pair Index::_diverse_search_with_filters(const DataType &query, + const std::string &raw_label, const size_t K, + const uint32_t L, std::any &indices, + float *distances, const uint32_t maxLperSeller) +{ + auto converted_label = this->get_converted_label(raw_label); + if (typeid(uint64_t *) == indices.type()) + { + auto ptr = std::any_cast(indices); + return this->search_with_filters(std::any_cast(query), converted_label, K, L, ptr, distances, maxLperSeller); + } + else if (typeid(uint32_t *) == indices.type()) + { + auto ptr = std::any_cast(indices); + return this->search_with_filters(std::any_cast(query), converted_label, K, L, ptr, distances, maxLperSeller); + } + else + { + throw ANNException("Error: Id type can only be uint64_t or uint32_t.", -1); + } +} + template template std::pair Index::search_with_filters(const T *query, const LabelT &filter_label, const size_t K, const uint32_t L, - IdType *indices, float *distances) + IdType *indices, float *distances, const uint32_t maxLperSeller) { if (K > (uint64_t)L) { @@ -2364,9 +2387,14 @@ std::pair Index::search_with_filters(const filter_vec.emplace_back(filter_label); _data_store->preprocess_query(query, scratch); - auto retval = iterate_to_fixed_point(scratch, L, init_ids, true, filter_vec, true); + auto retval = iterate_to_fixed_point(scratch, L, init_ids, true, filter_vec, true, maxLperSeller); - auto best_L_nodes = scratch->best_l_nodes(); + NeighborPriorityQueue best_L_nodes; + if (maxLperSeller == 0) { + best_L_nodes = scratch->best_l_nodes(); + } else { + best_L_nodes = scratch->best_diverse_nodes().best_L_nodes; + } size_t pos = 0; for (size_t i = 0; i < best_L_nodes.size(); ++i) @@ -3660,41 +3688,41 @@ template DISKANN_DLLEXPORT std::pair Index Index::search_with_filters< uint64_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, - float *distances); + float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint32_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, - float *distances); + float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, - float *distances); + float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint32_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, - float *distances); + float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, - float *distances); + float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint32_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, - float *distances); + float *distances, const uint32_t maxLperSeller); // TagT==uint32_t template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, - float *distances); + float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint32_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, - float *distances); + float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, - float *distances); + float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint32_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, - float *distances); + float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, - float *distances); + float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint32_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, - float *distances); + float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller); @@ -3724,40 +3752,40 @@ template DISKANN_DLLEXPORT std::pair Index Index::search_with_filters< uint64_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, - float *distances); + float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint32_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, - float *distances); + float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, - float *distances); + float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint32_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, - float *distances); + float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, - float *distances); + float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint32_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, - float *distances); + float *distances, const uint32_t maxLperSeller); // TagT==uint32_t template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, - float *distances); + float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint32_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, - float *distances); + float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, - float *distances); + float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint32_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, - float *distances); + float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, - float *distances); + float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint32_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, - float *distances); + float *distances, const uint32_t maxLperSeller); } // namespace diskann