-
Notifications
You must be signed in to change notification settings - Fork 0
Add ES93HnswVectorsFormat #11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| /* | ||
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
| * or more contributor license agreements. Licensed under the "Elastic License | ||
| * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side | ||
| * Public License v 1"; you may not use this file except in compliance with, at | ||
| * your election, the "Elastic License 2.0", the "GNU Affero General Public | ||
| * License v3.0 only", or the "Server Side Public License, v 1". | ||
| */ | ||
|
|
||
| package org.elasticsearch.index.codec.vectors.es93; | ||
|
|
||
| import org.apache.lucene.codecs.KnnVectorsReader; | ||
| import org.apache.lucene.codecs.KnnVectorsWriter; | ||
| import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; | ||
| import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; | ||
| import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter; | ||
| import org.apache.lucene.index.SegmentReadState; | ||
| import org.apache.lucene.index.SegmentWriteState; | ||
| import org.elasticsearch.index.codec.vectors.AbstractHnswVectorsFormat; | ||
|
|
||
| import java.io.IOException; | ||
| import java.util.concurrent.ExecutorService; | ||
|
|
||
| public class ES93HnswVectorsFormat extends AbstractHnswVectorsFormat { | ||
|
|
||
| static final String NAME = "ES93HnswVectorsFormat"; | ||
|
|
||
| private final FlatVectorsFormat flatVectorsFormat; | ||
|
|
||
| public ES93HnswVectorsFormat() { | ||
| super(NAME); | ||
| flatVectorsFormat = new ES93GenericFlatVectorsFormat(); | ||
| } | ||
|
|
||
| public ES93HnswVectorsFormat(int maxConn, int beamWidth, boolean bfloat16, boolean useDirectIO) { | ||
| super(NAME, maxConn, beamWidth); | ||
| flatVectorsFormat = new ES93GenericFlatVectorsFormat(bfloat16, useDirectIO); | ||
| } | ||
|
|
||
| public ES93HnswVectorsFormat( | ||
| int maxConn, | ||
| int beamWidth, | ||
| boolean bfloat16, | ||
| boolean useDirectIO, | ||
| int numMergeWorkers, | ||
| ExecutorService mergeExec | ||
| ) { | ||
| super(NAME, maxConn, beamWidth, numMergeWorkers, mergeExec); | ||
| flatVectorsFormat = new ES93GenericFlatVectorsFormat(bfloat16, useDirectIO); | ||
| } | ||
|
|
||
| @Override | ||
| protected FlatVectorsFormat flatVectorsFormat() { | ||
| return flatVectorsFormat; | ||
| } | ||
|
|
||
| @Override | ||
| public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { | ||
| return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), numMergeWorkers, mergeExec); | ||
| } | ||
|
|
||
| @Override | ||
| public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { | ||
| return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state)); | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,99 @@ | ||||||
| /* | ||||||
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||||||
| * or more contributor license agreements. Licensed under the "Elastic License | ||||||
| * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side | ||||||
| * Public License v 1"; you may not use this file except in compliance with, at | ||||||
| * your election, the "Elastic License 2.0", the "GNU Affero General Public | ||||||
| * License v3.0 only", or the "Server Side Public License, v 1". | ||||||
| */ | ||||||
|
|
||||||
| package org.elasticsearch.index.codec.vectors.es93; | ||||||
|
|
||||||
| import org.apache.lucene.index.VectorEncoding; | ||||||
|
|
||||||
| import java.util.regex.Matcher; | ||||||
| import java.util.regex.Pattern; | ||||||
|
|
||||||
| import static org.hamcrest.Matchers.closeTo; | ||||||
|
|
||||||
| public class ES93HnswBFloat16VectorsFormatTests extends ES93HnswVectorsFormatTests { | ||||||
|
|
||||||
| @Override | ||||||
| protected boolean useBFloat16() { | ||||||
| return true; | ||||||
| } | ||||||
|
|
||||||
| @Override | ||||||
| protected VectorEncoding randomVectorEncoding() { | ||||||
| return VectorEncoding.FLOAT32; | ||||||
| } | ||||||
|
|
||||||
| @Override | ||||||
| public void testEmptyByteVectorData() throws Exception { | ||||||
| // no bytes | ||||||
| } | ||||||
|
|
||||||
| @Override | ||||||
| public void testMergingWithDifferentByteKnnFields() throws Exception { | ||||||
| // no bytes | ||||||
| } | ||||||
|
|
||||||
| @Override | ||||||
| public void testByteVectorScorerIteration() throws Exception { | ||||||
| // no bytes | ||||||
| } | ||||||
|
|
||||||
| @Override | ||||||
| public void testSortedIndexBytes() throws Exception { | ||||||
| // no bytes | ||||||
| } | ||||||
|
|
||||||
| @Override | ||||||
| public void testMismatchedFields() throws Exception { | ||||||
| // no bytes | ||||||
| } | ||||||
|
|
||||||
| @Override | ||||||
| public void testRandomBytes() throws Exception { | ||||||
| // no bytes | ||||||
| } | ||||||
|
|
||||||
| @Override | ||||||
| public void testRandom() throws Exception { | ||||||
| AssertionError err = expectThrows(AssertionError.class, super::testRandom); | ||||||
| assertFloatsWithinBounds(err); | ||||||
| } | ||||||
|
|
||||||
| @Override | ||||||
| public void testRandomWithUpdatesAndGraph() throws Exception { | ||||||
| AssertionError err = expectThrows(AssertionError.class, super::testRandomWithUpdatesAndGraph); | ||||||
| assertFloatsWithinBounds(err); | ||||||
| } | ||||||
|
|
||||||
| @Override | ||||||
| public void testSparseVectors() throws Exception { | ||||||
| AssertionError err = expectThrows(AssertionError.class, super::testSparseVectors); | ||||||
| assertFloatsWithinBounds(err); | ||||||
| } | ||||||
|
|
||||||
| @Override | ||||||
| public void testVectorValuesReportCorrectDocs() throws Exception { | ||||||
| AssertionError err = expectThrows(AssertionError.class, super::testVectorValuesReportCorrectDocs); | ||||||
| assertFloatsWithinBounds(err); | ||||||
| } | ||||||
|
|
||||||
| private static final Pattern FLOAT_ASSERTION_FAILURE = Pattern.compile(".*expected:<([0-9.-]+)> but was:<([0-9.-]+)>"); | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The regular expression used to parse floating-point numbers from the assertion error message is not robust enough. It will fail to parse numbers represented in scientific notation (e.g.,
Suggested change
|
||||||
|
|
||||||
| private static void assertFloatsWithinBounds(AssertionError error) { | ||||||
| Matcher m = FLOAT_ASSERTION_FAILURE.matcher(error.getMessage()); | ||||||
| if (m.matches() == false) { | ||||||
| throw error; // nothing to do with us, just rethrow | ||||||
| } | ||||||
|
|
||||||
| // numbers just need to be in the same vicinity | ||||||
| double expected = Double.parseDouble(m.group(1)); | ||||||
| double actual = Double.parseDouble(m.group(2)); | ||||||
| double allowedError = expected * 0.01; // within 1% | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggestion: The tolerance for comparing floating point values is computed as Severity Level: Minor
Suggested change
Why it matters? ⭐The current code computes allowedError as expected * 0.01. If expected is negative this produces a negative delta, which is incorrect for Hamcrest's closeTo matcher (delta must be non-negative) and will make the comparison behave wrongly or always fail for negative expected values. Using Math.abs(expected) (or otherwise ensuring a non-negative delta) fixes a real logic bug in the test relaxation introduced for bfloat16 rounding differences. This is not merely cosmetic — it affects correctness of assertions when expected < 0. Prompt for AI Agent 🤖This is a comment left during a code review.
**Path:** server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBFloat16VectorsFormatTests.java
**Line:** 96:96
**Comment:**
*Logic Error: The tolerance for comparing floating point values is computed as `expected * 0.01`, which becomes negative when `expected` is negative; passing a negative delta into `closeTo` causes the matcher to always fail (since it effectively requires `|diff| <= negative_value`), so for negative expected values your "within 1%" relaxation will never succeed and these tests will still fail instead of allowing small differences due to bfloat16 rounding.
Validate the correctness of the flagged issue. If correct, How can I resolve this? If you propose a fix, implement it and please make it concise. |
||||||
| assertThat(error.getMessage(), actual, closeTo(expected, allowedError)); | ||||||
| } | ||||||
| } | ||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,117 @@ | ||||||
| /* | ||||||
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||||||
| * or more contributor license agreements. Licensed under the "Elastic License | ||||||
| * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side | ||||||
| * Public License v 1"; you may not use this file except in compliance with, at | ||||||
| * your election, the "Elastic License 2.0", the "GNU Affero General Public | ||||||
| * License v3.0 only", or the "Server Side Public License, v 1". | ||||||
| */ | ||||||
|
|
||||||
| package org.elasticsearch.index.codec.vectors.es93; | ||||||
|
|
||||||
| import org.apache.lucene.codecs.Codec; | ||||||
| import org.apache.lucene.codecs.FilterCodec; | ||||||
| import org.apache.lucene.codecs.KnnVectorsFormat; | ||||||
| import org.apache.lucene.codecs.KnnVectorsReader; | ||||||
| import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; | ||||||
| import org.apache.lucene.document.Document; | ||||||
| import org.apache.lucene.document.KnnFloatVectorField; | ||||||
| import org.apache.lucene.index.CodecReader; | ||||||
| import org.apache.lucene.index.DirectoryReader; | ||||||
| import org.apache.lucene.index.IndexReader; | ||||||
| import org.apache.lucene.index.IndexWriter; | ||||||
| import org.apache.lucene.index.LeafReader; | ||||||
| import org.apache.lucene.store.Directory; | ||||||
| import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; | ||||||
| import org.apache.lucene.tests.util.TestUtil; | ||||||
| import org.apache.lucene.util.SameThreadExecutorService; | ||||||
| import org.elasticsearch.common.logging.LogConfigurator; | ||||||
| import org.elasticsearch.index.codec.vectors.BFloat16; | ||||||
|
|
||||||
| import java.io.IOException; | ||||||
| import java.util.Locale; | ||||||
|
|
||||||
| import static java.lang.String.format; | ||||||
| import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; | ||||||
| import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; | ||||||
| import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; | ||||||
| import static org.hamcrest.Matchers.is; | ||||||
| import static org.hamcrest.Matchers.oneOf; | ||||||
|
|
||||||
| public class ES93HnswVectorsFormatTests extends BaseKnnVectorsFormatTestCase { | ||||||
|
|
||||||
| static { | ||||||
| LogConfigurator.loadLog4jPlugins(); | ||||||
| LogConfigurator.configureESLogging(); // native access requires logging to be initialized | ||||||
| } | ||||||
|
|
||||||
| private KnnVectorsFormat format; | ||||||
|
|
||||||
| protected boolean useBFloat16() { | ||||||
| return false; | ||||||
| } | ||||||
|
|
||||||
| @Override | ||||||
| public void setUp() throws Exception { | ||||||
| format = new ES93HnswVectorsFormat(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, useBFloat16(), random().nextBoolean()); | ||||||
| super.setUp(); | ||||||
| } | ||||||
|
|
||||||
| @Override | ||||||
| protected Codec getCodec() { | ||||||
| return TestUtil.alwaysKnnVectorsFormat(format); | ||||||
| } | ||||||
|
|
||||||
| public void testToString() { | ||||||
| FilterCodec customCodec = new FilterCodec("foo", Codec.getDefault()) { | ||||||
| @Override | ||||||
| public KnnVectorsFormat knnVectorsFormat() { | ||||||
| return new ES93HnswVectorsFormat(10, 20, false, false); | ||||||
| } | ||||||
| }; | ||||||
| String expectedPattern = "ES93HnswVectorsFormat(name=ES93HnswVectorsFormat, maxConn=10, beamWidth=20," | ||||||
| + " flatVectorFormat=ES93GenericFlatVectorsFormat(name=ES93GenericFlatVectorsFormat," | ||||||
| + " format=Lucene99FlatVectorsFormat(name=Lucene99FlatVectorsFormat, flatVectorScorer=%s())))"; | ||||||
| var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer"); | ||||||
| var memSegScorer = format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer"); | ||||||
| assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer))); | ||||||
| } | ||||||
|
|
||||||
| public void testLimits() { | ||||||
| expectThrows(IllegalArgumentException.class, () -> new ES93HnswVectorsFormat(-1, 20, false, false)); | ||||||
| expectThrows(IllegalArgumentException.class, () -> new ES93HnswVectorsFormat(0, 20, false, false)); | ||||||
| expectThrows(IllegalArgumentException.class, () -> new ES93HnswVectorsFormat(20, 0, false, false)); | ||||||
| expectThrows(IllegalArgumentException.class, () -> new ES93HnswVectorsFormat(20, -1, false, false)); | ||||||
| expectThrows(IllegalArgumentException.class, () -> new ES93HnswVectorsFormat(512 + 1, 20, false, false)); | ||||||
| expectThrows(IllegalArgumentException.class, () -> new ES93HnswVectorsFormat(20, 3201, false, false)); | ||||||
| expectThrows( | ||||||
| IllegalArgumentException.class, | ||||||
| () -> new ES93HnswVectorsFormat(20, 100, false, false, 1, new SameThreadExecutorService()) | ||||||
| ); | ||||||
| } | ||||||
|
|
||||||
| public void testSimpleOffHeapSize() throws IOException { | ||||||
| float[] vector = randomVector(random().nextInt(12, 500)); | ||||||
| try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { | ||||||
| Document doc = new Document(); | ||||||
| doc.add(new KnnFloatVectorField("f", vector, DOT_PRODUCT)); | ||||||
| w.addDocument(doc); | ||||||
| w.commit(); | ||||||
| try (IndexReader reader = DirectoryReader.open(w)) { | ||||||
| LeafReader r = getOnlyLeafReader(reader); | ||||||
| if (r instanceof CodecReader codecReader) { | ||||||
| KnnVectorsReader knnVectorsReader = codecReader.getVectorReader(); | ||||||
| if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) { | ||||||
| knnVectorsReader = fieldsReader.getFieldReader("f"); | ||||||
| } | ||||||
| var fieldInfo = r.getFieldInfos().fieldInfo("f"); | ||||||
| var offHeap = knnVectorsReader.getOffHeapByteSize(fieldInfo); | ||||||
| int bytes = useBFloat16() ? BFloat16.BYTES : Float.BYTES; | ||||||
| assertEquals(vector.length * bytes, (long) offHeap.get("vec")); | ||||||
| assertEquals(1L, (long) offHeap.get("vex")); | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The key |
||||||
| assertEquals(2, offHeap.size()); | ||||||
|
Comment on lines
+111
to
+112
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggestion: The test Severity Level: Minor
Suggested change
Why it matters? ⭐The current test makes a brittle assumption about the exact keys and count returned by getOffHeapByteSize: casting offHeap.get("vex") to long will NPE if that key is absent and the fixed-size assertion will break with any change to the map contract. The suggested change (only asserting the essential "vec" entry) removes a fragile assertion and focuses the test on the meaningful quantity (vector bytes). This is a legitimate, non-cosmetic improvement because it prevents spurious failures tied to incidental map entries rather than the behavior under test. Prompt for AI Agent 🤖This is a comment left during a code review.
**Path:** server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswVectorsFormatTests.java
**Line:** 111:112
**Comment:**
*Logic Error: The test `testSimpleOffHeapSize` assumes that `getOffHeapByteSize` returns an entry with key `vex` and that the map has exactly two entries; if the underlying implementation only exposes a `vec` entry (as in other flat vector readers) or adds different keys in future, `offHeap.get("vex")` will return null and the `(long)` cast will throw a NullPointerException and/or the fixed size assertion will fail, making the test incorrect and brittle with respect to the actual API contract of `getOffHeapByteSize`.
Validate the correctness of the flagged issue. If correct, How can I resolve this? If you propose a fix, implement it and please make it concise. |
||||||
| } | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The constructors in this class have duplicated logic for initializing
flatVectorsFormat. You can refactor them to use constructor chaining, which will reduce code duplication and improve maintainability.