Skip to content
36 changes: 25 additions & 11 deletions rust/lance-index/src/scalar/inverted/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use arrow_schema::{DataType, Field, Schema, SchemaRef};
use bitpacking::{BitPacker, BitPacker4x};
use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream};
use deepsize::DeepSizeOf;
use futures::{stream, Stream, StreamExt, TryStreamExt};
use futures::{Stream, StreamExt, TryStreamExt};
use lance_arrow::json::JSON_EXT_NAME;
use lance_arrow::{iter_str_array, ARROW_EXT_NAME_KEY};
use lance_core::utils::tokio::get_num_compute_intensive_cpus;
Expand Down Expand Up @@ -465,19 +465,30 @@ impl InnerBuilder {
id,
self.with_position
);
let mut batches = stream::iter(posting_lists)
.map(|posting_list| {
let block_max_scores = docs.calculate_block_max_scores(
posting_list.doc_ids.iter(),
posting_list.frequencies.iter(),
);
spawn_cpu(move || posting_list.to_batch(block_max_scores))
})
.buffered(get_num_compute_intensive_cpus());
let schema = inverted_list_schema(self.with_position);
let docs_for_batches = docs.clone();
let schema_for_batches = schema.clone();

let (tx, mut rx) = tokio::sync::mpsc::channel::<Result<RecordBatch>>(2);
let producer = spawn_cpu(move || {
for posting_list in posting_lists {
let batch =
posting_list.to_batch_with_docs(&docs_for_batches, schema_for_batches.clone());
let is_err = batch.is_err();
if tx.blocking_send(batch).is_err() {
break;
}
if is_err {
break;
}
}
Ok(())
});

let mut write_duration = std::time::Duration::ZERO;
let mut num_posting_lists = 0;
while let Some(batch) = batches.try_next().await? {
while let Some(batch) = rx.recv().await {
let batch = batch?;
num_posting_lists += 1;
let start = std::time::Instant::now();
writer.write_record_batch(batch).await?;
Expand All @@ -493,6 +504,9 @@ impl InnerBuilder {
}
}

// Errors from batch generation are sent through the channel and surfaced via `batch?`.
// Awaiting the producer here is just to propagate panics/cancellation.
producer.await?;
writer.finish().await?;
Ok(())
}
Expand Down
83 changes: 83 additions & 0 deletions rust/lance-index/src/scalar/inverted/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,89 @@ pub fn compress_posting_list<'a>(
Ok(builder.finish())
}

pub fn compress_posting_list_with_scores<'a, F>(
length: usize,
doc_ids: impl Iterator<Item = &'a u32>,
frequencies: impl Iterator<Item = &'a u32>,
mut score_for: F,
idf_scale: f32,
) -> Result<(arrow::array::LargeBinaryArray, f32)>
where
F: FnMut(u32, u32) -> f32,
{
// `length` comes from posting list size; zero would produce an invalid block
// (a max-score header with no doc/frequency data) and readers assume > 0 docs.
debug_assert!(length > 0);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to provide more information about this assert.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

if length < BLOCK_SIZE {
let mut builder = LargeBinaryBuilder::with_capacity(1, length * 4 * 2 + 1);
let mut max_score = f32::MIN;
let mut doc_id_buffer = Vec::with_capacity(length);
let mut freq_buffer = Vec::with_capacity(length);
for (doc_id, freq) in std::iter::zip(doc_ids, frequencies) {
let doc_id = *doc_id;
let freq = *freq;
doc_id_buffer.push(doc_id);
freq_buffer.push(freq);
let score = score_for(doc_id, freq);
if score > max_score {
max_score = score;
}
}
let max_score = max_score * idf_scale;
let _ = builder.write(max_score.to_le_bytes().as_ref())?;
compress_remainder(&doc_id_buffer, &mut builder)?;
compress_remainder(&freq_buffer, &mut builder)?;
builder.append_value("");
return Ok((builder.finish(), max_score));
}

let mut builder = LargeBinaryBuilder::with_capacity(length.div_ceil(BLOCK_SIZE), length * 3);
let mut buffer = [0u8; BLOCK_SIZE * 4 + 5];
let mut doc_id_buffer = Vec::with_capacity(BLOCK_SIZE);
let mut freq_buffer = Vec::with_capacity(BLOCK_SIZE);
let mut max_score = f32::MIN;
let mut block_max_score = f32::MIN;
for (doc_id, freq) in std::iter::zip(doc_ids, frequencies) {
let doc_id = *doc_id;
let freq = *freq;
doc_id_buffer.push(doc_id);
freq_buffer.push(freq);

let score = score_for(doc_id, freq);
if score > block_max_score {
block_max_score = score;
}

if doc_id_buffer.len() < BLOCK_SIZE {
continue;
}

let block_score = block_max_score * idf_scale;
if block_score > max_score {
max_score = block_score;
}
let _ = builder.write(block_score.to_le_bytes().as_ref())?;
compress_sorted_block(&doc_id_buffer, &mut buffer, &mut builder)?;
compress_block(&freq_buffer, &mut buffer, &mut builder)?;
builder.append_value("");
doc_id_buffer.clear();
freq_buffer.clear();
block_max_score = f32::MIN;
}

if !doc_id_buffer.is_empty() {
let block_score = block_max_score * idf_scale;
if block_score > max_score {
max_score = block_score;
}
let _ = builder.write(block_score.to_le_bytes().as_ref())?;
compress_remainder(&doc_id_buffer, &mut builder)?;
compress_remainder(&freq_buffer, &mut builder)?;
builder.append_value("");
}
Ok((builder.finish(), max_score))
}

#[inline]
fn compress_sorted_block(
data: &[u32],
Expand Down
104 changes: 90 additions & 14 deletions rust/lance-index/src/scalar/inverted/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ use super::{
};
use super::{
builder::{InnerBuilder, PositionRecorder},
encoding::compress_posting_list,
encoding::{compress_posting_list, compress_posting_list_with_scores},
iter::CompressedPostingListIterator,
};
use super::{encoding::compress_positions, iter::PostingListIterator};
Expand Down Expand Up @@ -1912,18 +1912,13 @@ impl PostingListBuilder {
}
}

// assume the posting list is sorted by doc id
pub fn to_batch(self, block_max_scores: Vec<f32>) -> Result<RecordBatch> {
fn build_batch(
self,
compressed: LargeBinaryArray,
max_score: f32,
schema: SchemaRef,
) -> Result<RecordBatch> {
let length = self.len();
let max_score = block_max_scores.iter().copied().fold(f32::MIN, f32::max);

let schema = inverted_list_schema(self.has_positions());
let compressed = compress_posting_list(
self.doc_ids.len(),
self.doc_ids.iter(),
self.frequencies.iter(),
block_max_scores.into_iter(),
)?;
let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, compressed.len() as i32]));
let mut columns = vec![
Arc::new(ListArray::try_new(
Expand All @@ -1934,7 +1929,7 @@ impl PostingListBuilder {
)?) as ArrayRef,
Arc::new(Float32Array::from_iter_values(std::iter::once(max_score))) as ArrayRef,
Arc::new(UInt32Array::from_iter_values(std::iter::once(
self.len() as u32
length as u32,
))) as ArrayRef,
];

Expand All @@ -1958,6 +1953,37 @@ impl PostingListBuilder {
Ok(batch)
}

// assume the posting list is sorted by doc id
pub fn to_batch(self, block_max_scores: Vec<f32>) -> Result<RecordBatch> {
let max_score = block_max_scores.iter().copied().fold(f32::MIN, f32::max);
let schema = inverted_list_schema(self.has_positions());
let compressed = compress_posting_list(
self.doc_ids.len(),
self.doc_ids.iter(),
self.frequencies.iter(),
block_max_scores.into_iter(),
)?;
self.build_batch(compressed, max_score, schema)
}

pub fn to_batch_with_docs(self, docs: &DocSet, schema: SchemaRef) -> Result<RecordBatch> {
let length = self.len();
let avgdl = docs.average_length();
let idf_scale = idf(length, docs.len()) * (K1 + 1.0);
let (compressed, max_score) = compress_posting_list_with_scores(
length,
self.doc_ids.iter(),
self.frequencies.iter(),
|doc_id, freq| {
let doc_norm = K1 * (1.0 - B + B * docs.num_tokens(doc_id) as f32 / avgdl);
let freq = freq as f32;
freq / (freq + doc_norm)
},
idf_scale,
)?;
self.build_batch(compressed, max_score, schema)
}

pub fn remap(&mut self, removed: &[u32]) {
let mut cursor = 0;
let mut new_doc_ids = ExpLinkedList::with_capacity(self.len());
Expand Down Expand Up @@ -2562,10 +2588,12 @@ mod tests {

use crate::metrics::NoOpMetricsCollector;
use crate::prefilter::NoFilter;
use crate::scalar::inverted::builder::{InnerBuilder, PositionRecorder};
use crate::scalar::inverted::builder::{inverted_list_schema, InnerBuilder, PositionRecorder};
use crate::scalar::inverted::encoding::decompress_posting_list;
use crate::scalar::inverted::query::{FtsSearchParams, Operator};
use crate::scalar::lance_format::LanceIndexStore;
use arrow::array::AsArray;
use arrow::datatypes::{Float32Type, UInt32Type};

use super::*;

Expand Down Expand Up @@ -2607,6 +2635,54 @@ mod tests {
.all(|(a, b)| a == b));
}

#[test]
fn test_posting_list_batch_matches_docset_scoring() {
let mut docs = DocSet::default();
let num_docs = BLOCK_SIZE + 3;
for doc_id in 0..num_docs as u32 {
docs.append(doc_id as u64, doc_id % 7 + 1);
}

let doc_ids = (0..num_docs as u32).collect::<Vec<_>>();
let freqs = doc_ids
.iter()
.map(|doc_id| doc_id % 5 + 1)
.collect::<Vec<_>>();

let mut builder_scores = PostingListBuilder::new(false);
let mut builder_docs = PostingListBuilder::new(false);
for (&doc_id, &freq) in doc_ids.iter().zip(freqs.iter()) {
builder_scores.add(doc_id, PositionRecorder::Count(freq));
builder_docs.add(doc_id, PositionRecorder::Count(freq));
}

let block_max_scores = docs.calculate_block_max_scores(doc_ids.iter(), freqs.iter());
let batch_scores = builder_scores.to_batch(block_max_scores).unwrap();
let batch_docs = builder_docs
.to_batch_with_docs(&docs, inverted_list_schema(false))
.unwrap();

let scores_posting = batch_scores[POSTING_COL].as_list::<i32>().value(0);
let scores_posting = scores_posting.as_binary::<i64>();
let docs_posting = batch_docs[POSTING_COL].as_list::<i32>().value(0);
let docs_posting = docs_posting.as_binary::<i64>();
assert_eq!(scores_posting, docs_posting);

let score_left = batch_scores[MAX_SCORE_COL]
.as_primitive::<Float32Type>()
.value(0);
let score_right = batch_docs[MAX_SCORE_COL]
.as_primitive::<Float32Type>()
.value(0);
assert!((score_left - score_right).abs() < 1e-6);

let len_left = batch_scores[LENGTH_COL]
.as_primitive::<UInt32Type>()
.value(0);
let len_right = batch_docs[LENGTH_COL].as_primitive::<UInt32Type>().value(0);
assert_eq!(len_left, len_right);
}

#[tokio::test]
async fn test_remap_to_empty_posting_list() {
let tmpdir = TempObjDir::default();
Expand Down
Loading