diff --git a/rust/lance-index/src/scalar/inverted/index.rs b/rust/lance-index/src/scalar/inverted/index.rs index 422d154b453..8ecc92b9170 100644 --- a/rust/lance-index/src/scalar/inverted/index.rs +++ b/rust/lance-index/src/scalar/inverted/index.rs @@ -476,7 +476,12 @@ impl InvertedIndex { if postings.is_empty() { return Result::Ok(PartitionCandidates::empty()); } - let mut tokens_by_position = vec![String::new(); postings.len()]; + let max_position = postings + .iter() + .map(|posting| posting.term_index() as usize) + .max() + .unwrap_or_default(); + let mut tokens_by_position = vec![String::new(); max_position + 1]; for posting in &postings { let idx = posting.term_index() as usize; tokens_by_position[idx] = posting.token().to_owned(); @@ -978,11 +983,14 @@ impl InvertedPartition { true => self.expand_fuzzy(tokens, params)?, false => tokens.clone(), }; + let token_positions = (0..tokens.len()) + .map(|index| tokens.position(index)) + .collect::>(); let mut token_ids = Vec::with_capacity(tokens.len()); - for token in tokens { + for (index, token) in tokens.into_iter().enumerate() { let token_id = self.map(&token); if let Some(token_id) = token_id { - token_ids.push((token_id, token)); + token_ids.push((token_id, token, token_positions[index])); } else if is_phrase_query { // if the token is not found, we can't do phrase query return Ok(Vec::new()); @@ -992,14 +1000,13 @@ impl InvertedPartition { return Ok(Vec::new()); } if !is_phrase_query { - token_ids.sort_unstable_by_key(|(token_id, _)| *token_id); - token_ids.dedup_by_key(|(token_id, _)| *token_id); + token_ids.sort_unstable_by_key(|(token_id, _, _)| *token_id); + token_ids.dedup_by_key(|(token_id, _, _)| *token_id); } let num_docs = self.docs.len(); stream::iter(token_ids) - .enumerate() - .map(|(position, (token_id, token))| async move { + .map(|(token_id, token, position)| async move { let posting = self .inverted_list .posting_list(token_id, is_phrase_query, metrics) @@ -1010,7 +1017,7 @@ impl InvertedPartition { Result::Ok(PostingIterator::with_query_weight( token, token_id, - position as u32, + position, query_weight, posting, num_docs, diff --git a/rust/lance-index/src/scalar/inverted/query.rs b/rust/lance-index/src/scalar/inverted/query.rs index 4c207d83134..6a8ebb07840 100644 --- a/rust/lance-index/src/scalar/inverted/query.rs +++ b/rust/lance-index/src/scalar/inverted/query.rs @@ -719,12 +719,19 @@ impl FtsQueryNode for BooleanQuery { #[derive(Clone)] pub struct Tokens { tokens: Vec, + positions: Vec, tokens_map: HashMap, token_type: DocType, } impl Tokens { pub fn new(tokens: Vec, token_type: DocType) -> Self { + let positions = (0..tokens.len() as u32).collect(); + Self::with_positions(tokens, positions, token_type) + } + + pub fn with_positions(tokens: Vec, positions: Vec, token_type: DocType) -> Self { + debug_assert_eq!(tokens.len(), positions.len()); let mut tokens_vec = vec![]; let mut tokens_map = HashMap::new(); for (idx, token) in tokens.into_iter().enumerate() { @@ -734,6 +741,7 @@ impl Tokens { Self { tokens: tokens_vec, + positions, tokens_map, token_type, } @@ -762,6 +770,10 @@ impl Tokens { pub fn get_token(&self, index: usize) -> &str { &self.tokens[index] } + + pub fn position(&self, index: usize) -> u32 { + self.positions[index] + } } impl IntoIterator for Tokens { @@ -786,10 +798,12 @@ pub fn collect_query_tokens(text: &str, tokenizer: &mut Box) let token_type = tokenizer.doc_type(); let mut stream = tokenizer.token_stream_for_search(text); let mut tokens = Vec::new(); + let mut positions = Vec::new(); while let Some(token) = stream.next() { tokens.push(token.text.clone()); + positions.push(token.position as u32); } - Tokens::new(tokens, token_type) + Tokens::with_positions(tokens, positions, token_type) } pub fn has_query_token( diff --git a/rust/lance-index/src/scalar/inverted/wand.rs b/rust/lance-index/src/scalar/inverted/wand.rs index 7da5ddfb0fb..b06c75c0021 100644 --- a/rust/lance-index/src/scalar/inverted/wand.rs +++ b/rust/lance-index/src/scalar/inverted/wand.rs @@ -1993,6 +1993,44 @@ mod tests { assert!(wand.check_positions(0)); } + #[rstest] + fn test_exact_phrase_respects_query_position_gaps(#[values(false, true)] is_compressed: bool) { + let mut docs = DocSet::default(); + docs.append(0, 16); + + let postings = vec![ + PostingIterator::new( + String::from("want"), + 0, + 0, + generate_posting_list_with_positions( + vec![0], + vec![vec![0_u32]], + 1.0, + is_compressed, + ), + docs.len(), + ), + PostingIterator::new( + String::from("apple"), + 1, + 2, + generate_posting_list_with_positions( + vec![0], + vec![vec![2_u32]], + 1.0, + is_compressed, + ), + docs.len(), + ), + ]; + + let bm25 = IndexBM25Scorer::new(std::iter::empty()); + let wand = Wand::new(Operator::And, postings.into_iter(), &docs, bm25); + assert!(wand.check_exact_positions()); + assert!(wand.check_positions(0)); + } + #[rstest] fn test_and_phrase_miss_advances_to_next_candidate(#[values(false, true)] is_compressed: bool) { let mut docs = DocSet::default(); diff --git a/rust/lance/src/dataset/tests/dataset_index.rs b/rust/lance/src/dataset/tests/dataset_index.rs index c3aac4493d4..87c17cdba32 100644 --- a/rust/lance/src/dataset/tests/dataset_index.rs +++ b/rust/lance/src/dataset/tests/dataset_index.rs @@ -1776,6 +1776,125 @@ async fn test_fts_phrase_query() { assert_eq!(result.num_rows(), 0); } +#[tokio::test] +async fn test_fts_phrase_query_with_removed_stop_words() { + let tmpdir = TempStrDir::default(); + let uri = tmpdir.to_owned(); + drop(tmpdir); + + let doc_col: Arc = Arc::new(GenericStringArray::::from(vec![ + "want the apple", + "want an apple", + "want green apple", + "apple want the", + ])); + let ids = UInt64Array::from_iter_values(0..doc_col.len() as u64); + let batch = RecordBatch::try_new( + arrow_schema::Schema::new(vec![ + arrow_schema::Field::new("doc", doc_col.data_type().to_owned(), true), + arrow_schema::Field::new("id", DataType::UInt64, false), + ]) + .into(), + vec![Arc::new(doc_col) as ArrayRef, Arc::new(ids) as ArrayRef], + ) + .unwrap(); + let schema = batch.schema(); + let batches = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema); + let mut dataset = Dataset::write(batches, &uri, None).await.unwrap(); + + dataset + .create_index( + &["doc"], + IndexType::Inverted, + None, + &InvertedIndexParams::default() + .with_position(true) + .remove_stop_words(true), + true, + ) + .await + .unwrap(); + + for query in ["want the apple", "want an apple"] { + let result = dataset + .scan() + .project(&["id"]) + .unwrap() + .full_text_search(FullTextSearchQuery::new_query( + PhraseQuery::new(query.to_owned()).into(), + )) + .unwrap() + .try_into_batch() + .await + .unwrap(); + + let ids = result["id"].as_primitive::().values(); + assert_eq!(result.num_rows(), 3, "query={query}, ids={ids:?}"); + assert!(ids.contains(&0), "query={query}, ids={ids:?}"); + assert!(ids.contains(&1), "query={query}, ids={ids:?}"); + assert!(ids.contains(&2), "query={query}, ids={ids:?}"); + } +} + +#[tokio::test] +async fn test_fts_phrase_query_preserves_stop_word_gaps() { + let tmpdir = TempStrDir::default(); + let uri = tmpdir.to_owned(); + drop(tmpdir); + + let doc_col: Arc = Arc::new(GenericStringArray::::from(vec![ + "the united states of america", + "the united states and america", + "united states america", + "the united states of north america", + ])); + let ids = UInt64Array::from_iter_values(0..doc_col.len() as u64); + let batch = RecordBatch::try_new( + arrow_schema::Schema::new(vec![ + arrow_schema::Field::new("doc", doc_col.data_type().to_owned(), true), + arrow_schema::Field::new("id", DataType::UInt64, false), + ]) + .into(), + vec![Arc::new(doc_col) as ArrayRef, Arc::new(ids) as ArrayRef], + ) + .unwrap(); + let schema = batch.schema(); + let batches = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema); + let mut dataset = Dataset::write(batches, &uri, None).await.unwrap(); + + dataset + .create_index( + &["doc"], + IndexType::Inverted, + None, + &InvertedIndexParams::default() + .with_position(true) + .remove_stop_words(true), + true, + ) + .await + .unwrap(); + + let result = dataset + .scan() + .project(&["id"]) + .unwrap() + .full_text_search(FullTextSearchQuery::new_query( + PhraseQuery::new("the united states of america".to_owned()).into(), + )) + .unwrap() + .try_into_batch() + .await + .unwrap(); + + let ids = result["id"].as_primitive::().values(); + assert_eq!(result.num_rows(), 2, "ids={ids:?}"); + assert!(ids.contains(&0), "ids={ids:?}"); + assert!(ids.contains(&1), "ids={ids:?}"); + assert!(!ids.contains(&2), "ids={ids:?}"); + assert!(!ids.contains(&3), "ids={ids:?}"); +} + async fn prepare_json_dataset() -> (Dataset, String) { let text_col = Arc::new(StringArray::from(vec![ r#"{ diff --git a/rust/lance/src/io/exec/fts.rs b/rust/lance/src/io/exec/fts.rs index 6e129841c76..7cefcb2046c 100644 --- a/rust/lance/src/io/exec/fts.rs +++ b/rust/lance/src/io/exec/fts.rs @@ -872,10 +872,13 @@ impl ExecutionPlan for PhraseQueryExec { let metrics = Arc::new(FtsIndexMetrics::new(&self.metrics, partition)); let stream = stream::once(async move { let _timer = metrics.baseline_metrics.elapsed_compute().timer(); - let column = query.column.ok_or(DataFusionError::Execution(format!( - "column not set for PhraseQuery {}", - query.terms - )))?; + let column = query + .column + .clone() + .ok_or(DataFusionError::Execution(format!( + "column not set for PhraseQuery {}", + query.terms + )))?; let index_meta = ds .load_scalar_index(IndexCriteria::default().for_column(&column).supports_fts()) .await? @@ -892,7 +895,7 @@ impl ExecutionPlan for PhraseQueryExec { context.clone(), partition, &prefilter_source, - ds, + ds.clone(), &[index_meta], )?;