Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions java/lance-jni/src/merge_insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use arrow::ffi_stream::{ArrowArrayStreamReader, FFI_ArrowArrayStream};
use jni::objects::{JObject, JString, JValueGen};
use jni::sys::jlong;
use jni::JNIEnv;
use lance::dataset::scanner::LanceFilter;
use lance::dataset::scanner::ExprFilter;
use lance::dataset::{
MergeInsertBuilder, MergeStats, WhenMatched, WhenNotMatched, WhenNotMatchedBySource,
};
Expand Down Expand Up @@ -158,7 +158,7 @@ fn extract_when_not_matched_by_source_str<'local>(
fn extract_when_not_matched_by_source_delete_expr<'local>(
env: &mut JNIEnv<'local>,
jparam: &JObject,
) -> Result<Option<LanceFilter>> {
) -> Result<Option<ExprFilter>> {
let when_not_matched_by_source_delete_expr = env
.call_method(
jparam,
Expand All @@ -169,7 +169,7 @@ fn extract_when_not_matched_by_source_delete_expr<'local>(
.l()?;

if let Some(expr) = env.get_string_opt(&when_not_matched_by_source_delete_expr)? {
return Ok(Some(LanceFilter::Sql(expr)));
return Ok(Some(ExprFilter::Sql(expr)));
}

let when_not_matched_by_source_delete_substrait_expr = env
Expand All @@ -182,15 +182,15 @@ fn extract_when_not_matched_by_source_delete_expr<'local>(
.l()?;

match env.get_bytes_opt(&when_not_matched_by_source_delete_substrait_expr)? {
Some(expr) => Ok(Some(LanceFilter::Substrait(expr.to_vec()))),
Some(expr) => Ok(Some(ExprFilter::Substrait(expr.to_vec()))),
None => Ok(None),
}
}

fn extract_when_not_matched_by_source(
schema: &Schema,
when_not_matched_by_source: &str,
when_not_matched_by_source_delete_expr: Option<LanceFilter>,
when_not_matched_by_source_delete_expr: Option<ExprFilter>,
) -> Result<WhenNotMatchedBySource> {
match when_not_matched_by_source {
"Keep" => Ok(WhenNotMatchedBySource::Keep),
Expand Down
16 changes: 13 additions & 3 deletions rust/lance-index/src/scalar/inverted/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2386,6 +2386,7 @@ pub fn flat_bm25_search(
query_tokens: &Tokens,
tokenizer: &mut Box<dyn LanceTokenizer>,
scorer: &mut MemBM25Scorer,
schema: SchemaRef,
) -> std::result::Result<RecordBatch, DataFusionError> {
let doc_iter = iter_str_array(&batch[doc_col]);
let mut scores = Vec::with_capacity(batch.num_rows());
Expand Down Expand Up @@ -2423,7 +2424,7 @@ pub fn flat_bm25_search(
let score_col = Arc::new(Float32Array::from(scores)) as ArrayRef;
let batch = batch
.try_with_column(SCORE_FIELD.clone(), score_col)?
.project_by_schema(&FTS_SCHEMA)?; // the scan node would probably scan some extra columns for prefilter, drop them here
.project_by_schema(&schema)?;
Ok(batch)
}

Expand All @@ -2432,6 +2433,7 @@ pub fn flat_bm25_search_stream(
doc_col: String,
query: String,
index: &Option<InvertedIndex>,
schema: SchemaRef,
) -> SendableRecordBatchStream {
let mut tokenizer = match index {
Some(index) => index.tokenizer(),
Expand Down Expand Up @@ -2466,10 +2468,18 @@ pub fn flat_bm25_search_stream(
None => MemBM25Scorer::new(0, 0, HashMap::new()),
};

let batch_schema = schema.clone();
let stream = input.map(move |batch| {
let batch = batch?;

let batch = flat_bm25_search(batch, &doc_col, &tokens, &mut tokenizer, &mut bm25_scorer)?;
let batch = flat_bm25_search(
batch,
&doc_col,
&tokens,
&mut tokenizer,
&mut bm25_scorer,
batch_schema.clone(),
)?;

// filter out rows with score 0
let score_col = batch[SCORE_COL].as_primitive::<Float32Type>();
Expand All @@ -2483,7 +2493,7 @@ pub fn flat_bm25_search_stream(
Ok(batch)
});

Box::pin(RecordBatchStreamAdapter::new(FTS_SCHEMA.clone(), stream)) as SendableRecordBatchStream
Box::pin(RecordBatchStreamAdapter::new(schema, stream)) as SendableRecordBatchStream
}

pub fn is_phrase_query(query: &str) -> bool {
Expand Down
2 changes: 1 addition & 1 deletion rust/lance/src/dataset/fragment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1754,7 +1754,7 @@ impl FileFragment {
// else if predicate is `false`, filter the predicate
// We do this on the expression level after expression optimization has
// occurred so we also catch expressions that are equivalent to `true`
if let Some(predicate) = &scanner.get_filter()? {
if let Some(predicate) = &scanner.get_expr_filter()? {
if matches!(
predicate,
Expr::Literal(ScalarValue::Boolean(Some(false)), _)
Expand Down
Loading