diff --git a/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java b/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java index 773c1e6337d1..ea134f67e955 100644 --- a/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java +++ b/solr/core/src/java/org/apache/solr/schema/DenseVectorField.java @@ -367,6 +367,11 @@ public ValueSource getValueSource(SchemaField field, QParser parser) { public Query getKnnVectorQuery( String fieldName, String vectorToSearch, int topK, Query filterQuery) { + return getKnnVectorQuery(fieldName, vectorToSearch, topK, filterQuery, null); + } + + public Query getKnnVectorQuery( + String fieldName, String vectorToSearch, int topK, Query filterQuery, Query seedQuery) { DenseVectorParser vectorBuilder = getVectorBuilder(vectorToSearch, DenseVectorParser.BuilderPhase.QUERY); @@ -374,9 +379,10 @@ public Query getKnnVectorQuery( switch (vectorEncoding) { case FLOAT32: return new KnnFloatVectorQuery( - fieldName, vectorBuilder.getFloatVector(), topK, filterQuery); + fieldName, vectorBuilder.getFloatVector(), topK, filterQuery /*, seedQuery */); case BYTE: - return new KnnByteVectorQuery(fieldName, vectorBuilder.getByteVector(), topK, filterQuery); + return new KnnByteVectorQuery( + fieldName, vectorBuilder.getByteVector(), topK, filterQuery /*, seedQuery */); default: throw new SolrException( SolrException.ErrorCode.SERVER_ERROR, diff --git a/solr/core/src/java/org/apache/solr/search/neural/AbstractVectorQParserBase.java b/solr/core/src/java/org/apache/solr/search/neural/AbstractVectorQParserBase.java index 4cafb45744e1..b57b5f5fd2ec 100644 --- a/solr/core/src/java/org/apache/solr/search/neural/AbstractVectorQParserBase.java +++ b/solr/core/src/java/org/apache/solr/search/neural/AbstractVectorQParserBase.java @@ -76,6 +76,10 @@ protected static DenseVectorField getCheckedFieldType(SchemaField schemaField) { return (DenseVectorField) fieldType; } + protected Query getSeedQuery() throws SolrException, SyntaxError { + return null; // TODO + } + protected Query getFilterQuery() throws SolrException, SyntaxError { // Default behavior of FQ wrapping, and suitability of some local params diff --git a/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java b/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java index b6d9f2541cd0..97580a083eb3 100644 --- a/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java +++ b/solr/core/src/java/org/apache/solr/search/neural/KnnQParser.java @@ -41,6 +41,6 @@ public Query parse() throws SyntaxError { final int topK = localParams.getInt(TOP_K, DEFAULT_TOP_K); return denseVectorType.getKnnVectorQuery( - schemaField.getName(), vectorToSearch, topK, getFilterQuery()); + schemaField.getName(), vectorToSearch, topK, getFilterQuery(), getSeedQuery()); } } diff --git a/solr/core/src/java/org/apache/solr/search/neural/VectorSimilarityQParser.java b/solr/core/src/java/org/apache/solr/search/neural/VectorSimilarityQParser.java index e3ec2f242f76..074aaa17a4df 100644 --- a/solr/core/src/java/org/apache/solr/search/neural/VectorSimilarityQParser.java +++ b/solr/core/src/java/org/apache/solr/search/neural/VectorSimilarityQParser.java @@ -62,10 +62,18 @@ public Query parse() throws SyntaxError { switch (vectorEncoding) { case FLOAT32: return new FloatVectorSimilarityQuery( - fieldName, vectorBuilder.getFloatVector(), minTraverse, minReturn, getFilterQuery()); + fieldName, + vectorBuilder.getFloatVector(), + minTraverse, + minReturn, + getFilterQuery() /*, getSeedQuery() */); case BYTE: return new ByteVectorSimilarityQuery( - fieldName, vectorBuilder.getByteVector(), minTraverse, minReturn, getFilterQuery()); + fieldName, + vectorBuilder.getByteVector(), + minTraverse, + minReturn, + getFilterQuery() /*, getSeedQuery() */); default: throw new SolrException( SolrException.ErrorCode.SERVER_ERROR,