Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions rust/lance-index/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ rayon.workspace = true
serde_json.workspace = true
serde.workspace = true
snafu.workspace = true
smallvec = "1.15"
tantivy.workspace = true
lindera = { workspace = true, optional = true }
lindera-tantivy = { workspace = true, optional = true }
Expand Down
138 changes: 124 additions & 14 deletions rust/lance-index/src/scalar/inverted/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use lance_core::{error::LanceOptionExt, utils::tempfile::TempDir};
use lance_core::{Error, Result, ROW_ID, ROW_ID_FIELD};
use lance_io::object_store::ObjectStore;
use object_store::path::Path;
use smallvec::SmallVec;
use snafu::location;
use std::collections::HashMap;
use std::pin::Pin;
Expand Down Expand Up @@ -572,6 +573,10 @@ struct IndexWorker {
total_doc_length: usize,
fragment_mask: Option<u64>,
token_set_format: TokenSetFormat,
token_occurrences: HashMap<u32, PositionRecorder>,
token_ids: Vec<u32>,
last_token_count: usize,
last_unique_token_count: usize,
}

impl IndexWorker {
Expand Down Expand Up @@ -601,6 +606,10 @@ impl IndexWorker {
total_doc_length: 0,
fragment_mask,
token_set_format,
token_occurrences: HashMap::new(),
token_ids: Vec::new(),
last_token_count: 0,
last_unique_token_count: 0,
})
}

Expand All @@ -618,20 +627,40 @@ impl IndexWorker {

let with_position = self.has_position();
for (doc, row_id) in docs {
let mut token_occurrences = HashMap::new();
let mut token_num = 0;
{
let mut token_num: u32 = 0;
if with_position {
if self.token_occurrences.capacity() < self.last_unique_token_count {
self.token_occurrences
.reserve(self.last_unique_token_count - self.token_occurrences.capacity());
}
self.token_occurrences.clear();

let mut token_stream = self.tokenizer.token_stream_for_doc(doc);
while token_stream.advance() {
let token = token_stream.token_mut();
let token_text = std::mem::take(&mut token.text);
let token_id = self.builder.tokens.add(token_text) as usize;
token_occurrences
.entry(token_id as u32)
.or_insert_with(|| PositionRecorder::new(with_position))
let token_id = self.builder.tokens.add(token_text);
self.token_occurrences
.entry(token_id)
.or_insert_with(|| PositionRecorder::new(true))
.push(token.position as u32);
token_num += 1;
}
} else {
if self.token_ids.capacity() < self.last_token_count {
self.token_ids
.reserve(self.last_token_count - self.token_ids.capacity());
}
self.token_ids.clear();

let mut token_stream = self.tokenizer.token_stream_for_doc(doc);
while token_stream.advance() {
let token = token_stream.token_mut();
let token_text = std::mem::take(&mut token.text);
let token_id = self.builder.tokens.add(token_text);
self.token_ids.push(token_id);
token_num += 1;
}
}
self.builder
.posting_lists
Expand All @@ -641,16 +670,44 @@ impl IndexWorker {
let doc_id = self.builder.docs.append(row_id, token_num);
self.total_doc_length += doc.len();

token_occurrences
.into_iter()
.for_each(|(token_id, term_positions)| {
if with_position {
let unique_tokens = self.token_occurrences.len();
for (token_id, term_positions) in self.token_occurrences.drain() {
let posting_list = &mut self.builder.posting_lists[token_id as usize];

let old_size = posting_list.size();
posting_list.add(doc_id, term_positions);
let new_size = posting_list.size();
self.estimated_size += new_size - old_size;
});
}
self.last_unique_token_count = unique_tokens;
} else if token_num > 0 {
self.token_ids.sort_unstable();
let mut iter = self.token_ids.iter();
let mut current = *iter.next().unwrap();
let mut count = 1u32;
for &token_id in iter {
if token_id == current {
count += 1;
continue;
}

let posting_list = &mut self.builder.posting_lists[current as usize];
let old_size = posting_list.size();
posting_list.add(doc_id, PositionRecorder::Count(count));
let new_size = posting_list.size();
self.estimated_size += new_size - old_size;

current = token_id;
count = 1;
}
let posting_list = &mut self.builder.posting_lists[current as usize];
let old_size = posting_list.size();
posting_list.add(doc_id, PositionRecorder::Count(count));
let new_size = posting_list.size();
self.estimated_size += new_size - old_size;
}
self.last_token_count = token_num as usize;

if self.builder.docs.len() as u32 == u32::MAX
|| self.estimated_size >= *LANCE_FTS_PARTITION_SIZE << 20
Expand Down Expand Up @@ -699,14 +756,14 @@ impl IndexWorker {

#[derive(Debug, Clone)]
pub enum PositionRecorder {
Position(Vec<u32>),
Position(SmallVec<[u32; 4]>),
Count(u32),
}

impl PositionRecorder {
fn new(with_position: bool) -> Self {
if with_position {
Self::Position(Vec::new())
Self::Position(SmallVec::new())
} else {
Self::Count(0)
}
Expand All @@ -732,7 +789,7 @@ impl PositionRecorder {

pub fn into_vec(self) -> Vec<u32> {
match self {
Self::Position(positions) => positions,
Self::Position(positions) => positions.into_vec(),
Self::Count(_) => vec![0],
}
}
Expand Down Expand Up @@ -1192,8 +1249,11 @@ pub fn document_input(
#[cfg(test)]
mod tests {
use super::*;
use crate::metrics::NoOpMetricsCollector;
use arrow_array::{RecordBatch, StringArray, UInt64Array};
use arrow_schema::{DataType, Field, Schema};
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use futures::stream;
use lance_core::cache::LanceCache;
use lance_core::utils::tempfile::TempDir;
use lance_core::ROW_ID;
Expand Down Expand Up @@ -1285,4 +1345,54 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn test_inverted_index_without_positions_tracks_frequency() -> Result<()> {
let index_dir = TempDir::default();
let store = Arc::new(LanceIndexStore::new(
ObjectStore::local().into(),
index_dir.obj_path(),
Arc::new(LanceCache::no_cache()),
));

let schema = Arc::new(Schema::new(vec![
Field::new("doc", DataType::Utf8, true),
Field::new(ROW_ID, DataType::UInt64, false),
]));
let docs = Arc::new(StringArray::from(vec![Some("hello hello world")]));
let row_ids = Arc::new(UInt64Array::from(vec![0u64]));
let batch = RecordBatch::try_new(schema.clone(), vec![docs, row_ids])?;
let stream = RecordBatchStreamAdapter::new(schema, stream::iter(vec![Ok(batch)]));
let stream = Box::pin(stream);

let params = InvertedIndexParams::new(
"whitespace".to_string(),
tantivy::tokenizer::Language::English,
)
.with_position(false)
.remove_stop_words(false)
.stem(false)
.max_token_length(None);

let mut builder = InvertedIndexBuilder::new(params);
builder.update(stream, store.as_ref()).await?;

let index = InvertedIndex::load(store, None, &LanceCache::no_cache()).await?;
assert_eq!(index.partitions.len(), 1);
let partition = &index.partitions[0];
let token_id = partition.tokens.get("hello").unwrap();
let posting = partition
.inverted_list
.posting_list(token_id, false, &NoOpMetricsCollector)
.await?;

let mut iter = posting.iter();
let (doc_id, freq, positions) = iter.next().unwrap();
assert_eq!(doc_id, 0);
assert_eq!(freq, 2);
assert!(positions.is_none());
assert!(iter.next().is_none());

Ok(())
}
}
4 changes: 2 additions & 2 deletions rust/lance-index/src/scalar/inverted/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1632,7 +1632,7 @@ impl PostingList {
let freq = freq as u32;
let positions = match positions {
Some(positions) => {
PositionRecorder::Position(positions.collect::<Vec<_>>())
PositionRecorder::Position(positions.collect::<Vec<_>>().into())
}
None => PositionRecorder::Count(freq),
};
Expand All @@ -1650,7 +1650,7 @@ impl PostingList {
posting.iter().for_each(|(doc_id, freq, positions)| {
let positions = match positions {
Some(positions) => {
PositionRecorder::Position(positions.collect::<Vec<_>>())
PositionRecorder::Position(positions.collect::<Vec<_>>().into())
}
None => PositionRecorder::Count(freq),
};
Expand Down
Loading