diff --git a/Cargo.lock b/Cargo.lock index 9c46fbf8da4..28e0069601b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5066,6 +5066,7 @@ dependencies = [ "rstest 0.23.0", "serde", "serde_json", + "smallvec", "snafu", "tantivy", "tempfile", diff --git a/rust/lance-index/Cargo.toml b/rust/lance-index/Cargo.toml index 29a6ba18475..70262b8de92 100644 --- a/rust/lance-index/Cargo.toml +++ b/rust/lance-index/Cargo.toml @@ -59,6 +59,7 @@ rayon.workspace = true serde_json.workspace = true serde.workspace = true snafu.workspace = true +smallvec = "1.15" tantivy.workspace = true lindera = { workspace = true, optional = true } lindera-tantivy = { workspace = true, optional = true } diff --git a/rust/lance-index/src/scalar/inverted/builder.rs b/rust/lance-index/src/scalar/inverted/builder.rs index 75c3aa9a33b..15b037a00ad 100644 --- a/rust/lance-index/src/scalar/inverted/builder.rs +++ b/rust/lance-index/src/scalar/inverted/builder.rs @@ -28,6 +28,7 @@ use lance_core::{error::LanceOptionExt, utils::tempfile::TempDir}; use lance_core::{Error, Result, ROW_ID, ROW_ID_FIELD}; use lance_io::object_store::ObjectStore; use object_store::path::Path; +use smallvec::SmallVec; use snafu::location; use std::collections::HashMap; use std::pin::Pin; @@ -572,6 +573,10 @@ struct IndexWorker { total_doc_length: usize, fragment_mask: Option, token_set_format: TokenSetFormat, + token_occurrences: HashMap, + token_ids: Vec, + last_token_count: usize, + last_unique_token_count: usize, } impl IndexWorker { @@ -601,6 +606,10 @@ impl IndexWorker { total_doc_length: 0, fragment_mask, token_set_format, + token_occurrences: HashMap::new(), + token_ids: Vec::new(), + last_token_count: 0, + last_unique_token_count: 0, }) } @@ -618,20 +627,40 @@ impl IndexWorker { let with_position = self.has_position(); for (doc, row_id) in docs { - let mut token_occurrences = HashMap::new(); - let mut token_num = 0; - { + let mut token_num: u32 = 0; + if with_position { + if self.token_occurrences.capacity() < self.last_unique_token_count { + self.token_occurrences + .reserve(self.last_unique_token_count - self.token_occurrences.capacity()); + } + self.token_occurrences.clear(); + let mut token_stream = self.tokenizer.token_stream_for_doc(doc); while token_stream.advance() { let token = token_stream.token_mut(); let token_text = std::mem::take(&mut token.text); - let token_id = self.builder.tokens.add(token_text) as usize; - token_occurrences - .entry(token_id as u32) - .or_insert_with(|| PositionRecorder::new(with_position)) + let token_id = self.builder.tokens.add(token_text); + self.token_occurrences + .entry(token_id) + .or_insert_with(|| PositionRecorder::new(true)) .push(token.position as u32); token_num += 1; } + } else { + if self.token_ids.capacity() < self.last_token_count { + self.token_ids + .reserve(self.last_token_count - self.token_ids.capacity()); + } + self.token_ids.clear(); + + let mut token_stream = self.tokenizer.token_stream_for_doc(doc); + while token_stream.advance() { + let token = token_stream.token_mut(); + let token_text = std::mem::take(&mut token.text); + let token_id = self.builder.tokens.add(token_text); + self.token_ids.push(token_id); + token_num += 1; + } } self.builder .posting_lists @@ -641,16 +670,44 @@ impl IndexWorker { let doc_id = self.builder.docs.append(row_id, token_num); self.total_doc_length += doc.len(); - token_occurrences - .into_iter() - .for_each(|(token_id, term_positions)| { + if with_position { + let unique_tokens = self.token_occurrences.len(); + for (token_id, term_positions) in self.token_occurrences.drain() { let posting_list = &mut self.builder.posting_lists[token_id as usize]; let old_size = posting_list.size(); posting_list.add(doc_id, term_positions); let new_size = posting_list.size(); self.estimated_size += new_size - old_size; - }); + } + self.last_unique_token_count = unique_tokens; + } else if token_num > 0 { + self.token_ids.sort_unstable(); + let mut iter = self.token_ids.iter(); + let mut current = *iter.next().unwrap(); + let mut count = 1u32; + for &token_id in iter { + if token_id == current { + count += 1; + continue; + } + + let posting_list = &mut self.builder.posting_lists[current as usize]; + let old_size = posting_list.size(); + posting_list.add(doc_id, PositionRecorder::Count(count)); + let new_size = posting_list.size(); + self.estimated_size += new_size - old_size; + + current = token_id; + count = 1; + } + let posting_list = &mut self.builder.posting_lists[current as usize]; + let old_size = posting_list.size(); + posting_list.add(doc_id, PositionRecorder::Count(count)); + let new_size = posting_list.size(); + self.estimated_size += new_size - old_size; + } + self.last_token_count = token_num as usize; if self.builder.docs.len() as u32 == u32::MAX || self.estimated_size >= *LANCE_FTS_PARTITION_SIZE << 20 @@ -699,14 +756,14 @@ impl IndexWorker { #[derive(Debug, Clone)] pub enum PositionRecorder { - Position(Vec), + Position(SmallVec<[u32; 4]>), Count(u32), } impl PositionRecorder { fn new(with_position: bool) -> Self { if with_position { - Self::Position(Vec::new()) + Self::Position(SmallVec::new()) } else { Self::Count(0) } @@ -732,7 +789,7 @@ impl PositionRecorder { pub fn into_vec(self) -> Vec { match self { - Self::Position(positions) => positions, + Self::Position(positions) => positions.into_vec(), Self::Count(_) => vec![0], } } @@ -1192,8 +1249,11 @@ pub fn document_input( #[cfg(test)] mod tests { use super::*; + use crate::metrics::NoOpMetricsCollector; use arrow_array::{RecordBatch, StringArray, UInt64Array}; use arrow_schema::{DataType, Field, Schema}; + use datafusion::physical_plan::stream::RecordBatchStreamAdapter; + use futures::stream; use lance_core::cache::LanceCache; use lance_core::utils::tempfile::TempDir; use lance_core::ROW_ID; @@ -1285,4 +1345,54 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_inverted_index_without_positions_tracks_frequency() -> Result<()> { + let index_dir = TempDir::default(); + let store = Arc::new(LanceIndexStore::new( + ObjectStore::local().into(), + index_dir.obj_path(), + Arc::new(LanceCache::no_cache()), + )); + + let schema = Arc::new(Schema::new(vec![ + Field::new("doc", DataType::Utf8, true), + Field::new(ROW_ID, DataType::UInt64, false), + ])); + let docs = Arc::new(StringArray::from(vec![Some("hello hello world")])); + let row_ids = Arc::new(UInt64Array::from(vec![0u64])); + let batch = RecordBatch::try_new(schema.clone(), vec![docs, row_ids])?; + let stream = RecordBatchStreamAdapter::new(schema, stream::iter(vec![Ok(batch)])); + let stream = Box::pin(stream); + + let params = InvertedIndexParams::new( + "whitespace".to_string(), + tantivy::tokenizer::Language::English, + ) + .with_position(false) + .remove_stop_words(false) + .stem(false) + .max_token_length(None); + + let mut builder = InvertedIndexBuilder::new(params); + builder.update(stream, store.as_ref()).await?; + + let index = InvertedIndex::load(store, None, &LanceCache::no_cache()).await?; + assert_eq!(index.partitions.len(), 1); + let partition = &index.partitions[0]; + let token_id = partition.tokens.get("hello").unwrap(); + let posting = partition + .inverted_list + .posting_list(token_id, false, &NoOpMetricsCollector) + .await?; + + let mut iter = posting.iter(); + let (doc_id, freq, positions) = iter.next().unwrap(); + assert_eq!(doc_id, 0); + assert_eq!(freq, 2); + assert!(positions.is_none()); + assert!(iter.next().is_none()); + + Ok(()) + } } diff --git a/rust/lance-index/src/scalar/inverted/index.rs b/rust/lance-index/src/scalar/inverted/index.rs index 2c5c7a847a5..61c44a0c005 100644 --- a/rust/lance-index/src/scalar/inverted/index.rs +++ b/rust/lance-index/src/scalar/inverted/index.rs @@ -1632,7 +1632,7 @@ impl PostingList { let freq = freq as u32; let positions = match positions { Some(positions) => { - PositionRecorder::Position(positions.collect::>()) + PositionRecorder::Position(positions.collect::>().into()) } None => PositionRecorder::Count(freq), }; @@ -1650,7 +1650,7 @@ impl PostingList { posting.iter().for_each(|(doc_id, freq, positions)| { let positions = match positions { Some(positions) => { - PositionRecorder::Position(positions.collect::>()) + PositionRecorder::Position(positions.collect::>().into()) } None => PositionRecorder::Count(freq), };