Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 47 additions & 23 deletions rust/lance-index/src/vector/hnsw/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ use std::cmp::min;
use std::collections::{BinaryHeap, HashMap, VecDeque};
use std::fmt::Debug;
use std::iter;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::sync::RwLock;
use tracing::instrument;
Expand Down Expand Up @@ -307,10 +306,10 @@ impl HNSW {
.inner
.level_count
.iter()
.chain(iter::once(&AtomicUsize::new(0)))
.scan(0, |state, x| {
.chain(iter::once(&0usize))
.scan(0usize, |state, &count| {
let start = *state;
*state += x.load(Ordering::Relaxed);
*state += count;
Some(start)
})
.collect();
Expand All @@ -327,7 +326,7 @@ struct HnswBuilder {
params: HnswBuildParams,

nodes: Arc<Vec<RwLock<GraphBuilderNode>>>,
level_count: Vec<AtomicUsize>,
level_count: Vec<usize>,

entry_point: u32,

Expand All @@ -349,7 +348,7 @@ impl HnswBuilder {
}

fn num_nodes(&self, level: usize) -> usize {
self.level_count[level].load(Ordering::Relaxed)
self.level_count[level]
}

fn nodes(&self) -> Arc<Vec<RwLock<GraphBuilderNode>>> {
Expand All @@ -361,9 +360,7 @@ impl HnswBuilder {
let len = storage.len();
let max_level = params.max_level;

let level_count = (0..max_level)
.map(|_| AtomicUsize::new(0))
.collect::<Vec<_>>();
let level_count = vec![0usize; max_level as usize];

let visited_generator_queue = Arc::new(ArrayQueue::new(get_num_compute_intensive_cpus()));
for _ in 0..get_num_compute_intensive_cpus() {
Expand Down Expand Up @@ -445,8 +442,6 @@ impl HnswBuilder {
{
let mut current_node = nodes[node as usize].write().unwrap();
for level in (0..=target_level).rev() {
self.level_count[level as usize].fetch_add(1, Ordering::Relaxed);

let neighbors = self.search_level(&ep, level, &dist_calc, nodes, visited_generator);
for neighbor in &neighbors {
current_node.add_neighbor(neighbor.id, neighbor.dist, level);
Expand Down Expand Up @@ -525,6 +520,17 @@ impl HnswBuilder {
*neighbors_ranked = select_neighbors_heuristic(storage, &level_neighbors, m_max);
builder_node.update_from_ranked_neighbors(level);
}

fn compute_level_count(&self) -> Vec<usize> {
let mut level_count = vec![0usize; self.max_level() as usize];
for node in self.nodes.iter() {
let levels = node.read().unwrap().level_neighbors.len();
for count in level_count.iter_mut().take(levels) {
*count += 1;
}
}
level_count
}
}

// View of a level in HNSW graph.
Expand Down Expand Up @@ -666,7 +672,7 @@ impl IvfSubIndex for HNSW {
let inner = HnswBuilder {
params: hnsw_metadata.params,
nodes: Arc::new(nodes.into_iter().map(RwLock::new).collect()),
level_count: level_count.into_iter().map(AtomicUsize::new).collect(),
level_count,
entry_point: hnsw_metadata.entry_point,
visited_generator_queue,
};
Expand Down Expand Up @@ -763,34 +769,37 @@ impl IvfSubIndex for HNSW {
where
Self: Sized,
{
let inner = HnswBuilder::with_params(params, storage);
let hnsw = Self {
inner: Arc::new(inner),
};
let mut inner = HnswBuilder::with_params(params, storage);

log::debug!(
"Building HNSW graph: num={}, max_levels={}, m={}, ef_construction={}, distance_type:{}",
storage.len(),
hnsw.inner.params.max_level,
hnsw.inner.params.m,
hnsw.inner.params.ef_construction,
inner.params.max_level,
inner.params.m,
inner.params.ef_construction,
storage.distance_type(),
);

if storage.is_empty() {
return Ok(hnsw);
return Ok(Self {
inner: Arc::new(inner),
});
}

let len = storage.len();
hnsw.inner.level_count[0].fetch_add(1, Ordering::Relaxed);
(1..len).into_par_iter().for_each_init(
|| VisitedGenerator::new(len),
|visited_generator, node| {
hnsw.inner.insert(node as u32, visited_generator, storage);
inner.insert(node as u32, visited_generator, storage);
},
);
inner.level_count = inner.compute_level_count();

assert_eq!(hnsw.inner.level_count[0].load(Ordering::Relaxed), len);
let hnsw = Self {
inner: Arc::new(inner),
};

assert_eq!(hnsw.inner.level_count[0], len);
Ok(hnsw)
}

Expand Down Expand Up @@ -945,4 +954,19 @@ mod tests {
.unwrap();
assert_eq!(builder_results, loaded_results);
}

#[test]
fn test_level_offsets_match_batch_rows() {
const DIM: usize = 16;
const TOTAL: usize = 512;
let data = generate_random_array(TOTAL * DIM);
let fsl = FixedSizeListArray::try_new_from_values(data, DIM as i32).unwrap();
let store = FlatFloatStorage::new(fsl, DistanceType::L2);
let hnsw = HNSW::index_vectors(&store, HnswBuildParams::default()).unwrap();
let metadata = hnsw.metadata();
let batch = hnsw.to_batch().unwrap();

assert_eq!(metadata.level_offsets.len(), hnsw.max_level() as usize + 1);
assert_eq!(*metadata.level_offsets.last().unwrap(), batch.num_rows());
}
}
Loading