Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions rust/lance-index/src/scalar/inverted/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,12 @@ impl InvertedIndex {
if postings.is_empty() {
return Result::Ok(PartitionCandidates::empty());
}
let mut tokens_by_position = vec![String::new(); postings.len()];
let max_position = postings
.iter()
.map(|posting| posting.term_index() as usize)
.max()
.unwrap_or_default();
let mut tokens_by_position = vec![String::new(); max_position + 1];
for posting in &postings {
let idx = posting.term_index() as usize;
tokens_by_position[idx] = posting.token().to_owned();
Expand Down Expand Up @@ -978,11 +983,14 @@ impl InvertedPartition {
true => self.expand_fuzzy(tokens, params)?,
false => tokens.clone(),
};
let token_positions = (0..tokens.len())
.map(|index| tokens.position(index))
.collect::<Vec<_>>();
let mut token_ids = Vec::with_capacity(tokens.len());
for token in tokens {
for (index, token) in tokens.into_iter().enumerate() {
let token_id = self.map(&token);
if let Some(token_id) = token_id {
token_ids.push((token_id, token));
token_ids.push((token_id, token, token_positions[index]));
} else if is_phrase_query {
// if the token is not found, we can't do phrase query
return Ok(Vec::new());
Expand All @@ -992,14 +1000,13 @@ impl InvertedPartition {
return Ok(Vec::new());
}
if !is_phrase_query {
token_ids.sort_unstable_by_key(|(token_id, _)| *token_id);
token_ids.dedup_by_key(|(token_id, _)| *token_id);
token_ids.sort_unstable_by_key(|(token_id, _, _)| *token_id);
token_ids.dedup_by_key(|(token_id, _, _)| *token_id);
}

let num_docs = self.docs.len();
stream::iter(token_ids)
.enumerate()
.map(|(position, (token_id, token))| async move {
.map(|(token_id, token, position)| async move {
let posting = self
.inverted_list
.posting_list(token_id, is_phrase_query, metrics)
Expand All @@ -1010,7 +1017,7 @@ impl InvertedPartition {
Result::Ok(PostingIterator::with_query_weight(
token,
token_id,
position as u32,
position,
query_weight,
posting,
num_docs,
Expand Down
16 changes: 15 additions & 1 deletion rust/lance-index/src/scalar/inverted/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -719,12 +719,19 @@ impl FtsQueryNode for BooleanQuery {
#[derive(Clone)]
pub struct Tokens {
tokens: Vec<String>,
positions: Vec<u32>,
tokens_map: HashMap<String, usize>,
token_type: DocType,
}

impl Tokens {
pub fn new(tokens: Vec<String>, token_type: DocType) -> Self {
let positions = (0..tokens.len() as u32).collect();
Self::with_positions(tokens, positions, token_type)
}

pub fn with_positions(tokens: Vec<String>, positions: Vec<u32>, token_type: DocType) -> Self {
debug_assert_eq!(tokens.len(), positions.len());
let mut tokens_vec = vec![];
let mut tokens_map = HashMap::new();
for (idx, token) in tokens.into_iter().enumerate() {
Expand All @@ -734,6 +741,7 @@ impl Tokens {

Self {
tokens: tokens_vec,
positions,
tokens_map,
token_type,
}
Expand Down Expand Up @@ -762,6 +770,10 @@ impl Tokens {
pub fn get_token(&self, index: usize) -> &str {
&self.tokens[index]
}

pub fn position(&self, index: usize) -> u32 {
self.positions[index]
}
}

impl IntoIterator for Tokens {
Expand All @@ -786,10 +798,12 @@ pub fn collect_query_tokens(text: &str, tokenizer: &mut Box<dyn LanceTokenizer>)
let token_type = tokenizer.doc_type();
let mut stream = tokenizer.token_stream_for_search(text);
let mut tokens = Vec::new();
let mut positions = Vec::new();
while let Some(token) = stream.next() {
tokens.push(token.text.clone());
positions.push(token.position as u32);
}
Tokens::new(tokens, token_type)
Tokens::with_positions(tokens, positions, token_type)
}

pub fn has_query_token(
Expand Down
38 changes: 38 additions & 0 deletions rust/lance-index/src/scalar/inverted/wand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1993,6 +1993,44 @@ mod tests {
assert!(wand.check_positions(0));
}

#[rstest]
fn test_exact_phrase_respects_query_position_gaps(#[values(false, true)] is_compressed: bool) {
let mut docs = DocSet::default();
docs.append(0, 16);

let postings = vec![
PostingIterator::new(
String::from("want"),
0,
0,
generate_posting_list_with_positions(
vec![0],
vec![vec![0_u32]],
1.0,
is_compressed,
),
docs.len(),
),
PostingIterator::new(
String::from("apple"),
1,
2,
generate_posting_list_with_positions(
vec![0],
vec![vec![2_u32]],
1.0,
is_compressed,
),
docs.len(),
),
];

let bm25 = IndexBM25Scorer::new(std::iter::empty());
let wand = Wand::new(Operator::And, postings.into_iter(), &docs, bm25);
assert!(wand.check_exact_positions());
assert!(wand.check_positions(0));
}

#[rstest]
fn test_and_phrase_miss_advances_to_next_candidate(#[values(false, true)] is_compressed: bool) {
let mut docs = DocSet::default();
Expand Down
119 changes: 119 additions & 0 deletions rust/lance/src/dataset/tests/dataset_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1776,6 +1776,125 @@ async fn test_fts_phrase_query() {
assert_eq!(result.num_rows(), 0);
}

#[tokio::test]
async fn test_fts_phrase_query_with_removed_stop_words() {
let tmpdir = TempStrDir::default();
let uri = tmpdir.to_owned();
drop(tmpdir);

let doc_col: Arc<dyn Array> = Arc::new(GenericStringArray::<i32>::from(vec![
"want the apple",
"want an apple",
"want green apple",
"apple want the",
]));
let ids = UInt64Array::from_iter_values(0..doc_col.len() as u64);
let batch = RecordBatch::try_new(
arrow_schema::Schema::new(vec![
arrow_schema::Field::new("doc", doc_col.data_type().to_owned(), true),
arrow_schema::Field::new("id", DataType::UInt64, false),
])
.into(),
vec![Arc::new(doc_col) as ArrayRef, Arc::new(ids) as ArrayRef],
)
.unwrap();
let schema = batch.schema();
let batches = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema);
let mut dataset = Dataset::write(batches, &uri, None).await.unwrap();

dataset
.create_index(
&["doc"],
IndexType::Inverted,
None,
&InvertedIndexParams::default()
.with_position(true)
.remove_stop_words(true),
true,
)
.await
.unwrap();

for query in ["want the apple", "want an apple"] {
let result = dataset
.scan()
.project(&["id"])
.unwrap()
.full_text_search(FullTextSearchQuery::new_query(
PhraseQuery::new(query.to_owned()).into(),
))
.unwrap()
.try_into_batch()
.await
.unwrap();

let ids = result["id"].as_primitive::<UInt64Type>().values();
assert_eq!(result.num_rows(), 3, "query={query}, ids={ids:?}");
assert!(ids.contains(&0), "query={query}, ids={ids:?}");
assert!(ids.contains(&1), "query={query}, ids={ids:?}");
assert!(ids.contains(&2), "query={query}, ids={ids:?}");
}
}

#[tokio::test]
async fn test_fts_phrase_query_preserves_stop_word_gaps() {
let tmpdir = TempStrDir::default();
let uri = tmpdir.to_owned();
drop(tmpdir);

let doc_col: Arc<dyn Array> = Arc::new(GenericStringArray::<i32>::from(vec![
"the united states of america",
"the united states and america",
"united states america",
"the united states of north america",
]));
let ids = UInt64Array::from_iter_values(0..doc_col.len() as u64);
let batch = RecordBatch::try_new(
arrow_schema::Schema::new(vec![
arrow_schema::Field::new("doc", doc_col.data_type().to_owned(), true),
arrow_schema::Field::new("id", DataType::UInt64, false),
])
.into(),
vec![Arc::new(doc_col) as ArrayRef, Arc::new(ids) as ArrayRef],
)
.unwrap();
let schema = batch.schema();
let batches = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema);
let mut dataset = Dataset::write(batches, &uri, None).await.unwrap();

dataset
.create_index(
&["doc"],
IndexType::Inverted,
None,
&InvertedIndexParams::default()
.with_position(true)
.remove_stop_words(true),
true,
)
.await
.unwrap();

let result = dataset
.scan()
.project(&["id"])
.unwrap()
.full_text_search(FullTextSearchQuery::new_query(
PhraseQuery::new("the united states of america".to_owned()).into(),
))
.unwrap()
.try_into_batch()
.await
.unwrap();

let ids = result["id"].as_primitive::<UInt64Type>().values();
assert_eq!(result.num_rows(), 2, "ids={ids:?}");
assert!(ids.contains(&0), "ids={ids:?}");
assert!(ids.contains(&1), "ids={ids:?}");
assert!(!ids.contains(&2), "ids={ids:?}");
assert!(!ids.contains(&3), "ids={ids:?}");
}

async fn prepare_json_dataset() -> (Dataset, String) {
let text_col = Arc::new(StringArray::from(vec![
r#"{
Expand Down
13 changes: 8 additions & 5 deletions rust/lance/src/io/exec/fts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -872,10 +872,13 @@ impl ExecutionPlan for PhraseQueryExec {
let metrics = Arc::new(FtsIndexMetrics::new(&self.metrics, partition));
let stream = stream::once(async move {
let _timer = metrics.baseline_metrics.elapsed_compute().timer();
let column = query.column.ok_or(DataFusionError::Execution(format!(
"column not set for PhraseQuery {}",
query.terms
)))?;
let column = query
.column
.clone()
.ok_or(DataFusionError::Execution(format!(
"column not set for PhraseQuery {}",
query.terms
)))?;
let index_meta = ds
.load_scalar_index(IndexCriteria::default().for_column(&column).supports_fts())
.await?
Expand All @@ -892,7 +895,7 @@ impl ExecutionPlan for PhraseQueryExec {
context.clone(),
partition,
&prefilter_source,
ds,
ds.clone(),
&[index_meta],
)?;

Expand Down
Loading