diff --git a/include/disk_utils.h b/include/disk_utils.h index dabeeeb0c..6aaa02426 100644 --- a/include/disk_utils.h +++ b/include/disk_utils.h @@ -75,7 +75,7 @@ namespace diskann { const _u64 nshards, unsigned max_degree, const std::string &output_vamana, const std::string &medoids_file, bool use_filters = false, - const std::string &labels_to_medoids_file = std::string("")); + const std::string &labels_to_medoids_file = std::string("")); DISKANN_DLLEXPORT void extract_shard_labels( const std::string &in_label_file, const std::string &shard_ids_bin, diff --git a/include/index.h b/include/index.h index 7ea309e34..a866591bf 100644 --- a/include/index.h +++ b/include/index.h @@ -192,6 +192,8 @@ namespace diskann { DISKANN_DLLEXPORT consolidation_report consolidate_deletes(const Parameters ¶meters); + DISKANN_DLLEXPORT void prune_all_nbrs(const Parameters ¶meters); + DISKANN_DLLEXPORT bool is_index_saved(); // repositions frozen points to the end of _data - if they have been moved @@ -238,7 +240,7 @@ namespace diskann { // determines navigating node of the graph by calculating medoid of datafopt unsigned calculate_entry_point(); - void parse_label_file(const std::string &map_file); + size_t parse_label_file(const std::string &map_file); std::pair iterate_to_fixed_point( const T *node_coords, const unsigned Lindex, diff --git a/src/disk_utils.cpp b/src/disk_utils.cpp index 9b7c3ca24..e6aa02fdb 100644 --- a/src/disk_utils.cpp +++ b/src/disk_utils.cpp @@ -323,6 +323,7 @@ namespace diskann { } std::ofstream mapping_writer(labels_to_medoids_file); + assert(mapping_writer.is_open()); for (auto iter : global_label_to_medoids) { mapping_writer << iter.first << ", "; auto &vec = iter.second; @@ -534,6 +535,7 @@ namespace diskann { diskann::cout << labels_per_point.size() << " is the new number of points" << std::endl; std::ofstream label_writer(out_labels_file); + assert(label_writer.is_open()); for (_u32 i = 0; i < labels_per_point.size(); i++) { for (_u32 j = 0; j < (labels_per_point[i].size() - 1); j++) { label_writer << labels_per_point[i][j] << ","; @@ -551,6 +553,7 @@ namespace diskann { data = (T *) std::realloc((void *) data, labels_per_point.size() * ndims * sizeof(T)); std::ofstream dummy_writer(out_metadata_file); + assert(dummy_writer.is_open()); for (auto i = dummy_pt_ids.begin(); i != dummy_pt_ids.end(); i++) { dummy_writer << i->first << "," << i->second << std::endl; std::memcpy(data + i->first * ndims, data + i->second * ndims, @@ -566,7 +569,7 @@ namespace diskann { const std::string &in_label_file, const std::string &shard_ids_bin, const std::string &shard_label_file) { // assumes ith row is for ith // point in labels file - std::cout << "Extracting labels for shard" << std::endl; + diskann::cout << "Extracting labels for shard" << std::endl; _u32 *ids = nullptr; _u64 num_ids, tmp_dim; @@ -578,6 +581,8 @@ namespace diskann { std::ifstream label_reader(in_label_file); std::ofstream label_writer(shard_label_file); + assert(label_reader.is_open()); + assert(label_writer.is_open()); if (label_reader && label_writer) { while (std::getline(label_reader, cur_line)) { if (shard_counter >= num_ids) { diff --git a/src/index.cpp b/src/index.cpp index 35eab43c3..344ca86d6 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -279,9 +279,9 @@ namespace diskann { if (_filter_to_medoid_id.size() > 0) { std::ofstream medoid_writer(std::string(filename) + "_labels_to_medoids.txt"); + assert(medoid_writer.is_open()); for (auto iter : _filter_to_medoid_id) { medoid_writer << iter.first << ", " << iter.second << std::endl; - // std::cout << iter.first << ", " << iter.second << std::endl; } medoid_writer.close(); } @@ -289,12 +289,14 @@ namespace diskann { if (_use_universal_label) { std::ofstream universal_label_writer(std::string(filename) + "_universal_label.txt"); + assert(universal_label_writer.is_open()); universal_label_writer << _universal_label << std::endl; universal_label_writer.close(); } if (_pts_to_labels.size() > 0) { std::ofstream label_writer(std::string(filename) + "_labels.txt"); + assert(label_writer.is_open()); for (_u32 i = 0; i < _pts_to_labels.size(); i++) { for (_u32 j = 0; j < (_pts_to_labels[i].size() - 1); j++) { label_writer << _pts_to_labels[i][j] << ","; @@ -477,52 +479,12 @@ namespace diskann { _has_built = true; - size_t tags_file_num_pts = 0, graph_num_pts = 0, data_file_num_pts = 0; + size_t tags_file_num_pts = 0, graph_num_pts = 0, data_file_num_pts = 0, label_num_pts = 0; std::string mem_index_file(filename); std::string labels_file = mem_index_file + "_labels.txt"; std::string labels_to_medoids = mem_index_file + "_labels_to_medoids.txt"; - if (file_exists(labels_file)) { - parse_label_file(labels_file); - if (file_exists(labels_to_medoids)) { - std::ifstream medoid_stream(labels_to_medoids); - - std::string line, token; - unsigned line_cnt = 0; - - _filter_to_medoid_id.clear(); - - while (std::getline(medoid_stream, line)) { - std::istringstream iss(line); - _u32 cnt = 0; - _u32 medoid = 0; - label label; - while (std::getline(iss, token, ',')) { - token.erase(std::remove(token.begin(), token.end(), '\n'), - token.end()); - token.erase(std::remove(token.begin(), token.end(), '\r'), - token.end()); - unsigned token_as_num = std::stoul(token); - if (cnt == 0) - label = token_as_num; - else - medoid = token_as_num; - cnt++; - } - _filter_to_medoid_id[label] = medoid; - line_cnt++; - } - } - - std::string universal_label_file(filename); - universal_label_file += "_universal_label.txt"; - if (file_exists(universal_label_file)) { - std::ifstream universal_label_reader(universal_label_file); - universal_label_reader >> _universal_label; - _use_universal_label = true; - universal_label_reader.close(); - } - } + if (!_save_as_one_file) { // For DLVS Store, we will not support saving the index in multiple files. @@ -561,6 +523,53 @@ namespace diskann { throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); } + if (file_exists(labels_file)) { + label_num_pts = parse_label_file(labels_file); + assert(label_num_pts == data_file_num_pts); + if (file_exists(labels_to_medoids)) { + std::ifstream medoid_stream(labels_to_medoids); + assert(medoid_stream.is_open()); + std::string line, token; + unsigned line_cnt = 0; + + _filter_to_medoid_id.clear(); + try { + while (std::getline(medoid_stream, line)) { + std::istringstream iss(line); + _u32 cnt = 0; + _u32 medoid = 0; + label label; + while (std::getline(iss, token, ',')) { + token.erase(std::remove(token.begin(), token.end(), '\n'), + token.end()); + token.erase(std::remove(token.begin(), token.end(), '\r'), + token.end()); + unsigned token_as_num = std::stoul(token); + if (cnt == 0) + label = token_as_num; + else + medoid = token_as_num; + cnt++; + } + _filter_to_medoid_id[label] = medoid; + line_cnt++; + } + } catch (std::system_error &e) { + throw FileException(labels_to_medoids, e, __FUNCSIG__, __FILE__, + __LINE__); + } + } + + std::string universal_label_file(filename); + universal_label_file += "_universal_label.txt"; + if (file_exists(universal_label_file)) { + std::ifstream universal_label_reader(universal_label_file); + assert(universal_label_reader.is_open()); + universal_label_reader >> _universal_label; + _use_universal_label = true; + universal_label_reader.close(); + } + } _nd = data_file_num_pts - _num_frozen_pts; _empty_slots.clear(); @@ -688,13 +697,13 @@ namespace diskann { } } #else - size_t bytes_read = vamana_metadata_size; size_t cc = 0; unsigned nodes_read = 0; while (bytes_read != expected_file_size) { unsigned k; in.read((char *) &k, sizeof(unsigned)); + if (k == 0) { diskann::cerr << "ERROR: Point found with no out-neighbors, point#" << nodes_read << std::endl; @@ -1378,6 +1387,65 @@ namespace diskann { } } + template + void Index::prune_all_nbrs(const Parameters ¶meters) { + const unsigned range = parameters.Get("R"); + const unsigned maxc = parameters.Get("C"); + const float alpha = parameters.Get("alpha"); + _filtered_index = true; + + diskann::Timer timer; +#pragma omp parallel for + for (_s64 node = 0; node < (_s64)(_max_points + _num_frozen_pts); node++) { + if ((size_t) node < _nd || (size_t) node == _max_points) { + if (_final_graph[node].size() > range) { + tsl::robin_set dummy_visited(0); + std::vector dummy_pool(0); + std::vector new_out_neighbors; + + ScratchStoreManager> manager(_query_scratch); + auto scratch = manager.scratch_space(); + + for (auto cur_nbr : _final_graph[node]) { + if (dummy_visited.find(cur_nbr) == dummy_visited.end() && + cur_nbr != node) { + float dist = + _distance->compare(_data + _aligned_dim * (size_t) node, + _data + _aligned_dim * (size_t) cur_nbr, + (unsigned) _aligned_dim); + dummy_pool.emplace_back(Neighbor(cur_nbr, dist)); + dummy_visited.insert(cur_nbr); + } + } + + prune_neighbors((_u32) node, dummy_pool, range, maxc, alpha, new_out_neighbors, scratch); + _final_graph[node].clear(); + for (auto id : new_out_neighbors) + _final_graph[node].emplace_back(id); + } + } + } + + diskann::cout << "Prune time : " << timer.elapsed() / 1000 << "ms" + << std::endl; + size_t max = 0, min = 1 << 30, total = 0, cnt = 0; + for (size_t i = 0; i < (_nd + _num_frozen_pts); i++) { + std::vector pool = _final_graph[i]; + max = (std::max)(max, pool.size()); + min = (std::min)(min, pool.size()); + total += pool.size(); + if (pool.size() < 2) + cnt++; + } + if (min > max) + min = max; + if (_nd > 0) { + diskann::cout << "Index built with degree: max:" << max << " avg:" + << (float) total / (float) (_nd + _num_frozen_pts) + << " min:" << min << " count(deg<2):" << cnt << std::endl; + } + } + template void Index::set_start_point(T *data) { std::unique_lock ul(_update_lock); @@ -1648,13 +1716,13 @@ namespace diskann { } template - void Index::parse_label_file(const std::string &map_file) { + size_t Index::parse_label_file(const std::string &label_file) { // Format of Label txt file: filters with comma separators - - std::ifstream infile(map_file); + std::ifstream infile(label_file); + assert(infile.is_open()); std::string line, token; unsigned line_cnt = 0; - + while (std::getline(infile, line)) { line_cnt++; } @@ -1663,28 +1731,35 @@ namespace diskann { infile.clear(); infile.seekg(0, std::ios::beg); line_cnt = 0; - while (std::getline(infile, line)) { - std::istringstream iss(line); - std::vector