Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions rust/lance-index/src/scalar/inverted/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2192,7 +2192,9 @@ impl DocSet {
) -> Vec<f32> {
let avgdl = self.average_length();
let length = doc_ids.size_hint().0;
let mut block_max_scores = Vec::with_capacity(length);
let num_blocks = length.div_ceil(BLOCK_SIZE);
let mut block_max_scores = Vec::with_capacity(num_blocks);
let idf_scale = idf(length, self.len()) * (K1 + 1.0);
let mut max_score = f32::MIN;
for (i, (doc_id, freq)) in doc_ids.zip(freqs).enumerate() {
let doc_norm = K1 * (1.0 - B + B * self.num_tokens(*doc_id) as f32 / avgdl);
Expand All @@ -2202,13 +2204,13 @@ impl DocSet {
max_score = score;
}
if (i + 1) % BLOCK_SIZE == 0 {
max_score *= idf(length, self.len()) * (K1 + 1.0);
max_score *= idf_scale;
block_max_scores.push(max_score);
max_score = f32::MIN;
}
}
if length % BLOCK_SIZE > 0 {
max_score *= idf(length, self.len()) * (K1 + 1.0);
max_score *= idf_scale;
block_max_scores.push(max_score);
}
block_max_scores
Expand Down Expand Up @@ -2771,4 +2773,21 @@ mod tests {
"Should contain row_id from partition 1"
);
}

#[test]
fn test_block_max_scores_capacity_matches_block_count() {
let mut docs = DocSet::default();
let num_docs = BLOCK_SIZE * 3 + 7;
let doc_ids = (0..num_docs as u32).collect::<Vec<_>>();
for doc_id in &doc_ids {
docs.append(*doc_id as u64, 1);
}

let freqs = vec![1_u32; doc_ids.len()];
let block_max_scores = docs.calculate_block_max_scores(doc_ids.iter(), freqs.iter());
let expected_blocks = doc_ids.len().div_ceil(BLOCK_SIZE);

assert_eq!(block_max_scores.len(), expected_blocks);
assert_eq!(block_max_scores.capacity(), expected_blocks);
}
}