diff --git a/rust/lance-index/src/scalar/inverted/index.rs b/rust/lance-index/src/scalar/inverted/index.rs index 3605364715e..74ecc3c78b6 100644 --- a/rust/lance-index/src/scalar/inverted/index.rs +++ b/rust/lance-index/src/scalar/inverted/index.rs @@ -302,11 +302,23 @@ impl InvertedIndex { .collect::>(); let mut parts = stream::iter(parts).buffer_unordered(get_num_compute_intensive_cpus()); let scorer = IndexBM25Scorer::new(self.partitions.iter().map(|part| part.as_ref())); + let mut idf_cache: HashMap = HashMap::new(); while let Some(res) = parts.try_next().await? { if res.candidates.is_empty() { continue; } - let tokens_by_position = &res.tokens_by_position; + let mut idf_by_position = Vec::with_capacity(res.tokens_by_position.len()); + for token in &res.tokens_by_position { + let idf_weight = match idf_cache.get(token) { + Some(weight) => *weight, + None => { + let weight = scorer.query_weight(token); + idf_cache.insert(token.clone(), weight); + weight + } + }; + idf_by_position.push(idf_weight); + } for DocCandidate { row_id, freqs, @@ -315,9 +327,9 @@ impl InvertedIndex { { let mut score = 0.0; for (term_index, freq) in freqs.into_iter() { - debug_assert!((term_index as usize) < tokens_by_position.len()); - let token = &tokens_by_position[term_index as usize]; - score += scorer.score(token.as_str(), freq, doc_length); + debug_assert!((term_index as usize) < idf_by_position.len()); + score += + idf_by_position[term_index as usize] * scorer.doc_weight(freq, doc_length); } if candidates.len() < limit { candidates.push(Reverse(ScoredDoc::new(row_id, score))); @@ -2790,4 +2802,86 @@ mod tests { assert_eq!(block_max_scores.len(), expected_blocks); assert_eq!(block_max_scores.capacity(), expected_blocks); } + + #[tokio::test] + async fn test_bm25_search_uses_global_idf() { + let tmpdir = TempObjDir::default(); + let store = Arc::new(LanceIndexStore::new( + ObjectStore::local().into(), + tmpdir.clone(), + Arc::new(LanceCache::no_cache()), + )); + + // Partition 0: 3 docs, only one contains "alpha". + let mut builder0 = InnerBuilder::new(0, false, TokenSetFormat::default()); + builder0.tokens.add("alpha".to_owned()); + builder0.tokens.add("beta".to_owned()); + builder0.posting_lists.push(PostingListBuilder::new(false)); + builder0.posting_lists.push(PostingListBuilder::new(false)); + builder0.posting_lists[0].add(0, PositionRecorder::Count(1)); + builder0.posting_lists[1].add(1, PositionRecorder::Count(1)); + builder0.posting_lists[1].add(2, PositionRecorder::Count(1)); + builder0.docs.append(100, 1); + builder0.docs.append(101, 1); + builder0.docs.append(102, 1); + builder0.write(store.as_ref()).await.unwrap(); + + // Partition 1: 1 doc, contains "alpha". + let mut builder1 = InnerBuilder::new(1, false, TokenSetFormat::default()); + builder1.tokens.add("alpha".to_owned()); + builder1.posting_lists.push(PostingListBuilder::new(false)); + builder1.posting_lists[0].add(0, PositionRecorder::Count(1)); + builder1.docs.append(200, 1); + builder1.write(store.as_ref()).await.unwrap(); + + let metadata = std::collections::HashMap::from_iter(vec![ + ( + "partitions".to_owned(), + serde_json::to_string(&vec![0u64, 1u64]).unwrap(), + ), + ( + "params".to_owned(), + serde_json::to_string(&InvertedIndexParams::default()).unwrap(), + ), + ( + TOKEN_SET_FORMAT_KEY.to_owned(), + TokenSetFormat::default().to_string(), + ), + ]); + let mut writer = store + .new_index_file(METADATA_FILE, Arc::new(arrow_schema::Schema::empty())) + .await + .unwrap(); + writer.finish_with_metadata(metadata).await.unwrap(); + + let cache = Arc::new(LanceCache::with_capacity(4096)); + let index = InvertedIndex::load(store.clone(), None, cache.as_ref()) + .await + .unwrap(); + + let tokens = Arc::new(Tokens::new(vec!["alpha".to_string()], DocType::Text)); + let params = Arc::new(FtsSearchParams::new().with_limit(Some(10))); + let prefilter = Arc::new(NoFilter); + let metrics = Arc::new(NoOpMetricsCollector); + + let (row_ids, scores) = index + .bm25_search(tokens, params, Operator::Or, prefilter, metrics) + .await + .unwrap(); + + assert_eq!(row_ids.len(), 2); + assert!(row_ids.contains(&100)); + assert!(row_ids.contains(&200)); + assert_eq!(row_ids.len(), scores.len()); + + let expected_idf = idf(2, 4); + for score in scores { + assert!( + (score - expected_idf).abs() < 1e-6, + "score: {}, expected: {}", + score, + expected_idf + ); + } + } }