diff --git a/rust/lance-index/src/vector/hnsw/builder.rs b/rust/lance-index/src/vector/hnsw/builder.rs index 66e1bee758f..c7648fa746f 100644 --- a/rust/lance-index/src/vector/hnsw/builder.rs +++ b/rust/lance-index/src/vector/hnsw/builder.rs @@ -19,7 +19,6 @@ use std::cmp::min; use std::collections::{BinaryHeap, HashMap, VecDeque}; use std::fmt::Debug; use std::iter; -use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::sync::RwLock; use tracing::instrument; @@ -307,10 +306,10 @@ impl HNSW { .inner .level_count .iter() - .chain(iter::once(&AtomicUsize::new(0))) - .scan(0, |state, x| { + .chain(iter::once(&0usize)) + .scan(0usize, |state, &count| { let start = *state; - *state += x.load(Ordering::Relaxed); + *state += count; Some(start) }) .collect(); @@ -327,7 +326,7 @@ struct HnswBuilder { params: HnswBuildParams, nodes: Arc>>, - level_count: Vec, + level_count: Vec, entry_point: u32, @@ -349,7 +348,7 @@ impl HnswBuilder { } fn num_nodes(&self, level: usize) -> usize { - self.level_count[level].load(Ordering::Relaxed) + self.level_count[level] } fn nodes(&self) -> Arc>> { @@ -361,9 +360,7 @@ impl HnswBuilder { let len = storage.len(); let max_level = params.max_level; - let level_count = (0..max_level) - .map(|_| AtomicUsize::new(0)) - .collect::>(); + let level_count = vec![0usize; max_level as usize]; let visited_generator_queue = Arc::new(ArrayQueue::new(get_num_compute_intensive_cpus())); for _ in 0..get_num_compute_intensive_cpus() { @@ -445,8 +442,6 @@ impl HnswBuilder { { let mut current_node = nodes[node as usize].write().unwrap(); for level in (0..=target_level).rev() { - self.level_count[level as usize].fetch_add(1, Ordering::Relaxed); - let neighbors = self.search_level(&ep, level, &dist_calc, nodes, visited_generator); for neighbor in &neighbors { current_node.add_neighbor(neighbor.id, neighbor.dist, level); @@ -525,6 +520,17 @@ impl HnswBuilder { *neighbors_ranked = select_neighbors_heuristic(storage, &level_neighbors, m_max); builder_node.update_from_ranked_neighbors(level); } + + fn compute_level_count(&self) -> Vec { + let mut level_count = vec![0usize; self.max_level() as usize]; + for node in self.nodes.iter() { + let levels = node.read().unwrap().level_neighbors.len(); + for count in level_count.iter_mut().take(levels) { + *count += 1; + } + } + level_count + } } // View of a level in HNSW graph. @@ -666,7 +672,7 @@ impl IvfSubIndex for HNSW { let inner = HnswBuilder { params: hnsw_metadata.params, nodes: Arc::new(nodes.into_iter().map(RwLock::new).collect()), - level_count: level_count.into_iter().map(AtomicUsize::new).collect(), + level_count, entry_point: hnsw_metadata.entry_point, visited_generator_queue, }; @@ -763,34 +769,37 @@ impl IvfSubIndex for HNSW { where Self: Sized, { - let inner = HnswBuilder::with_params(params, storage); - let hnsw = Self { - inner: Arc::new(inner), - }; + let mut inner = HnswBuilder::with_params(params, storage); log::debug!( "Building HNSW graph: num={}, max_levels={}, m={}, ef_construction={}, distance_type:{}", storage.len(), - hnsw.inner.params.max_level, - hnsw.inner.params.m, - hnsw.inner.params.ef_construction, + inner.params.max_level, + inner.params.m, + inner.params.ef_construction, storage.distance_type(), ); if storage.is_empty() { - return Ok(hnsw); + return Ok(Self { + inner: Arc::new(inner), + }); } let len = storage.len(); - hnsw.inner.level_count[0].fetch_add(1, Ordering::Relaxed); (1..len).into_par_iter().for_each_init( || VisitedGenerator::new(len), |visited_generator, node| { - hnsw.inner.insert(node as u32, visited_generator, storage); + inner.insert(node as u32, visited_generator, storage); }, ); + inner.level_count = inner.compute_level_count(); - assert_eq!(hnsw.inner.level_count[0].load(Ordering::Relaxed), len); + let hnsw = Self { + inner: Arc::new(inner), + }; + + assert_eq!(hnsw.inner.level_count[0], len); Ok(hnsw) } @@ -945,4 +954,19 @@ mod tests { .unwrap(); assert_eq!(builder_results, loaded_results); } + + #[test] + fn test_level_offsets_match_batch_rows() { + const DIM: usize = 16; + const TOTAL: usize = 512; + let data = generate_random_array(TOTAL * DIM); + let fsl = FixedSizeListArray::try_new_from_values(data, DIM as i32).unwrap(); + let store = FlatFloatStorage::new(fsl, DistanceType::L2); + let hnsw = HNSW::index_vectors(&store, HnswBuildParams::default()).unwrap(); + let metadata = hnsw.metadata(); + let batch = hnsw.to_batch().unwrap(); + + assert_eq!(metadata.level_offsets.len(), hnsw.max_level() as usize + 1); + assert_eq!(*metadata.level_offsets.last().unwrap(), batch.num_rows()); + } }