diff --git a/src/it/java/io/weaviate/containers/Container.java b/src/it/java/io/weaviate/containers/Container.java index 2c4218a9f..ee8859a18 100644 --- a/src/it/java/io/weaviate/containers/Container.java +++ b/src/it/java/io/weaviate/containers/Container.java @@ -10,7 +10,7 @@ import org.testcontainers.containers.Network; import org.testcontainers.lifecycle.Startable; -import io.weaviate.client6.WeaviateClient; +import io.weaviate.client6.v1.api.WeaviateClient; import lombok.RequiredArgsConstructor; public class Container { diff --git a/src/it/java/io/weaviate/containers/Weaviate.java b/src/it/java/io/weaviate/containers/Weaviate.java index d6251c028..c70342fd1 100644 --- a/src/it/java/io/weaviate/containers/Weaviate.java +++ b/src/it/java/io/weaviate/containers/Weaviate.java @@ -9,8 +9,8 @@ import org.testcontainers.weaviate.WeaviateContainer; -import io.weaviate.client6.Config; -import io.weaviate.client6.WeaviateClient; +import io.weaviate.client6.v1.api.Config; +import io.weaviate.client6.v1.api.WeaviateClient; public class Weaviate extends WeaviateContainer { private WeaviateClient clientInstance; diff --git a/src/it/java/io/weaviate/integration/AggregationITest.java b/src/it/java/io/weaviate/integration/AggregationITest.java index 29bc7b43c..7c7597c0c 100644 --- a/src/it/java/io/weaviate/integration/AggregationITest.java +++ b/src/it/java/io/weaviate/integration/AggregationITest.java @@ -12,7 +12,8 @@ import org.junit.Test; import io.weaviate.ConcurrentTest; -import io.weaviate.client6.WeaviateClient; +import io.weaviate.client6.v1.api.WeaviateClient; +import io.weaviate.client6.v1.api.collections.Property; import io.weaviate.client6.v1.api.collections.Vectors; import io.weaviate.client6.v1.api.collections.aggregate.AggregateResponseGroup; import io.weaviate.client6.v1.api.collections.aggregate.AggregateResponseGrouped; @@ -20,9 +21,8 @@ import io.weaviate.client6.v1.api.collections.aggregate.GroupBy; import io.weaviate.client6.v1.api.collections.aggregate.GroupedBy; import io.weaviate.client6.v1.api.collections.aggregate.IntegerAggregation; -import io.weaviate.client6.v1.collections.Property; -import io.weaviate.client6.v1.collections.VectorIndex; -import io.weaviate.client6.v1.collections.Vectorizer; +import io.weaviate.client6.v1.api.collections.vectorindex.Hnsw; +import io.weaviate.client6.v1.api.collections.vectorizers.NoneVectorizer; import io.weaviate.containers.Container; public class AggregationITest extends ConcurrentTest { @@ -36,7 +36,7 @@ public static void beforeAll() throws IOException { .properties( Property.text("category"), Property.integer("price")) - .vectors(io.weaviate.client6.v1.collections.Vectors.of(new VectorIndex<>(Vectorizer.none())))); + .vector(Hnsw.of(new NoneVectorizer()))); var things = client.collections.use(COLLECTION); for (var category : List.of("Shoes", "Hat", "Jacket")) { diff --git a/src/it/java/io/weaviate/integration/CollectionsITest.java b/src/it/java/io/weaviate/integration/CollectionsITest.java index 59083c44b..0b5b94fbe 100644 --- a/src/it/java/io/weaviate/integration/CollectionsITest.java +++ b/src/it/java/io/weaviate/integration/CollectionsITest.java @@ -7,15 +7,12 @@ import org.junit.Test; import io.weaviate.ConcurrentTest; -import io.weaviate.client6.WeaviateClient; -import io.weaviate.client6.v1.collections.Collection; -import io.weaviate.client6.v1.collections.NoneVectorizer; -import io.weaviate.client6.v1.collections.Property; -import io.weaviate.client6.v1.collections.VectorIndex; -import io.weaviate.client6.v1.collections.VectorIndex.IndexType; -import io.weaviate.client6.v1.collections.VectorIndex.IndexingStrategy; -import io.weaviate.client6.v1.collections.Vectorizer; -import io.weaviate.client6.v1.collections.Vectors; +import io.weaviate.client6.v1.api.WeaviateClient; +import io.weaviate.client6.v1.api.collections.Property; +import io.weaviate.client6.v1.api.collections.VectorIndex; +import io.weaviate.client6.v1.api.collections.WeaviateCollection; +import io.weaviate.client6.v1.api.collections.vectorindex.Hnsw; +import io.weaviate.client6.v1.api.collections.vectorizers.NoneVectorizer; import io.weaviate.containers.Container; public class CollectionsITest extends ConcurrentTest { @@ -27,18 +24,19 @@ public void testCreateGetDelete() throws IOException { client.collections.create(collectionName, col -> col .properties(Property.text("username"), Property.integer("age")) - .vector(new VectorIndex<>(IndexingStrategy.hnsw(), Vectorizer.none()))); + .vector(Hnsw.of(new NoneVectorizer()))); var thingsCollection = client.collections.getConfig(collectionName); Assertions.assertThat(thingsCollection).get() .hasFieldOrPropertyWithValue("name", collectionName) - .extracting(Collection::vectors).extracting(Vectors::getDefault) - .as("default vector").satisfies(defaultVector -> { + .extracting(WeaviateCollection::vectors, InstanceOfAssertFactories.map(String.class, VectorIndex.class)) + .as("default vector").extractingByKey("default") + .satisfies(defaultVector -> { Assertions.assertThat(defaultVector).extracting(VectorIndex::vectorizer) .as("has none vectorizer").isInstanceOf(NoneVectorizer.class); - Assertions.assertThat(defaultVector).extracting(VectorIndex::configuration) - .as("has hnsw index").returns(IndexType.HNSW, IndexingStrategy::type); + Assertions.assertThat(defaultVector).extracting(VectorIndex::config) + .isInstanceOf(Hnsw.class); }); client.collections.delete(collectionName); diff --git a/src/it/java/io/weaviate/integration/DataITest.java b/src/it/java/io/weaviate/integration/DataITest.java index 5d2325200..1f29861aa 100644 --- a/src/it/java/io/weaviate/integration/DataITest.java +++ b/src/it/java/io/weaviate/integration/DataITest.java @@ -9,16 +9,15 @@ import org.junit.Test; import io.weaviate.ConcurrentTest; -import io.weaviate.client6.WeaviateClient; +import io.weaviate.client6.v1.api.WeaviateClient; +import io.weaviate.client6.v1.api.collections.Property; import io.weaviate.client6.v1.api.collections.Vectors; -import io.weaviate.client6.v1.collections.Property; -import io.weaviate.client6.v1.collections.VectorIndex; -import io.weaviate.client6.v1.collections.VectorIndex.IndexingStrategy; -import io.weaviate.client6.v1.collections.Vectorizer; +import io.weaviate.client6.v1.api.collections.WeaviateObject; +import io.weaviate.client6.v1.api.collections.vectorindex.Hnsw; +import io.weaviate.client6.v1.api.collections.vectorizers.NoneVectorizer; import io.weaviate.containers.Container; public class DataITest extends ConcurrentTest { - private static WeaviateClient client = Container.WEAVIATE.getClient(); private static final String COLLECTION = unique("Artists"); private static final String VECTOR_INDEX = "bring_your_own"; @@ -34,9 +33,10 @@ public void testCreateGetDelete() throws IOException { var id = randomUUID(); Float[] vector = { 1f, 2f, 3f }; - artists.data.insert(Map.of("name", "john doe"), metadata -> metadata - .id(id) - .vectors(Vectors.of(VECTOR_INDEX, vector))); + artists.data.insert(Map.of("name", "john doe"), + metadata -> metadata + .uuid(id) + .vectors(Vectors.of(VECTOR_INDEX, vector))); var object = artists.query.byId(id, query -> query .returnProperties("name") @@ -45,11 +45,10 @@ public void testCreateGetDelete() throws IOException { Assertions.assertThat(object) .as("object exists after insert").get() .satisfies(obj -> { - Assertions.assertThat(obj.metadata().id()) + Assertions.assertThat(obj.metadata().uuid()) .as("object id").isEqualTo(id); - Assertions.assertThat(obj.metadata().vectors()).extracting(Vectors::getSingle) - .asInstanceOf(InstanceOfAssertFactories.OPTIONAL).as("has single vector").get() + Assertions.assertThat(obj.metadata().vectors()).extracting(v -> v.getSingle(VECTOR_INDEX)) .asInstanceOf(InstanceOfAssertFactories.array(Float[].class)).containsExactly(vector); Assertions.assertThat(obj.properties()) @@ -77,11 +76,11 @@ public void testBlobData() throws IOException { "breed", "ragdoll", "img", ragdollPng)); - var got = cats.query.byId(ragdoll.metadata().id(), + var got = cats.query.byId(ragdoll.metadata().uuid(), cat -> cat.returnProperties("img")); Assertions.assertThat(got).get() - .extracting(io.weaviate.client6.v1.api.collections.WeaviateObject::properties, InstanceOfAssertFactories.MAP) + .extracting(WeaviateObject::properties, InstanceOfAssertFactories.MAP) .extractingByKey("img").isEqualTo(ragdollPng); } @@ -99,6 +98,6 @@ private static void createTestCollections() throws IOException { Property.integer("age")) .references( Property.reference("hasAwards", awardsGrammy, awardsOscar)) - .vector(VECTOR_INDEX, new VectorIndex<>(IndexingStrategy.hnsw(), Vectorizer.none()))); + .vectors(named -> named.vector(VECTOR_INDEX, Hnsw.of(new NoneVectorizer())))); } } diff --git a/src/it/java/io/weaviate/integration/ReferencesITest.java b/src/it/java/io/weaviate/integration/ReferencesITest.java index 7feca750b..519319e64 100644 --- a/src/it/java/io/weaviate/integration/ReferencesITest.java +++ b/src/it/java/io/weaviate/integration/ReferencesITest.java @@ -1,6 +1,7 @@ package io.weaviate.integration; import java.io.IOException; +import java.util.List; import java.util.Map; import java.util.Optional; @@ -9,15 +10,14 @@ import org.junit.Test; import io.weaviate.ConcurrentTest; -import io.weaviate.client6.WeaviateClient; -import io.weaviate.client6.v1.api.collections.ObjectReference; +import io.weaviate.client6.v1.api.WeaviateClient; +import io.weaviate.client6.v1.api.collections.ObjectMetadata; +import io.weaviate.client6.v1.api.collections.Property; +import io.weaviate.client6.v1.api.collections.ReferenceProperty; import io.weaviate.client6.v1.api.collections.WeaviateObject; +import io.weaviate.client6.v1.api.collections.data.Reference; import io.weaviate.client6.v1.api.collections.query.MetadataField; -import io.weaviate.client6.v1.api.collections.query.QueryMetadata; import io.weaviate.client6.v1.api.collections.query.QueryReference; -import io.weaviate.client6.v1.collections.Property; -import io.weaviate.client6.v1.collections.Reference; -import io.weaviate.client6.v1.collections.ReferenceProperty; import io.weaviate.containers.Container; /** @@ -75,7 +75,7 @@ public void testReferences() throws IOException { Map.of("name", "Alex"), opt -> opt .reference("hasAwards", Reference.uuids( - grammy_1.metadata().id(), oscar_1.metadata().id())) + grammy_1.metadata().uuid(), oscar_1.metadata().uuid())) .reference("hasAwards", Reference.objects(grammy_2, oscar_2))); // Act: add one more reference @@ -92,7 +92,7 @@ public void testReferences() throws IOException { .extracting(ReferenceProperty::dataTypes, InstanceOfAssertFactories.list(String.class)) .containsOnly(nsMovies); - var gotAlex = artists.query.byId(alex.metadata().id(), + var gotAlex = artists.query.byId(alex.metadata().uuid(), opt -> opt.returnReferences( QueryReference.multi("hasAwards", nsOscar, ref -> ref.returnMetadata(MetadataField.ID)), @@ -101,10 +101,13 @@ public void testReferences() throws IOException { Assertions.assertThat(gotAlex).get() .as("Artists: fetch by id including hasAwards references") - .extracting(WeaviateObject::references, InstanceOfAssertFactories.map(String.class, ObjectReference.class)) + + // Cast references to Map> + .extracting(WeaviateObject::references, InstanceOfAssertFactories.map(String.class, List.class)) .as("hasAwards object reference").extractingByKey("hasAwards") - .extracting(ObjectReference::objects, InstanceOfAssertFactories.list(WeaviateObject.class)) - .extracting(object -> ((QueryMetadata) object.metadata()).id()) + .asInstanceOf(InstanceOfAssertFactories.list(WeaviateObject.class)) + + .extracting(object -> ((ObjectMetadata) object.metadata()).uuid()) .containsOnly( // INVESTIGATE: When references to 2+ collections are requested, // seems to Weaviate only return references to the first one in the list. @@ -112,7 +115,7 @@ public void testReferences() throws IOException { // so the latter will not be in the response. // // grammy_1.metadata().id(), grammy_2.metadata().id(), - oscar_1.metadata().id(), oscar_2.metadata().id()); + oscar_1.metadata().uuid(), oscar_2.metadata().uuid()); } @Test @@ -155,7 +158,7 @@ public void testNestedReferences() throws IOException { .reference("hasAwards", Reference.objects(grammy_1))); // Assert: fetch nested references - var gotAlex = artists.query.byId(alex.metadata().id(), + var gotAlex = artists.query.byId(alex.metadata().uuid(), opt -> opt.returnReferences( QueryReference.single("hasAwards", ref -> ref @@ -167,15 +170,20 @@ public void testNestedReferences() throws IOException { Assertions.assertThat(gotAlex).get() .as("Artists: fetch by id including nested references") - .extracting(WeaviateObject::references, InstanceOfAssertFactories.map(String.class, ObjectReference.class)) + + // Cast references to Map> + .extracting(WeaviateObject::references, InstanceOfAssertFactories.map(String.class, List.class)) .as("hasAwards object reference").extractingByKey("hasAwards") - .extracting(ObjectReference::objects, InstanceOfAssertFactories.list(WeaviateObject.class)) + .asInstanceOf(InstanceOfAssertFactories.list(WeaviateObject.class)) + .hasSize(1).allSatisfy(award -> Assertions.assertThat(award) - .returns(grammy_1.metadata().id(), grammy -> ((QueryMetadata) grammy.metadata()).id()) - .extracting(WeaviateObject::references, - InstanceOfAssertFactories.map(String.class, ObjectReference.class)) - .extractingByKey("presentedBy") - .extracting(ObjectReference::objects, InstanceOfAssertFactories.list(WeaviateObject.class)) + .returns(grammy_1.metadata().uuid(), grammy -> ((ObjectMetadata) grammy.metadata()).uuid()) + + // Cast references to Map> + .extracting(WeaviateObject::references, InstanceOfAssertFactories.map(String.class, List.class)) + .as("presentedBy object reference").extractingByKey("presentedBy") + .asInstanceOf(InstanceOfAssertFactories.list(WeaviateObject.class)) + .hasSize(1).extracting(WeaviateObject::properties) .allSatisfy(properties -> Assertions.assertThat(properties) .asInstanceOf(InstanceOfAssertFactories.map(String.class, Object.class)) diff --git a/src/it/java/io/weaviate/integration/SearchITest.java b/src/it/java/io/weaviate/integration/SearchITest.java index a7f8d6970..17679a5bc 100644 --- a/src/it/java/io/weaviate/integration/SearchITest.java +++ b/src/it/java/io/weaviate/integration/SearchITest.java @@ -15,17 +15,18 @@ import org.junit.rules.TestRule; import io.weaviate.ConcurrentTest; -import io.weaviate.client6.WeaviateClient; +import io.weaviate.client6.v1.api.WeaviateClient; +import io.weaviate.client6.v1.api.collections.Property; import io.weaviate.client6.v1.api.collections.Vectors; import io.weaviate.client6.v1.api.collections.WeaviateObject; +import io.weaviate.client6.v1.api.collections.data.Reference; import io.weaviate.client6.v1.api.collections.query.GroupBy; import io.weaviate.client6.v1.api.collections.query.MetadataField; import io.weaviate.client6.v1.api.collections.query.QueryResponseGroup; -import io.weaviate.client6.v1.collections.Property; -import io.weaviate.client6.v1.collections.Reference; -import io.weaviate.client6.v1.collections.VectorIndex; -import io.weaviate.client6.v1.collections.VectorIndex.IndexingStrategy; -import io.weaviate.client6.v1.collections.Vectorizer; +import io.weaviate.client6.v1.api.collections.vectorindex.Hnsw; +import io.weaviate.client6.v1.api.collections.vectorizers.Img2VecNeuralVectorizer; +import io.weaviate.client6.v1.api.collections.vectorizers.NoneVectorizer; +import io.weaviate.client6.v1.api.collections.vectorizers.Text2VecContextionaryVectorizer; import io.weaviate.containers.Container; import io.weaviate.containers.Container.ContainerGroup; import io.weaviate.containers.Contextionary; @@ -111,10 +112,10 @@ private static Map populateTest(int n) throws IOException { var object = things.data.insert( Map.of("category", CATEGORIES.get(i % CATEGORIES.size())), metadata -> metadata - .id(randomUUID()) + .uuid(randomUUID()) .vectors(Vectors.of(VECTOR_INDEX, vector))); - created.put(object.metadata().id(), vector); + created.put(object.metadata().uuid(), vector); } return created; @@ -128,7 +129,7 @@ private static Map populateTest(int n) throws IOException { private static void createTestCollection() throws IOException { client.collections.create(COLLECTION, cfg -> cfg .properties(Property.text("category")) - .vector(VECTOR_INDEX, new VectorIndex<>(IndexingStrategy.hnsw(), Vectorizer.none()))); + .vector(VECTOR_INDEX, Hnsw.of(new NoneVectorizer()))); } @Test @@ -137,7 +138,7 @@ public void testNearText() throws IOException { client.collections.create(nsSongs, col -> col .properties(Property.text("title")) - .vector(new VectorIndex<>(IndexingStrategy.hnsw(), Vectorizer.text2vecContextionary()))); + .vector(Hnsw.of(Text2VecContextionaryVectorizer.of()))); var songs = client.collections.use(nsSongs); var submarine = songs.data.insert(Map.of("title", "Yellow Submarine")); @@ -148,7 +149,7 @@ public void testNearText() throws IOException { opt -> opt .distance(0.5f) .moveTo(.98f, to -> to.concepts("tropical")) - .moveAway(.4f, away -> away.uuids(submarine.metadata().id())) + .moveAway(.4f, away -> away.uuids(submarine.metadata().uuid())) .returnProperties("title")); Assertions.assertThat(result.objects()).hasSize(2) @@ -159,7 +160,7 @@ public void testNearText() throws IOException { @Test public void testNearText_groupBy() throws IOException { - var vectorIndex = new VectorIndex<>(IndexingStrategy.hnsw(), Vectorizer.text2vecContextionary()); + var vectorIndex = Hnsw.of(Text2VecContextionaryVectorizer.of()); var nsArtists = ns("Artists"); client.collections.create(nsArtists, @@ -190,8 +191,8 @@ public void testNearText_groupBy() throws IOException { Assertions.assertThat(result.groups()).hasSize(2) .containsOnlyKeys( - "weaviate://localhost/%s/%s".formatted(nsArtists, beatles.metadata().id()), - "weaviate://localhost/%s/%s".formatted(nsArtists, ccr.metadata().id())); + "weaviate://localhost/%s/%s".formatted(nsArtists, beatles.metadata().uuid()), + "weaviate://localhost/%s/%s".formatted(nsArtists, ccr.metadata().uuid())); } @Test @@ -203,9 +204,8 @@ public void testNearImage() throws IOException { .properties( Property.text("breed"), Property.blob("img")) - .vector(new VectorIndex<>( - IndexingStrategy.hnsw(), - Vectorizer.img2VecNeuralVectorizer( + .vector(Hnsw.of( + Img2VecNeuralVectorizer.of( i2v -> i2v.imageFields("img"))))); var cats = client.collections.use(nsCats); diff --git a/src/main/java/io/weaviate/client6/Config.java b/src/main/java/io/weaviate/client6/Config.java deleted file mode 100644 index e749aa8c4..000000000 --- a/src/main/java/io/weaviate/client6/Config.java +++ /dev/null @@ -1,50 +0,0 @@ -package io.weaviate.client6; - -import java.util.Collection; -import java.util.Collections; -import java.util.Map.Entry; - -import io.weaviate.client6.v1.internal.grpc.GrpcChannelOptions; - -public class Config implements GrpcChannelOptions { - private final String version = "v1"; - private final String scheme; - private final String httpHost; - private final String grpcHost; - private final Collection> headers = Collections.emptyList(); - - public Config(String scheme, String httpHost, String grpcHost) { - this.scheme = scheme; - this.httpHost = httpHost; - this.grpcHost = grpcHost; - } - - public String baseUrl() { - return scheme + "://" + httpHost + "/" + version; - } - - public String grpcAddress() { - if (grpcHost.contains(":")) { - return grpcHost; - } - // FIXME: use secure port (433) if scheme == https - return String.format("%s:80", grpcHost); - } - - // GrpcChannelOptions ------------------------------------------------------- - - @Override - public String host() { - return grpcAddress(); - } - - @Override - public Collection> headers() { - return headers; - } - - @Override - public boolean useTls() { - return scheme.equals("https"); - } -} diff --git a/src/main/java/io/weaviate/client6/WeaviateClient.java b/src/main/java/io/weaviate/client6/WeaviateClient.java deleted file mode 100644 index 484546a60..000000000 --- a/src/main/java/io/weaviate/client6/WeaviateClient.java +++ /dev/null @@ -1,36 +0,0 @@ -package io.weaviate.client6; - -import java.io.Closeable; -import java.io.IOException; - -import io.weaviate.client6.internal.GrpcClient; -import io.weaviate.client6.internal.HttpClient; -import io.weaviate.client6.v1.collections.CollectionsClient; -import io.weaviate.client6.v1.internal.grpc.DefaultGrpcTransport; -import io.weaviate.client6.v1.internal.grpc.GrpcTransport; - -public class WeaviateClient implements Closeable { - private final HttpClient http; - private final GrpcClient grpc; - - public final CollectionsClient collections; - - private final GrpcTransport grpcTransport; - - public WeaviateClient(Config config) { - this.http = new HttpClient(); - this.grpc = new GrpcClient(config); - - this.grpcTransport = new DefaultGrpcTransport(config); - - this.collections = new CollectionsClient(config, http, grpc, grpcTransport); - } - - @Override - public void close() throws IOException { - this.http.close(); - this.grpc.close(); - - this.grpcTransport.close(); - } -} diff --git a/src/main/java/io/weaviate/client6/internal/DtoTypeAdapterFactory.java b/src/main/java/io/weaviate/client6/internal/DtoTypeAdapterFactory.java deleted file mode 100644 index 812f6dddd..000000000 --- a/src/main/java/io/weaviate/client6/internal/DtoTypeAdapterFactory.java +++ /dev/null @@ -1,108 +0,0 @@ -package io.weaviate.client6.internal; - -import java.io.IOException; -import java.util.HashMap; -import java.util.Map; - -import com.google.gson.Gson; -import com.google.gson.TypeAdapter; -import com.google.gson.TypeAdapterFactory; -import com.google.gson.reflect.TypeToken; -import com.google.gson.stream.JsonReader; -import com.google.gson.stream.JsonWriter; - -/** - * DtoTypeAdapterFactory de-/serializes objects using their registerred DTOs. - * - *

- * DTO classes must implement {@link Dto}, which produces the original model. - * Meanwhile, models do not need to be modified, to avoid leaking - * de-/serialization details. - * - *

- * Usage: - * - *

{@code
- * public class HttpHanlder {
- *   static {
- *     DtoTypeAdapterFactory.register(
- *         MyDomainObject.class,
- *         MyDtoObject.class,
- *         domain -> new MyDtoObject(domain));
- *   }
- *   static final Gson gson = new GsonBuilder()
- *       .registerTypeAdapterFactory(new DtoTypeAdapterFactory())
- *       .create();
- * }
- * }
- */ -public class DtoTypeAdapterFactory implements TypeAdapterFactory { - private static boolean locked = false; - private static final Map, Pair> registry = new HashMap<>(); - - /** - * Register a model-DTO pair. - * - *

- * Only one DTO can be registerred per model. - * Subsequent registrations will be ignored. - */ - public static > void register(Class model, Class dto, - ModelConverter> convert) { - registry.putIfAbsent(model, new Pair(dto, convert)); - } - - /** - * Get model-DTO pair for the provided model class. Returns null if no pair is - * registerred. In this case {@link #create} should also return null. - * - *

- * Conversion to {@code Pair} is safe, as entries to {@link #registry} - * can only be added via {@link #register}, which is type-safe. - */ - @SuppressWarnings("unchecked") - private static > Pair getPair(TypeToken type) { - var cls = type.getRawType(); - if (!registry.containsKey(cls)) { - return null; - } - return (Pair) registry.get(cls); - } - - /** Dto produces a domain model. */ - public interface Dto { - M toModel(); - } - - /** ModelConverter converts domain model to a DTO. */ - @FunctionalInterface - public interface ModelConverter> { - D toDTO(M model); - } - - record Pair>(Class dto, ModelConverter> convert) { - } - - @Override - public TypeAdapter create(Gson gson, TypeToken type) { - var pair = getPair(type); - if (pair == null) { - return null; - } - var delegate = gson.getDelegateAdapter(this, TypeToken.get(pair.dto)); - return new TypeAdapter() { - - @Override - public T read(JsonReader in) throws IOException { - var dto = delegate.read(in); - return dto.toModel(); - } - - @Override - public void write(JsonWriter out, T value) throws IOException { - var dto = pair.convert.toDTO(value); - delegate.write(out, dto); - } - }; - } -} diff --git a/src/main/java/io/weaviate/client6/internal/GrpcClient.java b/src/main/java/io/weaviate/client6/internal/GrpcClient.java deleted file mode 100644 index 8d36f84f5..000000000 --- a/src/main/java/io/weaviate/client6/internal/GrpcClient.java +++ /dev/null @@ -1,37 +0,0 @@ -package io.weaviate.client6.internal; - -import java.io.Closeable; -import java.io.IOException; - -import io.grpc.ManagedChannel; -import io.grpc.ManagedChannelBuilder; -import io.grpc.stub.MetadataUtils; -import io.weaviate.client6.Config; -import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc; -import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateBlockingStub; - -public class GrpcClient implements Closeable { - private final ManagedChannel channel; - public final WeaviateBlockingStub grpc; - - public GrpcClient(Config config) { - this.channel = buildChannel(config); - this.grpc = buildStub(channel); - } - - @Override - public void close() throws IOException { - channel.shutdown(); - } - - private static ManagedChannel buildChannel(Config config) { - ManagedChannelBuilder channelBuilder = ManagedChannelBuilder.forTarget(config.grpcAddress()); - channelBuilder.usePlaintext(); - return channelBuilder.build(); - } - - private static WeaviateBlockingStub buildStub(ManagedChannel channel) { - return WeaviateGrpc.newBlockingStub(channel) - .withInterceptors(MetadataUtils.newAttachHeadersInterceptor(new io.grpc.Metadata())); - } -} diff --git a/src/main/java/io/weaviate/client6/internal/HttpClient.java b/src/main/java/io/weaviate/client6/internal/HttpClient.java deleted file mode 100644 index 1d1122b37..000000000 --- a/src/main/java/io/weaviate/client6/internal/HttpClient.java +++ /dev/null @@ -1,23 +0,0 @@ -package io.weaviate.client6.internal; - -import java.io.Closeable; -import java.io.IOException; - -import org.apache.hc.client5.http.impl.classic.CloseableHttpClient; -import org.apache.hc.client5.http.impl.classic.HttpClients; - -public class HttpClient implements Closeable { - // TODO: move somewhere - // public static final Gson GSON = - - public final CloseableHttpClient http; - - public HttpClient() { - http = HttpClients.createDefault(); - } - - @Override - public void close() throws IOException { - http.close(); - } -} diff --git a/src/main/java/io/weaviate/client6/v1/api/Config.java b/src/main/java/io/weaviate/client6/v1/api/Config.java new file mode 100644 index 000000000..2e7d9391d --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/Config.java @@ -0,0 +1,68 @@ +package io.weaviate.client6.v1.api; + +import java.util.Collections; +import java.util.Map; + +import io.weaviate.client6.v1.internal.grpc.GrpcChannelOptions; +import io.weaviate.client6.v1.internal.rest.TransportOptions; + +public class Config { + private final String version = "v1"; + private final String scheme; + private final String httpHost; + private final String grpcHost; + private final Map headers = Collections.emptyMap(); + + public Config(String scheme, String httpHost, String grpcHost) { + this.scheme = scheme; + this.httpHost = httpHost; + this.grpcHost = grpcHost; + } + + public String baseUrl() { + return scheme + "://" + httpHost + "/" + version; + } + + public String grpcAddress() { + if (grpcHost.contains(":")) { + return grpcHost; + } + // FIXME: use secure port (433) if scheme == https + return String.format("%s:80", grpcHost); + } + + public TransportOptions rest() { + return new TransportOptions() { + + @Override + public String host() { + return baseUrl(); + } + + @Override + public Map headers() { + return headers; + } + + }; + } + + public GrpcChannelOptions grpc() { + return new GrpcChannelOptions() { + @Override + public String host() { + return grpcAddress(); + } + + @Override + public boolean useTls() { + return scheme.equals("https"); + } + + @Override + public Map headers() { + return headers; + } + }; + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/WeaviateClient.java b/src/main/java/io/weaviate/client6/v1/api/WeaviateClient.java new file mode 100644 index 000000000..f2ceeff24 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/WeaviateClient.java @@ -0,0 +1,38 @@ +package io.weaviate.client6.v1.api; + +import java.io.Closeable; +import java.io.IOException; + +import io.weaviate.client6.v1.api.collections.WeaviateCollectionsClient; +import io.weaviate.client6.v1.internal.grpc.DefaultGrpcTransport; +import io.weaviate.client6.v1.internal.grpc.GrpcTransport; +import io.weaviate.client6.v1.internal.rest.DefaultRestTransport; +import io.weaviate.client6.v1.internal.rest.RestTransport; + +public class WeaviateClient implements Closeable { + /** Store this for {@link #async()} helper. */ + private final Config config; + + private final RestTransport restTransport; + private final GrpcTransport grpcTransport; + + public final WeaviateCollectionsClient collections; + + public WeaviateClient(Config config) { + this.config = config; + this.restTransport = new DefaultRestTransport(config.rest()); + this.grpcTransport = new DefaultGrpcTransport(config.grpc()); + + this.collections = new WeaviateCollectionsClient(restTransport, grpcTransport); + } + + public WeaviateClientAsync async() { + return new WeaviateClientAsync(config); + } + + @Override + public void close() throws IOException { + this.restTransport.close(); + this.grpcTransport.close(); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/WeaviateClientAsync.java b/src/main/java/io/weaviate/client6/v1/api/WeaviateClientAsync.java new file mode 100644 index 000000000..a33927292 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/WeaviateClientAsync.java @@ -0,0 +1,30 @@ +package io.weaviate.client6.v1.api; + +import java.io.Closeable; +import java.io.IOException; + +import io.weaviate.client6.v1.api.collections.WeaviateCollectionsClientAsync; +import io.weaviate.client6.v1.internal.grpc.DefaultGrpcTransport; +import io.weaviate.client6.v1.internal.grpc.GrpcTransport; +import io.weaviate.client6.v1.internal.rest.DefaultRestTransport; +import io.weaviate.client6.v1.internal.rest.RestTransport; + +public class WeaviateClientAsync implements Closeable { + private final RestTransport restTransport; + private final GrpcTransport grpcTransport; + + public final WeaviateCollectionsClientAsync collections; + + public WeaviateClientAsync(Config config) { + this.restTransport = new DefaultRestTransport(config.rest()); + this.grpcTransport = new DefaultGrpcTransport(config.grpc()); + + this.collections = new WeaviateCollectionsClientAsync(restTransport, grpcTransport); + } + + @Override + public void close() throws IOException { + this.restTransport.close(); + this.grpcTransport.close(); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandle.java b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandle.java new file mode 100644 index 000000000..47f569cc7 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandle.java @@ -0,0 +1,27 @@ +package io.weaviate.client6.v1.api.collections; + +import io.weaviate.client6.v1.api.collections.aggregate.WeaviateAggregateClient; +import io.weaviate.client6.v1.api.collections.config.WeaviateConfigClient; +import io.weaviate.client6.v1.api.collections.data.WeaviateDataClient; +import io.weaviate.client6.v1.api.collections.query.WeaviateQueryClient; +import io.weaviate.client6.v1.internal.grpc.GrpcTransport; +import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; +import io.weaviate.client6.v1.internal.rest.RestTransport; + +public class CollectionHandle { + public final WeaviateConfigClient config; + public final WeaviateDataClient data; + public final WeaviateQueryClient query; + public final WeaviateAggregateClient aggregate; + + public CollectionHandle( + RestTransport restTransport, + GrpcTransport grpcTransport, + CollectionDescriptor collectionDescriptor) { + + this.config = new WeaviateConfigClient(collectionDescriptor, restTransport, grpcTransport); + this.data = new WeaviateDataClient<>(collectionDescriptor, restTransport); + this.query = new WeaviateQueryClient<>(collectionDescriptor, grpcTransport); + this.aggregate = new WeaviateAggregateClient(collectionDescriptor, grpcTransport); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleAsync.java b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleAsync.java new file mode 100644 index 000000000..95a3096c6 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleAsync.java @@ -0,0 +1,27 @@ +package io.weaviate.client6.v1.api.collections; + +import io.weaviate.client6.v1.api.collections.aggregate.WeaviateAggregateClientAsync; +import io.weaviate.client6.v1.api.collections.config.WeaviateConfigClientAsync; +import io.weaviate.client6.v1.api.collections.data.WeaviateDataClientAsync; +import io.weaviate.client6.v1.api.collections.query.WeaviateQueryClientAsync; +import io.weaviate.client6.v1.internal.grpc.GrpcTransport; +import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; +import io.weaviate.client6.v1.internal.rest.RestTransport; + +public class CollectionHandleAsync { + public final WeaviateConfigClientAsync config; + public final WeaviateDataClientAsync data; + public final WeaviateQueryClientAsync query; + public final WeaviateAggregateClientAsync aggregate; + + public CollectionHandleAsync( + RestTransport restTransport, + GrpcTransport grpcTransport, + CollectionDescriptor collectionDescriptor) { + + this.config = new WeaviateConfigClientAsync(collectionDescriptor, restTransport, grpcTransport); + this.data = new WeaviateDataClientAsync<>(collectionDescriptor, restTransport); + this.query = new WeaviateQueryClientAsync<>(collectionDescriptor, grpcTransport); + this.aggregate = new WeaviateAggregateClientAsync(collectionDescriptor, grpcTransport); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/CreateCollectionRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/CreateCollectionRequest.java new file mode 100644 index 000000000..25fe319ef --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/CreateCollectionRequest.java @@ -0,0 +1,18 @@ +package io.weaviate.client6.v1.api.collections; + +import java.util.Collections; + +import org.apache.hc.core5.http.HttpStatus; + +import io.weaviate.client6.v1.internal.json.JSON; +import io.weaviate.client6.v1.internal.rest.Endpoint; + +public record CreateCollectionRequest(WeaviateCollection collection) { + public static final Endpoint _ENDPOINT = Endpoint.of( + request -> "POST", + request -> "/schema/", + (gson, request) -> JSON.serialize(request.collection), + request -> Collections.emptyMap(), + code -> code != HttpStatus.SC_SUCCESS, + (gson, response) -> JSON.deserialize(response, WeaviateCollection.class)); +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/DataType.java b/src/main/java/io/weaviate/client6/v1/api/collections/DataType.java new file mode 100644 index 000000000..c114f0ab5 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/DataType.java @@ -0,0 +1,13 @@ +package io.weaviate.client6.v1.api.collections; + +import java.util.Set; + +import com.google.common.collect.ImmutableSet; + +public interface DataType { + public static final String TEXT = "text"; + public static final String INT = "int"; + public static final String BLOB = "blob"; + + public static final Set KNOWN_TYPES = ImmutableSet.of(TEXT, INT, BLOB); +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/DeleteCollectionRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/DeleteCollectionRequest.java new file mode 100644 index 000000000..e49b52317 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/DeleteCollectionRequest.java @@ -0,0 +1,17 @@ +package io.weaviate.client6.v1.api.collections; + +import java.util.Collections; + +import org.apache.hc.core5.http.HttpStatus; + +import io.weaviate.client6.v1.internal.rest.Endpoint; + +public record DeleteCollectionRequest(String collectionName) { + public static final Endpoint _ENDPOINT = Endpoint.of( + request -> "DELETE", + request -> "/schema/" + request.collectionName, + (gson, request) -> null, + request -> Collections.emptyMap(), + status -> status != HttpStatus.SC_SUCCESS, + (gson, resopnse) -> null); +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/GetConfigRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/GetConfigRequest.java new file mode 100644 index 000000000..2027428ce --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/GetConfigRequest.java @@ -0,0 +1,19 @@ +package io.weaviate.client6.v1.api.collections; + +import java.util.Collections; +import java.util.Optional; + +import org.apache.hc.core5.http.HttpStatus; + +import io.weaviate.client6.v1.internal.json.JSON; +import io.weaviate.client6.v1.internal.rest.Endpoint; + +public record GetConfigRequest(String collectionName) { + public static final Endpoint> _ENDPOINT = Endpoint.of( + request -> "GET", + request -> "/schema/" + request.collectionName, + (gson, request) -> null, + request -> Collections.emptyMap(), + code -> code != HttpStatus.SC_SUCCESS, + (gson, response) -> Optional.ofNullable(JSON.deserialize(response, WeaviateCollection.class))); +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/ObjectMetadata.java b/src/main/java/io/weaviate/client6/v1/api/collections/ObjectMetadata.java index 4fc2d97cf..d0aa0a815 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/ObjectMetadata.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/ObjectMetadata.java @@ -2,11 +2,18 @@ import java.util.function.Function; +import com.google.gson.annotations.SerializedName; + import io.weaviate.client6.v1.internal.ObjectBuilder; -public record ObjectMetadata(String id, Vectors vectors) { +public record ObjectMetadata( + @SerializedName("id") String uuid, + @SerializedName("vectors") Vectors vectors, + @SerializedName("creationTimeUnix") Long createdAt, + @SerializedName("lastUpdateTImeUnix") Long lastUpdatedAt) implements WeaviateMetadata { + public ObjectMetadata(Builder builder) { - this(builder.id, builder.vectors); + this(builder.id, builder.vectors, null, null); } public static ObjectMetadata of(Function> fn) { diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/ObjectReference.java b/src/main/java/io/weaviate/client6/v1/api/collections/ObjectReference.java index 6cc3395f9..b7f5f9128 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/ObjectReference.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/ObjectReference.java @@ -2,5 +2,8 @@ import java.util.List; -public record ObjectReference(List> objects) { +import io.weaviate.client6.v1.api.collections.query.QueryMetadata; + +public record ObjectReference( + List, QueryMetadata>> objects) { } diff --git a/src/main/java/io/weaviate/client6/v1/collections/Property.java b/src/main/java/io/weaviate/client6/v1/api/collections/Property.java similarity index 67% rename from src/main/java/io/weaviate/client6/v1/collections/Property.java rename to src/main/java/io/weaviate/client6/v1/api/collections/Property.java index bb9293abb..fb1e636b7 100644 --- a/src/main/java/io/weaviate/client6/v1/collections/Property.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/Property.java @@ -1,4 +1,4 @@ -package io.weaviate.client6.v1.collections; +package io.weaviate.client6.v1.api.collections; import java.util.Arrays; import java.util.List; @@ -9,19 +9,23 @@ public record Property( @SerializedName("name") String name, @SerializedName("dataType") List dataTypes) { + public Property(String name, String dataType) { + this(name, List.of(dataType)); + } + /** Add text property with default configuration. */ public static Property text(String name) { - return new Property(name, AtomicDataType.TEXT); + return new Property(name, DataType.TEXT); } /** Add integer property with default configuration. */ public static Property integer(String name) { - return new Property(name, AtomicDataType.INT); + return new Property(name, DataType.INT); } /** Add blob property with default configuration. */ public static Property blob(String name) { - return new Property(name, AtomicDataType.BLOB); + return new Property(name, DataType.BLOB); } public static ReferenceProperty reference(String name, String... collections) { @@ -31,13 +35,4 @@ public static ReferenceProperty reference(String name, String... collections) { public static ReferenceProperty reference(String name, List collections) { return new ReferenceProperty(name, collections); } - - public boolean isReference() { - return dataTypes.stream().noneMatch(t -> AtomicDataType.isAtomic(t)); - } - - private Property(String name, AtomicDataType type) { - this(name, List.of(type.name().toLowerCase())); - } - } diff --git a/src/main/java/io/weaviate/client6/v1/collections/ReferenceProperty.java b/src/main/java/io/weaviate/client6/v1/api/collections/ReferenceProperty.java similarity index 62% rename from src/main/java/io/weaviate/client6/v1/collections/ReferenceProperty.java rename to src/main/java/io/weaviate/client6/v1/api/collections/ReferenceProperty.java index 2b0c5b55a..3ed9b5aed 100644 --- a/src/main/java/io/weaviate/client6/v1/collections/ReferenceProperty.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/ReferenceProperty.java @@ -1,4 +1,4 @@ -package io.weaviate.client6.v1.collections; +package io.weaviate.client6.v1.api.collections; import java.util.List; @@ -7,4 +7,8 @@ public record ReferenceProperty( @SerializedName("name") String name, @SerializedName("dataType") List dataTypes) { + + public Property toProperty() { + return new Property(name, dataTypes); + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/VectorIndex.java b/src/main/java/io/weaviate/client6/v1/api/collections/VectorIndex.java new file mode 100644 index 000000000..05f535ad9 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/VectorIndex.java @@ -0,0 +1,126 @@ +package io.weaviate.client6.v1.api.collections; + +import java.io.IOException; +import java.util.EnumMap; +import java.util.Map; + +import com.google.gson.Gson; +import com.google.gson.JsonParser; +import com.google.gson.TypeAdapter; +import com.google.gson.TypeAdapterFactory; +import com.google.gson.internal.Streams; +import com.google.gson.reflect.TypeToken; +import com.google.gson.stream.JsonReader; +import com.google.gson.stream.JsonWriter; + +import io.weaviate.client6.v1.api.collections.vectorindex.Flat; +import io.weaviate.client6.v1.api.collections.vectorindex.Hnsw; +import io.weaviate.client6.v1.internal.json.JsonEnum; + +public interface VectorIndex { + static final String DEFAULT_VECTOR_NAME = "default"; + + public enum Kind implements JsonEnum { + HNSW("hnsw"), + FLAT("flat"), + DYNAMIC("dynamic"); + + private static final Map jsonValueMap = JsonEnum.collectNames(Kind.values()); + private final String jsonValue; + + private Kind(String jsonValue) { + this.jsonValue = jsonValue; + } + + @Override + public String jsonValue() { + return this.jsonValue; + } + + public static Kind valueOfJson(String jsonValue) { + return JsonEnum.valueOfJson(jsonValue, jsonValueMap, Kind.class); + } + } + + VectorIndex.Kind _kind(); + + default String type() { + return _kind().jsonValue(); + } + + Vectorizer vectorizer(); + + Object config(); + + public static enum CustomTypeAdapterFactory implements TypeAdapterFactory { + INSTANCE; + + private static final EnumMap> readAdapters = new EnumMap<>( + VectorIndex.Kind.class); + + private final void addAdapter(Gson gson, VectorIndex.Kind kind, Class cls) { + readAdapters.put(kind, (TypeAdapter) gson.getDelegateAdapter(this, TypeToken.get(cls))); + } + + private final void init(Gson gson) { + addAdapter(gson, VectorIndex.Kind.HNSW, Hnsw.class); + addAdapter(gson, VectorIndex.Kind.FLAT, Flat.class); + } + + @SuppressWarnings("unchecked") + @Override + public TypeAdapter create(Gson gson, TypeToken type) { + var rawType = type.getRawType(); + if (!VectorIndex.class.isAssignableFrom(rawType)) { + return null; + } + + if (readAdapters.isEmpty()) { + init(gson); + } + + final var vectorizerAdapter = gson.getDelegateAdapter(this, TypeToken.get(Vectorizer.class)); + final var writeAdapter = gson.getDelegateAdapter(this, TypeToken.get(rawType)); + return (TypeAdapter) new TypeAdapter() { + + @Override + public void write(JsonWriter out, VectorIndex value) throws IOException { + out.beginObject(); + out.name("vectorIndexType"); + out.value(value._kind().jsonValue()); + + var config = writeAdapter.toJsonTree((T) value.config()); + config.getAsJsonObject().remove("vectorizer"); + out.name("vectorIndexConfig"); + Streams.write(config, out); + + out.name("vectorizer"); + vectorizerAdapter.write(out, value.vectorizer()); + out.endObject(); + } + + @Override + public VectorIndex read(JsonReader in) throws IOException { + var jsonObject = JsonParser.parseReader(in).getAsJsonObject(); + + VectorIndex.Kind kind; + var kindString = jsonObject.get("vectorIndexType").getAsString(); + try { + kind = VectorIndex.Kind.valueOfJson(kindString); + } catch (IllegalArgumentException e) { + return null; + } + + var adapter = readAdapters.get(kind); + if (adapter == null) { + return null; + } + + var config = jsonObject.get("vectorIndexConfig").getAsJsonObject(); + config.add("vectorizer", jsonObject.get("vectorizer")); + return adapter.fromJsonTree(config); + } + }.nullSafe(); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/Vectorizer.java b/src/main/java/io/weaviate/client6/v1/api/collections/Vectorizer.java new file mode 100644 index 000000000..7752cfb07 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/Vectorizer.java @@ -0,0 +1,113 @@ +package io.weaviate.client6.v1.api.collections; + +import java.io.IOException; +import java.util.EnumMap; +import java.util.Map; + +import com.google.gson.Gson; +import com.google.gson.TypeAdapter; +import com.google.gson.TypeAdapterFactory; +import com.google.gson.reflect.TypeToken; +import com.google.gson.stream.JsonReader; +import com.google.gson.stream.JsonToken; +import com.google.gson.stream.JsonWriter; + +import io.weaviate.client6.v1.api.collections.vectorizers.Img2VecNeuralVectorizer; +import io.weaviate.client6.v1.api.collections.vectorizers.Multi2VecClipVectorizer; +import io.weaviate.client6.v1.api.collections.vectorizers.NoneVectorizer; +import io.weaviate.client6.v1.api.collections.vectorizers.Text2VecContextionaryVectorizer; +import io.weaviate.client6.v1.api.collections.vectorizers.Text2VecWeaviateVectorizer; +import io.weaviate.client6.v1.internal.json.JsonEnum; + +public interface Vectorizer { + public enum Kind implements JsonEnum { + NONE("none"), + IMG2VEC_NEURAL("img2vec-neural"), + TEXT2VEC_CONTEXTIONARY("text2vec-contextionary"), + TEXT2VEC_WEAVIATE("text2vec-weaviate"), + MULTI2VEC_CLIP("multi2vec-clip"); + + private static final Map jsonValueMap = JsonEnum.collectNames(Kind.values()); + private final String jsonValue; + + private Kind(String jsonValue) { + this.jsonValue = jsonValue; + } + + @Override + public String jsonValue() { + return this.jsonValue; + } + + public static Kind valueOfJson(String jsonValue) { + return JsonEnum.valueOfJson(jsonValue, jsonValueMap, Kind.class); + } + } + + Kind _kind(); + + Object _self(); + + public static enum CustomTypeAdapterFactory implements TypeAdapterFactory { + INSTANCE; + + private static final EnumMap> readAdapters = new EnumMap<>( + Vectorizer.Kind.class); + + private final void addAdapter(Gson gson, Vectorizer.Kind kind, Class cls) { + readAdapters.put(kind, (TypeAdapter) gson.getDelegateAdapter(this, TypeToken.get(cls))); + } + + private final void init(Gson gson) { + addAdapter(gson, Vectorizer.Kind.NONE, NoneVectorizer.class); + addAdapter(gson, Vectorizer.Kind.IMG2VEC_NEURAL, Img2VecNeuralVectorizer.class); + addAdapter(gson, Vectorizer.Kind.MULTI2VEC_CLIP, Multi2VecClipVectorizer.class); + addAdapter(gson, Vectorizer.Kind.TEXT2VEC_WEAVIATE, Text2VecWeaviateVectorizer.class); + addAdapter(gson, Vectorizer.Kind.TEXT2VEC_CONTEXTIONARY, Text2VecContextionaryVectorizer.class); + } + + @SuppressWarnings("unchecked") + @Override + public TypeAdapter create(Gson gson, TypeToken type) { + final var rawType = type.getRawType(); + if (!Vectorizer.class.isAssignableFrom(rawType)) { + return null; + } + + if (readAdapters.isEmpty()) { + init(gson); + } + + return (TypeAdapter) new TypeAdapter() { + + @Override + public void write(JsonWriter out, Vectorizer value) throws IOException { + var writeAdapter = readAdapters.get(value._kind()); + + out.beginObject(); + out.name(value._kind().jsonValue()); + ((TypeAdapter) writeAdapter).write(out, (T) value._self()); + out.endObject(); + } + + @Override + public Vectorizer read(JsonReader in) throws IOException { + in.beginObject(); + var vectorizerName = in.nextName(); + try { + var kind = Vectorizer.Kind.valueOfJson(vectorizerName); + var adapter = readAdapters.get(kind); + return adapter.read(in); + } catch (IllegalArgumentException e) { + return null; + } finally { + if (in.peek() == JsonToken.BEGIN_OBJECT) { + in.beginObject(); + } + in.endObject(); + } + } + }.nullSafe(); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/Vectors.java b/src/main/java/io/weaviate/client6/v1/api/collections/Vectors.java index dcc4850d9..cfd647ce3 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/Vectors.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/Vectors.java @@ -1,60 +1,76 @@ package io.weaviate.client6.v1.api.collections; +import java.io.IOException; import java.util.Collections; import java.util.HashMap; import java.util.Map; -import java.util.Optional; import java.util.function.Function; +import com.google.gson.Gson; +import com.google.gson.JsonArray; +import com.google.gson.JsonElement; +import com.google.gson.JsonParser; +import com.google.gson.TypeAdapter; +import com.google.gson.TypeAdapterFactory; +import com.google.gson.reflect.TypeToken; +import com.google.gson.stream.JsonReader; +import com.google.gson.stream.JsonWriter; + import io.weaviate.client6.v1.internal.ObjectBuilder; import lombok.ToString; /** - * Vectors is an abstraction over named vectors. - * It may contain both 1-dimensional and 2-dimensional vectors. + * Vectors is an abstraction over named vectors, which can store + * both 1-dimensional and 2-dimensional vectors. */ @ToString public class Vectors { - // TODO: define this in collection.config.Vectors - private static final String DEFAULT = "default"; - - private final Float[] unnamedVector; private final Map namedVectors; - /** - * Pass legacy unnamed vector. - * Multi-vectors can only be passed as named vectors. - */ - public static Vectors unnamed(Float[] vector) { - return new Vectors(vector); + public static Vectors of(Float[] vector) { + return new Vectors(VectorIndex.DEFAULT_VECTOR_NAME, vector); } - public static Vectors of(Float[] vector) { - return new Vectors(DEFAULT, vector); + public static Vectors of(String name, Float[] vector) { + return new Vectors(name, vector); } public static Vectors of(Float[][] vector) { - return new Vectors(DEFAULT, vector); + return new Vectors(VectorIndex.DEFAULT_VECTOR_NAME, vector); } - public static Vectors of(String name, Float[] vector) { + public static Vectors of(String name, Float[][] vector) { return new Vectors(name, vector); } - public static Vectors of(String name, Float[][] vector) { - return new Vectors(name, vector); + public static Vectors of(Function> fn) { + return fn.apply(new Builder()).build(); } - public static Vectors of(Map vectors) { - return new Vectors(vectors, null); + public Vectors(Builder builder) { + this.namedVectors = builder.namedVectors; } - public static Vectors of(Function> fn) { - return fn.apply(new Builder()).build(); + /* + * Create a single named vector. + * Intended to be used by factory methods, which can statically restrict + * vector's type to {@code Float[]} and {@code Float[][]}. + * + * @param name Vector name. + * + * @param vector {@code Float[]} or {@code Float[][]} vector. + * + */ + private Vectors(String name, Object vector) { + this.namedVectors = Collections.singletonMap(name, vector); + } + + private Vectors(Map namedVectors) { + this.namedVectors = namedVectors; } - public static class Builder { - private Map namedVectors = new HashMap<>(); + public static class Builder implements ObjectBuilder { + private final Map namedVectors = new HashMap<>(); public Builder vector(String name, Float[] vector) { this.namedVectors.put(name, vector); @@ -66,8 +82,9 @@ public Builder vector(String name, Float[][] vector) { return this; } + @Override public Vectors build() { - return new Vectors(this.namedVectors, null); + return new Vectors(this); } } @@ -76,12 +93,7 @@ public Float[] getSingle(String name) { } public Float[] getDefaultSingle() { - return getSingle(DEFAULT); - } - - @SuppressWarnings("unchecked") - public Optional getSingle() { - return (Optional) getOnly(); + return getSingle(VectorIndex.DEFAULT_VECTOR_NAME); } public Float[][] getMulti(String name) { @@ -89,43 +101,51 @@ public Float[][] getMulti(String name) { } public Float[][] getDefaultMulti() { - return getMulti(DEFAULT); - } - - @SuppressWarnings("unchecked") - public Optional getMulti() { - return (Optional) getOnly(); - } - - public Optional getUnnamed() { - return Optional.ofNullable(unnamedVector); - } - - private Optional getOnly() { - if (namedVectors == null || namedVectors.isEmpty() || namedVectors.size() > 1) { - return Optional.empty(); + return getMulti(VectorIndex.DEFAULT_VECTOR_NAME); + } + + public static enum CustomTypeAdapterFactory implements TypeAdapterFactory { + INSTANCE; + + @SuppressWarnings("unchecked") + @Override + public TypeAdapter create(Gson gson, TypeToken type) { + if (type.getRawType() != Vectors.class) { + return null; + } + final var mapAdapter = gson.getDelegateAdapter(this, new TypeToken>() { + }); + final var float_1d = gson.getDelegateAdapter(this, TypeToken.get(Float[].class)); + final var float_2d = gson.getDelegateAdapter(this, TypeToken.get(Float[][].class)); + return (TypeAdapter) new TypeAdapter() { + + @Override + public void write(JsonWriter out, Vectors value) throws IOException { + mapAdapter.write(out, value.namedVectors); + } + + @Override + public Vectors read(JsonReader in) throws IOException { + var vectorsMap = JsonParser.parseReader(in).getAsJsonObject().asMap(); + var namedVectors = new HashMap(); + + for (var entry : vectorsMap.entrySet()) { + String vectorName = entry.getKey(); + JsonElement el = entry.getValue(); + if (el.isJsonArray()) { + JsonArray array = el.getAsJsonArray(); + Object vector; + if (array.size() > 0 && array.get(0).isJsonArray()) { + vector = float_2d.fromJsonTree(array); + } else { + vector = float_1d.fromJsonTree(array); + } + namedVectors.put(vectorName, vector); + } + } + return new Vectors(namedVectors); + } + }.nullSafe(); } - return Optional.ofNullable(namedVectors.values().iterator().next()); - } - - public Map getNamed() { - return Map.copyOf(namedVectors); - } - - private Vectors(Map named) { - this(named, null); - } - - private Vectors(Float[] unnamed) { - this(Collections.emptyMap(), unnamed); - } - - private Vectors(String name, Object vector) { - this(Collections.singletonMap(name, vector)); - } - - private Vectors(Map named, Float[] unnamed) { - this.namedVectors = Map.copyOf(named); - this.unnamedVector = unnamed; } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateCollection.java b/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateCollection.java new file mode 100644 index 000000000..6dc8fa306 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateCollection.java @@ -0,0 +1,167 @@ +package io.weaviate.client6.v1.api.collections; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +import com.google.gson.Gson; +import com.google.gson.JsonArray; +import com.google.gson.JsonParser; +import com.google.gson.TypeAdapter; +import com.google.gson.TypeAdapterFactory; +import com.google.gson.annotations.SerializedName; +import com.google.gson.internal.Streams; +import com.google.gson.reflect.TypeToken; +import com.google.gson.stream.JsonReader; +import com.google.gson.stream.JsonWriter; + +import io.weaviate.client6.v1.internal.ObjectBuilder; + +public record WeaviateCollection( + @SerializedName("class") String name, + @SerializedName("description") String description, + @SerializedName("properties") List properties, + List references, + @SerializedName("vectorConfig") Map vectors) { + + public static WeaviateCollection of(String collectionName) { + return of(collectionName, ObjectBuilder.identity()); + } + + public static WeaviateCollection of(String collectionName, Function> fn) { + return fn.apply(new Builder(collectionName)).build(); + } + + public WeaviateCollection(Builder builder) { + this( + builder.collectionName, + builder.description, + builder.properties, + builder.references, + builder.vectors); + } + + public static class Builder implements ObjectBuilder { + // Required parameters; + private final String collectionName; + + private String description; + private List properties = new ArrayList<>(); + private List references = new ArrayList<>(); + private Map vectors = new HashMap<>(); + + public Builder(String collectionName) { + this.collectionName = collectionName; + } + + public Builder description(String description) { + this.description = description; + return this; + } + + public Builder properties(Property... properties) { + return properties(Arrays.asList(properties)); + } + + public Builder properties(List properties) { + this.properties = properties; + return this; + } + + public Builder references(ReferenceProperty... references) { + return references(Arrays.asList(references)); + } + + public Builder references(List references) { + this.references = references; + return this; + } + + public Builder vector(VectorIndex vector) { + this.vectors.put(VectorIndex.DEFAULT_VECTOR_NAME, vector); + return this; + } + + public Builder vector(String name, VectorIndex vector) { + this.vectors.put(name, vector); + return this; + } + + public Builder vectors(Function>> fn) { + this.vectors = fn.apply(new VectorsBuilder()).build(); + return this; + } + + public static class VectorsBuilder implements ObjectBuilder> { + private Map vectors = new HashMap<>(); + + public VectorsBuilder vector(String name, VectorIndex vector) { + vectors.put(name, vector); + return this; + } + + @Override + public Map build() { + return this.vectors; + } + } + + @Override + public WeaviateCollection build() { + return new WeaviateCollection(this); + } + } + + public static enum CustomTypeAdapterFactory implements TypeAdapterFactory { + INSTANCE; + + @SuppressWarnings("unchecked") + @Override + public TypeAdapter create(Gson gson, TypeToken type) { + if (type.getRawType() != WeaviateCollection.class) { + return null; + } + + final var delegate = gson.getDelegateAdapter(this, (TypeToken) type); + return (TypeAdapter) new TypeAdapter() { + + @Override + public void write(JsonWriter out, WeaviateCollection value) throws IOException { + var jsonObject = delegate.toJsonTree(value).getAsJsonObject(); + + var references = jsonObject.remove("references").getAsJsonArray(); + var properties = jsonObject.get("properties").getAsJsonArray(); + properties.addAll(references); + + Streams.write(jsonObject, out); + } + + @Override + public WeaviateCollection read(JsonReader in) throws IOException { + var jsonObject = JsonParser.parseReader(in).getAsJsonObject(); + + var mixedProperties = jsonObject.get("properties").getAsJsonArray(); + var references = new JsonArray(); + var properties = new JsonArray(); + + for (var property : mixedProperties) { + var dataTypes = property.getAsJsonObject().get("dataType").getAsJsonArray(); + if (dataTypes.size() == 1 && DataType.KNOWN_TYPES.contains(dataTypes.get(0).getAsString())) { + properties.add(property); + } else { + references.add(property); + } + } + + jsonObject.add("properties", properties); + jsonObject.add("references", references); + return delegate.fromJsonTree(jsonObject); + } + }.nullSafe(); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateCollectionsClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateCollectionsClient.java new file mode 100644 index 000000000..0937572bf --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateCollectionsClient.java @@ -0,0 +1,47 @@ +package io.weaviate.client6.v1.api.collections; + +import java.io.IOException; +import java.util.Map; +import java.util.Optional; +import java.util.function.Function; + +import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.GrpcTransport; +import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; +import io.weaviate.client6.v1.internal.rest.RestTransport; + +public class WeaviateCollectionsClient { + private final RestTransport restTransport; + private final GrpcTransport grpcTransport; + + public WeaviateCollectionsClient(RestTransport restTransport, GrpcTransport grpcTransport) { + this.restTransport = restTransport; + this.grpcTransport = grpcTransport; + } + + public CollectionHandle> use(String collectionName) { + return new CollectionHandle<>(restTransport, grpcTransport, CollectionDescriptor.ofMap(collectionName)); + } + + public WeaviateCollection create(String name) throws IOException { + return create(WeaviateCollection.of(name)); + } + + public WeaviateCollection create(String name, + Function> fn) throws IOException { + return create(WeaviateCollection.of(name, fn)); + } + + public WeaviateCollection create(WeaviateCollection collection) throws IOException { + return this.restTransport.performRequest(new CreateCollectionRequest(collection), + CreateCollectionRequest._ENDPOINT); + } + + public Optional getConfig(String name) throws IOException { + return this.restTransport.performRequest(new GetConfigRequest(name), GetConfigRequest._ENDPOINT); + } + + public void delete(String name) throws IOException { + this.restTransport.performRequest(new DeleteCollectionRequest(name), DeleteCollectionRequest._ENDPOINT); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateCollectionsClientAsync.java b/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateCollectionsClientAsync.java new file mode 100644 index 000000000..14f900484 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateCollectionsClientAsync.java @@ -0,0 +1,48 @@ +package io.weaviate.client6.v1.api.collections; + +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.function.Function; + +import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.GrpcTransport; +import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; +import io.weaviate.client6.v1.internal.rest.RestTransport; + +public class WeaviateCollectionsClientAsync { + private final RestTransport restTransport; + private final GrpcTransport grpcTransport; + + public WeaviateCollectionsClientAsync(RestTransport restTransport, GrpcTransport grpcTransport) { + this.restTransport = restTransport; + this.grpcTransport = grpcTransport; + } + + public CollectionHandle> use(String collectionName) { + return new CollectionHandle<>(restTransport, grpcTransport, + CollectionDescriptor.ofMap(collectionName)); + } + + public CompletableFuture create(String name) { + return create(WeaviateCollection.of(name)); + } + + public CompletableFuture create(String name, + Function> fn) { + return create(WeaviateCollection.of(name, fn)); + } + + public CompletableFuture create(WeaviateCollection collection) { + return this.restTransport.performRequestAsync(new CreateCollectionRequest(collection), + CreateCollectionRequest._ENDPOINT); + } + + public CompletableFuture> getConfig(String name) { + return this.restTransport.performRequestAsync(new GetConfigRequest(name), GetConfigRequest._ENDPOINT); + } + + public CompletableFuture delete(String name) { + return this.restTransport.performRequestAsync(new DeleteCollectionRequest(name), DeleteCollectionRequest._ENDPOINT); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateMetadata.java b/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateMetadata.java new file mode 100644 index 000000000..9121e33c3 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateMetadata.java @@ -0,0 +1,7 @@ +package io.weaviate.client6.v1.api.collections; + +public interface WeaviateMetadata { + String uuid(); + + Vectors vectors(); +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateObject.java b/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateObject.java index 7c58e7cc9..84367b67e 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateObject.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/WeaviateObject.java @@ -1,10 +1,177 @@ package io.weaviate.client6.v1.api.collections; +import java.io.IOException; +import java.lang.reflect.ParameterizedType; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; import java.util.Map; -public record WeaviateObject( +import com.google.gson.Gson; +import com.google.gson.JsonArray; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; +import com.google.gson.TypeAdapter; +import com.google.gson.TypeAdapterFactory; +import com.google.gson.internal.Streams; +import com.google.gson.reflect.TypeToken; +import com.google.gson.stream.JsonReader; +import com.google.gson.stream.JsonWriter; + +import io.weaviate.client6.v1.internal.ObjectBuilder; + +public record WeaviateObject( String collection, - T properties, - Map references, + P properties, + Map> references, M metadata) { + + public WeaviateObject(Builder builder) { + this(builder.collection, builder.properties, builder.references, builder.metadata); + } + + public static class Builder implements ObjectBuilder> { + private String collection; + private P properties; + private Map> references = new HashMap<>(); + private M metadata; + + public final Builder collection(String collection) { + this.collection = collection; + return this; + } + + public final Builder properties(P properties) { + this.properties = properties; + return this; + } + + /** + * Add a reference. Calls to {@link #reference} can be chained + * to add multiple references. + */ + @SafeVarargs + public final Builder reference(String property, R... references) { + for (var ref : references) { + addReference(property, ref); + } + return this; + } + + private final void addReference(String property, R reference) { + if (!references.containsKey(property)) { + references.put(property, new ArrayList<>()); + } + references.get(property).add(reference); + } + + public Builder references(Map> references) { + this.references = references; + return this; + } + + public Builder metadata(M metadata) { + this.metadata = metadata; + return this; + } + + @Override + public WeaviateObject build() { + return new WeaviateObject<>(this); + } + } + + public static enum CustomTypeAdapterFactory implements TypeAdapterFactory { + INSTANCE; + + @SuppressWarnings("unchecked") + @Override + public TypeAdapter create(Gson gson, TypeToken typeToken) { + var type = typeToken.getType(); + var rawType = typeToken.getRawType(); + if (rawType != WeaviateObject.class || + !(type instanceof ParameterizedType parameterized) + || parameterized.getActualTypeArguments().length < 3) { + return null; + } + + var typeParams = parameterized.getActualTypeArguments(); + final var propertiesType = typeParams[0]; + final var referencesType = typeParams[1]; + final var metadataType = typeParams[2]; + + final var propertiesAdapter = gson.getAdapter(TypeToken.get(propertiesType)); + final var metadataAdapter = gson.getAdapter(TypeToken.get(metadataType)); + final var referencesAdapter = gson.getAdapter(TypeToken.get(referencesType)); + + return (TypeAdapter) new TypeAdapter>() { + + @Override + public void write(JsonWriter out, WeaviateObject value) throws IOException { + out.beginObject(); + + out.name("class"); + out.value(value.collection()); + + out.name("properties"); + if (value.references().isEmpty()) { + ((TypeAdapter) propertiesAdapter).write(out, value.properties()); + } else { + var properties = ((TypeAdapter) propertiesAdapter).toJsonTree(value.properties()).getAsJsonObject(); + for (var refEntry : value.references().entrySet()) { + var beacons = new JsonArray(); + for (var reference : (List) refEntry.getValue()) { + var beacon = ((TypeAdapter) referencesAdapter).toJsonTree(reference); + beacons.add(beacon); + } + properties.add(refEntry.getKey(), beacons); + } + Streams.write(properties, out); + } + + // Flatten out metadata fields. + var metadata = ((TypeAdapter) metadataAdapter).toJsonTree(value.metadata); + for (var entry : metadata.getAsJsonObject().entrySet()) { + out.name(entry.getKey()); + Streams.write(entry.getValue(), out); + } + out.endObject(); + } + + @Override + public WeaviateObject read(JsonReader in) throws IOException { + var builder = new WeaviateObject.Builder<>(); + var metadata = new ObjectMetadata.Builder(); + + var object = JsonParser.parseReader(in).getAsJsonObject(); + builder.collection(object.get("class").getAsString()); + + var jsonProperties = object.get("properties").getAsJsonObject(); + var trueProperties = new JsonObject(); + for (var property : jsonProperties.entrySet()) { + var value = property.getValue(); + if (!value.isJsonArray()) { + trueProperties.add(property.getKey(), value); + continue; + } + var array = value.getAsJsonArray(); + var first = array.get(0); + if (first.isJsonObject() && first.getAsJsonObject().has("beacon")) { + for (var el : array) { + var beacon = ((TypeAdapter) referencesAdapter).fromJsonTree(el); + builder.reference(property.getKey(), beacon); + } + } + } + + builder.properties(propertiesAdapter.fromJsonTree(trueProperties)); + + metadata.id(object.get("id").getAsString()); + builder.metadata(metadata.build()); + + return builder.build(); + } + }; + } + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/AbstractAggregateClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/AbstractAggregateClient.java index dff37670e..c63c9ff9f 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/AbstractAggregateClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/AbstractAggregateClient.java @@ -10,7 +10,7 @@ import io.weaviate.client6.v1.internal.grpc.GrpcTransport; import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; -public abstract class AbstractAggregateClient { +abstract class AbstractAggregateClient { protected final CollectionDescriptor collection; protected final GrpcTransport transport; diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/Aggregation.java b/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/Aggregation.java index a95a359ea..1cea32828 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/Aggregation.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/aggregate/Aggregation.java @@ -1,5 +1,6 @@ package io.weaviate.client6.v1.api.collections.aggregate; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.function.Function; @@ -36,7 +37,7 @@ public Builder(ObjectFilter objectFilter) { this.objectFilter = objectFilter; } - private List metrics; + private List metrics = new ArrayList<>(); private Integer objectLimit; private boolean includeTotalCount = false; diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/config/AddPropertyRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/config/AddPropertyRequest.java new file mode 100644 index 000000000..ec670a13c --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/config/AddPropertyRequest.java @@ -0,0 +1,19 @@ +package io.weaviate.client6.v1.api.collections.config; + +import java.util.Collections; + +import org.apache.hc.core5.http.HttpStatus; + +import io.weaviate.client6.v1.api.collections.Property; +import io.weaviate.client6.v1.internal.json.JSON; +import io.weaviate.client6.v1.internal.rest.Endpoint; + +public record AddPropertyRequest(String collectionName, Property property) { + public static final Endpoint _ENDPOINT = Endpoint.of( + request -> "POST", + request -> "/schema/" + request.collectionName + "/properties", + (gson, request) -> JSON.serialize(request.property), + request -> Collections.emptyMap(), + code -> code != HttpStatus.SC_SUCCESS, + (gson, response) -> null); +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/config/WeaviateConfigClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/config/WeaviateConfigClient.java new file mode 100644 index 000000000..0fa3cb860 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/config/WeaviateConfigClient.java @@ -0,0 +1,38 @@ +package io.weaviate.client6.v1.api.collections.config; + +import java.io.IOException; +import java.util.Optional; + +import io.weaviate.client6.v1.api.collections.Property; +import io.weaviate.client6.v1.api.collections.WeaviateCollection; +import io.weaviate.client6.v1.api.collections.WeaviateCollectionsClient; +import io.weaviate.client6.v1.internal.grpc.GrpcTransport; +import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; +import io.weaviate.client6.v1.internal.rest.RestTransport; + +public class WeaviateConfigClient { + private final RestTransport transport; + private final WeaviateCollectionsClient collectionsClient; + + protected final CollectionDescriptor collection; + + public WeaviateConfigClient(CollectionDescriptor collection, RestTransport restTransport, + GrpcTransport grpcTransport) { + this.transport = restTransport; + this.collectionsClient = new WeaviateCollectionsClient(restTransport, grpcTransport); + + this.collection = collection; + } + + public Optional get() throws IOException { + return collectionsClient.getConfig(collection.name()); + } + + public void addProperty(Property property) throws IOException { + this.transport.performRequest(new AddPropertyRequest(collection.name(), property), AddPropertyRequest._ENDPOINT); + } + + public void addReference(String name, String... dataTypes) throws IOException { + this.addProperty(Property.reference(name, dataTypes).toProperty()); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/config/WeaviateConfigClientAsync.java b/src/main/java/io/weaviate/client6/v1/api/collections/config/WeaviateConfigClientAsync.java new file mode 100644 index 000000000..54e586a2b --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/config/WeaviateConfigClientAsync.java @@ -0,0 +1,40 @@ +package io.weaviate.client6.v1.api.collections.config; + +import java.io.IOException; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; + +import io.weaviate.client6.v1.api.collections.Property; +import io.weaviate.client6.v1.api.collections.WeaviateCollection; +import io.weaviate.client6.v1.api.collections.WeaviateCollectionsClientAsync; +import io.weaviate.client6.v1.internal.grpc.GrpcTransport; +import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; +import io.weaviate.client6.v1.internal.rest.RestTransport; + +public class WeaviateConfigClientAsync { + private final RestTransport transport; + private final WeaviateCollectionsClientAsync collectionsClient; + + protected final CollectionDescriptor collection; + + public WeaviateConfigClientAsync(CollectionDescriptor collection, RestTransport restTransport, + GrpcTransport grpcTransport) { + this.transport = restTransport; + this.collectionsClient = new WeaviateCollectionsClientAsync(restTransport, grpcTransport); + + this.collection = collection; + } + + public CompletableFuture> get() throws IOException { + return collectionsClient.getConfig(collection.name()); + } + + public CompletableFuture addProperty(Property property) throws IOException { + return this.transport.performRequestAsync(new AddPropertyRequest(collection.name(), property), + AddPropertyRequest._ENDPOINT); + } + + public CompletableFuture addReference(String name, String... dataTypes) throws IOException { + return this.addProperty(Property.reference(name, dataTypes).toProperty()); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/DeleteObjectRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/DeleteObjectRequest.java new file mode 100644 index 000000000..a425fc2a8 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/DeleteObjectRequest.java @@ -0,0 +1,18 @@ +package io.weaviate.client6.v1.api.collections.data; + +import java.util.Collections; + +import org.apache.hc.core5.http.HttpStatus; + +import io.weaviate.client6.v1.internal.rest.Endpoint; + +public record DeleteObjectRequest(String collectionName, String uuid) { + + public static final Endpoint _ENDPOINT = Endpoint.of( + request -> "DELETE", + request -> "/objects/" + request.collectionName + "/" + request.uuid, + (gson, request) -> null, + request -> Collections.emptyMap(), + code -> code != HttpStatus.SC_NO_CONTENT, + (gson, response) -> null); +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertObjectRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertObjectRequest.java new file mode 100644 index 000000000..6654c9c8d --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertObjectRequest.java @@ -0,0 +1,77 @@ +package io.weaviate.client6.v1.api.collections.data; + +import java.util.Collections; +import java.util.function.Function; + +import org.apache.hc.core5.http.HttpStatus; + +import com.google.gson.reflect.TypeToken; + +import io.weaviate.client6.v1.api.collections.ObjectMetadata; +import io.weaviate.client6.v1.api.collections.Vectors; +import io.weaviate.client6.v1.api.collections.WeaviateObject; +import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.json.JSON; +import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; +import io.weaviate.client6.v1.internal.rest.Endpoint; + +public record InsertObjectRequest(WeaviateObject object) { + + @SuppressWarnings("unchecked") + public static final Endpoint, WeaviateObject> endpoint( + CollectionDescriptor descriptor) { + return Endpoint.of( + request -> "POST", + request -> "/objects/", + (gson, request) -> JSON.serialize(request.object, TypeToken.getParameterized( + WeaviateObject.class, descriptor.typeToken().getType(), Reference.class, ObjectMetadata.class)), + request -> Collections.emptyMap(), + code -> code != HttpStatus.SC_SUCCESS, + (gson, response) -> JSON.deserialize(response, + (TypeToken>) TypeToken.getParameterized( + WeaviateObject.class, descriptor.typeToken().getType(), Object.class, ObjectMetadata.class))); + } + + public static InsertObjectRequest of(String collectionName, T properties) { + return of(collectionName, properties, ObjectBuilder.identity()); + } + + public static InsertObjectRequest of(String collectionName, T properties, + Function, ObjectBuilder>> fn) { + return fn.apply(new Builder(collectionName, properties)).build(); + } + + public InsertObjectRequest(Builder builder) { + this(builder.object.build()); + } + + public static class Builder implements ObjectBuilder> { + private final WeaviateObject.Builder object = new WeaviateObject.Builder<>(); + private final ObjectMetadata.Builder metadata = new ObjectMetadata.Builder(); + + public Builder(String collectionName, T properties) { + this.object.collection(collectionName).properties(properties); + } + + public Builder uuid(String uuid) { + this.metadata.id(uuid); + return this; + } + + public Builder vectors(Vectors vectors) { + this.metadata.vectors(vectors); + return this; + } + + public Builder reference(String property, Reference... references) { + this.object.reference(property, references); + return this; + } + + @Override + public InsertObjectRequest build() { + this.object.metadata(this.metadata.build()); + return new InsertObjectRequest<>(this); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertObjectResponse.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertObjectResponse.java new file mode 100644 index 000000000..c3eb95f2f --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/InsertObjectResponse.java @@ -0,0 +1,14 @@ +package io.weaviate.client6.v1.api.collections.data; + +import com.google.gson.annotations.SerializedName; + +import io.weaviate.client6.v1.api.collections.Vectors; + +public record InsertObjectResponse( + @SerializedName("class") String collectionName, + @SerializedName("properties") T properties, + @SerializedName("id") String uuid, + @SerializedName("vectors") Vectors vectors, + @SerializedName("creationTimeUnix") Long createdAt, + @SerializedName("lastUpdateTimeUnix") Long lastUpdatedAt) { +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/Reference.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/Reference.java new file mode 100644 index 000000000..5cf0c6b68 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/Reference.java @@ -0,0 +1,84 @@ +package io.weaviate.client6.v1.api.collections.data; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +import com.google.gson.TypeAdapter; +import com.google.gson.stream.JsonReader; +import com.google.gson.stream.JsonWriter; + +import io.weaviate.client6.v1.api.collections.WeaviateObject; + +public record Reference(String collection, List uuids) { + + public Reference(String collection, String uuid) { + this(collection, List.of(uuid)); + } + + /** + * Create reference to objects by their UUIDs. + *

+ * Weaviate will search each of the existing collections to identify + * the objects before inserting the references, so this may include + * some performance overhead. + */ + public static Reference uuids(String... uuids) { + return new Reference(null, Arrays.asList(uuids)); + } + + /** Create references to {@link WeaviateObject}. */ + public static Reference[] objects(WeaviateObject... objects) { + return Arrays.stream(objects) + .map(o -> new Reference(o.collection(), o.metadata().uuid())) + .toArray(Reference[]::new); + } + + /** Create references to objects in a collection by their UUIDs. */ + public static Reference collection(String collection, String... uuids) { + return new Reference(collection, Arrays.asList(uuids)); + } + + public static final TypeAdapter TYPE_ADAPTER = new TypeAdapter() { + + @Override + public void write(JsonWriter out, Reference value) throws IOException { + for (var uuid : value.uuids()) { + out.beginObject(); + out.name("beacon"); + + var beacon = "weaviate://localhost"; + if (value.collection() != null) { + beacon += "/" + value.collection(); + } + beacon += "/" + uuid; + + out.value(beacon); + out.endObject(); + } + } + + @Override + public Reference read(JsonReader in) throws IOException { + String collection = null; + String id = null; + + in.beginObject(); + in.nextName(); // expect "beacon"? + var beacon = in.nextString(); + in.endObject(); + + beacon = beacon.replaceFirst("weaviate://localhost/", ""); + if (beacon.contains("/")) { + var parts = beacon.split("/"); + collection = parts[0]; + id = parts[1]; + } else { + id = beacon; + } + + return new Reference(collection, id); + } + + }.nullSafe(); +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/WeaviateDataClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/WeaviateDataClient.java new file mode 100644 index 000000000..b4466b72d --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/WeaviateDataClient.java @@ -0,0 +1,39 @@ +package io.weaviate.client6.v1.api.collections.data; + +import java.io.IOException; +import java.util.function.Function; + +import io.weaviate.client6.v1.api.collections.ObjectMetadata; +import io.weaviate.client6.v1.api.collections.WeaviateObject; +import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; +import io.weaviate.client6.v1.internal.rest.RestTransport; + +public class WeaviateDataClient { + private final RestTransport restTransport; + private final CollectionDescriptor collectionDescriptor; + + public WeaviateDataClient(CollectionDescriptor collectionDescriptor, RestTransport restTransport) { + this.restTransport = restTransport; + this.collectionDescriptor = collectionDescriptor; + } + + public WeaviateObject insert(T properties) throws IOException { + return insert(InsertObjectRequest.of(collectionDescriptor.name(), properties)); + } + + public WeaviateObject insert(T properties, + Function, ObjectBuilder>> fn) + throws IOException { + return insert(InsertObjectRequest.of(collectionDescriptor.name(), properties, fn)); + } + + public WeaviateObject insert(InsertObjectRequest request) throws IOException { + return this.restTransport.performRequest(request, InsertObjectRequest.endpoint(collectionDescriptor)); + } + + public void delete(String uuid) throws IOException { + this.restTransport.performRequest(new DeleteObjectRequest(collectionDescriptor.name(), uuid), + DeleteObjectRequest._ENDPOINT); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/data/WeaviateDataClientAsync.java b/src/main/java/io/weaviate/client6/v1/api/collections/data/WeaviateDataClientAsync.java new file mode 100644 index 000000000..e0ac5ed74 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/data/WeaviateDataClientAsync.java @@ -0,0 +1,41 @@ +package io.weaviate.client6.v1.api.collections.data; + +import java.io.IOException; +import java.util.concurrent.CompletableFuture; +import java.util.function.Function; + +import io.weaviate.client6.v1.api.collections.ObjectMetadata; +import io.weaviate.client6.v1.api.collections.WeaviateObject; +import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; +import io.weaviate.client6.v1.internal.rest.RestTransport; + +public class WeaviateDataClientAsync { + private final RestTransport restTransport; + private final CollectionDescriptor collectionDescriptor; + + public WeaviateDataClientAsync(CollectionDescriptor collectionDescriptor, RestTransport restTransport) { + this.restTransport = restTransport; + this.collectionDescriptor = collectionDescriptor; + } + + public CompletableFuture> insert(T properties) throws IOException { + return insert(InsertObjectRequest.of(collectionDescriptor.name(), properties)); + } + + public CompletableFuture> insert(T properties, + Function, ObjectBuilder>> fn) + throws IOException { + return insert(InsertObjectRequest.of(collectionDescriptor.name(), properties, fn)); + } + + public CompletableFuture> insert(InsertObjectRequest request) + throws IOException { + return this.restTransport.performRequestAsync(request, InsertObjectRequest.endpoint(collectionDescriptor)); + } + + public CompletableFuture delete(String uuid) { + return this.restTransport.performRequestAsync(new DeleteObjectRequest(collectionDescriptor.name(), uuid), + DeleteObjectRequest._ENDPOINT); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/AbstractQueryClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/AbstractQueryClient.java index 38a78ea5f..abdca72a3 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/AbstractQueryClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/AbstractQueryClient.java @@ -8,7 +8,7 @@ import io.weaviate.client6.v1.internal.grpc.GrpcTransport; import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; -public abstract class AbstractQueryClient { +abstract class AbstractQueryClient { protected final CollectionDescriptor collection; protected final GrpcTransport transport; diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/NearVector.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/NearVector.java index d77cb4fe8..0a3b27c2a 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/NearVector.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/NearVector.java @@ -2,9 +2,9 @@ import java.util.function.Function; -import io.weaviate.client6.internal.GRPC; import io.weaviate.client6.v1.api.collections.aggregate.ObjectFilter; import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.GRPC; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoAggregate; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBaseSearch; diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryMetadata.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryMetadata.java index 54980688e..59cdee22a 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryMetadata.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryMetadata.java @@ -1,9 +1,10 @@ package io.weaviate.client6.v1.api.collections.query; import io.weaviate.client6.v1.api.collections.Vectors; +import io.weaviate.client6.v1.api.collections.WeaviateMetadata; import io.weaviate.client6.v1.internal.ObjectBuilder; -public record QueryMetadata(String id, Float distance, Float certainty, Vectors vectors) { +public record QueryMetadata(String uuid, Float distance, Float certainty, Vectors vectors) implements WeaviateMetadata { private QueryMetadata(Builder builder) { this(builder.uuid, builder.distance, builder.certainty, builder.vectors); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryObjectGrouped.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryObjectGrouped.java index 9fdaed176..1000d6321 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryObjectGrouped.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryObjectGrouped.java @@ -7,7 +7,8 @@ public record QueryObjectGrouped( QueryMetadata metadata, String belongsToGroup) { - QueryObjectGrouped(WeaviateObject object, String belongsToGroup) { + QueryObjectGrouped(WeaviateObject object, + String belongsToGroup) { this(object.properties(), object.metadata(), belongsToGroup); } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryReference.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryReference.java index e3c3a3537..fdc8a1c01 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryReference.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryReference.java @@ -3,7 +3,6 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.function.Consumer; import java.util.function.Function; import io.weaviate.client6.v1.internal.ObjectBuilder; @@ -47,14 +46,6 @@ public static QueryReference multi(String property, String collection, return fn.apply(new Builder(collection, property)).build(); } - public static QueryReference[] multi(String property, Consumer fn, String... collections) { - return Arrays.stream(collections).map(collection -> { - var builder = new Builder(collection, property); - fn.accept(builder); - return new QueryReference(builder); - }).toArray(QueryReference[]::new); - } - public static class Builder implements ObjectBuilder { private final String property; private final String collection; diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java index 4b9dd3ea8..33c271cf5 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java @@ -9,10 +9,10 @@ import java.util.stream.Collectors; import java.util.stream.Stream; -import io.weaviate.client6.internal.GRPC; -import io.weaviate.client6.v1.api.collections.ObjectReference; +import io.weaviate.client6.v1.api.collections.ObjectMetadata; import io.weaviate.client6.v1.api.collections.Vectors; import io.weaviate.client6.v1.api.collections.WeaviateObject; +import io.weaviate.client6.v1.internal.grpc.GRPC; import io.weaviate.client6.v1.internal.grpc.Rpc; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateBlockingStub; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateFutureStub; @@ -39,11 +39,11 @@ static Rpc { - List> objects = reply.getResultsList().stream() + var objects = reply + .getResultsList() + .stream() .map(obj -> QueryRequest.unmarshalResultObject( - obj.getProperties(), - obj.getMetadata(), - collection)) + obj.getProperties(), obj.getMetadata(), collection)) .toList(); return new QueryResponse<>(objects); }, @@ -82,7 +82,20 @@ static Rpc rpc.method(), () -> rpc.methodAsync()); } - private static WeaviateObject unmarshalResultObject( + private static WeaviateObject unmarshalResultObject( + WeaviateProtoSearchGet.PropertiesResult propertiesResult, + WeaviateProtoSearchGet.MetadataResult metadataResult, + CollectionDescriptor descriptor) { + var res = unmarshalReferences(propertiesResult, metadataResult, descriptor); + var metadata = new QueryMetadata.Builder() + .id(res.metadata().uuid()) + .distance(metadataResult.getDistance()) + .certainty(metadataResult.getCertainty()) + .vectors(res.metadata().vectors()); + return new WeaviateObject<>(descriptor.name(), res.properties(), res.references(), metadata.build()); + } + + private static WeaviateObject unmarshalReferences( WeaviateProtoSearchGet.PropertiesResult propertiesResult, WeaviateProtoSearchGet.MetadataResult metadataResult, CollectionDescriptor descriptor) { @@ -97,22 +110,31 @@ private static WeaviateObject unmarshalResultObject( // I.e. { "ref": A-1 } , { "ref": B-1 } => { "ref": [A-1, B-1] } var referenceProperties = propertiesResult.getRefPropsList() .stream().reduce( - new HashMap(), + new HashMap>(), (map, ref) -> { var refObjects = ref.getPropertiesList().stream() - .map(property -> unmarshalResultObject(property, property.getMetadata(), descriptor)) + .map(property -> { + var reference = unmarshalReferences( + property, property.getMetadata(), + // TODO: this should be possible to configure for ODM? + CollectionDescriptor.ofMap(property.getTargetCollection())); + return (Object) new WeaviateObject<>( + reference.collection(), + (Object) reference.properties(), + reference.references(), + reference.metadata()); + }) .toList(); // Merge ObjectReferences by joining the underlying WeaviateObjects. map.merge( ref.getPropName(), - // TODO: check if this works - new ObjectReference((List>) refObjects), + refObjects, (left, right) -> { var joined = Stream.concat( - left.objects().stream(), - right.objects().stream()).toList(); - return new ObjectReference(joined); + left.stream(), + right.stream()).toList(); + return joined; }); return map; }, @@ -121,33 +143,35 @@ private static WeaviateObject unmarshalResultObject( return left; }); - // TODO: should we return without metdata (null)? - if (metadataResult == null) { - metadataResult = propertiesResult.getMetadata(); - } - - var metadata = new QueryMetadata.Builder() - .id(metadataResult.getId()) - .distance(metadataResult.getDistance()) - .certainty(metadataResult.getCertainty()); + ObjectMetadata metadata = null; + if (metadataResult != null) { + var metadataBuilder = new ObjectMetadata.Builder() + .id(metadataResult.getId()); - var vectors = new Vectors.Builder(); - for (final var vector : metadataResult.getVectorsList()) { - var vectorName = vector.getName(); - switch (vector.getType()) { - case VECTOR_TYPE_SINGLE_FP32: - vectors.vector(vectorName, GRPC.fromByteString(vector.getVectorBytes())); - break; - case VECTOR_TYPE_MULTI_FP32: - vectors.vector(vectorName, GRPC.fromByteStringMulti(vector.getVectorBytes())); - break; - default: - continue; + var vectors = new Vectors.Builder(); + for (final var vector : metadataResult.getVectorsList()) { + var vectorName = vector.getName(); + switch (vector.getType()) { + case VECTOR_TYPE_SINGLE_FP32: + vectors.vector(vectorName, GRPC.fromByteString(vector.getVectorBytes())); + break; + case VECTOR_TYPE_MULTI_FP32: + vectors.vector(vectorName, GRPC.fromByteStringMulti(vector.getVectorBytes())); + break; + default: + continue; + } } + metadataBuilder.vectors(vectors.build()); + metadata = metadataBuilder.build(); } - metadata.vectors(vectors.build()); - return new WeaviateObject<>(descriptor.name(), properties.build(), referenceProperties, metadata.build()); + var obj = new WeaviateObject.Builder() + .collection(descriptor.name()) + .properties(properties.build()) + .references(referenceProperties) + .metadata(metadata); + return obj.build(); } private static void setProperty(String property, WeaviateProtoProperties.Value value, diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryResponse.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryResponse.java index 552ccc594..b5d465f8c 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryResponse.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryResponse.java @@ -4,5 +4,6 @@ import io.weaviate.client6.v1.api.collections.WeaviateObject; -public record QueryResponse(List> objects) { +public record QueryResponse( + List> objects) { } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/WeaviateObjectUnmarshaler.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/WeaviateObjectUnmarshaler.java new file mode 100644 index 000000000..e69de29bb diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/WeaviateQueryClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/WeaviateQueryClient.java index c2494f503..1bd313e3b 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/WeaviateQueryClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/WeaviateQueryClient.java @@ -8,14 +8,14 @@ public class WeaviateQueryClient extends - AbstractQueryClient>, QueryResponse, QueryResponseGrouped> { + AbstractQueryClient>, QueryResponse, QueryResponseGrouped> { public WeaviateQueryClient(CollectionDescriptor collection, GrpcTransport transport) { super(collection, transport); } @Override - protected Optional> byId(ById byId) { + protected Optional> byId(ById byId) { var request = new QueryRequest(byId, null); var result = this.transport.performRequest(request, QueryRequest.rpc(collection)); return optionalFirst(result.objects()); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/WeaviateQueryClientAsync.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/WeaviateQueryClientAsync.java index 5e5f0729f..50bc536cc 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/WeaviateQueryClientAsync.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/WeaviateQueryClientAsync.java @@ -9,14 +9,15 @@ public class WeaviateQueryClientAsync extends - AbstractQueryClient>>, CompletableFuture>, CompletableFuture>> { + AbstractQueryClient>>, CompletableFuture>, CompletableFuture>> { public WeaviateQueryClientAsync(CollectionDescriptor collection, GrpcTransport transport) { super(collection, transport); } @Override - protected CompletableFuture>> byId(ById byId) { + protected CompletableFuture>> byId( + ById byId) { var request = new QueryRequest(byId, null); var result = this.transport.performRequestAsync(request, QueryRequest.rpc(collection)); return result.thenApply(r -> optionalFirst(r.objects())); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/vectorindex/BaseVectorIndex.java b/src/main/java/io/weaviate/client6/v1/api/collections/vectorindex/BaseVectorIndex.java new file mode 100644 index 000000000..49ed116c2 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/vectorindex/BaseVectorIndex.java @@ -0,0 +1,19 @@ +package io.weaviate.client6.v1.api.collections.vectorindex; + +import io.weaviate.client6.v1.api.collections.VectorIndex; +import io.weaviate.client6.v1.api.collections.Vectorizer; +import lombok.EqualsAndHashCode; + +@EqualsAndHashCode +abstract class BaseVectorIndex implements VectorIndex { + protected final Vectorizer vectorizer; + + @Override + public Vectorizer vectorizer() { + return this.vectorizer; + } + + public BaseVectorIndex(Vectorizer vectorizer) { + this.vectorizer = vectorizer; + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/vectorindex/Distance.java b/src/main/java/io/weaviate/client6/v1/api/collections/vectorindex/Distance.java new file mode 100644 index 000000000..602e90693 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/vectorindex/Distance.java @@ -0,0 +1,16 @@ +package io.weaviate.client6.v1.api.collections.vectorindex; + +import com.google.gson.annotations.SerializedName; + +public enum Distance { + @SerializedName("cosine") + COSINE, + @SerializedName("dot") + DOT, + @SerializedName("l2-squared") + L2_SQUARED, + @SerializedName("hamming") + HAMMING, + @SerializedName("manhattan") + MANHATTAN; +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/vectorindex/Flat.java b/src/main/java/io/weaviate/client6/v1/api/collections/vectorindex/Flat.java new file mode 100644 index 000000000..90ca1990c --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/vectorindex/Flat.java @@ -0,0 +1,62 @@ +package io.weaviate.client6.v1.api.collections.vectorindex; + +import java.util.function.Function; + +import com.google.gson.annotations.SerializedName; + +import io.weaviate.client6.v1.api.collections.VectorIndex; +import io.weaviate.client6.v1.api.collections.Vectorizer; +import io.weaviate.client6.v1.internal.ObjectBuilder; +import lombok.EqualsAndHashCode; +import lombok.ToString; + +@EqualsAndHashCode(callSuper = true) +@ToString +public class Flat extends BaseVectorIndex { + @SerializedName("vectorCacheMaxObjects") + Long vectorCacheMaxObjects; + + @Override + public VectorIndex.Kind _kind() { + return VectorIndex.Kind.FLAT; + } + + @Override + public Object config() { + return this; + } + + public static Flat of(Vectorizer vectorizer) { + return of(vectorizer, ObjectBuilder.identity()); + } + + public static Flat of(Vectorizer vectorizer, Function> fn) { + return fn.apply(new Builder(vectorizer)).build(); + } + + public Flat(Builder builder) { + super(builder.vectorizer); + this.vectorCacheMaxObjects = builder.vectorCacheMaxObjects; + } + + public static class Builder implements ObjectBuilder { + // Required parameters. + private final Vectorizer vectorizer; + + private Long vectorCacheMaxObjects; + + protected Builder(Vectorizer vectorizer) { + this.vectorizer = vectorizer; + } + + public Builder vectorCacheMaxObjects(long vectorCacheMaxObjects) { + this.vectorCacheMaxObjects = vectorCacheMaxObjects; + return this; + } + + @Override + public Flat build() { + return new Flat(this); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/vectorindex/Hnsw.java b/src/main/java/io/weaviate/client6/v1/api/collections/vectorindex/Hnsw.java new file mode 100644 index 000000000..4538ad9b4 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/vectorindex/Hnsw.java @@ -0,0 +1,175 @@ +package io.weaviate.client6.v1.api.collections.vectorindex; + +import java.util.function.Function; + +import com.google.gson.annotations.SerializedName; + +import io.weaviate.client6.v1.api.collections.VectorIndex; +import io.weaviate.client6.v1.api.collections.Vectorizer; +import io.weaviate.client6.v1.internal.ObjectBuilder; +import lombok.EqualsAndHashCode; +import lombok.ToString; + +@EqualsAndHashCode(callSuper = true) +@ToString +public class Hnsw extends BaseVectorIndex { + @SerializedName("distance") + private final Distance distance; + @SerializedName("ef") + private final Integer ef; + @SerializedName("efConstruction") + private final Integer efConstruction; + @SerializedName("maxConnections") + private final Integer maxConnections; + @SerializedName("vectorCacheMaxObjects") + private final Long vectorCacheMaxObjects; + @SerializedName("cleanupIntervalSeconds") + private final Integer cleanupIntervalSeconds; + @SerializedName("filterStrategy") + private final FilterStrategy filterStrategy; + + @SerializedName("dynamicEfMin") + private final Integer dynamicEfMin; + @SerializedName("dynamicEfMax") + private final Integer dynamicEfMax; + @SerializedName("dynamicEfFactor") + private final Integer dynamicEfFactor; + @SerializedName("flatSearchCutoff") + private final Integer flatSearchCutoff; + @SerializedName("skip") + Boolean skipVectorization; + + @Override + public VectorIndex.Kind _kind() { + return VectorIndex.Kind.HNSW; + } + + @Override + public Object config() { + return this; + } + + @Override + public Vectorizer vectorizer() { + return this.vectorizer; + } + + public static Hnsw of(Vectorizer vectorizer) { + return of(vectorizer, ObjectBuilder.identity()); + } + + public static Hnsw of(Vectorizer vectorizer, Function> fn) { + return fn.apply(new Builder(vectorizer)).build(); + } + + public Hnsw(Builder builder) { + super(builder.vectorizer); + this.distance = builder.distance; + this.ef = builder.ef; + this.efConstruction = builder.efConstruction; + this.maxConnections = builder.maxConnections; + this.vectorCacheMaxObjects = builder.vectorCacheMaxObjects; + this.cleanupIntervalSeconds = builder.cleanupIntervalSeconds; + this.filterStrategy = builder.filterStrategy; + this.dynamicEfMin = builder.dynamicEfMin; + this.dynamicEfMax = builder.dynamicEfMax; + this.dynamicEfFactor = builder.dynamicEfFactor; + this.flatSearchCutoff = builder.flatSearchCutoff; + this.skipVectorization = builder.skipVectorization; + } + + public static class Builder implements ObjectBuilder { + // Required parameters. + private final Vectorizer vectorizer; + + private Distance distance; + private Integer ef; + private Integer efConstruction; + private Integer maxConnections; + private Long vectorCacheMaxObjects; + private Integer cleanupIntervalSeconds; + private FilterStrategy filterStrategy; + + private Integer dynamicEfMin; + private Integer dynamicEfMax; + private Integer dynamicEfFactor; + private Integer flatSearchCutoff; + private Boolean skipVectorization; + + public Builder(Vectorizer vectorizer) { + this.vectorizer = vectorizer; + } + + public Builder distance(Distance distance) { + this.distance = distance; + return this; + } + + public Builder ef(int ef) { + this.ef = ef; + return this; + } + + public final Builder efConstruction(int efConstruction) { + this.efConstruction = efConstruction; + return this; + } + + public final Builder maxConnections(int maxConnections) { + this.maxConnections = maxConnections; + return this; + } + + public final Builder vectorCacheMaxObjects(long vectorCacheMaxObjects) { + this.vectorCacheMaxObjects = vectorCacheMaxObjects; + return this; + } + + public final Builder cleanupIntervalSeconds(int cleanupIntervalSeconds) { + this.cleanupIntervalSeconds = cleanupIntervalSeconds; + return this; + } + + public final Builder filterStrategy(FilterStrategy filterStrategy) { + this.filterStrategy = filterStrategy; + return this; + } + + public final Builder dynamicEfMin(int dynamicEfMin) { + this.dynamicEfMin = dynamicEfMin; + return this; + } + + public final Builder dynamicEfMax(int dynamicEfMax) { + this.dynamicEfMax = dynamicEfMax; + return this; + } + + public final Builder dynamicEfFactor(int dynamicEfFactor) { + this.dynamicEfFactor = dynamicEfFactor; + return this; + } + + public final Builder flatSearchCutoff(int flatSearchCutoff) { + this.flatSearchCutoff = flatSearchCutoff; + return this; + } + + public final Builder skipVectorization(boolean skip) { + this.skipVectorization = skip; + return this; + } + + @Override + public Hnsw build() { + return new Hnsw(this); + } + } + + public enum FilterStrategy { + @SerializedName("sweeping") + SWEEPING, + @SerializedName("acorn") + ACORN; + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Img2VecNeuralVectorizer.java b/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Img2VecNeuralVectorizer.java new file mode 100644 index 000000000..2d9ff6beb --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Img2VecNeuralVectorizer.java @@ -0,0 +1,55 @@ +package io.weaviate.client6.v1.api.collections.vectorizers; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.function.Function; + +import com.google.gson.annotations.SerializedName; + +import io.weaviate.client6.v1.api.collections.Vectorizer; +import io.weaviate.client6.v1.internal.ObjectBuilder; + +public record Img2VecNeuralVectorizer( + @SerializedName("imageFields") List imageFields) implements Vectorizer { + + @Override + public Vectorizer.Kind _kind() { + return Vectorizer.Kind.IMG2VEC_NEURAL; + } + + @Override + public Object _self() { + return this; + } + + public static Img2VecNeuralVectorizer of() { + return of(ObjectBuilder.identity()); + } + + public static Img2VecNeuralVectorizer of(Function> fn) { + return fn.apply(new Builder()).build(); + } + + public Img2VecNeuralVectorizer(Builder builder) { + this(builder.imageFields); + } + + public static class Builder implements ObjectBuilder { + private List imageFields = new ArrayList<>(); + + public Builder imageFields(List fields) { + this.imageFields = fields; + return this; + } + + public Builder imageFields(String... fields) { + return imageFields(Arrays.asList(fields)); + } + + @Override + public Img2VecNeuralVectorizer build() { + return new Img2VecNeuralVectorizer(this); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Multi2VecClipVectorizer.java b/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Multi2VecClipVectorizer.java new file mode 100644 index 000000000..945984cc4 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Multi2VecClipVectorizer.java @@ -0,0 +1,104 @@ +package io.weaviate.client6.v1.api.collections.vectorizers; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +import com.google.gson.annotations.SerializedName; + +import io.weaviate.client6.v1.api.collections.Vectorizer; +import io.weaviate.client6.v1.internal.ObjectBuilder; + +public record Multi2VecClipVectorizer( + @SerializedName("vectorizeClassName") boolean vectorizeCollectionName, + @SerializedName("inferenceUrl") String inferenceUrl, + @SerializedName("imageFields") List imageFields, + @SerializedName("textFields") List textFields, + @SerializedName("weights") Weights weights) implements Vectorizer { + + private static record Weights( + @SerializedName("imageWeights") List imageWeights, + @SerializedName("textWeights") List textWeights) { + } + + @Override + public Vectorizer.Kind _kind() { + return Vectorizer.Kind.MULTI2VEC_CLIP; + } + + @Override + public Object _self() { + return this; + } + + public static Multi2VecClipVectorizer of() { + return of(ObjectBuilder.identity()); + } + + public static Multi2VecClipVectorizer of(Function> fn) { + return fn.apply(new Builder()).build(); + } + + public Multi2VecClipVectorizer(Builder builder) { + this( + builder.vectorizeCollectionName, + builder.inferenceUrl, + builder.imageFields.keySet().stream().toList(), + builder.textFields.keySet().stream().toList(), + new Weights( + builder.imageFields.values().stream().toList(), + builder.textFields.values().stream().toList())); + } + + public static class Builder implements ObjectBuilder { + private boolean vectorizeCollectionName = false; + private String inferenceUrl; + private Map imageFields = new HashMap<>(); + private Map textFields = new HashMap<>(); + + public Builder inferenceUrl(String inferenceUrl) { + this.inferenceUrl = inferenceUrl; + return this; + } + + public Builder imageFields(List fields) { + fields.forEach(field -> imageFields.put(field, null)); + return this; + } + + public Builder imageFields(String... fields) { + return imageFields(Arrays.asList(fields)); + } + + public Builder imageField(String field, float weight) { + imageFields.put(field, weight); + return this; + } + + public Builder textFields(List fields) { + fields.forEach(field -> textFields.put(field, null)); + return this; + } + + public Builder textFields(String... fields) { + return textFields(Arrays.asList(fields)); + } + + public Builder textField(String field, float weight) { + textFields.put(field, weight); + return this; + } + + public Builder vectorizeCollectionName(boolean enable) { + this.vectorizeCollectionName = enable; + return this; + } + + @Override + public Multi2VecClipVectorizer build() { + return new Multi2VecClipVectorizer(this); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/NoneVectorizer.java b/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/NoneVectorizer.java new file mode 100644 index 000000000..6449ba89b --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/NoneVectorizer.java @@ -0,0 +1,45 @@ +package io.weaviate.client6.v1.api.collections.vectorizers; + +import java.io.IOException; + +import com.google.gson.TypeAdapter; +import com.google.gson.stream.JsonReader; +import com.google.gson.stream.JsonToken; +import com.google.gson.stream.JsonWriter; + +import io.weaviate.client6.v1.api.collections.Vectorizer; + +public record NoneVectorizer() implements Vectorizer { + @Override + public Kind _kind() { + return Vectorizer.Kind.NONE; + } + + @Override + public Object _self() { + return this; + } + + public static final TypeAdapter TYPE_ADAPTER = new TypeAdapter() { + + @Override + public void write(JsonWriter out, NoneVectorizer value) throws IOException { + out.beginObject(); + out.name(value._kind().jsonValue()); + out.beginObject(); + out.endObject(); + out.endObject(); + } + + @Override + public NoneVectorizer read(JsonReader in) throws IOException { + // NoneVectorizer expects no parameters, so we just skip to the closing bracket. + in.beginObject(); + while (in.peek() != JsonToken.END_OBJECT) { + in.skipValue(); + } + in.endObject(); + return new NoneVectorizer(); + } + }.nullSafe(); +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Text2VecContextionaryVectorizer.java b/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Text2VecContextionaryVectorizer.java new file mode 100644 index 000000000..7bbfc6c9c --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Text2VecContextionaryVectorizer.java @@ -0,0 +1,48 @@ +package io.weaviate.client6.v1.api.collections.vectorizers; + +import java.util.function.Function; + +import com.google.gson.annotations.SerializedName; + +import io.weaviate.client6.v1.api.collections.Vectorizer; +import io.weaviate.client6.v1.internal.ObjectBuilder; + +public record Text2VecContextionaryVectorizer( + @SerializedName("vectorizeClassName") boolean vectorizeCollectionName) implements Vectorizer { + + @Override + public Vectorizer.Kind _kind() { + return Vectorizer.Kind.TEXT2VEC_CONTEXTIONARY; + } + + @Override + public Object _self() { + return this; + } + + public static Text2VecContextionaryVectorizer of() { + return of(ObjectBuilder.identity()); + } + + public static Text2VecContextionaryVectorizer of( + Function> fn) { + return fn.apply(new Builder()).build(); + } + + public Text2VecContextionaryVectorizer(Builder builder) { + this(builder.vectorizeCollectionName); + } + + public static class Builder implements ObjectBuilder { + private boolean vectorizeCollectionName = false; + + public Builder vectorizeCollectionName(boolean enable) { + this.vectorizeCollectionName = enable; + return this; + } + + public Text2VecContextionaryVectorizer build() { + return new Text2VecContextionaryVectorizer(this); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Text2VecWeaviateVectorizer.java b/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Text2VecWeaviateVectorizer.java new file mode 100644 index 000000000..134a6513a --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/vectorizers/Text2VecWeaviateVectorizer.java @@ -0,0 +1,74 @@ +package io.weaviate.client6.v1.api.collections.vectorizers; + +import java.util.function.Function; + +import com.google.gson.annotations.SerializedName; + +import io.weaviate.client6.v1.api.collections.Vectorizer; +import io.weaviate.client6.v1.internal.ObjectBuilder; + +public record Text2VecWeaviateVectorizer( + @SerializedName("vectorizeClassName") boolean vectorizeCollectionName, + @SerializedName("baseUrl") String inferenceUrl, + @SerializedName("dimensions") Integer dimensions, + @SerializedName("model") String model) implements Vectorizer { + + @Override + public Vectorizer.Kind _kind() { + return Vectorizer.Kind.TEXT2VEC_WEAVIATE; + } + + @Override + public Object _self() { + return this; + } + + public static Text2VecWeaviateVectorizer of() { + return of(ObjectBuilder.identity()); + } + + public static Text2VecWeaviateVectorizer of(Function> fn) { + return fn.apply(new Builder()).build(); + } + + public Text2VecWeaviateVectorizer(Builder builder) { + this(builder.vectorizeCollectionName, + builder.inferenceUrl, + builder.dimensions, + builder.model); + } + + public static final String SNOWFLAKE_ARCTIC_EMBED_L_20 = "Snowflake/snowflake-arctic-embed-l-v2.0"; + public static final String SNOWFLAKE_ARCTIC_EMBED_M_15 = "Snowflake/snowflake-arctic-embed-m-v1.5"; + + public static class Builder implements ObjectBuilder { + private boolean vectorizeCollectionName = false; + private String inferenceUrl; + private Integer dimensions; + private String model; + + public Builder vectorizeCollectionName(boolean enable) { + this.vectorizeCollectionName = enable; + return this; + } + + public Builder inferenceUrl(String inferenceUrl) { + this.inferenceUrl = inferenceUrl; + return this; + } + + public Builder dimensions(int dimensions) { + this.dimensions = dimensions; + return this; + } + + public Builder model(String model) { + this.model = model; + return this; + } + + public Text2VecWeaviateVectorizer build() { + return new Text2VecWeaviateVectorizer(this); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/collections/AtomicDataType.java b/src/main/java/io/weaviate/client6/v1/collections/AtomicDataType.java deleted file mode 100644 index 38c33ed22..000000000 --- a/src/main/java/io/weaviate/client6/v1/collections/AtomicDataType.java +++ /dev/null @@ -1,18 +0,0 @@ -package io.weaviate.client6.v1.collections; - -import com.google.gson.annotations.SerializedName; - -public enum AtomicDataType { - @SerializedName("text") - TEXT, - @SerializedName("int") - INT, - @SerializedName("blob") - BLOB; - - public static boolean isAtomic(String type) { - return type.equals(TEXT.name().toLowerCase()) - || type.equals(INT.name().toLowerCase()) - || type.equals(BLOB.name().toLowerCase()); - } -} diff --git a/src/main/java/io/weaviate/client6/v1/collections/Collection.java b/src/main/java/io/weaviate/client6/v1/collections/Collection.java deleted file mode 100644 index 870701311..000000000 --- a/src/main/java/io/weaviate/client6/v1/collections/Collection.java +++ /dev/null @@ -1,56 +0,0 @@ -package io.weaviate.client6.v1.collections; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.function.Consumer; - -import io.weaviate.client6.v1.collections.Vectors.NamedVectors; - -public record Collection(String name, List properties, List references, Vectors vectors) { - - public static Collection with(String name, Consumer options) { - var config = new Builder(options); - return new Collection(name, config.properties, config.references, config.vectors); - } - - public static class Builder { - private List properties = new ArrayList<>(); - public List references = new ArrayList<>(); - private Vectors vectors; - - public Builder properties(Property... properties) { - this.properties = Arrays.asList(properties); - return this; - } - - public Builder references(ReferenceProperty... references) { - this.references = Arrays.asList(references); - return this; - } - - public Builder vectors(Vectors vectors) { - this.vectors = vectors; - return this; - } - - public Builder vector(VectorIndex vector) { - this.vectors = Vectors.of(vector); - return this; - } - - public Builder vector(String name, VectorIndex vector) { - this.vectors = new Vectors(name, vector); - return this; - } - - public Builder vectors(Consumer named) { - this.vectors = Vectors.with(named); - return this; - } - - Builder(Consumer options) { - options.accept(this); - } - } -} diff --git a/src/main/java/io/weaviate/client6/v1/collections/CollectionClient.java b/src/main/java/io/weaviate/client6/v1/collections/CollectionClient.java deleted file mode 100644 index 76d840a03..000000000 --- a/src/main/java/io/weaviate/client6/v1/collections/CollectionClient.java +++ /dev/null @@ -1,27 +0,0 @@ -package io.weaviate.client6.v1.collections; - -import io.weaviate.client6.Config; -import io.weaviate.client6.internal.GrpcClient; -import io.weaviate.client6.internal.HttpClient; -import io.weaviate.client6.v1.api.collections.aggregate.WeaviateAggregateClient; -import io.weaviate.client6.v1.api.collections.query.WeaviateQueryClient; -import io.weaviate.client6.v1.collections.data.DataClient; -import io.weaviate.client6.v1.internal.grpc.GrpcTransport; -import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; - -public class CollectionClient { - public final WeaviateQueryClient query; - public final WeaviateAggregateClient aggregate; - - public final DataClient data; - public final CollectionConfigClient config; - - public CollectionClient(String collectionName, Config config, GrpcClient grpc, HttpClient http, - GrpcTransport grpcTransport, CollectionDescriptor collectionDescriptor) { - this.query = new WeaviateQueryClient<>(collectionDescriptor, grpcTransport); - this.aggregate = new WeaviateAggregateClient(collectionDescriptor, grpcTransport); - - this.data = new DataClient<>(collectionName, config, http, grpc); - this.config = new CollectionConfigClient(collectionName, config, http); - } -} diff --git a/src/main/java/io/weaviate/client6/v1/collections/CollectionConfigClient.java b/src/main/java/io/weaviate/client6/v1/collections/CollectionConfigClient.java deleted file mode 100644 index 926a335bb..000000000 --- a/src/main/java/io/weaviate/client6/v1/collections/CollectionConfigClient.java +++ /dev/null @@ -1,159 +0,0 @@ -package io.weaviate.client6.v1.collections; - -import java.io.IOException; -import java.io.InputStreamReader; -import java.lang.reflect.Type; -import java.util.Map; -import java.util.Optional; - -import org.apache.hc.core5.http.ClassicHttpRequest; -import org.apache.hc.core5.http.ContentType; -import org.apache.hc.core5.http.HttpStatus; -import org.apache.hc.core5.http.io.entity.EntityUtils; -import org.apache.hc.core5.http.io.support.ClassicRequestBuilder; - -import com.google.gson.Gson; -import com.google.gson.GsonBuilder; -import com.google.gson.JsonDeserializationContext; -import com.google.gson.JsonDeserializer; -import com.google.gson.JsonElement; -import com.google.gson.JsonParseException; -import com.google.gson.JsonSerializationContext; -import com.google.gson.JsonSerializer; -import com.google.gson.TypeAdapter; -import com.google.gson.reflect.TypeToken; -import com.google.gson.stream.JsonReader; -import com.google.gson.stream.JsonWriter; - -import io.weaviate.client6.Config; -import io.weaviate.client6.internal.DtoTypeAdapterFactory; -import io.weaviate.client6.internal.HttpClient; -import io.weaviate.client6.v1.collections.VectorIndex.IndexingStrategy; -import lombok.AllArgsConstructor; - -@AllArgsConstructor -public class CollectionConfigClient { - // TODO: hide befind an internal HttpClient - private final String collectionName; - private final Config config; - private final HttpClient httpClient; - - static { - DtoTypeAdapterFactory.register( - Collection.class, - CollectionDefinitionDTO.class, - m -> new CollectionDefinitionDTO(m)); - } - - // Gson cannot deserialize interfaces: - // https://stackoverflow.com/a/49871339/14726116 - private static class IndexingStrategySerde - implements JsonDeserializer, JsonSerializer { - - @Override - public IndexingStrategy deserialize(JsonElement json, Type typeOfT, JsonDeserializationContext context) - throws JsonParseException { - return IndexingStrategy.hnsw(); - } - - @Override - public JsonElement serialize(IndexingStrategy src, Type typeOfSrc, JsonSerializationContext context) { - return context.serialize(src); - } - } - - // Gson cannot deserialize interfaces: - // https://stackoverflow.com/a/49871339/14726116 - private static class VectorizerSerde - implements JsonDeserializer, JsonSerializer { - - @Override - public Vectorizer deserialize(JsonElement json, Type typeOfT, JsonDeserializationContext context) - throws JsonParseException { - return Vectorizer.none(); - } - - @Override - public JsonElement serialize(Vectorizer src, Type typeOfSrc, JsonSerializationContext context) { - return context.serialize(src); - } - } - - private static final Gson gson = new GsonBuilder() - .registerTypeAdapterFactory(new DtoTypeAdapterFactory()) - .registerTypeAdapter(Vectorizer.class, new VectorizerSerde()) - .registerTypeAdapter(IndexingStrategy.class, new IndexingStrategySerde()) - .registerTypeAdapter(Vectors.class, new TypeAdapter() { - Gson gson = new GsonBuilder() - .registerTypeAdapter(Vectorizer.class, new VectorizerSerde()) - .registerTypeAdapter(IndexingStrategy.class, new IndexingStrategySerde()) - .create(); - - @Override - public void write(JsonWriter out, Vectors value) throws IOException { - if (value != null) { - gson.toJson(value.asMap(), Map.class, out); - } else { - out.nullValue(); - } - } - - @Override - public Vectors read(JsonReader in) throws IOException { - Map> vectors = gson.fromJson(in, - new TypeToken>>() { - }.getType()); - return Vectors.of(vectors); - } - }) - .create(); - - public Optional get() throws IOException { - ClassicHttpRequest httpGet = ClassicRequestBuilder - .get(config.baseUrl() + "/schema/" + collectionName) - .build(); - - return httpClient.http.execute(httpGet, response -> { - if (response.getCode() == HttpStatus.SC_NOT_FOUND) { - return Optional.empty(); - } - try (var r = new InputStreamReader(response.getEntity().getContent())) { - var collection = gson.fromJson(r, Collection.class); - return Optional.ofNullable(collection); - } - }); - } - - public void addProperty(Property property) throws IOException { - ClassicHttpRequest httpPost = ClassicRequestBuilder - .post(config.baseUrl() + "/schema/" + collectionName + "/properties") - .setEntity(gson.toJson(property), ContentType.APPLICATION_JSON) - .build(); - - httpClient.http.execute(httpPost, response -> { - var entity = response.getEntity(); - if (response.getCode() != HttpStatus.SC_SUCCESS) { - var message = EntityUtils.toString(entity); - throw new RuntimeException("HTTP " + response.getCode() + ": " + message); - } - return null; - }); - } - - public void addReference(String name, String... dataTypes) throws IOException { - var property = Property.reference(name, dataTypes); - ClassicHttpRequest httpPost = ClassicRequestBuilder - .post(config.baseUrl() + "/schema/" + collectionName + "/properties") - .setEntity(gson.toJson(property), ContentType.APPLICATION_JSON) - .build(); - - httpClient.http.execute(httpPost, response -> { - var entity = response.getEntity(); - if (response.getCode() != HttpStatus.SC_SUCCESS) { - var message = EntityUtils.toString(entity); - throw new RuntimeException("HTTP " + response.getCode() + ": " + message); - } - return null; - }); - } -} diff --git a/src/main/java/io/weaviate/client6/v1/collections/CollectionDefinitionDTO.java b/src/main/java/io/weaviate/client6/v1/collections/CollectionDefinitionDTO.java deleted file mode 100644 index 9a3bb6422..000000000 --- a/src/main/java/io/weaviate/client6/v1/collections/CollectionDefinitionDTO.java +++ /dev/null @@ -1,63 +0,0 @@ -package io.weaviate.client6.v1.collections; - -import java.util.ArrayList; -import java.util.List; -import java.util.stream.Stream; - -import com.google.gson.annotations.SerializedName; - -import io.weaviate.client6.internal.DtoTypeAdapterFactory; - -class CollectionDefinitionDTO implements DtoTypeAdapterFactory.Dto { - @SerializedName("class") - String collection; - - @SerializedName("properties") - List properties; - - @SerializedName("vectorConfig") - Vectors vectors; - - @SerializedName("vectorIndexType") - private VectorIndex.IndexType vectorIndexType; - - @SerializedName("vectorIndexConfig") - private VectorIndex.IndexingStrategy vectorIndexConfig; - - @SerializedName("vectorizer") - private Vectorizer vectorizer; - - public CollectionDefinitionDTO(Collection colDef) { - this.collection = colDef.name(); - this.properties = Stream.concat( - colDef.properties().stream(), - colDef.references().stream().map(r -> new Property(r.name(), - r.dataTypes()))) - .toList(); - this.vectors = colDef.vectors(); - - if (this.vectors != null) { - var unnamed = this.vectors.getUnnamed(); - if (unnamed.isPresent()) { - var index = unnamed.get(); - this.vectorIndexType = index.type(); - this.vectorIndexConfig = index.configuration(); - this.vectorizer = index.vectorizer(); - } - } - } - - public Collection toModel() { - var onlyProperties = new ArrayList(); - var references = new ArrayList(); - - for (var p : properties) { - if (p.isReference()) { - references.add(Property.reference(p.name(), p.dataTypes())); - } else { - onlyProperties.add(p); - } - } - return new Collection(collection, onlyProperties, references, vectors); - } -} diff --git a/src/main/java/io/weaviate/client6/v1/collections/CollectionsClient.java b/src/main/java/io/weaviate/client6/v1/collections/CollectionsClient.java deleted file mode 100644 index 6ff496972..000000000 --- a/src/main/java/io/weaviate/client6/v1/collections/CollectionsClient.java +++ /dev/null @@ -1,178 +0,0 @@ -package io.weaviate.client6.v1.collections; - -import java.io.IOException; -import java.io.InputStreamReader; -import java.lang.reflect.Type; -import java.util.Map; -import java.util.Optional; -import java.util.function.Consumer; - -import org.apache.hc.core5.http.ClassicHttpRequest; -import org.apache.hc.core5.http.ContentType; -import org.apache.hc.core5.http.HttpStatus; -import org.apache.hc.core5.http.io.entity.EntityUtils; -import org.apache.hc.core5.http.io.support.ClassicRequestBuilder; - -import com.google.common.reflect.TypeToken; -import com.google.gson.Gson; -import com.google.gson.GsonBuilder; -import com.google.gson.JsonDeserializationContext; -import com.google.gson.JsonDeserializer; -import com.google.gson.JsonElement; -import com.google.gson.JsonParseException; -import com.google.gson.JsonSerializationContext; -import com.google.gson.JsonSerializer; -import com.google.gson.TypeAdapter; -import com.google.gson.stream.JsonReader; -import com.google.gson.stream.JsonWriter; - -import io.weaviate.client6.Config; -import io.weaviate.client6.internal.DtoTypeAdapterFactory; -import io.weaviate.client6.internal.GrpcClient; -import io.weaviate.client6.internal.HttpClient; -import io.weaviate.client6.v1.collections.VectorIndex.IndexingStrategy; -import io.weaviate.client6.v1.internal.grpc.GrpcTransport; -import io.weaviate.client6.v1.internal.orm.MapDescriptor; -import lombok.AllArgsConstructor; - -@AllArgsConstructor -public class CollectionsClient { - // TODO: hide befind an internal HttpClient - private final Config config; - - private final HttpClient httpClient; - private final GrpcClient grpcClient; - - // TODO: Some commong AbstractWeaviateClient should hold these. - private final GrpcTransport grpcTransport; - - static { - DtoTypeAdapterFactory.register( - Collection.class, - CollectionDefinitionDTO.class, - m -> new CollectionDefinitionDTO(m)); - } - - // Gson cannot deserialize interfaces: - // https://stackoverflow.com/a/49871339/14726116 - private static class IndexingStrategySerde - implements JsonDeserializer, JsonSerializer { - - @Override - public IndexingStrategy deserialize(JsonElement json, Type typeOfT, JsonDeserializationContext context) - throws JsonParseException { - return IndexingStrategy.hnsw(); - } - - @Override - public JsonElement serialize(IndexingStrategy src, Type typeOfSrc, JsonSerializationContext context) { - return context.serialize(src); - } - } - - // Gson cannot deserialize interfaces: - // https://stackoverflow.com/a/49871339/14726116 - private static class VectorizerSerde - implements JsonDeserializer, JsonSerializer { - - @Override - public Vectorizer deserialize(JsonElement json, Type typeOfT, JsonDeserializationContext context) - throws JsonParseException { - // TODO: deserialize different kinds of vectorizers - return Vectorizer.none(); - } - - @Override - public JsonElement serialize(Vectorizer src, Type typeOfSrc, JsonSerializationContext context) { - return context.serialize(src); - } - } - - private static final Gson gson = new GsonBuilder() - .registerTypeAdapterFactory(new DtoTypeAdapterFactory()) - .registerTypeAdapter(IndexingStrategy.class, new IndexingStrategySerde()) - .registerTypeAdapter(Vectorizer.class, new VectorizerSerde()) - .registerTypeAdapter(Vectors.class, new TypeAdapter() { - Gson gson = new GsonBuilder() - .registerTypeAdapter(Vectorizer.class, new VectorizerSerde()) - .registerTypeAdapter(IndexingStrategy.class, new IndexingStrategySerde()) - .create(); - - @Override - public void write(JsonWriter out, Vectors value) throws IOException { - if (value != null) { - gson.toJson(value.asMap(), Map.class, out); - } else { - out.nullValue(); - } - } - - @Override - public Vectors read(JsonReader in) throws IOException { - Map> vectors = gson.fromJson(in, - new TypeToken>>() { - }.getType()); - return Vectors.of(vectors); - } - }) - .create(); - - public void create(String name) throws IOException { - create(name, opt -> { - }); - } - - public void create(String name, Consumer options) throws IOException { - var collection = Collection.with(name, options); - ClassicHttpRequest httpPost = ClassicRequestBuilder - .post(config.baseUrl() + "/schema") - .setEntity(gson.toJson(collection), ContentType.APPLICATION_JSON) - .build(); - - // TODO: do not expose Apache HttpClient directly - httpClient.http.execute(httpPost, response -> { - var entity = response.getEntity(); - if (response.getCode() != HttpStatus.SC_SUCCESS) { // Does not return 201 - var message = EntityUtils.toString(entity); - throw new RuntimeException("HTTP " + response.getCode() + ": " + message); - } - return null; - }); - } - - public Optional getConfig(String name) throws IOException { - ClassicHttpRequest httpGet = ClassicRequestBuilder - .get(config.baseUrl() + "/schema/" + name) - .build(); - - return httpClient.http.execute(httpGet, response -> { - if (response.getCode() == HttpStatus.SC_NOT_FOUND) { - return Optional.empty(); - } - try (var r = new InputStreamReader(response.getEntity().getContent())) { - var collection = gson.fromJson(r, Collection.class); - return Optional.ofNullable(collection); - } - }); - } - - public void delete(String name) throws IOException { - ClassicHttpRequest httpDelete = ClassicRequestBuilder - .delete(config.baseUrl() + "/schema/" + name) - .build(); - - httpClient.http.execute(httpDelete, response -> { - var entity = response.getEntity(); - if (response.getCode() != HttpStatus.SC_SUCCESS) { // Does not return 201 - var message = EntityUtils.toString(entity); - throw new RuntimeException("HTTP " + response.getCode() + ": " + message); - } - return null; - }); - } - - public CollectionClient> use(String collectionName) { - return new CollectionClient<>(collectionName, config, grpcClient, httpClient, grpcTransport, - new MapDescriptor(collectionName)); - } -} diff --git a/src/main/java/io/weaviate/client6/v1/collections/ContextionaryVectorizer.java b/src/main/java/io/weaviate/client6/v1/collections/ContextionaryVectorizer.java deleted file mode 100644 index 1bb580ada..000000000 --- a/src/main/java/io/weaviate/client6/v1/collections/ContextionaryVectorizer.java +++ /dev/null @@ -1,37 +0,0 @@ -package io.weaviate.client6.v1.collections; - -import java.util.Map; -import java.util.function.Consumer; - -import com.google.gson.annotations.SerializedName; - -import lombok.AllArgsConstructor; - -@AllArgsConstructor -public class ContextionaryVectorizer extends Vectorizer { - @SerializedName("text2vec-contextionary") - private Map configuration; - - public static ContextionaryVectorizer of() { - return new Builder().build(); - } - - public static ContextionaryVectorizer of(Consumer fn) { - var builder = new Builder(); - fn.accept(builder); - return builder.build(); - } - - public static class Builder { - private boolean vectorizeCollectionName = false; - - public Builder vectorizeCollectionName() { - this.vectorizeCollectionName = true; - return this; - } - - public ContextionaryVectorizer build() { - return new ContextionaryVectorizer(Map.of("vectorizeClassName", vectorizeCollectionName)); - } - } -} diff --git a/src/main/java/io/weaviate/client6/v1/collections/HNSW.java b/src/main/java/io/weaviate/client6/v1/collections/HNSW.java deleted file mode 100644 index 938c44ffa..000000000 --- a/src/main/java/io/weaviate/client6/v1/collections/HNSW.java +++ /dev/null @@ -1,43 +0,0 @@ -package io.weaviate.client6.v1.collections; - -import java.util.function.Consumer; - -import io.weaviate.client6.v1.collections.VectorIndex.IndexType; - -public final record HNSW(Distance distance, Boolean skip) implements VectorIndex.IndexingStrategy { - public VectorIndex.IndexType type() { - return IndexType.HNSW; - } - - public enum Distance { - COSINE; - } - - HNSW() { - this(null, null); - } - - static HNSW with(Consumer options) { - var opt = new Builder(options); - return new HNSW(opt.distance, opt.skip); - } - - public static class Builder { - private Distance distance; - private Boolean skip; - - public Builder distance(Distance distance) { - this.distance = distance; - return this; - } - - public Builder disableIndexation() { - this.skip = true; - return this; - } - - public Builder(Consumer options) { - options.accept(this); - } - } -} diff --git a/src/main/java/io/weaviate/client6/v1/collections/Img2VecNeuralVectorizer.java b/src/main/java/io/weaviate/client6/v1/collections/Img2VecNeuralVectorizer.java deleted file mode 100644 index a0efc5c61..000000000 --- a/src/main/java/io/weaviate/client6/v1/collections/Img2VecNeuralVectorizer.java +++ /dev/null @@ -1,40 +0,0 @@ -package io.weaviate.client6.v1.collections; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Map; -import java.util.function.Consumer; - -import com.google.gson.annotations.SerializedName; - -import lombok.AllArgsConstructor; - -@AllArgsConstructor -public class Img2VecNeuralVectorizer extends Vectorizer { - @SerializedName("img2vec-neural") - private Map configuration; - - public static Img2VecNeuralVectorizer of() { - return new Builder().build(); - } - - public static Img2VecNeuralVectorizer of(Consumer fn) { - var builder = new Builder(); - fn.accept(builder); - return builder.build(); - } - - public static class Builder { - private List imageFields = new ArrayList<>(); - - public Builder imageFields(String... fields) { - this.imageFields = Arrays.asList(fields); - return this; - } - - public Img2VecNeuralVectorizer build() { - return new Img2VecNeuralVectorizer(Map.of("imageFields", imageFields)); - } - } -} diff --git a/src/main/java/io/weaviate/client6/v1/collections/Multi2VecClipVectorizer.java b/src/main/java/io/weaviate/client6/v1/collections/Multi2VecClipVectorizer.java deleted file mode 100644 index 305e8373a..000000000 --- a/src/main/java/io/weaviate/client6/v1/collections/Multi2VecClipVectorizer.java +++ /dev/null @@ -1,100 +0,0 @@ -package io.weaviate.client6.v1.collections; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.function.Consumer; - -import com.google.gson.annotations.SerializedName; - -import lombok.AllArgsConstructor; - -@AllArgsConstructor -public class Multi2VecClipVectorizer extends Vectorizer { - @SerializedName("multi2vec-clip") - private Map configuration; - - public static Multi2VecClipVectorizer of() { - return new Builder().build(); - } - - public static Multi2VecClipVectorizer of(Consumer fn) { - var builder = new Builder(); - fn.accept(builder); - return builder.build(); - } - - public static class Builder { - private boolean vectorizeCollectionName = false; - private String inferenceUrl; - private Map imageFields = new HashMap<>(); - private Map textFields = new HashMap<>(); - - public Builder inferenceUrl(String inferenceUrl) { - this.inferenceUrl = inferenceUrl; - return this; - } - - public Builder imageFields(String... fields) { - Arrays.stream(fields).forEach(f -> imageFields.put(f, null)); - return this; - } - - public Builder imageField(String field, float weight) { - imageFields.put(field, weight); - return this; - } - - public Builder textFields(String... fields) { - Arrays.stream(fields).forEach(f -> textFields.put(f, null)); - return this; - } - - public Builder textField(String field, float weight) { - textFields.put(field, weight); - return this; - } - - public Builder vectorizeCollectionName() { - this.vectorizeCollectionName = true; - return this; - } - - public Multi2VecClipVectorizer build() { - return new Multi2VecClipVectorizer(new HashMap<>() { - { - put("vectorizeClassName", vectorizeCollectionName); - if (inferenceUrl != null) { - put("inferenceUrl", inferenceUrl); - } - - var _imageFields = new ArrayList(); - var _imageWeights = new ArrayList(); - splitEntries(imageFields, _imageFields, _imageWeights); - - var _textFields = new ArrayList(); - var _textWeights = new ArrayList(); - splitEntries(imageFields, _textFields, _textWeights); - - put("imageFields", _imageFields); - put("textFields", _textFields); - put("weights", Map.of( - "imageWeights", _imageWeights, - "textWeights", _textWeights)); - } - }); - } - - private void splitEntries(Map map, List keys, List values) { - map.entrySet().forEach(entry -> { - keys.add(entry.getKey()); - var value = entry.getValue(); - if (value != null) { - values.add(value); - } - }); - } - } -} diff --git a/src/main/java/io/weaviate/client6/v1/collections/NoneVectorizer.java b/src/main/java/io/weaviate/client6/v1/collections/NoneVectorizer.java deleted file mode 100644 index 014ed7cca..000000000 --- a/src/main/java/io/weaviate/client6/v1/collections/NoneVectorizer.java +++ /dev/null @@ -1,10 +0,0 @@ -package io.weaviate.client6.v1.collections; - -import java.util.Map; - -import com.google.gson.annotations.SerializedName; - -public class NoneVectorizer extends Vectorizer { - @SerializedName("none") - private final Map _configuration = Map.of(); -} diff --git a/src/main/java/io/weaviate/client6/v1/collections/Reference.java b/src/main/java/io/weaviate/client6/v1/collections/Reference.java deleted file mode 100644 index b17799911..000000000 --- a/src/main/java/io/weaviate/client6/v1/collections/Reference.java +++ /dev/null @@ -1,59 +0,0 @@ -package io.weaviate.client6.v1.collections; - -import java.io.IOException; -import java.util.Arrays; -import java.util.List; - -import com.google.gson.stream.JsonWriter; - -import io.weaviate.client6.v1.collections.object.WeaviateObject; - -public record Reference(String collection, List uuids) { - - public Reference(String collection, String uuid) { - this(collection, List.of(uuid)); - } - - /** - * Create reference to objects by their UUIDs. - *

- * Weaviate will search each of the existing collections to identify - * the objects before inserting the references, so this may include - * some performance overhead. - */ - public static Reference uuids(String... uuids) { - return new Reference(null, Arrays.asList(uuids)); - } - - /** Create references to {@link WeaviateObject}. */ - public static Reference[] objects(WeaviateObject... objects) { - return Arrays.stream(objects) - .map(o -> new Reference(o.collection(), o.metadata().id())) - .toArray(Reference[]::new); - } - - /** Create references to objects in a collection by their UUIDs. */ - public static Reference collection(String collection, String... uuids) { - return new Reference(collection, Arrays.asList(uuids)); - } - - // TODO: put this in a type adapter. - /** writeValue assumes an array has been started will be ended by the caller. */ - public void writeValue(JsonWriter w) throws IOException { - for (var uuid : uuids) { - w.beginObject(); - w.name("beacon"); - w.value(toBeacon(uuid)); - w.endObject(); - } - } - - private String toBeacon(String uuid) { - var beacon = "weaviate://localhost/"; - if (collection != null) { - beacon += collection + "/"; - } - beacon += uuid; - return beacon; - } -} diff --git a/src/main/java/io/weaviate/client6/v1/collections/Text2VecWeaviateVectorizer.java b/src/main/java/io/weaviate/client6/v1/collections/Text2VecWeaviateVectorizer.java deleted file mode 100644 index db1f9a0f3..000000000 --- a/src/main/java/io/weaviate/client6/v1/collections/Text2VecWeaviateVectorizer.java +++ /dev/null @@ -1,72 +0,0 @@ -package io.weaviate.client6.v1.collections; - -import java.util.HashMap; -import java.util.Map; -import java.util.function.Consumer; - -import com.google.gson.annotations.SerializedName; - -import lombok.AllArgsConstructor; - -@AllArgsConstructor -public class Text2VecWeaviateVectorizer extends Vectorizer { - @SerializedName("text2vec-weaviate") - private Map configuration; - - public static Text2VecWeaviateVectorizer of() { - return new Builder().build(); - } - - public static Text2VecWeaviateVectorizer of(Consumer fn) { - var builder = new Builder(); - fn.accept(builder); - return builder.build(); - } - - public static final String SNOWFLAKE_ARCTIC_EMBED_L_20 = "Snowflake/snowflake-arctic-embed-l-v2.0"; - public static final String SNOWFLAKE_ARCTIC_EMBED_M_15 = "Snowflake/snowflake-arctic-embed-m-v1.5"; - - public static class Builder { - private boolean vectorizeCollectionName = false; - private String baseUrl; - private Integer dimensions; - private String model; - - public Builder baseUrl(String baseUrl) { - this.baseUrl = baseUrl; - return this; - } - - public Builder dimensions(int dimensions) { - this.dimensions = dimensions; - return this; - } - - public Builder model(String model) { - this.model = model; - return this; - } - - public Builder vectorizeCollectionName() { - this.vectorizeCollectionName = true; - return this; - } - - public Text2VecWeaviateVectorizer build() { - return new Text2VecWeaviateVectorizer(new HashMap<>() { - { - put("vectorizeClassName", vectorizeCollectionName); - if (baseUrl != null) { - put("baseURL", baseUrl); - } - if (dimensions != null) { - put("dimensions", dimensions); - } - if (model != null) { - put("model", model); - } - } - }); - } - } -} diff --git a/src/main/java/io/weaviate/client6/v1/collections/VectorIndex.java b/src/main/java/io/weaviate/client6/v1/collections/VectorIndex.java deleted file mode 100644 index b8aea7f0c..000000000 --- a/src/main/java/io/weaviate/client6/v1/collections/VectorIndex.java +++ /dev/null @@ -1,36 +0,0 @@ -package io.weaviate.client6.v1.collections; - -import java.util.function.Consumer; - -import com.google.gson.annotations.SerializedName; - -public record VectorIndex( - @SerializedName("vectorIndexType") IndexType type, - @SerializedName("vectorizer") V vectorizer, - @SerializedName("vectorIndexConfig") IndexingStrategy configuration) { - - public enum IndexType { - @SerializedName("hnsw") - HNSW; - } - - public VectorIndex(IndexingStrategy index, V vectorizer) { - this(index.type(), vectorizer, index); - } - - public VectorIndex(V vectorizer) { - this(IndexingStrategy.hnsw(), vectorizer); - } - - public static sealed interface IndexingStrategy permits HNSW { - IndexType type(); - - public static IndexingStrategy hnsw() { - return new HNSW(); - } - - public static IndexingStrategy hnsw(Consumer options) { - return HNSW.with(options); - } - } -} diff --git a/src/main/java/io/weaviate/client6/v1/collections/Vectorizer.java b/src/main/java/io/weaviate/client6/v1/collections/Vectorizer.java deleted file mode 100644 index f2e07be5a..000000000 --- a/src/main/java/io/weaviate/client6/v1/collections/Vectorizer.java +++ /dev/null @@ -1,44 +0,0 @@ -package io.weaviate.client6.v1.collections; - -import java.util.function.Consumer; - -// This class is WIP, I haven't decided how to structure it yet. -public abstract class Vectorizer { - public static NoneVectorizer none() { - return new NoneVectorizer(); - } - - public static ContextionaryVectorizer text2vecContextionary() { - return ContextionaryVectorizer.of(); - } - - public static ContextionaryVectorizer text2vecContextionary(Consumer fn) { - return ContextionaryVectorizer.of(fn); - } - - // TODO: add test cases - public static Text2VecWeaviateVectorizer text2vecWeaviate() { - return Text2VecWeaviateVectorizer.of(); - } - - public static Text2VecWeaviateVectorizer text2vecWeaviate(Consumer fn) { - return Text2VecWeaviateVectorizer.of(fn); - } - - // TODO: add test cases - public static Multi2VecClipVectorizer multi2vecClip() { - return Multi2VecClipVectorizer.of(); - } - - public static Multi2VecClipVectorizer multi2vecClip(Consumer fn) { - return Multi2VecClipVectorizer.of(fn); - } - - public static Img2VecNeuralVectorizer img2VecNeuralVectorizer() { - return Img2VecNeuralVectorizer.of(); - } - - public static Img2VecNeuralVectorizer img2VecNeuralVectorizer(Consumer fn) { - return Img2VecNeuralVectorizer.of(fn); - } -} diff --git a/src/main/java/io/weaviate/client6/v1/collections/Vectors.java b/src/main/java/io/weaviate/client6/v1/collections/Vectors.java deleted file mode 100644 index 345442914..000000000 --- a/src/main/java/io/weaviate/client6/v1/collections/Vectors.java +++ /dev/null @@ -1,80 +0,0 @@ -package io.weaviate.client6.v1.collections; - -import java.util.HashMap; -import java.util.Map; -import java.util.Optional; -import java.util.function.Consumer; - -public class Vectors { - public static final String DEFAULT = "default"; - - private final VectorIndex unnamedVector; - private final Map> namedVectors; - - public static Vectors unnamed(VectorIndex vector) { - return new Vectors(vector); - } - - public static Vectors of(String name, VectorIndex vector) { - return new Vectors(name, vector); - } - - public static Vectors of(VectorIndex vector) { - return new Vectors(DEFAULT, vector); - } - - public static Vectors of(Map> vectors) { - return new Vectors(vectors); - } - - public static Vectors with(Consumer named) { - var vectors = new NamedVectors(named); - return new Vectors(vectors.namedVectors); - } - - public VectorIndex get(String name) { - return namedVectors.get(name); - } - - public Optional> getUnnamed() { - return Optional.ofNullable(unnamedVector); - } - - public VectorIndex getDefault() { - return namedVectors.get(DEFAULT); - } - - // This needs to document the fact that this only returns named vectors. - // Rename to "getNamedVectors()" - public Map asMap() { - return Map.copyOf(namedVectors); - } - - Vectors(VectorIndex vector) { - this.unnamedVector = vector; - this.namedVectors = Map.of(); - } - - Vectors(String name, VectorIndex vector) { - this.unnamedVector = null; - this.namedVectors = Map.of(name, vector); - } - - Vectors(Map> vectors) { - this.unnamedVector = null; - this.namedVectors = vectors; - } - - public static class NamedVectors { - private final Map> namedVectors = new HashMap<>(); - - public NamedVectors vector(String name, VectorIndex vector) { - this.namedVectors.put(name, vector); - return this; - } - - NamedVectors(Consumer options) { - options.accept(this); - } - } -} diff --git a/src/main/java/io/weaviate/client6/v1/collections/data/ConsistencyLevel.java b/src/main/java/io/weaviate/client6/v1/collections/data/ConsistencyLevel.java deleted file mode 100644 index 3347b012c..000000000 --- a/src/main/java/io/weaviate/client6/v1/collections/data/ConsistencyLevel.java +++ /dev/null @@ -1,5 +0,0 @@ -package io.weaviate.client6.v1.collections.data; - -public enum ConsistencyLevel { - ONE, QUORUM, ALL -} diff --git a/src/main/java/io/weaviate/client6/v1/collections/data/DataClient.java b/src/main/java/io/weaviate/client6/v1/collections/data/DataClient.java deleted file mode 100644 index 5580ba95d..000000000 --- a/src/main/java/io/weaviate/client6/v1/collections/data/DataClient.java +++ /dev/null @@ -1,238 +0,0 @@ -package io.weaviate.client6.v1.collections.data; - -import java.io.IOException; -import java.time.OffsetDateTime; -import java.util.Date; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.function.Consumer; -import java.util.stream.Collectors; -import java.util.stream.Stream; - -import org.apache.hc.client5.http.impl.classic.CloseableHttpClient; -import org.apache.hc.client5.http.impl.classic.HttpClients; -import org.apache.hc.core5.http.ClassicHttpRequest; -import org.apache.hc.core5.http.ContentType; -import org.apache.hc.core5.http.HttpStatus; -import org.apache.hc.core5.http.io.entity.EntityUtils; -import org.apache.hc.core5.http.io.support.ClassicRequestBuilder; - -import com.google.gson.Gson; - -import io.weaviate.client6.Config; -import io.weaviate.client6.internal.GRPC; -import io.weaviate.client6.internal.GrpcClient; -import io.weaviate.client6.internal.HttpClient; -import io.weaviate.client6.v1.collections.object.ObjectMetadata; -import io.weaviate.client6.v1.collections.object.ObjectReference; -import io.weaviate.client6.v1.collections.object.Vectors; -import io.weaviate.client6.v1.collections.object.WeaviateObject; -import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase; -import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase.Vectors.VectorType; -import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoProperties.Value; -import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet.MetadataResult; -import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet.PropertiesResult; -import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet.SearchRequest; -import lombok.AllArgsConstructor; - -@AllArgsConstructor -public class DataClient { - // TODO: inject singleton as dependency - private static final Gson gson = new Gson(); - - // TODO: this should be wrapped around in some TypeInspector etc. - private final String collectionName; - - // TODO: hide befind an internal HttpClient - private final Config config; - private final HttpClient httpClient; - private final GrpcClient grpcClient; - - public WeaviateObject insert(T properties) throws IOException { - return insert(properties, opt -> { - }); - } - - public WeaviateObject insert(T properties, Consumer> fn) throws IOException { - return insert(InsertObjectRequest.of(collectionName, properties, fn)); - } - - public WeaviateObject insert(InsertObjectRequest request) throws IOException { - ClassicHttpRequest httpPost = ClassicRequestBuilder - .post(config.baseUrl() + "/objects") - .setEntity(request.serialize(gson), ContentType.APPLICATION_JSON) - .build(); - - return httpClient.http.execute(httpPost, response -> { - var entity = response.getEntity(); - if (response.getCode() != HttpStatus.SC_SUCCESS) { // Does not return 201 - var message = EntityUtils.toString(entity); - throw new RuntimeException("HTTP " + response.getCode() + ": " + message); - } - - return WeaviateObject.fromJson(gson, entity.getContent()); - }); - } - - public Optional> get(String id) throws IOException { - return get(id, q -> { - }); - } - - public Optional> get(String id, Consumer fn) throws IOException { - return findById(FetchByIdRequest.of(collectionName, id, fn)); - } - - private Optional> findById(FetchByIdRequest request) { - var req = SearchRequest.newBuilder(); - req.setUses127Api(true); - req.setUses125Api(true); - req.setUses123Api(true); - request.appendTo(req); - var result = grpcClient.grpc.search(req.build()); - var objects = result.getResultsList().stream().map(r -> { - var tempObj = readPropertiesResult(r.getProperties()); - MetadataResult meta = r.getMetadata(); - Vectors vectors; - if (!meta.getVectorBytes().isEmpty()) { - vectors = Vectors.of(GRPC.fromByteString(meta.getVectorBytes())); - } else { - vectors = Vectors.of(meta.getVectorsList().stream().collect( - Collectors.toMap( - WeaviateProtoBase.Vectors::getName, - v -> { - if (v.getType().equals(VectorType.VECTOR_TYPE_SINGLE_FP32)) { - return GRPC.fromByteString(v.getVectorBytes()); - } else { - return GRPC.fromByteStringMulti(v.getVectorBytes()); - } - }))); - } - var metadata = new ObjectMetadata(meta.getId(), vectors); - return new WeaviateObject<>( - tempObj.collection(), - tempObj.properties(), - tempObj.references(), - metadata); - }).toList(); - if (objects.isEmpty()) { - return Optional.empty(); - } - return Optional.ofNullable((WeaviateObject) objects.get(0)); - } - - private static WeaviateObject readPropertiesResult(PropertiesResult res) { - var collection = res.getTargetCollection(); - var objectProperties = convertProtoMap(res.getNonRefProps().getFieldsMap()); - - // In case a reference is multi-target, there will be a separate - // "reference property" for each of the targets, so instead of - // `collect` we need to `reduce` the map, merging related references - // as we go. - // I.e. { "ref": A-1 } , { "ref": B-1 } => { "ref": [A-1, B-1] } - var referenceProperties = res.getRefPropsList().stream().reduce( - new HashMap(), - (map, ref) -> { - var refObjects = ref.getPropertiesList().stream() - .map(DataClient::readPropertiesResult) - .toList(); - - // Merge ObjectReferences by joining the underlying WeaviateObjects. - map.merge( - ref.getPropName(), - new ObjectReference((List>) refObjects), - (left, right) -> { - var joined = Stream.concat( - left.objects().stream(), - right.objects().stream()).toList(); - return new ObjectReference(joined); - }); - return map; - }, - (left, right) -> { - left.putAll(right); - return left; - }); - - MetadataResult meta = res.getMetadata(); - Vectors vectors; - if (meta.getVectorBytes() != null) { - vectors = Vectors.of(GRPC.fromByteString(meta.getVectorBytes())); - } else { - vectors = Vectors.of(meta.getVectorsList().stream().collect( - Collectors.toMap( - WeaviateProtoBase.Vectors::getName, - v -> { - if (v.getType().equals(VectorType.VECTOR_TYPE_MULTI_FP32)) { - return GRPC.fromByteString(v.getVectorBytes()); - } else { - return GRPC.fromByteStringMulti(v.getVectorBytes()); - } - }))); - } - var metadata = new ObjectMetadata(meta.getId(), vectors); - return new WeaviateObject<>(collection, objectProperties, referenceProperties, metadata); - } - - /* - * Convert Map to Map such that can be - * (de-)serialized by {@link Gson}. - */ - private static Map convertProtoMap(Map map) { - return map.entrySet().stream() - // We cannot use Collectors.toMap() here, because convertProtoValue may - // return null (a collection property can be null), which breaks toMap(). - // See: https://bugs.openjdk.org/browse/JDK-8148463 - .collect( - HashMap::new, - (m, e) -> m.put(e.getKey(), convertProtoValue(e.getValue())), - HashMap::putAll); - } - - /** - * Convert protobuf's Value stub to an Object by extracting the first available - * field. The checks are non-exhaustive and only cover text, boolean, and - * integer values. - */ - private static Object convertProtoValue(Value value) { - if (value.hasNullValue()) { - // return value.getNullValue(); - return null; - } else if (value.hasTextValue()) { - return value.getTextValue(); - } else if (value.hasBoolValue()) { - return value.getBoolValue(); - } else if (value.hasIntValue()) { - return value.getIntValue(); - } else if (value.hasNumberValue()) { - return value.getNumberValue(); - } else if (value.hasBlobValue()) { - return value.getBlobValue(); - } else if (value.hasDateValue()) - - { - OffsetDateTime offsetDateTime = OffsetDateTime.parse(value.getDateValue()); - return Date.from(offsetDateTime.toInstant()); - } else { - assert false : "branch not covered"; - } - return null; - } - - public void delete(String id) throws IOException { - try (CloseableHttpClient httpclient = HttpClients.createDefault()) { - ClassicHttpRequest httpGet = ClassicRequestBuilder - .delete(config.baseUrl() + "/objects/" + collectionName + "/" + id) - .build(); - - httpClient.http.execute(httpGet, response -> { - if (response.getCode() != HttpStatus.SC_NO_CONTENT) { - throw new RuntimeException(EntityUtils.toString(response.getEntity())); - } - return null; - }); - } - } -} diff --git a/src/main/java/io/weaviate/client6/v1/collections/data/FetchByIdRequest.java b/src/main/java/io/weaviate/client6/v1/collections/data/FetchByIdRequest.java deleted file mode 100644 index bc0c7819f..000000000 --- a/src/main/java/io/weaviate/client6/v1/collections/data/FetchByIdRequest.java +++ /dev/null @@ -1,112 +0,0 @@ -package io.weaviate.client6.v1.collections.data; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.function.Consumer; - -import io.weaviate.client6.v1.api.collections.query.QueryReference; -import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase.FilterTarget; -import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase.Filters; -import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase.Filters.Operator; -import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet.MetadataRequest; -import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet.PropertiesRequest; -import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet.RefPropertiesRequest; -import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet.SearchRequest; - -public record FetchByIdRequest( - String collection, - String id, - boolean includeVector, - List includeVectors, - List returnProperties, - List returnReferences) { - - public FetchByIdRequest(Builder options) { - this( - options.collection, - options.uuid, - options.includeVector, - options.includeVectors, - options.returnProperties, - options.returnReferences); - } - - public static FetchByIdRequest of(String collection, String uuid, Consumer fn) { - var builder = new Builder(collection, uuid); - fn.accept(builder); - return new FetchByIdRequest(builder); - } - - public static class Builder { - private final String collection; - private final String uuid; - - public Builder(String collection, String uuid) { - this.collection = collection; - this.uuid = uuid; - } - - private boolean includeVector; - private List includeVectors = new ArrayList<>(); - private List returnProperties = new ArrayList<>(); - private List returnReferences = new ArrayList<>(); - - public final Builder includeVector() { - this.includeVector = true; - return this; - } - - public final Builder includeVectors(String... vectors) { - this.includeVectors = Arrays.asList(vectors); - return this; - } - - public final Builder returnProperties(String... properties) { - this.returnProperties = Arrays.asList(properties); - return this; - } - - public final Builder returnReferences(QueryReference... references) { - this.returnReferences = Arrays.asList(references); - return this; - } - - } - - void appendTo(SearchRequest.Builder req) { - req.setLimit(1); - req.setCollection(collection); - - req.setFilters(Filters.newBuilder() - .setTarget(FilterTarget.newBuilder().setProperty("_id")) - .setValueText(id) - .setOperator(Operator.OPERATOR_EQUAL)); - - if (!returnProperties.isEmpty() || !returnReferences.isEmpty()) { - var properties = PropertiesRequest.newBuilder(); - - if (!returnProperties.isEmpty()) { - properties.addAllNonRefProperties(returnProperties); - } - - if (!returnReferences.isEmpty()) { - returnReferences.forEach(r -> { - var references = RefPropertiesRequest.newBuilder(); - r.appendTo(references); - properties.addRefProperties(references); - }); - } - req.setProperties(properties); - } - - // Always request UUID back in this request. - var metadata = MetadataRequest.newBuilder().setUuid(true); - if (includeVector) { - metadata.setVector(true); - } else if (!includeVectors.isEmpty()) { - metadata.addAllVectors(includeVectors); - } - req.setMetadata(metadata); - } -} diff --git a/src/main/java/io/weaviate/client6/v1/collections/data/GetParameters.java b/src/main/java/io/weaviate/client6/v1/collections/data/GetParameters.java deleted file mode 100644 index a0e4afd43..000000000 --- a/src/main/java/io/weaviate/client6/v1/collections/data/GetParameters.java +++ /dev/null @@ -1,76 +0,0 @@ -package io.weaviate.client6.v1.collections.data; - -import java.util.LinkedHashSet; -import java.util.List; -import java.util.Set; -import java.util.function.Consumer; - -public class GetParameters implements QueryParameters { - private enum Include { - VECTOR, CLASSIFICATION, INTERPRETATION; - - String toLowerCase() { - return this.name().toLowerCase(); - } - } - - private Set include = new LinkedHashSet<>(); // Preserves insertion order, helps testing - private ConsistencyLevel consistency; - private String nodeName; - private String tenant; - - GetParameters(Consumer options) { - options.accept(this); - } - - public GetParameters withVector() { - include.add(Include.VECTOR); - return this; - } - - public GetParameters withClassification() { - include.add(Include.CLASSIFICATION); - return this; - } - - public GetParameters withInterpretation() { - include.add(Include.INTERPRETATION); - return this; - } - - public GetParameters consistencyLevel(ConsistencyLevel consistency) { - this.consistency = consistency; - return this; - } - - public GetParameters nodeName(String name) { - this.nodeName = name; - return this; - } - - public GetParameters tenant(String name) { - this.tenant = name; - return this; - } - - @Override - public String encode() { - var sb = new StringBuilder(); - - if (!include.isEmpty()) { - List includeString = include.stream().map(Include::toLowerCase).toList(); - QueryParameters.addRaw(sb, "include", String.join(",", includeString)); - } - - if (consistency != null) { - QueryParameters.add(sb, "consistency_level", consistency.name()); - } - if (nodeName != null) { - QueryParameters.add(sb, "node_name", nodeName); - } - if (tenant != null) { - QueryParameters.add(sb, "tenant", tenant); - } - return sb.toString(); - } -} diff --git a/src/main/java/io/weaviate/client6/v1/collections/data/InsertObjectRequest.java b/src/main/java/io/weaviate/client6/v1/collections/data/InsertObjectRequest.java deleted file mode 100644 index d2dd6d482..000000000 --- a/src/main/java/io/weaviate/client6/v1/collections/data/InsertObjectRequest.java +++ /dev/null @@ -1,150 +0,0 @@ -package io.weaviate.client6.v1.collections.data; - -import java.io.IOException; -import java.io.StringWriter; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.function.Consumer; -import java.util.function.Function; - -import com.google.gson.Gson; -import com.google.gson.reflect.TypeToken; - -import io.weaviate.client6.v1.api.collections.Vectors; -import io.weaviate.client6.v1.collections.Reference; - -public record InsertObjectRequest(String collection, T properties, String id, Vectors vectors, - Map> references) { - - /** Create InsertObjectRequest from Builder options. */ - public InsertObjectRequest(Builder builder) { - this(builder.collection, builder.properties, builder.id, builder.vectors, builder.references); - } - - /** - * Construct InsertObjectRequest with optional parameters. - * - * @param Shape of the object properties, e.g. - * {@code Map} - * @param collection Collection to insert to. - * @param properties Object properties. - * @param fn Optional parameters - * @return InsertObjectRequest - */ - static InsertObjectRequest of(String collection, T properties, Consumer> fn) { - var builder = new Builder<>(collection, properties); - fn.accept(builder); - return builder.build(); - } - - public static class Builder { - private final String collection; // Required - private final T properties; // Required - - private String id; - private Vectors vectors; - private final Map> references = new HashMap<>(); - - Builder(String collection, T properties) { - this.collection = collection; - this.properties = properties; - } - - /** Define custom object id. Must be a valid UUID. */ - public Builder id(String id) { - this.id = id; - return this; - } - - /** - * Supply one or more (named) vectors. Calls to {@link #vectors} are not - * chainable. Use {@link Vectors#of(Function)} to pass multiple vectors. - */ - public Builder vectors(Vectors vectors) { - this.vectors = vectors; - return this; - } - - /** - * Add a reference. Calls to {@link #reference} can be chained - * to add multiple references. - */ - public Builder reference(String property, Reference... references) { - for (var ref : references) { - addReference(property, ref); - } - return this; - } - - private void addReference(String property, Reference reference) { - if (!references.containsKey(property)) { - references.put(property, new ArrayList<>()); - } - references.get(property).add(reference); - } - - /** Build a new InsertObjectRequest. */ - public InsertObjectRequest build() { - return new InsertObjectRequest<>(this); - } - } - - // Here we're just rawdogging JSON serialization just to get a good feel for it. - public String serialize(Gson gson) throws IOException { - var buf = new StringWriter(); - var w = gson.newJsonWriter(buf); - - w.beginObject(); - - w.name("class"); - w.value(collection); - - if (id != null) { - w.name("id"); - w.value(id); - } - - if (vectors != null) { - var unnamed = vectors.getUnnamed(); - if (unnamed.isPresent()) { - w.name("vector"); - gson.getAdapter(Float[].class).write(w, unnamed.get()); - } else { - w.name("vectors"); - gson.getAdapter(new TypeToken>() { - }).write(w, vectors.getNamed()); - } - } - - if (properties != null || references != null) { - w.name("properties"); - w.beginObject(); - - if (properties != null) { - assert properties instanceof Map : "properties not a Map"; - for (var entry : ((Map) properties).entrySet()) { - w.name(entry.getKey()); - gson.getAdapter(Object.class).write(w, entry.getValue()); - } - - } - if (references != null && !references.isEmpty()) { - for (var entry : references.entrySet()) { - w.name(entry.getKey()); - w.beginArray(); - for (var ref : entry.getValue()) { - ref.writeValue(w); - } - w.endArray(); - } - } - - w.endObject(); - } - - w.endObject(); - return buf.toString(); - } -} diff --git a/src/main/java/io/weaviate/client6/v1/collections/data/QueryParameters.java b/src/main/java/io/weaviate/client6/v1/collections/data/QueryParameters.java deleted file mode 100644 index 6f7eda826..000000000 --- a/src/main/java/io/weaviate/client6/v1/collections/data/QueryParameters.java +++ /dev/null @@ -1,40 +0,0 @@ -package io.weaviate.client6.v1.collections.data; - -import java.io.UnsupportedEncodingException; -import java.net.URLEncoder; -import java.nio.charset.StandardCharsets; -import java.util.function.Consumer; - -interface QueryParameters { - /* Implementations must return an empty string if there're no parameters. */ - String encode(); - - static String encodeGet(Consumer options) { - return with(new GetParameters(options)); - } - - private static

String with(P parameters) { - var encoded = parameters.encode(); - return encoded.isEmpty() ? "" : "?" + encoded; - } - - static void add(StringBuilder sb, String key, String value) { - addRaw(sb, encode(key), encode(value)); - } - - static void addRaw(StringBuilder sb, String key, String value) { - if (!sb.isEmpty()) { - sb.append("&"); - } - sb.append(key).append("=").append(value); - } - - static String encode(String value) { - try { - return URLEncoder.encode(value, StandardCharsets.UTF_8.name()); - } catch (UnsupportedEncodingException e) { - // Will never happen, as we are using standard encoding. - return value; - } - } -} diff --git a/src/main/java/io/weaviate/client6/v1/collections/object/ObjectMetadata.java b/src/main/java/io/weaviate/client6/v1/collections/object/ObjectMetadata.java deleted file mode 100644 index 61ffcb9de..000000000 --- a/src/main/java/io/weaviate/client6/v1/collections/object/ObjectMetadata.java +++ /dev/null @@ -1,30 +0,0 @@ -package io.weaviate.client6.v1.collections.object; - -import java.util.function.Consumer; - -public record ObjectMetadata(String id, Vectors vectors) { - - public static ObjectMetadata with(Consumer options) { - var opt = new Builder(options); - return new ObjectMetadata(opt.id, opt.vectors); - } - - public static class Builder { - private String id; - private Vectors vectors; - - public Builder id(String id) { - this.id = id; - return this; - } - - public Builder vectors(Vectors vectors) { - this.vectors = vectors; - return this; - } - - private Builder(Consumer options) { - options.accept(this); - } - } -} diff --git a/src/main/java/io/weaviate/client6/v1/collections/object/ObjectReference.java b/src/main/java/io/weaviate/client6/v1/collections/object/ObjectReference.java deleted file mode 100644 index bc5c82c04..000000000 --- a/src/main/java/io/weaviate/client6/v1/collections/object/ObjectReference.java +++ /dev/null @@ -1,6 +0,0 @@ -package io.weaviate.client6.v1.collections.object; - -import java.util.List; - -public record ObjectReference(List> objects) { -} diff --git a/src/main/java/io/weaviate/client6/v1/collections/object/Vectors.java b/src/main/java/io/weaviate/client6/v1/collections/object/Vectors.java deleted file mode 100644 index db0669e90..000000000 --- a/src/main/java/io/weaviate/client6/v1/collections/object/Vectors.java +++ /dev/null @@ -1,134 +0,0 @@ -package io.weaviate.client6.v1.collections.object; - -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; -import java.util.Optional; -import java.util.function.Consumer; - -import lombok.ToString; - -/** - * Vectors is an abstraction over named vectors. - * It may contain both 1-dimensional and 2-dimensional vectors. - */ -@ToString -public class Vectors { - private static final String DEFAULT = "default"; - - private Float[] unnamedVector; - private Map namedVectors; - - public Float[] getSingle(String name) { - return (Float[]) namedVectors.get(name); - } - - public Float[] getDefaultSingle() { - return getSingle(DEFAULT); - } - - @SuppressWarnings("unchecked") - public Optional getSingle() { - return (Optional) getOnly(); - } - - public Float[][] getMulti(String name) { - return (Float[][]) namedVectors.get(name); - } - - public Float[][] getDefaultMulti() { - return getMulti(DEFAULT); - } - - @SuppressWarnings("unchecked") - public Optional getMulti() { - return (Optional) getOnly(); - } - - public Optional getUnnamed() { - return Optional.ofNullable(unnamedVector); - } - - private Optional getOnly() { - if (namedVectors == null || namedVectors.isEmpty() || namedVectors.size() > 1) { - return Optional.empty(); - } - return Optional.ofNullable(namedVectors.values().iterator().next()); - } - - public Map getNamed() { - return Map.copyOf(namedVectors); - } - - /** Creates Vectors with a single unnamed vector. */ - private Vectors(Float[] vector) { - this(Map.of()); - this.unnamedVector = vector; - } - - /** Creates Vectors with one named vector. */ - private Vectors(String name, Object vector) { - this.namedVectors = Map.of(name, vector); - } - - /** Creates immutable set of vectors. */ - private Vectors(Map vectors) { - this.namedVectors = Collections.unmodifiableMap(vectors); - } - - private Vectors(NamedVectors named) { - this.namedVectors = named.namedVectors; - } - - /** - * Pass legacy unnamed vector. - * Multi-vectors can only be passed as named vectors. - */ - public static Vectors unnamed(Float[] vector) { - return new Vectors(vector); - } - - public static Vectors of(Float[] vector) { - return new Vectors(DEFAULT, vector); - } - - public static Vectors of(Float[][] vector) { - return new Vectors(DEFAULT, vector); - } - - public static Vectors of(String name, Float[] vector) { - return new Vectors(name, vector); - } - - public static Vectors of(String name, Float[][] vector) { - return new Vectors(name, vector); - } - - public static Vectors of(Map vectors) { - return new Vectors(vectors); - } - - public static Vectors of(Consumer fn) { - var named = new NamedVectors(); - fn.accept(named); - return named.build(); - } - - public static class NamedVectors { - private Map namedVectors = new HashMap<>(); - - public NamedVectors vector(String name, Float[] vector) { - this.namedVectors.put(name, vector); - return this; - } - - public NamedVectors vector(String name, Float[][] vector) { - this.namedVectors.put(name, vector); - return this; - } - - public Vectors build() { - return new Vectors(this); - } - } -} diff --git a/src/main/java/io/weaviate/client6/v1/collections/object/WeaviateObject.java b/src/main/java/io/weaviate/client6/v1/collections/object/WeaviateObject.java deleted file mode 100644 index 28d5cc3b2..000000000 --- a/src/main/java/io/weaviate/client6/v1/collections/object/WeaviateObject.java +++ /dev/null @@ -1,35 +0,0 @@ -package io.weaviate.client6.v1.collections.object; - -import java.io.IOException; -import java.io.InputStream; -import java.io.InputStreamReader; -import java.util.Map; -import java.util.function.Consumer; - -import com.google.common.reflect.TypeToken; -import com.google.gson.Gson; - -public record WeaviateObject( - String collection, - T properties, - Map references, - ObjectMetadata metadata) { - - public WeaviateObject(String collection, T properties, Map references, - Consumer options) { - this(collection, properties, references, ObjectMetadata.with(options)); - } - - // JSON serialization ---------------- - public static WeaviateObject fromJson(Gson gson, InputStream input) throws IOException { - try (var r = new InputStreamReader(input)) { - WeaviateObjectDTO dto = gson.fromJson(r, new TypeToken>() { - }.getType()); - return dto.toWeaviateObject(); - } - } - - public String toJson(Gson gson) { - return gson.toJson(new WeaviateObjectDTO<>(this)); - } -} diff --git a/src/main/java/io/weaviate/client6/v1/collections/object/WeaviateObjectDTO.java b/src/main/java/io/weaviate/client6/v1/collections/object/WeaviateObjectDTO.java deleted file mode 100644 index e57afc1b5..000000000 --- a/src/main/java/io/weaviate/client6/v1/collections/object/WeaviateObjectDTO.java +++ /dev/null @@ -1,49 +0,0 @@ -package io.weaviate.client6.v1.collections.object; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.Map; - -import com.google.gson.annotations.SerializedName; - -class WeaviateObjectDTO { - @SerializedName("class") - String collection; - @SerializedName("id") - String id; - @SerializedName("properties") - T properties; - @SerializedName("vectors") - Map vectors; - - WeaviateObjectDTO(WeaviateObject object) { - this.collection = object.collection(); - this.properties = object.properties(); - - if (object.metadata() != null) { - this.id = object.metadata().id(); - if (object.metadata().vectors() != null) { - this.vectors = object.metadata().vectors().getNamed(); - } - } - } - - WeaviateObject toWeaviateObject() { - Map arrayVectors = new HashMap<>(); - if (vectors != null) { - for (var entry : vectors.entrySet()) { - var value = (ArrayList) entry.getValue(); - var vector = new Float[value.size()]; - int i = 0; - for (var v : value) { - vector[i++] = v.floatValue(); - } - arrayVectors.put(entry.getKey(), vector); - } - } - - return new WeaviateObject(collection, properties, - /* no references through HTTP */ new HashMap<>(), - new ObjectMetadata(id, Vectors.of(arrayVectors))); - } -} diff --git a/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java b/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java index 442947a97..f071c9005 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java +++ b/src/main/java/io/weaviate/client6/v1/internal/grpc/DefaultGrpcTransport.java @@ -82,7 +82,7 @@ private static ManagedChannel buildChannel(GrpcChannelOptions options) { } var headers = new Metadata(); - for (final var header : options.headers()) { + for (final var header : options.headers().entrySet()) { var key = Metadata.Key.of(header.getKey(), Metadata.ASCII_STRING_MARSHALLER); headers.put(key, header.getValue()); diff --git a/src/main/java/io/weaviate/client6/internal/GRPC.java b/src/main/java/io/weaviate/client6/v1/internal/grpc/GRPC.java similarity index 98% rename from src/main/java/io/weaviate/client6/internal/GRPC.java rename to src/main/java/io/weaviate/client6/v1/internal/grpc/GRPC.java index bd9bdd6c1..e5a6c0b5a 100644 --- a/src/main/java/io/weaviate/client6/internal/GRPC.java +++ b/src/main/java/io/weaviate/client6/v1/internal/grpc/GRPC.java @@ -1,4 +1,4 @@ -package io.weaviate.client6.internal; +package io.weaviate.client6.v1.internal.grpc; import java.nio.ByteBuffer; import java.nio.ByteOrder; diff --git a/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcChannelOptions.java b/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcChannelOptions.java index d6fed091a..517345844 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcChannelOptions.java +++ b/src/main/java/io/weaviate/client6/v1/internal/grpc/GrpcChannelOptions.java @@ -1,12 +1,15 @@ package io.weaviate.client6.v1.internal.grpc; -import java.util.Collection; +import java.util.Collections; import java.util.Map; +// TODO: unify with rest.TransportOptions? public interface GrpcChannelOptions { String host(); - Collection> headers(); + default Map headers() { + return Collections.emptyMap(); + } boolean useTls(); } diff --git a/src/main/java/io/weaviate/client6/v1/internal/grpc/WeaviateObjectTypeAdapter.java b/src/main/java/io/weaviate/client6/v1/internal/grpc/WeaviateObjectTypeAdapter.java new file mode 100644 index 000000000..e69de29bb diff --git a/src/main/java/io/weaviate/client6/v1/internal/json/JSON.java b/src/main/java/io/weaviate/client6/v1/internal/json/JSON.java new file mode 100644 index 000000000..91ae742ce --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/internal/json/JSON.java @@ -0,0 +1,54 @@ +package io.weaviate.client6.v1.internal.json; + +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import com.google.gson.reflect.TypeToken; + +public final class JSON { + private static final Gson gson; + + static { + var gsonBuilder = new GsonBuilder(); + gsonBuilder.registerTypeAdapterFactory( + io.weaviate.client6.v1.api.collections.WeaviateObject.CustomTypeAdapterFactory.INSTANCE); + gsonBuilder.registerTypeAdapterFactory( + io.weaviate.client6.v1.api.collections.WeaviateCollection.CustomTypeAdapterFactory.INSTANCE); + gsonBuilder.registerTypeAdapterFactory( + io.weaviate.client6.v1.api.collections.Vectors.CustomTypeAdapterFactory.INSTANCE); + gsonBuilder.registerTypeAdapterFactory( + io.weaviate.client6.v1.api.collections.Vectorizer.CustomTypeAdapterFactory.INSTANCE); + gsonBuilder.registerTypeAdapterFactory( + io.weaviate.client6.v1.api.collections.VectorIndex.CustomTypeAdapterFactory.INSTANCE); + + gsonBuilder.registerTypeAdapter( + io.weaviate.client6.v1.api.collections.vectorizers.NoneVectorizer.class, + io.weaviate.client6.v1.api.collections.vectorizers.NoneVectorizer.TYPE_ADAPTER); + gsonBuilder.registerTypeAdapter( + io.weaviate.client6.v1.api.collections.data.Reference.class, + io.weaviate.client6.v1.api.collections.data.Reference.TYPE_ADAPTER); + gson = gsonBuilder.create(); + } + + public static final Gson getGson() { + return gson; + } + + public static final String serialize(Object value) { + if (value == null) { + return null; + } + return serialize(value, TypeToken.get(value.getClass())); + } + + public static final String serialize(Object value, TypeToken typeToken) { + return gson.toJson(value, typeToken.getType()); + } + + public static final T deserialize(String json, Class cls) { + return gson.fromJson(json, cls); + } + + public static final T deserialize(String json, TypeToken token) { + return gson.fromJson(json, token); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/internal/json/JsonEnum.java b/src/main/java/io/weaviate/client6/v1/internal/json/JsonEnum.java new file mode 100644 index 000000000..c68a64892 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/internal/json/JsonEnum.java @@ -0,0 +1,26 @@ +package io.weaviate.client6.v1.internal.json; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +public interface JsonEnum> { + String jsonValue(); + + static > Map collectNames(JsonEnum[] values) { + final var jsonValueMap = new HashMap(values.length); + for (var value : values) { + @SuppressWarnings("unchecked") + var enumInstance = (E) value; + jsonValueMap.put(value.jsonValue(), enumInstance); + } + return Collections.unmodifiableMap(jsonValueMap); + } + + static > E valueOfJson(String jsonValue, Map enums, Class cls) { + if (!enums.containsKey(jsonValue)) { + throw new IllegalArgumentException("%s does not have a member with jsonValue=%s".formatted(cls, jsonValue)); + } + return enums.get(jsonValue); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/internal/orm/CollectionDescriptor.java b/src/main/java/io/weaviate/client6/v1/internal/orm/CollectionDescriptor.java index 9ab299a16..a122c7214 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/orm/CollectionDescriptor.java +++ b/src/main/java/io/weaviate/client6/v1/internal/orm/CollectionDescriptor.java @@ -1,9 +1,19 @@ package io.weaviate.client6.v1.internal.orm; -public interface CollectionDescriptor { +import java.util.Map; + +import com.google.gson.reflect.TypeToken; + +public sealed interface CollectionDescriptor permits MapDescriptor { String name(); + TypeToken typeToken(); + PropertiesReader propertiesReader(T properties); PropertiesBuilder propertiesBuilder(); + + static CollectionDescriptor> ofMap(String collectionName) { + return new MapDescriptor(collectionName); + } } diff --git a/src/main/java/io/weaviate/client6/v1/internal/orm/MapDescriptor.java b/src/main/java/io/weaviate/client6/v1/internal/orm/MapDescriptor.java index f8c9477ee..2910f2db4 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/orm/MapDescriptor.java +++ b/src/main/java/io/weaviate/client6/v1/internal/orm/MapDescriptor.java @@ -2,7 +2,9 @@ import java.util.Map; -public class MapDescriptor implements CollectionDescriptor> { +import com.google.gson.reflect.TypeToken; + +public final class MapDescriptor implements CollectionDescriptor> { private final String collectionName; public MapDescriptor(String collectionName) { @@ -23,4 +25,10 @@ public PropertiesReader> propertiesReader(Map> propertiesBuilder() { return new MapBuilder(); } + + @SuppressWarnings("unchecked") + @Override + public TypeToken> typeToken() { + return (TypeToken>) TypeToken.getParameterized(Map.class, String.class, Object.class); + } } diff --git a/src/main/java/io/weaviate/client6/v1/internal/rest/DefaultRestTransport.java b/src/main/java/io/weaviate/client6/v1/internal/rest/DefaultRestTransport.java new file mode 100644 index 000000000..470df5e89 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/internal/rest/DefaultRestTransport.java @@ -0,0 +1,108 @@ +package io.weaviate.client6.v1.internal.rest; + +import java.io.IOException; +import java.util.concurrent.CompletableFuture; + +import org.apache.hc.client5.http.async.methods.SimpleHttpRequest; +import org.apache.hc.client5.http.async.methods.SimpleHttpResponse; +import org.apache.hc.client5.http.impl.async.CloseableHttpAsyncClient; +import org.apache.hc.client5.http.impl.async.HttpAsyncClients; +import org.apache.hc.client5.http.impl.classic.CloseableHttpClient; +import org.apache.hc.client5.http.impl.classic.HttpClients; +import org.apache.hc.core5.concurrent.FutureCallback; +import org.apache.hc.core5.http.ClassicHttpRequest; +import org.apache.hc.core5.http.ContentType; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.apache.hc.core5.http.io.support.ClassicRequestBuilder; +import org.apache.hc.core5.io.CloseMode; + +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; + +public class DefaultRestTransport implements RestTransport { + private final CloseableHttpClient httpClient; + private final CloseableHttpAsyncClient httpClientAsync; + private final TransportOptions transportOptions; + + private static final Gson gson = new GsonBuilder().create(); + + public DefaultRestTransport(TransportOptions options) { + this.transportOptions = options; + this.httpClient = HttpClients.createDefault(); + this.httpClientAsync = HttpAsyncClients.createDefault(); + httpClientAsync.start(); + } + + @Override + public ResponseT performRequest(RequestT request, Endpoint endpoint) + throws IOException { + var req = prepareClassicRequest(request, endpoint); + // FIXME: we need to differentiate between "no body" and "soumething's wrong" + return this.httpClient.execute(req, + response -> response.getEntity() != null + ? endpoint.deserializeResponse(gson, EntityUtils.toString(response.getEntity())) + : null); + } + + @Override + public CompletableFuture performRequestAsync(RequestT request, + Endpoint endpoint) { + var req = prepareSimpleRequest(request, endpoint); + + var completable = new CompletableFuture(); + this.httpClientAsync.execute(req, new FutureCallback<>() { + + @Override + public void completed(SimpleHttpResponse result) { + completable.complete(result); + } + + @Override + public void failed(Exception ex) { + completable.completeExceptionally(ex); + } + + @Override + public void cancelled() { + completable.cancel(false); + } + + }); + // FIXME: we need to differentiate between "no body" and "soumething's wrong" + return completable.thenApply(r -> r.getBody() == null + ? endpoint.deserializeResponse(gson, r.getBody().getBodyText()) + : null); + } + + private SimpleHttpRequest prepareSimpleRequest(RequestT request, Endpoint endpoint) { + var method = endpoint.method(request); + var uri = transportOptions.host() + endpoint.requestUrl(request); + // TODO: apply options; + + var body = endpoint.body(gson, request); + var req = SimpleHttpRequest.create(method, uri); + if (body != null) { + req.setBody(body.getBytes(), ContentType.APPLICATION_JSON); + } + return req; + } + + private ClassicHttpRequest prepareClassicRequest(RequestT request, Endpoint endpoint) { + var method = endpoint.method(request); + var uri = transportOptions.host() + endpoint.requestUrl(request); + + // TODO: apply options; + var req = ClassicRequestBuilder.create(method).setUri(uri); + var body = endpoint.body(gson, request); + if (body != null) { + req.setEntity(body, ContentType.APPLICATION_JSON); + } + return req.build(); + } + + @Override + public void close() throws IOException { + httpClient.close(); + httpClientAsync.close(CloseMode.GRACEFUL); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/internal/rest/Endpoint.java b/src/main/java/io/weaviate/client6/v1/internal/rest/Endpoint.java new file mode 100644 index 000000000..7c8998a61 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/internal/rest/Endpoint.java @@ -0,0 +1,65 @@ +package io.weaviate.client6.v1.internal.rest; + +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Function; + +import com.google.gson.Gson; + +public interface Endpoint { + + String method(RequestT request); + + String requestUrl(RequestT request); + + // Gson is leaking. + String body(Gson gson, RequestT request); + + Map queryParameters(RequestT request); + + /** Should this status code be considered an error? */ + boolean isError(int code); + + ResponseT deserializeResponse(Gson gson, String response); + + public static Endpoint of( + Function method, + Function requestUrl, + BiFunction body, + Function> queryParameters, + Function isError, + BiFunction deserialize) { + return new Endpoint() { + + @Override + public String method(RequestT request) { + return method.apply(request); + } + + @Override + public String requestUrl(RequestT request) { + return requestUrl.apply(request); + } + + @Override + public String body(Gson gson, RequestT request) { + return body.apply(gson, request); + } + + @Override + public Map queryParameters(RequestT request) { + return queryParameters.apply(request); + } + + @Override + public ResponseT deserializeResponse(Gson gson, String response) { + return deserialize.apply(gson, response); + } + + @Override + public boolean isError(int code) { + return isError.apply(code); + } + }; + } +} diff --git a/src/main/java/io/weaviate/client6/v1/internal/rest/RestTransport.java b/src/main/java/io/weaviate/client6/v1/internal/rest/RestTransport.java new file mode 100644 index 000000000..b20c98fbd --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/internal/rest/RestTransport.java @@ -0,0 +1,13 @@ +package io.weaviate.client6.v1.internal.rest; + +import java.io.Closeable; +import java.io.IOException; +import java.util.concurrent.CompletableFuture; + +public interface RestTransport extends Closeable { + ResponseT performRequest(RequestT request, Endpoint endpoint) + throws IOException; + + CompletableFuture performRequestAsync(RequestT request, + Endpoint endpoint); +} diff --git a/src/main/java/io/weaviate/client6/v1/internal/rest/TransportOptions.java b/src/main/java/io/weaviate/client6/v1/internal/rest/TransportOptions.java new file mode 100644 index 000000000..9ddb3fa70 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/internal/rest/TransportOptions.java @@ -0,0 +1,12 @@ +package io.weaviate.client6.v1.internal.rest; + +import java.util.Collections; +import java.util.Map; + +public interface TransportOptions { + String host(); + + default Map headers() { + return Collections.emptyMap(); + } +} diff --git a/src/test/java/io/weaviate/client6/internal/DtoTypeAdapterFactoryTest.java b/src/test/java/io/weaviate/client6/internal/DtoTypeAdapterFactoryTest.java deleted file mode 100644 index f3ca920db..000000000 --- a/src/test/java/io/weaviate/client6/internal/DtoTypeAdapterFactoryTest.java +++ /dev/null @@ -1,76 +0,0 @@ -package io.weaviate.client6.internal; - -import org.assertj.core.api.Assertions; -import org.junit.Test; -import org.junit.runner.RunWith; - -import com.google.gson.Gson; -import com.google.gson.GsonBuilder; -import com.google.gson.JsonParser; -import com.jparams.junit4.JParamsTestRunner; -import com.jparams.junit4.data.DataMethod; - -@RunWith(JParamsTestRunner.class) -public class DtoTypeAdapterFactoryTest { - /** Person should be serialized to PersonDto. */ - record Person(String name) { - } - - record PersonDto(String nickname) implements DtoTypeAdapterFactory.Dto { - PersonDto(Person p) { - this(p.name); - } - - @Override - public Person toModel() { - return new Person(nickname); - } - } - - /** Car's DTO is a nested record. */ - record Car(String brand) { - record CarDto(String manufacturer, Integer version) implements DtoTypeAdapterFactory.Dto { - CarDto(Car c) { - this(c.brand, 1); - } - - @Override - public Car toModel() { - return new Car(manufacturer); - } - } - } - - /** Normal does not have a DTO and should be serialized as usual. */ - record Normal(String key, String value) { - } - - static { - DtoTypeAdapterFactory.register(Person.class, PersonDto.class, m -> new PersonDto(m)); - DtoTypeAdapterFactory.register(Car.class, Car.CarDto.class, m -> new Car.CarDto(m)); - } - - private static final Gson gson = new GsonBuilder() - .registerTypeAdapterFactory(new DtoTypeAdapterFactory()) - .create(); - - public static Object[][] testCases() { - return new Object[][] { - { new Person("Josh"), "{\"nickname\": \"Josh\"}" }, - { new Car("Porsche"), "{\"manufacturer\": \"Porsche\", \"version\": 1}" }, - { new Normal("foo", "bar"), "{\"key\": \"foo\", \"value\": \"bar\"}" }, - }; - } - - @Test - @DataMethod(source = DtoTypeAdapterFactoryTest.class, method = "testCases") - public void testRoundtrip(Object model, String wantJson) { - var gotJson = gson.toJson(model); - Assertions.assertThat(JsonParser.parseString(gotJson)) - .as("serialized") - .isEqualTo(JsonParser.parseString(wantJson)); - - var deserialized = gson.fromJson(gotJson, model.getClass()); - Assertions.assertThat(deserialized).as("deserialized").isEqualTo(model); - } -} diff --git a/src/test/java/io/weaviate/client6/v1/ObjectMetadataTest.java b/src/test/java/io/weaviate/client6/v1/ObjectMetadataTest.java deleted file mode 100644 index f95446fdc..000000000 --- a/src/test/java/io/weaviate/client6/v1/ObjectMetadataTest.java +++ /dev/null @@ -1,96 +0,0 @@ -package io.weaviate.client6.v1; - -import java.util.Optional; - -import org.assertj.core.api.Assertions; -import org.junit.Test; - -import io.weaviate.client6.v1.collections.object.ObjectMetadata; -import io.weaviate.client6.v1.collections.object.Vectors; - -public class ObjectMetadataTest { - - @Test - public final void testMetadata_id() { - var metadata = ObjectMetadata.with(m -> m.id("object-1")); - Assertions.assertThat(metadata.id()) - .as("object id").isEqualTo("object-1"); - } - - @Test - public final void testVectorsMetadata_unnamed() { - Float[] vector = { 1f, 2f, 3f }; - var metadata = ObjectMetadata.with(m -> m.vectors(Vectors.unnamed(vector))); - - Assertions.assertThat(metadata.vectors()) - .as("unnamed vector").isNotNull() - .returns(Optional.of(vector), Vectors::getUnnamed) - .returns(Optional.empty(), Vectors::getSingle); - } - - @Test - public final void testVectorsMetadata_default() { - Float[] vector = { 1f, 2f, 3f }; - var metadata = ObjectMetadata.with(m -> m.vectors(Vectors.of(vector))); - - Assertions.assertThat(metadata.vectors()) - .as("default vector").isNotNull() - .returns(vector, Vectors::getDefaultSingle) - .returns(Optional.of(vector), Vectors::getSingle) - .returns(Optional.empty(), Vectors::getUnnamed); - } - - @Test - public final void testVectorsMetadata_default_2d() { - Float[][] vector = { { 1f, 2f, 3f }, { 1f, 2f, 3f } }; - var metadata = ObjectMetadata.with(m -> m.vectors(Vectors.of(vector))); - - Assertions.assertThat(metadata.vectors()) - .as("default 2d vector").isNotNull() - .returns(vector, Vectors::getDefaultMulti) - .returns(Optional.of(vector), Vectors::getMulti) - .returns(Optional.empty(), Vectors::getUnnamed); - } - - @Test - public final void testVectorsMetadata_named() { - Float[] vector = { 1f, 2f, 3f }; - var metadata = ObjectMetadata.with(m -> m.vectors(Vectors.of("vector-1", vector))); - - Assertions.assertThat(metadata.vectors()) - .as("named vector").isNotNull() - .returns(vector, v -> v.getSingle("vector-1")) - .returns(Optional.of(vector), Vectors::getSingle) - .returns(null, Vectors::getDefaultSingle); - } - - @Test - public final void testVectorsMetadata_named_2d() { - Float[][] vector = { { 1f, 2f, 3f }, { 1f, 2f, 3f } }; - var metadata = ObjectMetadata.with(m -> m.vectors(Vectors.of("vector-1", vector))); - - Assertions.assertThat(metadata.vectors()) - .as("named 2d vector").isNotNull() - .returns(vector, v -> v.getMulti("vector-1")) - .returns(Optional.of(vector), Vectors::getMulti) - .returns(null, Vectors::getDefaultMulti); - } - - @Test - public final void testVectorsMetadata_multiple_named() { - Float[][] vector_1 = { { 1f, 2f, 3f }, { 1f, 2f, 3f } }; - Float[] vector_2 = { 4f, 5f, 6f }; - var metadata = ObjectMetadata.with(m -> m.vectors( - Vectors.of(named -> named - .vector("vector-1", vector_1) - .vector("vector-2", vector_2)))); - - Assertions.assertThat(metadata.vectors()) - .as("multiple named vectors").isNotNull() - .returns(vector_1, v -> v.getMulti("vector-1")) - .returns(vector_2, v -> v.getSingle("vector-2")) - .returns(Optional.empty(), Vectors::getMulti) - .returns(Optional.empty(), Vectors::getSingle) - .returns(null, Vectors::getDefaultMulti); - } -} diff --git a/src/test/java/io/weaviate/client6/v1/collections/VectorsTest.java b/src/test/java/io/weaviate/client6/v1/collections/VectorsTest.java deleted file mode 100644 index 949a8dbb4..000000000 --- a/src/test/java/io/weaviate/client6/v1/collections/VectorsTest.java +++ /dev/null @@ -1,135 +0,0 @@ -package io.weaviate.client6.v1.collections; - -import java.io.IOException; -import java.util.List; -import java.util.Map; - -import org.assertj.core.api.Assertions; -import org.junit.Test; -import org.junit.runner.RunWith; - -import com.google.gson.Gson; -import com.google.gson.GsonBuilder; -import com.google.gson.JsonParser; -import com.google.gson.TypeAdapter; -import com.google.gson.stream.JsonReader; -import com.google.gson.stream.JsonWriter; -import com.jparams.junit4.JParamsTestRunner; -import com.jparams.junit4.data.DataMethod; - -import io.weaviate.client6.internal.DtoTypeAdapterFactory; -import io.weaviate.client6.v1.collections.VectorIndex.IndexingStrategy; - -@RunWith(JParamsTestRunner.class) -public class VectorsTest { - // private static final Gson gson = new Gson(); - - static { - DtoTypeAdapterFactory.register(Collection.class, CollectionDefinitionDTO.class, - m -> new CollectionDefinitionDTO(m)); - } - private static final Gson gson = new GsonBuilder() - .registerTypeAdapterFactory(new DtoTypeAdapterFactory()) - // TODO: create TypeAdapters via TypeAdapterFactory - .registerTypeAdapter(Vectors.class, new TypeAdapter() { - Gson gson = new Gson(); - - @Override - public void write(JsonWriter out, Vectors value) throws IOException { - gson.toJson(value.asMap(), Map.class, out); - } - - @Override - public Vectors read(JsonReader in) throws IOException { - // TODO Auto-generated method stub - throw new UnsupportedOperationException("Unimplemented method 'read'"); - } - - }) - .create(); - - public static Object[][] testCases() { - return new Object[][] { - { - """ - { - "vectorConfig": { - "default": { - "vectorizer": { "none": {}}, - "vectorIndexType": "hnsw", - "vectorIndexConfig": {} - } - } - } - """, - collectionWithVectors(Vectors.of(new VectorIndex<>(Vectorizer.none()))), - new String[] { "vectorConfig" }, - }, - { - """ - { - "vectorConfig": { - "vector-1": { - "vectorizer": { "none": {}}, - "vectorIndexType": "hnsw", - "vectorIndexConfig": {} - }, - "vector-2": { - "vectorizer": { "none": {}}, - "vectorIndexType": "hnsw", - "vectorIndexConfig": {} - } - } - } - """, - collectionWithVectors(Vectors.with(named -> named - .vector("vector-1", new VectorIndex<>(Vectorizer.none())) - .vector("vector-2", new VectorIndex<>(IndexingStrategy.hnsw(), Vectorizer.none())))), - new String[] { "vectorConfig" }, - }, - { - """ - { - "vectorizer": { "none": {}}, - "vectorIndexType": "hnsw", - "vectorIndexConfig": { "distance": "COSINE", "skip": true } - } - """, - collectionWithVectors(Vectors.unnamed( - new VectorIndex<>( - IndexingStrategy.hnsw(opt -> opt - .distance(HNSW.Distance.COSINE) - .disableIndexation()), - Vectorizer.none()))), - new String[] { "vectorIndexType", "vectorIndexConfig", "vectorizer" }, - }, - }; - } - - @Test - @DataMethod(source = VectorsTest.class, method = "testCases") - public void test_toJson(String want, Collection collection, String... compareKeys) { - var got = gson.toJson(collection); - assertEqual(want, got, compareKeys); - } - - private static Collection collectionWithVectors(Vectors vectors) { - return new Collection("Things", List.of(), List.of(), vectors); - } - - private void assertEqual(String wantJson, String gotJson, String... compareKeys) { - var want = JsonParser.parseString(wantJson).getAsJsonObject(); - var got = JsonParser.parseString(gotJson).getAsJsonObject(); - - if (compareKeys == null || compareKeys.length == 0) { - Assertions.assertThat(got).isEqualTo(want); - return; - } - - for (var key : compareKeys) { - Assertions.assertThat(got.get(key)) - .isEqualTo(want.get(key)) - .as(key); - } - } -} diff --git a/src/test/java/io/weaviate/client6/v1/collections/data/QueryParametersTest.java b/src/test/java/io/weaviate/client6/v1/collections/data/QueryParametersTest.java deleted file mode 100644 index 275c8a9b7..000000000 --- a/src/test/java/io/weaviate/client6/v1/collections/data/QueryParametersTest.java +++ /dev/null @@ -1,48 +0,0 @@ -package io.weaviate.client6.v1.collections.data; - -import org.assertj.core.api.Assertions; -import org.junit.Test; -import org.junit.runner.RunWith; - -import com.jparams.junit4.JParamsTestRunner; -import com.jparams.junit4.data.DataMethod; - -@RunWith(JParamsTestRunner.class) -public class QueryParametersTest { - - public static Object[][] testCases() { - return new Object[][] { - { - QueryParameters.encodeGet(q -> q - .withVector() - .nodeName("node-1")), - "?include=vector&node_name=node-1", - }, - { - QueryParameters.encodeGet(q -> q - .withVector() - .withClassification() - .tenant("JohnDoe")), - "?include=vector,classification&tenant=JohnDoe", - }, - { - QueryParameters.encodeGet(q -> q - .consistencyLevel(ConsistencyLevel.ALL) - .nodeName("node-1") - .tenant("JohnDoe")), - "?consistency_level=ALL&node_name=node-1&tenant=JohnDoe", - }, - { - QueryParameters.encodeGet(q -> { - }), - "", - }, - }; - } - - @Test - @DataMethod(source = QueryParametersTest.class, method = "testCases") - public void testEncode(String got, String want) { - Assertions.assertThat(got).isEqualTo(want).as("expected query parameters"); - } -} diff --git a/src/test/java/io/weaviate/client6/internal/GRPCTest.java b/src/test/java/io/weaviate/client6/v1/internal/grpc/GRPCTest.java similarity index 97% rename from src/test/java/io/weaviate/client6/internal/GRPCTest.java rename to src/test/java/io/weaviate/client6/v1/internal/grpc/GRPCTest.java index ab14b6aaa..d18f4a00e 100644 --- a/src/test/java/io/weaviate/client6/internal/GRPCTest.java +++ b/src/test/java/io/weaviate/client6/v1/internal/grpc/GRPCTest.java @@ -1,4 +1,4 @@ -package io.weaviate.client6.internal; +package io.weaviate.client6.v1.internal.grpc; import static org.junit.Assert.assertArrayEquals; diff --git a/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java b/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java new file mode 100644 index 000000000..0be90cba7 --- /dev/null +++ b/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java @@ -0,0 +1,309 @@ +package io.weaviate.client6.v1.internal.json; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.function.BiConsumer; + +import org.assertj.core.api.Assertions; +import org.junit.Test; +import org.junit.runner.RunWith; + +import com.google.gson.JsonParser; +import com.google.gson.reflect.TypeToken; +import com.jparams.junit4.JParamsTestRunner; +import com.jparams.junit4.data.DataMethod; + +import io.weaviate.client6.v1.api.collections.ObjectMetadata; +import io.weaviate.client6.v1.api.collections.Property; +import io.weaviate.client6.v1.api.collections.VectorIndex; +import io.weaviate.client6.v1.api.collections.Vectorizer; +import io.weaviate.client6.v1.api.collections.Vectors; +import io.weaviate.client6.v1.api.collections.WeaviateCollection; +import io.weaviate.client6.v1.api.collections.WeaviateObject; +import io.weaviate.client6.v1.api.collections.data.Reference; +import io.weaviate.client6.v1.api.collections.vectorindex.Distance; +import io.weaviate.client6.v1.api.collections.vectorindex.Flat; +import io.weaviate.client6.v1.api.collections.vectorindex.Hnsw; +import io.weaviate.client6.v1.api.collections.vectorizers.Img2VecNeuralVectorizer; +import io.weaviate.client6.v1.api.collections.vectorizers.Multi2VecClipVectorizer; +import io.weaviate.client6.v1.api.collections.vectorizers.NoneVectorizer; +import io.weaviate.client6.v1.api.collections.vectorizers.Text2VecContextionaryVectorizer; +import io.weaviate.client6.v1.api.collections.vectorizers.Text2VecWeaviateVectorizer; + +/** Unit tests for custom POJO-to-JSON serialization. */ +@RunWith(JParamsTestRunner.class) +public class JSONTest { + public static Object[][] testCases() { + return new Object[][] { + // Vectorizer.CustomTypeAdapterFactory + { + Vectorizer.class, + new NoneVectorizer(), + "{\"none\": {}}", + }, + { + Vectorizer.class, + Img2VecNeuralVectorizer.of(i2v -> i2v.imageFields("jpeg", "png")), + """ + {"img2vec-neural": { + "imageFields": ["jpeg", "png"] + }} + """, + }, + { + Vectorizer.class, + Multi2VecClipVectorizer.of(m2v -> m2v + .inferenceUrl("http://example.com") + .imageField("img", 1f) + .textField("txt", 2f) + .vectorizeCollectionName(true)), + """ + {"multi2vec-clip": { + "inferenceUrl": "http://example.com", + "vectorizeClassName": true, + "imageFields": ["img"], + "textFields": ["txt"], + "weights": { + "imageWeights": [1.0], + "textWeights": [2.0] + } + }} + """, + }, + { + Vectorizer.class, + Text2VecContextionaryVectorizer.of(t2v -> t2v + .vectorizeCollectionName(true)), + """ + {"text2vec-contextionary": { + "vectorizeClassName": true + }} + """, + }, + { + Vectorizer.class, + Text2VecWeaviateVectorizer.of(t2v -> t2v + .inferenceUrl("http://example.com") + .dimensions(4) + .model("very-good-model") + .vectorizeCollectionName(true)), + """ + {"text2vec-weaviate": { + "baseUrl": "http://example.com", + "vectorizeClassName": true, + "dimensions": 4, + "model": "very-good-model" + }} + """, + }, + + // VectorIndex.CustomTypeAdapterFactory + { + VectorIndex.class, + Flat.of(new NoneVectorizer(), flat -> flat + .vectorCacheMaxObjects(100)), + """ + { + "vectorIndexType": "flat", + "vectorizer": {"none": {}}, + "vectorIndexConfig": {"vectorCacheMaxObjects": 100} + } + """, + }, + { + VectorIndex.class, + Hnsw.of(new NoneVectorizer(), hnsw -> hnsw + .distance(Distance.DOT) + .ef(1) + .efConstruction(2) + .maxConnections(3) + .vectorCacheMaxObjects(4) + .cleanupIntervalSeconds(5) + .dynamicEfMin(6) + .dynamicEfMax(7) + .dynamicEfFactor(8) + .flatSearchCutoff(9) + .skipVectorization(true) + .filterStrategy(Hnsw.FilterStrategy.ACORN)), + """ + { + "vectorIndexType": "hnsw", + "vectorizer": {"none": {}}, + "vectorIndexConfig": { + "distance": "dot", + "ef": 1, + "efConstruction": 2, + "maxConnections": 3, + "vectorCacheMaxObjects": 4, + "cleanupIntervalSeconds": 5, + "dynamicEfMin": 6, + "dynamicEfMax": 7, + "dynamicEfFactor": 8, + "flatSearchCutoff": 9, + "skip": true, + "filterStrategy":"acorn" + } + } + """, + }, + + // Vectors.CustomTypeAdapterFactory + { + Vectors.class, + Vectors.of(new Float[] { 1f, 2f }), + "{\"default\": [1.0, 2.0]}", + (CustomAssert) JSONTest::compareVectors, + }, + { + Vectors.class, + Vectors.of(new Float[][] { { 1f, 2f }, { 3f, 4f } }), + "{\"default\": [[1.0, 2.0], [3.0, 4.0]]}", + (CustomAssert) JSONTest::compareVectors, + }, + { + Vectors.class, + Vectors.of("custom", new Float[] { 1f, 2f }), + "{\"custom\": [1.0, 2.0]}", + (CustomAssert) JSONTest::compareVectors, + }, + { + Vectors.class, + Vectors.of("custom", new Float[][] { { 1f, 2f }, { 3f, 4f } }), + "{\"custom\": [[1.0, 2.0], [3.0, 4.0]]}", + (CustomAssert) JSONTest::compareVectors, + }, + { + Vectors.class, + Vectors.of(named -> named + .vector("1d", new Float[] { 1f, 2f }) + .vector("2d", new Float[][] { { 1f, 2f }, { 3f, 4f } })), + "{\"1d\": [1.0, 2.0], \"2d\": [[1.0, 2.0], [3.0, 4.0]]}", + (CustomAssert) JSONTest::compareVectors, + }, + + // WeaviateCollection.CustomTypeAdapterFactory + { + WeaviateCollection.class, + WeaviateCollection.of("Things", things -> things + .description("A collection of things") + .properties( + Property.text("shape"), + Property.integer("size")) + .references( + Property.reference("owner", "Person", "Company")) + .vectors(named -> named + .vector("v-shape", Hnsw.of(Img2VecNeuralVectorizer.of( + i2v -> i2v.imageFields("img")))))), + """ + { + "class": "Things", + "description": "A collection of things", + "properties": [ + {"name": "shape", "dataType": ["text"]}, + {"name": "size", "dataType": ["int"]}, + {"name": "owner", "dataType": ["Person", "Company"]} + ], + "vectorConfig": { + "v-shape": { + "vectorIndexType": "hnsw", + "vectorIndexConfig": {}, + "vectorizer": {"img2vec-neural": { + "imageFields": ["img"] + }} + } + } + } + """, + }, + + // Reference.TYPE_ADAPTER + { + Reference.class, + Reference.uuids("id-1"), + "{\"beacon\": \"weaviate://localhost/id-1\"}", + }, + { + Reference.class, + Reference.collection("Doodlebops", "id-1"), + "{\"beacon\": \"weaviate://localhost/Doodlebops/id-1\"}", + }, + + // WeaviateObject.CustomTypeAdapterFactory.INSTANCE + { + new TypeToken, Reference, ObjectMetadata>>() { + }, + new WeaviateObject<>( + "Things", + Map.of("title", "ThingOne"), + Map.of("hasRef", List.of(Reference.uuids("ref-1"))), + ObjectMetadata.of(meta -> meta.id("thing-1"))), + """ + { + "class": "Things", + "properties": { + "title": "ThingOne", + "hasRef": [{"beacon": "weaviate://localhost/ref-1"}] + }, + "id": "thing-1" + } + """, + }, + }; + } + + @Test + @DataMethod(source = JSONTest.class, method = "testCases") + public void test_serialize(Object cls, Object in, String want) { + String got; + if (cls instanceof TypeToken typeToken) { + got = JSON.serialize(in, typeToken); + } else { + got = JSON.serialize(in); + } + assertEqualJson(want, got); + + } + + private interface CustomAssert extends BiConsumer { + } + + @Test + @SuppressWarnings("unchecked") + @DataMethod(source = JSONTest.class, method = "testCases") + public void test_deserialize(Object target, Object want, String in, CustomAssert assertion) { + + Object got; + if (target instanceof Class targetClass) { + got = JSON.deserialize(in, targetClass); + } else if (target instanceof TypeToken targetToken) { + got = JSON.deserialize(in, targetToken); + } else { + throw new IllegalArgumentException("target must be either Class or TypeToken"); + } + + if (assertion != null) { + assertion.accept(got, want); + } else { + Assertions.assertThat(got).isEqualTo(want); + } + } + + private static void assertEqualJson(String want, String got) { + var wantJson = JsonParser.parseString(want); + var gotJson = JsonParser.parseString(got); + Assertions.assertThat(gotJson).isEqualTo(wantJson); + } + + /** + * Custom assert function that uses deep array equality + * to correctly compare Float[] and Float[][] nested in the object. + */ + private static void compareVectors(Object got, Object want) { + Assertions.assertThat(got) + .usingRecursiveComparison() + .withEqualsForType(Arrays::equals, Float[].class) + .withEqualsForType(Arrays::deepEquals, Float[][].class) + .isEqualTo(want); + } +}