From 7ec22e3ec2b8e47ff3f193f8acd507bf90572544 Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Wed, 14 May 2025 11:56:58 +0200 Subject: [PATCH 01/54] Fixed default value for FieldValueFeature --- .../java/org/apache/solr/ltr/feature/FieldValueFeature.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/FieldValueFeature.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/FieldValueFeature.java index 7a41916f29a7..bc10e6622a86 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/FieldValueFeature.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/FieldValueFeature.java @@ -125,6 +125,9 @@ public FieldValueFeatureWeight( super(FieldValueFeature.this, searcher, request, originalQuery, efi); if (searcher instanceof SolrIndexSearcher) { schemaField = ((SolrIndexSearcher) searcher).getSchema().getFieldOrNull(field); + if (schemaField.getDefaultValue() != null) { + setDefaultValue(schemaField.getDefaultValue()); + } } else { // some tests pass a null or a non-SolrIndexSearcher searcher schemaField = null; } From 70ecde1320c878bcdc7c81e480ff290b0654d13c Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Wed, 14 May 2025 14:51:18 +0200 Subject: [PATCH 02/54] Fixed remaining tests --- .../java/org/apache/solr/ltr/feature/FieldValueFeature.java | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/FieldValueFeature.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/FieldValueFeature.java index bc10e6622a86..a6583d3a0721 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/FieldValueFeature.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/FieldValueFeature.java @@ -125,8 +125,10 @@ public FieldValueFeatureWeight( super(FieldValueFeature.this, searcher, request, originalQuery, efi); if (searcher instanceof SolrIndexSearcher) { schemaField = ((SolrIndexSearcher) searcher).getSchema().getFieldOrNull(field); - if (schemaField.getDefaultValue() != null) { - setDefaultValue(schemaField.getDefaultValue()); + if (schemaField != null) { + if (schemaField.getDefaultValue() != null) { + setDefaultValue(schemaField.getDefaultValue()); + } } } else { // some tests pass a null or a non-SolrIndexSearcher searcher schemaField = null; From 07865770b43ebf6d751ee402953fe1b7402d5ca7 Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Thu, 15 May 2025 14:34:54 +0200 Subject: [PATCH 03/54] Removed default from schema --- .../java/org/apache/solr/ltr/feature/FieldValueFeature.java | 5 ----- 1 file changed, 5 deletions(-) diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/FieldValueFeature.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/FieldValueFeature.java index a6583d3a0721..7a41916f29a7 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/FieldValueFeature.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/FieldValueFeature.java @@ -125,11 +125,6 @@ public FieldValueFeatureWeight( super(FieldValueFeature.this, searcher, request, originalQuery, efi); if (searcher instanceof SolrIndexSearcher) { schemaField = ((SolrIndexSearcher) searcher).getSchema().getFieldOrNull(field); - if (schemaField != null) { - if (schemaField.getDefaultValue() != null) { - setDefaultValue(schemaField.getDefaultValue()); - } - } } else { // some tests pass a null or a non-SolrIndexSearcher searcher schemaField = null; } From 193908dd13a42081f2ca10934c6d96759b207c55 Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Tue, 10 Jun 2025 18:36:51 +0200 Subject: [PATCH 04/54] Added rerankingFeatureVectorCache and loggingFeatureVectorCache --- .../java/org/apache/solr/core/SolrConfig.java | 16 +++++++++- .../apache/solr/search/SolrIndexSearcher.java | 32 +++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/solr/core/src/java/org/apache/solr/core/SolrConfig.java b/solr/core/src/java/org/apache/solr/core/SolrConfig.java index c9482838d4ac..b37fff82bbf3 100644 --- a/solr/core/src/java/org/apache/solr/core/SolrConfig.java +++ b/solr/core/src/java/org/apache/solr/core/SolrConfig.java @@ -301,6 +301,12 @@ private SolrConfig(SolrResourceLoader loader, String name, Properties substituta queryResultCacheConfig = CacheConfig.getConfig( this, get("query").get("queryResultCache"), "query/queryResultCache"); + rerankingFeatureVectorCacheConfig = + CacheConfig.getConfig( + this, get("query").get("rerankingFeatureVectorCache"), "query/rerankingFeatureVectorCache"); + loggingFeatureVectorCacheConfig = + CacheConfig.getConfig( + this, get("query").get("loggingFeatureVectorCache"), "query/loggingFeatureVectorCache"); documentCacheConfig = CacheConfig.getConfig(this, get("query").get("documentCache"), "query/documentCache"); CacheConfig conf = @@ -662,6 +668,8 @@ public SolrRequestParsers getRequestParsers() { public final CacheConfig queryResultCacheConfig; public final CacheConfig documentCacheConfig; public final CacheConfig fieldValueCacheConfig; + public final CacheConfig rerankingFeatureVectorCacheConfig; + public final CacheConfig loggingFeatureVectorCacheConfig; public final Map userCacheConfigs; // SolrIndexSearcher - more... public final boolean useFilterForSortedQuery; @@ -998,7 +1006,13 @@ public Map toMap(Map result) { } addCacheConfig( - m, filterCacheConfig, queryResultCacheConfig, documentCacheConfig, fieldValueCacheConfig); + m, + filterCacheConfig, + queryResultCacheConfig, + documentCacheConfig, + fieldValueCacheConfig, + rerankingFeatureVectorCacheConfig, + loggingFeatureVectorCacheConfig); m = new LinkedHashMap<>(); result.put("requestDispatcher", m); if (httpCachingConfig != null) m.put("httpCaching", httpCachingConfig); diff --git a/solr/core/src/java/org/apache/solr/search/SolrIndexSearcher.java b/solr/core/src/java/org/apache/solr/search/SolrIndexSearcher.java index 5498d9202a83..3470867fe0fb 100644 --- a/solr/core/src/java/org/apache/solr/search/SolrIndexSearcher.java +++ b/solr/core/src/java/org/apache/solr/search/SolrIndexSearcher.java @@ -165,6 +165,8 @@ public class SolrIndexSearcher extends IndexSearcher implements Closeable, SolrI private final SolrCache filterCache; private final SolrCache queryResultCache; private final SolrCache fieldValueCache; + private final SolrCache rerankingFeatureVectorCache; + private final SolrCache loggingFeatureVectorCache; private final LongAdder fullSortCount = new LongAdder(); private final LongAdder skipSortCount = new LongAdder(); private final LongAdder liveDocsNaiveCacheHitCount = new LongAdder(); @@ -449,6 +451,16 @@ public SolrIndexSearcher( ? null : solrConfig.queryResultCacheConfig.newInstance(); if (queryResultCache != null) clist.add(queryResultCache); + rerankingFeatureVectorCache = + solrConfig.rerankingFeatureVectorCacheConfig == null + ? null + : solrConfig.rerankingFeatureVectorCacheConfig.newInstance(); + if (rerankingFeatureVectorCache != null) clist.add(rerankingFeatureVectorCache); + loggingFeatureVectorCache = + solrConfig.loggingFeatureVectorCacheConfig == null + ? null + : solrConfig.loggingFeatureVectorCacheConfig.newInstance(); + if (loggingFeatureVectorCache != null) clist.add(loggingFeatureVectorCache); SolrCache documentCache = docFetcher.getDocumentCache(); if (documentCache != null) clist.add(documentCache); @@ -470,6 +482,8 @@ public SolrIndexSearcher( this.filterCache = null; this.queryResultCache = null; this.fieldValueCache = null; + this.rerankingFeatureVectorCache = null; + this.loggingFeatureVectorCache = null; this.cacheMap = NO_GENERIC_CACHES; this.cacheList = NO_CACHES; } @@ -685,6 +699,14 @@ public SolrCache getFilterCache() { return filterCache; } + public SolrCache getRerankingFeatureVectorCache() { + return rerankingFeatureVectorCache; + } + + public SolrCache getLoggingFeatureVectorCache() { + return loggingFeatureVectorCache; + } + // // Set default regenerators on filter and query caches if they don't have any // @@ -727,6 +749,16 @@ public boolean regenerateItem( }); } + if (solrConfig.rerankingFeatureVectorCacheConfig != null + && solrConfig.rerankingFeatureVectorCacheConfig.getRegenerator() == null) { + solrConfig.rerankingFeatureVectorCacheConfig.setRegenerator(new NoOpRegenerator()); + } + + if (solrConfig.loggingFeatureVectorCacheConfig != null + && solrConfig.loggingFeatureVectorCacheConfig.getRegenerator() == null) { + solrConfig.loggingFeatureVectorCacheConfig.setRegenerator(new NoOpRegenerator()); + } + if (solrConfig.queryResultCacheConfig != null && solrConfig.queryResultCacheConfig.getRegenerator() == null) { final int queryResultWindowSize = solrConfig.queryResultWindowSize; From f4e7ba7152311c5f6dcfda6af25dc9c2a4ff105b Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Wed, 14 May 2025 11:56:58 +0200 Subject: [PATCH 05/54] Fixed default value for FieldValueFeature --- .../java/org/apache/solr/ltr/feature/FieldValueFeature.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/FieldValueFeature.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/FieldValueFeature.java index 7a41916f29a7..bc10e6622a86 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/FieldValueFeature.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/FieldValueFeature.java @@ -125,6 +125,9 @@ public FieldValueFeatureWeight( super(FieldValueFeature.this, searcher, request, originalQuery, efi); if (searcher instanceof SolrIndexSearcher) { schemaField = ((SolrIndexSearcher) searcher).getSchema().getFieldOrNull(field); + if (schemaField.getDefaultValue() != null) { + setDefaultValue(schemaField.getDefaultValue()); + } } else { // some tests pass a null or a non-SolrIndexSearcher searcher schemaField = null; } From 8330a9d83a8601d20691fccae420e52640f3e5b4 Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Wed, 14 May 2025 14:51:18 +0200 Subject: [PATCH 06/54] Fixed remaining tests --- .../java/org/apache/solr/ltr/feature/FieldValueFeature.java | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/FieldValueFeature.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/FieldValueFeature.java index bc10e6622a86..a6583d3a0721 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/FieldValueFeature.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/FieldValueFeature.java @@ -125,8 +125,10 @@ public FieldValueFeatureWeight( super(FieldValueFeature.this, searcher, request, originalQuery, efi); if (searcher instanceof SolrIndexSearcher) { schemaField = ((SolrIndexSearcher) searcher).getSchema().getFieldOrNull(field); - if (schemaField.getDefaultValue() != null) { - setDefaultValue(schemaField.getDefaultValue()); + if (schemaField != null) { + if (schemaField.getDefaultValue() != null) { + setDefaultValue(schemaField.getDefaultValue()); + } } } else { // some tests pass a null or a non-SolrIndexSearcher searcher schemaField = null; From 3ac57b2d30ee4b627255b168983efbaf8db5dcb9 Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Thu, 15 May 2025 14:34:54 +0200 Subject: [PATCH 07/54] Removed default from schema --- .../java/org/apache/solr/ltr/feature/FieldValueFeature.java | 5 ----- 1 file changed, 5 deletions(-) diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/FieldValueFeature.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/FieldValueFeature.java index a6583d3a0721..7a41916f29a7 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/FieldValueFeature.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/FieldValueFeature.java @@ -125,11 +125,6 @@ public FieldValueFeatureWeight( super(FieldValueFeature.this, searcher, request, originalQuery, efi); if (searcher instanceof SolrIndexSearcher) { schemaField = ((SolrIndexSearcher) searcher).getSchema().getFieldOrNull(field); - if (schemaField != null) { - if (schemaField.getDefaultValue() != null) { - setDefaultValue(schemaField.getDefaultValue()); - } - } } else { // some tests pass a null or a non-SolrIndexSearcher searcher schemaField = null; } From 996ecdb31b47ece5adfe2bafa8ddfcdc43de319d Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Wed, 25 Jun 2025 13:24:18 +0200 Subject: [PATCH 08/54] Implemented first cache lookup without ltr --- .../org/apache/solr/ltr/FeatureLogger.java | 7 ++++++- .../org/apache/solr/ltr/LTRScoringQuery.java | 21 +++++++------------ .../LTRFeatureLoggerTransformerFactory.java | 3 ++- 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/FeatureLogger.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/FeatureLogger.java index 9be531c1ef32..5daf786a70e2 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/FeatureLogger.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/FeatureLogger.java @@ -80,7 +80,12 @@ private static int fvCacheKey(LTRScoringQuery scoringQuery, int docid) { * @return String representation of the list of features calculated for docid */ public String getFeatureVector( - int docid, LTRScoringQuery scoringQuery, SolrIndexSearcher searcher) { + int docid, LTRScoringQuery scoringQuery, SolrIndexSearcher searcher, LTRScoringQuery.ModelWeight modelWeights) { + // CHANGE CACHE KEY + if (searcher.cacheLookup(fvCacheName, fvCacheKey(scoringQuery, docid)) == null) { + final String featureVector = makeFeatureVector(modelWeights.getFeaturesInfo()); + searcher.cacheInsert(fvCacheName, fvCacheKey(scoringQuery, docid), featureVector); + } return (String) searcher.cacheLookup(fvCacheName, fvCacheKey(scoringQuery, docid)); } diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java index 85b33fc3ebdb..cc6de9a8870d 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java @@ -18,12 +18,7 @@ import java.io.IOException; import java.lang.invoke.MethodHandles; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.concurrent.Callable; import java.util.concurrent.Future; import java.util.concurrent.FutureTask; @@ -133,13 +128,13 @@ public SolrQueryRequest getRequest() { @Override public int hashCode() { final int prime = 31; - int result = classHash(); - result = (prime * result) + ((ltrScoringModel == null) ? 0 : ltrScoringModel.hashCode()); - result = (prime * result) + ((originalQuery == null) ? 0 : originalQuery.hashCode()); - if (efi == null) { - result = (prime * result) + 0; - } else { - for (final Map.Entry entry : efi.entrySet()) { + int result = ltrScoringModel.getFeatureStoreName().hashCode(); + result = (prime * result) + (ltrScoringModel.getName().hashCode()); + result = (prime * result) + (this.getFeatureLogger().logAll.hashCode()); + result = (prime * result) + (this.getFeatureLogger().featureFormat.hashCode()); + if (efi != null) { + TreeMap sorted = new TreeMap<>(efi); + for (final Map.Entry entry : sorted.entrySet()) { final String key = entry.getKey(); final String[] values = entry.getValue(); result = (prime * result) + key.hashCode(); diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java index 21ead4756091..609c99a915e2 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java @@ -423,7 +423,8 @@ private void implTransform(SolrDocument doc, int docid, DocIterationInfo docInfo } } if (!(rerankingQuery instanceof OriginalRankingLTRScoringQuery) || hasExplicitFeatureStore) { - Object featureVector = featureLogger.getFeatureVector(docid, rerankingQuery, searcher); + // WHEN COULD WE HAVE MULTIPLE MODEL WEIGHTS? + Object featureVector = featureLogger.getFeatureVector(docid, rerankingQuery, searcher, modelWeights[0]); if (featureVector == null) { // FV for this document was not in the cache featureVector = featureLogger.makeFeatureVector( From c14d4eea41ebb6fe1d491183f99e1f279cc526c8 Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Fri, 27 Jun 2025 11:50:52 +0200 Subject: [PATCH 09/54] Starting integrating cache for logging --- .../org/apache/solr/ltr/CSVFeatureLogger.java | 11 +- .../org/apache/solr/ltr/FeatureLogger.java | 57 +---- .../java/org/apache/solr/ltr/LTRRescorer.java | 51 +---- .../org/apache/solr/ltr/LTRScoringQuery.java | 198 +++++++++++------- .../interleaving/LTRInterleavingRescorer.java | 8 +- .../LTRFeatureLoggerTransformerFactory.java | 30 +-- .../solr/collection1/conf/solrconfig-ltr.xml | 1 - .../solr/ltr/TestSelectiveWeightCreation.java | 2 +- 8 files changed, 152 insertions(+), 206 deletions(-) diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/CSVFeatureLogger.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/CSVFeatureLogger.java index 22ddcb8724a2..57a86a10e1c8 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/CSVFeatureLogger.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/CSVFeatureLogger.java @@ -23,21 +23,20 @@ public class CSVFeatureLogger extends FeatureLogger { private final char keyValueSep; private final char featureSep; - public CSVFeatureLogger(String fvCacheName, FeatureFormat f, Boolean logAll) { - super(fvCacheName, f, logAll); + public CSVFeatureLogger(FeatureFormat f, Boolean logAll) { + super(f, logAll); this.keyValueSep = DEFAULT_KEY_VALUE_SEPARATOR; this.featureSep = DEFAULT_FEATURE_SEPARATOR; } - public CSVFeatureLogger( - String fvCacheName, FeatureFormat f, Boolean logAll, char keyValueSep, char featureSep) { - super(fvCacheName, f, logAll); + public CSVFeatureLogger(FeatureFormat f, Boolean logAll, char keyValueSep, char featureSep) { + super(f, logAll); this.keyValueSep = keyValueSep; this.featureSep = featureSep; } @Override - public String makeFeatureVector(LTRScoringQuery.FeatureInfo[] featuresInfo) { + public String printFeatureVector(LTRScoringQuery.FeatureInfo[] featuresInfo) { // Allocate the buffer to a size based on the number of features instead of the // default 16. You need space for the name, value, and two separators per feature, // but not all the features are expected to fire, so this is just a naive estimate. diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/FeatureLogger.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/FeatureLogger.java index 5daf786a70e2..16c554df0f3f 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/FeatureLogger.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/FeatureLogger.java @@ -16,16 +16,10 @@ */ package org.apache.solr.ltr; -import org.apache.solr.search.SolrIndexSearcher; - /** * FeatureLogger can be registered in a model and provide a strategy for logging the feature values. */ public abstract class FeatureLogger { - - /** the name of the cache using for storing the feature value */ - private final String fvCacheName; - public enum FeatureFormat { DENSE, SPARSE @@ -35,59 +29,12 @@ public enum FeatureFormat { protected Boolean logAll; - protected FeatureLogger(String fvCacheName, FeatureFormat f, Boolean logAll) { - this.fvCacheName = fvCacheName; + protected FeatureLogger(FeatureFormat f, Boolean logAll) { this.featureFormat = f; this.logAll = logAll; } - /** - * Log will be called every time that the model generates the feature values for a document and a - * query. - * - * @param docid Solr document id whose features we are saving - * @param featuresInfo List of all the {@link LTRScoringQuery.FeatureInfo} objects which contain - * name and value for all the features triggered by the result set - * @return true if the logger successfully logged the features, false otherwise. - */ - public boolean log( - int docid, - LTRScoringQuery scoringQuery, - SolrIndexSearcher searcher, - LTRScoringQuery.FeatureInfo[] featuresInfo) { - final String featureVector = makeFeatureVector(featuresInfo); - if (featureVector == null) { - return false; - } - - if (null == searcher.cacheInsert(fvCacheName, fvCacheKey(scoringQuery, docid), featureVector)) { - return false; - } - - return true; - } - - public abstract String makeFeatureVector(LTRScoringQuery.FeatureInfo[] featuresInfo); - - private static int fvCacheKey(LTRScoringQuery scoringQuery, int docid) { - return scoringQuery.hashCode() + (31 * docid); - } - - /** - * populate the document with its feature vector - * - * @param docid Solr document id - * @return String representation of the list of features calculated for docid - */ - public String getFeatureVector( - int docid, LTRScoringQuery scoringQuery, SolrIndexSearcher searcher, LTRScoringQuery.ModelWeight modelWeights) { - // CHANGE CACHE KEY - if (searcher.cacheLookup(fvCacheName, fvCacheKey(scoringQuery, docid)) == null) { - final String featureVector = makeFeatureVector(modelWeights.getFeaturesInfo()); - searcher.cacheInsert(fvCacheName, fvCacheKey(scoringQuery, docid), featureVector); - } - return (String) searcher.cacheLookup(fvCacheName, fvCacheKey(scoringQuery, docid)); - } + public abstract String printFeatureVector(LTRScoringQuery.FeatureInfo[] featuresInfo); public Boolean isLoggingAll() { return logAll; diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java index a21c107438ca..de1a1668c0ff 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java @@ -33,7 +33,7 @@ import org.apache.solr.ltr.interleaving.OriginalRankingLTRScoringQuery; import org.apache.solr.search.IncompleteRerankingException; import org.apache.solr.search.QueryLimits; -import org.apache.solr.search.SolrIndexSearcher; +import org.apache.solr.search.SolrCache; /** * Implements the rescoring logic. The top documents returned by solr with their original scores, @@ -138,7 +138,7 @@ private ScoreDoc[] rerank(IndexSearcher searcher, int topN, ScoreDoc[] firstPass (LTRScoringQuery.ModelWeight) searcher.createWeight(searcher.rewrite(scoringQuery), ScoreMode.COMPLETE, 1); - scoreFeatures(searcher, topN, modelWeight, firstPassResults, leaves, reranked); + scoreFeatures(topN, modelWeight, firstPassResults, leaves, reranked); // Must sort all documents that we reranked, and then select the top Arrays.sort(reranked, scoreComparator); return reranked; @@ -153,7 +153,6 @@ protected static ScoreDoc[] getFirstPassDocsRanked(TopDocs firstPassTopDocs) { } public void scoreFeatures( - IndexSearcher indexSearcher, int topN, LTRScoringQuery.ModelWeight modelWeight, ScoreDoc[] hits, @@ -182,36 +181,13 @@ public void scoreFeatures( docBase = readerContext.docBase; scorer = modelWeight.modelScorer(readerContext); } - if (scoreSingleHit(topN, docBase, hitUpto, hit, docID, scorer, reranked)) { - logSingleHit(indexSearcher, modelWeight, hit.doc, scoringQuery); - } + scoreSingleHit(topN, docBase, hitUpto, hit, docID, scorer, reranked); hitUpto++; } } - /** - * Call this method if the {@link #scoreSingleHit(int, int, int, ScoreDoc, int, - * org.apache.solr.ltr.LTRScoringQuery.ModelWeight.ModelScorer, ScoreDoc[])} method indicated that - * the document's feature info should be logged. - */ - protected static void logSingleHit( - IndexSearcher indexSearcher, - LTRScoringQuery.ModelWeight modelWeight, - int docid, - LTRScoringQuery scoringQuery) { - final FeatureLogger featureLogger = scoringQuery.getFeatureLogger(); - if (featureLogger != null && indexSearcher instanceof SolrIndexSearcher) { - featureLogger.log( - docid, scoringQuery, (SolrIndexSearcher) indexSearcher, modelWeight.getFeaturesInfo()); - } - } - - /** - * Scores a single document and returns true if the document's feature info should be logged via - * the {@link #logSingleHit(IndexSearcher, org.apache.solr.ltr.LTRScoringQuery.ModelWeight, int, - * LTRScoringQuery)} method. Feature info logging is only necessary for the topN documents. - */ - protected static boolean scoreSingleHit( + /** Scores a single document. */ + protected static void scoreSingleHit( int topN, int docBase, int hitUpto, @@ -232,8 +208,6 @@ protected static boolean scoreSingleHit( scorer.docID(); scorer.iterator().advance(targetDoc); - boolean logHit = false; - scorer.getDocInfo().setOriginalDocScore(hit.score); hit.score = scorer.score(); if (QueryLimits.getCurrentLimits() @@ -245,26 +219,19 @@ protected static boolean scoreSingleHit( } if (hitUpto < topN) { reranked[hitUpto] = hit; - // if the heap is not full, maybe I want to log the features for this - // document - logHit = true; } else if (hitUpto == topN) { // collected topN document, I create the heap heapify(reranked, topN); } if (hitUpto >= topN) { - // once that heap is ready, if the score of this document is lower that - // the minimum - // i don't want to log the feature. Otherwise I replace it with the - // minimum and fix the - // heap. + // once that heap is ready, if the score of this document is greater that + // the minimum I replace it with the + // minimum and fix the heap. if (hit.score > reranked[0].score) { reranked[0] = hit; heapAdjust(reranked, topN, 0); - logHit = true; } } - return logHit; } @Override @@ -291,6 +258,7 @@ protected static Explanation getExplanation( } public static LTRScoringQuery.FeatureInfo[] extractFeaturesInfo( + SolrCache loggingCache, LTRScoringQuery.ModelWeight modelWeight, int docid, Float originalDocScore, @@ -308,6 +276,7 @@ public static LTRScoringQuery.FeatureInfo[] extractFeaturesInfo( // score, which some features can use instead of recalculating it r.getDocInfo().setOriginalDocScore(originalDocScore); } + r.setCache(loggingCache); r.score(); return modelWeight.getFeaturesInfo(); } diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java index cc6de9a8870d..ca7ad4ae63d6 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java @@ -42,6 +42,7 @@ import org.apache.solr.ltr.model.LTRScoringModel; import org.apache.solr.request.SolrQueryRequest; import org.apache.solr.util.SolrDefaultScorerSupplier; +import org.apache.solr.search.SolrCache; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -129,9 +130,16 @@ public SolrQueryRequest getRequest() { public int hashCode() { final int prime = 31; int result = ltrScoringModel.getFeatureStoreName().hashCode(); - result = (prime * result) + (ltrScoringModel.getName().hashCode()); - result = (prime * result) + (this.getFeatureLogger().logAll.hashCode()); + if (!this.getFeatureLogger().logAll) { + result = (prime * result) + (ltrScoringModel.getName().hashCode()); + } result = (prime * result) + (this.getFeatureLogger().featureFormat.hashCode()); + result = addEfisHash(result, prime); + result = (prime * result) + this.toString().hashCode(); + return result; + } + + private int addEfisHash(int result, int prime) { if (efi != null) { TreeMap sorted = new TreeMap<>(efi); for (final Map.Entry entry : sorted.entrySet()) { @@ -141,7 +149,6 @@ public int hashCode() { result = (prime * result) + Arrays.hashCode(values); } } - result = (prime * result) + this.toString().hashCode(); return result; } @@ -364,17 +371,7 @@ public class ModelWeight extends Weight { private final float[] modelFeatureValuesNormalized; private final Feature.FeatureWeight[] extractedFeatureWeights; - // List of all the feature names, values - used for both scoring and logging - /* - * What is the advantage of using a hashmap here instead of an array of objects? - * A set of arrays was used earlier and the elements were accessed using the featureId. - * With the updated logic to create weights selectively, - * the number of elements in the array can be fewer than the total number of features. - * When [features] are not requested, only the model features are extracted. - * In this case, the indexing by featureId, fails. For this reason, - * we need a map which holds just the features that were triggered by the documents in the result set. - * - */ + // All the features private final FeatureInfo[] featuresInfo; /* @@ -495,7 +492,7 @@ public ModelScorer modelScorer(LeafReaderContext context) throws IOException { // score on the model for every document, since 0 features matching could // return a // non 0 score for a given model. - ModelScorer mscorer = new ModelScorer(this, featureScorers); + ModelScorer mscorer = new ModelScorer(this, featureScorers, context); return mscorer; } @@ -507,24 +504,30 @@ public boolean isCacheable(LeafReaderContext ctx) { public class ModelScorer extends Scorer { private final DocInfo docInfo; private final Scorer featureTraversalScorer; + private SolrCache featureVectorCache; public DocInfo getDocInfo() { return docInfo; } - public ModelScorer(Weight weight, List featureScorers) { + public ModelScorer(Weight weight, List featureScorers, LeafReaderContext leafContext) { + featureVectorCache = null; docInfo = new DocInfo(); for (final Feature.FeatureWeight.FeatureScorer subScorer : featureScorers) { subScorer.setDocInfo(docInfo); } if (featureScorers.size() <= 1) { // future enhancement: allow the use of dense features in other cases - featureTraversalScorer = new DenseModelScorer(weight, featureScorers); + featureTraversalScorer = new DenseModelScorer(weight, featureScorers, leafContext); } else { - featureTraversalScorer = new SparseModelScorer(weight, featureScorers); + featureTraversalScorer = new SparseModelScorer(weight, featureScorers, leafContext); } } + public void setCache(SolrCache cacheToUse) { + this.featureVectorCache = cacheToUse; + } + @Override public Collection getChildren() throws IOException { return featureTraversalScorer.getChildren(); @@ -550,16 +553,85 @@ public DocIdSetIterator iterator() { return featureTraversalScorer.iterator(); } - private class SparseModelScorer extends Scorer { + private LTRScoringQuery getScoringQuery() { + return LTRScoringQuery.this; + } + + abstract class FeatureTraversalScorer extends Scorer { + protected int targetDoc = -1; + protected int activeDoc = -1; + protected LeafReaderContext leafContext; + + protected FeatureTraversalScorer(Weight weight, LeafReaderContext leafContext) { + this.leafContext = leafContext; + } + + @Override + public float score() throws IOException { + reset(); + fillFeaturesInfo(); + return makeNormalizedFeaturesAndScore(); + } + + @Override + public float getMaxScore(int upTo) throws IOException { + return Float.POSITIVE_INFINITY; + } + + private void fillFeaturesInfo() throws IOException { + if (activeDoc == targetDoc) { + SolrCache featureVectorCache = null; + float[] featureVector; + + if (featureVectorCache != null) { + int docId = activeDoc + leafContext.docBase; + int fvCacheKey = fvCacheKey(getScoringQuery(), docId); + featureVector = featureVectorCache.get(fvCacheKey); + if (featureVector == null) { + featureVector = extractFeatureVector(); + featureVectorCache.put(fvCacheKey, featureVector); + } + } else { + featureVector = extractFeatureVector(); + } + + for (int i = 0; i < extractedFeatureWeights.length; i++) { + int featureId = extractedFeatureWeights[i].getIndex(); + float featureValue = featureVector[featureId]; + if (!Float.isNaN(featureValue) + && featureValue != extractedFeatureWeights[i].getDefaultValue()) { + featuresInfo[featureId].setValue(featureValue); + featuresInfo[featureId].setIsDefaultValue(false); + } + } + } + } + + private int fvCacheKey(LTRScoringQuery scoringQuery, int docId) { + return (31 * scoringQuery.hashCode()) + docId; + } + + protected float[] initFeatureVector(FeatureInfo[] featuresInfos) { + float[] featureVector = new float[featuresInfo.length]; + for (int i = 0; i < featuresInfos.length; i++) { + if (featuresInfos[i] != null) { + featureVector[i] = featuresInfos[i].getValue(); + } + } + return featureVector; + } + + protected abstract float[] extractFeatureVector() throws IOException; + } + + private class SparseModelScorer extends FeatureTraversalScorer { private final DisiPriorityQueue subScorers; private final List wrappers; - private final ScoringQuerySparseIterator itr; - - private int targetDoc = -1; - private int activeDoc = -1; + private final ScoringQuerySparseIterator sparseIterator; private SparseModelScorer( - Weight unusedWeight, List featureScorers) { + Weight unusedWeight, List featureScorers, LeafReaderContext leafContext) { + super(unusedWeight, leafContext); if (featureScorers.size() <= 1) { throw new IllegalArgumentException("There must be at least 2 subScorers"); } @@ -571,46 +643,31 @@ private SparseModelScorer( wrappers.add(w); } - itr = new ScoringQuerySparseIterator(wrappers); + sparseIterator = new ScoringQuerySparseIterator(wrappers); } @Override public int docID() { - return itr.docID(); + return sparseIterator.docID(); } - @Override - public float score() throws IOException { + protected float[] extractFeatureVector() throws IOException { final DisiWrapper topList = subScorers.topList(); - // If target doc we wanted to advance to match the actual doc - // the underlying features advanced to, perform the feature - // calculations, - // otherwise just continue with the model's scoring process with empty - // features. - reset(); - if (activeDoc == targetDoc) { - for (DisiWrapper w = topList; w != null; w = w.next) { - final Feature.FeatureWeight.FeatureScorer subScorer = - (Feature.FeatureWeight.FeatureScorer) w.scorer; - Feature.FeatureWeight scFW = subScorer.getWeight(); - final int featureId = scFW.getIndex(); - featuresInfo[featureId].setValue(subScorer.score()); - if (featuresInfo[featureId].getValue() != scFW.getDefaultValue()) { - featuresInfo[featureId].setIsDefaultValue(false); - } - } + float[] featureVector = initFeatureVector(featuresInfo); + for (DisiWrapper w = topList; w != null; w = w.next) { + final Feature.FeatureWeight.FeatureScorer subScorer = + (Feature.FeatureWeight.FeatureScorer) w.scorer; + Feature.FeatureWeight scFW = subScorer.getWeight(); + final int featureId = scFW.getIndex(); + float featureValue = subScorer.score(); + featureVector[featureId] = featureValue; } - return makeNormalizedFeaturesAndScore(); - } - - @Override - public float getMaxScore(int upTo) throws IOException { - return Float.POSITIVE_INFINITY; + return featureVector; } @Override public DocIdSetIterator iterator() { - return itr; + return sparseIterator; } @Override @@ -717,14 +774,12 @@ public long cost() { } } - private class DenseModelScorer extends Scorer { - private int activeDoc = -1; // The doc that our scorer's are actually at - private int targetDoc = -1; // The doc we were most recently told to go to - private int freq = -1; + private class DenseModelScorer extends FeatureTraversalScorer { private final List featureScorers; private DenseModelScorer( - Weight unusedWeight, List featureScorers) { + Weight unusedWeight, List featureScorers, LeafReaderContext leafContext) { + super(unusedWeight, leafContext); this.featureScorers = featureScorers; } @@ -733,26 +788,19 @@ public int docID() { return targetDoc; } - @Override - public float score() throws IOException { - reset(); - freq = 0; - if (targetDoc == activeDoc) { - for (final Scorer scorer : featureScorers) { - if (scorer.docID() == activeDoc) { - freq++; - Feature.FeatureWeight.FeatureScorer featureScorer = - (Feature.FeatureWeight.FeatureScorer) scorer; - Feature.FeatureWeight scFW = featureScorer.getWeight(); - final int featureId = scFW.getIndex(); - featuresInfo[featureId].setValue(scorer.score()); - if (featuresInfo[featureId].getValue() != scFW.getDefaultValue()) { - featuresInfo[featureId].setIsDefaultValue(false); - } - } + protected float[] extractFeatureVector() throws IOException { + float[] featureVector = initFeatureVector(featuresInfo); + for (final Scorer scorer : featureScorers) { + if (scorer.docID() == activeDoc) { + Feature.FeatureWeight.FeatureScorer featureScorer = + (Feature.FeatureWeight.FeatureScorer) scorer; + Feature.FeatureWeight scFW = featureScorer.getWeight(); + final int featureId = scFW.getIndex(); + float featureValue = scorer.score(); + featureVector[featureId] = featureValue; } } - return makeNormalizedFeaturesAndScore(); + return featureVector; } @Override diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingRescorer.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingRescorer.java index 78803afd9332..6b2be9345dbe 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingRescorer.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingRescorer.java @@ -103,7 +103,7 @@ private ScoreDoc[][] rerank(IndexSearcher searcher, int topN, ScoreDoc[] firstPa searcher.createWeight(searcher.rewrite(rerankingQueries[i]), ScoreMode.COMPLETE, 1); } } - scoreFeatures(searcher, topN, modelWeights, firstPassResults, leaves, reRankedPerModel); + scoreFeatures(topN, modelWeights, firstPassResults, leaves, reRankedPerModel); for (int i = 0; i < rerankingQueries.length; i++) { if (originalRankingIndex == null || originalRankingIndex != i) { @@ -115,7 +115,6 @@ private ScoreDoc[][] rerank(IndexSearcher searcher, int topN, ScoreDoc[] firstPa } public void scoreFeatures( - IndexSearcher indexSearcher, int topN, LTRScoringQuery.ModelWeight[] modelWeights, ScoreDoc[] hits, @@ -151,10 +150,7 @@ public void scoreFeatures( for (int i = 0; i < rerankingQueries.length; i++) { if (modelWeights[i] != null) { final ScoreDoc hit_i = new ScoreDoc(hit.doc, hit.score, hit.shardIndex); - if (scoreSingleHit( - topN, docBase, hitUpto, hit_i, docID, scorers[i], rerankedPerModel[i])) { - logSingleHit(indexSearcher, modelWeights[i], hit_i.doc, rerankingQueries[i]); - } + scoreSingleHit(topN, docBase, hitUpto, hit_i, docID, scorers[i], rerankedPerModel[i]); } } hitUpto++; diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java index 609c99a915e2..2abf2c22fb32 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java @@ -79,7 +79,6 @@ public class LTRFeatureLoggerTransformerFactory extends TransformerFactory { private static final boolean DEFAULT_NO_RERANKING_LOGGING_ALL = true; - private String fvCacheName; private String loggingModelName = DEFAULT_LOGGING_MODEL_NAME; private String defaultStore; private FeatureLogger.FeatureFormat defaultFormat = FeatureLogger.FeatureFormat.DENSE; @@ -88,10 +87,6 @@ public class LTRFeatureLoggerTransformerFactory extends TransformerFactory { private LTRThreadModule threadManager = null; - public void setFvCacheName(String fvCacheName) { - this.fvCacheName = fvCacheName; - } - public void setLoggingModelName(String loggingModelName) { this.loggingModelName = loggingModelName; } @@ -161,11 +156,7 @@ private FeatureLogger createFeatureLogger(String formatStr, Boolean logAll) { } else { format = this.defaultFormat; } - if (fvCacheName == null) { - throw new IllegalArgumentException("a fvCacheName must be configured"); - } - return new CSVFeatureLogger( - fvCacheName, format, logAll, csvKeyValueDelimiter, csvFeatureSeparator); + return new CSVFeatureLogger(format, logAll, csvKeyValueDelimiter, csvFeatureSeparator); } class FeatureTransformer extends DocTransformer { @@ -423,17 +414,14 @@ private void implTransform(SolrDocument doc, int docid, DocIterationInfo docInfo } } if (!(rerankingQuery instanceof OriginalRankingLTRScoringQuery) || hasExplicitFeatureStore) { - // WHEN COULD WE HAVE MULTIPLE MODEL WEIGHTS? - Object featureVector = featureLogger.getFeatureVector(docid, rerankingQuery, searcher, modelWeights[0]); - if (featureVector == null) { // FV for this document was not in the cache - featureVector = - featureLogger.makeFeatureVector( - LTRRescorer.extractFeaturesInfo( - rerankingModelWeight, - docid, - (!docsWereReranked && docsHaveScores) ? docInfo.score() : null, - leafContexts)); - } + String featureVector = + featureLogger.printFeatureVector( + LTRRescorer.extractFeaturesInfo( + req.getSearcher().getLoggingFeatureVectorCache(), + rerankingModelWeight, + docid, + (!docsWereReranked && docsHaveScores) ? docInfo.score() : null, + leafContexts)); doc.addField(name, featureVector); } } diff --git a/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr.xml b/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr.xml index c20ee2026f67..c3ed0a061596 100644 --- a/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr.xml +++ b/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr.xml @@ -40,7 +40,6 @@ vector you will have to specify that you want the field (e.g., fl="*,[fv]) --> ${solr.ltr.transformer.fv.defaultFormat:dense} - QUERY_DOC_FV + + + ${tests.luceneMatchVersion:LATEST} + ${solr.data.dir:} + + + + + + + + + + + + + + + + + + + + ${solr.ltr.transformer.fv.defaultFormat:dense} + + + + + + + + + 15000 + false + + + 1000 + + + ${solr.data.dir:} + + + + + + + + explicit + json + true + id + + + + \ No newline at end of file diff --git a/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr.xml b/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr.xml index c3ed0a061596..09808f9bcf3f 100644 --- a/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr.xml +++ b/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr.xml @@ -11,7 +11,7 @@ language governing permissions and limitations under the License. --> - ${tests.luceneMatchVersion:LATEST} + ${tests.luceneMatchVersion:LATEST} ${solr.data.dir:} @@ -29,8 +29,6 @@ - - QUERY_DOC_FV diff --git a/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-multiseg.xml b/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-multiseg.xml index 911db9a9f557..04ee42b3834a 100644 --- a/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-multiseg.xml +++ b/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-multiseg.xml @@ -28,8 +28,6 @@ - 1 @@ -43,7 +41,6 @@ enclosed between brackets (in this case [fv]). In order to get the feature vector you will have to specify that you want the field (e.g., fl="*,[fv]) --> - QUERY_DOC_FV diff --git a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestFeatureVectorCache.java b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestFeatureVectorCache.java new file mode 100644 index 000000000000..5cc8d417a339 --- /dev/null +++ b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestFeatureVectorCache.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr; + +import org.apache.solr.client.solrj.SolrQuery; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +public class TestFeatureVectorCache extends TestRerankBase { + @Before + public void before() throws Exception { + setupFeatureVectorCachetest(false); + + assertU(adoc("id", "1", "title", "w1", "description", "w1", "popularity", "1")); + assertU(adoc("id", "2", "title", "w2", "description", "w2", "popularity", "2")); + assertU(adoc("id", "3", "title", "w3", "description", "w3", "popularity", "3")); + assertU(commit()); + + loadFeatures("featurevectorcache_features.json"); + loadModels("featurevectorcache_linear_model.json"); + } + + @After + public void after() throws Exception { + aftertest(); + } + + @Test + public void testFeatureVectorCacheLogging() throws Exception { + final String doc1_feature_vector = + FeatureLoggerTestUtils.toFeatureVector( + "value_feature_1", "1.0", + "efi_feature", "3.0", + "match_w1_title", "1.0", + "popularity_value", "1.0"); + + final String doc3_feature_vector = + FeatureLoggerTestUtils.toFeatureVector( + "value_feature_1", "1.0", + "efi_feature", "3.0", + "match_w1_title", "0.0", + "popularity_value", "3.0"); + + final SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("rows", "3"); + query.add("fl", "[fv format=dense efi.efi_feature=3]"); + + // Feature vectors without caching + assertJQ( + "/query" + query.toQueryString(), + "/response/docs/[0]/=={'[fv]':'" + doc1_feature_vector + "'}"); + + // Feature vectors with caching + query.add("sort", "popularity desc"); + assertJQ( + "/query" + query.toQueryString(), + "/response/docs/[0]/=={'[fv]':'" + doc3_feature_vector + "'}"); + } + + @Test + public void testFeatureVectorCacheRerank() throws Exception { + final SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("rows", "3"); + query.add("fl", "*,score"); + query.add("rq", "{!ltr reRankDocs=3 model=featurevectorcache_linear_model efi.efi_feature=4}"); + + // Feature vectors without caching + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==4.2"); + + // Feature vectors with caching + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/score==3.1"); + } +} \ No newline at end of file diff --git a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestRerankBase.java b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestRerankBase.java index e3b30b24043e..1792a5336559 100644 --- a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestRerankBase.java +++ b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestRerankBase.java @@ -120,6 +120,12 @@ protected static void setupPersistenttest(boolean bulkIndex) throws Exception { if (bulkIndex) bulkIndex(); } + protected static void setupFeatureVectorCachetest(boolean bulkIndex) throws Exception { + chooseDefaultFeatureFormat(); + setuptest("solrconfig-ltr-featurevectorcache.xml", "schema.xml"); + if (bulkIndex) bulkIndex(); + } + public static ManagedFeatureStore getManagedFeatureStore() { try (SolrCore core = solrClientTestRule.getCoreContainer().getCore(DEFAULT_TEST_CORENAME)) { return ManagedFeatureStore.getManagedFeatureStore(core); From 5f5c0f16f25885a84fb9df06873bd4d6ea4db4e3 Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Fri, 11 Jul 2025 16:08:34 +0200 Subject: [PATCH 11/54] Fixed efi when doing both reranking and logging --- .../org/apache/solr/ltr/LTRScoringQuery.java | 58 ++++++++++++------- .../LTRFeatureLoggerTransformerFactory.java | 57 +++++++++--------- 2 files changed, 67 insertions(+), 48 deletions(-) diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java index 3faa2aaa6062..bfdb1ed45249 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java @@ -129,25 +129,20 @@ public SolrQueryRequest getRequest() { @Override public int hashCode() { final int prime = 31; - int result = ltrScoringModel.getFeatureStoreName().hashCode(); - result = (prime * result) + addEfisHash(result, prime); - if (logger != null) { - result = (prime * result) + logger.featureFormat.hashCode(); - } - result = (prime * result) + this.toString().hashCode(); - return result; - } - - private int addEfisHash(int result, int prime) { - if (efi != null) { - TreeMap sorted = new TreeMap<>(efi); - for (final Map.Entry entry : sorted.entrySet()) { + int result = classHash(); + result = (prime * result) + ((ltrScoringModel == null) ? 0 : ltrScoringModel.hashCode()); + result = (prime * result) + ((originalQuery == null) ? 0 : originalQuery.hashCode()); + if (efi == null) { + result = (prime * result) + 0; + } else { + for (final Map.Entry entry : efi.entrySet()) { final String key = entry.getKey(); final String[] values = entry.getValue(); result = (prime * result) + key.hashCode(); result = (prime * result) + Arrays.hashCode(values); } } + result = (prime * result) + this.toString().hashCode(); return result; } @@ -527,6 +522,10 @@ public void setCache(SolrCache cacheToUse) { this.featureVectorCache = cacheToUse; } + public SolrCache getCache() { + return featureVectorCache; + } + @Override public Collection getChildren() throws IOException { return featureTraversalScorer.getChildren(); @@ -552,10 +551,6 @@ public DocIdSetIterator iterator() { return featureTraversalScorer.iterator(); } - private LTRScoringQuery getScoringQuery() { - return LTRScoringQuery.this; - } - abstract class FeatureTraversalScorer extends Scorer { protected int targetDoc = -1; protected int activeDoc = -1; @@ -583,7 +578,7 @@ private void fillFeaturesInfo() throws IOException { if (featureVectorCache != null) { int docId = activeDoc + leafContext.docBase; - int fvCacheKey = fvCacheKey(getScoringQuery(), docId); + int fvCacheKey = fvCacheKey(docId); featureVector = featureVectorCache.get(fvCacheKey); if (featureVector == null) { featureVector = extractFeatureVector(); @@ -605,8 +600,31 @@ private void fillFeaturesInfo() throws IOException { } } - private int fvCacheKey(LTRScoringQuery scoringQuery, int docId) { - return (31 * scoringQuery.hashCode()) + docId; + private int fvCacheKey(int docId) { + int prime = 31; + int result = docId; + if (Objects.equals(featureVectorCache.name(), "rerankingFeatureVectorCache")) { + result = (prime * result) + ltrScoringModel.getName().hashCode(); + } + if (Objects.equals(featureVectorCache.name(), "loggingFeatureVectorCache")) { + result = (prime * result) + ltrScoringModel.getFeatureStoreName().hashCode(); + result = (prime * result) + logger.featureFormat.hashCode(); + } + result = (prime * result) + addEfisHash(result, prime); + return result; + } + + private int addEfisHash(int result, int prime) { + if (efi != null) { + TreeMap sorted = new TreeMap<>(efi); + for (final Map.Entry entry : sorted.entrySet()) { + final String key = entry.getKey(); + final String[] values = entry.getValue(); + result = (prime * result) + key.hashCode(); + result = (prime * result) + Arrays.hashCode(values); + } + } + return result; } protected float[] initFeatureVector(FeatureInfo[] featuresInfos) { diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java index 2abf2c22fb32..63f0b8502e5d 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java @@ -258,7 +258,7 @@ public void setContext(ResultContext context) { modelFeatures, docsWereReranked); setupRerankingQueriesForLogging( - transformerFeatureStore, transformerExternalFeatureInfo, loggingModel); + transformerFeatureStore, transformerExternalFeatureInfo, loggingModel); setupRerankingWeightsForLogging(context, featureLogger); } @@ -329,41 +329,42 @@ private LoggingModel createLoggingModel( * @param transformerExternalFeatureInfo explicit efi for the transformer */ private void setupRerankingQueriesForLogging( - String transformerFeatureStore, - Map transformerExternalFeatureInfo, - LoggingModel loggingModel) { + String transformerFeatureStore, + Map transformerExternalFeatureInfo, + LoggingModel loggingModel) { if (!docsWereReranked) { // no reranking query LTRScoringQuery loggingQuery = - new LTRScoringQuery(loggingModel, transformerExternalFeatureInfo, threadManager); + new LTRScoringQuery(loggingModel, transformerExternalFeatureInfo, threadManager); rerankingQueries = new LTRScoringQuery[] {loggingQuery}; } else { rerankingQueries = new LTRScoringQuery[rerankingQueriesFromContext.length]; System.arraycopy( - rerankingQueriesFromContext, - 0, - rerankingQueries, - 0, - rerankingQueriesFromContext.length); - - if (transformerFeatureStore != null) { // explicit feature store for the transformer - LTRScoringModel matchingRerankingModel = loggingModel; - for (LTRScoringQuery rerankingQuery : rerankingQueries) { - if (!(rerankingQuery instanceof OriginalRankingLTRScoringQuery) - && transformerFeatureStore.equals( - rerankingQuery.getScoringModel().getFeatureStoreName())) { - matchingRerankingModel = rerankingQuery.getScoringModel(); - } + rerankingQueriesFromContext, + 0, + rerankingQueries, + 0, + rerankingQueriesFromContext.length); + + if (transformerFeatureStore == null) { + transformerFeatureStore = FeatureStore.DEFAULT_FEATURE_STORE_NAME; + } + LTRScoringModel matchingRerankingModel = loggingModel; + for (LTRScoringQuery rerankingQuery : rerankingQueries) { + if (!(rerankingQuery instanceof OriginalRankingLTRScoringQuery) + && transformerFeatureStore.equals( + rerankingQuery.getScoringModel().getFeatureStoreName())) { + matchingRerankingModel = rerankingQuery.getScoringModel(); } + } - for (int i = 0; i < rerankingQueries.length; i++) { - rerankingQueries[i] = - new LTRScoringQuery( - matchingRerankingModel, - (!transformerExternalFeatureInfo.isEmpty() - ? transformerExternalFeatureInfo - : rerankingQueries[i].getExternalFeatureInfo()), - threadManager); - } + for (int i = 0; i < rerankingQueries.length; i++) { + rerankingQueries[i] = + new LTRScoringQuery( + matchingRerankingModel, + (!transformerExternalFeatureInfo.isEmpty() + ? transformerExternalFeatureInfo + : rerankingQueries[i].getExternalFeatureInfo()), + threadManager); } } } From fa5ce792517b4fde9764af560bf5af68d792f053 Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Fri, 11 Jul 2025 16:36:13 +0200 Subject: [PATCH 12/54] Fixed efi when doing both reranking and logging. First implementation was not 100% correct. --- .../org/apache/solr/ltr/LTRScoringQuery.java | 6 +- .../LTRFeatureLoggerTransformerFactory.java | 61 ++++++++++--------- .../featurevectorcache_features.json | 7 --- 3 files changed, 38 insertions(+), 36 deletions(-) diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java index bfdb1ed45249..16e04f0e2309 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java @@ -66,7 +66,7 @@ public class LTRScoringQuery extends Query implements Accountable { private FeatureLogger logger; // Map of external parameters, such as query intent, that can be used by // features - private final Map efi; + private Map efi; // Original solr query used to fetch matching documents private Query originalQuery; // Original solr request @@ -118,6 +118,10 @@ public Map getExternalFeatureInfo() { return efi; } + public void setExternalFeatureInfo(Map efi) { + this.efi = efi; + } + public void setRequest(SolrQueryRequest request) { this.request = request; } diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java index 63f0b8502e5d..2f12897f099e 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java @@ -329,42 +329,47 @@ private LoggingModel createLoggingModel( * @param transformerExternalFeatureInfo explicit efi for the transformer */ private void setupRerankingQueriesForLogging( - String transformerFeatureStore, - Map transformerExternalFeatureInfo, - LoggingModel loggingModel) { + String transformerFeatureStore, + Map transformerExternalFeatureInfo, + LoggingModel loggingModel) { if (!docsWereReranked) { // no reranking query LTRScoringQuery loggingQuery = - new LTRScoringQuery(loggingModel, transformerExternalFeatureInfo, threadManager); + new LTRScoringQuery(loggingModel, transformerExternalFeatureInfo, threadManager); rerankingQueries = new LTRScoringQuery[] {loggingQuery}; } else { rerankingQueries = new LTRScoringQuery[rerankingQueriesFromContext.length]; System.arraycopy( - rerankingQueriesFromContext, - 0, - rerankingQueries, - 0, - rerankingQueriesFromContext.length); - - if (transformerFeatureStore == null) { - transformerFeatureStore = FeatureStore.DEFAULT_FEATURE_STORE_NAME; - } - LTRScoringModel matchingRerankingModel = loggingModel; - for (LTRScoringQuery rerankingQuery : rerankingQueries) { - if (!(rerankingQuery instanceof OriginalRankingLTRScoringQuery) - && transformerFeatureStore.equals( - rerankingQuery.getScoringModel().getFeatureStoreName())) { - matchingRerankingModel = rerankingQuery.getScoringModel(); + rerankingQueriesFromContext, + 0, + rerankingQueries, + 0, + rerankingQueriesFromContext.length); + + if (transformerFeatureStore != null) { // explicit feature store for the transformer + LTRScoringModel matchingRerankingModel = loggingModel; + for (LTRScoringQuery rerankingQuery : rerankingQueries) { + if (!(rerankingQuery instanceof OriginalRankingLTRScoringQuery) + && transformerFeatureStore.equals( + rerankingQuery.getScoringModel().getFeatureStoreName())) { + matchingRerankingModel = rerankingQuery.getScoringModel(); + } } - } - for (int i = 0; i < rerankingQueries.length; i++) { - rerankingQueries[i] = - new LTRScoringQuery( - matchingRerankingModel, - (!transformerExternalFeatureInfo.isEmpty() - ? transformerExternalFeatureInfo - : rerankingQueries[i].getExternalFeatureInfo()), - threadManager); + for (int i = 0; i < rerankingQueries.length; i++) { + rerankingQueries[i] = + new LTRScoringQuery( + matchingRerankingModel, + (!transformerExternalFeatureInfo.isEmpty() + ? transformerExternalFeatureInfo + : rerankingQueries[i].getExternalFeatureInfo()), + threadManager); + } + } else { + for (int i = 0; i < rerankingQueries.length; i++) { + if (!transformerExternalFeatureInfo.isEmpty()) { + rerankingQueries[i].setExternalFeatureInfo(transformerExternalFeatureInfo); + } + } } } } diff --git a/solr/modules/ltr/src/test-files/featureExamples/featurevectorcache_features.json b/solr/modules/ltr/src/test-files/featureExamples/featurevectorcache_features.json index f63fe791f40d..58e30763bb7e 100644 --- a/solr/modules/ltr/src/test-files/featureExamples/featurevectorcache_features.json +++ b/solr/modules/ltr/src/test-files/featureExamples/featurevectorcache_features.json @@ -1,11 +1,4 @@ [ - { - "name": "value_feature_2", - "class": "org.apache.solr.ltr.feature.ValueFeature", - "params": { - "value": 1 - } - }, { "name": "value_feature_1", "class": "org.apache.solr.ltr.feature.ValueFeature", From ac4507d6547b7c3cfcbdfb5d661938b473c956e8 Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Mon, 14 Jul 2025 14:18:44 +0200 Subject: [PATCH 13/54] WORKING VERSION WITH 2 CACHES --- .../ltr/src/java/org/apache/solr/ltr/LTRRescorer.java | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java index e7da5dc585de..902d46c0153f 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java @@ -181,21 +181,20 @@ public void scoreFeatures( docBase = readerContext.docBase; scorer = modelWeight.modelScorer(readerContext); } - scoreSingleHit(topN, docBase, hitUpto, hit, docID, scorer, reranked, scoringQuery); + scoreSingleHit(topN, docBase, hitUpto, hit, docID, scorer, reranked); hitUpto++; } } /** Scores a single document. */ - protected static void scoreSingleHit( + protected void scoreSingleHit( int topN, int docBase, int hitUpto, ScoreDoc hit, int docID, LTRScoringQuery.ModelWeight.ModelScorer scorer, - ScoreDoc[] reranked, - LTRScoringQuery scoringQuery) + ScoreDoc[] reranked) throws IOException { /* * Scorer for a LTRScoringQuery.ModelWeight should never be null since we always have to call From abf6706284fb354675897a18af13b265d65f482b Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Mon, 14 Jul 2025 16:18:53 +0200 Subject: [PATCH 14/54] Unique cache for feature vector --- .../java/org/apache/solr/core/SolrConfig.java | 13 ++----- .../apache/solr/search/SolrIndexSearcher.java | 38 ++++++------------- .../java/org/apache/solr/ltr/LTRRescorer.java | 4 -- .../org/apache/solr/ltr/LTRScoringQuery.java | 17 ++------- .../LTRFeatureLoggerTransformerFactory.java | 3 +- .../solrconfig-ltr-featurevectorcache.xml | 3 +- 6 files changed, 21 insertions(+), 57 deletions(-) diff --git a/solr/core/src/java/org/apache/solr/core/SolrConfig.java b/solr/core/src/java/org/apache/solr/core/SolrConfig.java index b37fff82bbf3..4be0350efd3f 100644 --- a/solr/core/src/java/org/apache/solr/core/SolrConfig.java +++ b/solr/core/src/java/org/apache/solr/core/SolrConfig.java @@ -301,12 +301,9 @@ private SolrConfig(SolrResourceLoader loader, String name, Properties substituta queryResultCacheConfig = CacheConfig.getConfig( this, get("query").get("queryResultCache"), "query/queryResultCache"); - rerankingFeatureVectorCacheConfig = + featureVectorCacheConfig = CacheConfig.getConfig( - this, get("query").get("rerankingFeatureVectorCache"), "query/rerankingFeatureVectorCache"); - loggingFeatureVectorCacheConfig = - CacheConfig.getConfig( - this, get("query").get("loggingFeatureVectorCache"), "query/loggingFeatureVectorCache"); + this, get("query").get("featureVectorCache"), "query/featureVectorCache"); documentCacheConfig = CacheConfig.getConfig(this, get("query").get("documentCache"), "query/documentCache"); CacheConfig conf = @@ -668,8 +665,7 @@ public SolrRequestParsers getRequestParsers() { public final CacheConfig queryResultCacheConfig; public final CacheConfig documentCacheConfig; public final CacheConfig fieldValueCacheConfig; - public final CacheConfig rerankingFeatureVectorCacheConfig; - public final CacheConfig loggingFeatureVectorCacheConfig; + public final CacheConfig featureVectorCacheConfig; public final Map userCacheConfigs; // SolrIndexSearcher - more... public final boolean useFilterForSortedQuery; @@ -1011,8 +1007,7 @@ public Map toMap(Map result) { queryResultCacheConfig, documentCacheConfig, fieldValueCacheConfig, - rerankingFeatureVectorCacheConfig, - loggingFeatureVectorCacheConfig); + featureVectorCacheConfig); m = new LinkedHashMap<>(); result.put("requestDispatcher", m); if (httpCachingConfig != null) m.put("httpCaching", httpCachingConfig); diff --git a/solr/core/src/java/org/apache/solr/search/SolrIndexSearcher.java b/solr/core/src/java/org/apache/solr/search/SolrIndexSearcher.java index 3470867fe0fb..8f9800fb016c 100644 --- a/solr/core/src/java/org/apache/solr/search/SolrIndexSearcher.java +++ b/solr/core/src/java/org/apache/solr/search/SolrIndexSearcher.java @@ -165,8 +165,7 @@ public class SolrIndexSearcher extends IndexSearcher implements Closeable, SolrI private final SolrCache filterCache; private final SolrCache queryResultCache; private final SolrCache fieldValueCache; - private final SolrCache rerankingFeatureVectorCache; - private final SolrCache loggingFeatureVectorCache; + private final SolrCache featureVectorCache; private final LongAdder fullSortCount = new LongAdder(); private final LongAdder skipSortCount = new LongAdder(); private final LongAdder liveDocsNaiveCacheHitCount = new LongAdder(); @@ -451,16 +450,11 @@ public SolrIndexSearcher( ? null : solrConfig.queryResultCacheConfig.newInstance(); if (queryResultCache != null) clist.add(queryResultCache); - rerankingFeatureVectorCache = - solrConfig.rerankingFeatureVectorCacheConfig == null + featureVectorCache = + solrConfig.featureVectorCacheConfig == null ? null - : solrConfig.rerankingFeatureVectorCacheConfig.newInstance(); - if (rerankingFeatureVectorCache != null) clist.add(rerankingFeatureVectorCache); - loggingFeatureVectorCache = - solrConfig.loggingFeatureVectorCacheConfig == null - ? null - : solrConfig.loggingFeatureVectorCacheConfig.newInstance(); - if (loggingFeatureVectorCache != null) clist.add(loggingFeatureVectorCache); + : solrConfig.featureVectorCacheConfig.newInstance(); + if (featureVectorCache != null) clist.add(featureVectorCache); SolrCache documentCache = docFetcher.getDocumentCache(); if (documentCache != null) clist.add(documentCache); @@ -482,8 +476,7 @@ public SolrIndexSearcher( this.filterCache = null; this.queryResultCache = null; this.fieldValueCache = null; - this.rerankingFeatureVectorCache = null; - this.loggingFeatureVectorCache = null; + this.featureVectorCache = null; this.cacheMap = NO_GENERIC_CACHES; this.cacheList = NO_CACHES; } @@ -699,12 +692,8 @@ public SolrCache getFilterCache() { return filterCache; } - public SolrCache getRerankingFeatureVectorCache() { - return rerankingFeatureVectorCache; - } - - public SolrCache getLoggingFeatureVectorCache() { - return loggingFeatureVectorCache; + public SolrCache getFeatureVectorCache() { + return featureVectorCache; } // @@ -749,14 +738,9 @@ public boolean regenerateItem( }); } - if (solrConfig.rerankingFeatureVectorCacheConfig != null - && solrConfig.rerankingFeatureVectorCacheConfig.getRegenerator() == null) { - solrConfig.rerankingFeatureVectorCacheConfig.setRegenerator(new NoOpRegenerator()); - } - - if (solrConfig.loggingFeatureVectorCacheConfig != null - && solrConfig.loggingFeatureVectorCacheConfig.getRegenerator() == null) { - solrConfig.loggingFeatureVectorCacheConfig.setRegenerator(new NoOpRegenerator()); + if (solrConfig.featureVectorCacheConfig != null + && solrConfig.featureVectorCacheConfig.getRegenerator() == null) { + solrConfig.featureVectorCacheConfig.setRegenerator(new NoOpRegenerator()); } if (solrConfig.queryResultCacheConfig != null diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java index 902d46c0153f..3fa99eed80b5 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java @@ -33,7 +33,6 @@ import org.apache.solr.ltr.interleaving.OriginalRankingLTRScoringQuery; import org.apache.solr.search.IncompleteRerankingException; import org.apache.solr.search.QueryLimits; -import org.apache.solr.search.SolrCache; /** * Implements the rescoring logic. The top documents returned by solr with their original scores, @@ -209,7 +208,6 @@ protected void scoreSingleHit( scorer.iterator().advance(targetDoc); scorer.getDocInfo().setOriginalDocScore(hit.score); - scorer.setCache(scoringQuery.getRequest().getSearcher().getRerankingFeatureVectorCache()); hit.score = scorer.score(); if (QueryLimits.getCurrentLimits() .maybeExitWithPartialResults( @@ -259,7 +257,6 @@ protected static Explanation getExplanation( } public static LTRScoringQuery.FeatureInfo[] extractFeaturesInfo( - SolrCache loggingCache, LTRScoringQuery.ModelWeight modelWeight, int docid, Float originalDocScore, @@ -277,7 +274,6 @@ public static LTRScoringQuery.FeatureInfo[] extractFeaturesInfo( // score, which some features can use instead of recalculating it r.getDocInfo().setOriginalDocScore(originalDocScore); } - r.setCache(loggingCache); r.score(); return modelWeight.getFeaturesInfo(); } diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java index 16e04f0e2309..b83090a815f6 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java @@ -40,6 +40,7 @@ import org.apache.lucene.util.RamUsageEstimator; import org.apache.solr.ltr.feature.Feature; import org.apache.solr.ltr.model.LTRScoringModel; +import org.apache.solr.ltr.response.transform.LTRFeatureLoggerTransformerFactory; import org.apache.solr.request.SolrQueryRequest; import org.apache.solr.util.SolrDefaultScorerSupplier; import org.apache.solr.search.SolrCache; @@ -502,14 +503,12 @@ public boolean isCacheable(LeafReaderContext ctx) { public class ModelScorer extends Scorer { private final DocInfo docInfo; private final Scorer featureTraversalScorer; - private SolrCache featureVectorCache; public DocInfo getDocInfo() { return docInfo; } public ModelScorer(Weight weight, List featureScorers, LeafReaderContext leafContext) { - featureVectorCache = null; docInfo = new DocInfo(); for (final Feature.FeatureWeight.FeatureScorer subScorer : featureScorers) { subScorer.setDocInfo(docInfo); @@ -522,14 +521,6 @@ public ModelScorer(Weight weight, List feat } } - public void setCache(SolrCache cacheToUse) { - this.featureVectorCache = cacheToUse; - } - - public SolrCache getCache() { - return featureVectorCache; - } - @Override public Collection getChildren() throws IOException { return featureTraversalScorer.getChildren(); @@ -580,6 +571,7 @@ private void fillFeaturesInfo() throws IOException { if (activeDoc == targetDoc) { float[] featureVector; + SolrCache featureVectorCache = request.getSearcher().getFeatureVectorCache(); if (featureVectorCache != null) { int docId = activeDoc + leafContext.docBase; int fvCacheKey = fvCacheKey(docId); @@ -607,12 +599,11 @@ private void fillFeaturesInfo() throws IOException { private int fvCacheKey(int docId) { int prime = 31; int result = docId; - if (Objects.equals(featureVectorCache.name(), "rerankingFeatureVectorCache")) { + if (!Objects.equals(ltrScoringModel.getName(), LTRFeatureLoggerTransformerFactory.DEFAULT_LOGGING_MODEL_NAME)) { result = (prime * result) + ltrScoringModel.getName().hashCode(); } - if (Objects.equals(featureVectorCache.name(), "loggingFeatureVectorCache")) { + else { result = (prime * result) + ltrScoringModel.getFeatureStoreName().hashCode(); - result = (prime * result) + logger.featureFormat.hashCode(); } result = (prime * result) + addEfisHash(result, prime); return result; diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java index 2f12897f099e..ed542039755a 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java @@ -75,7 +75,7 @@ public class LTRFeatureLoggerTransformerFactory extends TransformerFactory { // used inside fl to specify to log (all|model only) features private static final String FV_LOG_ALL = "logAll"; - private static final String DEFAULT_LOGGING_MODEL_NAME = "logging-model"; + public static final String DEFAULT_LOGGING_MODEL_NAME = "logging-model"; private static final boolean DEFAULT_NO_RERANKING_LOGGING_ALL = true; @@ -423,7 +423,6 @@ private void implTransform(SolrDocument doc, int docid, DocIterationInfo docInfo String featureVector = featureLogger.printFeatureVector( LTRRescorer.extractFeaturesInfo( - req.getSearcher().getLoggingFeatureVectorCache(), rerankingModelWeight, docid, (!docsWereReranked && docsHaveScores) ? docInfo.score() : null, diff --git a/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr-featurevectorcache.xml b/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr-featurevectorcache.xml index 47fc0a4a402c..f78d5996d3db 100644 --- a/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr-featurevectorcache.xml +++ b/solr/modules/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr-featurevectorcache.xml @@ -28,8 +28,7 @@ - - + - ${tests.luceneMatchVersion:LATEST} + ${tests.luceneMatchVersion:LATEST} ${solr.data.dir:} @@ -41,9 +41,9 @@ From ccca6d4e70f05a6ecc3d5ebb9217c233e9b95efb Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Thu, 31 Jul 2025 12:47:44 +0200 Subject: [PATCH 22/54] Small refactors for variable and methods names --- .../java/org/apache/solr/ltr/LTRRescorer.java | 4 +- .../org/apache/solr/ltr/LTRScoringQuery.java | 72 +++++++++---------- .../LTRFeatureLoggerTransformerFactory.java | 2 +- .../apache/solr/ltr/TestLTRScoringQuery.java | 2 +- .../solr/ltr/TestSelectiveWeightCreation.java | 4 +- ...FeatureExtractionFromMultipleSegments.java | 2 +- 6 files changed, 43 insertions(+), 43 deletions(-) diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java index 520016ffabb4..8d69e36a8c61 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java @@ -256,7 +256,7 @@ protected static Explanation getExplanation( return rankingWeight.explain(context, deBasedDoc); } - public static LTRScoringQuery.FeatureInfo[] extractFeaturesInfo( + public static LTRScoringQuery.FeatureInfo[] extractFeatures( LTRScoringQuery.ModelWeight modelWeight, int docid, Float originalDocScore, @@ -276,7 +276,7 @@ public static LTRScoringQuery.FeatureInfo[] extractFeaturesInfo( } r.setIsLogging(true); r.score(); - return modelWeight.getFeaturesInfo(); + return modelWeight.getAllFeaturesInStore(); } } } diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java index b7ae071b031e..8be866500538 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java @@ -377,8 +377,8 @@ public class ModelWeight extends Weight { private final float[] modelFeatureValuesNormalized; private final Feature.FeatureWeight[] extractedFeatureWeights; - // All the features - private final FeatureInfo[] featuresInfo; + // Array of all the features in the feature store of reference + private final FeatureInfo[] allFeaturesInStore; /* * @param modelFeatureWeights @@ -398,7 +398,7 @@ public ModelWeight( this.extractedFeatureWeights = extractedFeatureWeights; this.modelFeatureWeights = modelFeatureWeights; this.modelFeatureValuesNormalized = new float[modelFeatureWeights.length]; - this.featuresInfo = new FeatureInfo[allFeaturesSize]; + this.allFeaturesInStore = new FeatureInfo[allFeaturesSize]; setFeaturesInfo(); } @@ -407,12 +407,12 @@ private void setFeaturesInfo() { String featName = extractedFeatureWeights[i].getName(); int featId = extractedFeatureWeights[i].getIndex(); float value = extractedFeatureWeights[i].getDefaultValue(); - featuresInfo[featId] = new FeatureInfo(featName, value, true); + allFeaturesInStore[featId] = new FeatureInfo(featName, value, true); } } - public FeatureInfo[] getFeaturesInfo() { - return featuresInfo; + public FeatureInfo[] getAllFeaturesInStore() { + return allFeaturesInStore; } // for test use @@ -434,22 +434,21 @@ Feature.FeatureWeight[] getExtractedFeatureWeights() { * Goes through all the stored feature values, and calculates the normalized values for all the * features that will be used for scoring. Then calculate and return the model's score. */ - private float makeNormalizedFeaturesAndScore() { + private void normalizeFeatures() { int pos = 0; for (final Feature.FeatureWeight feature : modelFeatureWeights) { final int featureId = feature.getIndex(); - FeatureInfo fInfo = featuresInfo[featureId]; + FeatureInfo fInfo = allFeaturesInStore[featureId]; modelFeatureValuesNormalized[pos] = fInfo.getValue(); pos++; } ltrScoringModel.normalizeFeaturesInPlace(modelFeatureValuesNormalized); - return ltrScoringModel.score(modelFeatureValuesNormalized); } @Override public Explanation explain(LeafReaderContext context, int doc) throws IOException { - final Explanation[] explanations = new Explanation[this.featuresInfo.length]; + final Explanation[] explanations = new Explanation[this.allFeaturesInStore.length]; for (final Feature.FeatureWeight feature : extractedFeatureWeights) { explanations[feature.getIndex()] = feature.explain(context, doc); } @@ -473,8 +472,8 @@ protected void reset() { float value = extractedFeatureWeights[i].getDefaultValue(); // need to set default value everytime as the default value is used in 'dense' // mode even if used=false - featuresInfo[featId].setValue(value); - featuresInfo[featId].setIsDefaultValue(true); + allFeaturesInStore[featId].setValue(value); + allFeaturesInStore[featId].setIsDefaultValue(true); } } @@ -524,9 +523,9 @@ public ModelScorer(Weight weight, List feat } if (featureScorers.size() <= 1) { // future enhancement: allow the use of dense features in other cases - featureTraversalScorer = new DenseModelScorer(weight, featureScorers, leafContext); + featureTraversalScorer = new SingleFeatureScorer(weight, featureScorers, leafContext); } else { - featureTraversalScorer = new SparseModelScorer(weight, featureScorers, leafContext); + featureTraversalScorer = new MultiFeaturesScorer(weight, featureScorers, leafContext); } } @@ -559,6 +558,9 @@ public void setIsLogging(boolean isLogging) { this.isLogging = isLogging; } + /** + * This class is responsible for extracting features and using them to score the document. + */ abstract class FeatureTraversalScorer extends Scorer { protected int targetDoc = -1; protected int activeDoc = -1; @@ -570,9 +572,11 @@ protected FeatureTraversalScorer(Weight weight, LeafReaderContext leafContext) { @Override public float score() throws IOException { + // Initialize features to their default values and set isDefaultValue to true. reset(); fillFeaturesInfo(); - return makeNormalizedFeaturesAndScore(); + normalizeFeatures(); + return ltrScoringModel.score(modelFeatureValuesNormalized); } @Override @@ -585,15 +589,12 @@ private void fillFeaturesInfo() throws IOException { SolrCache featureVectorCache = null; float[] featureVector; - // Check added otherwise org.apache.solr.ltr.TestLTRScoringQuery.testLTRScoringQuery - // and org.apache.solr.ltr.TestSelectiveWeightCreation.testScoringQueryWeightCreation - // fail if (request != null) { featureVectorCache = request.getSearcher().getFeatureVectorCache(); } if (featureVectorCache != null) { int docId = activeDoc + leafContext.docBase; - int fvCacheKey = fvCacheKey(docId); + int fvCacheKey = computeFeatureVectorCacheKey(docId); featureVector = featureVectorCache.get(fvCacheKey); if (featureVector == null) { featureVector = extractFeatureVector(); @@ -608,14 +609,14 @@ private void fillFeaturesInfo() throws IOException { float featureValue = featureVector[featureId]; if (!Float.isNaN(featureValue) && featureValue != extractedFeatureWeights[i].getDefaultValue()) { - featuresInfo[featureId].setValue(featureValue); - featuresInfo[featureId].setIsDefaultValue(false); + allFeaturesInStore[featureId].setValue(featureValue); + allFeaturesInStore[featureId].setIsDefaultValue(false); } } } } - private int fvCacheKey(int docId) { + private int computeFeatureVectorCacheKey(int docId) { int prime = 31; int result = docId; if (Objects.equals( @@ -644,7 +645,7 @@ private int addEfisHash(int result, int prime) { } protected float[] initFeatureVector(FeatureInfo[] featuresInfos) { - float[] featureVector = new float[featuresInfo.length]; + float[] featureVector = new float[allFeaturesInStore.length]; for (int i = 0; i < featuresInfos.length; i++) { if (featuresInfos[i] != null) { featureVector[i] = featuresInfos[i].getValue(); @@ -656,13 +657,15 @@ protected float[] initFeatureVector(FeatureInfo[] featuresInfos) { protected abstract float[] extractFeatureVector() throws IOException; } - private class SparseModelScorer extends FeatureTraversalScorer { + private class MultiFeaturesScorer extends FeatureTraversalScorer { private final DisiPriorityQueue subScorers; private final List wrappers; private final ScoringQuerySparseIterator sparseIterator; - private SparseModelScorer( - Weight unusedWeight, List featureScorers, LeafReaderContext leafContext) { + private MultiFeaturesScorer( + Weight unusedWeight, + List featureScorers, + LeafReaderContext leafContext) { super(unusedWeight, leafContext); if (featureScorers.size() <= 1) { throw new IllegalArgumentException("There must be at least 2 subScorers"); @@ -685,7 +688,7 @@ public int docID() { protected float[] extractFeatureVector() throws IOException { final DisiWrapper topList = subScorers.topList(); - float[] featureVector = initFeatureVector(featuresInfo); + float[] featureVector = initFeatureVector(allFeaturesInStore); for (DisiWrapper w = topList; w != null; w = w.next) { final Feature.FeatureWeight.FeatureScorer subScorer = (Feature.FeatureWeight.FeatureScorer) w.scorer; @@ -806,11 +809,13 @@ public long cost() { } } - private class DenseModelScorer extends FeatureTraversalScorer { + private class SingleFeatureScorer extends FeatureTraversalScorer { private final List featureScorers; - private DenseModelScorer( - Weight unusedWeight, List featureScorers, LeafReaderContext leafContext) { + private SingleFeatureScorer( + Weight unusedWeight, + List featureScorers, + LeafReaderContext leafContext) { super(unusedWeight, leafContext); this.featureScorers = featureScorers; } @@ -821,7 +826,7 @@ public int docID() { } protected float[] extractFeatureVector() throws IOException { - float[] featureVector = initFeatureVector(featuresInfo); + float[] featureVector = initFeatureVector(allFeaturesInStore); for (final Scorer scorer : featureScorers) { if (scorer.docID() == activeDoc) { Feature.FeatureWeight.FeatureScorer featureScorer = @@ -835,11 +840,6 @@ protected float[] extractFeatureVector() throws IOException { return featureVector; } - @Override - public float getMaxScore(int upTo) throws IOException { - return Float.POSITIVE_INFINITY; - } - @Override public final Collection getChildren() { final ArrayList children = new ArrayList<>(); diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java index 26edda37bfae..d3ebfba22a65 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java @@ -422,7 +422,7 @@ private void implTransform(SolrDocument doc, int docid, DocIterationInfo docInfo if (!(rerankingQuery instanceof OriginalRankingLTRScoringQuery) || hasExplicitFeatureStore) { String featureVector = featureLogger.printFeatureVector( - LTRRescorer.extractFeaturesInfo( + LTRRescorer.extractFeatures( rerankingModelWeight, docid, (!docsWereReranked && docsHaveScores) ? docInfo.score() : null, diff --git a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestLTRScoringQuery.java b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestLTRScoringQuery.java index 7ad5b34f49b9..583ae22f742b 100644 --- a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestLTRScoringQuery.java +++ b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestLTRScoringQuery.java @@ -219,7 +219,7 @@ public void testLTRScoringQuery() throws IOException, ModelException { } int[] posVals = new int[] {0, 1, 2}; int pos = 0; - for (LTRScoringQuery.FeatureInfo fInfo : modelWeight.getFeaturesInfo()) { + for (LTRScoringQuery.FeatureInfo fInfo : modelWeight.getAllFeaturesInStore()) { if (fInfo == null) { continue; } diff --git a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestSelectiveWeightCreation.java b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestSelectiveWeightCreation.java index 8d911087a748..00d72666956d 100644 --- a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestSelectiveWeightCreation.java +++ b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestSelectiveWeightCreation.java @@ -167,7 +167,7 @@ public void testScoringQueryWeightCreation() throws IOException, ModelException searcher, hits.scoreDocs[0].doc, new LTRScoringQuery(ltrScoringModel1)); // features not requested in response - LTRScoringQuery.FeatureInfo[] featuresInfo = modelWeight.getFeaturesInfo(); + LTRScoringQuery.FeatureInfo[] featuresInfo = modelWeight.getAllFeaturesInStore(); assertEquals(features.size(), modelWeight.getModelFeatureValuesNormalized().length); int nonDefaultFeatures = 0; @@ -191,7 +191,7 @@ public void testScoringQueryWeightCreation() throws IOException, ModelException // features requested in response ltrQuery2.setFeatureLogger(new CSVFeatureLogger(FeatureLogger.FeatureFormat.DENSE, true)); modelWeight = performQuery(hits, searcher, hits.scoreDocs[0].doc, ltrQuery2); - featuresInfo = modelWeight.getFeaturesInfo(); + featuresInfo = modelWeight.getAllFeaturesInStore(); assertEquals(features.size(), modelWeight.getModelFeatureValuesNormalized().length); assertEquals(allFeatures.size(), modelWeight.getExtractedFeatureWeights().length); diff --git a/solr/modules/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureExtractionFromMultipleSegments.java b/solr/modules/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureExtractionFromMultipleSegments.java index 3e6e986e5918..d13af84d5c7d 100644 --- a/solr/modules/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureExtractionFromMultipleSegments.java +++ b/solr/modules/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureExtractionFromMultipleSegments.java @@ -236,7 +236,7 @@ public void testFeatureExtractionFromMultipleSegments() throws Exception { query.setQuery( "{!edismax qf='description^1' boost='sum(product(pow(normHits, 0.7), 1600), .1)' v='apple'}"); // request 100 rows, if any rows are fetched from the second or subsequent segments the tests - // should succeed if LTRRescorer::extractFeaturesInfo() advances the doc iterator properly + // should succeed if LTRRescorer::extractFeatures() advances the doc iterator properly int numRows = 100; query.add("rows", Integer.toString(numRows)); query.add("wt", "json"); From 672e514f43ebdfde950da6d7434b899f27c30509 Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Thu, 31 Jul 2025 18:44:13 +0200 Subject: [PATCH 23/54] First attempt to create a dedicated FeatureExtractor class --- .../java/org/apache/solr/ltr/LTRRescorer.java | 24 ------- .../org/apache/solr/ltr/LTRScoringQuery.java | 64 ++++++++++++------- .../LTRFeatureLoggerTransformerFactory.java | 28 +++++++- 3 files changed, 68 insertions(+), 48 deletions(-) diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java index 8d69e36a8c61..c6f8ec5cb35b 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java @@ -255,28 +255,4 @@ protected static Explanation getExplanation( } return rankingWeight.explain(context, deBasedDoc); } - - public static LTRScoringQuery.FeatureInfo[] extractFeatures( - LTRScoringQuery.ModelWeight modelWeight, - int docid, - Float originalDocScore, - List leafContexts) - throws IOException { - final int n = ReaderUtil.subIndex(docid, leafContexts); - final LeafReaderContext atomicContext = leafContexts.get(n); - final int deBasedDoc = docid - atomicContext.docBase; - final LTRScoringQuery.ModelWeight.ModelScorer r = modelWeight.modelScorer(atomicContext); - if ((r == null) || (r.iterator().advance(deBasedDoc) != deBasedDoc)) { - return new LTRScoringQuery.FeatureInfo[0]; - } else { - if (originalDocScore != null) { - // If results have not been reranked, the score passed in is the original query's - // score, which some features can use instead of recalculating it - r.getDocInfo().setOriginalDocScore(originalDocScore); - } - r.setIsLogging(true); - r.score(); - return modelWeight.getAllFeaturesInStore(); - } - } } diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java index 8be866500538..ef5eee29da58 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java @@ -508,7 +508,7 @@ public boolean isCacheable(LeafReaderContext ctx) { public class ModelScorer extends Scorer { private final DocInfo docInfo; - private final Scorer featureTraversalScorer; + private final FeatureTraversalScorer featureTraversalScorer; protected boolean isLogging; public DocInfo getDocInfo() { @@ -558,34 +558,66 @@ public void setIsLogging(boolean isLogging) { this.isLogging = isLogging; } + public void fillFeaturesInfo() throws IOException { + featureTraversalScorer.fillFeaturesInfo(); + } + /** * This class is responsible for extracting features and using them to score the document. */ - abstract class FeatureTraversalScorer extends Scorer { + private abstract class FeatureTraversalScorer extends Scorer { protected int targetDoc = -1; protected int activeDoc = -1; protected LeafReaderContext leafContext; + protected FeatureExtractor featureExtractor; protected FeatureTraversalScorer(Weight weight, LeafReaderContext leafContext) { this.leafContext = leafContext; + this.featureExtractor = new FeatureExtractor(this); + } + + void fillFeaturesInfo() throws IOException { + // Initialize features to their default values and set isDefaultValue to true. + reset(); + featureExtractor.fillFeaturesInfo(); } @Override public float score() throws IOException { // Initialize features to their default values and set isDefaultValue to true. reset(); - fillFeaturesInfo(); + featureExtractor.fillFeaturesInfo(); normalizeFeatures(); return ltrScoringModel.score(modelFeatureValuesNormalized); } @Override - public float getMaxScore(int upTo) throws IOException { + public float getMaxScore(int upTo) { return Float.POSITIVE_INFINITY; } + protected float[] initFeatureVector(FeatureInfo[] featuresInfos) { + float[] featureVector = new float[allFeaturesInStore.length]; + for (int i = 0; i < featuresInfos.length; i++) { + if (featuresInfos[i] != null) { + featureVector[i] = featuresInfos[i].getValue(); + } + } + return featureVector; + } + + protected abstract float[] extractFeatureVector() throws IOException; + } + + private class FeatureExtractor { + private final FeatureTraversalScorer traversalScorer; + + private FeatureExtractor(FeatureTraversalScorer traversalScorer) { + this.traversalScorer = traversalScorer; + } + private void fillFeaturesInfo() throws IOException { - if (activeDoc == targetDoc) { + if (traversalScorer.activeDoc == traversalScorer.targetDoc) { SolrCache featureVectorCache = null; float[] featureVector; @@ -593,15 +625,15 @@ private void fillFeaturesInfo() throws IOException { featureVectorCache = request.getSearcher().getFeatureVectorCache(); } if (featureVectorCache != null) { - int docId = activeDoc + leafContext.docBase; + int docId = traversalScorer.activeDoc + traversalScorer.leafContext.docBase; int fvCacheKey = computeFeatureVectorCacheKey(docId); featureVector = featureVectorCache.get(fvCacheKey); if (featureVector == null) { - featureVector = extractFeatureVector(); + featureVector = traversalScorer.extractFeatureVector(); featureVectorCache.put(fvCacheKey, featureVector); } } else { - featureVector = extractFeatureVector(); + featureVector = traversalScorer.extractFeatureVector(); } for (int i = 0; i < extractedFeatureWeights.length; i++) { @@ -620,8 +652,8 @@ private int computeFeatureVectorCacheKey(int docId) { int prime = 31; int result = docId; if (Objects.equals( - ltrScoringModel.getName(), - LTRFeatureLoggerTransformerFactory.DEFAULT_LOGGING_MODEL_NAME) + ltrScoringModel.getName(), + LTRFeatureLoggerTransformerFactory.DEFAULT_LOGGING_MODEL_NAME) || (isLogging && logger.isLoggingAll())) { result = (prime * result) + ltrScoringModel.getFeatureStoreName().hashCode(); } else { @@ -643,18 +675,6 @@ private int addEfisHash(int result, int prime) { } return result; } - - protected float[] initFeatureVector(FeatureInfo[] featuresInfos) { - float[] featureVector = new float[allFeaturesInStore.length]; - for (int i = 0; i < featuresInfos.length; i++) { - if (featuresInfos[i] != null) { - featureVector[i] = featuresInfos[i].getValue(); - } - } - return featureVector; - } - - protected abstract float[] extractFeatureVector() throws IOException; } private class MultiFeaturesScorer extends FeatureTraversalScorer { diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java index d3ebfba22a65..096422dd0ed1 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java @@ -22,6 +22,7 @@ import java.util.Locale; import java.util.Map; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.ReaderUtil; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.ScoreMode; import org.apache.solr.common.SolrDocument; @@ -30,7 +31,6 @@ import org.apache.solr.common.util.NamedList; import org.apache.solr.ltr.CSVFeatureLogger; import org.apache.solr.ltr.FeatureLogger; -import org.apache.solr.ltr.LTRRescorer; import org.apache.solr.ltr.LTRScoringQuery; import org.apache.solr.ltr.LTRThreadModule; import org.apache.solr.ltr.SolrQueryRequestContextUtils; @@ -407,6 +407,30 @@ public void transform(SolrDocument doc, int docid, DocIterationInfo docInfo) implTransform(doc, docid, docInfo); } + private static LTRScoringQuery.FeatureInfo[] extractFeatures( + LTRScoringQuery.ModelWeight modelWeight, + int docid, + Float originalDocScore, + List leafContexts) + throws IOException { + final int n = ReaderUtil.subIndex(docid, leafContexts); + final LeafReaderContext atomicContext = leafContexts.get(n); + final int deBasedDoc = docid - atomicContext.docBase; + final LTRScoringQuery.ModelWeight.ModelScorer r = modelWeight.scorer(atomicContext); + if ((r == null) || (r.iterator().advance(deBasedDoc) != deBasedDoc)) { + return new LTRScoringQuery.FeatureInfo[0]; + } else { + if (originalDocScore != null) { + // If results have not been reranked, the score passed in is the original query's + // score, which some features can use instead of recalculating it + r.getDocInfo().setOriginalDocScore(originalDocScore); + } + r.fillFeaturesInfo(); + r.setIsLogging(true); + return modelWeight.getAllFeaturesInStore(); + } + } + private void implTransform(SolrDocument doc, int docid, DocIterationInfo docInfo) throws IOException { LTRScoringQuery rerankingQuery = rerankingQueries[0]; @@ -422,7 +446,7 @@ private void implTransform(SolrDocument doc, int docid, DocIterationInfo docInfo if (!(rerankingQuery instanceof OriginalRankingLTRScoringQuery) || hasExplicitFeatureStore) { String featureVector = featureLogger.printFeatureVector( - LTRRescorer.extractFeatures( + extractFeatures( rerankingModelWeight, docid, (!docsWereReranked && docsHaveScores) ? docInfo.score() : null, From e883ef41fa060be569ad39ea56d43fd3df5e0039 Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Tue, 5 Aug 2025 11:26:15 +0200 Subject: [PATCH 24/54] Changed implementation for getting DocID for feature vectore cache and removed unuseful leafContext. --- .../org/apache/solr/ltr/LTRScoringQuery.java | 27 +++++++++---------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java index ef5eee29da58..ed99c3b7179c 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java @@ -497,7 +497,7 @@ public ModelScorer modelScorer(LeafReaderContext context) throws IOException { // score on the model for every document, since 0 features matching could // return a // non 0 score for a given model. - ModelScorer mscorer = new ModelScorer(this, featureScorers, context); + ModelScorer mscorer = new ModelScorer(this, featureScorers); return mscorer; } @@ -515,7 +515,9 @@ public DocInfo getDocInfo() { return docInfo; } - public ModelScorer(Weight weight, List featureScorers, LeafReaderContext leafContext) { + public ModelScorer( + Weight weight, + List featureScorers) { isLogging = false; docInfo = new DocInfo(); for (final Feature.FeatureWeight.FeatureScorer subScorer : featureScorers) { @@ -523,9 +525,9 @@ public ModelScorer(Weight weight, List feat } if (featureScorers.size() <= 1) { // future enhancement: allow the use of dense features in other cases - featureTraversalScorer = new SingleFeatureScorer(weight, featureScorers, leafContext); + featureTraversalScorer = new SingleFeatureScorer(weight, featureScorers); } else { - featureTraversalScorer = new MultiFeaturesScorer(weight, featureScorers, leafContext); + featureTraversalScorer = new MultiFeaturesScorer(weight, featureScorers); } } @@ -568,11 +570,9 @@ public void fillFeaturesInfo() throws IOException { private abstract class FeatureTraversalScorer extends Scorer { protected int targetDoc = -1; protected int activeDoc = -1; - protected LeafReaderContext leafContext; protected FeatureExtractor featureExtractor; - protected FeatureTraversalScorer(Weight weight, LeafReaderContext leafContext) { - this.leafContext = leafContext; + protected FeatureTraversalScorer(Weight weight) { this.featureExtractor = new FeatureExtractor(this); } @@ -625,8 +625,7 @@ private void fillFeaturesInfo() throws IOException { featureVectorCache = request.getSearcher().getFeatureVectorCache(); } if (featureVectorCache != null) { - int docId = traversalScorer.activeDoc + traversalScorer.leafContext.docBase; - int fvCacheKey = computeFeatureVectorCacheKey(docId); + int fvCacheKey = computeFeatureVectorCacheKey(traversalScorer.docID()); featureVector = featureVectorCache.get(fvCacheKey); if (featureVector == null) { featureVector = traversalScorer.extractFeatureVector(); @@ -684,9 +683,8 @@ private class MultiFeaturesScorer extends FeatureTraversalScorer { private MultiFeaturesScorer( Weight unusedWeight, - List featureScorers, - LeafReaderContext leafContext) { - super(unusedWeight, leafContext); + List featureScorers) { + super(unusedWeight); if (featureScorers.size() <= 1) { throw new IllegalArgumentException("There must be at least 2 subScorers"); } @@ -834,9 +832,8 @@ private class SingleFeatureScorer extends FeatureTraversalScorer { private SingleFeatureScorer( Weight unusedWeight, - List featureScorers, - LeafReaderContext leafContext) { - super(unusedWeight, leafContext); + List featureScorers) { + super(unusedWeight); this.featureScorers = featureScorers; } From 318d2b0ca4a63fda546150be72886c50bcf7b61f Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Tue, 5 Aug 2025 15:17:20 +0200 Subject: [PATCH 25/54] Second version of code refactoring to move extraction part in the FeatureExtractor class --- .../org/apache/solr/ltr/LTRScoringQuery.java | 88 +++++++++++-------- 1 file changed, 51 insertions(+), 37 deletions(-) diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java index ed99c3b7179c..362c43ab81cc 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java @@ -595,6 +595,14 @@ public float score() throws IOException { public float getMaxScore(int upTo) { return Float.POSITIVE_INFINITY; } + } + + private class FeatureExtractor { + private final FeatureTraversalScorer traversalScorer; + + private FeatureExtractor(FeatureTraversalScorer traversalScorer) { + this.traversalScorer = traversalScorer; + } protected float[] initFeatureVector(FeatureInfo[] featuresInfos) { float[] featureVector = new float[allFeaturesInStore.length]; @@ -606,14 +614,41 @@ protected float[] initFeatureVector(FeatureInfo[] featuresInfos) { return featureVector; } - protected abstract float[] extractFeatureVector() throws IOException; - } + private float[] extractFeatureVector() throws IOException { + if (traversalScorer instanceof MultiFeaturesScorer) { + MultiFeaturesScorer multiFeaturesScorer = (MultiFeaturesScorer) traversalScorer; + return extractMultiFeaturesVector(multiFeaturesScorer.getSubScorers()); + } else { + SingleFeatureScorer singleFeatureScorer = (SingleFeatureScorer) traversalScorer; + return extractSingleFeatureVector(singleFeatureScorer.getFeatureScorers()); + } + } - private class FeatureExtractor { - private final FeatureTraversalScorer traversalScorer; + protected float[] extractSingleFeatureVector(List featureScorers) throws IOException { + float[] featureVector = initFeatureVector(allFeaturesInStore); + for (int i = 0; i < featureScorers.size(); i++) { + Scorer scorer = featureScorers.get(i); + if (scorer.docID() == traversalScorer.activeDoc) { + Feature.FeatureWeight scFW = (Feature.FeatureWeight) scorer.getWeight(); + final int featureId = scFW.getIndex(); + float featureValue = scorer.score(); + featureVector[featureId] = featureValue; + } + } + return featureVector; + } - private FeatureExtractor(FeatureTraversalScorer traversalScorer) { - this.traversalScorer = traversalScorer; + protected float[] extractMultiFeaturesVector(DisiPriorityQueue subScorers) throws IOException { + final DisiWrapper topList = subScorers.topList(); + float[] featureVector = initFeatureVector(allFeaturesInStore); + for (DisiWrapper w = topList; w != null; w = w.next) { + final Scorer subScorer = w.scorer; + Feature.FeatureWeight feature = (Feature.FeatureWeight) subScorer.getWeight(); + final int featureId = feature.getIndex(); + float featureValue = subScorer.score(); + featureVector[featureId] = featureValue; + } + return featureVector; } private void fillFeaturesInfo() throws IOException { @@ -628,11 +663,11 @@ private void fillFeaturesInfo() throws IOException { int fvCacheKey = computeFeatureVectorCacheKey(traversalScorer.docID()); featureVector = featureVectorCache.get(fvCacheKey); if (featureVector == null) { - featureVector = traversalScorer.extractFeatureVector(); + featureVector = extractFeatureVector(); featureVectorCache.put(fvCacheKey, featureVector); } } else { - featureVector = traversalScorer.extractFeatureVector(); + featureVector = extractFeatureVector(); } for (int i = 0; i < extractedFeatureWeights.length; i++) { @@ -699,25 +734,15 @@ private MultiFeaturesScorer( sparseIterator = new ScoringQuerySparseIterator(wrappers); } + private DisiPriorityQueue getSubScorers() { + return this.subScorers; + } + @Override public int docID() { return sparseIterator.docID(); } - protected float[] extractFeatureVector() throws IOException { - final DisiWrapper topList = subScorers.topList(); - float[] featureVector = initFeatureVector(allFeaturesInStore); - for (DisiWrapper w = topList; w != null; w = w.next) { - final Feature.FeatureWeight.FeatureScorer subScorer = - (Feature.FeatureWeight.FeatureScorer) w.scorer; - Feature.FeatureWeight scFW = subScorer.getWeight(); - final int featureId = scFW.getIndex(); - float featureValue = subScorer.score(); - featureVector[featureId] = featureValue; - } - return featureVector; - } - @Override public DocIdSetIterator iterator() { return sparseIterator; @@ -837,26 +862,15 @@ private SingleFeatureScorer( this.featureScorers = featureScorers; } + private List getFeatureScorers() { + return this.featureScorers; + } + @Override public int docID() { return targetDoc; } - protected float[] extractFeatureVector() throws IOException { - float[] featureVector = initFeatureVector(allFeaturesInStore); - for (final Scorer scorer : featureScorers) { - if (scorer.docID() == activeDoc) { - Feature.FeatureWeight.FeatureScorer featureScorer = - (Feature.FeatureWeight.FeatureScorer) scorer; - Feature.FeatureWeight scFW = featureScorer.getWeight(); - final int featureId = scFW.getIndex(); - float featureValue = scorer.score(); - featureVector[featureId] = featureValue; - } - } - return featureVector; - } - @Override public final Collection getChildren() { final ArrayList children = new ArrayList<>(); From 0088ca9bf468ed95024c8814e9a53972198ac3f4 Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Tue, 5 Aug 2025 15:44:16 +0200 Subject: [PATCH 26/54] Refactoring of the code --- .../org/apache/solr/ltr/LTRScoringQuery.java | 302 +++++++++--------- 1 file changed, 149 insertions(+), 153 deletions(-) diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java index 362c43ab81cc..a13757b2e4f7 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java @@ -525,9 +525,9 @@ public ModelScorer( } if (featureScorers.size() <= 1) { // future enhancement: allow the use of dense features in other cases - featureTraversalScorer = new SingleFeatureScorer(weight, featureScorers); + featureTraversalScorer = new SingleFeatureScorer(featureScorers); } else { - featureTraversalScorer = new MultiFeaturesScorer(weight, featureScorers); + featureTraversalScorer = new MultiFeaturesScorer(featureScorers); } } @@ -572,7 +572,7 @@ private abstract class FeatureTraversalScorer extends Scorer { protected int activeDoc = -1; protected FeatureExtractor featureExtractor; - protected FeatureTraversalScorer(Weight weight) { + protected FeatureTraversalScorer() { this.featureExtractor = new FeatureExtractor(this); } @@ -597,129 +597,89 @@ public float getMaxScore(int upTo) { } } - private class FeatureExtractor { - private final FeatureTraversalScorer traversalScorer; + private class SingleFeatureScorer extends FeatureTraversalScorer { + private final List featureScorers; - private FeatureExtractor(FeatureTraversalScorer traversalScorer) { - this.traversalScorer = traversalScorer; + private SingleFeatureScorer( + List featureScorers) { + this.featureScorers = featureScorers; } - protected float[] initFeatureVector(FeatureInfo[] featuresInfos) { - float[] featureVector = new float[allFeaturesInStore.length]; - for (int i = 0; i < featuresInfos.length; i++) { - if (featuresInfos[i] != null) { - featureVector[i] = featuresInfos[i].getValue(); - } - } - return featureVector; + private List getFeatureScorers() { + return this.featureScorers; } - private float[] extractFeatureVector() throws IOException { - if (traversalScorer instanceof MultiFeaturesScorer) { - MultiFeaturesScorer multiFeaturesScorer = (MultiFeaturesScorer) traversalScorer; - return extractMultiFeaturesVector(multiFeaturesScorer.getSubScorers()); - } else { - SingleFeatureScorer singleFeatureScorer = (SingleFeatureScorer) traversalScorer; - return extractSingleFeatureVector(singleFeatureScorer.getFeatureScorers()); - } + @Override + public int docID() { + return targetDoc; } - protected float[] extractSingleFeatureVector(List featureScorers) throws IOException { - float[] featureVector = initFeatureVector(allFeaturesInStore); - for (int i = 0; i < featureScorers.size(); i++) { - Scorer scorer = featureScorers.get(i); - if (scorer.docID() == traversalScorer.activeDoc) { - Feature.FeatureWeight scFW = (Feature.FeatureWeight) scorer.getWeight(); - final int featureId = scFW.getIndex(); - float featureValue = scorer.score(); - featureVector[featureId] = featureValue; - } + @Override + public final Collection getChildren() { + final ArrayList children = new ArrayList<>(); + for (final Scorer scorer : featureScorers) { + children.add(new ChildScorable(scorer, "SHOULD")); } - return featureVector; + return children; } - protected float[] extractMultiFeaturesVector(DisiPriorityQueue subScorers) throws IOException { - final DisiWrapper topList = subScorers.topList(); - float[] featureVector = initFeatureVector(allFeaturesInStore); - for (DisiWrapper w = topList; w != null; w = w.next) { - final Scorer subScorer = w.scorer; - Feature.FeatureWeight feature = (Feature.FeatureWeight) subScorer.getWeight(); - final int featureId = feature.getIndex(); - float featureValue = subScorer.score(); - featureVector[featureId] = featureValue; - } - return featureVector; + @Override + public DocIdSetIterator iterator() { + return new SingleFeatureIterator(); } - private void fillFeaturesInfo() throws IOException { - if (traversalScorer.activeDoc == traversalScorer.targetDoc) { - SolrCache featureVectorCache = null; - float[] featureVector; + private class SingleFeatureIterator extends DocIdSetIterator { - if (request != null) { - featureVectorCache = request.getSearcher().getFeatureVectorCache(); - } - if (featureVectorCache != null) { - int fvCacheKey = computeFeatureVectorCacheKey(traversalScorer.docID()); - featureVector = featureVectorCache.get(fvCacheKey); - if (featureVector == null) { - featureVector = extractFeatureVector(); - featureVectorCache.put(fvCacheKey, featureVector); - } - } else { - featureVector = extractFeatureVector(); - } + @Override + public int docID() { + return targetDoc; + } - for (int i = 0; i < extractedFeatureWeights.length; i++) { - int featureId = extractedFeatureWeights[i].getIndex(); - float featureValue = featureVector[featureId]; - if (!Float.isNaN(featureValue) - && featureValue != extractedFeatureWeights[i].getDefaultValue()) { - allFeaturesInStore[featureId].setValue(featureValue); - allFeaturesInStore[featureId].setIsDefaultValue(false); + @Override + public int nextDoc() throws IOException { + if (activeDoc <= targetDoc) { + activeDoc = NO_MORE_DOCS; + for (final Scorer scorer : featureScorers) { + if (scorer.docID() != NO_MORE_DOCS) { + activeDoc = Math.min(activeDoc, scorer.iterator().nextDoc()); + } } } + return ++targetDoc; } - } - private int computeFeatureVectorCacheKey(int docId) { - int prime = 31; - int result = docId; - if (Objects.equals( - ltrScoringModel.getName(), - LTRFeatureLoggerTransformerFactory.DEFAULT_LOGGING_MODEL_NAME) - || (isLogging && logger.isLoggingAll())) { - result = (prime * result) + ltrScoringModel.getFeatureStoreName().hashCode(); - } else { - result = (prime * result) + ltrScoringModel.getName().hashCode(); + @Override + public int advance(int target) throws IOException { + if (activeDoc < target) { + activeDoc = NO_MORE_DOCS; + for (final Scorer scorer : featureScorers) { + if (scorer.docID() != NO_MORE_DOCS) { + activeDoc = Math.min(activeDoc, scorer.iterator().advance(target)); + } + } + } + targetDoc = target; + return target; } - result = (prime * result) + addEfisHash(result, prime); - return result; - } - private int addEfisHash(int result, int prime) { - if (efi != null) { - TreeMap sorted = new TreeMap<>(efi); - for (final Map.Entry entry : sorted.entrySet()) { - final String key = entry.getKey(); - final String[] values = entry.getValue(); - result = (prime * result) + key.hashCode(); - result = (prime * result) + Arrays.hashCode(values); + @Override + public long cost() { + long sum = 0; + for (int i = 0; i < featureScorers.size(); i++) { + sum += featureScorers.get(i).iterator().cost(); } + return sum; } - return result; } } private class MultiFeaturesScorer extends FeatureTraversalScorer { private final DisiPriorityQueue subScorers; private final List wrappers; - private final ScoringQuerySparseIterator sparseIterator; + private final MultiFeaturesIterator multiFeaturesIteratorIterator; private MultiFeaturesScorer( - Weight unusedWeight, List featureScorers) { - super(unusedWeight); if (featureScorers.size() <= 1) { throw new IllegalArgumentException("There must be at least 2 subScorers"); } @@ -731,7 +691,7 @@ private MultiFeaturesScorer( wrappers.add(w); } - sparseIterator = new ScoringQuerySparseIterator(wrappers); + multiFeaturesIteratorIterator = new MultiFeaturesIterator(wrappers); } private DisiPriorityQueue getSubScorers() { @@ -740,12 +700,12 @@ private DisiPriorityQueue getSubScorers() { @Override public int docID() { - return sparseIterator.docID(); + return multiFeaturesIteratorIterator.docID(); } @Override public DocIdSetIterator iterator() { - return sparseIterator; + return multiFeaturesIteratorIterator; } @Override @@ -757,9 +717,9 @@ public final Collection getChildren() { return children; } - private class ScoringQuerySparseIterator extends DocIdSetIterator { + private class MultiFeaturesIterator extends DocIdSetIterator { - public ScoringQuerySparseIterator(Collection wrappers) { + public MultiFeaturesIterator(Collection wrappers) { // Initialize all wrappers to start at -1 for (DisiWrapper wrapper : wrappers) { wrapper.doc = -1; @@ -852,81 +812,117 @@ public long cost() { } } - private class SingleFeatureScorer extends FeatureTraversalScorer { - private final List featureScorers; + private class FeatureExtractor { + private final FeatureTraversalScorer traversalScorer; - private SingleFeatureScorer( - Weight unusedWeight, - List featureScorers) { - super(unusedWeight); - this.featureScorers = featureScorers; + private FeatureExtractor(FeatureTraversalScorer traversalScorer) { + this.traversalScorer = traversalScorer; } - private List getFeatureScorers() { - return this.featureScorers; + protected float[] initFeatureVector(FeatureInfo[] featuresInfos) { + float[] featureVector = new float[allFeaturesInStore.length]; + for (int i = 0; i < featuresInfos.length; i++) { + if (featuresInfos[i] != null) { + featureVector[i] = featuresInfos[i].getValue(); + } + } + return featureVector; } - @Override - public int docID() { - return targetDoc; + private float[] extractFeatureVector() throws IOException { + if (traversalScorer instanceof MultiFeaturesScorer) { + MultiFeaturesScorer multiFeaturesScorer = (MultiFeaturesScorer) traversalScorer; + return extractMultiFeaturesVector(multiFeaturesScorer.getSubScorers()); + } else { + SingleFeatureScorer singleFeatureScorer = (SingleFeatureScorer) traversalScorer; + return extractSingleFeatureVector(singleFeatureScorer.getFeatureScorers()); + } } - @Override - public final Collection getChildren() { - final ArrayList children = new ArrayList<>(); - for (final Scorer scorer : featureScorers) { - children.add(new ChildScorable(scorer, "SHOULD")); + protected float[] extractSingleFeatureVector(List featureScorers) throws IOException { + float[] featureVector = initFeatureVector(allFeaturesInStore); + for (int i = 0; i < featureScorers.size(); i++) { + Scorer scorer = featureScorers.get(i); + if (scorer.docID() == traversalScorer.activeDoc) { + Feature.FeatureWeight scFW = (Feature.FeatureWeight) scorer.getWeight(); + final int featureId = scFW.getIndex(); + float featureValue = scorer.score(); + featureVector[featureId] = featureValue; + } } - return children; + return featureVector; } - @Override - public DocIdSetIterator iterator() { - return new DenseIterator(); + protected float[] extractMultiFeaturesVector(DisiPriorityQueue subScorers) throws IOException { + final DisiWrapper topList = subScorers.topList(); + float[] featureVector = initFeatureVector(allFeaturesInStore); + for (DisiWrapper w = topList; w != null; w = w.next) { + final Scorer subScorer = w.scorer; + Feature.FeatureWeight feature = (Feature.FeatureWeight) subScorer.getWeight(); + final int featureId = feature.getIndex(); + float featureValue = subScorer.score(); + featureVector[featureId] = featureValue; + } + return featureVector; } - private class DenseIterator extends DocIdSetIterator { - - @Override - public int docID() { - return targetDoc; - } + private void fillFeaturesInfo() throws IOException { + if (traversalScorer.activeDoc == traversalScorer.targetDoc) { + SolrCache featureVectorCache = null; + float[] featureVector; - @Override - public int nextDoc() throws IOException { - if (activeDoc <= targetDoc) { - activeDoc = NO_MORE_DOCS; - for (final Scorer scorer : featureScorers) { - if (scorer.docID() != NO_MORE_DOCS) { - activeDoc = Math.min(activeDoc, scorer.iterator().nextDoc()); - } + if (request != null) { + featureVectorCache = request.getSearcher().getFeatureVectorCache(); + } + if (featureVectorCache != null) { + int fvCacheKey = computeFeatureVectorCacheKey(traversalScorer.docID()); + featureVector = featureVectorCache.get(fvCacheKey); + if (featureVector == null) { + featureVector = extractFeatureVector(); + featureVectorCache.put(fvCacheKey, featureVector); } + } else { + featureVector = extractFeatureVector(); } - return ++targetDoc; - } - @Override - public int advance(int target) throws IOException { - if (activeDoc < target) { - activeDoc = NO_MORE_DOCS; - for (final Scorer scorer : featureScorers) { - if (scorer.docID() != NO_MORE_DOCS) { - activeDoc = Math.min(activeDoc, scorer.iterator().advance(target)); - } + for (int i = 0; i < extractedFeatureWeights.length; i++) { + int featureId = extractedFeatureWeights[i].getIndex(); + float featureValue = featureVector[featureId]; + if (!Float.isNaN(featureValue) + && featureValue != extractedFeatureWeights[i].getDefaultValue()) { + allFeaturesInStore[featureId].setValue(featureValue); + allFeaturesInStore[featureId].setIsDefaultValue(false); } } - targetDoc = target; - return target; } + } - @Override - public long cost() { - long sum = 0; - for (int i = 0; i < featureScorers.size(); i++) { - sum += featureScorers.get(i).iterator().cost(); + private int computeFeatureVectorCacheKey(int docId) { + int prime = 31; + int result = docId; + if (Objects.equals( + ltrScoringModel.getName(), + LTRFeatureLoggerTransformerFactory.DEFAULT_LOGGING_MODEL_NAME) + || (isLogging && logger.isLoggingAll())) { + result = (prime * result) + ltrScoringModel.getFeatureStoreName().hashCode(); + } else { + result = (prime * result) + ltrScoringModel.getName().hashCode(); + } + result = (prime * result) + addEfisHash(result, prime); + return result; + } + + private int addEfisHash(int result, int prime) { + if (efi != null) { + TreeMap sorted = new TreeMap<>(efi); + for (final Map.Entry entry : sorted.entrySet()) { + final String key = entry.getKey(); + final String[] values = entry.getValue(); + result = (prime * result) + key.hashCode(); + result = (prime * result) + Arrays.hashCode(values); } - return sum; } + return result; } } } From eaa41901f82a77c4efc3cd619ce1b9c06423cee5 Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Tue, 5 Aug 2025 15:47:28 +0200 Subject: [PATCH 27/54] Gradlew tidy --- .../java/org/apache/solr/ltr/LTRScoringQuery.java | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java index a13757b2e4f7..ff3dfb6a1c10 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java @@ -516,7 +516,6 @@ public DocInfo getDocInfo() { } public ModelScorer( - Weight weight, List featureScorers) { isLogging = false; docInfo = new DocInfo(); @@ -564,9 +563,7 @@ public void fillFeaturesInfo() throws IOException { featureTraversalScorer.fillFeaturesInfo(); } - /** - * This class is responsible for extracting features and using them to score the document. - */ + /** This class is responsible for extracting features and using them to score the document. */ private abstract class FeatureTraversalScorer extends Scorer { protected int targetDoc = -1; protected int activeDoc = -1; @@ -839,7 +836,8 @@ private float[] extractFeatureVector() throws IOException { } } - protected float[] extractSingleFeatureVector(List featureScorers) throws IOException { + protected float[] extractSingleFeatureVector( + List featureScorers) throws IOException { float[] featureVector = initFeatureVector(allFeaturesInStore); for (int i = 0; i < featureScorers.size(); i++) { Scorer scorer = featureScorers.get(i); @@ -853,7 +851,8 @@ protected float[] extractSingleFeatureVector(List Date: Tue, 5 Aug 2025 16:49:13 +0200 Subject: [PATCH 28/54] Added missing dependency --- solr/modules/ltr/build.gradle | 2 ++ 1 file changed, 2 insertions(+) diff --git a/solr/modules/ltr/build.gradle b/solr/modules/ltr/build.gradle index 61e02bb645d0..6467da570def 100644 --- a/solr/modules/ltr/build.gradle +++ b/solr/modules/ltr/build.gradle @@ -56,6 +56,8 @@ dependencies { testImplementation libs.hamcrest.hamcrest testImplementation libs.commonsio.commonsio + + testImplementation libs.dropwizard.metrics.core } task copyPythonClientToExample(type: Sync) { From 3bb752b9a5dd865f5435d2871831671c23674ccc Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Mon, 11 Aug 2025 09:36:38 +0200 Subject: [PATCH 29/54] Updated sample_techproducts_configs solrconfig --- .../sample_techproducts_configs/conf/solrconfig.xml | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/solr/server/solr/configsets/sample_techproducts_configs/conf/solrconfig.xml b/solr/server/solr/configsets/sample_techproducts_configs/conf/solrconfig.xml index 1dd706b13b3d..1e2fbe7af88d 100644 --- a/solr/server/solr/configsets/sample_techproducts_configs/conf/solrconfig.xml +++ b/solr/server/solr/configsets/sample_techproducts_configs/conf/solrconfig.xml @@ -401,12 +401,11 @@ https://solr.apache.org/guide/solr/latest/query-guide/learning-to-rank.html --> - + autowarmCount="0" /> - QUERY_DOC_FV From 62a51fa5c99babd325eb5197c4d1a562a0e16ca6 Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Mon, 11 Aug 2025 17:23:27 +0200 Subject: [PATCH 30/54] Small changes after merge with Lucene 10 --- .../ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java | 3 +-- .../response/transform/LTRFeatureLoggerTransformerFactory.java | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java index ff3dfb6a1c10..c57d012b32f9 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java @@ -597,8 +597,7 @@ public float getMaxScore(int upTo) { private class SingleFeatureScorer extends FeatureTraversalScorer { private final List featureScorers; - private SingleFeatureScorer( - List featureScorers) { + private SingleFeatureScorer(List featureScorers) { this.featureScorers = featureScorers; } diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java index 096422dd0ed1..1ca7edde341c 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java @@ -416,7 +416,7 @@ private static LTRScoringQuery.FeatureInfo[] extractFeatures( final int n = ReaderUtil.subIndex(docid, leafContexts); final LeafReaderContext atomicContext = leafContexts.get(n); final int deBasedDoc = docid - atomicContext.docBase; - final LTRScoringQuery.ModelWeight.ModelScorer r = modelWeight.scorer(atomicContext); + final LTRScoringQuery.ModelWeight.ModelScorer r = modelWeight.modelScorer(atomicContext); if ((r == null) || (r.iterator().advance(deBasedDoc) != deBasedDoc)) { return new LTRScoringQuery.FeatureInfo[0]; } else { From 46877f6546235cb7641a2ece6da06633f0000336 Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Mon, 15 Sep 2025 14:53:52 +0200 Subject: [PATCH 31/54] Gradlew tidy and small typos --- solr/CHANGES.txt | 4 +++- .../org/apache/solr/ltr/LTRScoringQuery.java | 22 +++++++++---------- .../solr/ltr/TestFeatureVectorCache.java | 14 ++++++------ .../org/apache/solr/ltr/TestRerankBase.java | 2 +- 4 files changed, 22 insertions(+), 20 deletions(-) diff --git a/solr/CHANGES.txt b/solr/CHANGES.txt index 837e05caa6c2..863fb629ad22 100644 --- a/solr/CHANGES.txt +++ b/solr/CHANGES.txt @@ -323,6 +323,8 @@ New Features subset of JavaScript, pre-compiled, and that which can access the score and fields. It's powered by the Lucene Expressions module. (hossman, David Smiley, Ryan Ernst, Kevin Risden) +* SOLR-16667: LTR Add feature vector caching for ranking. (Anna Ruggero, Alessandro Benedetti) + Improvements --------------------- * SOLR-15751: The v2 API now has parity with the v1 "COLSTATUS" and "segments" APIs, which can be used to fetch detailed information about @@ -438,7 +440,7 @@ Bug Fixes * SOLR-17726: MoreLikeThis to support copy-fields (Ilaria Petreti via Alessandro Benedetti) -* SOLR-16667: Fixed dense/sparse representation in LTR module. (Anna Ruggero, Alessandro Benedetti) +* SOLR-17760: Fixed dense/sparse representation in LTR module. (Anna Ruggero, Alessandro Benedetti) * SOLR-17800: Security Manager should handle symlink on /tmp (Kevin Risden) diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java index c57d012b32f9..9497c986fcda 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java @@ -49,8 +49,8 @@ import org.apache.solr.ltr.model.LTRScoringModel; import org.apache.solr.ltr.response.transform.LTRFeatureLoggerTransformerFactory; import org.apache.solr.request.SolrQueryRequest; -import org.apache.solr.util.SolrDefaultScorerSupplier; import org.apache.solr.search.SolrCache; +import org.apache.solr.util.SolrDefaultScorerSupplier; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -497,7 +497,7 @@ public ModelScorer modelScorer(LeafReaderContext context) throws IOException { // score on the model for every document, since 0 features matching could // return a // non 0 score for a given model. - ModelScorer mscorer = new ModelScorer(this, featureScorers); + ModelScorer mscorer = new ModelScorer(featureScorers); return mscorer; } @@ -515,8 +515,7 @@ public DocInfo getDocInfo() { return docInfo; } - public ModelScorer( - List featureScorers) { + public ModelScorer(List featureScorers) { isLogging = false; docInfo = new DocInfo(); for (final Feature.FeatureWeight.FeatureScorer subScorer : featureScorers) { @@ -674,8 +673,7 @@ private class MultiFeaturesScorer extends FeatureTraversalScorer { private final List wrappers; private final MultiFeaturesIterator multiFeaturesIteratorIterator; - private MultiFeaturesScorer( - List featureScorers) { + private MultiFeaturesScorer(List featureScorers) { if (featureScorers.size() <= 1) { throw new IllegalArgumentException("There must be at least 2 subScorers"); } @@ -838,10 +836,11 @@ private float[] extractFeatureVector() throws IOException { protected float[] extractSingleFeatureVector( List featureScorers) throws IOException { float[] featureVector = initFeatureVector(allFeaturesInStore); - for (int i = 0; i < featureScorers.size(); i++) { - Scorer scorer = featureScorers.get(i); + for (final Scorer scorer : featureScorers) { if (scorer.docID() == traversalScorer.activeDoc) { - Feature.FeatureWeight scFW = (Feature.FeatureWeight) scorer.getWeight(); + Feature.FeatureWeight.FeatureScorer featureScorer = + (Feature.FeatureWeight.FeatureScorer) scorer; + Feature.FeatureWeight scFW = featureScorer.getWeight(); final int featureId = scFW.getIndex(); float featureValue = scorer.score(); featureVector[featureId] = featureValue; @@ -855,8 +854,9 @@ protected float[] extractMultiFeaturesVector(DisiPriorityQueue subScorers) final DisiWrapper topList = subScorers.topList(); float[] featureVector = initFeatureVector(allFeaturesInStore); for (DisiWrapper w = topList; w != null; w = w.next) { - final Scorer subScorer = w.scorer; - Feature.FeatureWeight feature = (Feature.FeatureWeight) subScorer.getWeight(); + final Feature.FeatureWeight.FeatureScorer subScorer = + (Feature.FeatureWeight.FeatureScorer) w.scorer; + Feature.FeatureWeight feature = subScorer.getWeight(); final int featureId = feature.getIndex(); float featureValue = subScorer.score(); featureVector[featureId] = featureValue; diff --git a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestFeatureVectorCache.java b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestFeatureVectorCache.java index 7c7db2ba9617..4369685e14f8 100644 --- a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestFeatureVectorCache.java +++ b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestFeatureVectorCache.java @@ -30,7 +30,7 @@ public class TestFeatureVectorCache extends TestRerankBase { @Before public void before() throws Exception { - setupFeatureVectorCachetest(false); + setupFeatureVectorCacheTest(false); assertU(adoc("id", "1", "title", "w2", "description", "w2", "popularity", "2")); assertU(adoc("id", "2", "title", "w1", "description", "w1", "popularity", "0")); @@ -92,7 +92,7 @@ public void testFeatureVectorCache_loggingDefaultStoreNoReranking() throws Excep query.add("rows", "3"); query.add("fl", "[fv efi.efi_feature=3]"); - // No caching, we want to see lookups, an insertions and no hits + // No caching, we want to see lookups, insertions and no hits assertJQ( "/query" + query.toQueryString(), "/response/docs/[0]/=={'[fv]':'" + docs0fv_default_csv + "'}"); @@ -127,7 +127,7 @@ public void testFeatureVectorCache_loggingExplicitStoreNoReranking() throws Exce query.add("rows", "3"); query.add("fl", "[fv store=store1 efi.efi_feature=3]"); - // No caching, we want to see lookups, an insertions and no hits + // No caching, we want to see lookups, insertions and no hits assertJQ( "/query" + query.toQueryString(), "/response/docs/[0]/=={'[fv]':'" + docs0fv_default_csv + "'}"); @@ -172,7 +172,7 @@ public void testFeatureVectorCache_loggingModelStoreRerankingDifferentEfi() thro query.add("fl", "id,score,fv:[fv efi.efi_feature=3]"); query.add("rq", "{!ltr reRankDocs=3 model=featurevectorcache_linear_model efi.efi_feature=4}"); - // No caching, we want to see lookups, an insertions and no hits since the efi are different + // No caching, we want to see lookups, insertions and no hits since the efi are different assertJQ( "/query" + query.toQueryString(), "/response/docs/[0]/=={" @@ -258,7 +258,7 @@ public void testFeatureVectorCache_loggingModelStoreRerankingSameEfi() throws Ex } @Test - public void testFeatureVectorCache_loggingAllStoreReranking() throws Exception { + public void testFeatureVectorCache_loggingAllFeatureStoreAndReranking() throws Exception { final String docs0fv_dense_csv = FeatureLoggerTestUtils.toFeatureVector( "value_feature_1", @@ -291,7 +291,7 @@ public void testFeatureVectorCache_loggingAllStoreReranking() throws Exception { query.add("fl", "id,score,fv:[fv logAll=true efi.efi_feature=3]"); query.add("rq", "{!ltr reRankDocs=3 model=featurevectorcache_linear_model efi.efi_feature=4}"); - // No caching, we want to see lookups, an insertions and no hits since the efi are different + // No caching, we want to see lookups, insertions and no hits since the efi are different assertJQ( "/query" + query.toQueryString(), "/response/docs/[0]/=={" @@ -337,7 +337,7 @@ public void testFeatureVectorCache_loggingExplicitStoreReranking() throws Except query.add("fl", "id,score,fv:[fv store=store1 efi.efi_feature=3]"); query.add("rq", "{!ltr reRankDocs=3 model=featurevectorcache_linear_model efi.efi_feature=4}"); - // No caching, we want to see lookups, an insertions and no hits since the efi are different + // No caching, we want to see lookups, insertions and no hits since the efi are different assertJQ( "/query" + query.toQueryString(), "/response/docs/[0]/=={" diff --git a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestRerankBase.java b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestRerankBase.java index 1792a5336559..2ee0c1ebdd0a 100644 --- a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestRerankBase.java +++ b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestRerankBase.java @@ -120,7 +120,7 @@ protected static void setupPersistenttest(boolean bulkIndex) throws Exception { if (bulkIndex) bulkIndex(); } - protected static void setupFeatureVectorCachetest(boolean bulkIndex) throws Exception { + protected static void setupFeatureVectorCacheTest(boolean bulkIndex) throws Exception { chooseDefaultFeatureFormat(); setuptest("solrconfig-ltr-featurevectorcache.xml", "schema.xml"); if (bulkIndex) bulkIndex(); From 1b590ce545e59d00d643d532a39ebf9e6cf4ca9d Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Mon, 15 Sep 2025 16:12:20 +0200 Subject: [PATCH 32/54] Refactored LTRScoringQuery to create feature extractors within scorer and do not check for the class type --- .../org/apache/solr/ltr/LTRScoringQuery.java | 111 +++++++++--------- 1 file changed, 57 insertions(+), 54 deletions(-) diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java index 9497c986fcda..ba73b219ad45 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java @@ -568,10 +568,6 @@ private abstract class FeatureTraversalScorer extends Scorer { protected int activeDoc = -1; protected FeatureExtractor featureExtractor; - protected FeatureTraversalScorer() { - this.featureExtractor = new FeatureExtractor(this); - } - void fillFeaturesInfo() throws IOException { // Initialize features to their default values and set isDefaultValue to true. reset(); @@ -598,10 +594,7 @@ private class SingleFeatureScorer extends FeatureTraversalScorer { private SingleFeatureScorer(List featureScorers) { this.featureScorers = featureScorers; - } - - private List getFeatureScorers() { - return this.featureScorers; + this.featureExtractor = new SingleFeatureExtractor(this, featureScorers); } @Override @@ -686,10 +679,7 @@ private MultiFeaturesScorer(List featureSco } multiFeaturesIteratorIterator = new MultiFeaturesIterator(wrappers); - } - - private DisiPriorityQueue getSubScorers() { - return this.subScorers; + this.featureExtractor = new MultiFeaturesExtractor(this, subScorers); } @Override @@ -806,8 +796,8 @@ public long cost() { } } - private class FeatureExtractor { - private final FeatureTraversalScorer traversalScorer; + private abstract class FeatureExtractor { + protected final FeatureTraversalScorer traversalScorer; private FeatureExtractor(FeatureTraversalScorer traversalScorer) { this.traversalScorer = traversalScorer; @@ -823,46 +813,7 @@ protected float[] initFeatureVector(FeatureInfo[] featuresInfos) { return featureVector; } - private float[] extractFeatureVector() throws IOException { - if (traversalScorer instanceof MultiFeaturesScorer) { - MultiFeaturesScorer multiFeaturesScorer = (MultiFeaturesScorer) traversalScorer; - return extractMultiFeaturesVector(multiFeaturesScorer.getSubScorers()); - } else { - SingleFeatureScorer singleFeatureScorer = (SingleFeatureScorer) traversalScorer; - return extractSingleFeatureVector(singleFeatureScorer.getFeatureScorers()); - } - } - - protected float[] extractSingleFeatureVector( - List featureScorers) throws IOException { - float[] featureVector = initFeatureVector(allFeaturesInStore); - for (final Scorer scorer : featureScorers) { - if (scorer.docID() == traversalScorer.activeDoc) { - Feature.FeatureWeight.FeatureScorer featureScorer = - (Feature.FeatureWeight.FeatureScorer) scorer; - Feature.FeatureWeight scFW = featureScorer.getWeight(); - final int featureId = scFW.getIndex(); - float featureValue = scorer.score(); - featureVector[featureId] = featureValue; - } - } - return featureVector; - } - - protected float[] extractMultiFeaturesVector(DisiPriorityQueue subScorers) - throws IOException { - final DisiWrapper topList = subScorers.topList(); - float[] featureVector = initFeatureVector(allFeaturesInStore); - for (DisiWrapper w = topList; w != null; w = w.next) { - final Feature.FeatureWeight.FeatureScorer subScorer = - (Feature.FeatureWeight.FeatureScorer) w.scorer; - Feature.FeatureWeight feature = subScorer.getWeight(); - final int featureId = feature.getIndex(); - float featureValue = subScorer.score(); - featureVector[featureId] = featureValue; - } - return featureVector; - } + protected abstract float[] extractFeatureVector() throws IOException; private void fillFeaturesInfo() throws IOException { if (traversalScorer.activeDoc == traversalScorer.targetDoc) { @@ -923,6 +874,58 @@ private int addEfisHash(int result, int prime) { return result; } } + + private class SingleFeatureExtractor extends FeatureExtractor { + List featureScorers; + + private SingleFeatureExtractor( + FeatureTraversalScorer singleFeatureScorer, + List featureScorers) { + super(singleFeatureScorer); + this.featureScorers = featureScorers; + } + + @Override + protected float[] extractFeatureVector() throws IOException { + float[] featureVector = initFeatureVector(allFeaturesInStore); + for (final Scorer scorer : featureScorers) { + if (scorer.docID() == traversalScorer.activeDoc) { + Feature.FeatureWeight.FeatureScorer featureScorer = + (Feature.FeatureWeight.FeatureScorer) scorer; + Feature.FeatureWeight scFW = featureScorer.getWeight(); + final int featureId = scFW.getIndex(); + float featureValue = scorer.score(); + featureVector[featureId] = featureValue; + } + } + return featureVector; + } + } + + private class MultiFeaturesExtractor extends FeatureExtractor { + DisiPriorityQueue subScorers; + + private MultiFeaturesExtractor( + FeatureTraversalScorer multiFeaturesScorer, DisiPriorityQueue subScorers) { + super(multiFeaturesScorer); + this.subScorers = subScorers; + } + + @Override + protected float[] extractFeatureVector() throws IOException { + final DisiWrapper topList = subScorers.topList(); + float[] featureVector = initFeatureVector(allFeaturesInStore); + for (DisiWrapper w = topList; w != null; w = w.next) { + final Feature.FeatureWeight.FeatureScorer subScorer = + (Feature.FeatureWeight.FeatureScorer) w.scorer; + Feature.FeatureWeight feature = subScorer.getWeight(); + final int featureId = feature.getIndex(); + float featureValue = subScorer.score(); + featureVector[featureId] = featureValue; + } + return featureVector; + } + } } } } From 2eb15fedf13dcf50f1a651c6c1d6157a78b1f67e Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Mon, 15 Sep 2025 16:27:39 +0200 Subject: [PATCH 33/54] Inserted docs into a list for tests --- .../solr/ltr/TestFeatureVectorCache.java | 76 +++++++++++-------- 1 file changed, 43 insertions(+), 33 deletions(-) diff --git a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestFeatureVectorCache.java b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestFeatureVectorCache.java index 4369685e14f8..8c5f67bb7bcb 100644 --- a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestFeatureVectorCache.java +++ b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestFeatureVectorCache.java @@ -16,6 +16,8 @@ */ package org.apache.solr.ltr; +import java.util.ArrayList; +import java.util.List; import java.util.Map; import org.apache.solr.client.solrj.SolrQuery; import org.apache.solr.core.SolrCore; @@ -27,13 +29,21 @@ public class TestFeatureVectorCache extends TestRerankBase { SolrCore core = null; + List docs; @Before public void before() throws Exception { setupFeatureVectorCacheTest(false); - assertU(adoc("id", "1", "title", "w2", "description", "w2", "popularity", "2")); - assertU(adoc("id", "2", "title", "w1", "description", "w1", "popularity", "0")); + this.docs = new ArrayList<>(); + + // Add strings to the list + docs.add(adoc("id", "1", "title", "w2", "description", "w2", "popularity", "2")); + docs.add(adoc("id", "2", "title", "w1", "description", "w1", "popularity", "0")); + + for (String doc : docs) { + assertU(doc); + } assertU(commit()); loadFeatures("featurevectorcache_features.json"); @@ -97,8 +107,8 @@ public void testFeatureVectorCache_loggingDefaultStoreNoReranking() throws Excep "/query" + query.toQueryString(), "/response/docs/[0]/=={'[fv]':'" + docs0fv_default_csv + "'}"); Map filterCacheMetrics = lookupFilterCacheMetrics(core); - assertEquals(2, (long) filterCacheMetrics.get("inserts")); - assertEquals(2, (long) filterCacheMetrics.get("lookups")); + assertEquals(docs.size(), (long) filterCacheMetrics.get("inserts")); + assertEquals(docs.size(), (long) filterCacheMetrics.get("lookups")); assertEquals(0, (long) filterCacheMetrics.get("hits")); query.add("sort", "popularity desc"); @@ -107,9 +117,9 @@ public void testFeatureVectorCache_loggingDefaultStoreNoReranking() throws Excep "/query" + query.toQueryString(), "/response/docs/[0]/=={'[fv]':'" + docs0fv_default_csv + "'}"); filterCacheMetrics = lookupFilterCacheMetrics(core); - assertEquals(2, (long) filterCacheMetrics.get("inserts")); - assertEquals(4, (long) filterCacheMetrics.get("lookups")); - assertEquals(2, (long) filterCacheMetrics.get("hits")); + assertEquals(docs.size(), (long) filterCacheMetrics.get("inserts")); + assertEquals(docs.size() * 2, (long) filterCacheMetrics.get("lookups")); + assertEquals(docs.size(), (long) filterCacheMetrics.get("hits")); } @Test @@ -132,8 +142,8 @@ public void testFeatureVectorCache_loggingExplicitStoreNoReranking() throws Exce "/query" + query.toQueryString(), "/response/docs/[0]/=={'[fv]':'" + docs0fv_default_csv + "'}"); Map filterCacheMetrics = lookupFilterCacheMetrics(core); - assertEquals(2, (long) filterCacheMetrics.get("inserts")); - assertEquals(2, (long) filterCacheMetrics.get("lookups")); + assertEquals(docs.size(), (long) filterCacheMetrics.get("inserts")); + assertEquals(docs.size(), (long) filterCacheMetrics.get("lookups")); assertEquals(0, (long) filterCacheMetrics.get("hits")); query.add("sort", "popularity desc"); @@ -142,9 +152,9 @@ public void testFeatureVectorCache_loggingExplicitStoreNoReranking() throws Exce "/query" + query.toQueryString(), "/response/docs/[0]/=={'[fv]':'" + docs0fv_default_csv + "'}"); filterCacheMetrics = lookupFilterCacheMetrics(core); - assertEquals(2, (long) filterCacheMetrics.get("inserts")); - assertEquals(4, (long) filterCacheMetrics.get("lookups")); - assertEquals(2, (long) filterCacheMetrics.get("hits")); + assertEquals(docs.size(), (long) filterCacheMetrics.get("inserts")); + assertEquals(docs.size() * 2, (long) filterCacheMetrics.get("lookups")); + assertEquals(docs.size(), (long) filterCacheMetrics.get("hits")); } @Test @@ -182,8 +192,8 @@ public void testFeatureVectorCache_loggingModelStoreRerankingDifferentEfi() thro + docs0fv_default_csv + "'}"); Map filterCacheMetrics = lookupFilterCacheMetrics(core); - assertEquals(4, (long) filterCacheMetrics.get("inserts")); - assertEquals(4, (long) filterCacheMetrics.get("lookups")); + assertEquals(docs.size() * 2, (long) filterCacheMetrics.get("inserts")); + assertEquals(docs.size() * 2, (long) filterCacheMetrics.get("lookups")); assertEquals(0, (long) filterCacheMetrics.get("hits")); query.add("sort", "popularity desc"); @@ -197,9 +207,9 @@ public void testFeatureVectorCache_loggingModelStoreRerankingDifferentEfi() thro + docs0fv_default_csv + "'}"); filterCacheMetrics = lookupFilterCacheMetrics(core); - assertEquals(4, (long) filterCacheMetrics.get("inserts")); - assertEquals(8, (long) filterCacheMetrics.get("lookups")); - assertEquals(4, (long) filterCacheMetrics.get("hits")); + assertEquals(docs.size() * 2, (long) filterCacheMetrics.get("inserts")); + assertEquals(docs.size() * 4, (long) filterCacheMetrics.get("lookups")); + assertEquals(docs.size() * 2, (long) filterCacheMetrics.get("hits")); } @Test @@ -237,9 +247,9 @@ public void testFeatureVectorCache_loggingModelStoreRerankingSameEfi() throws Ex + docs0fv_default_csv + "'}"); Map filterCacheMetrics = lookupFilterCacheMetrics(core); - assertEquals(2, (long) filterCacheMetrics.get("inserts")); - assertEquals(4, (long) filterCacheMetrics.get("lookups")); - assertEquals(2, (long) filterCacheMetrics.get("hits")); + assertEquals(docs.size(), (long) filterCacheMetrics.get("inserts")); + assertEquals(docs.size() * 2, (long) filterCacheMetrics.get("lookups")); + assertEquals(docs.size(), (long) filterCacheMetrics.get("hits")); query.add("sort", "popularity desc"); // Caching, we want to see hits and same score @@ -252,9 +262,9 @@ public void testFeatureVectorCache_loggingModelStoreRerankingSameEfi() throws Ex + docs0fv_default_csv + "'}"); filterCacheMetrics = lookupFilterCacheMetrics(core); - assertEquals(2, (long) filterCacheMetrics.get("inserts")); - assertEquals(8, (long) filterCacheMetrics.get("lookups")); - assertEquals(6, (long) filterCacheMetrics.get("hits")); + assertEquals(docs.size(), (long) filterCacheMetrics.get("inserts")); + assertEquals(docs.size() * 4, (long) filterCacheMetrics.get("lookups")); + assertEquals(docs.size() * 3, (long) filterCacheMetrics.get("hits")); } @Test @@ -301,8 +311,8 @@ public void testFeatureVectorCache_loggingAllFeatureStoreAndReranking() throws E + docs0fv_default_csv + "'}"); Map filterCacheMetrics = lookupFilterCacheMetrics(core); - assertEquals(4, (long) filterCacheMetrics.get("inserts")); - assertEquals(4, (long) filterCacheMetrics.get("lookups")); + assertEquals(docs.size() * 2, (long) filterCacheMetrics.get("inserts")); + assertEquals(docs.size() * 2, (long) filterCacheMetrics.get("lookups")); assertEquals(0, (long) filterCacheMetrics.get("hits")); query.add("sort", "popularity desc"); @@ -316,9 +326,9 @@ public void testFeatureVectorCache_loggingAllFeatureStoreAndReranking() throws E + docs0fv_default_csv + "'}"); filterCacheMetrics = lookupFilterCacheMetrics(core); - assertEquals(4, (long) filterCacheMetrics.get("inserts")); - assertEquals(8, (long) filterCacheMetrics.get("lookups")); - assertEquals(4, (long) filterCacheMetrics.get("hits")); + assertEquals(docs.size() * 2, (long) filterCacheMetrics.get("inserts")); + assertEquals(docs.size() * 4, (long) filterCacheMetrics.get("lookups")); + assertEquals(docs.size() * 2, (long) filterCacheMetrics.get("hits")); } @Test @@ -347,8 +357,8 @@ public void testFeatureVectorCache_loggingExplicitStoreReranking() throws Except + docs0fv_default_csv + "'}"); Map filterCacheMetrics = lookupFilterCacheMetrics(core); - assertEquals(4, (long) filterCacheMetrics.get("inserts")); - assertEquals(4, (long) filterCacheMetrics.get("lookups")); + assertEquals(docs.size() * 2, (long) filterCacheMetrics.get("inserts")); + assertEquals(docs.size() * 2, (long) filterCacheMetrics.get("lookups")); assertEquals(0, (long) filterCacheMetrics.get("hits")); query.add("sort", "popularity desc"); @@ -362,8 +372,8 @@ public void testFeatureVectorCache_loggingExplicitStoreReranking() throws Except + docs0fv_default_csv + "'}"); filterCacheMetrics = lookupFilterCacheMetrics(core); - assertEquals(4, (long) filterCacheMetrics.get("inserts")); - assertEquals(8, (long) filterCacheMetrics.get("lookups")); - assertEquals(4, (long) filterCacheMetrics.get("hits")); + assertEquals(docs.size() * 2, (long) filterCacheMetrics.get("inserts")); + assertEquals(docs.size() * 4, (long) filterCacheMetrics.get("lookups")); + assertEquals(docs.size() * 2, (long) filterCacheMetrics.get("hits")); } } From 8cb6cb667a81884ce75c32c2ffacef45a6c4d813 Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Mon, 15 Sep 2025 16:31:06 +0200 Subject: [PATCH 34/54] Removed unuseful comment --- .../src/test/org/apache/solr/ltr/TestFeatureVectorCache.java | 3 --- 1 file changed, 3 deletions(-) diff --git a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestFeatureVectorCache.java b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestFeatureVectorCache.java index 8c5f67bb7bcb..dad18c040233 100644 --- a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestFeatureVectorCache.java +++ b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestFeatureVectorCache.java @@ -36,11 +36,8 @@ public void before() throws Exception { setupFeatureVectorCacheTest(false); this.docs = new ArrayList<>(); - - // Add strings to the list docs.add(adoc("id", "1", "title", "w2", "description", "w2", "popularity", "2")); docs.add(adoc("id", "2", "title", "w1", "description", "w1", "popularity", "0")); - for (String doc : docs) { assertU(doc); } From bfd2f7e7ac87deff57785471de4338b9936d3f3f Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Mon, 15 Sep 2025 16:33:45 +0200 Subject: [PATCH 35/54] Typos --- .../apache/solr/ltr/TestFeatureVectorCache.java | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestFeatureVectorCache.java b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestFeatureVectorCache.java index dad18c040233..2421be7a2c6f 100644 --- a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestFeatureVectorCache.java +++ b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestFeatureVectorCache.java @@ -179,7 +179,7 @@ public void testFeatureVectorCache_loggingModelStoreRerankingDifferentEfi() thro query.add("fl", "id,score,fv:[fv efi.efi_feature=3]"); query.add("rq", "{!ltr reRankDocs=3 model=featurevectorcache_linear_model efi.efi_feature=4}"); - // No caching, we want to see lookups, insertions and no hits since the efi are different + // No caching, we want to see lookups, insertions and no hits since the efis are different assertJQ( "/query" + query.toQueryString(), "/response/docs/[0]/=={" @@ -194,7 +194,7 @@ public void testFeatureVectorCache_loggingModelStoreRerankingDifferentEfi() thro assertEquals(0, (long) filterCacheMetrics.get("hits")); query.add("sort", "popularity desc"); - // Caching, we want to see hits and same score + // Caching, we want to see hits and same scores as before assertJQ( "/query" + query.toQueryString(), "/response/docs/[0]/=={" @@ -249,7 +249,7 @@ public void testFeatureVectorCache_loggingModelStoreRerankingSameEfi() throws Ex assertEquals(docs.size(), (long) filterCacheMetrics.get("hits")); query.add("sort", "popularity desc"); - // Caching, we want to see hits and same score + // Caching, we want to see hits and same scores assertJQ( "/query" + query.toQueryString(), "/response/docs/[0]/=={" @@ -298,7 +298,7 @@ public void testFeatureVectorCache_loggingAllFeatureStoreAndReranking() throws E query.add("fl", "id,score,fv:[fv logAll=true efi.efi_feature=3]"); query.add("rq", "{!ltr reRankDocs=3 model=featurevectorcache_linear_model efi.efi_feature=4}"); - // No caching, we want to see lookups, insertions and no hits since the efi are different + // No caching, we want to see lookups, insertions and no hits since the efis are different assertJQ( "/query" + query.toQueryString(), "/response/docs/[0]/=={" @@ -313,7 +313,7 @@ public void testFeatureVectorCache_loggingAllFeatureStoreAndReranking() throws E assertEquals(0, (long) filterCacheMetrics.get("hits")); query.add("sort", "popularity desc"); - // Caching, we want to see hits and same score + // Caching, we want to see hits and same scores assertJQ( "/query" + query.toQueryString(), "/response/docs/[0]/=={" @@ -344,7 +344,7 @@ public void testFeatureVectorCache_loggingExplicitStoreReranking() throws Except query.add("fl", "id,score,fv:[fv store=store1 efi.efi_feature=3]"); query.add("rq", "{!ltr reRankDocs=3 model=featurevectorcache_linear_model efi.efi_feature=4}"); - // No caching, we want to see lookups, insertions and no hits since the efi are different + // No caching, we want to see lookups, insertions and no hits since the efis are different assertJQ( "/query" + query.toQueryString(), "/response/docs/[0]/=={" @@ -359,7 +359,7 @@ public void testFeatureVectorCache_loggingExplicitStoreReranking() throws Except assertEquals(0, (long) filterCacheMetrics.get("hits")); query.add("sort", "popularity desc"); - // Caching, we want to see hits and same score + // Caching, we want to see hits and same scores assertJQ( "/query" + query.toQueryString(), "/response/docs/[0]/=={" From dd8042d093ca000059d41ad6119295dc462805ea Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Mon, 15 Sep 2025 16:35:19 +0200 Subject: [PATCH 36/54] Better test names --- .../test/org/apache/solr/ltr/TestFeatureVectorCache.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestFeatureVectorCache.java b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestFeatureVectorCache.java index 2421be7a2c6f..0859fe202714 100644 --- a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestFeatureVectorCache.java +++ b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestFeatureVectorCache.java @@ -155,7 +155,7 @@ public void testFeatureVectorCache_loggingExplicitStoreNoReranking() throws Exce } @Test - public void testFeatureVectorCache_loggingModelStoreRerankingDifferentEfi() throws Exception { + public void testFeatureVectorCache_loggingModelStoreAndRerankingWithDifferentEfi() throws Exception { final String docs0fv_dense_csv = FeatureLoggerTestUtils.toFeatureVector( "value_feature_1", @@ -210,7 +210,7 @@ public void testFeatureVectorCache_loggingModelStoreRerankingDifferentEfi() thro } @Test - public void testFeatureVectorCache_loggingModelStoreRerankingSameEfi() throws Exception { + public void testFeatureVectorCache_loggingModelStoreAndRerankingWithSameEfi() throws Exception { final String docs0fv_dense_csv = FeatureLoggerTestUtils.toFeatureVector( "value_feature_1", @@ -329,7 +329,7 @@ public void testFeatureVectorCache_loggingAllFeatureStoreAndReranking() throws E } @Test - public void testFeatureVectorCache_loggingExplicitStoreReranking() throws Exception { + public void testFeatureVectorCache_loggingExplicitStoreAndReranking() throws Exception { final String docs0fv_dense_csv = FeatureLoggerTestUtils.toFeatureVector("match_w1_title", "0.0", "value_feature_2", "2.0"); final String docs0fv_sparse_csv = From a4c7431f951ae21be5cfca86b417427fd1e12980 Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Mon, 15 Sep 2025 16:38:43 +0200 Subject: [PATCH 37/54] Gradlew tidy --- .../src/test/org/apache/solr/ltr/TestFeatureVectorCache.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestFeatureVectorCache.java b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestFeatureVectorCache.java index 0859fe202714..442b3b48514e 100644 --- a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestFeatureVectorCache.java +++ b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestFeatureVectorCache.java @@ -155,7 +155,8 @@ public void testFeatureVectorCache_loggingExplicitStoreNoReranking() throws Exce } @Test - public void testFeatureVectorCache_loggingModelStoreAndRerankingWithDifferentEfi() throws Exception { + public void testFeatureVectorCache_loggingModelStoreAndRerankingWithDifferentEfi() + throws Exception { final String docs0fv_dense_csv = FeatureLoggerTestUtils.toFeatureVector( "value_feature_1", From 074ef4126d73d5708375d54871a79cb21dc687dd Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Wed, 17 Sep 2025 11:16:39 +0200 Subject: [PATCH 38/54] Moved feature extraction into a dedicated package --- .../feature/extraction/FeatureExtractor.java | 109 ++++++++++++++++++ .../extraction/MultiFeaturesExtractor.java | 40 +++++++ .../extraction/SingleFeatureExtractor.java | 41 +++++++ 3 files changed, 190 insertions(+) create mode 100644 solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/FeatureExtractor.java create mode 100644 solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/MultiFeaturesExtractor.java create mode 100644 solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/SingleFeatureExtractor.java diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/FeatureExtractor.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/FeatureExtractor.java new file mode 100644 index 000000000000..5bb88c91e41c --- /dev/null +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/FeatureExtractor.java @@ -0,0 +1,109 @@ +package org.apache.solr.ltr.feature; + +import org.apache.solr.ltr.FeatureLogger; +import org.apache.solr.ltr.LTRScoringQuery; +import org.apache.solr.ltr.model.LTRScoringModel; +import org.apache.solr.ltr.response.transform.LTRFeatureLoggerTransformerFactory; +import org.apache.solr.request.SolrQueryRequest; +import org.apache.solr.search.SolrCache; +import java.io.IOException; +import java.util.Arrays; +import java.util.Map; +import java.util.Objects; +import java.util.TreeMap; + +public abstract class FeatureExtractor { + protected final LTRScoringQuery.ModelWeight.ModelScorer.FeatureTraversalScorer traversalScorer; + SolrQueryRequest request; + Feature.FeatureWeight[] extractedFeatureWeights; + LTRScoringQuery.FeatureInfo[] allFeaturesInStore; + LTRScoringModel ltrScoringModel; + FeatureLogger logger; + Map efi; + + FeatureExtractor( + LTRScoringQuery.ModelWeight.ModelScorer.FeatureTraversalScorer traversalScorer, + SolrQueryRequest request, + Feature.FeatureWeight[] extractedFeatureWeights, + LTRScoringQuery.FeatureInfo[] allFeaturesInStore, + LTRScoringModel ltrScoringModel, + Map efi) { + this.traversalScorer = traversalScorer; + this.request = request; + this.extractedFeatureWeights = extractedFeatureWeights; + this.allFeaturesInStore = allFeaturesInStore; + this.ltrScoringModel = ltrScoringModel; + this.efi = efi; + } + + protected float[] initFeatureVector(LTRScoringQuery.FeatureInfo[] featuresInfos) { + float[] featureVector = new float[featuresInfos.length]; + for (int i = 0; i < featuresInfos.length; i++) { + if (featuresInfos[i] != null) { + featureVector[i] = featuresInfos[i].getValue(); + } + } + return featureVector; + } + + protected abstract float[] extractFeatureVector() throws IOException; + + public void fillFeaturesInfo() throws IOException { + if (traversalScorer.getActiveDoc() == traversalScorer.getTargetDoc()) { + SolrCache featureVectorCache = null; + float[] featureVector; + + if (request != null) { + featureVectorCache = request.getSearcher().getFeatureVectorCache(); + } + if (featureVectorCache != null) { + int fvCacheKey = computeFeatureVectorCacheKey(traversalScorer.docID()); + featureVector = featureVectorCache.get(fvCacheKey); + if (featureVector == null) { + featureVector = extractFeatureVector(); + featureVectorCache.put(fvCacheKey, featureVector); + } + } else { + featureVector = extractFeatureVector(); + } + + for (int i = 0; i < extractedFeatureWeights.length; i++) { + int featureId = extractedFeatureWeights[i].getIndex(); + float featureValue = featureVector[featureId]; + if (!Float.isNaN(featureValue) + && featureValue != extractedFeatureWeights[i].getDefaultValue()) { + allFeaturesInStore[featureId].setValue(featureValue); + allFeaturesInStore[featureId].setIsDefaultValue(false); + } + } + } + } + + private int computeFeatureVectorCacheKey(int docId) { + int prime = 31; + int result = docId; + if (Objects.equals( + ltrScoringModel.getName(), + LTRFeatureLoggerTransformerFactory.DEFAULT_LOGGING_MODEL_NAME) + || (logger != null && logger.isLogFeatures() && logger.isLoggingAll())) { + result = (prime * result) + ltrScoringModel.getFeatureStoreName().hashCode(); + } else { + result = (prime * result) + ltrScoringModel.getName().hashCode(); + } + result = (prime * result) + addEfisHash(result, prime, efi); + return result; + } + + private int addEfisHash(int result, int prime, Map efi) { + if (efi != null) { + TreeMap sorted = new TreeMap<>(efi); + for (final Map.Entry entry : sorted.entrySet()) { + final String key = entry.getKey(); + final String[] values = entry.getValue(); + result = (prime * result) + key.hashCode(); + result = (prime * result) + Arrays.hashCode(values); + } + } + return result; + } +} diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/MultiFeaturesExtractor.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/MultiFeaturesExtractor.java new file mode 100644 index 000000000000..1dbffd794b81 --- /dev/null +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/MultiFeaturesExtractor.java @@ -0,0 +1,40 @@ +package org.apache.solr.ltr.feature; + +import org.apache.lucene.search.DisiPriorityQueue; +import org.apache.lucene.search.DisiWrapper; +import org.apache.solr.ltr.LTRScoringQuery; +import org.apache.solr.ltr.model.LTRScoringModel; +import org.apache.solr.request.SolrQueryRequest; +import java.io.IOException; +import java.util.Map; + +public class MultiFeaturesExtractor extends FeatureExtractor { + DisiPriorityQueue subScorers; + + public MultiFeaturesExtractor( + LTRScoringQuery.ModelWeight.ModelScorer.FeatureTraversalScorer multiFeaturesScorer, + SolrQueryRequest request, + Feature.FeatureWeight[] extractedFeatureWeights, + LTRScoringQuery.FeatureInfo[] allFeaturesInStore, + LTRScoringModel ltrScoringModel, + Map efi, + DisiPriorityQueue subScorers) { + super(multiFeaturesScorer, request, extractedFeatureWeights, allFeaturesInStore, ltrScoringModel, efi); + this.subScorers = subScorers; + } + + @Override + protected float[] extractFeatureVector() throws IOException { + final DisiWrapper topList = subScorers.topList(); + float[] featureVector = initFeatureVector(allFeaturesInStore); + for (DisiWrapper w = topList; w != null; w = w.next) { + final Feature.FeatureWeight.FeatureScorer subScorer = + (Feature.FeatureWeight.FeatureScorer) w.scorer; + Feature.FeatureWeight feature = subScorer.getWeight(); + final int featureId = feature.getIndex(); + float featureValue = subScorer.score(); + featureVector[featureId] = featureValue; + } + return featureVector; + } +} diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/SingleFeatureExtractor.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/SingleFeatureExtractor.java new file mode 100644 index 000000000000..a51fc50bd48d --- /dev/null +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/SingleFeatureExtractor.java @@ -0,0 +1,41 @@ +package org.apache.solr.ltr.feature; + +import org.apache.lucene.search.Scorer; +import org.apache.solr.ltr.LTRScoringQuery; +import org.apache.solr.ltr.model.LTRScoringModel; +import org.apache.solr.request.SolrQueryRequest; +import java.io.IOException; +import java.util.List; +import java.util.Map; + +public class SingleFeatureExtractor extends FeatureExtractor { + List featureScorers; + + public SingleFeatureExtractor( + LTRScoringQuery.ModelWeight.ModelScorer.FeatureTraversalScorer singleFeatureScorer, + SolrQueryRequest request, + Feature.FeatureWeight[] extractedFeatureWeights, + LTRScoringQuery.FeatureInfo[] allFeaturesInStore, + LTRScoringModel ltrScoringModel, + Map efi, + List featureScorers) { + super(singleFeatureScorer, request, extractedFeatureWeights, allFeaturesInStore, ltrScoringModel, efi); + this.featureScorers = featureScorers; + } + + @Override + protected float[] extractFeatureVector() throws IOException { + float[] featureVector = initFeatureVector(allFeaturesInStore); + for (final Scorer scorer : featureScorers) { + if (scorer.docID() == traversalScorer.getActiveDoc()) { + Feature.FeatureWeight.FeatureScorer featureScorer = + (Feature.FeatureWeight.FeatureScorer) scorer; + Feature.FeatureWeight scFW = featureScorer.getWeight(); + final int featureId = scFW.getIndex(); + float featureValue = scorer.score(); + featureVector[featureId] = featureValue; + } + } + return featureVector; + } +} From 0a432f1904c073e2ce1897f17bd8622305039ba5 Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Wed, 17 Sep 2025 14:50:18 +0200 Subject: [PATCH 39/54] First attempt to move scorer outside LTRScoringQuery --- .../org/apache/solr/ltr/FeatureLogger.java | 11 + .../org/apache/solr/ltr/LTRScoringQuery.java | 411 +----------------- .../feature/extraction/FeatureExtractor.java | 8 +- .../extraction/MultiFeaturesExtractor.java | 6 +- .../extraction/SingleFeatureExtractor.java | 6 +- .../LTRFeatureLoggerTransformerFactory.java | 15 +- .../ltr/scoring/FeatureTraversalScorer.java | 57 +++ .../solr/ltr/scoring/MultiFeaturesScorer.java | 161 +++++++ .../solr/ltr/scoring/SingleFeatureScorer.java | 96 ++++ 9 files changed, 359 insertions(+), 412 deletions(-) create mode 100644 solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/FeatureTraversalScorer.java create mode 100644 solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/MultiFeaturesScorer.java create mode 100644 solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/SingleFeatureScorer.java diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/FeatureLogger.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/FeatureLogger.java index 16c554df0f3f..54d308b665e4 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/FeatureLogger.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/FeatureLogger.java @@ -29,9 +29,12 @@ public enum FeatureFormat { protected Boolean logAll; + protected boolean logFeatures; + protected FeatureLogger(FeatureFormat f, Boolean logAll) { this.featureFormat = f; this.logAll = logAll; + this.logFeatures = false; } public abstract String printFeatureVector(LTRScoringQuery.FeatureInfo[] featuresInfo); @@ -43,4 +46,12 @@ public Boolean isLoggingAll() { public void setLogAll(Boolean logAll) { this.logAll = logAll; } + + public void setLogFeatures(boolean logFeatures) { + this.logFeatures = logFeatures; + } + + public boolean isLogFeatures() { + return logFeatures; + } } diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java index ba73b219ad45..6936726bd5a0 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java @@ -24,8 +24,6 @@ import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.Objects; -import java.util.TreeMap; import java.util.concurrent.Callable; import java.util.concurrent.Future; import java.util.concurrent.FutureTask; @@ -46,10 +44,14 @@ import org.apache.lucene.util.Accountable; import org.apache.lucene.util.RamUsageEstimator; import org.apache.solr.ltr.feature.Feature; +import org.apache.solr.ltr.feature.extraction.FeatureExtractor; +import org.apache.solr.ltr.feature.extraction.MultiFeaturesExtractor; +import org.apache.solr.ltr.feature.extraction.SingleFeatureExtractor; import org.apache.solr.ltr.model.LTRScoringModel; -import org.apache.solr.ltr.response.transform.LTRFeatureLoggerTransformerFactory; +import org.apache.solr.ltr.scoring.FeatureTraversalScorer; +import org.apache.solr.ltr.scoring.MultiFeaturesScorer; +import org.apache.solr.ltr.scoring.SingleFeatureScorer; import org.apache.solr.request.SolrQueryRequest; -import org.apache.solr.search.SolrCache; import org.apache.solr.util.SolrDefaultScorerSupplier; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -375,7 +377,6 @@ public class ModelWeight extends Weight { // features used for logging. private final Feature.FeatureWeight[] modelFeatureWeights; private final float[] modelFeatureValuesNormalized; - private final Feature.FeatureWeight[] extractedFeatureWeights; // Array of all the features in the feature store of reference private final FeatureInfo[] allFeaturesInStore; @@ -395,14 +396,13 @@ public ModelWeight( Feature.FeatureWeight[] extractedFeatureWeights, int allFeaturesSize) { super(LTRScoringQuery.this); - this.extractedFeatureWeights = extractedFeatureWeights; this.modelFeatureWeights = modelFeatureWeights; this.modelFeatureValuesNormalized = new float[modelFeatureWeights.length]; this.allFeaturesInStore = new FeatureInfo[allFeaturesSize]; - setFeaturesInfo(); + setFeaturesInfo(extractedFeatureWeights); } - private void setFeaturesInfo() { + private void setFeaturesInfo(Feature.FeatureWeight[] extractedFeatureWeights) { for (int i = 0; i < extractedFeatureWeights.length; ++i) { String featName = extractedFeatureWeights[i].getName(); int featId = extractedFeatureWeights[i].getIndex(); @@ -421,20 +421,15 @@ Feature.FeatureWeight[] getModelFeatureWeights() { } // for test use - float[] getModelFeatureValuesNormalized() { + public float[] getModelFeatureValuesNormalized() { return modelFeatureValuesNormalized; } - // for test use - Feature.FeatureWeight[] getExtractedFeatureWeights() { - return extractedFeatureWeights; - } - /** * Goes through all the stored feature values, and calculates the normalized values for all the * features that will be used for scoring. Then calculate and return the model's score. */ - private void normalizeFeatures() { + public void normalizeFeatures() { int pos = 0; for (final Feature.FeatureWeight feature : modelFeatureWeights) { final int featureId = feature.getIndex(); @@ -466,17 +461,6 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio return ltrScoringModel.explain(context, doc, finalScore, featureExplanations); } - protected void reset() { - for (int i = 0; i < extractedFeatureWeights.length; ++i) { - int featId = extractedFeatureWeights[i].getIndex(); - float value = extractedFeatureWeights[i].getDefaultValue(); - // need to set default value everytime as the default value is used in 'dense' - // mode even if used=false - allFeaturesInStore[featId].setValue(value); - allFeaturesInStore[featId].setIsDefaultValue(true); - } - } - @Override public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { return new SolrDefaultScorerSupplier(modelScorer(context)); @@ -509,23 +493,21 @@ public boolean isCacheable(LeafReaderContext ctx) { public class ModelScorer extends Scorer { private final DocInfo docInfo; private final FeatureTraversalScorer featureTraversalScorer; - protected boolean isLogging; public DocInfo getDocInfo() { return docInfo; } public ModelScorer(List featureScorers) { - isLogging = false; docInfo = new DocInfo(); for (final Feature.FeatureWeight.FeatureScorer subScorer : featureScorers) { subScorer.setDocInfo(docInfo); } if (featureScorers.size() <= 1) { // future enhancement: allow the use of dense features in other cases - featureTraversalScorer = new SingleFeatureScorer(featureScorers); + featureTraversalScorer = new SingleFeatureScorer(this, request, extractedFeatureWeights, allFeaturesInStore, ltrScoringModel, efi, featureScorers); } else { - featureTraversalScorer = new MultiFeaturesScorer(featureScorers); + featureTraversalScorer = new MultiFeaturesScorer(this, request, extractedFeatureWeights, allFeaturesInStore, ltrScoringModel, efi, featureScorers); } } @@ -554,378 +536,9 @@ public DocIdSetIterator iterator() { return featureTraversalScorer.iterator(); } - public void setIsLogging(boolean isLogging) { - this.isLogging = isLogging; - } - public void fillFeaturesInfo() throws IOException { featureTraversalScorer.fillFeaturesInfo(); } - - /** This class is responsible for extracting features and using them to score the document. */ - private abstract class FeatureTraversalScorer extends Scorer { - protected int targetDoc = -1; - protected int activeDoc = -1; - protected FeatureExtractor featureExtractor; - - void fillFeaturesInfo() throws IOException { - // Initialize features to their default values and set isDefaultValue to true. - reset(); - featureExtractor.fillFeaturesInfo(); - } - - @Override - public float score() throws IOException { - // Initialize features to their default values and set isDefaultValue to true. - reset(); - featureExtractor.fillFeaturesInfo(); - normalizeFeatures(); - return ltrScoringModel.score(modelFeatureValuesNormalized); - } - - @Override - public float getMaxScore(int upTo) { - return Float.POSITIVE_INFINITY; - } - } - - private class SingleFeatureScorer extends FeatureTraversalScorer { - private final List featureScorers; - - private SingleFeatureScorer(List featureScorers) { - this.featureScorers = featureScorers; - this.featureExtractor = new SingleFeatureExtractor(this, featureScorers); - } - - @Override - public int docID() { - return targetDoc; - } - - @Override - public final Collection getChildren() { - final ArrayList children = new ArrayList<>(); - for (final Scorer scorer : featureScorers) { - children.add(new ChildScorable(scorer, "SHOULD")); - } - return children; - } - - @Override - public DocIdSetIterator iterator() { - return new SingleFeatureIterator(); - } - - private class SingleFeatureIterator extends DocIdSetIterator { - - @Override - public int docID() { - return targetDoc; - } - - @Override - public int nextDoc() throws IOException { - if (activeDoc <= targetDoc) { - activeDoc = NO_MORE_DOCS; - for (final Scorer scorer : featureScorers) { - if (scorer.docID() != NO_MORE_DOCS) { - activeDoc = Math.min(activeDoc, scorer.iterator().nextDoc()); - } - } - } - return ++targetDoc; - } - - @Override - public int advance(int target) throws IOException { - if (activeDoc < target) { - activeDoc = NO_MORE_DOCS; - for (final Scorer scorer : featureScorers) { - if (scorer.docID() != NO_MORE_DOCS) { - activeDoc = Math.min(activeDoc, scorer.iterator().advance(target)); - } - } - } - targetDoc = target; - return target; - } - - @Override - public long cost() { - long sum = 0; - for (int i = 0; i < featureScorers.size(); i++) { - sum += featureScorers.get(i).iterator().cost(); - } - return sum; - } - } - } - - private class MultiFeaturesScorer extends FeatureTraversalScorer { - private final DisiPriorityQueue subScorers; - private final List wrappers; - private final MultiFeaturesIterator multiFeaturesIteratorIterator; - - private MultiFeaturesScorer(List featureScorers) { - if (featureScorers.size() <= 1) { - throw new IllegalArgumentException("There must be at least 2 subScorers"); - } - subScorers = DisiPriorityQueue.ofMaxSize(featureScorers.size()); - wrappers = new ArrayList<>(); - for (final Scorer scorer : featureScorers) { - final DisiWrapper w = new DisiWrapper(scorer, false /* impacts */); - subScorers.add(w); - wrappers.add(w); - } - - multiFeaturesIteratorIterator = new MultiFeaturesIterator(wrappers); - this.featureExtractor = new MultiFeaturesExtractor(this, subScorers); - } - - @Override - public int docID() { - return multiFeaturesIteratorIterator.docID(); - } - - @Override - public DocIdSetIterator iterator() { - return multiFeaturesIteratorIterator; - } - - @Override - public final Collection getChildren() { - final ArrayList children = new ArrayList<>(); - for (final DisiWrapper scorer : subScorers) { - children.add(new ChildScorable(scorer.scorer, "SHOULD")); - } - return children; - } - - private class MultiFeaturesIterator extends DocIdSetIterator { - - public MultiFeaturesIterator(Collection wrappers) { - // Initialize all wrappers to start at -1 - for (DisiWrapper wrapper : wrappers) { - wrapper.doc = -1; - } - } - - @Override - public int docID() { - // Return the target document ID (mimicking DisjunctionDISIApproximation behavior) - return targetDoc; - } - - @Override - public final int nextDoc() throws IOException { - // Mimic DisjunctionDISIApproximation behavior - if (targetDoc == -1) { - // First call - initialize all iterators - DisiWrapper top = subScorers.top(); - if (top != null && top.doc == -1) { - // Need to advance all iterators to their first document - DisiWrapper current = subScorers.top(); - while (current != null) { - current.doc = current.iterator.nextDoc(); - current = subScorers.updateTop(); - } - top = subScorers.top(); - activeDoc = top == null ? NO_MORE_DOCS : top.doc; - } - targetDoc = activeDoc; - return targetDoc; - } - - if (activeDoc == targetDoc) { - // Advance the underlying disjunction - DisiWrapper top = subScorers.top(); - if (top == null) { - activeDoc = NO_MORE_DOCS; - } else { - // Advance the top iterator and rebalance the queue - top.doc = top.iterator.nextDoc(); - top = subScorers.updateTop(); - activeDoc = top == null ? NO_MORE_DOCS : top.doc; - } - } else if (activeDoc < targetDoc) { - // Need to catch up to targetDoc + 1 - activeDoc = advanceInternal(targetDoc + 1); - } - return ++targetDoc; - } - - @Override - public final int advance(int target) throws IOException { - // Mimic DisjunctionDISIApproximation behavior - if (activeDoc < target) { - activeDoc = advanceInternal(target); - } - targetDoc = target; - return targetDoc; - } - - private int advanceInternal(int target) throws IOException { - // Advance the underlying disjunction to the target - DisiWrapper top; - do { - top = subScorers.top(); - if (top == null) { - return NO_MORE_DOCS; - } - if (top.doc >= target) { - return top.doc; - } - top.doc = top.iterator.advance(target); - top = subScorers.updateTop(); - if (top == null) { - return NO_MORE_DOCS; - } - } while (top.doc < target); - return top.doc; - } - - @Override - public long cost() { - // Calculate cost from all wrappers - long cost = 0; - for (DisiWrapper wrapper : wrappers) { - cost += wrapper.iterator.cost(); - } - return cost; - } - } - } - - private abstract class FeatureExtractor { - protected final FeatureTraversalScorer traversalScorer; - - private FeatureExtractor(FeatureTraversalScorer traversalScorer) { - this.traversalScorer = traversalScorer; - } - - protected float[] initFeatureVector(FeatureInfo[] featuresInfos) { - float[] featureVector = new float[allFeaturesInStore.length]; - for (int i = 0; i < featuresInfos.length; i++) { - if (featuresInfos[i] != null) { - featureVector[i] = featuresInfos[i].getValue(); - } - } - return featureVector; - } - - protected abstract float[] extractFeatureVector() throws IOException; - - private void fillFeaturesInfo() throws IOException { - if (traversalScorer.activeDoc == traversalScorer.targetDoc) { - SolrCache featureVectorCache = null; - float[] featureVector; - - if (request != null) { - featureVectorCache = request.getSearcher().getFeatureVectorCache(); - } - if (featureVectorCache != null) { - int fvCacheKey = computeFeatureVectorCacheKey(traversalScorer.docID()); - featureVector = featureVectorCache.get(fvCacheKey); - if (featureVector == null) { - featureVector = extractFeatureVector(); - featureVectorCache.put(fvCacheKey, featureVector); - } - } else { - featureVector = extractFeatureVector(); - } - - for (int i = 0; i < extractedFeatureWeights.length; i++) { - int featureId = extractedFeatureWeights[i].getIndex(); - float featureValue = featureVector[featureId]; - if (!Float.isNaN(featureValue) - && featureValue != extractedFeatureWeights[i].getDefaultValue()) { - allFeaturesInStore[featureId].setValue(featureValue); - allFeaturesInStore[featureId].setIsDefaultValue(false); - } - } - } - } - - private int computeFeatureVectorCacheKey(int docId) { - int prime = 31; - int result = docId; - if (Objects.equals( - ltrScoringModel.getName(), - LTRFeatureLoggerTransformerFactory.DEFAULT_LOGGING_MODEL_NAME) - || (isLogging && logger.isLoggingAll())) { - result = (prime * result) + ltrScoringModel.getFeatureStoreName().hashCode(); - } else { - result = (prime * result) + ltrScoringModel.getName().hashCode(); - } - result = (prime * result) + addEfisHash(result, prime); - return result; - } - - private int addEfisHash(int result, int prime) { - if (efi != null) { - TreeMap sorted = new TreeMap<>(efi); - for (final Map.Entry entry : sorted.entrySet()) { - final String key = entry.getKey(); - final String[] values = entry.getValue(); - result = (prime * result) + key.hashCode(); - result = (prime * result) + Arrays.hashCode(values); - } - } - return result; - } - } - - private class SingleFeatureExtractor extends FeatureExtractor { - List featureScorers; - - private SingleFeatureExtractor( - FeatureTraversalScorer singleFeatureScorer, - List featureScorers) { - super(singleFeatureScorer); - this.featureScorers = featureScorers; - } - - @Override - protected float[] extractFeatureVector() throws IOException { - float[] featureVector = initFeatureVector(allFeaturesInStore); - for (final Scorer scorer : featureScorers) { - if (scorer.docID() == traversalScorer.activeDoc) { - Feature.FeatureWeight.FeatureScorer featureScorer = - (Feature.FeatureWeight.FeatureScorer) scorer; - Feature.FeatureWeight scFW = featureScorer.getWeight(); - final int featureId = scFW.getIndex(); - float featureValue = scorer.score(); - featureVector[featureId] = featureValue; - } - } - return featureVector; - } - } - - private class MultiFeaturesExtractor extends FeatureExtractor { - DisiPriorityQueue subScorers; - - private MultiFeaturesExtractor( - FeatureTraversalScorer multiFeaturesScorer, DisiPriorityQueue subScorers) { - super(multiFeaturesScorer); - this.subScorers = subScorers; - } - - @Override - protected float[] extractFeatureVector() throws IOException { - final DisiWrapper topList = subScorers.topList(); - float[] featureVector = initFeatureVector(allFeaturesInStore); - for (DisiWrapper w = topList; w != null; w = w.next) { - final Feature.FeatureWeight.FeatureScorer subScorer = - (Feature.FeatureWeight.FeatureScorer) w.scorer; - Feature.FeatureWeight feature = subScorer.getWeight(); - final int featureId = feature.getIndex(); - float featureValue = subScorer.score(); - featureVector[featureId] = featureValue; - } - return featureVector; - } - } } } } diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/FeatureExtractor.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/FeatureExtractor.java index 5bb88c91e41c..1027c6777cda 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/FeatureExtractor.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/FeatureExtractor.java @@ -1,9 +1,11 @@ -package org.apache.solr.ltr.feature; +package org.apache.solr.ltr.feature.extraction; import org.apache.solr.ltr.FeatureLogger; import org.apache.solr.ltr.LTRScoringQuery; +import org.apache.solr.ltr.feature.Feature; import org.apache.solr.ltr.model.LTRScoringModel; import org.apache.solr.ltr.response.transform.LTRFeatureLoggerTransformerFactory; +import org.apache.solr.ltr.scoring.FeatureTraversalScorer; import org.apache.solr.request.SolrQueryRequest; import org.apache.solr.search.SolrCache; import java.io.IOException; @@ -13,7 +15,7 @@ import java.util.TreeMap; public abstract class FeatureExtractor { - protected final LTRScoringQuery.ModelWeight.ModelScorer.FeatureTraversalScorer traversalScorer; + protected final FeatureTraversalScorer traversalScorer; SolrQueryRequest request; Feature.FeatureWeight[] extractedFeatureWeights; LTRScoringQuery.FeatureInfo[] allFeaturesInStore; @@ -22,7 +24,7 @@ public abstract class FeatureExtractor { Map efi; FeatureExtractor( - LTRScoringQuery.ModelWeight.ModelScorer.FeatureTraversalScorer traversalScorer, + FeatureTraversalScorer traversalScorer, SolrQueryRequest request, Feature.FeatureWeight[] extractedFeatureWeights, LTRScoringQuery.FeatureInfo[] allFeaturesInStore, diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/MultiFeaturesExtractor.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/MultiFeaturesExtractor.java index 1dbffd794b81..ace249f87f22 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/MultiFeaturesExtractor.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/MultiFeaturesExtractor.java @@ -1,9 +1,11 @@ -package org.apache.solr.ltr.feature; +package org.apache.solr.ltr.feature.extraction; import org.apache.lucene.search.DisiPriorityQueue; import org.apache.lucene.search.DisiWrapper; import org.apache.solr.ltr.LTRScoringQuery; +import org.apache.solr.ltr.feature.Feature; import org.apache.solr.ltr.model.LTRScoringModel; +import org.apache.solr.ltr.scoring.FeatureTraversalScorer; import org.apache.solr.request.SolrQueryRequest; import java.io.IOException; import java.util.Map; @@ -12,7 +14,7 @@ public class MultiFeaturesExtractor extends FeatureExtractor { DisiPriorityQueue subScorers; public MultiFeaturesExtractor( - LTRScoringQuery.ModelWeight.ModelScorer.FeatureTraversalScorer multiFeaturesScorer, + FeatureTraversalScorer multiFeaturesScorer, SolrQueryRequest request, Feature.FeatureWeight[] extractedFeatureWeights, LTRScoringQuery.FeatureInfo[] allFeaturesInStore, diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/SingleFeatureExtractor.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/SingleFeatureExtractor.java index a51fc50bd48d..ccf7f4be0a78 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/SingleFeatureExtractor.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/SingleFeatureExtractor.java @@ -1,8 +1,10 @@ -package org.apache.solr.ltr.feature; +package org.apache.solr.ltr.feature.extraction; import org.apache.lucene.search.Scorer; import org.apache.solr.ltr.LTRScoringQuery; +import org.apache.solr.ltr.feature.Feature; import org.apache.solr.ltr.model.LTRScoringModel; +import org.apache.solr.ltr.scoring.FeatureTraversalScorer; import org.apache.solr.request.SolrQueryRequest; import java.io.IOException; import java.util.List; @@ -12,7 +14,7 @@ public class SingleFeatureExtractor extends FeatureExtractor { List featureScorers; public SingleFeatureExtractor( - LTRScoringQuery.ModelWeight.ModelScorer.FeatureTraversalScorer singleFeatureScorer, + FeatureTraversalScorer singleFeatureScorer, SolrQueryRequest request, Feature.FeatureWeight[] extractedFeatureWeights, LTRScoringQuery.FeatureInfo[] allFeaturesInStore, diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java index 1ca7edde341c..36107d5348de 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java @@ -408,6 +408,7 @@ public void transform(SolrDocument doc, int docid, DocIterationInfo docInfo) } private static LTRScoringQuery.FeatureInfo[] extractFeatures( + FeatureLogger logger, LTRScoringQuery.ModelWeight modelWeight, int docid, Float originalDocScore, @@ -426,7 +427,7 @@ private static LTRScoringQuery.FeatureInfo[] extractFeatures( r.getDocInfo().setOriginalDocScore(originalDocScore); } r.fillFeaturesInfo(); - r.setIsLogging(true); + logger.setLogFeatures(true); return modelWeight.getAllFeaturesInStore(); } } @@ -444,13 +445,15 @@ private void implTransform(SolrDocument doc, int docid, DocIterationInfo docInfo } } if (!(rerankingQuery instanceof OriginalRankingLTRScoringQuery) || hasExplicitFeatureStore) { + LTRScoringQuery.FeatureInfo[] featuresInfo = extractFeatures( + featureLogger, + rerankingModelWeight, + docid, + (!docsWereReranked && docsHaveScores) ? docInfo.score() : null, + leafContexts); String featureVector = featureLogger.printFeatureVector( - extractFeatures( - rerankingModelWeight, - docid, - (!docsWereReranked && docsHaveScores) ? docInfo.score() : null, - leafContexts)); + featuresInfo); doc.addField(name, featureVector); } } diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/FeatureTraversalScorer.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/FeatureTraversalScorer.java new file mode 100644 index 000000000000..529361ee5179 --- /dev/null +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/FeatureTraversalScorer.java @@ -0,0 +1,57 @@ +package org.apache.solr.ltr.scoring; + +import org.apache.lucene.search.Scorer; +import org.apache.solr.ltr.LTRScoringQuery; +import org.apache.solr.ltr.feature.Feature; +import org.apache.solr.ltr.feature.extraction.FeatureExtractor; +import org.apache.solr.ltr.model.LTRScoringModel; +import java.io.IOException; + +/** This class is responsible for extracting features and using them to score the document. */ +public abstract class FeatureTraversalScorer extends Scorer { + protected int targetDoc = -1; + protected int activeDoc = -1; + protected FeatureExtractor featureExtractor; + protected Feature.FeatureWeight[] extractedFeatureWeights; + protected LTRScoringQuery.FeatureInfo[] allFeaturesInStore; + protected LTRScoringModel ltrScoringModel; + + public int getTargetDoc() { + return targetDoc; + } + + public int getActiveDoc() { + return activeDoc; + } + + public void reset() { + for (int i = 0; i < extractedFeatureWeights.length; ++i) { + int featId = extractedFeatureWeights[i].getIndex(); + float value = extractedFeatureWeights[i].getDefaultValue(); + // need to set default value everytime as the default value is used in 'dense' + // mode even if used=false + allFeaturesInStore[featId].setValue(value); + allFeaturesInStore[featId].setIsDefaultValue(true); + } + } + + public void fillFeaturesInfo() throws IOException { + // Initialize features to their default values and set isDefaultValue to true. + reset(); + featureExtractor.fillFeaturesInfo(); + } + + @Override + public float score() throws IOException { + // Initialize features to their default values and set isDefaultValue to true. + reset(); + featureExtractor.fillFeaturesInfo(); + modelWeight.normalizeFeatures(); + return ltrScoringModel.score(modelWeight.getModelFeatureValuesNormalized()); + } + + @Override + public float getMaxScore(int upTo) { + return Float.POSITIVE_INFINITY; + } +} diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/MultiFeaturesScorer.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/MultiFeaturesScorer.java new file mode 100644 index 000000000000..8e70105d0fa0 --- /dev/null +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/MultiFeaturesScorer.java @@ -0,0 +1,161 @@ +package org.apache.solr.ltr.scoring; + +import org.apache.lucene.search.DisiPriorityQueue; +import org.apache.lucene.search.DisiWrapper; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Scorer; +import org.apache.solr.ltr.LTRScoringQuery; +import org.apache.solr.ltr.feature.Feature; +import org.apache.solr.ltr.feature.extraction.MultiFeaturesExtractor; +import org.apache.solr.ltr.model.LTRScoringModel; +import org.apache.solr.request.SolrQueryRequest; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; + +public class MultiFeaturesScorer extends FeatureTraversalScorer { + private final DisiPriorityQueue subScorers; + private final List wrappers; + private final MultiFeaturesIterator multiFeaturesIteratorIterator; + + public MultiFeaturesScorer( + LTRScoringQuery.ModelWeight modelWeight, + SolrQueryRequest request, + Feature.FeatureWeight[] extractedFeatureWeights, + LTRScoringQuery.FeatureInfo[] allFeaturesInStore, + LTRScoringModel ltrScoringModel, + Map efi, + List featureScorers) { + if (featureScorers.size() <= 1) { + throw new IllegalArgumentException("There must be at least 2 subScorers"); + } + subScorers = DisiPriorityQueue.ofMaxSize(featureScorers.size()); + wrappers = new ArrayList<>(); + for (final Scorer scorer : featureScorers) { + final DisiWrapper w = new DisiWrapper(scorer, false /* impacts */); + subScorers.add(w); + wrappers.add(w); + } + + multiFeaturesIteratorIterator = new MultiFeaturesIterator(wrappers); + this.extractedFeatureWeights = extractedFeatureWeights; + this.allFeaturesInStore = allFeaturesInStore; + this.ltrScoringModel = ltrScoringModel; + this.featureExtractor = new MultiFeaturesExtractor(this, request, extractedFeatureWeights, allFeaturesInStore, ltrScoringModel, efi, subScorers); + } + + @Override + public int docID() { + return multiFeaturesIteratorIterator.docID(); + } + + @Override + public DocIdSetIterator iterator() { + return multiFeaturesIteratorIterator; + } + + @Override + public final Collection getChildren() { + final ArrayList children = new ArrayList<>(); + for (final DisiWrapper scorer : subScorers) { + children.add(new ChildScorable(scorer.scorer, "SHOULD")); + } + return children; + } + + private class MultiFeaturesIterator extends DocIdSetIterator { + + public MultiFeaturesIterator(Collection wrappers) { + // Initialize all wrappers to start at -1 + for (DisiWrapper wrapper : wrappers) { + wrapper.doc = -1; + } + } + + @Override + public int docID() { + // Return the target document ID (mimicking DisjunctionDISIApproximation behavior) + return targetDoc; + } + + @Override + public final int nextDoc() throws IOException { + // Mimic DisjunctionDISIApproximation behavior + if (targetDoc == -1) { + // First call - initialize all iterators + DisiWrapper top = subScorers.top(); + if (top != null && top.doc == -1) { + // Need to advance all iterators to their first document + DisiWrapper current = subScorers.top(); + while (current != null) { + current.doc = current.iterator.nextDoc(); + current = subScorers.updateTop(); + } + top = subScorers.top(); + activeDoc = top == null ? NO_MORE_DOCS : top.doc; + } + targetDoc = activeDoc; + return targetDoc; + } + + if (activeDoc == targetDoc) { + // Advance the underlying disjunction + DisiWrapper top = subScorers.top(); + if (top == null) { + activeDoc = NO_MORE_DOCS; + } else { + // Advance the top iterator and rebalance the queue + top.doc = top.iterator.nextDoc(); + top = subScorers.updateTop(); + activeDoc = top == null ? NO_MORE_DOCS : top.doc; + } + } else if (activeDoc < targetDoc) { + // Need to catch up to targetDoc + 1 + activeDoc = advanceInternal(targetDoc + 1); + } + return ++targetDoc; + } + + @Override + public final int advance(int target) throws IOException { + // Mimic DisjunctionDISIApproximation behavior + if (activeDoc < target) { + activeDoc = advanceInternal(target); + } + targetDoc = target; + return targetDoc; + } + + private int advanceInternal(int target) throws IOException { + // Advance the underlying disjunction to the target + DisiWrapper top; + do { + top = subScorers.top(); + if (top == null) { + return NO_MORE_DOCS; + } + if (top.doc >= target) { + return top.doc; + } + top.doc = top.iterator.advance(target); + top = subScorers.updateTop(); + if (top == null) { + return NO_MORE_DOCS; + } + } while (top.doc < target); + return top.doc; + } + + @Override + public long cost() { + // Calculate cost from all wrappers + long cost = 0; + for (DisiWrapper wrapper : wrappers) { + cost += wrapper.iterator.cost(); + } + return cost; + } + } +} diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/SingleFeatureScorer.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/SingleFeatureScorer.java new file mode 100644 index 000000000000..ceb3f766a7b1 --- /dev/null +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/SingleFeatureScorer.java @@ -0,0 +1,96 @@ +package org.apache.solr.ltr.scoring; + +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Scorer; +import org.apache.solr.ltr.LTRScoringQuery; +import org.apache.solr.ltr.feature.Feature; +import org.apache.solr.ltr.feature.extraction.SingleFeatureExtractor; +import org.apache.solr.ltr.model.LTRScoringModel; +import org.apache.solr.request.SolrQueryRequest; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; + +public class SingleFeatureScorer extends FeatureTraversalScorer { + private final List featureScorers; + + public SingleFeatureScorer( + LTRScoringQuery.ModelWeight modelWeight, + SolrQueryRequest request, + Feature.FeatureWeight[] extractedFeatureWeights, + LTRScoringQuery.FeatureInfo[] allFeaturesInStore, + LTRScoringModel ltrScoringModel, + Map efi, + List featureScorers) { + this.featureScorers = featureScorers; + this.extractedFeatureWeights = extractedFeatureWeights; + this.allFeaturesInStore = allFeaturesInStore; + this.ltrScoringModel = ltrScoringModel; + this.featureExtractor = new SingleFeatureExtractor(this, request, extractedFeatureWeights, allFeaturesInStore, ltrScoringModel, efi, featureScorers); + } + + @Override + public int docID() { + return targetDoc; + } + + @Override + public final Collection getChildren() { + final ArrayList children = new ArrayList<>(); + for (final Scorer scorer : featureScorers) { + children.add(new ChildScorable(scorer, "SHOULD")); + } + return children; + } + + @Override + public DocIdSetIterator iterator() { + return new SingleFeatureIterator(); + } + + private class SingleFeatureIterator extends DocIdSetIterator { + + @Override + public int docID() { + return targetDoc; + } + + @Override + public int nextDoc() throws IOException { + if (activeDoc <= targetDoc) { + activeDoc = NO_MORE_DOCS; + for (final Scorer scorer : featureScorers) { + if (scorer.docID() != NO_MORE_DOCS) { + activeDoc = Math.min(activeDoc, scorer.iterator().nextDoc()); + } + } + } + return ++targetDoc; + } + + @Override + public int advance(int target) throws IOException { + if (activeDoc < target) { + activeDoc = NO_MORE_DOCS; + for (final Scorer scorer : featureScorers) { + if (scorer.docID() != NO_MORE_DOCS) { + activeDoc = Math.min(activeDoc, scorer.iterator().advance(target)); + } + } + } + targetDoc = target; + return target; + } + + @Override + public long cost() { + long sum = 0; + for (int i = 0; i < featureScorers.size(); i++) { + sum += featureScorers.get(i).iterator().cost(); + } + return sum; + } + } +} From e6c50a594ef65687ac6892ceec21718cd610cd11 Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Mon, 22 Sep 2025 09:29:13 +0200 Subject: [PATCH 40/54] Moved scorer outside ltrscoringquery --- .../java/org/apache/solr/ltr/LTRScoringQuery.java | 15 ++++++++++----- .../solr/ltr/scoring/FeatureTraversalScorer.java | 3 ++- .../solr/ltr/scoring/MultiFeaturesScorer.java | 1 + .../solr/ltr/scoring/SingleFeatureScorer.java | 1 + .../solr/ltr/TestSelectiveWeightCreation.java | 2 +- 5 files changed, 15 insertions(+), 7 deletions(-) diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java index 6936726bd5a0..adf517f5f7a0 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java @@ -82,6 +82,8 @@ public class LTRScoringQuery extends Query implements Accountable { // Original solr request private SolrQueryRequest request; + private Feature.FeatureWeight[] extractedFeatureWeights; + public LTRScoringQuery(LTRScoringModel ltrScoringModel) { this(ltrScoringModel, Collections.emptyMap(), null); } @@ -140,6 +142,10 @@ public SolrQueryRequest getRequest() { return request; } + public Feature.FeatureWeight[] getExtractedFeatureWeights() { + return extractedFeatureWeights; + } + @Override public int hashCode() { final int prime = 31; @@ -217,7 +223,7 @@ public ModelWeight createWeight(IndexSearcher searcher, ScoreMode scoreMode, flo } else { features = modelFeatures; } - final Feature.FeatureWeight[] extractedFeatureWeights = + this.extractedFeatureWeights = new Feature.FeatureWeight[features.size()]; final Feature.FeatureWeight[] modelFeaturesWeights = new Feature.FeatureWeight[modelFeatSize]; List featureWeights = new ArrayList<>(features.size()); @@ -242,7 +248,7 @@ public ModelWeight createWeight(IndexSearcher searcher, ScoreMode scoreMode, flo modelFeaturesWeights[j++] = fw; } } - return new ModelWeight(modelFeaturesWeights, extractedFeatureWeights, allFeatures.size()); + return new ModelWeight(modelFeaturesWeights, allFeatures.size()); } private void createWeights( @@ -393,7 +399,6 @@ public class ModelWeight extends Weight { */ public ModelWeight( Feature.FeatureWeight[] modelFeatureWeights, - Feature.FeatureWeight[] extractedFeatureWeights, int allFeaturesSize) { super(LTRScoringQuery.this); this.modelFeatureWeights = modelFeatureWeights; @@ -505,9 +510,9 @@ public ModelScorer(List featureScorers) { } if (featureScorers.size() <= 1) { // future enhancement: allow the use of dense features in other cases - featureTraversalScorer = new SingleFeatureScorer(this, request, extractedFeatureWeights, allFeaturesInStore, ltrScoringModel, efi, featureScorers); + featureTraversalScorer = new SingleFeatureScorer(ModelWeight.this, request, extractedFeatureWeights, allFeaturesInStore, ltrScoringModel, efi, featureScorers); } else { - featureTraversalScorer = new MultiFeaturesScorer(this, request, extractedFeatureWeights, allFeaturesInStore, ltrScoringModel, efi, featureScorers); + featureTraversalScorer = new MultiFeaturesScorer(ModelWeight.this, request, extractedFeatureWeights, allFeaturesInStore, ltrScoringModel, efi, featureScorers); } } diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/FeatureTraversalScorer.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/FeatureTraversalScorer.java index 529361ee5179..3b8f1d042865 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/FeatureTraversalScorer.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/FeatureTraversalScorer.java @@ -12,9 +12,10 @@ public abstract class FeatureTraversalScorer extends Scorer { protected int targetDoc = -1; protected int activeDoc = -1; protected FeatureExtractor featureExtractor; - protected Feature.FeatureWeight[] extractedFeatureWeights; protected LTRScoringQuery.FeatureInfo[] allFeaturesInStore; protected LTRScoringModel ltrScoringModel; + protected Feature.FeatureWeight[] extractedFeatureWeights; + protected LTRScoringQuery.ModelWeight modelWeight; public int getTargetDoc() { return targetDoc; diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/MultiFeaturesScorer.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/MultiFeaturesScorer.java index 8e70105d0fa0..4c5014724726 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/MultiFeaturesScorer.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/MultiFeaturesScorer.java @@ -44,6 +44,7 @@ public MultiFeaturesScorer( this.allFeaturesInStore = allFeaturesInStore; this.ltrScoringModel = ltrScoringModel; this.featureExtractor = new MultiFeaturesExtractor(this, request, extractedFeatureWeights, allFeaturesInStore, ltrScoringModel, efi, subScorers); + this.modelWeight = modelWeight; } @Override diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/SingleFeatureScorer.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/SingleFeatureScorer.java index ceb3f766a7b1..11b2a9591733 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/SingleFeatureScorer.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/SingleFeatureScorer.java @@ -29,6 +29,7 @@ public SingleFeatureScorer( this.allFeaturesInStore = allFeaturesInStore; this.ltrScoringModel = ltrScoringModel; this.featureExtractor = new SingleFeatureExtractor(this, request, extractedFeatureWeights, allFeaturesInStore, ltrScoringModel, efi, featureScorers); + this.modelWeight = modelWeight; } @Override diff --git a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestSelectiveWeightCreation.java b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestSelectiveWeightCreation.java index 00d72666956d..b5148f66b805 100644 --- a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestSelectiveWeightCreation.java +++ b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestSelectiveWeightCreation.java @@ -194,7 +194,7 @@ public void testScoringQueryWeightCreation() throws IOException, ModelException featuresInfo = modelWeight.getAllFeaturesInStore(); assertEquals(features.size(), modelWeight.getModelFeatureValuesNormalized().length); - assertEquals(allFeatures.size(), modelWeight.getExtractedFeatureWeights().length); + assertEquals(allFeatures.size(), ltrQuery2.getExtractedFeatureWeights().length); nonDefaultFeatures = 0; for (int i = 0; i < featuresInfo.length; ++i) { From 1857b26727c13237427f933fdaf06649dc82ec61 Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Mon, 22 Sep 2025 14:33:52 +0200 Subject: [PATCH 41/54] Gradlew widy --- .../org/apache/solr/ltr/LTRScoringQuery.java | 32 ++++++++++++------- .../feature/extraction/FeatureExtractor.java | 14 ++++---- .../extraction/MultiFeaturesExtractor.java | 12 +++++-- .../extraction/SingleFeatureExtractor.java | 14 +++++--- .../LTRFeatureLoggerTransformerFactory.java | 17 +++++----- .../ltr/scoring/FeatureTraversalScorer.java | 2 +- .../solr/ltr/scoring/MultiFeaturesScorer.java | 20 ++++++++---- .../solr/ltr/scoring/SingleFeatureScorer.java | 20 ++++++++---- 8 files changed, 83 insertions(+), 48 deletions(-) diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java index adf517f5f7a0..c44075ca9178 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java @@ -30,8 +30,6 @@ import java.util.concurrent.RunnableFuture; import java.util.concurrent.Semaphore; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.search.DisiPriorityQueue; -import org.apache.lucene.search.DisiWrapper; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.IndexSearcher; @@ -44,9 +42,6 @@ import org.apache.lucene.util.Accountable; import org.apache.lucene.util.RamUsageEstimator; import org.apache.solr.ltr.feature.Feature; -import org.apache.solr.ltr.feature.extraction.FeatureExtractor; -import org.apache.solr.ltr.feature.extraction.MultiFeaturesExtractor; -import org.apache.solr.ltr.feature.extraction.SingleFeatureExtractor; import org.apache.solr.ltr.model.LTRScoringModel; import org.apache.solr.ltr.scoring.FeatureTraversalScorer; import org.apache.solr.ltr.scoring.MultiFeaturesScorer; @@ -223,8 +218,7 @@ public ModelWeight createWeight(IndexSearcher searcher, ScoreMode scoreMode, flo } else { features = modelFeatures; } - this.extractedFeatureWeights = - new Feature.FeatureWeight[features.size()]; + this.extractedFeatureWeights = new Feature.FeatureWeight[features.size()]; final Feature.FeatureWeight[] modelFeaturesWeights = new Feature.FeatureWeight[modelFeatSize]; List featureWeights = new ArrayList<>(features.size()); @@ -397,9 +391,7 @@ public class ModelWeight extends Weight { * @param allFeaturesSize * - total number of feature in the feature store used by this model */ - public ModelWeight( - Feature.FeatureWeight[] modelFeatureWeights, - int allFeaturesSize) { + public ModelWeight(Feature.FeatureWeight[] modelFeatureWeights, int allFeaturesSize) { super(LTRScoringQuery.this); this.modelFeatureWeights = modelFeatureWeights; this.modelFeatureValuesNormalized = new float[modelFeatureWeights.length]; @@ -510,9 +502,25 @@ public ModelScorer(List featureScorers) { } if (featureScorers.size() <= 1) { // future enhancement: allow the use of dense features in other cases - featureTraversalScorer = new SingleFeatureScorer(ModelWeight.this, request, extractedFeatureWeights, allFeaturesInStore, ltrScoringModel, efi, featureScorers); + featureTraversalScorer = + new SingleFeatureScorer( + ModelWeight.this, + request, + extractedFeatureWeights, + allFeaturesInStore, + ltrScoringModel, + efi, + featureScorers); } else { - featureTraversalScorer = new MultiFeaturesScorer(ModelWeight.this, request, extractedFeatureWeights, allFeaturesInStore, ltrScoringModel, efi, featureScorers); + featureTraversalScorer = + new MultiFeaturesScorer( + ModelWeight.this, + request, + extractedFeatureWeights, + allFeaturesInStore, + ltrScoringModel, + efi, + featureScorers); } } diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/FeatureExtractor.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/FeatureExtractor.java index 1027c6777cda..c2656cf17465 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/FeatureExtractor.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/FeatureExtractor.java @@ -1,5 +1,10 @@ package org.apache.solr.ltr.feature.extraction; +import java.io.IOException; +import java.util.Arrays; +import java.util.Map; +import java.util.Objects; +import java.util.TreeMap; import org.apache.solr.ltr.FeatureLogger; import org.apache.solr.ltr.LTRScoringQuery; import org.apache.solr.ltr.feature.Feature; @@ -8,11 +13,6 @@ import org.apache.solr.ltr.scoring.FeatureTraversalScorer; import org.apache.solr.request.SolrQueryRequest; import org.apache.solr.search.SolrCache; -import java.io.IOException; -import java.util.Arrays; -import java.util.Map; -import java.util.Objects; -import java.util.TreeMap; public abstract class FeatureExtractor { protected final FeatureTraversalScorer traversalScorer; @@ -85,8 +85,8 @@ private int computeFeatureVectorCacheKey(int docId) { int prime = 31; int result = docId; if (Objects.equals( - ltrScoringModel.getName(), - LTRFeatureLoggerTransformerFactory.DEFAULT_LOGGING_MODEL_NAME) + ltrScoringModel.getName(), + LTRFeatureLoggerTransformerFactory.DEFAULT_LOGGING_MODEL_NAME) || (logger != null && logger.isLogFeatures() && logger.isLoggingAll())) { result = (prime * result) + ltrScoringModel.getFeatureStoreName().hashCode(); } else { diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/MultiFeaturesExtractor.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/MultiFeaturesExtractor.java index ace249f87f22..6001d49e828d 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/MultiFeaturesExtractor.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/MultiFeaturesExtractor.java @@ -1,5 +1,7 @@ package org.apache.solr.ltr.feature.extraction; +import java.io.IOException; +import java.util.Map; import org.apache.lucene.search.DisiPriorityQueue; import org.apache.lucene.search.DisiWrapper; import org.apache.solr.ltr.LTRScoringQuery; @@ -7,8 +9,6 @@ import org.apache.solr.ltr.model.LTRScoringModel; import org.apache.solr.ltr.scoring.FeatureTraversalScorer; import org.apache.solr.request.SolrQueryRequest; -import java.io.IOException; -import java.util.Map; public class MultiFeaturesExtractor extends FeatureExtractor { DisiPriorityQueue subScorers; @@ -21,7 +21,13 @@ public MultiFeaturesExtractor( LTRScoringModel ltrScoringModel, Map efi, DisiPriorityQueue subScorers) { - super(multiFeaturesScorer, request, extractedFeatureWeights, allFeaturesInStore, ltrScoringModel, efi); + super( + multiFeaturesScorer, + request, + extractedFeatureWeights, + allFeaturesInStore, + ltrScoringModel, + efi); this.subScorers = subScorers; } diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/SingleFeatureExtractor.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/SingleFeatureExtractor.java index ccf7f4be0a78..c6a554135ad9 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/SingleFeatureExtractor.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/SingleFeatureExtractor.java @@ -1,14 +1,14 @@ package org.apache.solr.ltr.feature.extraction; +import java.io.IOException; +import java.util.List; +import java.util.Map; import org.apache.lucene.search.Scorer; import org.apache.solr.ltr.LTRScoringQuery; import org.apache.solr.ltr.feature.Feature; import org.apache.solr.ltr.model.LTRScoringModel; import org.apache.solr.ltr.scoring.FeatureTraversalScorer; import org.apache.solr.request.SolrQueryRequest; -import java.io.IOException; -import java.util.List; -import java.util.Map; public class SingleFeatureExtractor extends FeatureExtractor { List featureScorers; @@ -21,7 +21,13 @@ public SingleFeatureExtractor( LTRScoringModel ltrScoringModel, Map efi, List featureScorers) { - super(singleFeatureScorer, request, extractedFeatureWeights, allFeaturesInStore, ltrScoringModel, efi); + super( + singleFeatureScorer, + request, + extractedFeatureWeights, + allFeaturesInStore, + ltrScoringModel, + efi); this.featureScorers = featureScorers; } diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java index 36107d5348de..e33214bb3b40 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java @@ -445,15 +445,14 @@ private void implTransform(SolrDocument doc, int docid, DocIterationInfo docInfo } } if (!(rerankingQuery instanceof OriginalRankingLTRScoringQuery) || hasExplicitFeatureStore) { - LTRScoringQuery.FeatureInfo[] featuresInfo = extractFeatures( - featureLogger, - rerankingModelWeight, - docid, - (!docsWereReranked && docsHaveScores) ? docInfo.score() : null, - leafContexts); - String featureVector = - featureLogger.printFeatureVector( - featuresInfo); + LTRScoringQuery.FeatureInfo[] featuresInfo = + extractFeatures( + featureLogger, + rerankingModelWeight, + docid, + (!docsWereReranked && docsHaveScores) ? docInfo.score() : null, + leafContexts); + String featureVector = featureLogger.printFeatureVector(featuresInfo); doc.addField(name, featureVector); } } diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/FeatureTraversalScorer.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/FeatureTraversalScorer.java index 3b8f1d042865..d2485b69f3e7 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/FeatureTraversalScorer.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/FeatureTraversalScorer.java @@ -1,11 +1,11 @@ package org.apache.solr.ltr.scoring; +import java.io.IOException; import org.apache.lucene.search.Scorer; import org.apache.solr.ltr.LTRScoringQuery; import org.apache.solr.ltr.feature.Feature; import org.apache.solr.ltr.feature.extraction.FeatureExtractor; import org.apache.solr.ltr.model.LTRScoringModel; -import java.io.IOException; /** This class is responsible for extracting features and using them to score the document. */ public abstract class FeatureTraversalScorer extends Scorer { diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/MultiFeaturesScorer.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/MultiFeaturesScorer.java index 4c5014724726..f744eb17ac73 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/MultiFeaturesScorer.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/MultiFeaturesScorer.java @@ -1,5 +1,10 @@ package org.apache.solr.ltr.scoring; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; import org.apache.lucene.search.DisiPriorityQueue; import org.apache.lucene.search.DisiWrapper; import org.apache.lucene.search.DocIdSetIterator; @@ -9,11 +14,6 @@ import org.apache.solr.ltr.feature.extraction.MultiFeaturesExtractor; import org.apache.solr.ltr.model.LTRScoringModel; import org.apache.solr.request.SolrQueryRequest; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Collection; -import java.util.List; -import java.util.Map; public class MultiFeaturesScorer extends FeatureTraversalScorer { private final DisiPriorityQueue subScorers; @@ -43,7 +43,15 @@ public MultiFeaturesScorer( this.extractedFeatureWeights = extractedFeatureWeights; this.allFeaturesInStore = allFeaturesInStore; this.ltrScoringModel = ltrScoringModel; - this.featureExtractor = new MultiFeaturesExtractor(this, request, extractedFeatureWeights, allFeaturesInStore, ltrScoringModel, efi, subScorers); + this.featureExtractor = + new MultiFeaturesExtractor( + this, + request, + extractedFeatureWeights, + allFeaturesInStore, + ltrScoringModel, + efi, + subScorers); this.modelWeight = modelWeight; } diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/SingleFeatureScorer.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/SingleFeatureScorer.java index 11b2a9591733..7e7954484c88 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/SingleFeatureScorer.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/SingleFeatureScorer.java @@ -1,5 +1,10 @@ package org.apache.solr.ltr.scoring; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Scorer; import org.apache.solr.ltr.LTRScoringQuery; @@ -7,11 +12,6 @@ import org.apache.solr.ltr.feature.extraction.SingleFeatureExtractor; import org.apache.solr.ltr.model.LTRScoringModel; import org.apache.solr.request.SolrQueryRequest; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Collection; -import java.util.List; -import java.util.Map; public class SingleFeatureScorer extends FeatureTraversalScorer { private final List featureScorers; @@ -28,7 +28,15 @@ public SingleFeatureScorer( this.extractedFeatureWeights = extractedFeatureWeights; this.allFeaturesInStore = allFeaturesInStore; this.ltrScoringModel = ltrScoringModel; - this.featureExtractor = new SingleFeatureExtractor(this, request, extractedFeatureWeights, allFeaturesInStore, ltrScoringModel, efi, featureScorers); + this.featureExtractor = + new SingleFeatureExtractor( + this, + request, + extractedFeatureWeights, + allFeaturesInStore, + ltrScoringModel, + efi, + featureScorers); this.modelWeight = modelWeight; } From 0d8f3e788504ea6c5c5bc8200c11d5b50529c58a Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Mon, 22 Sep 2025 14:56:15 +0200 Subject: [PATCH 42/54] Added package-info and did gradle tidy --- .../feature/extraction/FeatureExtractor.java | 17 ++++++++++++++++ .../extraction/MultiFeaturesExtractor.java | 17 ++++++++++++++++ .../extraction/SingleFeatureExtractor.java | 17 ++++++++++++++++ .../ltr/feature/extraction/package-info.java | 19 ++++++++++++++++++ .../ltr/scoring/FeatureTraversalScorer.java | 16 +++++++++++++++ .../solr/ltr/scoring/MultiFeaturesScorer.java | 20 +++++++++++++++++++ .../solr/ltr/scoring/SingleFeatureScorer.java | 17 ++++++++++++++++ .../apache/solr/ltr/scoring/package-info.java | 19 ++++++++++++++++++ 8 files changed, 142 insertions(+) create mode 100644 solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/package-info.java create mode 100644 solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/package-info.java diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/FeatureExtractor.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/FeatureExtractor.java index c2656cf17465..ff5f3e24f2de 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/FeatureExtractor.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/FeatureExtractor.java @@ -1,3 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.solr.ltr.feature.extraction; import java.io.IOException; @@ -14,6 +30,7 @@ import org.apache.solr.request.SolrQueryRequest; import org.apache.solr.search.SolrCache; +/** The class used to extract features for LTR feature logging. */ public abstract class FeatureExtractor { protected final FeatureTraversalScorer traversalScorer; SolrQueryRequest request; diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/MultiFeaturesExtractor.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/MultiFeaturesExtractor.java index 6001d49e828d..1db6d161ede8 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/MultiFeaturesExtractor.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/MultiFeaturesExtractor.java @@ -1,3 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.solr.ltr.feature.extraction; import java.io.IOException; @@ -10,6 +26,7 @@ import org.apache.solr.ltr.scoring.FeatureTraversalScorer; import org.apache.solr.request.SolrQueryRequest; +/** The class used to extract more than one feature for LTR feature logging. */ public class MultiFeaturesExtractor extends FeatureExtractor { DisiPriorityQueue subScorers; diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/SingleFeatureExtractor.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/SingleFeatureExtractor.java index c6a554135ad9..5d3cf648afec 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/SingleFeatureExtractor.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/SingleFeatureExtractor.java @@ -1,3 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.solr.ltr.feature.extraction; import java.io.IOException; @@ -10,6 +26,7 @@ import org.apache.solr.ltr.scoring.FeatureTraversalScorer; import org.apache.solr.request.SolrQueryRequest; +/** The class used to extract a single feature for LTR feature logging. */ public class SingleFeatureExtractor extends FeatureExtractor { List featureScorers; diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/package-info.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/package-info.java new file mode 100644 index 000000000000..83140800aacf --- /dev/null +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/package-info.java @@ -0,0 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Contains the logic to extract features for logging. */ +package org.apache.solr.ltr.feature.extraction; diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/FeatureTraversalScorer.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/FeatureTraversalScorer.java index d2485b69f3e7..194079df6adf 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/FeatureTraversalScorer.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/FeatureTraversalScorer.java @@ -1,3 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.solr.ltr.scoring; import java.io.IOException; diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/MultiFeaturesScorer.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/MultiFeaturesScorer.java index f744eb17ac73..fdf5a1116a34 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/MultiFeaturesScorer.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/MultiFeaturesScorer.java @@ -1,3 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.solr.ltr.scoring; import java.io.IOException; @@ -15,6 +31,10 @@ import org.apache.solr.ltr.model.LTRScoringModel; import org.apache.solr.request.SolrQueryRequest; +/** + * This class is responsible for extracting more than one feature and using them to score the + * document. + */ public class MultiFeaturesScorer extends FeatureTraversalScorer { private final DisiPriorityQueue subScorers; private final List wrappers; diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/SingleFeatureScorer.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/SingleFeatureScorer.java index 7e7954484c88..414d0bb4cbd1 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/SingleFeatureScorer.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/SingleFeatureScorer.java @@ -1,3 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.solr.ltr.scoring; import java.io.IOException; @@ -13,6 +29,7 @@ import org.apache.solr.ltr.model.LTRScoringModel; import org.apache.solr.request.SolrQueryRequest; +/** This class is responsible for extracting a single feature and using it to score the document. */ public class SingleFeatureScorer extends FeatureTraversalScorer { private final List featureScorers; diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/package-info.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/package-info.java new file mode 100644 index 000000000000..54e25ba080b7 --- /dev/null +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/package-info.java @@ -0,0 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Contains the logic to extract features for scoring. */ +package org.apache.solr.ltr.scoring; From c1458b75aeab66a5fb55b588fa7a5d5bae97c3d1 Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Tue, 23 Sep 2025 09:11:40 +0200 Subject: [PATCH 43/54] Changed package info --- .../org/apache/solr/ltr/feature/extraction/package-info.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/package-info.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/package-info.java index 83140800aacf..844fc0dcfe29 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/package-info.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/package-info.java @@ -15,5 +15,5 @@ * limitations under the License. */ -/** Contains the logic to extract features for logging. */ +/** Contains the logic to extract features. */ package org.apache.solr.ltr.feature.extraction; From 03aeabbc5de32d6c198fbd9d47d70bc99aeb3b78 Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Fri, 3 Oct 2025 11:45:20 +0200 Subject: [PATCH 44/54] Attempt to fix docID with dedicated targetDoc --- .../solr/ltr/scoring/FeatureTraversalScorer.java | 10 ++-------- .../apache/solr/ltr/scoring/MultiFeaturesScorer.java | 12 ++++++++++++ .../apache/solr/ltr/scoring/SingleFeatureScorer.java | 12 ++++++++++++ 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/FeatureTraversalScorer.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/FeatureTraversalScorer.java index 194079df6adf..de816147a348 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/FeatureTraversalScorer.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/FeatureTraversalScorer.java @@ -25,21 +25,15 @@ /** This class is responsible for extracting features and using them to score the document. */ public abstract class FeatureTraversalScorer extends Scorer { - protected int targetDoc = -1; - protected int activeDoc = -1; protected FeatureExtractor featureExtractor; protected LTRScoringQuery.FeatureInfo[] allFeaturesInStore; protected LTRScoringModel ltrScoringModel; protected Feature.FeatureWeight[] extractedFeatureWeights; protected LTRScoringQuery.ModelWeight modelWeight; - public int getTargetDoc() { - return targetDoc; - } + public abstract int getActiveDoc(); - public int getActiveDoc() { - return activeDoc; - } + public abstract int getTargetDoc(); public void reset() { for (int i = 0; i < extractedFeatureWeights.length; ++i) { diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/MultiFeaturesScorer.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/MultiFeaturesScorer.java index fdf5a1116a34..4dcfc9a30052 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/MultiFeaturesScorer.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/MultiFeaturesScorer.java @@ -36,6 +36,8 @@ * document. */ public class MultiFeaturesScorer extends FeatureTraversalScorer { + private int targetDoc = -1; + private int activeDoc = -1; private final DisiPriorityQueue subScorers; private final List wrappers; private final MultiFeaturesIterator multiFeaturesIteratorIterator; @@ -75,6 +77,16 @@ public MultiFeaturesScorer( this.modelWeight = modelWeight; } + @Override + public int getActiveDoc() { + return activeDoc; + } + + @Override + public int getTargetDoc() { + return targetDoc; + } + @Override public int docID() { return multiFeaturesIteratorIterator.docID(); diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/SingleFeatureScorer.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/SingleFeatureScorer.java index 414d0bb4cbd1..b880cb0acb5b 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/SingleFeatureScorer.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/SingleFeatureScorer.java @@ -31,6 +31,8 @@ /** This class is responsible for extracting a single feature and using it to score the document. */ public class SingleFeatureScorer extends FeatureTraversalScorer { + private int targetDoc = -1; + private int activeDoc = -1; private final List featureScorers; public SingleFeatureScorer( @@ -57,6 +59,16 @@ public SingleFeatureScorer( this.modelWeight = modelWeight; } + @Override + public int getActiveDoc() { + return activeDoc; + } + + @Override + public int getTargetDoc() { + return targetDoc; + } + @Override public int docID() { return targetDoc; From 4d3532998d073633c89b604806fae91fe6655871 Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Fri, 3 Oct 2025 13:14:14 +0200 Subject: [PATCH 45/54] Trying to set the correct doc id --- .../ltr/src/java/org/apache/solr/ltr/LTRRescorer.java | 1 + .../src/java/org/apache/solr/ltr/LTRScoringQuery.java | 4 ++++ .../solr/ltr/feature/extraction/FeatureExtractor.java | 2 +- .../solr/ltr/scoring/FeatureTraversalScorer.java | 4 ++++ .../apache/solr/ltr/scoring/MultiFeaturesScorer.java | 11 +++++++++++ .../apache/solr/ltr/scoring/SingleFeatureScorer.java | 11 +++++++++++ 6 files changed, 32 insertions(+), 1 deletion(-) diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java index c6f8ec5cb35b..901b738a7385 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java @@ -208,6 +208,7 @@ protected void scoreSingleHit( scorer.iterator().advance(targetDoc); scorer.getDocInfo().setOriginalDocScore(hit.score); + scorer.setSolrDocID(docID); hit.score = scorer.score(); if (QueryLimits.getCurrentLimits() .maybeExitWithPartialResults( diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java index c44075ca9178..8d4f8e3c9623 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java @@ -552,6 +552,10 @@ public DocIdSetIterator iterator() { public void fillFeaturesInfo() throws IOException { featureTraversalScorer.fillFeaturesInfo(); } + + public void setSolrDocID(int solrDocID) throws IOException { + featureTraversalScorer.setSolrDocID(solrDocID); + } } } } diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/FeatureExtractor.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/FeatureExtractor.java index ff5f3e24f2de..878f0c98f722 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/FeatureExtractor.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/FeatureExtractor.java @@ -76,7 +76,7 @@ public void fillFeaturesInfo() throws IOException { featureVectorCache = request.getSearcher().getFeatureVectorCache(); } if (featureVectorCache != null) { - int fvCacheKey = computeFeatureVectorCacheKey(traversalScorer.docID()); + int fvCacheKey = computeFeatureVectorCacheKey(traversalScorer.getSolrDocID()); featureVector = featureVectorCache.get(fvCacheKey); if (featureVector == null) { featureVector = extractFeatureVector(); diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/FeatureTraversalScorer.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/FeatureTraversalScorer.java index de816147a348..45826167e716 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/FeatureTraversalScorer.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/FeatureTraversalScorer.java @@ -35,6 +35,10 @@ public abstract class FeatureTraversalScorer extends Scorer { public abstract int getTargetDoc(); + public abstract int getSolrDocID(); + + public abstract void setSolrDocID(int solrDocID); + public void reset() { for (int i = 0; i < extractedFeatureWeights.length; ++i) { int featId = extractedFeatureWeights[i].getIndex(); diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/MultiFeaturesScorer.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/MultiFeaturesScorer.java index 4dcfc9a30052..1437fef0406c 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/MultiFeaturesScorer.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/MultiFeaturesScorer.java @@ -38,6 +38,7 @@ public class MultiFeaturesScorer extends FeatureTraversalScorer { private int targetDoc = -1; private int activeDoc = -1; + private int solrDocID = -1; private final DisiPriorityQueue subScorers; private final List wrappers; private final MultiFeaturesIterator multiFeaturesIteratorIterator; @@ -87,6 +88,16 @@ public int getTargetDoc() { return targetDoc; } + @Override + public int getSolrDocID() { + return solrDocID; + } + + @Override + public void setSolrDocID(int solrDocID) { + this.solrDocID = solrDocID; + } + @Override public int docID() { return multiFeaturesIteratorIterator.docID(); diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/SingleFeatureScorer.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/SingleFeatureScorer.java index b880cb0acb5b..79cb70d0fc59 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/SingleFeatureScorer.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/SingleFeatureScorer.java @@ -33,6 +33,7 @@ public class SingleFeatureScorer extends FeatureTraversalScorer { private int targetDoc = -1; private int activeDoc = -1; + private int solrDocID = -1; private final List featureScorers; public SingleFeatureScorer( @@ -69,6 +70,16 @@ public int getTargetDoc() { return targetDoc; } + @Override + public int getSolrDocID() { + return solrDocID; + } + + @Override + public void setSolrDocID(int solrDocID) { + this.solrDocID = solrDocID; + } + @Override public int docID() { return targetDoc; From af1990bcbee14411c946d2e223c38072f76f0dc6 Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Fri, 3 Oct 2025 13:21:58 +0200 Subject: [PATCH 46/54] added missing solrDocID set --- .../response/transform/LTRFeatureLoggerTransformerFactory.java | 1 + 1 file changed, 1 insertion(+) diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java index e33214bb3b40..cba6f7b8b9c2 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java @@ -418,6 +418,7 @@ private static LTRScoringQuery.FeatureInfo[] extractFeatures( final LeafReaderContext atomicContext = leafContexts.get(n); final int deBasedDoc = docid - atomicContext.docBase; final LTRScoringQuery.ModelWeight.ModelScorer r = modelWeight.modelScorer(atomicContext); + r.setSolrDocID(docid); if ((r == null) || (r.iterator().advance(deBasedDoc) != deBasedDoc)) { return new LTRScoringQuery.FeatureInfo[0]; } else { From 2a46745298545e9319ee28ffef569ca11590a6a2 Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Tue, 14 Oct 2025 18:27:15 +0200 Subject: [PATCH 47/54] Moved docId for cache key to DocInfo --- .../apache/solr/search/SolrIndexSearcher.java | 5 -- .../src/java/org/apache/solr/ltr/DocInfo.java | 10 ++++ .../java/org/apache/solr/ltr/LTRRescorer.java | 52 +++++++++---------- .../org/apache/solr/ltr/LTRScoringQuery.java | 10 ++-- .../feature/extraction/FeatureExtractor.java | 2 +- .../interleaving/LTRInterleavingRescorer.java | 12 ++--- .../LTRFeatureLoggerTransformerFactory.java | 2 +- .../ltr/scoring/FeatureTraversalScorer.java | 6 +-- .../solr/ltr/scoring/MultiFeaturesScorer.java | 15 +++--- .../solr/ltr/scoring/SingleFeatureScorer.java | 15 +++--- 10 files changed, 62 insertions(+), 67 deletions(-) diff --git a/solr/core/src/java/org/apache/solr/search/SolrIndexSearcher.java b/solr/core/src/java/org/apache/solr/search/SolrIndexSearcher.java index 8f9800fb016c..cef8108f3c5f 100644 --- a/solr/core/src/java/org/apache/solr/search/SolrIndexSearcher.java +++ b/solr/core/src/java/org/apache/solr/search/SolrIndexSearcher.java @@ -738,11 +738,6 @@ public boolean regenerateItem( }); } - if (solrConfig.featureVectorCacheConfig != null - && solrConfig.featureVectorCacheConfig.getRegenerator() == null) { - solrConfig.featureVectorCacheConfig.setRegenerator(new NoOpRegenerator()); - } - if (solrConfig.queryResultCacheConfig != null && solrConfig.queryResultCacheConfig.getRegenerator() == null) { final int queryResultWindowSize = solrConfig.queryResultWindowSize; diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/DocInfo.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/DocInfo.java index e454d90acc2d..ee82bb41df7d 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/DocInfo.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/DocInfo.java @@ -22,6 +22,8 @@ public class DocInfo extends HashMap { // Name of key used to store the original score of a doc private static final String ORIGINAL_DOC_SCORE = "ORIGINAL_DOC_SCORE"; + // Name of key used to store the original id of a doc + private static final String ORIGINAL_DOC_ID = "ORIGINAL_DOC_ID"; public DocInfo() { super(); @@ -38,4 +40,12 @@ public Float getOriginalDocScore() { public boolean hasOriginalDocScore() { return containsKey(ORIGINAL_DOC_SCORE); } + + public void setOriginalDocId(int docId) { + put(ORIGINAL_DOC_ID, docId); + } + + public int getOriginalDocId() { + return (int) get(ORIGINAL_DOC_ID); + } } diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java index 901b738a7385..0cd0258eb52c 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java @@ -113,31 +113,31 @@ protected static void heapify(ScoreDoc[] hits, int size) { * * @param searcher current IndexSearcher * @param firstPassTopDocs documents to rerank; - * @param topN documents to return; + * @param docsToRerank documents to return; */ @Override - public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs, int topN) + public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs, int docsToRerank) throws IOException { - if ((topN == 0) || (firstPassTopDocs.scoreDocs.length == 0)) { + if ((docsToRerank == 0) || (firstPassTopDocs.scoreDocs.length == 0)) { return firstPassTopDocs; } final ScoreDoc[] firstPassResults = getFirstPassDocsRanked(firstPassTopDocs); - topN = Math.toIntExact(Math.min(topN, firstPassTopDocs.totalHits.value())); + docsToRerank = Math.toIntExact(Math.min(docsToRerank, firstPassTopDocs.totalHits.value())); - final ScoreDoc[] reranked = rerank(searcher, topN, firstPassResults); + final ScoreDoc[] reranked = rerank(searcher, docsToRerank, firstPassResults); return new TopDocs(firstPassTopDocs.totalHits, reranked); } - private ScoreDoc[] rerank(IndexSearcher searcher, int topN, ScoreDoc[] firstPassResults) + private ScoreDoc[] rerank(IndexSearcher searcher, int docsToRerank, ScoreDoc[] firstPassResults) throws IOException { - final ScoreDoc[] reranked = new ScoreDoc[topN]; + final ScoreDoc[] reranked = new ScoreDoc[docsToRerank]; final List leaves = searcher.getIndexReader().leaves(); final LTRScoringQuery.ModelWeight modelWeight = (LTRScoringQuery.ModelWeight) searcher.createWeight(searcher.rewrite(scoringQuery), ScoreMode.COMPLETE, 1); - scoreFeatures(topN, modelWeight, firstPassResults, leaves, reranked); + scoreFeatures(docsToRerank, modelWeight, firstPassResults, leaves, reranked); // Must sort all documents that we reranked, and then select the top Arrays.sort(reranked, scoreComparator); return reranked; @@ -152,7 +152,7 @@ protected static ScoreDoc[] getFirstPassDocsRanked(TopDocs firstPassTopDocs) { } public void scoreFeatures( - int topN, + int docsToRerank, LTRScoringQuery.ModelWeight modelWeight, ScoreDoc[] hits, List leaves, @@ -164,13 +164,12 @@ public void scoreFeatures( int docBase = 0; LTRScoringQuery.ModelWeight.ModelScorer scorer = null; - int hitUpto = 0; + int hitPosition = 0; - while (hitUpto < hits.length) { - final ScoreDoc hit = hits[hitUpto]; - final int docID = hit.doc; + while (hitPosition < hits.length) { + final ScoreDoc hit = hits[hitPosition]; LeafReaderContext readerContext = null; - while (docID >= endDoc) { + while (hit.doc >= endDoc) { readerUpto++; readerContext = leaves.get(readerUpto); endDoc = readerContext.docBase + readerContext.reader().maxDoc(); @@ -180,18 +179,17 @@ public void scoreFeatures( docBase = readerContext.docBase; scorer = modelWeight.modelScorer(readerContext); } - scoreSingleHit(topN, docBase, hitUpto, hit, docID, scorer, reranked); - hitUpto++; + scoreSingleHit(docsToRerank, docBase, hitPosition, hit, scorer, reranked); + hitPosition++; } } /** Scores a single document. */ protected void scoreSingleHit( - int topN, + int docsToRerank, int docBase, - int hitUpto, + int hitPosition, ScoreDoc hit, - int docID, LTRScoringQuery.ModelWeight.ModelScorer scorer, ScoreDoc[] reranked) throws IOException { @@ -203,12 +201,12 @@ protected void scoreSingleHit( * needs to compute a potentially non-zero score from blank features. */ assert (scorer != null); - final int targetDoc = docID - docBase; + final int targetDoc = hit.doc - docBase; scorer.docID(); scorer.iterator().advance(targetDoc); scorer.getDocInfo().setOriginalDocScore(hit.score); - scorer.setSolrDocID(docID); + scorer.getDocInfo().setOriginalDocId(hit.doc); hit.score = scorer.score(); if (QueryLimits.getCurrentLimits() .maybeExitWithPartialResults( @@ -217,19 +215,19 @@ protected void scoreSingleHit( + " If partial results are tolerated the reranking got reverted and all documents preserved their original score and ranking.")) { throw new IncompleteRerankingException(); } - if (hitUpto < topN) { - reranked[hitUpto] = hit; - } else if (hitUpto == topN) { + if (hitPosition < docsToRerank) { + reranked[hitPosition] = hit; + } else if (hitPosition == docsToRerank) { // collected topN document, I create the heap - heapify(reranked, topN); + heapify(reranked, docsToRerank); } - if (hitUpto >= topN) { + if (hitPosition >= docsToRerank) { // once that heap is ready, if the score of this document is greater that // the minimum I replace it with the // minimum and fix the heap. if (hit.score > reranked[0].score) { reranked[0] = hit; - heapAdjust(reranked, topN, 0); + heapAdjust(reranked, docsToRerank, 0); } } } diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java index 8d4f8e3c9623..22a9a8e5dfbd 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java @@ -510,7 +510,8 @@ public ModelScorer(List featureScorers) { allFeaturesInStore, ltrScoringModel, efi, - featureScorers); + featureScorers, + docInfo); } else { featureTraversalScorer = new MultiFeaturesScorer( @@ -520,7 +521,8 @@ public ModelScorer(List featureScorers) { allFeaturesInStore, ltrScoringModel, efi, - featureScorers); + featureScorers, + docInfo); } } @@ -552,10 +554,6 @@ public DocIdSetIterator iterator() { public void fillFeaturesInfo() throws IOException { featureTraversalScorer.fillFeaturesInfo(); } - - public void setSolrDocID(int solrDocID) throws IOException { - featureTraversalScorer.setSolrDocID(solrDocID); - } } } } diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/FeatureExtractor.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/FeatureExtractor.java index 878f0c98f722..050ed0605863 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/FeatureExtractor.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/FeatureExtractor.java @@ -76,7 +76,7 @@ public void fillFeaturesInfo() throws IOException { featureVectorCache = request.getSearcher().getFeatureVectorCache(); } if (featureVectorCache != null) { - int fvCacheKey = computeFeatureVectorCacheKey(traversalScorer.getSolrDocID()); + int fvCacheKey = computeFeatureVectorCacheKey(traversalScorer.getDocInfo().getOriginalDocId()); featureVector = featureVectorCache.get(fvCacheKey); if (featureVector == null) { featureVector = extractFeatureVector(); diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingRescorer.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingRescorer.java index 6b2be9345dbe..1660e96efbbb 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingRescorer.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingRescorer.java @@ -57,12 +57,12 @@ public LTRInterleavingRescorer( * * @param searcher current IndexSearcher * @param firstPassTopDocs documents to rerank; - * @param topN documents to return; + * @param docsToRerank documents to return; */ @Override - public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs, int topN) + public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs, int docsToRerank) throws IOException { - if ((topN == 0) || (firstPassTopDocs.scoreDocs.length == 0)) { + if ((docsToRerank == 0) || (firstPassTopDocs.scoreDocs.length == 0)) { return firstPassTopDocs; } @@ -72,10 +72,10 @@ public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs, int top System.arraycopy( firstPassTopDocs.scoreDocs, 0, firstPassResults, 0, firstPassTopDocs.scoreDocs.length); } - topN = Math.toIntExact(Math.min(topN, firstPassTopDocs.totalHits.value())); + docsToRerank = Math.toIntExact(Math.min(docsToRerank, firstPassTopDocs.totalHits.value())); ScoreDoc[][] reRankedPerModel = - rerank(searcher, topN, getFirstPassDocsRanked(firstPassTopDocs)); + rerank(searcher, docsToRerank, getFirstPassDocsRanked(firstPassTopDocs)); if (originalRankingIndex != null) { reRankedPerModel[originalRankingIndex] = firstPassResults; } @@ -150,7 +150,7 @@ public void scoreFeatures( for (int i = 0; i < rerankingQueries.length; i++) { if (modelWeights[i] != null) { final ScoreDoc hit_i = new ScoreDoc(hit.doc, hit.score, hit.shardIndex); - scoreSingleHit(topN, docBase, hitUpto, hit_i, docID, scorers[i], rerankedPerModel[i]); + scoreSingleHit(topN, docBase, hitUpto, hit_i, scorers[i], rerankedPerModel[i]); } } hitUpto++; diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java index cba6f7b8b9c2..a85597bde089 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java @@ -418,7 +418,7 @@ private static LTRScoringQuery.FeatureInfo[] extractFeatures( final LeafReaderContext atomicContext = leafContexts.get(n); final int deBasedDoc = docid - atomicContext.docBase; final LTRScoringQuery.ModelWeight.ModelScorer r = modelWeight.modelScorer(atomicContext); - r.setSolrDocID(docid); + r.getDocInfo().setOriginalDocId(docid); if ((r == null) || (r.iterator().advance(deBasedDoc) != deBasedDoc)) { return new LTRScoringQuery.FeatureInfo[0]; } else { diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/FeatureTraversalScorer.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/FeatureTraversalScorer.java index 45826167e716..9daf911e7178 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/FeatureTraversalScorer.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/FeatureTraversalScorer.java @@ -18,6 +18,7 @@ import java.io.IOException; import org.apache.lucene.search.Scorer; +import org.apache.solr.ltr.DocInfo; import org.apache.solr.ltr.LTRScoringQuery; import org.apache.solr.ltr.feature.Feature; import org.apache.solr.ltr.feature.extraction.FeatureExtractor; @@ -30,14 +31,13 @@ public abstract class FeatureTraversalScorer extends Scorer { protected LTRScoringModel ltrScoringModel; protected Feature.FeatureWeight[] extractedFeatureWeights; protected LTRScoringQuery.ModelWeight modelWeight; + protected DocInfo docInfo; public abstract int getActiveDoc(); public abstract int getTargetDoc(); - public abstract int getSolrDocID(); - - public abstract void setSolrDocID(int solrDocID); + public abstract DocInfo getDocInfo(); public void reset() { for (int i = 0; i < extractedFeatureWeights.length; ++i) { diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/MultiFeaturesScorer.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/MultiFeaturesScorer.java index 1437fef0406c..98527b1f7538 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/MultiFeaturesScorer.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/MultiFeaturesScorer.java @@ -25,6 +25,7 @@ import org.apache.lucene.search.DisiWrapper; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Scorer; +import org.apache.solr.ltr.DocInfo; import org.apache.solr.ltr.LTRScoringQuery; import org.apache.solr.ltr.feature.Feature; import org.apache.solr.ltr.feature.extraction.MultiFeaturesExtractor; @@ -38,7 +39,6 @@ public class MultiFeaturesScorer extends FeatureTraversalScorer { private int targetDoc = -1; private int activeDoc = -1; - private int solrDocID = -1; private final DisiPriorityQueue subScorers; private final List wrappers; private final MultiFeaturesIterator multiFeaturesIteratorIterator; @@ -50,7 +50,8 @@ public MultiFeaturesScorer( LTRScoringQuery.FeatureInfo[] allFeaturesInStore, LTRScoringModel ltrScoringModel, Map efi, - List featureScorers) { + List featureScorers, + DocInfo docInfo) { if (featureScorers.size() <= 1) { throw new IllegalArgumentException("There must be at least 2 subScorers"); } @@ -76,6 +77,7 @@ public MultiFeaturesScorer( efi, subScorers); this.modelWeight = modelWeight; + this.docInfo = docInfo; } @Override @@ -89,13 +91,8 @@ public int getTargetDoc() { } @Override - public int getSolrDocID() { - return solrDocID; - } - - @Override - public void setSolrDocID(int solrDocID) { - this.solrDocID = solrDocID; + public DocInfo getDocInfo() { + return docInfo; } @Override diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/SingleFeatureScorer.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/SingleFeatureScorer.java index 79cb70d0fc59..cc7645e92639 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/SingleFeatureScorer.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/SingleFeatureScorer.java @@ -23,6 +23,7 @@ import java.util.Map; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Scorer; +import org.apache.solr.ltr.DocInfo; import org.apache.solr.ltr.LTRScoringQuery; import org.apache.solr.ltr.feature.Feature; import org.apache.solr.ltr.feature.extraction.SingleFeatureExtractor; @@ -33,7 +34,6 @@ public class SingleFeatureScorer extends FeatureTraversalScorer { private int targetDoc = -1; private int activeDoc = -1; - private int solrDocID = -1; private final List featureScorers; public SingleFeatureScorer( @@ -43,7 +43,8 @@ public SingleFeatureScorer( LTRScoringQuery.FeatureInfo[] allFeaturesInStore, LTRScoringModel ltrScoringModel, Map efi, - List featureScorers) { + List featureScorers, + DocInfo docInfo) { this.featureScorers = featureScorers; this.extractedFeatureWeights = extractedFeatureWeights; this.allFeaturesInStore = allFeaturesInStore; @@ -58,6 +59,7 @@ public SingleFeatureScorer( efi, featureScorers); this.modelWeight = modelWeight; + this.docInfo = docInfo; } @Override @@ -71,13 +73,8 @@ public int getTargetDoc() { } @Override - public int getSolrDocID() { - return solrDocID; - } - - @Override - public void setSolrDocID(int solrDocID) { - this.solrDocID = solrDocID; + public DocInfo getDocInfo() { + return docInfo; } @Override From b1c8c9862346d5310a768df116508728e4f16477 Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Wed, 15 Oct 2025 12:04:38 +0200 Subject: [PATCH 48/54] Changed name variable in Interleaving --- .../interleaving/LTRInterleavingRescorer.java | 21 +++++++++---------- .../ltr/scoring/FeatureTraversalScorer.java | 1 - .../solr/ltr/scoring/MultiFeaturesScorer.java | 1 + .../solr/ltr/scoring/SingleFeatureScorer.java | 1 + ...FeatureExtractionFromMultipleSegments.java | 2 +- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingRescorer.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingRescorer.java index 1660e96efbbb..7b1cde86cbf0 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingRescorer.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingRescorer.java @@ -90,9 +90,9 @@ public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs, int doc return new TopDocs(firstPassTopDocs.totalHits, interleavedResults); } - private ScoreDoc[][] rerank(IndexSearcher searcher, int topN, ScoreDoc[] firstPassResults) + private ScoreDoc[][] rerank(IndexSearcher searcher, int docsToRerank, ScoreDoc[] firstPassResults) throws IOException { - ScoreDoc[][] reRankedPerModel = new ScoreDoc[rerankingQueries.length][topN]; + ScoreDoc[][] reRankedPerModel = new ScoreDoc[rerankingQueries.length][docsToRerank]; final List leaves = searcher.getIndexReader().leaves(); LTRScoringQuery.ModelWeight[] modelWeights = new LTRScoringQuery.ModelWeight[rerankingQueries.length]; @@ -103,7 +103,7 @@ private ScoreDoc[][] rerank(IndexSearcher searcher, int topN, ScoreDoc[] firstPa searcher.createWeight(searcher.rewrite(rerankingQueries[i]), ScoreMode.COMPLETE, 1); } } - scoreFeatures(topN, modelWeights, firstPassResults, leaves, reRankedPerModel); + scoreFeatures(docsToRerank, modelWeights, firstPassResults, leaves, reRankedPerModel); for (int i = 0; i < rerankingQueries.length; i++) { if (originalRankingIndex == null || originalRankingIndex != i) { @@ -115,7 +115,7 @@ private ScoreDoc[][] rerank(IndexSearcher searcher, int topN, ScoreDoc[] firstPa } public void scoreFeatures( - int topN, + int docsToRerank, LTRScoringQuery.ModelWeight[] modelWeights, ScoreDoc[] hits, List leaves, @@ -125,14 +125,13 @@ public void scoreFeatures( int readerUpto = -1; int endDoc = 0; int docBase = 0; - int hitUpto = 0; + int hitPosition = 0; LTRScoringQuery.ModelWeight.ModelScorer[] scorers = new LTRScoringQuery.ModelWeight.ModelScorer[rerankingQueries.length]; - while (hitUpto < hits.length) { - final ScoreDoc hit = hits[hitUpto]; - final int docID = hit.doc; + while (hitPosition < hits.length) { + final ScoreDoc hit = hits[hitPosition]; LeafReaderContext readerContext = null; - while (docID >= endDoc) { + while (hit.doc >= endDoc) { readerUpto++; readerContext = leaves.get(readerUpto); endDoc = readerContext.docBase + readerContext.reader().maxDoc(); @@ -150,10 +149,10 @@ public void scoreFeatures( for (int i = 0; i < rerankingQueries.length; i++) { if (modelWeights[i] != null) { final ScoreDoc hit_i = new ScoreDoc(hit.doc, hit.score, hit.shardIndex); - scoreSingleHit(topN, docBase, hitUpto, hit_i, scorers[i], rerankedPerModel[i]); + scoreSingleHit(docsToRerank, docBase, hitPosition, hit_i, scorers[i], rerankedPerModel[i]); } } - hitUpto++; + hitPosition++; } } diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/FeatureTraversalScorer.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/FeatureTraversalScorer.java index 9daf911e7178..2c92ff9e15e7 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/FeatureTraversalScorer.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/FeatureTraversalScorer.java @@ -31,7 +31,6 @@ public abstract class FeatureTraversalScorer extends Scorer { protected LTRScoringModel ltrScoringModel; protected Feature.FeatureWeight[] extractedFeatureWeights; protected LTRScoringQuery.ModelWeight modelWeight; - protected DocInfo docInfo; public abstract int getActiveDoc(); diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/MultiFeaturesScorer.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/MultiFeaturesScorer.java index 98527b1f7538..2ac44c19be87 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/MultiFeaturesScorer.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/MultiFeaturesScorer.java @@ -39,6 +39,7 @@ public class MultiFeaturesScorer extends FeatureTraversalScorer { private int targetDoc = -1; private int activeDoc = -1; + protected DocInfo docInfo; private final DisiPriorityQueue subScorers; private final List wrappers; private final MultiFeaturesIterator multiFeaturesIteratorIterator; diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/SingleFeatureScorer.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/SingleFeatureScorer.java index cc7645e92639..6619856901af 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/SingleFeatureScorer.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/scoring/SingleFeatureScorer.java @@ -34,6 +34,7 @@ public class SingleFeatureScorer extends FeatureTraversalScorer { private int targetDoc = -1; private int activeDoc = -1; + protected DocInfo docInfo; private final List featureScorers; public SingleFeatureScorer( diff --git a/solr/modules/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureExtractionFromMultipleSegments.java b/solr/modules/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureExtractionFromMultipleSegments.java index d13af84d5c7d..130c82498edd 100644 --- a/solr/modules/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureExtractionFromMultipleSegments.java +++ b/solr/modules/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureExtractionFromMultipleSegments.java @@ -236,7 +236,7 @@ public void testFeatureExtractionFromMultipleSegments() throws Exception { query.setQuery( "{!edismax qf='description^1' boost='sum(product(pow(normHits, 0.7), 1600), .1)' v='apple'}"); // request 100 rows, if any rows are fetched from the second or subsequent segments the tests - // should succeed if LTRRescorer::extractFeatures() advances the doc iterator properly + // should succeed if LTRFeatureLoggerTransformerFactory::extractFeatures() advances the doc iterator properly int numRows = 100; query.add("rows", Integer.toString(numRows)); query.add("wt", "json"); From 7d33e2d5fa7c8c632d582c0b0246cda0b277e015 Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Wed, 15 Oct 2025 12:14:57 +0200 Subject: [PATCH 49/54] Added disclaimer about Lucene doc IDs to the LTR reference page --- .../modules/query-guide/pages/learning-to-rank.adoc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/solr/solr-ref-guide/modules/query-guide/pages/learning-to-rank.adoc b/solr/solr-ref-guide/modules/query-guide/pages/learning-to-rank.adoc index aeefd9051758..836624aa6675 100644 --- a/solr/solr-ref-guide/modules/query-guide/pages/learning-to-rank.adoc +++ b/solr/solr-ref-guide/modules/query-guide/pages/learning-to-rank.adoc @@ -136,9 +136,12 @@ This needs to be added in the `` section as follows. + [source,xml] ---- - + ---- +[NOTE] +The `featureVectorCache` holds Lucene Document IDs. Since they are transient, this cache is not auto-warmed. + * Declaration of the `[features]` transformer. + [source,xml] From 4b4b07e735b313c1067228d979c3c677ee4d6b3a Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Wed, 15 Oct 2025 12:17:47 +0200 Subject: [PATCH 50/54] Gradlew tidy --- .../apache/solr/ltr/feature/extraction/FeatureExtractor.java | 3 ++- .../apache/solr/ltr/interleaving/LTRInterleavingRescorer.java | 3 ++- .../ltr/feature/TestFeatureExtractionFromMultipleSegments.java | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/FeatureExtractor.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/FeatureExtractor.java index 050ed0605863..03911b8bbc45 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/FeatureExtractor.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/feature/extraction/FeatureExtractor.java @@ -76,7 +76,8 @@ public void fillFeaturesInfo() throws IOException { featureVectorCache = request.getSearcher().getFeatureVectorCache(); } if (featureVectorCache != null) { - int fvCacheKey = computeFeatureVectorCacheKey(traversalScorer.getDocInfo().getOriginalDocId()); + int fvCacheKey = + computeFeatureVectorCacheKey(traversalScorer.getDocInfo().getOriginalDocId()); featureVector = featureVectorCache.get(fvCacheKey); if (featureVector == null) { featureVector = extractFeatureVector(); diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingRescorer.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingRescorer.java index 7b1cde86cbf0..8d1227056a35 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingRescorer.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingRescorer.java @@ -149,7 +149,8 @@ public void scoreFeatures( for (int i = 0; i < rerankingQueries.length; i++) { if (modelWeights[i] != null) { final ScoreDoc hit_i = new ScoreDoc(hit.doc, hit.score, hit.shardIndex); - scoreSingleHit(docsToRerank, docBase, hitPosition, hit_i, scorers[i], rerankedPerModel[i]); + scoreSingleHit( + docsToRerank, docBase, hitPosition, hit_i, scorers[i], rerankedPerModel[i]); } } hitPosition++; diff --git a/solr/modules/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureExtractionFromMultipleSegments.java b/solr/modules/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureExtractionFromMultipleSegments.java index 130c82498edd..f9fb8099cf32 100644 --- a/solr/modules/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureExtractionFromMultipleSegments.java +++ b/solr/modules/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureExtractionFromMultipleSegments.java @@ -236,7 +236,8 @@ public void testFeatureExtractionFromMultipleSegments() throws Exception { query.setQuery( "{!edismax qf='description^1' boost='sum(product(pow(normHits, 0.7), 1600), .1)' v='apple'}"); // request 100 rows, if any rows are fetched from the second or subsequent segments the tests - // should succeed if LTRFeatureLoggerTransformerFactory::extractFeatures() advances the doc iterator properly + // should succeed if LTRFeatureLoggerTransformerFactory::extractFeatures() advances the doc + // iterator properly int numRows = 100; query.add("rows", Integer.toString(numRows)); query.add("wt", "json"); From df09ef141e1a21ace6b544d2faef5141587d18d5 Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Wed, 15 Oct 2025 15:31:15 +0200 Subject: [PATCH 51/54] Added element to ltr gradle for metrics testing --- solr/modules/ltr/gradle.lockfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/solr/modules/ltr/gradle.lockfile b/solr/modules/ltr/gradle.lockfile index de00f483d4d5..1601feefd07c 100644 --- a/solr/modules/ltr/gradle.lockfile +++ b/solr/modules/ltr/gradle.lockfile @@ -34,7 +34,7 @@ commons-cli:commons-cli:1.10.0=jarValidation,runtimeClasspath,runtimeLibs,solrPl commons-codec:commons-codec:1.19.0=jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testRuntimeClasspath commons-io:commons-io:2.20.0=compileClasspath,jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testCompileClasspath,testRuntimeClasspath io.dropwizard.metrics:metrics-annotation:4.2.26=jarValidation,testRuntimeClasspath -io.dropwizard.metrics:metrics-core:4.2.26=jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testRuntimeClasspath +io.dropwizard.metrics:metrics-core:4.2.26=jarValidation,testCompileClasspath,runtimeClasspath,runtimeLibs,solrPlatformLibs,testRuntimeClasspath io.dropwizard.metrics:metrics-jetty12-ee10:4.2.26=jarValidation,testRuntimeClasspath io.dropwizard.metrics:metrics-jetty12:4.2.26=jarValidation,testRuntimeClasspath io.github.eisop:dataflow-errorprone:3.41.0-eisop1=annotationProcessor,errorprone,testAnnotationProcessor From d73f0c7d1b41a8e9b8a1ca2a0f275f99410b4064 Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Wed, 15 Oct 2025 15:55:34 +0200 Subject: [PATCH 52/54] Fixed CHANGES.txt and adjust reference guide --- solr/CHANGES.txt | 4 ++-- .../modules/query-guide/pages/learning-to-rank.adoc | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/solr/CHANGES.txt b/solr/CHANGES.txt index e65a90de9a35..f46576020a59 100644 --- a/solr/CHANGES.txt +++ b/solr/CHANGES.txt @@ -29,6 +29,8 @@ New Features * SOLR-17814: Add support for PatienceKnnVectorQuery. (Ilaria Petreti via Alessandro Benedetti) +* SOLR-16667: LTR Add feature vector caching for ranking. (Anna Ruggero, Alessandro Benedetti) + Improvements --------------------- @@ -376,8 +378,6 @@ New Features subset of JavaScript, pre-compiled, and that which can access the score and fields. It's powered by the Lucene Expressions module. (hossman, David Smiley, Ryan Ernst, Kevin Risden) -* SOLR-16667: LTR Add feature vector caching for ranking. (Anna Ruggero, Alessandro Benedetti) - Improvements --------------------- * SOLR-15751: The v2 API now has parity with the v1 "COLSTATUS" and "segments" APIs, which can be used to fetch detailed information about diff --git a/solr/solr-ref-guide/modules/query-guide/pages/learning-to-rank.adoc b/solr/solr-ref-guide/modules/query-guide/pages/learning-to-rank.adoc index 836624aa6675..b83c224ef2cd 100644 --- a/solr/solr-ref-guide/modules/query-guide/pages/learning-to-rank.adoc +++ b/solr/solr-ref-guide/modules/query-guide/pages/learning-to-rank.adoc @@ -140,7 +140,8 @@ This needs to be added in the `` section as follows. ---- [NOTE] -The `featureVectorCache` holds Lucene Document IDs. Since they are transient, this cache is not auto-warmed. +The `featureVectorCache` key is computed using the Lucene Document ID (necessary for document-level features). +Since these IDs are transient, this cache does not support auto-warming. * Declaration of the `[features]` transformer. + From f8bf5a83a51ec22829f9505ea3a0ade280ff9c9a Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Fri, 17 Oct 2025 17:22:23 +0200 Subject: [PATCH 53/54] Adapted to new cache metrics access + tidy --- solr/modules/ltr/build.gradle | 4 +- solr/modules/ltr/gradle.lockfile | 4 +- .../solr/ltr/TestFeatureVectorCache.java | 109 ++++++++---------- 3 files changed, 53 insertions(+), 64 deletions(-) diff --git a/solr/modules/ltr/build.gradle b/solr/modules/ltr/build.gradle index 6467da570def..20582b1d06d6 100644 --- a/solr/modules/ltr/build.gradle +++ b/solr/modules/ltr/build.gradle @@ -55,9 +55,9 @@ dependencies { testImplementation libs.junit.junit testImplementation libs.hamcrest.hamcrest - testImplementation libs.commonsio.commonsio + testImplementation libs.prometheus.metrics.model - testImplementation libs.dropwizard.metrics.core + testImplementation libs.commonsio.commonsio } task copyPythonClientToExample(type: Sync) { diff --git a/solr/modules/ltr/gradle.lockfile b/solr/modules/ltr/gradle.lockfile index 1601feefd07c..d4364c515012 100644 --- a/solr/modules/ltr/gradle.lockfile +++ b/solr/modules/ltr/gradle.lockfile @@ -34,7 +34,7 @@ commons-cli:commons-cli:1.10.0=jarValidation,runtimeClasspath,runtimeLibs,solrPl commons-codec:commons-codec:1.19.0=jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testRuntimeClasspath commons-io:commons-io:2.20.0=compileClasspath,jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testCompileClasspath,testRuntimeClasspath io.dropwizard.metrics:metrics-annotation:4.2.26=jarValidation,testRuntimeClasspath -io.dropwizard.metrics:metrics-core:4.2.26=jarValidation,testCompileClasspath,runtimeClasspath,runtimeLibs,solrPlatformLibs,testRuntimeClasspath +io.dropwizard.metrics:metrics-core:4.2.26=jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testRuntimeClasspath io.dropwizard.metrics:metrics-jetty12-ee10:4.2.26=jarValidation,testRuntimeClasspath io.dropwizard.metrics:metrics-jetty12:4.2.26=jarValidation,testRuntimeClasspath io.github.eisop:dataflow-errorprone:3.41.0-eisop1=annotationProcessor,errorprone,testAnnotationProcessor @@ -65,7 +65,7 @@ io.opentelemetry:opentelemetry-sdk-metrics:1.53.0=jarValidation,runtimeClasspath io.opentelemetry:opentelemetry-sdk-trace:1.53.0=jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testRuntimeClasspath io.opentelemetry:opentelemetry-sdk:1.53.0=jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testRuntimeClasspath io.prometheus:prometheus-metrics-exposition-formats:1.1.0=jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testRuntimeClasspath -io.prometheus:prometheus-metrics-model:1.1.0=jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testRuntimeClasspath +io.prometheus:prometheus-metrics-model:1.1.0=jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testCompileClasspath,testRuntimeClasspath io.sgr:s2-geometry-library-java:1.0.0=jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testRuntimeClasspath io.swagger.core.v3:swagger-annotations-jakarta:2.2.22=compileClasspath,jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testCompileClasspath,testRuntimeClasspath jakarta.annotation:jakarta.annotation-api:2.1.1=jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testRuntimeClasspath diff --git a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestFeatureVectorCache.java b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestFeatureVectorCache.java index 442b3b48514e..2674f567edd9 100644 --- a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestFeatureVectorCache.java +++ b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestFeatureVectorCache.java @@ -16,13 +16,12 @@ */ package org.apache.solr.ltr; +import io.prometheus.metrics.model.snapshots.CounterSnapshot; import java.util.ArrayList; import java.util.List; -import java.util.Map; import org.apache.solr.client.solrj.SolrQuery; import org.apache.solr.core.SolrCore; -import org.apache.solr.metrics.MetricsMap; -import org.apache.solr.metrics.SolrMetricManager; +import org.apache.solr.util.SolrMetricTestUtils; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -55,15 +54,17 @@ public void after() throws Exception { aftertest(); } - private static Map lookupFilterCacheMetrics(SolrCore core) { - return ((MetricsMap) - ((SolrMetricManager.GaugeWrapper) - core.getCoreMetricManager() - .getRegistry() - .getMetrics() - .get("CACHE.searcher.featureVectorCache")) - .getGauge()) - .getValue(); + private static CounterSnapshot.CounterDataPointSnapshot getFeatureVectorCacheInserts( + SolrCore core) { + return SolrMetricTestUtils.getCacheSearcherOpsInserts(core, "featureVectorCache"); + } + + private static double getFeatureVectorCacheLookups(SolrCore core) { + return SolrMetricTestUtils.getCacheSearcherTotalLookups(core, "featureVectorCache"); + } + + private static CounterSnapshot.CounterDataPointSnapshot getFeatureVectorCacheHits(SolrCore core) { + return SolrMetricTestUtils.getCacheSearcherOpsHits(core, "featureVectorCache"); } @Test @@ -103,20 +104,18 @@ public void testFeatureVectorCache_loggingDefaultStoreNoReranking() throws Excep assertJQ( "/query" + query.toQueryString(), "/response/docs/[0]/=={'[fv]':'" + docs0fv_default_csv + "'}"); - Map filterCacheMetrics = lookupFilterCacheMetrics(core); - assertEquals(docs.size(), (long) filterCacheMetrics.get("inserts")); - assertEquals(docs.size(), (long) filterCacheMetrics.get("lookups")); - assertEquals(0, (long) filterCacheMetrics.get("hits")); + assertEquals(docs.size(), getFeatureVectorCacheInserts(core).getValue(), 0); + assertEquals(docs.size(), getFeatureVectorCacheLookups(core), 0); + assertEquals(0, getFeatureVectorCacheHits(core).getValue(), 0); query.add("sort", "popularity desc"); // Caching, we want to see hits assertJQ( "/query" + query.toQueryString(), "/response/docs/[0]/=={'[fv]':'" + docs0fv_default_csv + "'}"); - filterCacheMetrics = lookupFilterCacheMetrics(core); - assertEquals(docs.size(), (long) filterCacheMetrics.get("inserts")); - assertEquals(docs.size() * 2, (long) filterCacheMetrics.get("lookups")); - assertEquals(docs.size(), (long) filterCacheMetrics.get("hits")); + assertEquals(docs.size(), getFeatureVectorCacheInserts(core).getValue(), 0); + assertEquals(docs.size() * 2, getFeatureVectorCacheLookups(core), 0); + assertEquals(docs.size(), getFeatureVectorCacheHits(core).getValue(), 0); } @Test @@ -138,20 +137,18 @@ public void testFeatureVectorCache_loggingExplicitStoreNoReranking() throws Exce assertJQ( "/query" + query.toQueryString(), "/response/docs/[0]/=={'[fv]':'" + docs0fv_default_csv + "'}"); - Map filterCacheMetrics = lookupFilterCacheMetrics(core); - assertEquals(docs.size(), (long) filterCacheMetrics.get("inserts")); - assertEquals(docs.size(), (long) filterCacheMetrics.get("lookups")); - assertEquals(0, (long) filterCacheMetrics.get("hits")); + assertEquals(docs.size(), getFeatureVectorCacheInserts(core).getValue(), 0); + assertEquals(docs.size(), getFeatureVectorCacheLookups(core), 0); + assertEquals(0, getFeatureVectorCacheHits(core).getValue(), 0); query.add("sort", "popularity desc"); // Caching, we want to see hits assertJQ( "/query" + query.toQueryString(), "/response/docs/[0]/=={'[fv]':'" + docs0fv_default_csv + "'}"); - filterCacheMetrics = lookupFilterCacheMetrics(core); - assertEquals(docs.size(), (long) filterCacheMetrics.get("inserts")); - assertEquals(docs.size() * 2, (long) filterCacheMetrics.get("lookups")); - assertEquals(docs.size(), (long) filterCacheMetrics.get("hits")); + assertEquals(docs.size(), getFeatureVectorCacheInserts(core).getValue(), 0); + assertEquals(docs.size() * 2, getFeatureVectorCacheLookups(core), 0); + assertEquals(docs.size(), getFeatureVectorCacheHits(core).getValue(), 0); } @Test @@ -189,10 +186,9 @@ public void testFeatureVectorCache_loggingModelStoreAndRerankingWithDifferentEfi + "'fv':'" + docs0fv_default_csv + "'}"); - Map filterCacheMetrics = lookupFilterCacheMetrics(core); - assertEquals(docs.size() * 2, (long) filterCacheMetrics.get("inserts")); - assertEquals(docs.size() * 2, (long) filterCacheMetrics.get("lookups")); - assertEquals(0, (long) filterCacheMetrics.get("hits")); + assertEquals(docs.size() * 2, getFeatureVectorCacheInserts(core).getValue(), 0); + assertEquals(docs.size() * 2, getFeatureVectorCacheLookups(core), 0); + assertEquals(0, getFeatureVectorCacheHits(core).getValue(), 0); query.add("sort", "popularity desc"); // Caching, we want to see hits and same scores as before @@ -204,10 +200,9 @@ public void testFeatureVectorCache_loggingModelStoreAndRerankingWithDifferentEfi + "'fv':'" + docs0fv_default_csv + "'}"); - filterCacheMetrics = lookupFilterCacheMetrics(core); - assertEquals(docs.size() * 2, (long) filterCacheMetrics.get("inserts")); - assertEquals(docs.size() * 4, (long) filterCacheMetrics.get("lookups")); - assertEquals(docs.size() * 2, (long) filterCacheMetrics.get("hits")); + assertEquals(docs.size() * 2, getFeatureVectorCacheInserts(core).getValue(), 0); + assertEquals(docs.size() * 4, getFeatureVectorCacheLookups(core), 0); + assertEquals(docs.size() * 2, getFeatureVectorCacheHits(core).getValue(), 0); } @Test @@ -244,10 +239,9 @@ public void testFeatureVectorCache_loggingModelStoreAndRerankingWithSameEfi() th + "'fv':'" + docs0fv_default_csv + "'}"); - Map filterCacheMetrics = lookupFilterCacheMetrics(core); - assertEquals(docs.size(), (long) filterCacheMetrics.get("inserts")); - assertEquals(docs.size() * 2, (long) filterCacheMetrics.get("lookups")); - assertEquals(docs.size(), (long) filterCacheMetrics.get("hits")); + assertEquals(docs.size(), getFeatureVectorCacheInserts(core).getValue(), 0); + assertEquals(docs.size() * 2, getFeatureVectorCacheLookups(core), 0); + assertEquals(docs.size(), getFeatureVectorCacheHits(core).getValue(), 0); query.add("sort", "popularity desc"); // Caching, we want to see hits and same scores @@ -259,10 +253,9 @@ public void testFeatureVectorCache_loggingModelStoreAndRerankingWithSameEfi() th + "'fv':'" + docs0fv_default_csv + "'}"); - filterCacheMetrics = lookupFilterCacheMetrics(core); - assertEquals(docs.size(), (long) filterCacheMetrics.get("inserts")); - assertEquals(docs.size() * 4, (long) filterCacheMetrics.get("lookups")); - assertEquals(docs.size() * 3, (long) filterCacheMetrics.get("hits")); + assertEquals(docs.size(), getFeatureVectorCacheInserts(core).getValue(), 0); + assertEquals(docs.size() * 4, getFeatureVectorCacheLookups(core), 0); + assertEquals(docs.size() * 3, getFeatureVectorCacheHits(core).getValue(), 0); } @Test @@ -308,10 +301,9 @@ public void testFeatureVectorCache_loggingAllFeatureStoreAndReranking() throws E + "'fv':'" + docs0fv_default_csv + "'}"); - Map filterCacheMetrics = lookupFilterCacheMetrics(core); - assertEquals(docs.size() * 2, (long) filterCacheMetrics.get("inserts")); - assertEquals(docs.size() * 2, (long) filterCacheMetrics.get("lookups")); - assertEquals(0, (long) filterCacheMetrics.get("hits")); + assertEquals(docs.size() * 2, getFeatureVectorCacheInserts(core).getValue(), 0); + assertEquals(docs.size() * 2, getFeatureVectorCacheLookups(core), 0); + assertEquals(0, getFeatureVectorCacheHits(core).getValue(), 0); query.add("sort", "popularity desc"); // Caching, we want to see hits and same scores @@ -323,10 +315,9 @@ public void testFeatureVectorCache_loggingAllFeatureStoreAndReranking() throws E + "'fv':'" + docs0fv_default_csv + "'}"); - filterCacheMetrics = lookupFilterCacheMetrics(core); - assertEquals(docs.size() * 2, (long) filterCacheMetrics.get("inserts")); - assertEquals(docs.size() * 4, (long) filterCacheMetrics.get("lookups")); - assertEquals(docs.size() * 2, (long) filterCacheMetrics.get("hits")); + assertEquals(docs.size() * 2, getFeatureVectorCacheInserts(core).getValue(), 0); + assertEquals(docs.size() * 4, getFeatureVectorCacheLookups(core), 0); + assertEquals(docs.size() * 2, getFeatureVectorCacheHits(core).getValue(), 0); } @Test @@ -354,10 +345,9 @@ public void testFeatureVectorCache_loggingExplicitStoreAndReranking() throws Exc + "'fv':'" + docs0fv_default_csv + "'}"); - Map filterCacheMetrics = lookupFilterCacheMetrics(core); - assertEquals(docs.size() * 2, (long) filterCacheMetrics.get("inserts")); - assertEquals(docs.size() * 2, (long) filterCacheMetrics.get("lookups")); - assertEquals(0, (long) filterCacheMetrics.get("hits")); + assertEquals(docs.size() * 2, getFeatureVectorCacheInserts(core).getValue(), 0); + assertEquals(docs.size() * 2, getFeatureVectorCacheLookups(core), 0); + assertEquals(0, getFeatureVectorCacheHits(core).getValue(), 0); query.add("sort", "popularity desc"); // Caching, we want to see hits and same scores @@ -369,9 +359,8 @@ public void testFeatureVectorCache_loggingExplicitStoreAndReranking() throws Exc + "'fv':'" + docs0fv_default_csv + "'}"); - filterCacheMetrics = lookupFilterCacheMetrics(core); - assertEquals(docs.size() * 2, (long) filterCacheMetrics.get("inserts")); - assertEquals(docs.size() * 4, (long) filterCacheMetrics.get("lookups")); - assertEquals(docs.size() * 2, (long) filterCacheMetrics.get("hits")); + assertEquals(docs.size() * 2, getFeatureVectorCacheInserts(core).getValue(), 0); + assertEquals(docs.size() * 4, getFeatureVectorCacheLookups(core), 0); + assertEquals(docs.size() * 2, getFeatureVectorCacheHits(core).getValue(), 0); } } From 11075ce5070913a856b57721adc6454c95bf4ca7 Mon Sep 17 00:00:00 2001 From: Anna Ruggero Date: Tue, 21 Oct 2025 09:39:31 +0200 Subject: [PATCH 54/54] Fixed typo --- .../configsets/sample_techproducts_configs/conf/solrconfig.xml | 2 +- .../modules/query-guide/pages/learning-to-rank.adoc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/solr/server/solr/configsets/sample_techproducts_configs/conf/solrconfig.xml b/solr/server/solr/configsets/sample_techproducts_configs/conf/solrconfig.xml index 4c417e0a8853..d6635c05e9b2 100644 --- a/solr/server/solr/configsets/sample_techproducts_configs/conf/solrconfig.xml +++ b/solr/server/solr/configsets/sample_techproducts_configs/conf/solrconfig.xml @@ -391,7 +391,7 @@ /> --> -