Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions solr/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ New Features

* SOLR-16667: LTR Add feature vector caching for ranking. (Anna Ruggero, Alessandro Benedetti)

* SOLR-17815: Add parameter to regulate for ACORN-based filtering in vector search. (Anna Ruggero, Alessandro Benedetti)

Improvements
---------------------

Expand Down
30 changes: 25 additions & 5 deletions solr/core/src/java/org/apache/solr/schema/DenseVectorField.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.apache.lucene.search.Query;
import org.apache.lucene.search.SeededKnnVectorQuery;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.knn.KnnSearchStrategy;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.solr.common.SolrException;
Expand Down Expand Up @@ -379,17 +380,36 @@ public Query getKnnVectorQuery(
int topK,
Query filterQuery,
Query seedQuery,
EarlyTerminationParams earlyTermination) {
EarlyTerminationParams earlyTermination,
Integer filteredSearchThreshold) {

DenseVectorParser vectorBuilder =
getVectorBuilder(vectorToSearch, DenseVectorParser.BuilderPhase.QUERY);

final Query knnQuery =
switch (vectorEncoding) {
case FLOAT32 -> new KnnFloatVectorQuery(
fieldName, vectorBuilder.getFloatVector(), topK, filterQuery);
case BYTE -> new KnnByteVectorQuery(
fieldName, vectorBuilder.getByteVector(), topK, filterQuery);
case FLOAT32 -> {
if (filteredSearchThreshold != null) {
KnnSearchStrategy knnSearchStrategy =
new KnnSearchStrategy.Hnsw(filteredSearchThreshold);
yield new KnnFloatVectorQuery(
fieldName, vectorBuilder.getFloatVector(), topK, filterQuery, knnSearchStrategy);
} else {
yield new KnnFloatVectorQuery(
fieldName, vectorBuilder.getFloatVector(), topK, filterQuery);
}
}
case BYTE -> {
if (filteredSearchThreshold != null) {
KnnSearchStrategy knnSearchStrategy =
new KnnSearchStrategy.Hnsw(filteredSearchThreshold);
yield new KnnByteVectorQuery(
fieldName, vectorBuilder.getByteVector(), topK, filterQuery, knnSearchStrategy);
} else {
yield new KnnByteVectorQuery(
fieldName, vectorBuilder.getByteVector(), topK, filterQuery);
}
}
};

final boolean seedEnabled = (seedQuery != null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ public class KnnQParser extends AbstractVectorQParserBase {
protected static final String TOP_K = "topK";
protected static final int DEFAULT_TOP_K = 10;
protected static final String SEED_QUERY = "seedQuery";
protected static final String FILTERED_SEARCH_THRESHOLD = "filteredSearchThreshold";

// parameters for PatienceKnnVectorQuery, a version of knn vector query that exits early when HNSW
// queue saturates over a {@code #saturationThreshold} for more than {@code #patience} times.
Expand Down Expand Up @@ -107,13 +108,15 @@ public Query parse() throws SyntaxError {
final DenseVectorField denseVectorType = getCheckedFieldType(schemaField);
final String vectorToSearch = getVectorToSearch();
final int topK = localParams.getInt(TOP_K, DEFAULT_TOP_K);
final Integer filteredSearchThreshold = localParams.getInt(FILTERED_SEARCH_THRESHOLD);

return denseVectorType.getKnnVectorQuery(
schemaField.getName(),
vectorToSearch,
topK,
getFilterQuery(),
getSeedQuery(),
getEarlyTerminationParams());
getEarlyTerminationParams(),
filteredSearchThreshold);
}
}
287 changes: 287 additions & 0 deletions solr/core/src/test/org/apache/solr/schema/DenseVectorFieldTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@
import java.util.Map;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.PatienceKnnVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.SeededKnnVectorQuery;
import org.apache.lucene.search.knn.KnnSearchStrategy;
import org.apache.solr.client.solrj.request.JavaBinUpdateRequestCodec;
import org.apache.solr.client.solrj.request.UpdateRequest;
import org.apache.solr.common.SolrException;
Expand All @@ -35,6 +42,7 @@
import org.apache.solr.handler.loader.JavabinLoader;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.response.SolrQueryResponse;
import org.apache.solr.search.neural.KnnQParser;
import org.apache.solr.update.CommitUpdateCommand;
import org.apache.solr.update.processor.UpdateRequestProcessor;
import org.apache.solr.update.processor.UpdateRequestProcessorChain;
Expand Down Expand Up @@ -838,4 +846,283 @@ public void testIndexingViaJavaBin() throws Exception {
deleteCore();
}
}

@Test
public void testFilteredSearchThreshold_floatNoThresholdInInput_shouldSetDefaultThreshold()
throws Exception {
try {
Integer expectedThreshold = KnnSearchStrategy.DEFAULT_FILTERED_SEARCH_THRESHOLD;

initCore("solrconfig-basic.xml", "schema-densevector.xml");
IndexSchema schema = h.getCore().getLatestSchema();
SchemaField vectorField = schema.getField("vector");
assertNotNull(vectorField);
DenseVectorField type = (DenseVectorField) vectorField.getType();
KnnFloatVectorQuery vectorQuery =
(KnnFloatVectorQuery)
type.getKnnVectorQuery("vector", "[2, 1, 3, 4]", 3, null, null, null, null);
KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy();
Integer threshold = strategy.filteredSearchThreshold();

assertEquals(expectedThreshold, threshold);
} finally {
deleteCore();
}
}

@Test
public void testFilteredSearchThreshold_floatThresholdInInput_shouldSetCustomThreshold()
throws Exception {
try {
Integer expectedThreshold = 30;

initCore("solrconfig-basic.xml", "schema-densevector.xml");
IndexSchema schema = h.getCore().getLatestSchema();
SchemaField vectorField = schema.getField("vector");
assertNotNull(vectorField);
DenseVectorField type = (DenseVectorField) vectorField.getType();
KnnFloatVectorQuery vectorQuery =
(KnnFloatVectorQuery)
type.getKnnVectorQuery(
"vector", "[2, 1, 3, 4]", 3, null, null, null, expectedThreshold);
KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy();
Integer threshold = strategy.filteredSearchThreshold();

assertEquals(expectedThreshold, threshold);
} finally {
deleteCore();
}
}

@Test
public void testFilteredSearchThreshold_seededFloatThresholdInInput_shouldSetCustomThreshold()
throws Exception {
try {
Query seedQuery = new BooleanQuery.Builder().build();
Integer expectedThreshold = 30;

initCore("solrconfig-basic.xml", "schema-densevector.xml");
IndexSchema schema = h.getCore().getLatestSchema();
SchemaField vectorField = schema.getField("vector");
assertNotNull(vectorField);
DenseVectorField type = (DenseVectorField) vectorField.getType();
SeededKnnVectorQuery vectorQuery =
(SeededKnnVectorQuery)
type.getKnnVectorQuery(
"vector", "[2, 1, 3, 4]", 3, null, seedQuery, null, expectedThreshold);
KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy();
Integer threshold = strategy.filteredSearchThreshold();

assertEquals(expectedThreshold, threshold);
} finally {
deleteCore();
}
}

@Test
public void
testFilteredSearchThreshold_earlyTerminationFloatThresholdInInput_shouldSetCustomThreshold()
throws Exception {
try {
KnnQParser.EarlyTerminationParams earlyTermination =
new KnnQParser.EarlyTerminationParams(true, 0.995, 7);
Integer expectedThreshold = 30;

initCore("solrconfig-basic.xml", "schema-densevector.xml");
IndexSchema schema = h.getCore().getLatestSchema();
SchemaField vectorField = schema.getField("vector");
assertNotNull(vectorField);
DenseVectorField type = (DenseVectorField) vectorField.getType();
PatienceKnnVectorQuery vectorQuery =
(PatienceKnnVectorQuery)
type.getKnnVectorQuery(
"vector", "[2, 1, 3, 4]", 3, null, null, earlyTermination, expectedThreshold);
KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy();
Integer threshold = strategy.filteredSearchThreshold();

assertEquals(expectedThreshold, threshold);
} finally {
deleteCore();
}
}

@Test
public void
testFilteredSearchThreshold_seededAndEarlyTerminationFloatThresholdInInput_shouldSetCustomThreshold()
throws Exception {
try {
Query seedQuery = new BooleanQuery.Builder().build();
KnnQParser.EarlyTerminationParams earlyTermination =
new KnnQParser.EarlyTerminationParams(true, 0.995, 7);
Integer expectedThreshold = 30;

initCore("solrconfig-basic.xml", "schema-densevector.xml");
IndexSchema schema = h.getCore().getLatestSchema();
SchemaField vectorField = schema.getField("vector");
assertNotNull(vectorField);
DenseVectorField type = (DenseVectorField) vectorField.getType();
PatienceKnnVectorQuery vectorQuery =
(PatienceKnnVectorQuery)
type.getKnnVectorQuery(
"vector",
"[2, 1, 3, 4]",
3,
null,
seedQuery,
earlyTermination,
expectedThreshold);
KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy();
Integer threshold = strategy.filteredSearchThreshold();

assertEquals(expectedThreshold, threshold);
} finally {
deleteCore();
}
}

@Test
public void testFilteredSearchThreshold_byteNoThresholdInInput_shouldSetDefaultThreshold()
throws Exception {
try {
Integer expectedThreshold = KnnSearchStrategy.DEFAULT_FILTERED_SEARCH_THRESHOLD;

initCore("solrconfig-basic.xml", "schema-densevector.xml");
IndexSchema schema = h.getCore().getLatestSchema();
SchemaField vectorField = schema.getField("vector_byte_encoding");
assertNotNull(vectorField);
DenseVectorField type = (DenseVectorField) vectorField.getType();
KnnByteVectorQuery vectorQuery =
(KnnByteVectorQuery)
type.getKnnVectorQuery(
"vector_byte_encoding", "[2, 1, 3, 4]", 3, null, null, null, null);
KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy();
Integer threshold = strategy.filteredSearchThreshold();

assertEquals(expectedThreshold, threshold);
} finally {
deleteCore();
}
}

@Test
public void testFilteredSearchThreshold_byteThresholdInInput_shouldSetCustomThreshold()
throws Exception {
try {
Integer expectedThreshold = 30;

initCore("solrconfig-basic.xml", "schema-densevector.xml");
IndexSchema schema = h.getCore().getLatestSchema();
SchemaField vectorField = schema.getField("vector_byte_encoding");
assertNotNull(vectorField);
DenseVectorField type = (DenseVectorField) vectorField.getType();
KnnByteVectorQuery vectorQuery =
(KnnByteVectorQuery)
type.getKnnVectorQuery(
"vector_byte_encoding", "[2, 1, 3, 4]", 3, null, null, null, expectedThreshold);
KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy();
Integer threshold = strategy.filteredSearchThreshold();

assertEquals(expectedThreshold, threshold);
} finally {
deleteCore();
}
}

@Test
public void testFilteredSearchThreshold_seededByteThresholdInInput_shouldSetCustomThreshold()
throws Exception {
try {
Query seedQuery = new BooleanQuery.Builder().build();
Integer expectedThreshold = 30;

initCore("solrconfig-basic.xml", "schema-densevector.xml");
IndexSchema schema = h.getCore().getLatestSchema();
SchemaField vectorField = schema.getField("vector_byte_encoding");
assertNotNull(vectorField);
DenseVectorField type = (DenseVectorField) vectorField.getType();
SeededKnnVectorQuery vectorQuery =
(SeededKnnVectorQuery)
type.getKnnVectorQuery(
"vector_byte_encoding",
"[2, 1, 3, 4]",
3,
null,
seedQuery,
null,
expectedThreshold);
KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy();
Integer threshold = strategy.filteredSearchThreshold();

assertEquals(expectedThreshold, threshold);
} finally {
deleteCore();
}
}

@Test
public void
testFilteredSearchThreshold_earlyTerminationByteThresholdInInput_shouldSetCustomThreshold()
throws Exception {
try {
KnnQParser.EarlyTerminationParams earlyTermination =
new KnnQParser.EarlyTerminationParams(true, 0.995, 7);
Integer expectedThreshold = 30;

initCore("solrconfig-basic.xml", "schema-densevector.xml");
IndexSchema schema = h.getCore().getLatestSchema();
SchemaField vectorField = schema.getField("vector_byte_encoding");
assertNotNull(vectorField);
DenseVectorField type = (DenseVectorField) vectorField.getType();
PatienceKnnVectorQuery vectorQuery =
(PatienceKnnVectorQuery)
type.getKnnVectorQuery(
"vector_byte_encoding",
"[2, 1, 3, 4]",
3,
null,
null,
earlyTermination,
expectedThreshold);
KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy();
Integer threshold = strategy.filteredSearchThreshold();

assertEquals(expectedThreshold, threshold);
} finally {
deleteCore();
}
}

@Test
public void
testFilteredSearchThreshold_seededAndEarlyTerminationByteThresholdInInput_shouldSetCustomThreshold()
throws Exception {
try {
Query seedQuery = new BooleanQuery.Builder().build();
KnnQParser.EarlyTerminationParams earlyTermination =
new KnnQParser.EarlyTerminationParams(true, 0.995, 7);
Integer expectedThreshold = 30;

initCore("solrconfig-basic.xml", "schema-densevector.xml");
IndexSchema schema = h.getCore().getLatestSchema();
SchemaField vectorField = schema.getField("vector_byte_encoding");
assertNotNull(vectorField);
DenseVectorField type = (DenseVectorField) vectorField.getType();
PatienceKnnVectorQuery vectorQuery =
(PatienceKnnVectorQuery)
type.getKnnVectorQuery(
"vector_byte_encoding",
"[2, 1, 3, 4]",
3,
null,
seedQuery,
earlyTermination,
expectedThreshold);
KnnSearchStrategy.Hnsw strategy = (KnnSearchStrategy.Hnsw) vectorQuery.getSearchStrategy();
Integer threshold = strategy.filteredSearchThreshold();

assertEquals(expectedThreshold, threshold);
} finally {
deleteCore();
}
}
}
Loading
Loading