Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/disk_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ namespace diskann {
const _u64 nshards, unsigned max_degree,
const std::string &output_vamana,
const std::string &medoids_file, bool use_filters = false,
const std::string &labels_to_medoids_file = std::string(""));
const std::string &labels_to_medoids_file = std::string(""));

DISKANN_DLLEXPORT void extract_shard_labels(
const std::string &in_label_file, const std::string &shard_ids_bin,
Expand Down
4 changes: 3 additions & 1 deletion include/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ namespace diskann {
DISKANN_DLLEXPORT consolidation_report
consolidate_deletes(const Parameters &parameters);

DISKANN_DLLEXPORT void prune_all_nbrs(const Parameters &parameters);

DISKANN_DLLEXPORT bool is_index_saved();

// repositions frozen points to the end of _data - if they have been moved
Expand Down Expand Up @@ -238,7 +240,7 @@ namespace diskann {
// determines navigating node of the graph by calculating medoid of datafopt
unsigned calculate_entry_point();

void parse_label_file(const std::string &map_file);
size_t parse_label_file(const std::string &map_file);

std::pair<uint32_t, uint32_t> iterate_to_fixed_point(
const T *node_coords, const unsigned Lindex,
Expand Down
7 changes: 6 additions & 1 deletion src/disk_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ namespace diskann {
}

std::ofstream mapping_writer(labels_to_medoids_file);
assert(mapping_writer.is_open());
for (auto iter : global_label_to_medoids) {
mapping_writer << iter.first << ", ";
auto &vec = iter.second;
Expand Down Expand Up @@ -534,6 +535,7 @@ namespace diskann {
diskann::cout << labels_per_point.size() << " is the new number of points"
<< std::endl;
std::ofstream label_writer(out_labels_file);
assert(label_writer.is_open());
for (_u32 i = 0; i < labels_per_point.size(); i++) {
for (_u32 j = 0; j < (labels_per_point[i].size() - 1); j++) {
label_writer << labels_per_point[i][j] << ",";
Expand All @@ -551,6 +553,7 @@ namespace diskann {
data = (T *) std::realloc((void *) data,
labels_per_point.size() * ndims * sizeof(T));
std::ofstream dummy_writer(out_metadata_file);
assert(dummy_writer.is_open());
for (auto i = dummy_pt_ids.begin(); i != dummy_pt_ids.end(); i++) {
dummy_writer << i->first << "," << i->second << std::endl;
std::memcpy(data + i->first * ndims, data + i->second * ndims,
Expand All @@ -566,7 +569,7 @@ namespace diskann {
const std::string &in_label_file, const std::string &shard_ids_bin,
const std::string &shard_label_file) { // assumes ith row is for ith
// point in labels file
std::cout << "Extracting labels for shard" << std::endl;
diskann::cout << "Extracting labels for shard" << std::endl;

_u32 *ids = nullptr;
_u64 num_ids, tmp_dim;
Expand All @@ -578,6 +581,8 @@ namespace diskann {
std::ifstream label_reader(in_label_file);
std::ofstream label_writer(shard_label_file);

assert(label_reader.is_open());
assert(label_writer.is_open());
if (label_reader && label_writer) {
while (std::getline(label_reader, cur_line)) {
if (shard_counter >= num_ids) {
Expand Down
219 changes: 146 additions & 73 deletions src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,22 +279,24 @@ namespace diskann {
if (_filter_to_medoid_id.size() > 0) {
std::ofstream medoid_writer(std::string(filename) +
"_labels_to_medoids.txt");
assert(medoid_writer.is_open());
for (auto iter : _filter_to_medoid_id) {
medoid_writer << iter.first << ", " << iter.second << std::endl;
// std::cout << iter.first << ", " << iter.second << std::endl;
}
medoid_writer.close();
}

if (_use_universal_label) {
std::ofstream universal_label_writer(std::string(filename) +
"_universal_label.txt");
assert(universal_label_writer.is_open());
universal_label_writer << _universal_label << std::endl;
universal_label_writer.close();
}

if (_pts_to_labels.size() > 0) {
std::ofstream label_writer(std::string(filename) + "_labels.txt");
assert(label_writer.is_open());
for (_u32 i = 0; i < _pts_to_labels.size(); i++) {
for (_u32 j = 0; j < (_pts_to_labels[i].size() - 1); j++) {
label_writer << _pts_to_labels[i][j] << ",";
Expand Down Expand Up @@ -477,52 +479,12 @@ namespace diskann {

_has_built = true;

size_t tags_file_num_pts = 0, graph_num_pts = 0, data_file_num_pts = 0;
size_t tags_file_num_pts = 0, graph_num_pts = 0, data_file_num_pts = 0, label_num_pts = 0;

std::string mem_index_file(filename);
std::string labels_file = mem_index_file + "_labels.txt";
std::string labels_to_medoids = mem_index_file + "_labels_to_medoids.txt";
if (file_exists(labels_file)) {
parse_label_file(labels_file);
if (file_exists(labels_to_medoids)) {
std::ifstream medoid_stream(labels_to_medoids);

std::string line, token;
unsigned line_cnt = 0;

_filter_to_medoid_id.clear();

while (std::getline(medoid_stream, line)) {
std::istringstream iss(line);
_u32 cnt = 0;
_u32 medoid = 0;
label label;
while (std::getline(iss, token, ',')) {
token.erase(std::remove(token.begin(), token.end(), '\n'),
token.end());
token.erase(std::remove(token.begin(), token.end(), '\r'),
token.end());
unsigned token_as_num = std::stoul(token);
if (cnt == 0)
label = token_as_num;
else
medoid = token_as_num;
cnt++;
}
_filter_to_medoid_id[label] = medoid;
line_cnt++;
}
}

std::string universal_label_file(filename);
universal_label_file += "_universal_label.txt";
if (file_exists(universal_label_file)) {
std::ifstream universal_label_reader(universal_label_file);
universal_label_reader >> _universal_label;
_use_universal_label = true;
universal_label_reader.close();
}
}


if (!_save_as_one_file) {
// For DLVS Store, we will not support saving the index in multiple files.
Expand Down Expand Up @@ -561,6 +523,53 @@ namespace diskann {
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__,
__LINE__);
}
if (file_exists(labels_file)) {
label_num_pts = parse_label_file(labels_file);
assert(label_num_pts == data_file_num_pts);
if (file_exists(labels_to_medoids)) {
std::ifstream medoid_stream(labels_to_medoids);
assert(medoid_stream.is_open());
std::string line, token;
unsigned line_cnt = 0;

_filter_to_medoid_id.clear();
try {
while (std::getline(medoid_stream, line)) {
std::istringstream iss(line);
_u32 cnt = 0;
_u32 medoid = 0;
label label;
while (std::getline(iss, token, ',')) {
token.erase(std::remove(token.begin(), token.end(), '\n'),
token.end());
token.erase(std::remove(token.begin(), token.end(), '\r'),
token.end());
unsigned token_as_num = std::stoul(token);
if (cnt == 0)
label = token_as_num;
else
medoid = token_as_num;
cnt++;
}
_filter_to_medoid_id[label] = medoid;
line_cnt++;
}
} catch (std::system_error &e) {
throw FileException(labels_to_medoids, e, __FUNCSIG__, __FILE__,
__LINE__);
}
}

std::string universal_label_file(filename);
universal_label_file += "_universal_label.txt";
if (file_exists(universal_label_file)) {
std::ifstream universal_label_reader(universal_label_file);
assert(universal_label_reader.is_open());
universal_label_reader >> _universal_label;
_use_universal_label = true;
universal_label_reader.close();
}
}

_nd = data_file_num_pts - _num_frozen_pts;
_empty_slots.clear();
Expand Down Expand Up @@ -688,13 +697,13 @@ namespace diskann {
}
}
#else

size_t bytes_read = vamana_metadata_size;
size_t cc = 0;
unsigned nodes_read = 0;
while (bytes_read != expected_file_size) {
unsigned k;
in.read((char *) &k, sizeof(unsigned));

if (k == 0) {
diskann::cerr << "ERROR: Point found with no out-neighbors, point#"
<< nodes_read << std::endl;
Expand Down Expand Up @@ -1378,6 +1387,65 @@ namespace diskann {
}
}

template<typename T, typename TagT>
void Index<T, TagT>::prune_all_nbrs(const Parameters &parameters) {
const unsigned range = parameters.Get<unsigned>("R");
const unsigned maxc = parameters.Get<unsigned>("C");
const float alpha = parameters.Get<unsigned>("alpha");
_filtered_index = true;

diskann::Timer timer;
#pragma omp parallel for
for (_s64 node = 0; node < (_s64)(_max_points + _num_frozen_pts); node++) {
if ((size_t) node < _nd || (size_t) node == _max_points) {
if (_final_graph[node].size() > range) {
tsl::robin_set<unsigned> dummy_visited(0);
std::vector<Neighbor> dummy_pool(0);
std::vector<unsigned> new_out_neighbors;

ScratchStoreManager<InMemQueryScratch<T>> manager(_query_scratch);
auto scratch = manager.scratch_space();

for (auto cur_nbr : _final_graph[node]) {
if (dummy_visited.find(cur_nbr) == dummy_visited.end() &&
cur_nbr != node) {
float dist =
_distance->compare(_data + _aligned_dim * (size_t) node,
_data + _aligned_dim * (size_t) cur_nbr,
(unsigned) _aligned_dim);
dummy_pool.emplace_back(Neighbor(cur_nbr, dist));
dummy_visited.insert(cur_nbr);
}
}

prune_neighbors((_u32) node, dummy_pool, range, maxc, alpha, new_out_neighbors, scratch);
_final_graph[node].clear();
for (auto id : new_out_neighbors)
_final_graph[node].emplace_back(id);
}
}
}

diskann::cout << "Prune time : " << timer.elapsed() / 1000 << "ms"
<< std::endl;
size_t max = 0, min = 1 << 30, total = 0, cnt = 0;
for (size_t i = 0; i < (_nd + _num_frozen_pts); i++) {
std::vector<unsigned> pool = _final_graph[i];
max = (std::max)(max, pool.size());
min = (std::min)(min, pool.size());
total += pool.size();
if (pool.size() < 2)
cnt++;
}
if (min > max)
min = max;
if (_nd > 0) {
diskann::cout << "Index built with degree: max:" << max << " avg:"
<< (float) total / (float) (_nd + _num_frozen_pts)
<< " min:" << min << " count(deg<2):" << cnt << std::endl;
}
}

template<typename T, typename TagT>
void Index<T, TagT>::set_start_point(T *data) {
std::unique_lock<std::shared_timed_mutex> ul(_update_lock);
Expand Down Expand Up @@ -1648,13 +1716,13 @@ namespace diskann {
}

template<typename T, typename TagT>
void Index<T, TagT>::parse_label_file(const std::string &map_file) {
size_t Index<T, TagT>::parse_label_file(const std::string &label_file) {
// Format of Label txt file: filters with comma separators

std::ifstream infile(map_file);
std::ifstream infile(label_file);
assert(infile.is_open());
std::string line, token;
unsigned line_cnt = 0;

while (std::getline(infile, line)) {
line_cnt++;
}
Expand All @@ -1663,28 +1731,35 @@ namespace diskann {
infile.clear();
infile.seekg(0, std::ios::beg);
line_cnt = 0;
while (std::getline(infile, line)) {
std::istringstream iss(line);
std::vector<label> lbls(0);
getline(iss, token, '\t');
std::istringstream new_iss(token);
while (getline(new_iss, token, ',')) {
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
unsigned token_as_num = std::stoul(token);
lbls.push_back(token_as_num);
_labels.insert(token_as_num);
}
if (lbls.size() <= 0) {
std::cout << "No label found";
exit(-1);
try {
while (std::getline(infile, line)) {
std::istringstream iss(line);
std::vector<label> lbls(0);
getline(iss, token, '\t');
std::istringstream new_iss(token);
while (getline(new_iss, token, ',')) {
token.erase(std::remove(token.begin(), token.end(), '\n'),
token.end());
token.erase(std::remove(token.begin(), token.end(), '\r'),
token.end());
unsigned token_as_num = std::stoul(token);
lbls.push_back(token_as_num);
_labels.insert(token_as_num);
}
if (lbls.size() <= 0) {
diskann::cout << "No label found";
exit(-1);
}
std::sort(lbls.begin(), lbls.end());
_pts_to_labels[line_cnt] = lbls;
line_cnt++;
}
std::sort(lbls.begin(), lbls.end());
_pts_to_labels[line_cnt] = lbls;
line_cnt++;
diskann::cout << "Identified " << _labels.size() << " distinct label(s)"
<< std::endl;
} catch (std::system_error &e) {
throw FileException(label_file, e, __FUNCSIG__, __FILE__, __LINE__);
}
std::cout << "Identified " << _labels.size() << " distinct label(s)"
<< std::endl;
return line_cnt;
}

template<typename T, typename TagT>
Expand All @@ -1702,9 +1777,9 @@ namespace diskann {
_labels_file = label_file;
_filtered_index = true;
_filter_to_medoid_id.clear();
parse_label_file(label_file); // determines medoid for each label and
size_t label_num_pts = parse_label_file(label_file); // determines medoid for each label and
// identifies the points to label mapping

assert(label_num_pts != num_points_to_load);
_u32 counter = 0;
#pragma omp parallel for schedule(dynamic, 1)
for (int lbl = 0; lbl < _labels.size(); lbl++) {
Expand Down Expand Up @@ -1744,15 +1819,13 @@ namespace diskann {
_filter_to_medoid_id[x] = best_medoid;
_medoid_counts[best_medoid]++;
std::stringstream a;
// a << "Medoid of " << x << " is " << best_medoid <<
// std::endl; std::cout << a.str();
}
}
#pragma omp critical
counter++;
std::stringstream a;
a << ((100.0 * counter) / _labels.size()) << "\% processed \r";
std::cout << a.str() << std::flush;
diskann::cout << a.str() << std::flush;
}

this->build(filename, num_points_to_load, parameters, tags);
Expand Down
Loading