diff --git a/rust/lance-index/benches/inverted.rs b/rust/lance-index/benches/inverted.rs index 415c1bc3fc4..f08d711ce5b 100644 --- a/rust/lance-index/benches/inverted.rs +++ b/rust/lance-index/benches/inverted.rs @@ -14,7 +14,6 @@ use futures::stream; use itertools::Itertools; use lance_core::cache::LanceCache; use lance_core::ROW_ID; -use lance_datagen::{array, RowCount}; use lance_index::prefilter::NoFilter; use lance_index::scalar::inverted::lance_tokenizer::DocType; use lance_index::scalar::inverted::query::{FtsSearchParams, Operator, Tokens}; @@ -27,6 +26,8 @@ use lance_io::object_store::ObjectStore; use object_store::path::Path; #[cfg(target_os = "linux")] use pprof::criterion::{Output, PProfProfiler}; +use rand::{rngs::StdRng, Rng, SeedableRng}; +use rand_distr::Zipf; fn bench_inverted(c: &mut Criterion) { const TOTAL: usize = 1_000_000; @@ -43,16 +44,32 @@ fn bench_inverted(c: &mut Criterion) { )) }); - // generate random words using lance-datagen let row_id_col = Arc::new(UInt64Array::from( (0..TOTAL).map(|i| i as u64).collect_vec(), )); - // Generate random words with 1-100 words per document - let mut words_gen = array::random_sentence(1, 100, true); - let doc_col = words_gen - .generate_default(RowCount::from(TOTAL as u64)) - .unwrap(); + // Generate Zipf-distributed words to better reflect real-world term frequency. + const VOCAB_SIZE: usize = 100_000; + const MIN_WORDS: usize = 1; + const MAX_WORDS: usize = 100; + const ZIPF_EXPONENT: f64 = 1.1; + let vocab: Vec = (0..VOCAB_SIZE).map(|i| format!("term{i:05}")).collect(); + let word_zipf = Zipf::new(VOCAB_SIZE as f64, ZIPF_EXPONENT).unwrap(); + let mut rng = StdRng::seed_from_u64(42); + let mut docs = Vec::with_capacity(TOTAL); + for _ in 0..TOTAL { + let num_words = rng.random_range(MIN_WORDS..=MAX_WORDS); + let mut doc = String::with_capacity(num_words * 8); + for i in 0..num_words { + let idx = (rng.sample(word_zipf) as usize).clamp(1, VOCAB_SIZE) - 1; + if i > 0 { + doc.push(' '); + } + doc.push_str(&vocab[idx]); + } + docs.push(doc); + } + let doc_col = Arc::new(LargeStringArray::from(docs)); let batch = RecordBatch::try_new( arrow_schema::Schema::new(vec![ arrow_schema::Field::new("doc", arrow_schema::DataType::LargeUtf8, false), @@ -86,32 +103,48 @@ fn bench_inverted(c: &mut Criterion) { let no_filter = Arc::new(NoFilter); // Get some sample words from the generated documents for search - let large_string_array = doc_col.as_any().downcast_ref::().unwrap(); - let sample_doc = large_string_array.value(0); + let sample_doc = doc_col.value(0); let sample_words: Vec = sample_doc .split_whitespace() .map(|s| s.to_owned()) .collect(); + let sample_words_len = sample_words.len(); + const TOKENS_PER_QUERY: usize = 15; + const QUERY_SET_SIZE: usize = 1024; + let mut query_rng = StdRng::seed_from_u64(7); + let mut queries = Vec::with_capacity(QUERY_SET_SIZE); + for _ in 0..QUERY_SET_SIZE { + let mut query_tokens = Vec::with_capacity(TOKENS_PER_QUERY); + for _ in 0..TOKENS_PER_QUERY { + let word_idx = query_rng.random_range(0..sample_words_len); + query_tokens.push(sample_words[word_idx].clone()); + } + queries.push(Arc::new(Tokens::new(query_tokens, DocType::Text))); + } + let mut query_idx = 0usize; c.bench_function(format!("invert_search({TOTAL})").as_str(), |b| { - b.to_async(&rt).iter(|| async { - // Pick a random word from our sample - let word_idx = rand::random_range(0..sample_words.len()); - black_box( - invert_index - .bm25_search( - Arc::new(Tokens::new( - vec![sample_words[word_idx].clone()], - DocType::Text, - )), - params.clone().into(), - Operator::Or, - no_filter.clone(), - Arc::new(NoOpMetricsCollector), - ) - .await - .unwrap(), - ); + b.to_async(&rt).iter(|| { + // Cycle through pre-generated queries to avoid skewing benchmark results. + let query = queries[query_idx % queries.len()].clone(); + query_idx = query_idx.wrapping_add(1); + let invert_index = invert_index.clone(); + let params = params.clone(); + let no_filter = no_filter.clone(); + async move { + black_box( + invert_index + .bm25_search( + query, + params.clone().into(), + Operator::Or, + no_filter.clone(), + Arc::new(NoOpMetricsCollector), + ) + .await + .unwrap(), + ); + } }) }); }