Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 67 additions & 17 deletions rust/lance-index/src/scalar/inverted/wand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use super::{
encoding::{decompress_positions, decompress_posting_block, decompress_posting_remainder},
query::FtsSearchParams,
scorer::Scorer,
DocSet, PostingList, RawDocInfo,
CompressedPostingList, DocSet, PostingList, RawDocInfo,
};
use super::{builder::BLOCK_SIZE, DocInfo};
use super::{
Expand Down Expand Up @@ -140,6 +140,28 @@ impl Ord for PostingIterator {
}

impl PostingIterator {
#[inline]
fn compressed_state_ptr(&self) -> *mut CompressedState {
debug_assert!(self.compressed.is_some());
// this method is called very frequently, so we prefer to use `UnsafeCell` instead of
// `RefCell` to avoid the overhead of runtime borrow checking
self.compressed.as_ref().unwrap().get()
}

#[inline]
fn ensure_compressed_block_ptr(
&self,
list: &CompressedPostingList,
block_idx: usize,
) -> *mut CompressedState {
let compressed = unsafe { &mut *self.compressed_state_ptr() };
if compressed.block_idx != block_idx || compressed.doc_ids.is_empty() {
let block = list.blocks.value(block_idx);
compressed.decompress(block, block_idx, list.blocks.len(), list.length);
}
compressed as *mut CompressedState
}

pub(crate) fn new(
token: String,
token_id: u32,
Expand Down Expand Up @@ -194,19 +216,9 @@ impl PostingIterator {

match self.list {
PostingList::Compressed(ref list) => {
debug_assert!(self.compressed.is_some());
// this method is called very frequently, so we prefer to use `UnsafeCell` instead of `RefCell`
// to avoid the overhead of runtime borrow checking
let compressed = unsafe {
let compressed = self.compressed.as_ref().unwrap();
&mut *compressed.get()
};
let block_idx = self.index / BLOCK_SIZE;
let block_offset = self.index % BLOCK_SIZE;
if compressed.block_idx != block_idx || compressed.doc_ids.is_empty() {
let block = list.blocks.value(block_idx);
compressed.decompress(block, block_idx, list.blocks.len(), list.length);
}
let compressed = unsafe { &mut *self.ensure_compressed_block_ptr(list, block_idx) };

// Read from the decompressed block
let doc_id = compressed.doc_ids[block_offset];
Expand All @@ -232,7 +244,7 @@ impl PostingIterator {
// move to the next doc id that is greater than or equal to least_id
fn next(&mut self, least_id: u64) {
match self.list {
PostingList::Compressed(ref mut list) => {
PostingList::Compressed(ref list) => {
debug_assert!(least_id <= u32::MAX as u64);
let least_id = least_id as u32;
let mut block_idx = self.index / BLOCK_SIZE;
Expand All @@ -242,9 +254,24 @@ impl PostingIterator {
block_idx += 1;
}
self.index = self.index.max(block_idx * BLOCK_SIZE);
let length = self.list.len();
while self.index < length && (self.doc().unwrap().doc_id() as u32) < least_id {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is our new PR faster? Is self.doc() a heavy operation? Would it be better to provide an API like compressed_doc(doc_index) instead of maintaining complex logic inside next?

Copy link
Copy Markdown
Contributor Author

@BubbleCal BubbleCal Jan 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea is to use binary search to avoid scanning the entire block.

doc() is the most costly operation during FTS search because it decompresses the block if needed

self.index += 1;
let length = list.length as usize;
while self.index < length {
let block_idx = self.index / BLOCK_SIZE;
let block_offset = self.index % BLOCK_SIZE;
let compressed =
unsafe { &mut *self.ensure_compressed_block_ptr(list, block_idx) };
let in_block = &compressed.doc_ids[block_offset..];
let offset_in_block = in_block.partition_point(|&doc_id| doc_id < least_id);
let new_offset = block_offset + offset_in_block;
if new_offset < compressed.doc_ids.len() {
self.index = block_idx * BLOCK_SIZE + new_offset;
break;
}
if block_idx + 1 >= list.blocks.len() {
self.index = length;
break;
}
self.index = (block_idx + 1) * BLOCK_SIZE;
}
self.block_idx = self.index / BLOCK_SIZE;
}
Expand All @@ -256,7 +283,7 @@ impl PostingIterator {

fn shallow_next(&mut self, least_id: u64) {
match self.list {
PostingList::Compressed(ref mut list) => {
PostingList::Compressed(ref list) => {
debug_assert!(least_id <= u32::MAX as u64);
let least_id = least_id as u32;
while self.block_idx + 1 < list.blocks.len()
Expand Down Expand Up @@ -952,6 +979,29 @@ mod tests {
assert_eq!(result.len(), 0); // Should not panic
}

#[test]
fn test_posting_iterator_next_compressed_partition_point() {
let mut docs = DocSet::default();
let num_docs = (BLOCK_SIZE * 2 + 5) as u32;
for i in 0..num_docs {
docs.append(i as u64, 1);
}

let doc_ids = (0..num_docs).collect::<Vec<_>>();
let posting = generate_posting_list(doc_ids, 1.0, None, true);
let mut iter = PostingIterator::new(String::from("term"), 0, 0, posting, docs.len());

iter.next(10);
assert_eq!(iter.doc().unwrap().doc_id(), 10);

let target = BLOCK_SIZE as u64 + 3;
iter.next(target);
assert_eq!(iter.doc().unwrap().doc_id(), target);

iter.next(num_docs as u64 + 10);
assert!(iter.doc().is_none());
}

#[test]
fn test_wand_skip_to_next_block() {
let mut docs = DocSet::default();
Expand Down
Loading