diff --git a/src/it/java/io/weaviate/ConcurrentTest.java b/src/it/java/io/weaviate/ConcurrentTest.java index 2e2036d18..f3a70cc37 100644 --- a/src/it/java/io/weaviate/ConcurrentTest.java +++ b/src/it/java/io/weaviate/ConcurrentTest.java @@ -2,7 +2,6 @@ import java.util.Random; import java.util.UUID; -import java.util.stream.IntStream; import org.apache.commons.lang3.RandomStringUtils; import org.junit.Rule; @@ -56,9 +55,11 @@ protected static String randomUUID() { * @param bound Value range upper bound. * @return */ - protected static Float[] randomVector(int length, float origin, float bound) { - return IntStream.range(0, length) - .mapToObj(f -> rand.nextFloat(origin, bound)) - .toArray(Float[]::new); + protected static float[] randomVector(int length, float origin, float bound) { + var vector = new float[length]; + for (var i = 0; i < length; i++) { + vector[i] = rand.nextFloat(origin, bound); + } + return vector; } } diff --git a/src/it/java/io/weaviate/integration/DataITest.java b/src/it/java/io/weaviate/integration/DataITest.java index 827236163..712a0957d 100644 --- a/src/it/java/io/weaviate/integration/DataITest.java +++ b/src/it/java/io/weaviate/integration/DataITest.java @@ -37,7 +37,7 @@ public static void beforeAll() throws IOException { public void testCreateGetDelete() throws IOException { var artists = client.collections.use(COLLECTION); var id = randomUUID(); - Float[] vector = { 1f, 2f, 3f }; + float[] vector = { 1, 2, 3 }; artists.data.insert(Map.of("name", "john doe"), metadata -> metadata @@ -56,8 +56,8 @@ public void testCreateGetDelete() throws IOException { Assertions.assertThat(obj.metadata().uuid()) .as("object id").isEqualTo(id); - Assertions.assertThat(obj.metadata().vectors()).extracting(v -> v.getSingle(VECTOR_INDEX)) - .asInstanceOf(InstanceOfAssertFactories.array(Float[].class)).containsExactly(vector); + Assertions.assertThat(obj.metadata().vectors().getSingle(VECTOR_INDEX)) + .containsExactly(vector); Assertions.assertThat(obj.properties()) .as("has expected properties") @@ -227,7 +227,7 @@ public void testUpdate() throws IOException { var authors = client.collections.use(nsAuthors); var walter = authors.data.insert(Map.of("name", "walter scott")); - var vector = new Float[] { 1f, 2f, 3f }; + var vector = new float[] { 1, 2, 3 }; var books = client.collections.use(nsBooks); diff --git a/src/it/java/io/weaviate/integration/SearchITest.java b/src/it/java/io/weaviate/integration/SearchITest.java index 3f67240a5..480e18c55 100644 --- a/src/it/java/io/weaviate/integration/SearchITest.java +++ b/src/it/java/io/weaviate/integration/SearchITest.java @@ -53,7 +53,7 @@ public class SearchITest extends ConcurrentTest { /** * One of the inserted vectors which will be used as target vector for search. */ - private static Float[] searchVector; + private static float[] searchVector; @BeforeClass public static void beforeAll() throws IOException { @@ -102,10 +102,10 @@ public void testNearVector_groupBy() { /** * Insert 10 objects with random vectors. * - * @returns IDs of inserted objects and their corresponding vectors. + * @return IDs of inserted objects and their corresponding vectors. */ - private static Map populateTest(int n) throws IOException { - var created = new HashMap(); + private static Map populateTest(int n) throws IOException { + var created = new HashMap(); var things = client.collections.use(COLLECTION); for (int i = 0; i < n; i++) { diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/Vectors.java b/src/main/java/io/weaviate/client6/v1/api/collections/Vectors.java index 9638bed49..5c3a6a778 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/Vectors.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/Vectors.java @@ -25,21 +25,22 @@ */ @ToString public class Vectors { + /** Elements of this map must only be {@code float[]} or {@code float[][]}. */ private final Map namedVectors; - public static Vectors of(Float[] vector) { - return new Vectors(VectorIndex.DEFAULT_VECTOR_NAME, vector); + public static Vectors of(float[] vector) { + return of(VectorIndex.DEFAULT_VECTOR_NAME, vector); } - public static Vectors of(String name, Float[] vector) { + public static Vectors of(String name, float[] vector) { return new Vectors(name, vector); } - public static Vectors of(Float[][] vector) { - return new Vectors(VectorIndex.DEFAULT_VECTOR_NAME, vector); + public static Vectors of(float[][] vector) { + return of(VectorIndex.DEFAULT_VECTOR_NAME, vector); } - public static Vectors of(String name, Float[][] vector) { + public static Vectors of(String name, float[][] vector) { return new Vectors(name, vector); } @@ -51,20 +52,30 @@ public Vectors(Builder builder) { this.namedVectors = builder.namedVectors; } - /* + /** * Create a single named vector. - * Intended to be used by factory methods, which can statically restrict - * vector's type to {@code Float[]} and {@code Float[][]}. * - * @param name Vector name. - * - * @param vector {@code Float[]} or {@code Float[][]} vector. + *

+ * Callers must ensure that vectors are either + * {@code float[]} or {@code float[][]}. * + * @param name Vector name. + * @param vector {@code float[]} or {@code float[][]} vector. */ private Vectors(String name, Object vector) { this.namedVectors = Collections.singletonMap(name, vector); } + /** + * Create a Vectors from a map. + * + *

+ * Callers must ensure that vectors are either + * {@code float[]} or {@code float[][]}. + * + * @param name Vector name. + * @param vector Map of named vectors. + */ private Vectors(Map namedVectors) { this.namedVectors = namedVectors; } @@ -72,12 +83,12 @@ private Vectors(Map namedVectors) { public static class Builder implements ObjectBuilder { private final Map namedVectors = new HashMap<>(); - public Builder vector(String name, Float[] vector) { + public Builder vector(String name, float[] vector) { this.namedVectors.put(name, vector); return this; } - public Builder vector(String name, Float[][] vector) { + public Builder vector(String name, float[][] vector) { this.namedVectors.put(name, vector); return this; } @@ -88,22 +99,55 @@ public Vectors build() { } } - public Float[] getSingle(String name) { - return (Float[]) namedVectors.get(name); + /** + * Get 1-dimensional vector by name. + * + * @return Vector as {@code float[]} or {@code null}. + * @throws ClassCastException The underlying vector is not a {@code float[]}. + */ + public float[] getSingle(String name) { + return (float[]) namedVectors.get(name); } - public Float[] getDefaultSingle() { + /** + * Get default 1-dimensional vector. + * + * @return Vector as {@code float[]} or {@code null}. + * @throws ClassCastException if the underlying object is not a {@code float[]}. + */ + public float[] getDefaultSingle() { return getSingle(VectorIndex.DEFAULT_VECTOR_NAME); } - public Float[][] getMulti(String name) { - return (Float[][]) namedVectors.get(name); + /** + * Get 2-dimensional vector by name. + * + * @return Vector as {@code float[][]} or {@code null}. + * @throws ClassCastException if the underlying object is not a + * {@code float[][]}. + */ + public float[][] getMulti(String name) { + return (float[][]) namedVectors.get(name); } - public Float[][] getDefaultMulti() { + /** + * Get default 2-dimensional vector. + * + * @return Vector as {@code float[][]} or {@code null}. + * @throws ClassCastException if the underlying object is not a + * {@code float[][]}. + */ + public float[][] getDefaultMulti() { return getMulti(VectorIndex.DEFAULT_VECTOR_NAME); } + /** + * Get all vectors. + * Each element is either a {@code float[]} or a {@code float[][]}. + * + * + * @return Map of name-vector pairs. The returned map is immutable. + */ public Map asMap() { return Map.copyOf(namedVectors); } @@ -119,8 +163,8 @@ public TypeAdapter create(Gson gson, TypeToken type) { } final var mapAdapter = gson.getDelegateAdapter(this, new TypeToken>() { }); - final var float_1d = gson.getDelegateAdapter(this, TypeToken.get(Float[].class)); - final var float_2d = gson.getDelegateAdapter(this, TypeToken.get(Float[][].class)); + final var float_1d = gson.getDelegateAdapter(this, TypeToken.get(float[].class)); + final var float_2d = gson.getDelegateAdapter(this, TypeToken.get(float[][].class)); return (TypeAdapter) new TypeAdapter() { @Override @@ -144,6 +188,8 @@ public Vectors read(JsonReader in) throws IOException { } else { vector = float_1d.fromJsonTree(array); } + + assert (vector instanceof float[]) || (vector instanceof float[][]) : "invalid vector type"; namedVectors.put(vectorName, vector); } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/AbstractAggregateClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/AbstractAggregateClient.java index 23fce1bc5..4258947bd 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/AbstractAggregateClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/AbstractAggregateClient.java @@ -72,11 +72,11 @@ public GroupedResponseT hybrid(Hybrid filter, Function> fn) { + public ResponseT nearVector(float[] vector, Function> fn) { return nearVector(NearVector.of(vector), fn); } - public ResponseT nearVector(Float[] vector, Function> nv, + public ResponseT nearVector(float[] vector, Function> nv, Function> fn) { return nearVector(NearVector.of(vector, nv), fn); } @@ -85,12 +85,12 @@ public ResponseT nearVector(NearVector filter, Function> fn, + public GroupedResponseT nearVector(float[] vector, Function> fn, GroupBy groupBy) { return nearVector(NearVector.of(vector), fn, groupBy); } - public GroupedResponseT nearVector(Float[] vector, Function> nv, + public GroupedResponseT nearVector(float[] vector, Function> nv, Function> fn, GroupBy groupBy) { return nearVector(NearVector.of(vector, nv), fn, groupBy); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/GroupedBy.java b/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/GroupedBy.java index d3db6e971..f853780c7 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/GroupedBy.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/GroupedBy.java @@ -13,12 +13,12 @@ public String text() { } public boolean isInteger() { - return value instanceof String; + return value instanceof Long; } - public Integer integer() { - checkPropertyType(this::isInteger, "Integer"); - return (Integer) value; + public Long integer() { + checkPropertyType(this::isInteger, "Long"); + return (Long) value; } private void checkPropertyType(Supplier check, String expected) { diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertManyRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertManyRequest.java index 6c6f42748..48c41ebec 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertManyRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertManyRequest.java @@ -101,10 +101,10 @@ public static void buildObject(WeaviateProtoBatch.BatchObject.Builder object var vector = WeaviateProtoBase.Vectors.newBuilder() .setName(entry.getKey()); - if (value instanceof Float[] single) { + if (value instanceof float[] single) { vector.setType(VectorType.VECTOR_TYPE_SINGLE_FP32); vector.setVectorBytes(ByteStringUtil.encodeVectorSingle(single)); - } else if (value instanceof Float[][] multi) { + } else if (value instanceof float[][] multi) { vector.setVectorBytes(ByteStringUtil.encodeVectorMulti(multi)); vector.setType(VectorType.VECTOR_TYPE_MULTI_FP32); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/AbstractQueryClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/AbstractQueryClient.java index cc0017527..0db66bd6b 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/AbstractQueryClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/AbstractQueryClient.java @@ -110,11 +110,11 @@ public GroupedResponseT hybrid(Hybrid query, GroupBy groupBy) { // NearVector queries ------------------------------------------------------- - public ResponseT nearVector(Float[] vector) { + public ResponseT nearVector(float[] vector) { return nearVector(NearVector.of(vector)); } - public ResponseT nearVector(Float[] vector, Function> fn) { + public ResponseT nearVector(float[] vector, Function> fn) { return nearVector(NearVector.of(vector, fn)); } @@ -122,11 +122,11 @@ public ResponseT nearVector(NearVector query) { return performRequest(query); } - public GroupedResponseT nearVector(Float[] vector, GroupBy groupBy) { + public GroupedResponseT nearVector(float[] vector, GroupBy groupBy) { return nearVector(NearVector.of(vector), groupBy); } - public GroupedResponseT nearVector(Float[] vector, Function> fn, + public GroupedResponseT nearVector(float[] vector, Function> fn, GroupBy groupBy) { return nearVector(NearVector.of(vector, fn), groupBy); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/BaseQueryOptions.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/BaseQueryOptions.java index 3a2815864..e709ee069 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/BaseQueryOptions.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/BaseQueryOptions.java @@ -47,17 +47,17 @@ public static abstract class Builder, T extends Ob private List returnReferences = new ArrayList<>(); private List returnMetadata = new ArrayList<>(); - public final SELF limit(Integer limit) { + public final SELF limit(int limit) { this.limit = limit; return (SELF) this; } - public final SELF offset(Integer offset) { + public final SELF offset(int offset) { this.offset = offset; return (SELF) this; } - public final SELF autocut(Integer autocut) { + public final SELF autocut(int autocut) { this.autocut = autocut; return (SELF) this; } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/NearVector.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/NearVector.java index 266e9ddbf..303729879 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/NearVector.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/NearVector.java @@ -10,14 +10,14 @@ import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBaseSearch; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet; -public record NearVector(Float[] vector, Float distance, Float certainty, BaseQueryOptions common) +public record NearVector(float[] vector, Float distance, Float certainty, BaseQueryOptions common) implements QueryOperator, AggregateObjectFilter { - public static final NearVector of(Float[] vector) { + public static final NearVector of(float[] vector) { return of(vector, ObjectBuilder.identity()); } - public static final NearVector of(Float[] vector, Function> fn) { + public static final NearVector of(float[] vector, Function> fn) { return fn.apply(new Builder(vector)).build(); } @@ -27,9 +27,9 @@ public NearVector(Builder builder) { public static class Builder extends BaseVectorSearchBuilder { // Required query parameters. - private final Float[] vector; + private final float[] vector; - public Builder(Float[] vector) { + public Builder(float[] vector) { this.vector = vector; } diff --git a/src/main/java/io/weaviate/client6/v1/internal/grpc/ByteStringUtil.java b/src/main/java/io/weaviate/client6/v1/internal/grpc/ByteStringUtil.java index c4dbd7785..1d45bed0f 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/grpc/ByteStringUtil.java +++ b/src/main/java/io/weaviate/client6/v1/internal/grpc/ByteStringUtil.java @@ -6,8 +6,6 @@ import java.util.Arrays; import java.util.UUID; -import org.apache.commons.lang3.ArrayUtils; - import com.google.protobuf.ByteString; public class ByteStringUtil { @@ -21,32 +19,25 @@ public static UUID decodeUuid(ByteString bs) { return new UUID(most, least); } - /** Encode Float[] to ByteString. */ - public static ByteString encodeVectorSingle(Float[] vector) { + /** Encode float[] to ByteString. */ + public static ByteString encodeVectorSingle(float[] vector) { if (vector == null || vector.length == 0) { return ByteString.EMPTY; } ByteBuffer buffer = ByteBuffer.allocate(vector.length * Float.BYTES).order(BYTE_ORDER); - Arrays.stream(vector).forEach(buffer::putFloat); - return ByteString.copyFrom(buffer.array()); - } - - /** Encode float[] to ByteString. */ - public static ByteString encodeVectorSingle(float[] vector) { - ByteBuffer buffer = ByteBuffer.allocate(vector.length * Float.BYTES).order(BYTE_ORDER); - for (float f : vector) { + for (final var f : vector) { buffer.putFloat(f); } return ByteString.copyFrom(buffer.array()); } /** - * Encode Float[][] to ByteString. + * Encode float[][] to ByteString. *

* The first 2 bytes of the resulting ByteString encode the number of dimensions * (uint16 / short) followed by concatenated vectors (4 bytes per element). */ - public static ByteString encodeVectorMulti(Float[][] vectors) { + public static ByteString encodeVectorMulti(float[][] vectors) { if (vectors == null || vectors.length == 0 || vectors[0].length == 0) { return ByteString.EMPTY; } @@ -57,45 +48,67 @@ public static ByteString encodeVectorMulti(Float[][] vectors) { /* concatenated elements */ (n * dimensions * Float.BYTES); ByteBuffer buffer = ByteBuffer.allocate(capacity).order(BYTE_ORDER) .putShort(dimensions); - Arrays.stream(vectors).forEach(v -> Arrays.stream(v).forEach(buffer::putFloat)); + Arrays.stream(vectors).forEach(vector -> { + for (final var f : vector) { + buffer.putFloat(f); + } + }); return ByteString.copyFrom(buffer.array()); } /** - * Decode ByteString into a Float[]. ByteString size must be a multiple of - * {@link Float#BYTES}, throws {@link IllegalArgumentException} otherwise. + * Decode ByteString to {@code float[]}. + * + * @throws IllegalArgumentException if ByteString size is not + * a multiple of {@link Float#BYTES}. */ - public static Float[] decodeVectorSingle(ByteString bs) { + public static float[] decodeVectorSingle(ByteString bs) { if (bs.size() % Float.BYTES != 0) { throw new IllegalArgumentException( - "byte string size not a multiple of " + String.valueOf(Float.BYTES) + " (Float.BYTES)"); + "ByteString size " + bs.size() + " is not a multiple of " + String.valueOf(Float.BYTES) + " (Float.BYTES)"); } float[] vector = new float[bs.size() / Float.BYTES]; bs.asReadOnlyByteBuffer().order(BYTE_ORDER).asFloatBuffer().get(vector); - return ArrayUtils.toObject(vector); + return vector; } - /** Decode ByteString to Float[][]. */ - public static Float[][] decodeVectorMulti(ByteString bs) { + /** + * Decode ByteString to {@code float[][]}. + * + *

+ * The expected structure of the byte string of total size N is: + *

    + *
  • [2 bytes]: dimensionality of the inner vector ({@code dim}) + *
  • [N-2 bytes]: concatenated inner vectors. N-2 must be a multiple of + * {@code Float.BYTES * dim} + *
+ * + * @throws IllegalArgumentException if ByteString is not of a valid size. + */ + public static float[][] decodeVectorMulti(ByteString bs) { if (bs == null || bs.size() == 0) { - return new Float[0][0]; + return new float[0][0]; } ByteBuffer buf = bs.asReadOnlyByteBuffer().order(BYTE_ORDER); + short dim = buf.getShort(); // advances current position + if (dim == 0) { + return new float[0][0]; + } - // Dimensions are encoded in the first 2 bytes. - short dimensions = buf.getShort(); // advances current position - - FloatBuffer fbuf = buf.asFloatBuffer(); - int n = fbuf.remaining() / dimensions; // fbuf size is buf / Float.BYTES + FloatBuffer fbuf = buf.asFloatBuffer(); // fbuf size is buf / Float.BYTES + if (fbuf.remaining() % dim != 0) { + throw new IllegalArgumentException( + "Remaing ByteString size " + fbuf.remaining() + " is not a multiple of " + dim + + " (dim)"); + } + int n = fbuf.remaining() / dim; // Reading from buffer advances current position, // so we always read from offset=0. - Float[][] vectors = new Float[n][dimensions]; + float[][] vectors = new float[n][dim]; for (int i = 0; i < n; i++) { - float[] v = new float[dimensions]; - fbuf.get(v, 0, dimensions); - vectors[i] = ArrayUtils.toObject(v); + fbuf.get(vectors[i], 0, dim); } return vectors; } diff --git a/src/test/java/io/weaviate/client6/v1/internal/grpc/ByteStringUtilTest.java b/src/test/java/io/weaviate/client6/v1/internal/grpc/ByteStringUtilTest.java new file mode 100644 index 000000000..f9c6d1f71 --- /dev/null +++ b/src/test/java/io/weaviate/client6/v1/internal/grpc/ByteStringUtilTest.java @@ -0,0 +1,96 @@ +package io.weaviate.client6.v1.internal.grpc; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.Test; + +import com.google.protobuf.ByteString; + +/** + * Note: Java's {@code byte} is signed (int8) and is different from {@code byte} + * in Go, which is an alias for uint8. + * + * For this tests purposes the distinction is immaterial, as "want" arrays + * are "golden values" meant to be a readable respresentation for the test. + */ +public class ByteStringUtilTest { + @Test + public void test_encodeVector_1d() { + float[] vector = { 1f, 2f, 3f }; + byte[] want = { 0, 0, -128, 63, 0, 0, 0, 64, 0, 0, 64, 64 }; + byte[] got = ByteStringUtil.encodeVectorSingle(vector).toByteArray(); + assertArrayEquals(want, got); + } + + @Test + public void test_decodeVector_1d() { + byte[] bytes = { 0, 0, -128, 63, 0, 0, 0, 64, 0, 0, 64, 64 }; + float[] want = { 1f, 2f, 3f }; + float[] got = ByteStringUtil.decodeVectorSingle(ByteString.copyFrom(bytes)); + assertArrayEquals(want, got, 0); + } + + @Test + public void test_encodeVector_2d() { + float[][] vector = { { 1f, 2f, 3f }, { 4f, 5f, 6f } }; + byte[] want = { 3, 0, 0, 0, -128, 63, 0, 0, 0, 64, 0, 0, 64, 64, 0, 0, -128, 64, 0, 0, -96, 64, 0, 0, -64, 64 }; + byte[] got = ByteStringUtil.encodeVectorMulti(vector).toByteArray(); + assertArrayEquals(want, got); + } + + @Test + public void test_decodeVector_2d() { + byte[] bytes = { 3, 0, 0, 0, -128, 63, 0, 0, 0, 64, 0, 0, 64, 64, 0, 0, -128, 64, 0, 0, -96, 64, 0, 0, -64, 64 }; + float[][] want = { { 1f, 2f, 3f }, { 4f, 5f, 6f } }; + float[][] got = ByteStringUtil.decodeVectorMulti(ByteString.copyFrom(bytes)); + assertArrayEquals(want, got); + } + + @Test + public void test_decodeUuid() { + byte[] bytes = { 38, 19, -74, 24, -114, -19, 73, 43, -112, -60, 47, 96, 83, -89, -35, -23 }; + String want = "2613b618-8eed-492b-90c4-2f6053a7dde9"; + String got = ByteStringUtil.decodeUuid(ByteString.copyFrom(bytes)).toString(); + assertEquals(want, got); + } + + @Test + public void test_decodeVector_1d_empty() { + byte[] bytes = new byte[0]; + float[] got = ByteStringUtil.decodeVectorSingle(ByteString.copyFrom(bytes)); + assertEquals(0, got.length); + } + + @Test + public void test_decodeVector_2d_empty() { + byte[] bytes = new byte[0]; + float[][] got = ByteStringUtil.decodeVectorMulti(ByteString.copyFrom(bytes)); + assertEquals(0, got.length); + } + + @Test + public void test_decodeVector_2d_dim_zero() { + byte[] bytes = new byte[] { 0, 0 }; + float[][] got = ByteStringUtil.decodeVectorMulti(ByteString.copyFrom(bytes)); + assertEquals(0, got.length); + } + + @Test(expected = IllegalArgumentException.class) + public void test_decodeVector_1d_illegal() { + byte[] bytes = new byte[Float.BYTES - 1]; // must be a multiple of Float.BYTES + ByteStringUtil.decodeVectorSingle(ByteString.copyFrom(bytes)); + } + + @Test(expected = IllegalArgumentException.class) + public void test_decodeVector_2d_illegal() { + // The first Short.BYTES is the dimensionality of each array. + // The size of the rest must be a multiple of Float.BYTES * dimensionality. + var dimensionality = 5; + byte[] bytes = new byte[Short.BYTES + (Float.BYTES * dimensionality - 1)]; + bytes[0] = 0; + bytes[1] = (byte) dimensionality; + + ByteStringUtil.decodeVectorMulti(ByteString.copyFrom(bytes)); + } +} diff --git a/src/test/java/io/weaviate/client6/v1/internal/grpc/GRPCTest.java b/src/test/java/io/weaviate/client6/v1/internal/grpc/GRPCTest.java deleted file mode 100644 index 1bc5d76a4..000000000 --- a/src/test/java/io/weaviate/client6/v1/internal/grpc/GRPCTest.java +++ /dev/null @@ -1,57 +0,0 @@ -package io.weaviate.client6.v1.internal.grpc; - -import static org.junit.Assert.assertArrayEquals; -import static org.junit.jupiter.api.Assertions.assertEquals; - -import org.junit.Test; - -import com.google.protobuf.ByteString; - -/** - * Note: Java's {@code byte} is signed (int8) and is different from {@code byte} - * in Go, which is an alias for uint8. - * - * For this tests purposes the distinction is immaterial, as "want" arrays - * are "golden values" meant to be a readable respresentation for the test. - */ -public class GRPCTest { - @Test - public void test_encodeVector_1d() { - Float[] vector = { 1f, 2f, 3f }; - byte[] want = { 0, 0, -128, 63, 0, 0, 0, 64, 0, 0, 64, 64 }; - byte[] got = ByteStringUtil.encodeVectorSingle(vector).toByteArray(); - assertArrayEquals(want, got); - } - - @Test - public void test_decodeVector_1d() { - byte[] bytes = { 0, 0, -128, 63, 0, 0, 0, 64, 0, 0, 64, 64 }; - Float[] want = { 1f, 2f, 3f }; - Float[] got = ByteStringUtil.decodeVectorSingle(ByteString.copyFrom(bytes)); - assertArrayEquals(want, got); - } - - @Test - public void test_encodeVector_2d() { - Float[][] vector = { { 1f, 2f, 3f }, { 4f, 5f, 6f } }; - byte[] want = { 3, 0, 0, 0, -128, 63, 0, 0, 0, 64, 0, 0, 64, 64, 0, 0, -128, 64, 0, 0, -96, 64, 0, 0, -64, 64 }; - byte[] got = ByteStringUtil.encodeVectorMulti(vector).toByteArray(); - assertArrayEquals(want, got); - } - - @Test - public void test_decodeVector_2d() { - byte[] bytes = { 3, 0, 0, 0, -128, 63, 0, 0, 0, 64, 0, 0, 64, 64, 0, 0, -128, 64, 0, 0, -96, 64, 0, 0, -64, 64 }; - Float[][] want = { { 1f, 2f, 3f }, { 4f, 5f, 6f } }; - Float[][] got = ByteStringUtil.decodeVectorMulti(ByteString.copyFrom(bytes)); - assertArrayEquals(want, got); - } - - @Test - public void test_decodeUuid() { - byte[] bytes = { 38, 19, -74, 24, -114, -19, 73, 43, -112, -60, 47, 96, 83, -89, -35, -23 }; - String want = "2613b618-8eed-492b-90c4-2f6053a7dde9"; - String got = ByteStringUtil.decodeUuid(ByteString.copyFrom(bytes)).toString(); - assertEquals(want, got); - } -} diff --git a/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java b/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java index 82a205ae9..b931db440 100644 --- a/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java +++ b/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java @@ -188,33 +188,33 @@ public static Object[][] testCases() { // Vectors.CustomTypeAdapterFactory { Vectors.class, - Vectors.of(new Float[] { 1f, 2f }), + Vectors.of(new float[] { 1f, 2f }), "{\"default\": [1.0, 2.0]}", (CustomAssert) JSONTest::compareVectors, }, { Vectors.class, - Vectors.of(new Float[][] { { 1f, 2f }, { 3f, 4f } }), + Vectors.of(new float[][] { { 1f, 2f }, { 3f, 4f } }), "{\"default\": [[1.0, 2.0], [3.0, 4.0]]}", (CustomAssert) JSONTest::compareVectors, }, { Vectors.class, - Vectors.of("custom", new Float[] { 1f, 2f }), + Vectors.of("custom", new float[] { 1f, 2f }), "{\"custom\": [1.0, 2.0]}", (CustomAssert) JSONTest::compareVectors, }, { Vectors.class, - Vectors.of("custom", new Float[][] { { 1f, 2f }, { 3f, 4f } }), + Vectors.of("custom", new float[][] { { 1f, 2f }, { 3f, 4f } }), "{\"custom\": [[1.0, 2.0], [3.0, 4.0]]}", (CustomAssert) JSONTest::compareVectors, }, { Vectors.class, Vectors.of(named -> named - .vector("1d", new Float[] { 1f, 2f }) - .vector("2d", new Float[][] { { 1f, 2f }, { 3f, 4f } })), + .vector("1d", new float[] { 1f, 2f }) + .vector("2d", new float[][] { { 1f, 2f }, { 3f, 4f } })), "{\"1d\": [1.0, 2.0], \"2d\": [[1.0, 2.0], [3.0, 4.0]]}", (CustomAssert) JSONTest::compareVectors, }, @@ -382,13 +382,13 @@ private static void assertEqualJson(String want, String got) { /** * Custom assert function that uses deep array equality - * to correctly compare Float[] and Float[][] nested in the object. + * to correctly compare float[] and float[][] nested in the object. */ private static void compareVectors(Object got, Object want) { Assertions.assertThat(got) .usingRecursiveComparison() - .withEqualsForType(Arrays::equals, Float[].class) - .withEqualsForType(Arrays::deepEquals, Float[][].class) + .withEqualsForType(Arrays::equals, float[].class) + .withEqualsForType(Arrays::deepEquals, float[][].class) .isEqualTo(want); }