diff --git a/rust/lance-index/src/scalar/inverted/builder.rs b/rust/lance-index/src/scalar/inverted/builder.rs index 6e5155e1c20..07f16391a0a 100644 --- a/rust/lance-index/src/scalar/inverted/builder.rs +++ b/rust/lance-index/src/scalar/inverted/builder.rs @@ -19,7 +19,7 @@ use arrow_schema::{DataType, Field, Schema, SchemaRef}; use bitpacking::{BitPacker, BitPacker4x}; use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream}; use deepsize::DeepSizeOf; -use futures::{stream, Stream, StreamExt, TryStreamExt}; +use futures::{Stream, StreamExt, TryStreamExt}; use lance_arrow::json::JSON_EXT_NAME; use lance_arrow::{iter_str_array, ARROW_EXT_NAME_KEY}; use lance_core::utils::tokio::get_num_compute_intensive_cpus; @@ -465,19 +465,30 @@ impl InnerBuilder { id, self.with_position ); - let mut batches = stream::iter(posting_lists) - .map(|posting_list| { - let block_max_scores = docs.calculate_block_max_scores( - posting_list.doc_ids.iter(), - posting_list.frequencies.iter(), - ); - spawn_cpu(move || posting_list.to_batch(block_max_scores)) - }) - .buffered(get_num_compute_intensive_cpus()); + let schema = inverted_list_schema(self.with_position); + let docs_for_batches = docs.clone(); + let schema_for_batches = schema.clone(); + + let (tx, mut rx) = tokio::sync::mpsc::channel::>(2); + let producer = spawn_cpu(move || { + for posting_list in posting_lists { + let batch = + posting_list.to_batch_with_docs(&docs_for_batches, schema_for_batches.clone()); + let is_err = batch.is_err(); + if tx.blocking_send(batch).is_err() { + break; + } + if is_err { + break; + } + } + Ok(()) + }); let mut write_duration = std::time::Duration::ZERO; let mut num_posting_lists = 0; - while let Some(batch) = batches.try_next().await? { + while let Some(batch) = rx.recv().await { + let batch = batch?; num_posting_lists += 1; let start = std::time::Instant::now(); writer.write_record_batch(batch).await?; @@ -493,6 +504,9 @@ impl InnerBuilder { } } + // Errors from batch generation are sent through the channel and surfaced via `batch?`. + // Awaiting the producer here is just to propagate panics/cancellation. + producer.await?; writer.finish().await?; Ok(()) } diff --git a/rust/lance-index/src/scalar/inverted/encoding.rs b/rust/lance-index/src/scalar/inverted/encoding.rs index 29c4eb39f4b..57bc80cda66 100644 --- a/rust/lance-index/src/scalar/inverted/encoding.rs +++ b/rust/lance-index/src/scalar/inverted/encoding.rs @@ -90,6 +90,89 @@ pub fn compress_posting_list<'a>( Ok(builder.finish()) } +pub fn compress_posting_list_with_scores<'a, F>( + length: usize, + doc_ids: impl Iterator, + frequencies: impl Iterator, + mut score_for: F, + idf_scale: f32, +) -> Result<(arrow::array::LargeBinaryArray, f32)> +where + F: FnMut(u32, u32) -> f32, +{ + // `length` comes from posting list size; zero would produce an invalid block + // (a max-score header with no doc/frequency data) and readers assume > 0 docs. + debug_assert!(length > 0); + if length < BLOCK_SIZE { + let mut builder = LargeBinaryBuilder::with_capacity(1, length * 4 * 2 + 1); + let mut max_score = f32::MIN; + let mut doc_id_buffer = Vec::with_capacity(length); + let mut freq_buffer = Vec::with_capacity(length); + for (doc_id, freq) in std::iter::zip(doc_ids, frequencies) { + let doc_id = *doc_id; + let freq = *freq; + doc_id_buffer.push(doc_id); + freq_buffer.push(freq); + let score = score_for(doc_id, freq); + if score > max_score { + max_score = score; + } + } + let max_score = max_score * idf_scale; + let _ = builder.write(max_score.to_le_bytes().as_ref())?; + compress_remainder(&doc_id_buffer, &mut builder)?; + compress_remainder(&freq_buffer, &mut builder)?; + builder.append_value(""); + return Ok((builder.finish(), max_score)); + } + + let mut builder = LargeBinaryBuilder::with_capacity(length.div_ceil(BLOCK_SIZE), length * 3); + let mut buffer = [0u8; BLOCK_SIZE * 4 + 5]; + let mut doc_id_buffer = Vec::with_capacity(BLOCK_SIZE); + let mut freq_buffer = Vec::with_capacity(BLOCK_SIZE); + let mut max_score = f32::MIN; + let mut block_max_score = f32::MIN; + for (doc_id, freq) in std::iter::zip(doc_ids, frequencies) { + let doc_id = *doc_id; + let freq = *freq; + doc_id_buffer.push(doc_id); + freq_buffer.push(freq); + + let score = score_for(doc_id, freq); + if score > block_max_score { + block_max_score = score; + } + + if doc_id_buffer.len() < BLOCK_SIZE { + continue; + } + + let block_score = block_max_score * idf_scale; + if block_score > max_score { + max_score = block_score; + } + let _ = builder.write(block_score.to_le_bytes().as_ref())?; + compress_sorted_block(&doc_id_buffer, &mut buffer, &mut builder)?; + compress_block(&freq_buffer, &mut buffer, &mut builder)?; + builder.append_value(""); + doc_id_buffer.clear(); + freq_buffer.clear(); + block_max_score = f32::MIN; + } + + if !doc_id_buffer.is_empty() { + let block_score = block_max_score * idf_scale; + if block_score > max_score { + max_score = block_score; + } + let _ = builder.write(block_score.to_le_bytes().as_ref())?; + compress_remainder(&doc_id_buffer, &mut builder)?; + compress_remainder(&freq_buffer, &mut builder)?; + builder.append_value(""); + } + Ok((builder.finish(), max_score)) +} + #[inline] fn compress_sorted_block( data: &[u32], diff --git a/rust/lance-index/src/scalar/inverted/index.rs b/rust/lance-index/src/scalar/inverted/index.rs index 74ecc3c78b6..37d3f28ddb8 100644 --- a/rust/lance-index/src/scalar/inverted/index.rs +++ b/rust/lance-index/src/scalar/inverted/index.rs @@ -59,7 +59,7 @@ use super::{ }; use super::{ builder::{InnerBuilder, PositionRecorder}, - encoding::compress_posting_list, + encoding::{compress_posting_list, compress_posting_list_with_scores}, iter::CompressedPostingListIterator, }; use super::{encoding::compress_positions, iter::PostingListIterator}; @@ -1912,18 +1912,13 @@ impl PostingListBuilder { } } - // assume the posting list is sorted by doc id - pub fn to_batch(self, block_max_scores: Vec) -> Result { + fn build_batch( + self, + compressed: LargeBinaryArray, + max_score: f32, + schema: SchemaRef, + ) -> Result { let length = self.len(); - let max_score = block_max_scores.iter().copied().fold(f32::MIN, f32::max); - - let schema = inverted_list_schema(self.has_positions()); - let compressed = compress_posting_list( - self.doc_ids.len(), - self.doc_ids.iter(), - self.frequencies.iter(), - block_max_scores.into_iter(), - )?; let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, compressed.len() as i32])); let mut columns = vec![ Arc::new(ListArray::try_new( @@ -1934,7 +1929,7 @@ impl PostingListBuilder { )?) as ArrayRef, Arc::new(Float32Array::from_iter_values(std::iter::once(max_score))) as ArrayRef, Arc::new(UInt32Array::from_iter_values(std::iter::once( - self.len() as u32 + length as u32, ))) as ArrayRef, ]; @@ -1958,6 +1953,37 @@ impl PostingListBuilder { Ok(batch) } + // assume the posting list is sorted by doc id + pub fn to_batch(self, block_max_scores: Vec) -> Result { + let max_score = block_max_scores.iter().copied().fold(f32::MIN, f32::max); + let schema = inverted_list_schema(self.has_positions()); + let compressed = compress_posting_list( + self.doc_ids.len(), + self.doc_ids.iter(), + self.frequencies.iter(), + block_max_scores.into_iter(), + )?; + self.build_batch(compressed, max_score, schema) + } + + pub fn to_batch_with_docs(self, docs: &DocSet, schema: SchemaRef) -> Result { + let length = self.len(); + let avgdl = docs.average_length(); + let idf_scale = idf(length, docs.len()) * (K1 + 1.0); + let (compressed, max_score) = compress_posting_list_with_scores( + length, + self.doc_ids.iter(), + self.frequencies.iter(), + |doc_id, freq| { + let doc_norm = K1 * (1.0 - B + B * docs.num_tokens(doc_id) as f32 / avgdl); + let freq = freq as f32; + freq / (freq + doc_norm) + }, + idf_scale, + )?; + self.build_batch(compressed, max_score, schema) + } + pub fn remap(&mut self, removed: &[u32]) { let mut cursor = 0; let mut new_doc_ids = ExpLinkedList::with_capacity(self.len()); @@ -2562,10 +2588,12 @@ mod tests { use crate::metrics::NoOpMetricsCollector; use crate::prefilter::NoFilter; - use crate::scalar::inverted::builder::{InnerBuilder, PositionRecorder}; + use crate::scalar::inverted::builder::{inverted_list_schema, InnerBuilder, PositionRecorder}; use crate::scalar::inverted::encoding::decompress_posting_list; use crate::scalar::inverted::query::{FtsSearchParams, Operator}; use crate::scalar::lance_format::LanceIndexStore; + use arrow::array::AsArray; + use arrow::datatypes::{Float32Type, UInt32Type}; use super::*; @@ -2607,6 +2635,54 @@ mod tests { .all(|(a, b)| a == b)); } + #[test] + fn test_posting_list_batch_matches_docset_scoring() { + let mut docs = DocSet::default(); + let num_docs = BLOCK_SIZE + 3; + for doc_id in 0..num_docs as u32 { + docs.append(doc_id as u64, doc_id % 7 + 1); + } + + let doc_ids = (0..num_docs as u32).collect::>(); + let freqs = doc_ids + .iter() + .map(|doc_id| doc_id % 5 + 1) + .collect::>(); + + let mut builder_scores = PostingListBuilder::new(false); + let mut builder_docs = PostingListBuilder::new(false); + for (&doc_id, &freq) in doc_ids.iter().zip(freqs.iter()) { + builder_scores.add(doc_id, PositionRecorder::Count(freq)); + builder_docs.add(doc_id, PositionRecorder::Count(freq)); + } + + let block_max_scores = docs.calculate_block_max_scores(doc_ids.iter(), freqs.iter()); + let batch_scores = builder_scores.to_batch(block_max_scores).unwrap(); + let batch_docs = builder_docs + .to_batch_with_docs(&docs, inverted_list_schema(false)) + .unwrap(); + + let scores_posting = batch_scores[POSTING_COL].as_list::().value(0); + let scores_posting = scores_posting.as_binary::(); + let docs_posting = batch_docs[POSTING_COL].as_list::().value(0); + let docs_posting = docs_posting.as_binary::(); + assert_eq!(scores_posting, docs_posting); + + let score_left = batch_scores[MAX_SCORE_COL] + .as_primitive::() + .value(0); + let score_right = batch_docs[MAX_SCORE_COL] + .as_primitive::() + .value(0); + assert!((score_left - score_right).abs() < 1e-6); + + let len_left = batch_scores[LENGTH_COL] + .as_primitive::() + .value(0); + let len_right = batch_docs[LENGTH_COL].as_primitive::().value(0); + assert_eq!(len_left, len_right); + } + #[tokio::test] async fn test_remap_to_empty_posting_list() { let tmpdir = TempObjDir::default();