Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 98 additions & 4 deletions rust/lance-index/src/scalar/inverted/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,11 +302,23 @@ impl InvertedIndex {
.collect::<Vec<_>>();
let mut parts = stream::iter(parts).buffer_unordered(get_num_compute_intensive_cpus());
let scorer = IndexBM25Scorer::new(self.partitions.iter().map(|part| part.as_ref()));
let mut idf_cache: HashMap<String, f32> = HashMap::new();
while let Some(res) = parts.try_next().await? {
if res.candidates.is_empty() {
continue;
}
let tokens_by_position = &res.tokens_by_position;
let mut idf_by_position = Vec::with_capacity(res.tokens_by_position.len());
for token in &res.tokens_by_position {
let idf_weight = match idf_cache.get(token) {
Some(weight) => *weight,
None => {
let weight = scorer.query_weight(token);
idf_cache.insert(token.clone(), weight);
weight
}
};
idf_by_position.push(idf_weight);
}
for DocCandidate {
row_id,
freqs,
Expand All @@ -315,9 +327,9 @@ impl InvertedIndex {
{
let mut score = 0.0;
for (term_index, freq) in freqs.into_iter() {
debug_assert!((term_index as usize) < tokens_by_position.len());
let token = &tokens_by_position[term_index as usize];
score += scorer.score(token.as_str(), freq, doc_length);
debug_assert!((term_index as usize) < idf_by_position.len());
score +=
idf_by_position[term_index as usize] * scorer.doc_weight(freq, doc_length);
}
if candidates.len() < limit {
candidates.push(Reverse(ScoredDoc::new(row_id, score)));
Expand Down Expand Up @@ -2790,4 +2802,86 @@ mod tests {
assert_eq!(block_max_scores.len(), expected_blocks);
assert_eq!(block_max_scores.capacity(), expected_blocks);
}

#[tokio::test]
async fn test_bm25_search_uses_global_idf() {
let tmpdir = TempObjDir::default();
let store = Arc::new(LanceIndexStore::new(
ObjectStore::local().into(),
tmpdir.clone(),
Arc::new(LanceCache::no_cache()),
));

// Partition 0: 3 docs, only one contains "alpha".
let mut builder0 = InnerBuilder::new(0, false, TokenSetFormat::default());
builder0.tokens.add("alpha".to_owned());
builder0.tokens.add("beta".to_owned());
builder0.posting_lists.push(PostingListBuilder::new(false));
builder0.posting_lists.push(PostingListBuilder::new(false));
builder0.posting_lists[0].add(0, PositionRecorder::Count(1));
builder0.posting_lists[1].add(1, PositionRecorder::Count(1));
builder0.posting_lists[1].add(2, PositionRecorder::Count(1));
builder0.docs.append(100, 1);
builder0.docs.append(101, 1);
builder0.docs.append(102, 1);
builder0.write(store.as_ref()).await.unwrap();

// Partition 1: 1 doc, contains "alpha".
let mut builder1 = InnerBuilder::new(1, false, TokenSetFormat::default());
builder1.tokens.add("alpha".to_owned());
builder1.posting_lists.push(PostingListBuilder::new(false));
builder1.posting_lists[0].add(0, PositionRecorder::Count(1));
builder1.docs.append(200, 1);
builder1.write(store.as_ref()).await.unwrap();

let metadata = std::collections::HashMap::from_iter(vec![
(
"partitions".to_owned(),
serde_json::to_string(&vec![0u64, 1u64]).unwrap(),
),
(
"params".to_owned(),
serde_json::to_string(&InvertedIndexParams::default()).unwrap(),
),
(
TOKEN_SET_FORMAT_KEY.to_owned(),
TokenSetFormat::default().to_string(),
),
]);
let mut writer = store
.new_index_file(METADATA_FILE, Arc::new(arrow_schema::Schema::empty()))
.await
.unwrap();
writer.finish_with_metadata(metadata).await.unwrap();

let cache = Arc::new(LanceCache::with_capacity(4096));
let index = InvertedIndex::load(store.clone(), None, cache.as_ref())
.await
.unwrap();

let tokens = Arc::new(Tokens::new(vec!["alpha".to_string()], DocType::Text));
let params = Arc::new(FtsSearchParams::new().with_limit(Some(10)));
let prefilter = Arc::new(NoFilter);
let metrics = Arc::new(NoOpMetricsCollector);

let (row_ids, scores) = index
.bm25_search(tokens, params, Operator::Or, prefilter, metrics)
.await
.unwrap();

assert_eq!(row_ids.len(), 2);
assert!(row_ids.contains(&100));
assert!(row_ids.contains(&200));
assert_eq!(row_ids.len(), scores.len());

let expected_idf = idf(2, 4);
for score in scores {
assert!(
(score - expected_idf).abs() < 1e-6,
"score: {}, expected: {}",
score,
expected_idf
);
}
}
}
Loading