From 48dbcfafd440d76982de58a6b36aa1cea01ef737 Mon Sep 17 00:00:00 2001 From: Christine Poerschke Date: Fri, 23 Aug 2024 18:31:55 +0100 Subject: [PATCH] support Lucene's (proposed) HNSW search seeding feature --- .../org/apache/solr/schema/DenseVectorField.java | 10 ++++++++-- .../search/neural/AbstractVectorQParserBase.java | 4 ++++ .../org/apache/solr/search/neural/KnnQParser.java | 2 +- .../solr/search/neural/VectorSimilarityQParser.java | 12 ++++++++++-- 4 files changed, 23 insertions(+), 5 deletions(-) diff --git a/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java b/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java index 4d528361dd44..7cdc6b75221f 100644 --- a/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java +++ b/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java @@ -365,6 +365,11 @@ public ValueSource getValueSource(SchemaField field, QParser parser) { public Query getKnnVectorQuery( String fieldName, String vectorToSearch, int topK, Query filterQuery) { + return getKnnVectorQuery(fieldName, vectorToSearch, topK, filterQuery, null); + } + + public Query getKnnVectorQuery( + String fieldName, String vectorToSearch, int topK, Query filterQuery, Query seedQuery) { DenseVectorParser vectorBuilder = getVectorBuilder(vectorToSearch, DenseVectorParser.BuilderPhase.QUERY); @@ -372,9 +377,10 @@ public Query getKnnVectorQuery( switch (vectorEncoding) { case FLOAT32: return new KnnFloatVectorQuery( - fieldName, vectorBuilder.getFloatVector(), topK, filterQuery); + fieldName, vectorBuilder.getFloatVector(), topK, filterQuery /*, seedQuery */); case BYTE: - return new KnnByteVectorQuery(fieldName, vectorBuilder.getByteVector(), topK, filterQuery); + return new KnnByteVectorQuery( + fieldName, vectorBuilder.getByteVector(), topK, filterQuery /*, seedQuery */); default: throw new SolrException( SolrException.ErrorCode.SERVER_ERROR, diff --git a/solr/core/src/java/org/apache/solr/search/neural/AbstractVectorQParserBase.java b/solr/core/src/java/org/apache/solr/search/neural/AbstractVectorQParserBase.java index 4cafb45744e1..b57b5f5fd2ec 100644 --- a/solr/core/src/java/org/apache/solr/search/neural/AbstractVectorQParserBase.java +++ b/solr/core/src/java/org/apache/solr/search/neural/AbstractVectorQParserBase.java @@ -76,6 +76,10 @@ protected static DenseVectorField getCheckedFieldType(SchemaField schemaField) { return (DenseVectorField) fieldType; } + protected Query getSeedQuery() throws SolrException, SyntaxError { + return null; // TODO + } + protected Query getFilterQuery() throws SolrException, SyntaxError { // Default behavior of FQ wrapping, and suitability of some local params diff --git a/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java b/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java index 166dada5b7f0..071f93753bb9 100644 --- a/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java +++ b/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java @@ -41,6 +41,6 @@ public Query parse() throws SyntaxError { final int topK = localParams.getInt(TOP_K, DEFAULT_TOP_K); return denseVectorType.getKnnVectorQuery( - schemaField.getName(), vectorToSearch, topK, getFilterQuery()); + schemaField.getName(), vectorToSearch, topK, getFilterQuery(), getSeedQuery()); } } diff --git a/solr/core/src/java/org/apache/solr/search/neural/VectorSimilarityQParser.java b/solr/core/src/java/org/apache/solr/search/neural/VectorSimilarityQParser.java index e3ec2f242f76..074aaa17a4df 100644 --- a/solr/core/src/java/org/apache/solr/search/neural/VectorSimilarityQParser.java +++ b/solr/core/src/java/org/apache/solr/search/neural/VectorSimilarityQParser.java @@ -62,10 +62,18 @@ public Query parse() throws SyntaxError { switch (vectorEncoding) { case FLOAT32: return new FloatVectorSimilarityQuery( - fieldName, vectorBuilder.getFloatVector(), minTraverse, minReturn, getFilterQuery()); + fieldName, + vectorBuilder.getFloatVector(), + minTraverse, + minReturn, + getFilterQuery() /*, getSeedQuery() */); case BYTE: return new ByteVectorSimilarityQuery( - fieldName, vectorBuilder.getByteVector(), minTraverse, minReturn, getFilterQuery()); + fieldName, + vectorBuilder.getByteVector(), + minTraverse, + minReturn, + getFilterQuery() /*, getSeedQuery() */); default: throw new SolrException( SolrException.ErrorCode.SERVER_ERROR,