Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 49 additions & 58 deletions rust/lance-index/benches/hnsw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ use lance_linalg::distance::DistanceType;
use lance_testing::datagen::generate_random_array_with_seed;

fn bench_hnsw(c: &mut Criterion) {
const DIMENSION: usize = 512;
const TOTAL: usize = 10 * 1024;
const DIMENSION: usize = 128;
const TOTAL: usize = 100_000;
const SEED: [u8; 32] = [42; 32];
const K: usize = 10;
const K: usize = 100;

let rt = tokio::runtime::Runtime::new().unwrap();

Expand All @@ -34,64 +34,55 @@ fn bench_hnsw(c: &mut Criterion) {
let vectors = Arc::new(FlatFloatStorage::new(fsl.clone(), DistanceType::L2));

let query = fsl.value(0);
c.bench_function(
format!("create_hnsw({TOTAL}x{DIMENSION},levels=6)").as_str(),
|b| {
b.to_async(&rt).iter(|| async {
let hnsw =
HNSW::index_vectors(vectors.as_ref(), HnswBuildParams::default().max_level(6))
.unwrap();
let uids: HashSet<u32> = hnsw
.search_basic(
query.clone(),
K,
&HnswQueryParams {
ef: 300,
lower_bound: None,
upper_bound: None,
dist_q_c: 0.0,
},
None,
vectors.as_ref(),
)
.unwrap()
.iter()
.map(|node| node.id)
.collect();
c.bench_function(format!("create_hnsw({TOTAL}x{DIMENSION})").as_str(), |b| {
b.to_async(&rt).iter(|| async {
let hnsw = HNSW::index_vectors(vectors.as_ref(), HnswBuildParams::default()).unwrap();
let uids: HashSet<u32> = hnsw
.search_basic(
query.clone(),
K,
&HnswQueryParams {
ef: 300,
lower_bound: None,
upper_bound: None,
dist_q_c: 0.0,
},
None,
vectors.as_ref(),
)
.unwrap()
.iter()
.map(|node| node.id)
.collect();

assert_eq!(uids.len(), K);
})
},
);
assert_eq!(uids.len(), K);
})
});

let hnsw =
HNSW::index_vectors(vectors.as_ref(), HnswBuildParams::default().max_level(6)).unwrap();
c.bench_function(
format!("search_hnsw{TOTAL}x{DIMENSION}, levels=6").as_str(),
|b| {
b.to_async(&rt).iter(|| async {
let uids: HashSet<u32> = hnsw
.search_basic(
query.clone(),
K,
&HnswQueryParams {
ef: 300,
lower_bound: None,
upper_bound: None,
dist_q_c: 0.0,
},
None,
vectors.as_ref(),
)
.unwrap()
.iter()
.map(|node| node.id)
.collect();
let hnsw = HNSW::index_vectors(vectors.as_ref(), HnswBuildParams::default()).unwrap();
c.bench_function(format!("search_hnsw{TOTAL}x{DIMENSION}").as_str(), |b| {
b.to_async(&rt).iter(|| async {
let uids: HashSet<u32> = hnsw
.search_basic(
query.clone(),
K,
&HnswQueryParams {
ef: 300,
lower_bound: None,
upper_bound: None,
dist_q_c: 0.0,
},
None,
vectors.as_ref(),
)
.unwrap()
.iter()
.map(|node| node.id)
.collect();

assert_eq!(uids.len(), K);
})
},
);
assert_eq!(uids.len(), K);
})
});
}

#[cfg(target_os = "linux")]
Expand Down
19 changes: 10 additions & 9 deletions rust/lance-index/src/vector/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,11 @@ impl Visited<'_> {
self.visited[node_id_usize]
}

#[inline(always)]
pub fn iter_ones(&self) -> impl Iterator<Item = usize> + '_ {
self.visited.iter_ones()
}

pub fn count_ones(&self) -> usize {
self.visited.count_ones()
}
Expand Down Expand Up @@ -310,20 +315,15 @@ pub fn beam_search(
if current.dist > furthest && results.len() == k {
break;
}
let neighbors = graph.neighbors(current.id);

let furthest = results
.peek()
.map(|node| node.dist)
.unwrap_or(OrderedFloat(f32::INFINITY));

let unvisited_neighbors: Vec<_> = neighbors
.iter()
.filter(|&&neighbor| !visited.contains(neighbor))
.copied()
.collect();

let process_neighbor = |neighbor: u32| {
if visited.contains(neighbor) {
return;
}
visited.insert(neighbor);
let dist: OrderedFloat = dist_calc.distance(neighbor).into();
if dist <= furthest || results.len() < k {
Expand All @@ -343,8 +343,9 @@ pub fn beam_search(
candidates.push(Reverse((dist, neighbor).into()));
}
};
let neighbors = graph.neighbors(current.id);
process_neighbors_with_look_ahead(
&unvisited_neighbors,
&neighbors,
process_neighbor,
prefetch_distance,
dist_calc,
Expand Down
72 changes: 46 additions & 26 deletions rust/lance-index/src/vector/hnsw/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use lance_linalg::distance::DistanceType;
use rayon::prelude::*;
use snafu::location;
use std::cmp::min;
use std::collections::{BinaryHeap, HashMap};
use std::collections::{BinaryHeap, HashMap, VecDeque};
use std::fmt::Debug;
use std::iter;
use std::sync::atomic::{AtomicUsize, Ordering};
Expand Down Expand Up @@ -243,39 +243,59 @@ impl HNSW {
prefilter_bitset: Visited,
params: &HnswQueryParams,
) -> Vec<OrderedNode> {
let node_ids = storage
.row_ids()
.enumerate()
.filter_map(|(node_id, _)| {
prefilter_bitset
.contains(node_id as u32)
.then_some(node_id as u32)
})
.collect_vec();

let lower_bound: OrderedFloat = params.lower_bound.unwrap_or(f32::MIN).into();
let upper_bound: OrderedFloat = params.upper_bound.unwrap_or(f32::MAX).into();

let dist_calc = storage.dist_calculator(query, params.dist_q_c);
let mut heap = BinaryHeap::<OrderedNode>::with_capacity(k);
for i in 0..node_ids.len() {
if let Some(ahead) = self.inner.params.prefetch_distance {
if i + ahead < node_ids.len() {
dist_calc.prefetch(node_ids[i + ahead]);

match self.inner.params.prefetch_distance {
Some(ahead) if ahead > 0 => {
let mut ids_iter = prefilter_bitset.iter_ones().map(|i| i as u32);
let mut buffer = VecDeque::with_capacity(ahead + 1);
for _ in 0..=ahead {
if let Some(id) = ids_iter.next() {
buffer.push_back(id);
} else {
break;
}
}

while let Some(node_id) = buffer.pop_front() {
if let Some(&prefetch_id) = buffer.get(ahead - 1) {
dist_calc.prefetch(prefetch_id);
}
if let Some(next) = ids_iter.next() {
buffer.push_back(next);
}

let dist: OrderedFloat = dist_calc.distance(node_id).into();
if dist <= lower_bound || dist > upper_bound {
continue;
}
if heap.len() < k {
heap.push((dist, node_id).into());
} else if dist < heap.peek().unwrap().dist {
heap.pop();
heap.push((dist, node_id).into());
}
}
}
let node_id = node_ids[i];
let dist: OrderedFloat = dist_calc.distance(node_id).into();
if dist <= lower_bound || dist > upper_bound {
continue;
}
if heap.len() < k {
heap.push((dist, node_id).into());
} else if dist < heap.peek().unwrap().dist {
heap.pop();
heap.push((dist, node_id).into());
_ => {
for node_id in prefilter_bitset.iter_ones().map(|i| i as u32) {
let dist: OrderedFloat = dist_calc.distance(node_id).into();
if dist <= lower_bound || dist > upper_bound {
continue;
}
if heap.len() < k {
heap.push((dist, node_id).into());
} else if dist < heap.peek().unwrap().dist {
heap.pop();
heap.push((dist, node_id).into());
}
}
}
}
};
heap.into_sorted_vec()
}

Expand Down
Loading