Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions apps/search_memory_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,19 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path,
if (filtered_search && !tags)
{
std::string raw_filter = query_filters.size() == 1 ? query_filters[0] : query_filters[i];

auto retval = index->search_with_filters(query + i * query_aligned_dim, raw_filter, recall_at, L,
if (diverse_search){
auto retval = index->diverse_search_with_filters(query + i * query_aligned_dim, raw_filter, recall_at, L,
query_result_ids[test_id].data() + i * recall_at,
query_result_dists[test_id].data() + i * recall_at, maxLperSeller);
cmp_stats[i] = retval.second;
}
else
{
auto retval = index->search_with_filters(query + i * query_aligned_dim, raw_filter, recall_at, L,
query_result_ids[test_id].data() + i * recall_at,
query_result_dists[test_id].data() + i * recall_at);
cmp_stats[i] = retval.second;
cmp_stats[i] = retval.second;
}
}
else if (metric == diskann::FAST_L2)
{
Expand Down
9 changes: 9 additions & 0 deletions include/abstract_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ class AbstractIndex
std::pair<uint32_t, uint32_t> search_with_filters(const DataType &query, const std::string &raw_label,
const size_t K, const uint32_t L, IndexType *indices,
float *distances);
// Filter support + diverse search
// IndexType is either uint32_t or uint64_t
template <typename IndexType>
std::pair<uint32_t, uint32_t> diverse_search_with_filters(const DataType &query, const std::string &raw_label,
const size_t K, const uint32_t L, IndexType *indices,
float *distances, const uint32_t maxLperSeller);

// insert points with labels, labels should be present for filtered index
template <typename data_type, typename tag_type, typename label_type>
Expand Down Expand Up @@ -119,6 +125,9 @@ class AbstractIndex
virtual std::pair<uint32_t, uint32_t> _search_with_filters(const DataType &query, const std::string &filter_label,
const size_t K, const uint32_t L, std::any &indices,
float *distances) = 0;
virtual std::pair<uint32_t, uint32_t> _diverse_search_with_filters(const DataType &query, const std::string &filter_label,
const size_t K, const uint32_t L, std::any &indices,
float *distances, const uint32_t maxLperSeller) = 0;
virtual int _insert_point(const DataType &data_point, const TagType tag, Labelvector &labels) = 0;
virtual int _insert_point(const DataType &data_point, const TagType tag) = 0;
virtual int _lazy_delete(const TagType &tag) = 0;
Expand Down
12 changes: 11 additions & 1 deletion include/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,12 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
template <typename IndexType>
DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> search_with_filters(const T *query, const LabelT &filter_label,
const size_t K, const uint32_t L,
IndexType *indices, float *distances);
IndexType *indices, float *distances, const uint32_t maxLperSeller = 0);
// Filter support + diverse search
template <typename IndexType>
DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> diverse_search_with_filters(const T *query, const LabelT &filter_label,
const size_t K, const uint32_t L,
IndexType *indices, float *distances, const uint32_t maxLperSeller = 0);

// Will fail if tag already in the index or if tag=0.
DISKANN_DLLEXPORT int insert_point(const T *point, const TagT tag);
Expand Down Expand Up @@ -216,6 +221,11 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
float *distances) override;
virtual std::pair<uint32_t, uint32_t> _diverse_search(const DataType &query, const size_t K, const uint32_t L, const uint32_t maxLperSeller,
std::any &indices, float *distances = nullptr) override;
virtual std::pair<uint32_t, uint32_t> _diverse_search_with_filters(const DataType &query,
const std::string &filter_label_raw, const size_t K,
const uint32_t L, std::any &indices,
float *distances, const uint32_t maxLperSeller) override;


virtual int _insert_point(const DataType &data_point, const TagType tag) override;
virtual int _insert_point(const DataType &data_point, const TagType tag, Labelvector &labels) override;
Expand Down
15 changes: 15 additions & 0 deletions src/abstract_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,15 @@ std::pair<uint32_t, uint32_t> AbstractIndex::search_with_filters(const DataType
return _search_with_filters(query, raw_label, K, L, any_indices, distances);
}

template <typename IndexType>
std::pair<uint32_t, uint32_t> AbstractIndex::diverse_search_with_filters(const DataType &query, const std::string &raw_label,
const size_t K, const uint32_t L, IndexType *indices,
float *distances, const uint32_t maxLperSeller)
{
auto any_indices = std::any(indices);
return _diverse_search_with_filters(query, raw_label, K, L, any_indices, distances, maxLperSeller);
}

template <typename data_type>
void AbstractIndex::search_with_optimized_layout(const data_type *query, size_t K, size_t L, uint32_t *indices)
{
Expand Down Expand Up @@ -178,7 +187,13 @@ template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::diverse_
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::diverse_search<int8_t, uint64_t>(
const int8_t *query, const size_t K, const uint32_t L, const uint32_t maxL,uint64_t *indices, float *distances);

template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::diverse_search_with_filters<uint32_t>(
const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, uint32_t *indices,
float *distances, const uint32_t maxLperSeller);

template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::diverse_search_with_filters<uint64_t>(
const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, uint64_t *indices,
float *distances, const uint32_t maxLperSeller);

template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::search_with_filters<uint32_t>(
const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, uint32_t *indices,
Expand Down
82 changes: 55 additions & 27 deletions src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2318,11 +2318,34 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::_search_with_filters(const
}
}

template <typename T, typename TagT, typename LabelT>
std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::_diverse_search_with_filters(const DataType &query,
const std::string &raw_label, const size_t K,
const uint32_t L, std::any &indices,
float *distances, const uint32_t maxLperSeller)
{
auto converted_label = this->get_converted_label(raw_label);
if (typeid(uint64_t *) == indices.type())
{
auto ptr = std::any_cast<uint64_t *>(indices);
return this->search_with_filters(std::any_cast<T *>(query), converted_label, K, L, ptr, distances, maxLperSeller);
}
else if (typeid(uint32_t *) == indices.type())
{
auto ptr = std::any_cast<uint32_t *>(indices);
return this->search_with_filters(std::any_cast<T *>(query), converted_label, K, L, ptr, distances, maxLperSeller);
}
else
{
throw ANNException("Error: Id type can only be uint64_t or uint32_t.", -1);
}
}

template <typename T, typename TagT, typename LabelT>
template <typename IdType>
std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::search_with_filters(const T *query, const LabelT &filter_label,
const size_t K, const uint32_t L,
IdType *indices, float *distances)
IdType *indices, float *distances, const uint32_t maxLperSeller)
{
if (K > (uint64_t)L)
{
Expand Down Expand Up @@ -2364,9 +2387,14 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::search_with_filters(const
filter_vec.emplace_back(filter_label);

_data_store->preprocess_query(query, scratch);
auto retval = iterate_to_fixed_point(scratch, L, init_ids, true, filter_vec, true);
auto retval = iterate_to_fixed_point(scratch, L, init_ids, true, filter_vec, true, maxLperSeller);

auto best_L_nodes = scratch->best_l_nodes();
NeighborPriorityQueue best_L_nodes;
if (maxLperSeller == 0) {
best_L_nodes = scratch->best_l_nodes();
} else {
best_L_nodes = scratch->best_diverse_nodes().best_L_nodes;
}

size_t pos = 0;
for (size_t i = 0; i < best_L_nodes.size(); ++i)
Expand Down Expand Up @@ -3660,41 +3688,41 @@ template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint32_t,

template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint64_t, uint32_t>::search_with_filters<
uint64_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
float *distances);
float *distances, const uint32_t maxLperSeller);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint64_t, uint32_t>::search_with_filters<
uint32_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
float *distances);
float *distances, const uint32_t maxLperSeller);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, uint64_t, uint32_t>::search_with_filters<
uint64_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
float *distances);
float *distances, const uint32_t maxLperSeller);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, uint64_t, uint32_t>::search_with_filters<
uint32_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
float *distances);
float *distances, const uint32_t maxLperSeller);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint64_t, uint32_t>::search_with_filters<
uint64_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
float *distances);
float *distances, const uint32_t maxLperSeller);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint64_t, uint32_t>::search_with_filters<
uint32_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
float *distances);
float *distances, const uint32_t maxLperSeller);
// TagT==uint32_t
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint32_t, uint32_t>::search_with_filters<
uint64_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
float *distances);
float *distances, const uint32_t maxLperSeller);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint32_t, uint32_t>::search_with_filters<
uint32_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
float *distances);
float *distances, const uint32_t maxLperSeller);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, uint32_t, uint32_t>::search_with_filters<
uint64_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
float *distances);
float *distances, const uint32_t maxLperSeller);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, uint32_t, uint32_t>::search_with_filters<
uint32_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
float *distances);
float *distances, const uint32_t maxLperSeller);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint32_t, uint32_t>::search_with_filters<
uint64_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
float *distances);
float *distances, const uint32_t maxLperSeller);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint32_t, uint32_t>::search_with_filters<
uint32_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
float *distances);
float *distances, const uint32_t maxLperSeller);

template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint64_t, uint16_t>::search<uint64_t>(
const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller);
Expand Down Expand Up @@ -3724,40 +3752,40 @@ template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint32_t,

template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint64_t, uint16_t>::search_with_filters<
uint64_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
float *distances);
float *distances, const uint32_t maxLperSeller);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint64_t, uint16_t>::search_with_filters<
uint32_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
float *distances);
float *distances, const uint32_t maxLperSeller);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, uint64_t, uint16_t>::search_with_filters<
uint64_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
float *distances);
float *distances, const uint32_t maxLperSeller);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, uint64_t, uint16_t>::search_with_filters<
uint32_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
float *distances);
float *distances, const uint32_t maxLperSeller);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint64_t, uint16_t>::search_with_filters<
uint64_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
float *distances);
float *distances, const uint32_t maxLperSeller);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint64_t, uint16_t>::search_with_filters<
uint32_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
float *distances);
float *distances, const uint32_t maxLperSeller);
// TagT==uint32_t
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint32_t, uint16_t>::search_with_filters<
uint64_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
float *distances);
float *distances, const uint32_t maxLperSeller);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint32_t, uint16_t>::search_with_filters<
uint32_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
float *distances);
float *distances, const uint32_t maxLperSeller);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, uint32_t, uint16_t>::search_with_filters<
uint64_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
float *distances);
float *distances, const uint32_t maxLperSeller);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, uint32_t, uint16_t>::search_with_filters<
uint32_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
float *distances);
float *distances, const uint32_t maxLperSeller);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint32_t, uint16_t>::search_with_filters<
uint64_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
float *distances);
float *distances, const uint32_t maxLperSeller);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint32_t, uint16_t>::search_with_filters<
uint32_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
float *distances);
float *distances, const uint32_t maxLperSeller);

} // namespace diskann
Loading