Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 54 additions & 21 deletions rust/lance-datafusion/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

//! Datafusion user defined functions

use arrow_array::{ArrayRef, BooleanArray, StringArray};
use arrow_array::{Array, ArrayRef, BooleanArray, StringArray};
use arrow_schema::DataType;
use datafusion::logical_expr::{create_udf, ScalarUDF, Volatility};
use datafusion::prelude::SessionContext;
Expand All @@ -15,20 +15,25 @@ pub fn register_functions(ctx: &SessionContext) {
ctx.register_udf(CONTAINS_TOKENS_UDF.clone());
}

/// This method checks whether a string contains another string. It utilizes FTS (Full-Text Search)
/// indexes, but due to the false negative characteristic of FTS, the results may have omissions.
/// For example, "bakin" will not match documents containing "baking."
/// If the query string is a whole word, or if you prioritize better performance, `contains_tokens`
/// is the better choice. Otherwise, you can use the `contains` method to obtain accurate results.
/// This method checks whether a string contains all specified tokens. The tokens are separated by
/// punctuations and white spaces.
///
/// The functionality is equivalent to FTS MatchQuery (with fuzziness disabled, Operator::And,
/// and using the simple tokenizer). If FTS index exists and suites the query, it will be used to
/// optimize the query.
///
/// Usage
/// * Use `contains_tokens` in sql.
/// ```rust,ignore
/// let sql = "SELECT * FROM table WHERE contains_tokens(text_col, 'bakin')"
/// let sql = "SELECT * FROM table WHERE contains_tokens(text_col, 'fox jumps dog')";
/// let mut ds = Dataset::open(&ds_path).await?;
/// let mut builder = ds.sql(&sql);
/// let records = builder.clone().build().await?.into_batch_records().await?;
/// let ctx = SessionContext::new();
/// ctx.register_table(
/// "table",
/// Arc::new(LanceTableProvider::new(dataset, false, false)),
/// )?;
/// register_functions(&ctx);
/// let df = ctx.sql(sql).await?;
/// ```
fn contains_tokens() -> ScalarUDF {
let function = Arc::new(make_scalar_function(
Expand All @@ -44,10 +49,25 @@ fn contains_tokens() -> ScalarUDF {
),
)?;

let result = column
.iter()
.enumerate()
.map(|(i, column)| column.map(|value| value.contains(scalar_str.value(i))));
let tokens: Option<Vec<&str>> = match scalar_str.len() {
0 => None,
_ => Some(collect_tokens(scalar_str.value(0))),
};

let result = column.iter().map(|text| {
text.map(|text| {
let text_tokens = collect_tokens(text);
if let Some(tokens) = &tokens {
tokens.len()
== tokens
.iter()
.filter(|token| text_tokens.contains(*token))
.count()
} else {
true
}
})
});

Ok(Arc::new(BooleanArray::from_iter(result)) as ArrayRef)
},
Expand All @@ -63,7 +83,14 @@ fn contains_tokens() -> ScalarUDF {
)
}

static CONTAINS_TOKENS_UDF: LazyLock<ScalarUDF> = LazyLock::new(contains_tokens);
/// Split tokens separated by punctuations and white spaces.
fn collect_tokens(text: &str) -> Vec<&str> {
text.split(|c: char| !c.is_alphanumeric())
.filter(|word| !word.is_empty())
.collect()
}

pub static CONTAINS_TOKENS_UDF: LazyLock<ScalarUDF> = LazyLock::new(contains_tokens);

#[cfg(test)]
mod tests {
Expand All @@ -79,13 +106,19 @@ mod tests {
// Prepare arguments
let contains_tokens = CONTAINS_TOKENS_UDF.clone();
let text_col = Arc::new(StringArray::from(vec![
"a cat",
"lovely cat",
"white cat",
"catch up",
"fish",
"a cat catch a fish",
"a fish catch a cat",
"a white cat catch a big fish",
"cat catchup fish",
"cat fish catch",
]));
let token = Arc::new(StringArray::from(vec![
" cat catch fish.",
" cat catch fish.",
" cat catch fish.",
" cat catch fish.",
" cat catch fish.",
]));
let token = Arc::new(StringArray::from(vec!["cat", "cat", "cat", "cat", "cat"]));

let args = vec![ColumnarValue::Array(text_col), ColumnarValue::Array(token)];
let arg_fields = vec![
Expand All @@ -107,7 +140,7 @@ mod tests {
let array = array.as_any().downcast_ref::<BooleanArray>().unwrap();
assert_eq!(
array.clone(),
BooleanArray::from(vec![true, true, true, true, false])
BooleanArray::from(vec![true, true, true, false, true])
);
} else {
panic!("Expected an Array but got {:?}", values);
Expand Down
38 changes: 38 additions & 0 deletions rust/lance-index/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ pub mod zonemap;

use crate::frag_reuse::FragReuseIndex;
pub use inverted::tokenizer::InvertedIndexParams;
use lance_datafusion::udf::CONTAINS_TOKENS_UDF;

pub const LANCE_SCALAR_INDEX: &str = "__lance_scalar_index";

Expand Down Expand Up @@ -554,6 +555,43 @@ impl AnyQuery for TextQuery {
}
}

/// A query that a InvertedIndex can satisfy
#[derive(Debug, Clone, PartialEq)]
pub enum TokenQuery {
/// Retrieve all row ids where the text contains all tokens parsed from given string. The tokens
Comment on lines +559 to +561
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There might be a few pieces of glue missing. I think we need a TokenQueryParser (see https://github.com/lancedb/lance/blob/60711f360b7f8692df44a0e84c98c8fdff2897a3/rust/lance-index/src/scalar/expression.rs#L348 for the TextQueryParser).

We also need to register the token query parser here: https://github.com/lancedb/lance/blob/60711f360b7f8692df44a0e84c98c8fdff2897a3/rust/lance/src/index.rs#L1361

Can we call is_query_allowed in the registration function (scalar_index_info)? This way we can skip the scalar index entirely if it is not eligible. Returning AtLeast with zero rows might lead to bad performance (the planner will think we are doing a scalar index optimized search and make certain decisions based on that)

Copy link
Copy Markdown
Contributor Author

@wojiaodoubao wojiaodoubao Aug 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need a TokenQueryParser

We already have a FtsQueryParser which parses contains_tokens into TokenQuery::TokensContains. Actually you implemented it (^-^). Shall we can just rely on FtsQueryParser?

Can we call is_query_allowed in the registration function ...

Thanks your nice suggestion, let me fix it.

/// are separated by punctuations and white spaces.
TokensContains(String),
}

impl AnyQuery for TokenQuery {
fn as_any(&self) -> &dyn Any {
self
}

fn format(&self, col: &str) -> String {
format!("{}", self.to_expr(col.to_string()))
}

fn to_expr(&self, col: String) -> Expr {
match self {
Self::TokensContains(substr) => Expr::ScalarFunction(ScalarFunction {
func: Arc::new(CONTAINS_TOKENS_UDF.clone()),
args: vec![
Expr::Column(Column::new_unqualified(col)),
Expr::Literal(ScalarValue::Utf8(Some(substr.clone())), None),
],
}),
}
}

fn dyn_eq(&self, other: &dyn AnyQuery) -> bool {
match other.as_any().downcast_ref::<Self>() {
Some(o) => self == o,
None => false,
}
}
}

/// The result of a search operation against a scalar index
#[derive(Debug, PartialEq)]
pub enum SearchResult {
Expand Down
39 changes: 20 additions & 19 deletions rust/lance-index/src/scalar/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,18 @@ use datafusion_expr::{
Between, BinaryExpr, Expr, Operator, ReturnFieldArgs, ScalarUDF,
};

use super::{
AnyQuery, LabelListQuery, MetricsCollector, SargableQuery, ScalarIndex, SearchResult,
TextQuery, TokenQuery,
};
use crate::scalar::inverted::InvertedIndex;
use futures::join;
use lance_core::{utils::mask::RowIdMask, Error, Result};
use lance_datafusion::{expr::safe_coerce_scalar, planner::Planner};
use roaring::RoaringBitmap;
use snafu::location;
use tracing::instrument;

use super::{
AnyQuery, LabelListQuery, MetricsCollector, SargableQuery, ScalarIndex, SearchResult, TextQuery,
};

const MAX_DEPTH: usize = 500;

/// An indexed expression consists of a scalar index query with a post-scan filter
Expand Down Expand Up @@ -427,11 +428,15 @@ impl ScalarQueryParser for TextQueryParser {
#[derive(Debug, Clone)]
pub struct FtsQueryParser {
index_name: String,
index: InvertedIndex,
}

impl FtsQueryParser {
pub fn new(index_name: String) -> Self {
Self { index_name }
pub fn new(name: String, index: InvertedIndex) -> Self {
Self {
index_name: name,
index,
}
}
}

Expand Down Expand Up @@ -478,22 +483,18 @@ impl ScalarQueryParser for FtsQueryParser {
}
let scalar = maybe_scalar(&args[1], data_type)?;
if let ScalarValue::Utf8(Some(scalar_str)) = scalar {
// TODO(https://github.com/lancedb/lance/issues/3855):
//
// Create the contains_tokens UDF
if func.name() == "contains_tokens" {
let query = TextQuery::StringContains(scalar_str);
Some(IndexedExpression::index_query(
column.to_string(),
self.index_name.clone(),
Arc::new(query),
))
} else {
None
let query = TokenQuery::TokensContains(scalar_str);
if self.index.is_query_allowed(&query) {
return Some(IndexedExpression::index_query(
column.to_string(),
self.index_name.clone(),
Arc::new(query),
));
}
}
} else {
None
}
None
}
}

Expand Down
66 changes: 59 additions & 7 deletions rust/lance-index/src/scalar/inverted/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ use std::{
ops::Range,
};

use crate::metrics::NoOpMetricsCollector;
use crate::prefilter::NoFilter;
use arrow::{
array::LargeBinaryBuilder,
datatypes::{self, Float32Type, Int32Type, UInt64Type},
Expand All @@ -33,10 +35,11 @@ use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use datafusion_common::DataFusionError;
use deepsize::DeepSizeOf;
use fst::{Automaton, IntoStreamer, Streamer};
use futures::{stream, StreamExt, TryStreamExt};
use futures::{stream, FutureExt, StreamExt, TryStreamExt};
use itertools::Itertools;
use lance_arrow::{iter_str_array, RecordBatchExt};
use lance_core::cache::{CacheKey, LanceCache};
use lance_core::utils::mask::RowIdTreeMap;
use lance_core::utils::{
mask::RowIdMask,
tracing::{IO_TYPE_LOAD_SCALAR_PART, TRACE_IO_EVENTS},
Expand All @@ -46,6 +49,7 @@ use lance_core::{Error, Result, ROW_ID, ROW_ID_FIELD};
use roaring::RoaringBitmap;
use snafu::location;
use std::sync::LazyLock;
use tantivy::tokenizer::Language;
use tracing::{info, instrument};

use super::{
Expand All @@ -69,7 +73,7 @@ use super::{
use super::{wand::*, InvertedIndexBuilder, InvertedIndexParams};
use crate::frag_reuse::FragReuseIndex;
use crate::scalar::{
AnyQuery, IndexReader, IndexStore, MetricsCollector, SargableQuery, ScalarIndex, SearchResult,
AnyQuery, IndexReader, IndexStore, MetricsCollector, ScalarIndex, SearchResult, TokenQuery,
};
use crate::Index;
use crate::{prefilter::PreFilter, scalar::inverted::iter::take_fst_keys};
Expand All @@ -94,6 +98,8 @@ pub static SCORE_FIELD: LazyLock<Field> =
LazyLock::new(|| Field::new(SCORE_COL, DataType::Float32, true));
pub static FTS_SCHEMA: LazyLock<SchemaRef> =
LazyLock::new(|| Arc::new(Schema::new(vec![ROW_ID_FIELD.clone(), SCORE_FIELD.clone()])));
static ROW_ID_SCHEMA: LazyLock<SchemaRef> =
LazyLock::new(|| Arc::new(Schema::new(vec![ROW_ID_FIELD.clone()])));

#[derive(Clone)]
pub struct InvertedIndex {
Expand Down Expand Up @@ -320,6 +326,43 @@ impl Index for InvertedIndex {
}
}

impl InvertedIndex {
/// Whether the query can use the current index.
pub fn is_query_allowed(&self, query: &TokenQuery) -> bool {
match query {
TokenQuery::TokensContains(_) => {
self.params.base_tokenizer == "simple"
&& self.params.max_token_length.is_none()
&& self.params.language == Language::English
&& !self.params.stem
}
}
}

/// Search docs match the input text.
async fn do_search(&self, text: &str) -> Result<RecordBatch> {
let params = FtsSearchParams::new();
let mut tokenizer = self.tokenizer.clone();
let tokens = collect_tokens(text, &mut tokenizer, None);

let (doc_ids, _) = self
.bm25_search(
tokens.into(),
params.into(),
Operator::And,
Arc::new(NoFilter),
Arc::new(NoOpMetricsCollector),
)
.boxed()
.await?;

Ok(RecordBatch::try_new(
ROW_ID_SCHEMA.clone(),
vec![Arc::new(UInt64Array::from(doc_ids))],
)?)
}
}

#[async_trait]
impl ScalarIndex for InvertedIndex {
// return the row ids of the documents that contain the query
Expand All @@ -329,11 +372,20 @@ impl ScalarIndex for InvertedIndex {
query: &dyn AnyQuery,
_metrics: &dyn MetricsCollector,
) -> Result<SearchResult> {
let query = query.as_any().downcast_ref::<SargableQuery>().unwrap();
return Err(Error::invalid_input(
format!("unsupported query {:?} for inverted index", query),
location!(),
));
let query = query.as_any().downcast_ref::<TokenQuery>().unwrap();

match query {
TokenQuery::TokensContains(text) => {
let records = self.do_search(text).await?;
let row_ids = records
.column(0)
.as_any()
.downcast_ref::<UInt64Array>()
.unwrap();
let row_ids = row_ids.iter().flatten().collect_vec();
Ok(SearchResult::AtMost(RowIdTreeMap::from_iter(row_ids)))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prob we can do this better in the future, if the tokenizer wouldn't change the original texts, we can return AtLeast directly

}
}
}

fn can_remap(&self) -> bool {
Expand Down
Loading
Loading