diff --git a/rust/lance/tests/query/inverted.rs b/rust/lance/tests/query/inverted.rs new file mode 100644 index 00000000000..c6872447049 --- /dev/null +++ b/rust/lance/tests/query/inverted.rs @@ -0,0 +1,268 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use std::sync::Arc; + +use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray, UInt32Array}; +use lance::dataset::scanner::ColumnOrdering; +use lance::Dataset; +use lance_index::scalar::inverted::query::{FtsQuery, PhraseQuery}; +use lance_index::scalar::{FullTextSearchQuery, InvertedIndexParams}; +use lance_index::{DatasetIndexExt, IndexType}; +use tantivy::tokenizer::Language; + +use super::{strip_score_column, test_fts, test_scan, test_take}; +use crate::utils::DatasetTestCases; + +// Build baseline inverted index parameters for tests, toggling token positions. +fn base_inverted_params(with_position: bool) -> InvertedIndexParams { + InvertedIndexParams::new("simple".to_string(), Language::English) + .with_position(with_position) + .lower_case(true) + .stem(false) + .remove_stop_words(false) + .ascii_folding(false) + .max_token_length(None) +} + +fn params_for(base_tokenizer: &str, lower_case: bool, with_position: bool) -> InvertedIndexParams { + InvertedIndexParams::new(base_tokenizer.to_string(), Language::English) + .with_position(with_position) + .lower_case(lower_case) + .stem(false) + .remove_stop_words(false) + .ascii_folding(false) + .max_token_length(None) +} + +// Execute a full-text search with optional filter and deterministic id ordering. +async fn run_fts(ds: &Dataset, query: FullTextSearchQuery, filter: Option<&str>) -> RecordBatch { + let mut scanner = ds.scan(); + scanner.full_text_search(query).unwrap(); + if let Some(predicate) = filter { + scanner.filter(predicate).unwrap(); + } + scanner + .order_by(Some(vec![ColumnOrdering::asc_nulls_first( + "id".to_string(), + )])) + .unwrap(); + scanner.try_into_batch().await.unwrap() +} + +// Run an FTS query and assert results match a deterministic expected batch. +async fn assert_fts_expected( + original: &RecordBatch, + ds: &Dataset, + query: FullTextSearchQuery, + filter: Option<&str>, + expected_ids: &[i32], +) { + let scanned = run_fts(ds, query, filter).await; + let scanned = strip_score_column(&scanned, original.schema().as_ref()); + + let indices_u32: Vec = expected_ids.iter().map(|&i| i as u32).collect(); + let indices_array = UInt32Array::from(indices_u32); + let expected = arrow::compute::take_record_batch(original, &indices_array).unwrap(); + + // Ensure ordering is deterministic (id asc) and matches the expected rows. + assert_eq!(&expected, &scanned); +} + +#[tokio::test] +// Ensure indexed and non-indexed full-text search return the same ids. +async fn test_inverted_basic_equivalence() { + let ids = Arc::new(Int32Array::from((0..10).collect::>())); + let text_values = vec![ + Some("hello world"), + Some("world hello"), + Some("hello"), + Some("lance database"), + Some(""), + None, + Some("hello lance"), + Some("lance"), + Some("database"), + Some("world"), + ]; + let text = Arc::new(StringArray::from(text_values)) as ArrayRef; + let batch = RecordBatch::try_from_iter(vec![("id", ids as ArrayRef), ("text", text)]).unwrap(); + + DatasetTestCases::from_data(batch.clone()) + .run(|ds, original| async move { + let mut ds = ds; + let query = FullTextSearchQuery::new("hello".to_string()) + .with_column("text".to_string()) + .unwrap(); + + let expected_ids = vec![0, 1, 2, 6]; + assert_fts_expected(&original, &ds, query.clone(), None, &expected_ids).await; + + let params = base_inverted_params(false); + ds.create_index(&["text"], IndexType::Inverted, None, ¶ms, true) + .await + .unwrap(); + assert_fts_expected(&original, &ds, query.clone(), None, &expected_ids).await; + test_fts(&original, &ds, "text", "hello", None, true, false).await; + + test_scan(&original, &ds).await; + test_take(&original, &ds).await; + }) + .await; +} + +#[tokio::test] +// Verify phrase queries require token positions and match contiguous terms. +async fn test_inverted_phrase_query_with_positions() { + let ids = Arc::new(Int32Array::from((0..6).collect::>())); + let text_values = vec![ + Some("lance database"), + Some("lance and database"), + Some("database lance"), + Some("lance database test"), + Some("lance database"), + None, + ]; + let text = Arc::new(StringArray::from(text_values)) as ArrayRef; + let batch = RecordBatch::try_from_iter(vec![("id", ids as ArrayRef), ("text", text)]).unwrap(); + + DatasetTestCases::from_data(batch.clone()) + .run(|ds, original| async move { + let mut ds = ds; + let params = base_inverted_params(true); + ds.create_index(&["text"], IndexType::Inverted, None, ¶ms, true) + .await + .unwrap(); + + let phrase = PhraseQuery::new("lance database".to_string()) + .with_column(Some("text".to_string())); + let query = FullTextSearchQuery::new_query(FtsQuery::Phrase(phrase)); + + assert_fts_expected(&original, &ds, query, None, &[0, 3, 4]).await; + test_fts(&original, &ds, "text", "lance database", None, true, true).await; + }) + .await; +} + +#[tokio::test] +// Validate filters are applied alongside inverted index search results. +async fn test_inverted_with_filter() { + let ids = Arc::new(Int32Array::from((0..5).collect::>())); + let text_values = vec![ + Some("lance database"), + Some("lance vector"), + Some("random text"), + Some("lance"), + None, + ]; + let categories = vec![ + Some("keep"), + Some("drop"), + Some("keep"), + Some("keep"), + Some("keep"), + ]; + let text = Arc::new(StringArray::from(text_values)) as ArrayRef; + let category = Arc::new(StringArray::from(categories)) as ArrayRef; + let batch = RecordBatch::try_from_iter(vec![ + ("id", ids as ArrayRef), + ("text", text), + ("category", category), + ]) + .unwrap(); + + DatasetTestCases::from_data(batch.clone()) + .with_index_types( + "category", + [ + None, + Some(IndexType::Bitmap), + Some(IndexType::BTree), + Some(IndexType::BloomFilter), + Some(IndexType::ZoneMap), + ], + ) + .run(|ds, original| async move { + let mut ds = ds; + let params = base_inverted_params(false); + ds.create_index(&["text"], IndexType::Inverted, None, ¶ms, true) + .await + .unwrap(); + + let query = FullTextSearchQuery::new("lance".to_string()) + .with_column("text".to_string()) + .unwrap(); + assert_fts_expected(&original, &ds, query, Some("category = 'keep'"), &[0, 3]).await; + test_fts( + &original, + &ds, + "text", + "lance", + Some("category = 'keep'"), + true, + false, + ) + .await; + }) + .await; +} + +#[tokio::test] +// Validate tokenizer/lowercase/position parameter combinations against expected matches. +async fn test_inverted_params_combinations() { + let ids = Arc::new(Int32Array::from((0..5).collect::>())); + let text_values = vec![ + Some("Hello there, this is a longer sentence about Lance."), + Some("In this longer sentence we say hello to the database."), + Some("Another line: hello world appears in a longer phrase."), + Some("Saying HELLO loudly in a long sentence for testing."), + None, + ]; + let text = Arc::new(StringArray::from(text_values)) as ArrayRef; + let batch = RecordBatch::try_from_iter(vec![("id", ids as ArrayRef), ("text", text)]).unwrap(); + + let cases = vec![ + ( + "simple_lc_pos", + params_for("simple", true, true), + vec![0, 1, 2, 3], + true, + ), + ( + "simple_no_lc", + params_for("simple", false, false), + vec![1, 2], + false, + ), + ( + "whitespace_lc", + params_for("whitespace", true, false), + vec![0, 1, 2, 3], + true, + ), + ( + "whitespace_no_lc_pos", + params_for("whitespace", false, true), + vec![1, 2], + false, + ), + ]; + + for (_name, params, expected, lower_case) in cases { + let params = params.clone(); + let expected = expected.clone(); + DatasetTestCases::from_data(batch.clone()) + .with_index_types_and_inverted_index_params("text", [Some(IndexType::Inverted)], params) + .run(|ds, original| { + let expected = expected.clone(); + async move { + let query = FullTextSearchQuery::new("hello".to_string()) + .with_column("text".to_string()) + .unwrap(); + assert_fts_expected(&original, &ds, query.clone(), None, &expected).await; + test_fts(&original, &ds, "text", "hello", None, lower_case, false).await; + } + }) + .await; + } +} diff --git a/rust/lance/tests/query/mod.rs b/rust/lance/tests/query/mod.rs index 5816d786f89..c9514100a63 100644 --- a/rust/lance/tests/query/mod.rs +++ b/rust/lance/tests/query/mod.rs @@ -10,6 +10,8 @@ use datafusion::prelude::SessionContext; use lance::dataset::scanner::ColumnOrdering; use lance::Dataset; use lance_datafusion::udf::register_functions; +use lance_index::scalar::inverted::query::{FtsQuery, PhraseQuery}; +use lance_index::scalar::FullTextSearchQuery; /// Creates a fresh SessionContext with Lance UDFs registered fn create_datafusion_context() -> SessionContext { @@ -18,6 +20,7 @@ fn create_datafusion_context() -> SessionContext { ctx } +mod inverted; mod primitives; mod vectors; @@ -94,6 +97,106 @@ async fn test_filter(original: &RecordBatch, ds: &Dataset, predicate: &str) { assert_eq!(&expected, &scanned); } +// Rebuild a batch using only columns present in the schema (drops _score from FTS results). +fn strip_score_column(batch: &RecordBatch, schema: &arrow_schema::Schema) -> RecordBatch { + let columns = schema + .fields() + .iter() + .map(|field| batch.column_by_name(field.name()).unwrap().clone()) + .collect::>(); + RecordBatch::try_new(Arc::new(schema.clone()), columns).unwrap() +} + +/// Full text search should match results computed in DataFusion using the constructed SQL +async fn test_fts( + original: &RecordBatch, + ds: &Dataset, + column: &str, + query: &str, + filter: Option<&str>, + lower_case: bool, + phrase_query: bool, +) { + // Scan with FTS and order + let mut scanner = ds.scan(); + let fts_query = if phrase_query { + let phrase = PhraseQuery::new(query.to_string()).with_column(Some(column.to_string())); + FullTextSearchQuery::new_query(FtsQuery::Phrase(phrase)) + } else { + FullTextSearchQuery::new(query.to_string()) + .with_column(column.to_string()) + .unwrap() + }; + scanner.full_text_search(fts_query).unwrap(); + if let Some(predicate) = filter { + scanner.filter(predicate).unwrap(); + } + scanner + .order_by(Some(vec![ColumnOrdering::asc_nulls_first( + "id".to_string(), + )])) + .unwrap(); + let scanned = scanner.try_into_batch().await.unwrap(); + let scanned = strip_score_column(&scanned, original.schema().as_ref()); + + let ctx = create_datafusion_context(); + let table = MemTable::try_new(original.schema(), vec![vec![original.clone()]]).unwrap(); + ctx.register_table("t", Arc::new(table)).unwrap(); + + let col_expr = if lower_case { + format!("lower(t.{})", column) + } else { + format!("t.{}", column) + }; + let normalized_query = if lower_case { + query.to_lowercase() + } else { + query.to_string() + }; + let expected_from_where = |where_clause: String| async move { + let sql = format!("SELECT * FROM t WHERE {} ORDER BY id", where_clause); + let df = ctx.sql(&sql).await.unwrap(); + let expected_batches = df.collect().await.unwrap(); + concat_batches(&original.schema(), &expected_batches).unwrap() + }; + let expected = if normalized_query.is_empty() { + expected_from_where(filter.unwrap_or("true").to_string()).await + } else if phrase_query { + let predicate = format!("{} LIKE '%{}%'", col_expr, normalized_query); + let where_clause = if let Some(extra) = filter { + format!("{} AND {}", predicate, extra) + } else { + predicate + }; + expected_from_where(where_clause).await + } else { + let tokens = collect_tokens(&normalized_query); + if tokens.is_empty() { + expected_from_where(filter.unwrap_or("true").to_string()).await + } else { + let predicate = tokens + .into_iter() + .map(|token| format!("{} LIKE '%{}%'", col_expr, token)) + .collect::>() + .join(" AND "); + let where_clause = if let Some(extra) = filter { + format!("{} AND {}", predicate, extra) + } else { + predicate + }; + expected_from_where(where_clause).await + } + }; + + assert_eq!(&expected, &scanned); +} + +fn collect_tokens(text: &str) -> Vec<&str> { + text.split(|c: char| !c.is_alphanumeric()) + .filter(|word| !word.is_empty()) + .collect() +} + /// Test that an exhaustive ANN query gives the same results as brute force /// KNN against the original batch. /// diff --git a/rust/lance/tests/utils/mod.rs b/rust/lance/tests/utils/mod.rs index 9ef9f39b10d..930813ee17c 100644 --- a/rust/lance/tests/utils/mod.rs +++ b/rust/lance/tests/utils/mod.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +use std::collections::HashMap; use std::panic::AssertUnwindSafe; use std::sync::Arc; @@ -11,7 +12,7 @@ use lance::{ dataset::{InsertBuilder, WriteParams}, Dataset, }; -use lance_index::scalar::ScalarIndexParams; +use lance_index::scalar::{InvertedIndexParams, ScalarIndexParams}; use lance_index::vector::hnsw::builder::HnswBuildParams; use lance_index::vector::ivf::IvfBuildParams; use lance_index::vector::pq::PQBuildParams; @@ -40,6 +41,7 @@ pub enum DeletionState { pub struct DatasetTestCases { original: RecordBatch, index_options: Vec<(String, Vec>)>, + inverted_index_params: HashMap, } impl DatasetTestCases { @@ -47,6 +49,7 @@ impl DatasetTestCases { Self { original, index_options: Vec::new(), + inverted_index_params: HashMap::new(), } } @@ -60,6 +63,19 @@ impl DatasetTestCases { self } + pub fn with_index_types_and_inverted_index_params( + mut self, + column: impl Into, + index_types: impl IntoIterator>, + inverted_params: InvertedIndexParams, + ) -> Self { + let column = column.into(); + self.index_options + .push((column.clone(), index_types.into_iter().collect())); + self.inverted_index_params.insert(column, inverted_params); + self + } + fn generate_index_combinations(&self) -> Vec> { if self.index_options.is_empty() { return vec![vec![]]; @@ -109,12 +125,17 @@ impl DatasetTestCases { ] { let index_combinations = self.generate_index_combinations(); for indices in index_combinations { - let ds = - build_dataset(self.original.clone(), fragmentation, deletion, &indices) - .await; + let ds = build_dataset( + self.original.clone(), + fragmentation, + deletion, + &indices, + &self.inverted_index_params, + ) + .await; let context = format!( - "fragmentation: {:?}, deletion: {:?}, index: {:?}", - fragmentation, deletion, indices + "fragmentation: {:?}, deletion: {:?}, index: {:?}, inverted_index_params: {:?}", + fragmentation, deletion, indices, self.inverted_index_params ); // Catch unwind so we can add test context to the panic. AssertUnwindSafe(test_fn(ds, self.original.clone())) @@ -136,6 +157,7 @@ async fn build_dataset( fragmentation: Fragmentation, deletion: DeletionState, indices: &[(&str, IndexType)], + inverted_index_params: &HashMap, ) -> Dataset { let data_to_write = fill_deleted_rows(&original, deletion); @@ -172,10 +194,17 @@ async fn build_dataset( | IndexType::LabelList | IndexType::NGram | IndexType::ZoneMap - | IndexType::Inverted | IndexType::BloomFilter => Box::new(ScalarIndexParams::for_builtin( (*index_type).try_into().unwrap(), )), + IndexType::Inverted => inverted_index_params + .get(*column) + .map(|params| Box::new(params.clone()) as Box) + .unwrap_or_else(|| { + Box::new(ScalarIndexParams::for_builtin( + (*index_type).try_into().unwrap(), + )) + }), IndexType::IvfFlat => { // Use a small number of partitions for testing Box::new(VectorIndexParams::ivf_flat(2, MetricType::L2))