diff --git a/rust/lance-datafusion/src/udf.rs b/rust/lance-datafusion/src/udf.rs index 6c02383b44d..003582006ea 100644 --- a/rust/lance-datafusion/src/udf.rs +++ b/rust/lance-datafusion/src/udf.rs @@ -3,7 +3,7 @@ //! Datafusion user defined functions -use arrow_array::{ArrayRef, BooleanArray, StringArray}; +use arrow_array::{Array, ArrayRef, BooleanArray, StringArray}; use arrow_schema::DataType; use datafusion::logical_expr::{create_udf, ScalarUDF, Volatility}; use datafusion::prelude::SessionContext; @@ -15,20 +15,25 @@ pub fn register_functions(ctx: &SessionContext) { ctx.register_udf(CONTAINS_TOKENS_UDF.clone()); } -/// This method checks whether a string contains another string. It utilizes FTS (Full-Text Search) -/// indexes, but due to the false negative characteristic of FTS, the results may have omissions. -/// For example, "bakin" will not match documents containing "baking." -/// If the query string is a whole word, or if you prioritize better performance, `contains_tokens` -/// is the better choice. Otherwise, you can use the `contains` method to obtain accurate results. +/// This method checks whether a string contains all specified tokens. The tokens are separated by +/// punctuations and white spaces. /// +/// The functionality is equivalent to FTS MatchQuery (with fuzziness disabled, Operator::And, +/// and using the simple tokenizer). If FTS index exists and suites the query, it will be used to +/// optimize the query. /// /// Usage /// * Use `contains_tokens` in sql. /// ```rust,ignore -/// let sql = "SELECT * FROM table WHERE contains_tokens(text_col, 'bakin')" +/// let sql = "SELECT * FROM table WHERE contains_tokens(text_col, 'fox jumps dog')"; /// let mut ds = Dataset::open(&ds_path).await?; -/// let mut builder = ds.sql(&sql); -/// let records = builder.clone().build().await?.into_batch_records().await?; +/// let ctx = SessionContext::new(); +/// ctx.register_table( +/// "table", +/// Arc::new(LanceTableProvider::new(dataset, false, false)), +/// )?; +/// register_functions(&ctx); +/// let df = ctx.sql(sql).await?; /// ``` fn contains_tokens() -> ScalarUDF { let function = Arc::new(make_scalar_function( @@ -44,10 +49,25 @@ fn contains_tokens() -> ScalarUDF { ), )?; - let result = column - .iter() - .enumerate() - .map(|(i, column)| column.map(|value| value.contains(scalar_str.value(i)))); + let tokens: Option> = match scalar_str.len() { + 0 => None, + _ => Some(collect_tokens(scalar_str.value(0))), + }; + + let result = column.iter().map(|text| { + text.map(|text| { + let text_tokens = collect_tokens(text); + if let Some(tokens) = &tokens { + tokens.len() + == tokens + .iter() + .filter(|token| text_tokens.contains(*token)) + .count() + } else { + true + } + }) + }); Ok(Arc::new(BooleanArray::from_iter(result)) as ArrayRef) }, @@ -63,7 +83,14 @@ fn contains_tokens() -> ScalarUDF { ) } -static CONTAINS_TOKENS_UDF: LazyLock = LazyLock::new(contains_tokens); +/// Split tokens separated by punctuations and white spaces. +fn collect_tokens(text: &str) -> Vec<&str> { + text.split(|c: char| !c.is_alphanumeric()) + .filter(|word| !word.is_empty()) + .collect() +} + +pub static CONTAINS_TOKENS_UDF: LazyLock = LazyLock::new(contains_tokens); #[cfg(test)] mod tests { @@ -79,13 +106,19 @@ mod tests { // Prepare arguments let contains_tokens = CONTAINS_TOKENS_UDF.clone(); let text_col = Arc::new(StringArray::from(vec![ - "a cat", - "lovely cat", - "white cat", - "catch up", - "fish", + "a cat catch a fish", + "a fish catch a cat", + "a white cat catch a big fish", + "cat catchup fish", + "cat fish catch", + ])); + let token = Arc::new(StringArray::from(vec![ + " cat catch fish.", + " cat catch fish.", + " cat catch fish.", + " cat catch fish.", + " cat catch fish.", ])); - let token = Arc::new(StringArray::from(vec!["cat", "cat", "cat", "cat", "cat"])); let args = vec![ColumnarValue::Array(text_col), ColumnarValue::Array(token)]; let arg_fields = vec![ @@ -107,7 +140,7 @@ mod tests { let array = array.as_any().downcast_ref::().unwrap(); assert_eq!( array.clone(), - BooleanArray::from(vec![true, true, true, true, false]) + BooleanArray::from(vec![true, true, true, false, true]) ); } else { panic!("Expected an Array but got {:?}", values); diff --git a/rust/lance-index/src/scalar.rs b/rust/lance-index/src/scalar.rs index 12d81585cb3..ba8cb591f29 100644 --- a/rust/lance-index/src/scalar.rs +++ b/rust/lance-index/src/scalar.rs @@ -40,6 +40,7 @@ pub mod zonemap; use crate::frag_reuse::FragReuseIndex; pub use inverted::tokenizer::InvertedIndexParams; +use lance_datafusion::udf::CONTAINS_TOKENS_UDF; pub const LANCE_SCALAR_INDEX: &str = "__lance_scalar_index"; @@ -554,6 +555,43 @@ impl AnyQuery for TextQuery { } } +/// A query that a InvertedIndex can satisfy +#[derive(Debug, Clone, PartialEq)] +pub enum TokenQuery { + /// Retrieve all row ids where the text contains all tokens parsed from given string. The tokens + /// are separated by punctuations and white spaces. + TokensContains(String), +} + +impl AnyQuery for TokenQuery { + fn as_any(&self) -> &dyn Any { + self + } + + fn format(&self, col: &str) -> String { + format!("{}", self.to_expr(col.to_string())) + } + + fn to_expr(&self, col: String) -> Expr { + match self { + Self::TokensContains(substr) => Expr::ScalarFunction(ScalarFunction { + func: Arc::new(CONTAINS_TOKENS_UDF.clone()), + args: vec![ + Expr::Column(Column::new_unqualified(col)), + Expr::Literal(ScalarValue::Utf8(Some(substr.clone())), None), + ], + }), + } + } + + fn dyn_eq(&self, other: &dyn AnyQuery) -> bool { + match other.as_any().downcast_ref::() { + Some(o) => self == o, + None => false, + } + } +} + /// The result of a search operation against a scalar index #[derive(Debug, PartialEq)] pub enum SearchResult { diff --git a/rust/lance-index/src/scalar/expression.rs b/rust/lance-index/src/scalar/expression.rs index 4c20aefaed8..e7c57f82a56 100644 --- a/rust/lance-index/src/scalar/expression.rs +++ b/rust/lance-index/src/scalar/expression.rs @@ -17,6 +17,11 @@ use datafusion_expr::{ Between, BinaryExpr, Expr, Operator, ReturnFieldArgs, ScalarUDF, }; +use super::{ + AnyQuery, LabelListQuery, MetricsCollector, SargableQuery, ScalarIndex, SearchResult, + TextQuery, TokenQuery, +}; +use crate::scalar::inverted::InvertedIndex; use futures::join; use lance_core::{utils::mask::RowIdMask, Error, Result}; use lance_datafusion::{expr::safe_coerce_scalar, planner::Planner}; @@ -24,10 +29,6 @@ use roaring::RoaringBitmap; use snafu::location; use tracing::instrument; -use super::{ - AnyQuery, LabelListQuery, MetricsCollector, SargableQuery, ScalarIndex, SearchResult, TextQuery, -}; - const MAX_DEPTH: usize = 500; /// An indexed expression consists of a scalar index query with a post-scan filter @@ -427,11 +428,15 @@ impl ScalarQueryParser for TextQueryParser { #[derive(Debug, Clone)] pub struct FtsQueryParser { index_name: String, + index: InvertedIndex, } impl FtsQueryParser { - pub fn new(index_name: String) -> Self { - Self { index_name } + pub fn new(name: String, index: InvertedIndex) -> Self { + Self { + index_name: name, + index, + } } } @@ -478,22 +483,18 @@ impl ScalarQueryParser for FtsQueryParser { } let scalar = maybe_scalar(&args[1], data_type)?; if let ScalarValue::Utf8(Some(scalar_str)) = scalar { - // TODO(https://github.com/lancedb/lance/issues/3855): - // - // Create the contains_tokens UDF if func.name() == "contains_tokens" { - let query = TextQuery::StringContains(scalar_str); - Some(IndexedExpression::index_query( - column.to_string(), - self.index_name.clone(), - Arc::new(query), - )) - } else { - None + let query = TokenQuery::TokensContains(scalar_str); + if self.index.is_query_allowed(&query) { + return Some(IndexedExpression::index_query( + column.to_string(), + self.index_name.clone(), + Arc::new(query), + )); + } } - } else { - None } + None } } diff --git a/rust/lance-index/src/scalar/inverted/index.rs b/rust/lance-index/src/scalar/inverted/index.rs index bd0dd6062b0..4d4357e5b95 100644 --- a/rust/lance-index/src/scalar/inverted/index.rs +++ b/rust/lance-index/src/scalar/inverted/index.rs @@ -13,6 +13,8 @@ use std::{ ops::Range, }; +use crate::metrics::NoOpMetricsCollector; +use crate::prefilter::NoFilter; use arrow::{ array::LargeBinaryBuilder, datatypes::{self, Float32Type, Int32Type, UInt64Type}, @@ -33,10 +35,11 @@ use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion_common::DataFusionError; use deepsize::DeepSizeOf; use fst::{Automaton, IntoStreamer, Streamer}; -use futures::{stream, StreamExt, TryStreamExt}; +use futures::{stream, FutureExt, StreamExt, TryStreamExt}; use itertools::Itertools; use lance_arrow::{iter_str_array, RecordBatchExt}; use lance_core::cache::{CacheKey, LanceCache}; +use lance_core::utils::mask::RowIdTreeMap; use lance_core::utils::{ mask::RowIdMask, tracing::{IO_TYPE_LOAD_SCALAR_PART, TRACE_IO_EVENTS}, @@ -46,6 +49,7 @@ use lance_core::{Error, Result, ROW_ID, ROW_ID_FIELD}; use roaring::RoaringBitmap; use snafu::location; use std::sync::LazyLock; +use tantivy::tokenizer::Language; use tracing::{info, instrument}; use super::{ @@ -69,7 +73,7 @@ use super::{ use super::{wand::*, InvertedIndexBuilder, InvertedIndexParams}; use crate::frag_reuse::FragReuseIndex; use crate::scalar::{ - AnyQuery, IndexReader, IndexStore, MetricsCollector, SargableQuery, ScalarIndex, SearchResult, + AnyQuery, IndexReader, IndexStore, MetricsCollector, ScalarIndex, SearchResult, TokenQuery, }; use crate::Index; use crate::{prefilter::PreFilter, scalar::inverted::iter::take_fst_keys}; @@ -94,6 +98,8 @@ pub static SCORE_FIELD: LazyLock = LazyLock::new(|| Field::new(SCORE_COL, DataType::Float32, true)); pub static FTS_SCHEMA: LazyLock = LazyLock::new(|| Arc::new(Schema::new(vec![ROW_ID_FIELD.clone(), SCORE_FIELD.clone()]))); +static ROW_ID_SCHEMA: LazyLock = + LazyLock::new(|| Arc::new(Schema::new(vec![ROW_ID_FIELD.clone()]))); #[derive(Clone)] pub struct InvertedIndex { @@ -320,6 +326,43 @@ impl Index for InvertedIndex { } } +impl InvertedIndex { + /// Whether the query can use the current index. + pub fn is_query_allowed(&self, query: &TokenQuery) -> bool { + match query { + TokenQuery::TokensContains(_) => { + self.params.base_tokenizer == "simple" + && self.params.max_token_length.is_none() + && self.params.language == Language::English + && !self.params.stem + } + } + } + + /// Search docs match the input text. + async fn do_search(&self, text: &str) -> Result { + let params = FtsSearchParams::new(); + let mut tokenizer = self.tokenizer.clone(); + let tokens = collect_tokens(text, &mut tokenizer, None); + + let (doc_ids, _) = self + .bm25_search( + tokens.into(), + params.into(), + Operator::And, + Arc::new(NoFilter), + Arc::new(NoOpMetricsCollector), + ) + .boxed() + .await?; + + Ok(RecordBatch::try_new( + ROW_ID_SCHEMA.clone(), + vec![Arc::new(UInt64Array::from(doc_ids))], + )?) + } +} + #[async_trait] impl ScalarIndex for InvertedIndex { // return the row ids of the documents that contain the query @@ -329,11 +372,20 @@ impl ScalarIndex for InvertedIndex { query: &dyn AnyQuery, _metrics: &dyn MetricsCollector, ) -> Result { - let query = query.as_any().downcast_ref::().unwrap(); - return Err(Error::invalid_input( - format!("unsupported query {:?} for inverted index", query), - location!(), - )); + let query = query.as_any().downcast_ref::().unwrap(); + + match query { + TokenQuery::TokensContains(text) => { + let records = self.do_search(text).await?; + let row_ids = records + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let row_ids = row_ids.iter().flatten().collect_vec(); + Ok(SearchResult::AtMost(RowIdTreeMap::from_iter(row_ids))) + } + } } fn can_remap(&self) -> bool { diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index fc49813c2af..863a87e4678 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -2080,7 +2080,7 @@ mod tests { }; use arrow_ord::sort::sort_to_indices; use arrow_schema::{ - DataType, Field as ArrowField, Fields as ArrowFields, Schema as ArrowSchema, + DataType, Field as ArrowField, Field, Fields as ArrowFields, Schema as ArrowSchema, }; use lance_arrow::bfloat16::{self, ARROW_EXT_META_KEY, ARROW_EXT_NAME_KEY, BFLOAT16_EXT_NAME}; use lance_core::datatypes::LANCE_STORAGE_CLASS_SCHEMA_META_KEY; @@ -2098,8 +2098,12 @@ mod tests { use lance_table::feature_flags; use lance_table::format::{DataFile, WriterVersion}; + use crate::datafusion::LanceTableProvider; use all_asserts::assert_true; + use datafusion::common::{assert_contains, assert_not_contains}; + use datafusion::prelude::SessionContext; use lance_datafusion::datagen::DatafusionDatagenExt; + use lance_datafusion::udf::register_functions; use lance_testing::datagen::generate_random_array; use pretty_assertions::assert_eq; use rand::seq::SliceRandom; @@ -7269,4 +7273,173 @@ mod tests { dataset.validate().await.unwrap(); assert_eq!(dataset.count_rows(None).await.unwrap(), 3); } + + #[tokio::test] + async fn test_sql_contains_tokens() { + let text_col = Arc::new(StringArray::from(vec![ + "a cat catch a fish", + "a fish catch a cat", + "a white cat catch a big fish", + "cat catchup fish", + "cat fish catch", + ])); + + // Prepare dataset + let batch = RecordBatch::try_new( + arrow_schema::Schema::new(vec![Field::new("text", DataType::Utf8, false)]).into(), + vec![text_col.clone()], + ) + .unwrap(); + let schema = batch.schema(); + let stream = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema); + let mut dataset = Dataset::write(stream, "memory://test/table", None) + .await + .unwrap(); + + // Test without fts index + let results = execute_sql( + "select * from foo where contains_tokens(text, 'cat catch fish')", + "foo".to_string(), + Arc::new(dataset.clone()), + ) + .await + .unwrap(); + + assert_results( + results, + &StringArray::from(vec![ + "a cat catch a fish", + "a fish catch a cat", + "a white cat catch a big fish", + "cat fish catch", + ]), + ); + + // Verify plan, should not contain ScalarIndexQuery. + let results = execute_sql( + "explain select * from foo where contains_tokens(text, 'cat catch fish')", + "foo".to_string(), + Arc::new(dataset.clone()), + ) + .await + .unwrap(); + let plan = format!("{:?}", results); + assert_not_contains!(&plan, "ScalarIndexQuery"); + + // Test with unsuitable fts index + dataset + .create_index( + &["text"], + IndexType::Inverted, + None, + &InvertedIndexParams::default().base_tokenizer("raw".to_string()), + true, + ) + .await + .unwrap(); + + let results = execute_sql( + "select * from foo where contains_tokens(text, 'cat catch fish')", + "foo".to_string(), + Arc::new(dataset.clone()), + ) + .await + .unwrap(); + + assert_results( + results, + &StringArray::from(vec![ + "a cat catch a fish", + "a fish catch a cat", + "a white cat catch a big fish", + "cat fish catch", + ]), + ); + + // Verify plan, should not contain ScalarIndexQuery because fts index is not unsuitable. + let results = execute_sql( + "explain select * from foo where contains_tokens(text, 'cat catch fish')", + "foo".to_string(), + Arc::new(dataset.clone()), + ) + .await + .unwrap(); + let plan = format!("{:?}", results); + assert_not_contains!(&plan, "ScalarIndexQuery"); + + // Test with suitable fts index + dataset + .create_index( + &["text"], + IndexType::Inverted, + None, + &InvertedIndexParams::default() + .max_token_length(None) + .stem(false), + true, + ) + .await + .unwrap(); + + let results = execute_sql( + "select * from foo where contains_tokens(text, 'cat catch fish')", + "foo".to_string(), + Arc::new(dataset.clone()), + ) + .await + .unwrap(); + + assert_results( + results, + &StringArray::from(vec![ + "a cat catch a fish", + "a fish catch a cat", + "a white cat catch a big fish", + "cat fish catch", + ]), + ); + + // Verify plan, should contain ScalarIndexQuery. + let results = execute_sql( + "explain select * from foo where contains_tokens(text, 'cat catch fish')", + "foo".to_string(), + Arc::new(dataset.clone()), + ) + .await + .unwrap(); + let plan = format!("{:?}", results); + assert_contains!(&plan, "ScalarIndexQuery"); + } + + async fn execute_sql( + sql: &str, + table: String, + dataset: Arc, + ) -> Result> { + let ctx = SessionContext::new(); + ctx.register_table( + table, + Arc::new(LanceTableProvider::new(dataset, false, false)), + )?; + register_functions(&ctx); + + let df = ctx.sql(sql).await?; + Ok(df + .execute_stream() + .await + .unwrap() + .try_collect::>() + .await?) + } + + fn assert_results(results: Vec, values: &T) { + assert_eq!(results.len(), 1); + let results = results.into_iter().next().unwrap(); + assert_eq!(results.num_columns(), 1); + + assert_eq!( + results.column(0).as_any().downcast_ref::().unwrap(), + values + ) + } } diff --git a/rust/lance/src/index.rs b/rust/lance/src/index.rs index 69c775be93d..e41786fa009 100644 --- a/rust/lance/src/index.rs +++ b/rust/lance/src/index.rs @@ -71,17 +71,17 @@ pub mod prefilter; pub mod scalar; pub mod vector; -use crate::dataset::index::LanceIndexStoreExt; -pub use crate::index::prefilter::{FilterLoader, PreFilter}; -pub use create::CreateIndexBuilder; - use self::append::merge_indices; use self::vector::remap_vector_index; +use crate::dataset::index::LanceIndexStoreExt; use crate::dataset::transaction::{Operation, Transaction}; use crate::index::frag_reuse::{load_frag_reuse_index_details, open_frag_reuse_index}; use crate::index::mem_wal::open_mem_wal_index; +pub use crate::index::prefilter::{FilterLoader, PreFilter}; use crate::session::index_caches::{FragReuseIndexKey, IndexMetadataKey}; use crate::{dataset::Dataset, Error, Result}; +pub use create::CreateIndexBuilder; +use lance_index::scalar::inverted::InvertedIndex; // Cache keys for different index types #[derive(Debug, Clone)] @@ -1374,8 +1374,19 @@ impl DatasetIndexInternalExt for Dataset { Box::new(TextQueryParser::new(index.name.clone(), true)) as Box } - ScalarIndexType::Inverted => Box::new(FtsQueryParser::new(index.name.clone())) - as Box, + ScalarIndexType::Inverted => { + let fts_index = + lance_index::scalar::expression::ScalarIndexLoader::load_index( + self, + &field.name, + &index.name, + &NoOpMetricsCollector, + ) + .await?; + let fts_index = fts_index.as_any().downcast_ref::().unwrap(); + Box::new(FtsQueryParser::new(index.name.clone(), fts_index.clone())) + as Box + } _ => continue, }, _ => {