diff --git a/java/lance-jni/src/blocking_scanner.rs b/java/lance-jni/src/blocking_scanner.rs index a8f3c807ed7..122824252cd 100644 --- a/java/lance-jni/src/blocking_scanner.rs +++ b/java/lance-jni/src/blocking_scanner.rs @@ -5,6 +5,7 @@ use std::sync::Arc; use crate::error::{Error, Result}; use crate::ffi::JNIEnvExt; +use crate::traits::{import_vec_from_method, import_vec_to_rust}; use arrow::array::Float32Array; use arrow::{ffi::FFI_ArrowSchema, ffi_stream::FFI_ArrowArrayStream}; use arrow_schema::SchemaRef; @@ -12,6 +13,12 @@ use jni::objects::{JObject, JString}; use jni::sys::{jboolean, jint, JNI_TRUE}; use jni::{sys::jlong, JNIEnv}; use lance::dataset::scanner::{ColumnOrdering, DatasetRecordBatchStream, Scanner}; +use lance_index::scalar::inverted::query::{ + BooleanQuery as FtsBooleanQuery, BoostQuery as FtsBoostQuery, FtsQuery, + MatchQuery as FtsMatchQuery, MultiMatchQuery as FtsMultiMatchQuery, Occur as FtsOccur, + PhraseQuery as FtsPhraseQuery, +}; +use lance_index::scalar::FullTextSearchQuery; use lance_io::ffi::to_ffi_arrow_array_stream; use lance_linalg::distance::DistanceType; @@ -51,6 +58,141 @@ impl BlockingScanner { } } +fn build_full_text_search_query<'a>(env: &mut JNIEnv<'a>, java_obj: JObject) -> Result { + let type_obj = env + .call_method( + &java_obj, + "getType", + "()Lorg/lance/ipc/FullTextQuery$Type;", + &[], + )? + .l()?; + let type_name = env.get_string_from_method(&type_obj, "name")?; + + match type_name.as_str() { + "MATCH" => { + let query_text = env.get_string_from_method(&java_obj, "getQueryText")?; + let column = env.get_string_from_method(&java_obj, "getColumn")?; + let boost = env.get_f32_from_method(&java_obj, "getBoost")?; + let fuzziness = env.get_optional_u32_from_method(&java_obj, "getFuzziness")?; + let max_expansions = env.get_int_as_usize_from_method(&java_obj, "getMaxExpansions")?; + let operator = env.get_fts_operator_from_method(&java_obj)?; + let prefix_length = env.get_u32_from_method(&java_obj, "getPrefixLength")?; + + let mut query = FtsMatchQuery::new(query_text); + query = query.with_column(Some(column)); + query = query + .with_boost(boost) + .with_fuzziness(fuzziness) + .with_max_expansions(max_expansions) + .with_operator(operator) + .with_prefix_length(prefix_length); + + Ok(FtsQuery::Match(query)) + } + "MATCH_PHRASE" => { + let query_text = env.get_string_from_method(&java_obj, "getQueryText")?; + let column = env.get_string_from_method(&java_obj, "getColumn")?; + let slop = env.get_u32_from_method(&java_obj, "getSlop")?; + + let mut query = FtsPhraseQuery::new(query_text); + query = query.with_column(Some(column)); + query = query.with_slop(slop); + + Ok(FtsQuery::Phrase(query)) + } + "MULTI_MATCH" => { + let query_text = env.get_string_from_method(&java_obj, "getQueryText")?; + let columns: Vec = + import_vec_from_method(env, &java_obj, "getColumns", |env, elem| { + let jstr = JString::from(elem); + let value: String = env.get_string(&jstr)?.into(); + Ok(value) + })?; + + let boosts: Option> = + env.get_optional_from_method(&java_obj, "getBoosts", |env, list_obj| { + import_vec_to_rust(env, &list_obj, |env, elem| { + env.get_f32_from_method(&elem, "floatValue") + }) + })?; + let operator = env.get_fts_operator_from_method(&java_obj)?; + + let mut query = FtsMultiMatchQuery::try_new(query_text, columns)?; + if let Some(boosts) = boosts { + query = query.try_with_boosts(boosts)?; + } + query = query.with_operator(operator); + + Ok(FtsQuery::MultiMatch(query)) + } + "BOOST" => { + let positive_obj = env + .call_method( + &java_obj, + "getPositive", + "()Lorg/lance/ipc/FullTextQuery;", + &[], + )? + .l()?; + if positive_obj.is_null() { + return Err(Error::input_error( + "positive query must not be null in BOOST FullTextQuery".to_string(), + )); + } + let negative_obj = env + .call_method( + &java_obj, + "getNegative", + "()Lorg/lance/ipc/FullTextQuery;", + &[], + )? + .l()?; + if negative_obj.is_null() { + return Err(Error::input_error( + "negative query must not be null in BOOST FullTextQuery".to_string(), + )); + } + + let positive = build_full_text_search_query(env, positive_obj)?; + let negative = build_full_text_search_query(env, negative_obj)?; + let negative_boost = env.get_f32_from_method(&java_obj, "getNegativeBoost")?; + + let query = FtsBoostQuery::new(positive, negative, Some(negative_boost)); + Ok(FtsQuery::Boost(query)) + } + "BOOLEAN" => { + let clauses: Vec<(FtsOccur, FtsQuery)> = + import_vec_from_method(env, &java_obj, "getClauses", |env, clause_obj| { + let occur = env.get_occur_from_method(&clause_obj)?; + + let query_obj = env + .call_method( + &clause_obj, + "getQuery", + "()Lorg/lance/ipc/FullTextQuery;", + &[], + )? + .l()?; + if query_obj.is_null() { + return Err(Error::input_error( + "BooleanClause query must not be null".to_string(), + )); + } + let query = build_full_text_search_query(env, query_obj)?; + Ok((occur, query)) + })?; + + let boolean_query = FtsBooleanQuery::new(clauses); + Ok(FtsQuery::Boolean(boolean_query)) + } + other => Err(Error::input_error(format!( + "Unsupported FullTextQuery type: {}", + other + ))), + } +} + /////////////////// // Write Methods // /////////////////// @@ -67,6 +209,7 @@ pub extern "system" fn Java_org_lance_ipc_LanceScanner_createScanner<'local>( limit_obj: JObject, // Optional offset_obj: JObject, // Optional query_obj: JObject, // Optional + fts_query_obj: JObject, // Optional with_row_id: jboolean, // boolean with_row_address: jboolean, // boolean batch_readahead: jint, // int @@ -85,6 +228,7 @@ pub extern "system" fn Java_org_lance_ipc_LanceScanner_createScanner<'local>( limit_obj, offset_obj, query_obj, + fts_query_obj, with_row_id, with_row_address, batch_readahead, @@ -105,6 +249,7 @@ fn inner_create_scanner<'local>( limit_obj: JObject, offset_obj: JObject, query_obj: JObject, + fts_query_obj: JObject, with_row_id: jboolean, with_row_address: jboolean, batch_readahead: jint, @@ -204,6 +349,13 @@ fn inner_create_scanner<'local>( Ok(()) })?; + env.get_optional(&fts_query_obj, |env, java_obj| { + let fts_query = build_full_text_search_query(env, java_obj)?; + let full_text_query = FullTextSearchQuery::new_query(fts_query); + scanner.full_text_search(full_text_query)?; + Ok(()) + })?; + scanner.batch_readahead(batch_readahead as usize); env.get_optional(&column_orderings, |env, java_obj| { diff --git a/java/lance-jni/src/ffi.rs b/java/lance-jni/src/ffi.rs index 5889e562c6b..371e44563ed 100644 --- a/java/lance-jni/src/ffi.rs +++ b/java/lance-jni/src/ffi.rs @@ -9,6 +9,7 @@ use crate::Error; use jni::objects::{JByteBuffer, JFloatArray, JObjectArray, JString}; use jni::sys::jobjectArray; use jni::{objects::JObject, JNIEnv}; +use lance_index::scalar::inverted::query::{Occur, Operator}; /// Extend JNIEnv with helper functions. pub trait JNIEnvExt { @@ -97,6 +98,8 @@ pub trait JNIEnvExt { obj: &JObject, method_name: &str, ) -> Result>; + // Get f32 from Java Float with given method name. + fn get_f32_from_method(&mut self, obj: &JObject, method_name: &str) -> Result; fn get_optional_integer_from_method( &mut self, @@ -146,6 +149,10 @@ pub trait JNIEnvExt { fn get_optional(&mut self, obj: &JObject, f: F) -> Result> where F: FnOnce(&mut JNIEnv, JObject) -> Result; + + fn get_fts_operator_from_method(&mut self, obj: &JObject) -> Result; + + fn get_occur_from_method(&mut self, obj: &JObject) -> Result; } impl JNIEnvExt for JNIEnv<'_> { @@ -278,6 +285,34 @@ impl JNIEnvExt for JNIEnv<'_> { }) } + fn get_fts_operator_from_method(&mut self, obj: &JObject) -> Result { + let operator_obj = self + .call_method( + obj, + "getOperator", + "()Lorg/lance/ipc/FullTextQuery$Operator;", + &[], + )? + .l()?; + let operator_str = self.get_string_from_method(&operator_obj, "name")?; + Operator::try_from(operator_str.as_str()) + .map_err(|e| Error::io_error(format!("Invalid operator: {:?}", e))) + } + + fn get_occur_from_method(&mut self, obj: &JObject) -> Result { + let occur_obj = self + .call_method( + obj, + "getOccur", + "()Lorg/lance/ipc/FullTextQuery$Occur;", + &[], + )? + .l()?; + let occur_str = self.get_string_from_method(&occur_obj, "name")?; + Occur::try_from(occur_str.as_str()) + .map_err(|e| Error::io_error(format!("Invalid occur: {:?}", e))) + } + fn get_string_from_method(&mut self, obj: &JObject, method_name: &str) -> Result { let string_obj = self .call_method(obj, method_name, "()Ljava/lang/String;", &[])? @@ -335,6 +370,12 @@ impl JNIEnvExt for JNIEnv<'_> { self.get_optional_integer_from_method(obj, method_name) } + fn get_f32_from_method(&mut self, obj: &JObject, method_name: &str) -> Result { + let float_obj = self.call_method(obj, method_name, "()F", &[])?; + let float_value = float_obj.f()?; + Ok(float_value) + } + fn get_optional_integer_from_method( &mut self, obj: &JObject, diff --git a/java/src/main/java/org/lance/ipc/FullTextQuery.java b/java/src/main/java/org/lance/ipc/FullTextQuery.java new file mode 100755 index 00000000000..e28e12c2189 --- /dev/null +++ b/java/src/main/java/org/lance/ipc/FullTextQuery.java @@ -0,0 +1,360 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.ipc; + +import com.google.common.base.MoreObjects; +import org.apache.arrow.util.Preconditions; + +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +/** Base type for full text search queries used by Lance scanner. */ +public abstract class FullTextQuery { + public enum Type { + MATCH, + MATCH_PHRASE, + BOOST, + MULTI_MATCH, + BOOLEAN + } + + public enum Operator { + AND, + OR + } + + public enum Occur { + SHOULD, + MUST, + MUST_NOT + } + + public static final class BooleanClause { + private final Occur occur; + private final FullTextQuery query; + + public BooleanClause(Occur occur, FullTextQuery query) { + this.occur = Objects.requireNonNull(occur, "occur must not be null"); + this.query = Objects.requireNonNull(query, "query must not be null"); + } + + public Occur getOccur() { + return occur; + } + + public FullTextQuery getQuery() { + return query; + } + } + + public abstract Type getType(); + + public static FullTextQuery match(String queryText, String column) { + return match(queryText, column, 1.0f, Optional.empty(), 50, Operator.OR, 0); + } + + public static FullTextQuery match( + String queryText, + String column, + float boost, + Optional fuzziness, + int maxExpansions, + Operator operator, + int prefixLength) { + return new MatchQuery( + queryText, column, boost, fuzziness, maxExpansions, operator, prefixLength); + } + + public static FullTextQuery phrase(String queryText, String column) { + return phrase(queryText, column, 0); + } + + public static FullTextQuery phrase(String queryText, String column, int slop) { + return new PhraseQuery(queryText, column, slop); + } + + public static FullTextQuery multiMatch(String queryText, List columns) { + return multiMatch(queryText, columns, null, Operator.OR); + } + + public static FullTextQuery multiMatch( + String queryText, List columns, List boosts, Operator operator) { + return new MultiMatchQuery(queryText, columns, boosts, operator); + } + + public static FullTextQuery boost(FullTextQuery positive, FullTextQuery negative) { + return boost(positive, negative, 0.5f); + } + + public static FullTextQuery boost( + FullTextQuery positive, FullTextQuery negative, float negativeBoost) { + return new BoostQuery(positive, negative, negativeBoost); + } + + public static FullTextQuery booleanQuery(List clauses) { + return new BooleanQuery(clauses); + } + + /** Match query on a single column. */ + public static final class MatchQuery extends FullTextQuery { + private final String queryText; + private final String column; + private final float boost; + private final Optional fuzziness; + private final int maxExpansions; + private final Operator operator; + private final int prefixLength; + + MatchQuery( + String queryText, + String column, + float boost, + Optional fuzziness, + int maxExpansions, + Operator operator, + int prefixLength) { + Preconditions.checkArgument( + queryText != null && !queryText.isEmpty(), "queryText must not be null or empty"); + Preconditions.checkArgument( + column != null && !column.isEmpty(), "column must not be null or empty"); + Preconditions.checkArgument(maxExpansions >= 1, "maxExpansions must be >= 1"); + Preconditions.checkArgument(prefixLength >= 0, "prefixLength must be >= 0"); + + this.queryText = queryText; + this.column = column; + this.boost = boost; + this.fuzziness = fuzziness; + this.maxExpansions = maxExpansions; + this.operator = operator == null ? Operator.OR : operator; + this.prefixLength = prefixLength; + } + + @Override + public Type getType() { + return Type.MATCH; + } + + public String getQueryText() { + return queryText; + } + + public String getColumn() { + return column; + } + + public float getBoost() { + return boost; + } + + public Optional getFuzziness() { + return fuzziness; + } + + public int getMaxExpansions() { + return maxExpansions; + } + + public Operator getOperator() { + return operator; + } + + public int getPrefixLength() { + return prefixLength; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("type", getType()) + .add("queryText", queryText) + .add("column", column) + .add("boost", boost) + .add("fuzziness", fuzziness) + .add("maxExpansions", maxExpansions) + .add("operator", operator) + .add("prefixLength", prefixLength) + .toString(); + } + } + + /** Phrase query on a single column. */ + public static final class PhraseQuery extends FullTextQuery { + private final String queryText; + private final String column; + private final int slop; + + PhraseQuery(String queryText, String column, int slop) { + Preconditions.checkArgument( + queryText != null && !queryText.isEmpty(), "queryText must not be null or empty"); + Preconditions.checkArgument( + column != null && !column.isEmpty(), "column must not be null or empty"); + Preconditions.checkArgument(slop >= 0, "slop must be >= 0"); + + this.queryText = queryText; + this.column = column; + this.slop = slop; + } + + @Override + public Type getType() { + return Type.MATCH_PHRASE; + } + + public String getQueryText() { + return queryText; + } + + public String getColumn() { + return column; + } + + public int getSlop() { + return slop; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("type", getType()) + .add("queryText", queryText) + .add("column", column) + .add("slop", slop) + .toString(); + } + } + + /** Multi-match query across multiple columns. */ + public static final class MultiMatchQuery extends FullTextQuery { + private final String queryText; + private final List columns; + private final Optional> boosts; + private final Operator operator; + + MultiMatchQuery(String queryText, List columns, List boosts, Operator operator) { + Preconditions.checkArgument( + queryText != null && !queryText.isEmpty(), "queryText must not be null or empty"); + Preconditions.checkArgument( + columns != null && !columns.isEmpty(), "columns must not be null or empty"); + + this.queryText = queryText; + this.columns = + Collections.unmodifiableList(new java.util.ArrayList<>(Objects.requireNonNull(columns))); + this.boosts = boosts == null ? Optional.empty() : Optional.of(boosts); + this.operator = operator == null ? Operator.OR : operator; + } + + @Override + public Type getType() { + return Type.MULTI_MATCH; + } + + public String getQueryText() { + return queryText; + } + + public List getColumns() { + return columns; + } + + public Optional> getBoosts() { + return boosts; + } + + public Operator getOperator() { + return operator; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("type", getType()) + .add("queryText", queryText) + .add("columns", columns) + .add("boosts", boosts) + .add("operator", operator) + .toString(); + } + } + + /** Boost query combining positive and negative queries. */ + public static final class BoostQuery extends FullTextQuery { + private final FullTextQuery positive; + private final FullTextQuery negative; + private final Float negativeBoost; + + BoostQuery(FullTextQuery positive, FullTextQuery negative, float negativeBoost) { + this.positive = Objects.requireNonNull(positive, "positive must not be null"); + this.negative = Objects.requireNonNull(negative, "negative must not be null"); + this.negativeBoost = negativeBoost; + } + + @Override + public Type getType() { + return Type.BOOST; + } + + public FullTextQuery getPositive() { + return positive; + } + + public FullTextQuery getNegative() { + return negative; + } + + public float getNegativeBoost() { + return negativeBoost; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("type", getType()) + .add("positive", positive) + .add("negative", negative) + .add("negativeBoost", negativeBoost) + .toString(); + } + } + + /** Boolean query composed of multiple clauses. */ + public static final class BooleanQuery extends FullTextQuery { + private final List clauses; + + BooleanQuery(List clauses) { + Preconditions.checkArgument( + clauses != null && !clauses.isEmpty(), "clauses must not be null or empty"); + this.clauses = + Collections.unmodifiableList(new java.util.ArrayList<>(Objects.requireNonNull(clauses))); + } + + @Override + public Type getType() { + return Type.BOOLEAN; + } + + public List getClauses() { + return clauses; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("type", getType()) + .add("clauses", clauses) + .toString(); + } + } +} diff --git a/java/src/main/java/org/lance/ipc/LanceScanner.java b/java/src/main/java/org/lance/ipc/LanceScanner.java index 60a619d9063..804b7ea22f3 100644 --- a/java/src/main/java/org/lance/ipc/LanceScanner.java +++ b/java/src/main/java/org/lance/ipc/LanceScanner.java @@ -68,6 +68,7 @@ public static LanceScanner create( options.getLimit(), options.getOffset(), options.getNearest(), + options.getFullTextQuery(), options.isWithRowId(), options.isWithRowAddress(), options.getBatchReadahead(), @@ -88,6 +89,7 @@ static native LanceScanner createScanner( Optional limit, Optional offset, Optional query, + Optional fullTextQuery, boolean withRowId, boolean withRowAddress, int batchReadahead, diff --git a/java/src/main/java/org/lance/ipc/ScanOptions.java b/java/src/main/java/org/lance/ipc/ScanOptions.java index 615a96a2bb5..b73f9e8104e 100644 --- a/java/src/main/java/org/lance/ipc/ScanOptions.java +++ b/java/src/main/java/org/lance/ipc/ScanOptions.java @@ -30,6 +30,7 @@ public class ScanOptions { private final Optional limit; private final Optional offset; private final Optional nearest; + private final Optional fullTextQuery; private final boolean withRowId; private final boolean withRowAddress; private final int batchReadahead; @@ -61,6 +62,7 @@ public ScanOptions( Optional limit, Optional offset, Optional nearest, + Optional fullTextQuery, boolean withRowId, boolean withRowAddress, int batchReadahead, @@ -76,6 +78,7 @@ public ScanOptions( this.limit = limit; this.offset = offset; this.nearest = nearest; + this.fullTextQuery = fullTextQuery; this.withRowId = withRowId; this.withRowAddress = withRowAddress; this.batchReadahead = batchReadahead; @@ -154,6 +157,15 @@ public Optional getNearest() { return nearest; } + /** + * Get the full text search query. + * + * @return Optional containing the full text search query if specified, otherwise empty. + */ + public Optional getFullTextQuery() { + return fullTextQuery; + } + /** * Get whether to include the row ID. * @@ -198,6 +210,7 @@ public String toString() { .add("limit", limit.orElse(null)) .add("offset", offset.orElse(null)) .add("nearest", nearest.orElse(null)) + .add("fullTextQuery", fullTextQuery.orElse(null)) .add("withRowId", withRowId) .add("WithRowAddress", withRowAddress) .add("batchReadahead", batchReadahead) @@ -215,6 +228,7 @@ public static class Builder { private Optional limit = Optional.empty(); private Optional offset = Optional.empty(); private Optional nearest = Optional.empty(); + private Optional fullTextQuery = Optional.empty(); private boolean withRowId = false; private boolean withRowAddress = false; private int batchReadahead = 16; @@ -236,6 +250,7 @@ public Builder(ScanOptions options) { this.limit = options.getLimit(); this.offset = options.getOffset(); this.nearest = options.getNearest(); + this.fullTextQuery = options.getFullTextQuery(); this.withRowId = options.isWithRowId(); this.withRowAddress = options.isWithRowAddress(); this.batchReadahead = options.getBatchReadahead(); @@ -330,6 +345,17 @@ public Builder nearest(Query nearest) { return this; } + /** + * Set the full text search query. + * + * @param fullTextQuery full text search query definition. + * @return Builder instance for method chaining. + */ + public Builder fullTextQuery(FullTextQuery fullTextQuery) { + this.fullTextQuery = Optional.ofNullable(fullTextQuery); + return this; + } + /** * Set whether to include the row ID. * @@ -383,6 +409,7 @@ public ScanOptions build() { limit, offset, nearest, + fullTextQuery, withRowId, withRowAddress, batchReadahead, diff --git a/java/src/test/java/org/lance/ipc/FullTextQueryTest.java b/java/src/test/java/org/lance/ipc/FullTextQueryTest.java new file mode 100755 index 00000000000..595e99eccd8 --- /dev/null +++ b/java/src/test/java/org/lance/ipc/FullTextQueryTest.java @@ -0,0 +1,170 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.ipc; + +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.Optional; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class FullTextQueryTest { + + @Test + void testMatchQueryDefaults() { + FullTextQuery.MatchQuery q = + (FullTextQuery.MatchQuery) FullTextQuery.match("hello world", "body"); + + assertEquals(FullTextQuery.Type.MATCH, q.getType()); + assertEquals("hello world", q.getQueryText()); + assertEquals("body", q.getColumn()); + assertEquals(1.0f, q.getBoost()); + assertFalse(q.getFuzziness().isPresent()); + assertEquals(50, q.getMaxExpansions()); + assertEquals(FullTextQuery.Operator.OR, q.getOperator()); + assertEquals(0, q.getPrefixLength()); + } + + @Test + void testMatchQueryCustomParameters() { + FullTextQuery.MatchQuery q = + (FullTextQuery.MatchQuery) + FullTextQuery.match( + "hello", "title", 2.0f, Optional.of(1), 10, FullTextQuery.Operator.AND, 3); + + assertEquals(FullTextQuery.Type.MATCH, q.getType()); + assertEquals("hello", q.getQueryText()); + assertEquals("title", q.getColumn()); + assertEquals(2.0f, q.getBoost()); + assertEquals(Optional.of(1), q.getFuzziness()); + assertEquals(10, q.getMaxExpansions()); + assertEquals(FullTextQuery.Operator.AND, q.getOperator()); + assertEquals(3, q.getPrefixLength()); + } + + @Test + void testPhraseQueryDefaults() { + FullTextQuery.PhraseQuery q = + (FullTextQuery.PhraseQuery) FullTextQuery.phrase("exact match", "content"); + + assertEquals(FullTextQuery.Type.MATCH_PHRASE, q.getType()); + assertEquals("exact match", q.getQueryText()); + assertEquals("content", q.getColumn()); + assertEquals(0, q.getSlop()); + } + + @Test + void testPhraseQueryCustomSlop() { + FullTextQuery.PhraseQuery q = + (FullTextQuery.PhraseQuery) FullTextQuery.phrase("ordered terms", "content", 2); + + assertEquals(FullTextQuery.Type.MATCH_PHRASE, q.getType()); + assertEquals("ordered terms", q.getQueryText()); + assertEquals("content", q.getColumn()); + assertEquals(2, q.getSlop()); + } + + @Test + void testMultiMatchWithoutBoosts() { + FullTextQuery.MultiMatchQuery q = + (FullTextQuery.MultiMatchQuery) + FullTextQuery.multiMatch("hello", Arrays.asList("title", "body")); + + assertEquals(FullTextQuery.Type.MULTI_MATCH, q.getType()); + assertEquals("hello", q.getQueryText()); + assertEquals(Arrays.asList("title", "body"), q.getColumns()); + assertFalse(q.getBoosts().isPresent()); + assertEquals(FullTextQuery.Operator.OR, q.getOperator()); + } + + @Test + void testMultiMatchWithBoosts() { + FullTextQuery.MultiMatchQuery q = + (FullTextQuery.MultiMatchQuery) + FullTextQuery.multiMatch( + "hello", + Arrays.asList("title", "body"), + Arrays.asList(2.0f, 0.5f), + FullTextQuery.Operator.AND); + + assertEquals(FullTextQuery.Type.MULTI_MATCH, q.getType()); + assertTrue(q.getBoosts().isPresent()); + assertEquals(2, q.getBoosts().get().size()); + assertEquals(2.0f, q.getBoosts().get().get(0)); + assertEquals(0.5f, q.getBoosts().get().get(1)); + assertEquals(FullTextQuery.Operator.AND, q.getOperator()); + assertNotNull(q.toString()); + } + + @Test + void testBoostQuery() { + FullTextQuery.MatchQuery positive = + (FullTextQuery.MatchQuery) FullTextQuery.match("good", "body"); + FullTextQuery.MatchQuery negative = + (FullTextQuery.MatchQuery) FullTextQuery.match("bad", "body"); + + FullTextQuery.BoostQuery q = + (FullTextQuery.BoostQuery) FullTextQuery.boost(positive, negative, 0.3f); + + assertEquals(FullTextQuery.Type.BOOST, q.getType()); + assertEquals(positive, q.getPositive()); + assertEquals(negative, q.getNegative()); + assertEquals(Float.valueOf(0.3f), q.getNegativeBoost()); + } + + @Test + void testBooleanQuery() { + FullTextQuery.MatchQuery match = + (FullTextQuery.MatchQuery) FullTextQuery.match("hello", "body"); + FullTextQuery.MatchQuery mustNot = + (FullTextQuery.MatchQuery) FullTextQuery.match("spam", "body"); + + FullTextQuery.BooleanClause shouldClause = + new FullTextQuery.BooleanClause(FullTextQuery.Occur.SHOULD, match); + FullTextQuery.BooleanClause mustNotClause = + new FullTextQuery.BooleanClause(FullTextQuery.Occur.MUST_NOT, mustNot); + + FullTextQuery.BooleanQuery q = + (FullTextQuery.BooleanQuery) + FullTextQuery.booleanQuery(Arrays.asList(shouldClause, mustNotClause)); + + assertEquals(FullTextQuery.Type.BOOLEAN, q.getType()); + assertNotNull(q.getClauses()); + assertEquals(2, q.getClauses().size()); + assertEquals(FullTextQuery.Occur.SHOULD, q.getClauses().get(0).getOccur()); + assertEquals(FullTextQuery.Type.MATCH, q.getClauses().get(0).getQuery().getType()); + assertEquals(FullTextQuery.Occur.MUST_NOT, q.getClauses().get(1).getOccur()); + } + + @Test + void testBooleanQuerySingleClause() { + FullTextQuery.MatchQuery match = + (FullTextQuery.MatchQuery) FullTextQuery.match("hello", "body"); + FullTextQuery.BooleanClause shouldClause = + new FullTextQuery.BooleanClause(FullTextQuery.Occur.SHOULD, match); + + FullTextQuery.BooleanQuery q = + (FullTextQuery.BooleanQuery) + FullTextQuery.booleanQuery(Collections.singletonList(shouldClause)); + + assertEquals(FullTextQuery.Type.BOOLEAN, q.getType()); + assertEquals(1, q.getClauses().size()); + assertEquals(FullTextQuery.Occur.SHOULD, q.getClauses().get(0).getOccur()); + } +} diff --git a/java/src/test/java/org/lance/ipc/LanceScannerFullTextSearchTest.java b/java/src/test/java/org/lance/ipc/LanceScannerFullTextSearchTest.java new file mode 100755 index 00000000000..1c46b399195 --- /dev/null +++ b/java/src/test/java/org/lance/ipc/LanceScannerFullTextSearchTest.java @@ -0,0 +1,168 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.ipc; + +import org.lance.Dataset; +import org.lance.WriteParams; +import org.lance.index.IndexOptions; +import org.lance.index.IndexParams; +import org.lance.index.IndexType; +import org.lance.index.scalar.ScalarIndexParams; + +import org.apache.arrow.c.ArrowArrayStream; +import org.apache.arrow.c.Data; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.arrow.vector.ipc.ArrowStreamReader; +import org.apache.arrow.vector.ipc.ArrowStreamWriter; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.jupiter.api.Test; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Collections; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class LanceScannerFullTextSearchTest { + + @Test + void testMatchQuery() throws Exception { + runFtsQuery("memory://fts_java_match", FullTextQuery.match("hello", "doc"), 2L); + } + + @Test + void testPhraseQuery() throws Exception { + runFtsQuery("memory://fts_java_phrase", FullTextQuery.phrase("hello world", "doc", 0), 1L); + } + + @Test + void testBoostQuery() throws Exception { + FullTextQuery positive = FullTextQuery.match("hello", "doc"); + FullTextQuery negative = FullTextQuery.match("world", "doc"); + FullTextQuery boosted = FullTextQuery.boost(positive, negative, 0.3f); + + runFtsQuery("memory://fts_java_boost", boosted, 2L); + } + + @Test + void testMultiMatch() throws Exception { + FullTextQuery multiMatch = FullTextQuery.multiMatch("hello", Arrays.asList("doc", "title")); + runFtsQuery("memory://fts_java_multimatch", multiMatch, 3); + } + + @Test + void testBooleanQuery() throws Exception { + FullTextQuery.MatchQuery shouldMatch = + (FullTextQuery.MatchQuery) FullTextQuery.match("hello", "doc"); + FullTextQuery.MatchQuery mustNotMatch = + (FullTextQuery.MatchQuery) FullTextQuery.match("lance", "doc"); + + FullTextQuery.BooleanClause shouldClause = + new FullTextQuery.BooleanClause(FullTextQuery.Occur.SHOULD, shouldMatch); + FullTextQuery.BooleanClause mustNotClause = + new FullTextQuery.BooleanClause(FullTextQuery.Occur.MUST_NOT, mustNotMatch); + + FullTextQuery booleanQuery = + FullTextQuery.booleanQuery(Arrays.asList(shouldClause, mustNotClause)); + + runFtsQuery("memory://fts_java_boolean", booleanQuery, 1L); + } + + private void runFtsQuery(String uri, FullTextQuery query, long expectedTotal) throws Exception { + + Schema schema = + new Schema( + Arrays.asList( + Field.nullable("doc", ArrowType.Utf8.INSTANCE), + Field.nullable("title", ArrowType.Utf8.INSTANCE)), + null); + + try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + VarCharVector docVector = (VarCharVector) root.getVector("doc"); + VarCharVector titleVector = (VarCharVector) root.getVector("title"); + + docVector.allocateNew(); + docVector.setSafe(0, "hello world".getBytes(StandardCharsets.UTF_8)); + docVector.setSafe(1, "hello lance".getBytes(StandardCharsets.UTF_8)); + docVector.setSafe(2, "other text".getBytes(StandardCharsets.UTF_8)); + + titleVector.allocateNew(); + titleVector.setSafe(0, "bye world".getBytes(StandardCharsets.UTF_8)); + titleVector.setSafe(1, "bye lance".getBytes(StandardCharsets.UTF_8)); + titleVector.setSafe(2, "say hello".getBytes(StandardCharsets.UTF_8)); + + root.setRowCount(3); + + ByteArrayOutputStream out = new ByteArrayOutputStream(); + try (ArrowStreamWriter writer = new ArrowStreamWriter(root, null, out)) { + writer.start(); + writer.writeBatch(); + writer.end(); + } + + byte[] arrowData = out.toByteArray(); + ByteArrayInputStream in = new ByteArrayInputStream(arrowData); + try (ArrowStreamReader reader = new ArrowStreamReader(in, allocator); + ArrowArrayStream stream = ArrowArrayStream.allocateNew(allocator)) { + Data.exportArrayStream(allocator, reader, stream); + + WriteParams writeParams = + new WriteParams.Builder().withMode(WriteParams.WriteMode.CREATE).build(); + + try (Dataset dataset = Dataset.create(allocator, stream, uri, writeParams)) { + ScalarIndexParams scalarParams = + ScalarIndexParams.create( + "inverted", + "{\"base_tokenizer\":\"simple\",\"language\":\"English\",\"with_position\":true}"); + IndexParams indexParams = + IndexParams.builder().setScalarIndexParams(scalarParams).build(); + + dataset.createIndex( + IndexOptions.builder( + Collections.singletonList("doc"), IndexType.INVERTED, indexParams) + .withIndexName("doc_idx") + .build()); + + dataset.createIndex( + IndexOptions.builder( + Collections.singletonList("title"), IndexType.INVERTED, indexParams) + .withIndexName("title_idx") + .build()); + + ScanOptions scanOptions = new ScanOptions.Builder().fullTextQuery(query).build(); + + try (LanceScanner scanner = dataset.newScan(scanOptions)) { + long total = 0L; + try (ArrowReader arrowReader = scanner.scanBatches()) { + while (arrowReader.loadNextBatch()) { + total += arrowReader.getVectorSchemaRoot().getRowCount(); + } + } + assertEquals(expectedTotal, total); + } + } + } + } + } + } +}