From 98f7081c7f12b64043728a95841c08f764e86bc1 Mon Sep 17 00:00:00 2001 From: lijinglun Date: Tue, 12 Aug 2025 10:19:31 +0800 Subject: [PATCH 1/5] feat: pushdown contains_tokens using inverted index --- rust/lance-index/src/scalar/expression.rs | 3 - rust/lance-index/src/scalar/inverted/index.rs | 52 ++++++-- rust/lance/src/dataset.rs | 111 +++++++++++++++++- 3 files changed, 155 insertions(+), 11 deletions(-) diff --git a/rust/lance-index/src/scalar/expression.rs b/rust/lance-index/src/scalar/expression.rs index 4c20aefaed8..a2abf0b49ff 100644 --- a/rust/lance-index/src/scalar/expression.rs +++ b/rust/lance-index/src/scalar/expression.rs @@ -478,9 +478,6 @@ impl ScalarQueryParser for FtsQueryParser { } let scalar = maybe_scalar(&args[1], data_type)?; if let ScalarValue::Utf8(Some(scalar_str)) = scalar { - // TODO(https://github.com/lancedb/lance/issues/3855): - // - // Create the contains_tokens UDF if func.name() == "contains_tokens" { let query = TextQuery::StringContains(scalar_str); Some(IndexedExpression::index_query( diff --git a/rust/lance-index/src/scalar/inverted/index.rs b/rust/lance-index/src/scalar/inverted/index.rs index bd0dd6062b0..80ccb52cdee 100644 --- a/rust/lance-index/src/scalar/inverted/index.rs +++ b/rust/lance-index/src/scalar/inverted/index.rs @@ -13,6 +13,8 @@ use std::{ ops::Range, }; +use crate::metrics::NoOpMetricsCollector; +use crate::prefilter::NoFilter; use arrow::{ array::LargeBinaryBuilder, datatypes::{self, Float32Type, Int32Type, UInt64Type}, @@ -33,10 +35,11 @@ use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion_common::DataFusionError; use deepsize::DeepSizeOf; use fst::{Automaton, IntoStreamer, Streamer}; -use futures::{stream, StreamExt, TryStreamExt}; +use futures::{stream, FutureExt, StreamExt, TryStreamExt}; use itertools::Itertools; use lance_arrow::{iter_str_array, RecordBatchExt}; use lance_core::cache::{CacheKey, LanceCache}; +use lance_core::utils::mask::RowIdTreeMap; use lance_core::utils::{ mask::RowIdMask, tracing::{IO_TYPE_LOAD_SCALAR_PART, TRACE_IO_EVENTS}, @@ -69,7 +72,7 @@ use super::{ use super::{wand::*, InvertedIndexBuilder, InvertedIndexParams}; use crate::frag_reuse::FragReuseIndex; use crate::scalar::{ - AnyQuery, IndexReader, IndexStore, MetricsCollector, SargableQuery, ScalarIndex, SearchResult, + AnyQuery, IndexReader, IndexStore, MetricsCollector, ScalarIndex, SearchResult, TextQuery, }; use crate::Index; use crate::{prefilter::PreFilter, scalar::inverted::iter::take_fst_keys}; @@ -94,6 +97,8 @@ pub static SCORE_FIELD: LazyLock = LazyLock::new(|| Field::new(SCORE_COL, DataType::Float32, true)); pub static FTS_SCHEMA: LazyLock = LazyLock::new(|| Arc::new(Schema::new(vec![ROW_ID_FIELD.clone(), SCORE_FIELD.clone()]))); +static ROW_ID_SCHEMA: LazyLock = + LazyLock::new(|| Arc::new(Schema::new(vec![ROW_ID_FIELD.clone()]))); #[derive(Clone)] pub struct InvertedIndex { @@ -320,6 +325,31 @@ impl Index for InvertedIndex { } } +impl InvertedIndex { + /// Search docs match the input text. + async fn do_search(&self, text: &str) -> Result { + let params = FtsSearchParams::new(); + let mut tokenizer = self.tokenizer.clone(); + let tokens = collect_tokens(text, &mut tokenizer, None); + + let (doc_ids, _) = self + .bm25_search( + tokens.into(), + params.into(), + Operator::Or, + Arc::new(NoFilter), + Arc::new(NoOpMetricsCollector), + ) + .boxed() + .await?; + + Ok(RecordBatch::try_new( + ROW_ID_SCHEMA.clone(), + vec![Arc::new(UInt64Array::from(doc_ids))], + )?) + } +} + #[async_trait] impl ScalarIndex for InvertedIndex { // return the row ids of the documents that contain the query @@ -329,11 +359,19 @@ impl ScalarIndex for InvertedIndex { query: &dyn AnyQuery, _metrics: &dyn MetricsCollector, ) -> Result { - let query = query.as_any().downcast_ref::().unwrap(); - return Err(Error::invalid_input( - format!("unsupported query {:?} for inverted index", query), - location!(), - )); + let query = query.as_any().downcast_ref::().unwrap(); + match query { + TextQuery::StringContains(text) => { + let records = self.do_search(text).await?; + let row_ids = records + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let row_ids = row_ids.iter().flatten().collect_vec(); + Ok(SearchResult::AtMost(RowIdTreeMap::from_iter(row_ids))) + } + } } fn can_remap(&self) -> bool { diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index fc49813c2af..973de92e114 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -2080,7 +2080,7 @@ mod tests { }; use arrow_ord::sort::sort_to_indices; use arrow_schema::{ - DataType, Field as ArrowField, Fields as ArrowFields, Schema as ArrowSchema, + DataType, Field as ArrowField, Field, Fields as ArrowFields, Schema as ArrowSchema, }; use lance_arrow::bfloat16::{self, ARROW_EXT_META_KEY, ARROW_EXT_NAME_KEY, BFLOAT16_EXT_NAME}; use lance_core::datatypes::LANCE_STORAGE_CLASS_SCHEMA_META_KEY; @@ -2098,8 +2098,11 @@ mod tests { use lance_table::feature_flags; use lance_table::format::{DataFile, WriterVersion}; + use crate::datafusion::LanceTableProvider; use all_asserts::assert_true; + use datafusion::prelude::SessionContext; use lance_datafusion::datagen::DatafusionDatagenExt; + use lance_datafusion::udf::register_functions; use lance_testing::datagen::generate_random_array; use pretty_assertions::assert_eq; use rand::seq::SliceRandom; @@ -7269,4 +7272,110 @@ mod tests { dataset.validate().await.unwrap(); assert_eq!(dataset.count_rows(None).await.unwrap(), 3); } + + #[tokio::test] + async fn test_sql_contains_tokens() { + let text_col = Arc::new(StringArray::from(vec![ + "a cat", + "lovely cat", + "white cat", + "catch up", + "fish", + ])); + + // Prepare dataset + let batch = RecordBatch::try_new( + arrow_schema::Schema::new(vec![Field::new("text", DataType::Utf8, false)]).into(), + vec![text_col.clone()], + ) + .unwrap(); + let schema = batch.schema(); + let stream = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema); + let mut dataset = Dataset::write(stream, "memory://test/table", None) + .await + .unwrap(); + + // Test without fts index + let results = execute_sql( + "select * from foo where contains_tokens(text, 'cat')", + "foo".to_string(), + Arc::new(dataset.clone()), + ) + .await + .unwrap(); + + assert_results( + results, + &StringArray::from(vec!["a cat", "lovely cat", "white cat", "catch up"]), + ); + + // Test with fts index + dataset + .create_index( + &["text"], + IndexType::Inverted, + None, + &ScalarIndexParams::default(), + false, + ) + .await + .unwrap(); + + let results = execute_sql( + "select * from foo where contains_tokens(text, 'cat')", + "foo".to_string(), + Arc::new(dataset.clone()), + ) + .await + .unwrap(); + + // FTS index introduces false negatives, so "catch up" will miss. + assert_results( + results, + &StringArray::from(vec!["a cat", "lovely cat", "white cat"]), + ); + + // Test multiple tokens + let results = execute_sql( + "select * from foo where contains_tokens(text, 'lovely cat')", + "foo".to_string(), + Arc::new(dataset.clone()), + ) + .await + .unwrap(); + + assert_results(results, &StringArray::from(vec!["lovely cat"])); + } + + async fn execute_sql( + sql: &str, + table: String, + dataset: Arc, + ) -> Result> { + let ctx = SessionContext::new(); + ctx.register_table( + table, + Arc::new(LanceTableProvider::new(dataset, false, false)), + )?; + register_functions(&ctx); + + let df = ctx.sql(sql).await?; + Ok(df + .execute_stream() + .await + .unwrap() + .try_collect::>() + .await?) + } + + fn assert_results(results: Vec, values: &T) { + assert_eq!(results.len(), 1); + let results = results.into_iter().next().unwrap(); + assert_eq!(results.num_columns(), 1); + + assert_eq!( + results.column(0).as_any().downcast_ref::().unwrap(), + values + ) + } } From 0571e12193e29767bb1e0cc20d057c134a7b577b Mon Sep 17 00:00:00 2001 From: lijinglun Date: Thu, 21 Aug 2025 14:51:09 +0800 Subject: [PATCH 2/5] update udf contains_tokens syntax --- rust/lance-datafusion/src/udf.rs | 60 ++++++++++++++++++++++---------- 1 file changed, 42 insertions(+), 18 deletions(-) diff --git a/rust/lance-datafusion/src/udf.rs b/rust/lance-datafusion/src/udf.rs index 6c02383b44d..a7bb9197e58 100644 --- a/rust/lance-datafusion/src/udf.rs +++ b/rust/lance-datafusion/src/udf.rs @@ -3,7 +3,7 @@ //! Datafusion user defined functions -use arrow_array::{ArrayRef, BooleanArray, StringArray}; +use arrow_array::{Array, ArrayRef, BooleanArray, StringArray}; use arrow_schema::DataType; use datafusion::logical_expr::{create_udf, ScalarUDF, Volatility}; use datafusion::prelude::SessionContext; @@ -15,20 +15,25 @@ pub fn register_functions(ctx: &SessionContext) { ctx.register_udf(CONTAINS_TOKENS_UDF.clone()); } -/// This method checks whether a string contains another string. It utilizes FTS (Full-Text Search) -/// indexes, but due to the false negative characteristic of FTS, the results may have omissions. -/// For example, "bakin" will not match documents containing "baking." -/// If the query string is a whole word, or if you prioritize better performance, `contains_tokens` -/// is the better choice. Otherwise, you can use the `contains` method to obtain accurate results. +/// This method checks whether a string contains all specified tokens. The tokens are separated by +/// punctuations and white spaces. /// +/// The functionality is equivalent to FTS MatchQuery (with fuzziness disabled, Operator::And, +/// and using the simple tokenizer). If FTS index exists and suites the query, it will be used to +/// optimize the query. /// /// Usage /// * Use `contains_tokens` in sql. /// ```rust,ignore -/// let sql = "SELECT * FROM table WHERE contains_tokens(text_col, 'bakin')" +/// let sql = "SELECT * FROM table WHERE contains_tokens(text_col, 'fox jumps dog')"; /// let mut ds = Dataset::open(&ds_path).await?; -/// let mut builder = ds.sql(&sql); -/// let records = builder.clone().build().await?.into_batch_records().await?; +/// let ctx = SessionContext::new(); +/// ctx.register_table( +/// "table", +/// Arc::new(LanceTableProvider::new(dataset, false, false)), +/// )?; +/// register_functions(&ctx); +/// let df = ctx.sql(sql).await?; /// ``` fn contains_tokens() -> ScalarUDF { let function = Arc::new(make_scalar_function( @@ -44,10 +49,22 @@ fn contains_tokens() -> ScalarUDF { ), )?; + // Split tokens based on punctuations and white spaces. + let tokens: Option> = match scalar_str.len() { + 0 => None, + _ => Some(collect_tokens(scalar_str.value(0))) + }; + let result = column .iter() - .enumerate() - .map(|(i, column)| column.map(|value| value.contains(scalar_str.value(i)))); + .map(|text| text.map(|text| { + let text_tokens = collect_tokens(text); + if let Some(tokens) = &tokens { + tokens.len() == tokens.iter().filter(|token| text_tokens.contains(*token)).count() + } else { + true + } + })); Ok(Arc::new(BooleanArray::from_iter(result)) as ArrayRef) }, @@ -63,6 +80,13 @@ fn contains_tokens() -> ScalarUDF { ) } +/// Split tokens separated by punctuations and white spaces. +fn collect_tokens(text: &str) -> Vec<&str> { + text.split(|c: char| !c.is_alphanumeric()) + .filter(|word|!word.is_empty()) + .collect() +} + static CONTAINS_TOKENS_UDF: LazyLock = LazyLock::new(contains_tokens); #[cfg(test)] @@ -79,13 +103,13 @@ mod tests { // Prepare arguments let contains_tokens = CONTAINS_TOKENS_UDF.clone(); let text_col = Arc::new(StringArray::from(vec![ - "a cat", - "lovely cat", - "white cat", - "catch up", - "fish", + "a cat catch a fish", + "a fish catch a cat", + "a white cat catch a big fish", + "cat catchup fish", + "cat fish catch", ])); - let token = Arc::new(StringArray::from(vec!["cat", "cat", "cat", "cat", "cat"])); + let token = Arc::new(StringArray::from(vec![" cat catch fish.", " cat catch fish.", " cat catch fish.", " cat catch fish.", " cat catch fish."])); let args = vec![ColumnarValue::Array(text_col), ColumnarValue::Array(token)]; let arg_fields = vec![ @@ -107,7 +131,7 @@ mod tests { let array = array.as_any().downcast_ref::().unwrap(); assert_eq!( array.clone(), - BooleanArray::from(vec![true, true, true, true, false]) + BooleanArray::from(vec![true, true, true, false, true]) ); } else { panic!("Expected an Array but got {:?}", values); From 0fc8a1d86a1b095a083323a9f630f412eea7c4c4 Mon Sep 17 00:00:00 2001 From: lijinglun Date: Thu, 21 Aug 2025 16:07:16 +0800 Subject: [PATCH 3/5] push contains_tokens to inerted index --- rust/lance-datafusion/src/udf.rs | 2 +- rust/lance-index/src/scalar.rs | 38 +++++++++++++++++++ rust/lance-index/src/scalar/expression.rs | 6 +-- rust/lance-index/src/scalar/inverted/index.rs | 23 ++++++++--- rust/lance/src/dataset.rs | 38 +++++++------------ 5 files changed, 73 insertions(+), 34 deletions(-) diff --git a/rust/lance-datafusion/src/udf.rs b/rust/lance-datafusion/src/udf.rs index a7bb9197e58..a3065039b16 100644 --- a/rust/lance-datafusion/src/udf.rs +++ b/rust/lance-datafusion/src/udf.rs @@ -87,7 +87,7 @@ fn collect_tokens(text: &str) -> Vec<&str> { .collect() } -static CONTAINS_TOKENS_UDF: LazyLock = LazyLock::new(contains_tokens); +pub static CONTAINS_TOKENS_UDF: LazyLock = LazyLock::new(contains_tokens); #[cfg(test)] mod tests { diff --git a/rust/lance-index/src/scalar.rs b/rust/lance-index/src/scalar.rs index 12d81585cb3..ba8cb591f29 100644 --- a/rust/lance-index/src/scalar.rs +++ b/rust/lance-index/src/scalar.rs @@ -40,6 +40,7 @@ pub mod zonemap; use crate::frag_reuse::FragReuseIndex; pub use inverted::tokenizer::InvertedIndexParams; +use lance_datafusion::udf::CONTAINS_TOKENS_UDF; pub const LANCE_SCALAR_INDEX: &str = "__lance_scalar_index"; @@ -554,6 +555,43 @@ impl AnyQuery for TextQuery { } } +/// A query that a InvertedIndex can satisfy +#[derive(Debug, Clone, PartialEq)] +pub enum TokenQuery { + /// Retrieve all row ids where the text contains all tokens parsed from given string. The tokens + /// are separated by punctuations and white spaces. + TokensContains(String), +} + +impl AnyQuery for TokenQuery { + fn as_any(&self) -> &dyn Any { + self + } + + fn format(&self, col: &str) -> String { + format!("{}", self.to_expr(col.to_string())) + } + + fn to_expr(&self, col: String) -> Expr { + match self { + Self::TokensContains(substr) => Expr::ScalarFunction(ScalarFunction { + func: Arc::new(CONTAINS_TOKENS_UDF.clone()), + args: vec![ + Expr::Column(Column::new_unqualified(col)), + Expr::Literal(ScalarValue::Utf8(Some(substr.clone())), None), + ], + }), + } + } + + fn dyn_eq(&self, other: &dyn AnyQuery) -> bool { + match other.as_any().downcast_ref::() { + Some(o) => self == o, + None => false, + } + } +} + /// The result of a search operation against a scalar index #[derive(Debug, PartialEq)] pub enum SearchResult { diff --git a/rust/lance-index/src/scalar/expression.rs b/rust/lance-index/src/scalar/expression.rs index a2abf0b49ff..a30487a3834 100644 --- a/rust/lance-index/src/scalar/expression.rs +++ b/rust/lance-index/src/scalar/expression.rs @@ -24,9 +24,7 @@ use roaring::RoaringBitmap; use snafu::location; use tracing::instrument; -use super::{ - AnyQuery, LabelListQuery, MetricsCollector, SargableQuery, ScalarIndex, SearchResult, TextQuery, -}; +use super::{AnyQuery, LabelListQuery, MetricsCollector, SargableQuery, ScalarIndex, SearchResult, TextQuery, TokenQuery}; const MAX_DEPTH: usize = 500; @@ -479,7 +477,7 @@ impl ScalarQueryParser for FtsQueryParser { let scalar = maybe_scalar(&args[1], data_type)?; if let ScalarValue::Utf8(Some(scalar_str)) = scalar { if func.name() == "contains_tokens" { - let query = TextQuery::StringContains(scalar_str); + let query = TokenQuery::TokensContains(scalar_str); Some(IndexedExpression::index_query( column.to_string(), self.index_name.clone(), diff --git a/rust/lance-index/src/scalar/inverted/index.rs b/rust/lance-index/src/scalar/inverted/index.rs index 80ccb52cdee..8afef522664 100644 --- a/rust/lance-index/src/scalar/inverted/index.rs +++ b/rust/lance-index/src/scalar/inverted/index.rs @@ -71,9 +71,7 @@ use super::{ }; use super::{wand::*, InvertedIndexBuilder, InvertedIndexParams}; use crate::frag_reuse::FragReuseIndex; -use crate::scalar::{ - AnyQuery, IndexReader, IndexStore, MetricsCollector, ScalarIndex, SearchResult, TextQuery, -}; +use crate::scalar::{AnyQuery, IndexReader, IndexStore, MetricsCollector, ScalarIndex, SearchResult, TextQuery, TokenQuery}; use crate::Index; use crate::{prefilter::PreFilter, scalar::inverted::iter::take_fst_keys}; @@ -326,6 +324,17 @@ impl Index for InvertedIndex { } impl InvertedIndex { + /// Whether the query can use the current index. + fn is_query_allowed(&self, query: &TokenQuery) -> bool { + match query { + TokenQuery::TokensContains(_) => { + self.params.base_tokenizer == "simple" && + self.params.max_token_length.is_none() + } + } + } + + /// Search docs match the input text. async fn do_search(&self, text: &str) -> Result { let params = FtsSearchParams::new(); @@ -359,9 +368,13 @@ impl ScalarIndex for InvertedIndex { query: &dyn AnyQuery, _metrics: &dyn MetricsCollector, ) -> Result { - let query = query.as_any().downcast_ref::().unwrap(); + let query = query.as_any().downcast_ref::().unwrap(); + if !self.is_query_allowed(query) { + return Ok(SearchResult::AtLeast(RowIdTreeMap::from_iter::>(vec![]))) + } + match query { - TextQuery::StringContains(text) => { + TokenQuery::TokensContains(text) => { let records = self.do_search(text).await?; let row_ids = records .column(0) diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index 973de92e114..d7e6d9d4614 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -7276,11 +7276,11 @@ mod tests { #[tokio::test] async fn test_sql_contains_tokens() { let text_col = Arc::new(StringArray::from(vec![ - "a cat", - "lovely cat", - "white cat", - "catch up", - "fish", + "a cat catch a fish", + "a fish catch a cat", + "a white cat catch a big fish", + "cat catchup fish", + "cat fish catch", ])); // Prepare dataset @@ -7297,7 +7297,7 @@ mod tests { // Test without fts index let results = execute_sql( - "select * from foo where contains_tokens(text, 'cat')", + "select * from foo where contains_tokens(text, 'cat catch fish')", "foo".to_string(), Arc::new(dataset.clone()), ) @@ -7306,7 +7306,8 @@ mod tests { assert_results( results, - &StringArray::from(vec!["a cat", "lovely cat", "white cat", "catch up"]), + &StringArray::from(vec!["a cat catch a fish", "a fish catch a cat", + "a white cat catch a big fish", "cat fish catch"]), ); // Test with fts index @@ -7315,36 +7316,25 @@ mod tests { &["text"], IndexType::Inverted, None, - &ScalarIndexParams::default(), + &InvertedIndexParams::default().max_token_length(None), false, ) .await .unwrap(); let results = execute_sql( - "select * from foo where contains_tokens(text, 'cat')", + "select * from foo where contains_tokens(text, 'cat catch fish')", "foo".to_string(), Arc::new(dataset.clone()), ) - .await - .unwrap(); + .await + .unwrap(); - // FTS index introduces false negatives, so "catch up" will miss. assert_results( results, - &StringArray::from(vec!["a cat", "lovely cat", "white cat"]), + &StringArray::from(vec!["a cat catch a fish", "a fish catch a cat", + "a white cat catch a big fish", "cat fish catch"]), ); - - // Test multiple tokens - let results = execute_sql( - "select * from foo where contains_tokens(text, 'lovely cat')", - "foo".to_string(), - Arc::new(dataset.clone()), - ) - .await - .unwrap(); - - assert_results(results, &StringArray::from(vec!["lovely cat"])); } async fn execute_sql( From 80a2c6688963c6f20e9c526eeec3d4fa70c245a6 Mon Sep 17 00:00:00 2001 From: lijinglun Date: Thu, 21 Aug 2025 16:07:35 +0800 Subject: [PATCH 4/5] fmt and clippy --- rust/lance-datafusion/src/udf.rs | 27 ++++++++++++------- rust/lance-index/src/scalar/expression.rs | 5 +++- rust/lance-index/src/scalar/inverted/index.rs | 18 ++++++++----- rust/lance/src/dataset.rs | 20 +++++++++----- 4 files changed, 48 insertions(+), 22 deletions(-) diff --git a/rust/lance-datafusion/src/udf.rs b/rust/lance-datafusion/src/udf.rs index a3065039b16..003582006ea 100644 --- a/rust/lance-datafusion/src/udf.rs +++ b/rust/lance-datafusion/src/udf.rs @@ -49,22 +49,25 @@ fn contains_tokens() -> ScalarUDF { ), )?; - // Split tokens based on punctuations and white spaces. let tokens: Option> = match scalar_str.len() { 0 => None, - _ => Some(collect_tokens(scalar_str.value(0))) + _ => Some(collect_tokens(scalar_str.value(0))), }; - let result = column - .iter() - .map(|text| text.map(|text| { + let result = column.iter().map(|text| { + text.map(|text| { let text_tokens = collect_tokens(text); if let Some(tokens) = &tokens { - tokens.len() == tokens.iter().filter(|token| text_tokens.contains(*token)).count() + tokens.len() + == tokens + .iter() + .filter(|token| text_tokens.contains(*token)) + .count() } else { true } - })); + }) + }); Ok(Arc::new(BooleanArray::from_iter(result)) as ArrayRef) }, @@ -83,7 +86,7 @@ fn contains_tokens() -> ScalarUDF { /// Split tokens separated by punctuations and white spaces. fn collect_tokens(text: &str) -> Vec<&str> { text.split(|c: char| !c.is_alphanumeric()) - .filter(|word|!word.is_empty()) + .filter(|word| !word.is_empty()) .collect() } @@ -109,7 +112,13 @@ mod tests { "cat catchup fish", "cat fish catch", ])); - let token = Arc::new(StringArray::from(vec![" cat catch fish.", " cat catch fish.", " cat catch fish.", " cat catch fish.", " cat catch fish."])); + let token = Arc::new(StringArray::from(vec![ + " cat catch fish.", + " cat catch fish.", + " cat catch fish.", + " cat catch fish.", + " cat catch fish.", + ])); let args = vec![ColumnarValue::Array(text_col), ColumnarValue::Array(token)]; let arg_fields = vec![ diff --git a/rust/lance-index/src/scalar/expression.rs b/rust/lance-index/src/scalar/expression.rs index a30487a3834..74db917237a 100644 --- a/rust/lance-index/src/scalar/expression.rs +++ b/rust/lance-index/src/scalar/expression.rs @@ -24,7 +24,10 @@ use roaring::RoaringBitmap; use snafu::location; use tracing::instrument; -use super::{AnyQuery, LabelListQuery, MetricsCollector, SargableQuery, ScalarIndex, SearchResult, TextQuery, TokenQuery}; +use super::{ + AnyQuery, LabelListQuery, MetricsCollector, SargableQuery, ScalarIndex, SearchResult, + TextQuery, TokenQuery, +}; const MAX_DEPTH: usize = 500; diff --git a/rust/lance-index/src/scalar/inverted/index.rs b/rust/lance-index/src/scalar/inverted/index.rs index 8afef522664..a70ca95946c 100644 --- a/rust/lance-index/src/scalar/inverted/index.rs +++ b/rust/lance-index/src/scalar/inverted/index.rs @@ -49,6 +49,7 @@ use lance_core::{Error, Result, ROW_ID, ROW_ID_FIELD}; use roaring::RoaringBitmap; use snafu::location; use std::sync::LazyLock; +use tantivy::tokenizer::Language; use tracing::{info, instrument}; use super::{ @@ -71,7 +72,9 @@ use super::{ }; use super::{wand::*, InvertedIndexBuilder, InvertedIndexParams}; use crate::frag_reuse::FragReuseIndex; -use crate::scalar::{AnyQuery, IndexReader, IndexStore, MetricsCollector, ScalarIndex, SearchResult, TextQuery, TokenQuery}; +use crate::scalar::{ + AnyQuery, IndexReader, IndexStore, MetricsCollector, ScalarIndex, SearchResult, TokenQuery, +}; use crate::Index; use crate::{prefilter::PreFilter, scalar::inverted::iter::take_fst_keys}; @@ -328,13 +331,14 @@ impl InvertedIndex { fn is_query_allowed(&self, query: &TokenQuery) -> bool { match query { TokenQuery::TokensContains(_) => { - self.params.base_tokenizer == "simple" && - self.params.max_token_length.is_none() + self.params.base_tokenizer == "simple" + && self.params.max_token_length.is_none() + && self.params.language == Language::English + && !self.params.stem } } } - /// Search docs match the input text. async fn do_search(&self, text: &str) -> Result { let params = FtsSearchParams::new(); @@ -345,7 +349,7 @@ impl InvertedIndex { .bm25_search( tokens.into(), params.into(), - Operator::Or, + Operator::And, Arc::new(NoFilter), Arc::new(NoOpMetricsCollector), ) @@ -370,7 +374,9 @@ impl ScalarIndex for InvertedIndex { ) -> Result { let query = query.as_any().downcast_ref::().unwrap(); if !self.is_query_allowed(query) { - return Ok(SearchResult::AtLeast(RowIdTreeMap::from_iter::>(vec![]))) + return Ok(SearchResult::AtLeast(RowIdTreeMap::from_iter::>( + vec![], + ))); } match query { diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index d7e6d9d4614..005e8bd0783 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -7306,8 +7306,12 @@ mod tests { assert_results( results, - &StringArray::from(vec!["a cat catch a fish", "a fish catch a cat", - "a white cat catch a big fish", "cat fish catch"]), + &StringArray::from(vec![ + "a cat catch a fish", + "a fish catch a cat", + "a white cat catch a big fish", + "cat fish catch", + ]), ); // Test with fts index @@ -7327,13 +7331,17 @@ mod tests { "foo".to_string(), Arc::new(dataset.clone()), ) - .await - .unwrap(); + .await + .unwrap(); assert_results( results, - &StringArray::from(vec!["a cat catch a fish", "a fish catch a cat", - "a white cat catch a big fish", "cat fish catch"]), + &StringArray::from(vec![ + "a cat catch a fish", + "a fish catch a cat", + "a white cat catch a big fish", + "cat fish catch", + ]), ); } From f066ce6fbb1d328f31c18a9a214e11f47e26e422 Mon Sep 17 00:00:00 2001 From: lijinglun Date: Fri, 22 Aug 2025 19:49:13 +0800 Subject: [PATCH 5/5] follow comments --- rust/lance-index/src/scalar/expression.rs | 35 ++++----- rust/lance-index/src/scalar/inverted/index.rs | 7 +- rust/lance/src/dataset.rs | 72 ++++++++++++++++++- rust/lance/src/index.rs | 23 ++++-- 4 files changed, 106 insertions(+), 31 deletions(-) diff --git a/rust/lance-index/src/scalar/expression.rs b/rust/lance-index/src/scalar/expression.rs index 74db917237a..e7c57f82a56 100644 --- a/rust/lance-index/src/scalar/expression.rs +++ b/rust/lance-index/src/scalar/expression.rs @@ -17,6 +17,11 @@ use datafusion_expr::{ Between, BinaryExpr, Expr, Operator, ReturnFieldArgs, ScalarUDF, }; +use super::{ + AnyQuery, LabelListQuery, MetricsCollector, SargableQuery, ScalarIndex, SearchResult, + TextQuery, TokenQuery, +}; +use crate::scalar::inverted::InvertedIndex; use futures::join; use lance_core::{utils::mask::RowIdMask, Error, Result}; use lance_datafusion::{expr::safe_coerce_scalar, planner::Planner}; @@ -24,11 +29,6 @@ use roaring::RoaringBitmap; use snafu::location; use tracing::instrument; -use super::{ - AnyQuery, LabelListQuery, MetricsCollector, SargableQuery, ScalarIndex, SearchResult, - TextQuery, TokenQuery, -}; - const MAX_DEPTH: usize = 500; /// An indexed expression consists of a scalar index query with a post-scan filter @@ -428,11 +428,15 @@ impl ScalarQueryParser for TextQueryParser { #[derive(Debug, Clone)] pub struct FtsQueryParser { index_name: String, + index: InvertedIndex, } impl FtsQueryParser { - pub fn new(index_name: String) -> Self { - Self { index_name } + pub fn new(name: String, index: InvertedIndex) -> Self { + Self { + index_name: name, + index, + } } } @@ -481,17 +485,16 @@ impl ScalarQueryParser for FtsQueryParser { if let ScalarValue::Utf8(Some(scalar_str)) = scalar { if func.name() == "contains_tokens" { let query = TokenQuery::TokensContains(scalar_str); - Some(IndexedExpression::index_query( - column.to_string(), - self.index_name.clone(), - Arc::new(query), - )) - } else { - None + if self.index.is_query_allowed(&query) { + return Some(IndexedExpression::index_query( + column.to_string(), + self.index_name.clone(), + Arc::new(query), + )); + } } - } else { - None } + None } } diff --git a/rust/lance-index/src/scalar/inverted/index.rs b/rust/lance-index/src/scalar/inverted/index.rs index a70ca95946c..4d4357e5b95 100644 --- a/rust/lance-index/src/scalar/inverted/index.rs +++ b/rust/lance-index/src/scalar/inverted/index.rs @@ -328,7 +328,7 @@ impl Index for InvertedIndex { impl InvertedIndex { /// Whether the query can use the current index. - fn is_query_allowed(&self, query: &TokenQuery) -> bool { + pub fn is_query_allowed(&self, query: &TokenQuery) -> bool { match query { TokenQuery::TokensContains(_) => { self.params.base_tokenizer == "simple" @@ -373,11 +373,6 @@ impl ScalarIndex for InvertedIndex { _metrics: &dyn MetricsCollector, ) -> Result { let query = query.as_any().downcast_ref::().unwrap(); - if !self.is_query_allowed(query) { - return Ok(SearchResult::AtLeast(RowIdTreeMap::from_iter::>( - vec![], - ))); - } match query { TokenQuery::TokensContains(text) => { diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index 005e8bd0783..863a87e4678 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -2100,6 +2100,7 @@ mod tests { use crate::datafusion::LanceTableProvider; use all_asserts::assert_true; + use datafusion::common::{assert_contains, assert_not_contains}; use datafusion::prelude::SessionContext; use lance_datafusion::datagen::DatafusionDatagenExt; use lance_datafusion::udf::register_functions; @@ -7314,14 +7315,25 @@ mod tests { ]), ); - // Test with fts index + // Verify plan, should not contain ScalarIndexQuery. + let results = execute_sql( + "explain select * from foo where contains_tokens(text, 'cat catch fish')", + "foo".to_string(), + Arc::new(dataset.clone()), + ) + .await + .unwrap(); + let plan = format!("{:?}", results); + assert_not_contains!(&plan, "ScalarIndexQuery"); + + // Test with unsuitable fts index dataset .create_index( &["text"], IndexType::Inverted, None, - &InvertedIndexParams::default().max_token_length(None), - false, + &InvertedIndexParams::default().base_tokenizer("raw".to_string()), + true, ) .await .unwrap(); @@ -7343,6 +7355,60 @@ mod tests { "cat fish catch", ]), ); + + // Verify plan, should not contain ScalarIndexQuery because fts index is not unsuitable. + let results = execute_sql( + "explain select * from foo where contains_tokens(text, 'cat catch fish')", + "foo".to_string(), + Arc::new(dataset.clone()), + ) + .await + .unwrap(); + let plan = format!("{:?}", results); + assert_not_contains!(&plan, "ScalarIndexQuery"); + + // Test with suitable fts index + dataset + .create_index( + &["text"], + IndexType::Inverted, + None, + &InvertedIndexParams::default() + .max_token_length(None) + .stem(false), + true, + ) + .await + .unwrap(); + + let results = execute_sql( + "select * from foo where contains_tokens(text, 'cat catch fish')", + "foo".to_string(), + Arc::new(dataset.clone()), + ) + .await + .unwrap(); + + assert_results( + results, + &StringArray::from(vec![ + "a cat catch a fish", + "a fish catch a cat", + "a white cat catch a big fish", + "cat fish catch", + ]), + ); + + // Verify plan, should contain ScalarIndexQuery. + let results = execute_sql( + "explain select * from foo where contains_tokens(text, 'cat catch fish')", + "foo".to_string(), + Arc::new(dataset.clone()), + ) + .await + .unwrap(); + let plan = format!("{:?}", results); + assert_contains!(&plan, "ScalarIndexQuery"); } async fn execute_sql( diff --git a/rust/lance/src/index.rs b/rust/lance/src/index.rs index 69c775be93d..e41786fa009 100644 --- a/rust/lance/src/index.rs +++ b/rust/lance/src/index.rs @@ -71,17 +71,17 @@ pub mod prefilter; pub mod scalar; pub mod vector; -use crate::dataset::index::LanceIndexStoreExt; -pub use crate::index::prefilter::{FilterLoader, PreFilter}; -pub use create::CreateIndexBuilder; - use self::append::merge_indices; use self::vector::remap_vector_index; +use crate::dataset::index::LanceIndexStoreExt; use crate::dataset::transaction::{Operation, Transaction}; use crate::index::frag_reuse::{load_frag_reuse_index_details, open_frag_reuse_index}; use crate::index::mem_wal::open_mem_wal_index; +pub use crate::index::prefilter::{FilterLoader, PreFilter}; use crate::session::index_caches::{FragReuseIndexKey, IndexMetadataKey}; use crate::{dataset::Dataset, Error, Result}; +pub use create::CreateIndexBuilder; +use lance_index::scalar::inverted::InvertedIndex; // Cache keys for different index types #[derive(Debug, Clone)] @@ -1374,8 +1374,19 @@ impl DatasetIndexInternalExt for Dataset { Box::new(TextQueryParser::new(index.name.clone(), true)) as Box } - ScalarIndexType::Inverted => Box::new(FtsQueryParser::new(index.name.clone())) - as Box, + ScalarIndexType::Inverted => { + let fts_index = + lance_index::scalar::expression::ScalarIndexLoader::load_index( + self, + &field.name, + &index.name, + &NoOpMetricsCollector, + ) + .await?; + let fts_index = fts_index.as_any().downcast_ref::().unwrap(); + Box::new(FtsQueryParser::new(index.name.clone(), fts_index.clone())) + as Box + } _ => continue, }, _ => {