From a5429303873bae893deed68aa33ab79f3e184585 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Mon, 1 Dec 2025 20:32:02 +0800 Subject: [PATCH] Optimize HNSW search and benchmarks --- rust/lance-index/benches/hnsw.rs | 107 +++++++++----------- rust/lance-index/src/vector/graph.rs | 19 ++-- rust/lance-index/src/vector/hnsw/builder.rs | 72 ++++++++----- 3 files changed, 105 insertions(+), 93 deletions(-) diff --git a/rust/lance-index/benches/hnsw.rs b/rust/lance-index/benches/hnsw.rs index 12848dd6a06..967b2e67b67 100644 --- a/rust/lance-index/benches/hnsw.rs +++ b/rust/lance-index/benches/hnsw.rs @@ -22,10 +22,10 @@ use lance_linalg::distance::DistanceType; use lance_testing::datagen::generate_random_array_with_seed; fn bench_hnsw(c: &mut Criterion) { - const DIMENSION: usize = 512; - const TOTAL: usize = 10 * 1024; + const DIMENSION: usize = 128; + const TOTAL: usize = 100_000; const SEED: [u8; 32] = [42; 32]; - const K: usize = 10; + const K: usize = 100; let rt = tokio::runtime::Runtime::new().unwrap(); @@ -34,64 +34,55 @@ fn bench_hnsw(c: &mut Criterion) { let vectors = Arc::new(FlatFloatStorage::new(fsl.clone(), DistanceType::L2)); let query = fsl.value(0); - c.bench_function( - format!("create_hnsw({TOTAL}x{DIMENSION},levels=6)").as_str(), - |b| { - b.to_async(&rt).iter(|| async { - let hnsw = - HNSW::index_vectors(vectors.as_ref(), HnswBuildParams::default().max_level(6)) - .unwrap(); - let uids: HashSet = hnsw - .search_basic( - query.clone(), - K, - &HnswQueryParams { - ef: 300, - lower_bound: None, - upper_bound: None, - dist_q_c: 0.0, - }, - None, - vectors.as_ref(), - ) - .unwrap() - .iter() - .map(|node| node.id) - .collect(); + c.bench_function(format!("create_hnsw({TOTAL}x{DIMENSION})").as_str(), |b| { + b.to_async(&rt).iter(|| async { + let hnsw = HNSW::index_vectors(vectors.as_ref(), HnswBuildParams::default()).unwrap(); + let uids: HashSet = hnsw + .search_basic( + query.clone(), + K, + &HnswQueryParams { + ef: 300, + lower_bound: None, + upper_bound: None, + dist_q_c: 0.0, + }, + None, + vectors.as_ref(), + ) + .unwrap() + .iter() + .map(|node| node.id) + .collect(); - assert_eq!(uids.len(), K); - }) - }, - ); + assert_eq!(uids.len(), K); + }) + }); - let hnsw = - HNSW::index_vectors(vectors.as_ref(), HnswBuildParams::default().max_level(6)).unwrap(); - c.bench_function( - format!("search_hnsw{TOTAL}x{DIMENSION}, levels=6").as_str(), - |b| { - b.to_async(&rt).iter(|| async { - let uids: HashSet = hnsw - .search_basic( - query.clone(), - K, - &HnswQueryParams { - ef: 300, - lower_bound: None, - upper_bound: None, - dist_q_c: 0.0, - }, - None, - vectors.as_ref(), - ) - .unwrap() - .iter() - .map(|node| node.id) - .collect(); + let hnsw = HNSW::index_vectors(vectors.as_ref(), HnswBuildParams::default()).unwrap(); + c.bench_function(format!("search_hnsw{TOTAL}x{DIMENSION}").as_str(), |b| { + b.to_async(&rt).iter(|| async { + let uids: HashSet = hnsw + .search_basic( + query.clone(), + K, + &HnswQueryParams { + ef: 300, + lower_bound: None, + upper_bound: None, + dist_q_c: 0.0, + }, + None, + vectors.as_ref(), + ) + .unwrap() + .iter() + .map(|node| node.id) + .collect(); - assert_eq!(uids.len(), K); - }) - }, - ); + assert_eq!(uids.len(), K); + }) + }); } #[cfg(target_os = "linux")] diff --git a/rust/lance-index/src/vector/graph.rs b/rust/lance-index/src/vector/graph.rs index 4daca614ef2..2dac8a23f7e 100644 --- a/rust/lance-index/src/vector/graph.rs +++ b/rust/lance-index/src/vector/graph.rs @@ -181,6 +181,11 @@ impl Visited<'_> { self.visited[node_id_usize] } + #[inline(always)] + pub fn iter_ones(&self) -> impl Iterator + '_ { + self.visited.iter_ones() + } + pub fn count_ones(&self) -> usize { self.visited.count_ones() } @@ -310,20 +315,15 @@ pub fn beam_search( if current.dist > furthest && results.len() == k { break; } - let neighbors = graph.neighbors(current.id); - let furthest = results .peek() .map(|node| node.dist) .unwrap_or(OrderedFloat(f32::INFINITY)); - let unvisited_neighbors: Vec<_> = neighbors - .iter() - .filter(|&&neighbor| !visited.contains(neighbor)) - .copied() - .collect(); - let process_neighbor = |neighbor: u32| { + if visited.contains(neighbor) { + return; + } visited.insert(neighbor); let dist: OrderedFloat = dist_calc.distance(neighbor).into(); if dist <= furthest || results.len() < k { @@ -343,8 +343,9 @@ pub fn beam_search( candidates.push(Reverse((dist, neighbor).into())); } }; + let neighbors = graph.neighbors(current.id); process_neighbors_with_look_ahead( - &unvisited_neighbors, + &neighbors, process_neighbor, prefetch_distance, dist_calc, diff --git a/rust/lance-index/src/vector/hnsw/builder.rs b/rust/lance-index/src/vector/hnsw/builder.rs index fd444fb053b..66e1bee758f 100644 --- a/rust/lance-index/src/vector/hnsw/builder.rs +++ b/rust/lance-index/src/vector/hnsw/builder.rs @@ -16,7 +16,7 @@ use lance_linalg::distance::DistanceType; use rayon::prelude::*; use snafu::location; use std::cmp::min; -use std::collections::{BinaryHeap, HashMap}; +use std::collections::{BinaryHeap, HashMap, VecDeque}; use std::fmt::Debug; use std::iter; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -243,39 +243,59 @@ impl HNSW { prefilter_bitset: Visited, params: &HnswQueryParams, ) -> Vec { - let node_ids = storage - .row_ids() - .enumerate() - .filter_map(|(node_id, _)| { - prefilter_bitset - .contains(node_id as u32) - .then_some(node_id as u32) - }) - .collect_vec(); - let lower_bound: OrderedFloat = params.lower_bound.unwrap_or(f32::MIN).into(); let upper_bound: OrderedFloat = params.upper_bound.unwrap_or(f32::MAX).into(); let dist_calc = storage.dist_calculator(query, params.dist_q_c); let mut heap = BinaryHeap::::with_capacity(k); - for i in 0..node_ids.len() { - if let Some(ahead) = self.inner.params.prefetch_distance { - if i + ahead < node_ids.len() { - dist_calc.prefetch(node_ids[i + ahead]); + + match self.inner.params.prefetch_distance { + Some(ahead) if ahead > 0 => { + let mut ids_iter = prefilter_bitset.iter_ones().map(|i| i as u32); + let mut buffer = VecDeque::with_capacity(ahead + 1); + for _ in 0..=ahead { + if let Some(id) = ids_iter.next() { + buffer.push_back(id); + } else { + break; + } + } + + while let Some(node_id) = buffer.pop_front() { + if let Some(&prefetch_id) = buffer.get(ahead - 1) { + dist_calc.prefetch(prefetch_id); + } + if let Some(next) = ids_iter.next() { + buffer.push_back(next); + } + + let dist: OrderedFloat = dist_calc.distance(node_id).into(); + if dist <= lower_bound || dist > upper_bound { + continue; + } + if heap.len() < k { + heap.push((dist, node_id).into()); + } else if dist < heap.peek().unwrap().dist { + heap.pop(); + heap.push((dist, node_id).into()); + } } } - let node_id = node_ids[i]; - let dist: OrderedFloat = dist_calc.distance(node_id).into(); - if dist <= lower_bound || dist > upper_bound { - continue; - } - if heap.len() < k { - heap.push((dist, node_id).into()); - } else if dist < heap.peek().unwrap().dist { - heap.pop(); - heap.push((dist, node_id).into()); + _ => { + for node_id in prefilter_bitset.iter_ones().map(|i| i as u32) { + let dist: OrderedFloat = dist_calc.distance(node_id).into(); + if dist <= lower_bound || dist > upper_bound { + continue; + } + if heap.len() < k { + heap.push((dist, node_id).into()); + } else if dist < heap.peek().unwrap().dist { + heap.pop(); + heap.push((dist, node_id).into()); + } + } } - } + }; heap.into_sorted_vec() }