diff --git a/src/it/java/io/weaviate/client6/v1/collections/CollectionsITest.java b/src/it/java/io/weaviate/client6/v1/collections/CollectionsITest.java deleted file mode 100644 index dd50d69e1..000000000 --- a/src/it/java/io/weaviate/client6/v1/collections/CollectionsITest.java +++ /dev/null @@ -1,41 +0,0 @@ -package io.weaviate.client6.v1.collections; - -import java.io.IOException; - -import org.assertj.core.api.Assertions; -import org.junit.Test; - -import io.weaviate.ConcurrentTest; -import io.weaviate.client6.WeaviateClient; -import io.weaviate.client6.v1.collections.VectorIndex.IndexType; -import io.weaviate.client6.v1.collections.VectorIndex.IndexingStrategy; -import io.weaviate.containers.Container; - -public class CollectionsITest extends ConcurrentTest { - private static WeaviateClient client = Container.WEAVIATE.getClient(); - - @Test - public void testCreateGetDelete() throws IOException { - var collectionName = ns("Things_1"); - client.collections.create(collectionName, - col -> col - .properties(Property.text("username"), Property.integer("age")) - .vector(new VectorIndex<>(IndexingStrategy.hnsw(), Vectorizer.none()))); - - var thingsCollection = client.collections.getConfig(collectionName); - - Assertions.assertThat(thingsCollection).get() - .hasFieldOrPropertyWithValue("name", collectionName) - .extracting(CollectionDefinition::vectors).extracting(Vectors::getDefault) - .as("default vector").satisfies(defaultVector -> { - Assertions.assertThat(defaultVector).extracting(VectorIndex::vectorizer) - .as("has none vectorizer").isInstanceOf(NoneVectorizer.class); - Assertions.assertThat(defaultVector).extracting(VectorIndex::configuration) - .as("has hnsw index").returns(IndexType.HNSW, IndexingStrategy::type); - }); - - client.collections.delete(collectionName); - var noCollection = client.collections.getConfig(collectionName); - Assertions.assertThat(noCollection).as("after delete").isEmpty(); - } -} diff --git a/src/it/java/io/weaviate/containers/Container.java b/src/it/java/io/weaviate/containers/Container.java index 9ef69cfe8..974e56a9e 100644 --- a/src/it/java/io/weaviate/containers/Container.java +++ b/src/it/java/io/weaviate/containers/Container.java @@ -54,6 +54,8 @@ public static class Group implements Startable { private Group(Weaviate weaviate, GenericContainer... containers) { this.weaviate = weaviate; this.containers = Arrays.asList(containers); + + weaviate.dependsOn(containers); setSharedNetwork(); } @@ -63,8 +65,7 @@ public WeaviateClient getClient() { @Override public void start() { - containers.forEach(GenericContainer::start); - weaviate.start(); + weaviate.start(); // testcontainers will resolve dependencies } @Override diff --git a/src/it/java/io/weaviate/containers/Contextionary.java b/src/it/java/io/weaviate/containers/Contextionary.java index 76ec5aefd..69abde7df 100644 --- a/src/it/java/io/weaviate/containers/Contextionary.java +++ b/src/it/java/io/weaviate/containers/Contextionary.java @@ -35,7 +35,7 @@ public Contextionary build() { .withEnv("EXTENSIONS_STORAGE_ORIGIN", "http://weaviate:8080") .withEnv("NEIGHBOR_OCCURRENCE_IGNORE_PERCENTILE", "5") .withEnv("ENABLE_COMPOUND_SPLITTING", "'false'"); - container.withCreateContainerCmdModifier(cmd -> cmd.withHostName("contextionary")); + container.withCreateContainerCmdModifier(cmd -> cmd.withHostName(HOST_NAME)); return container; } } diff --git a/src/it/java/io/weaviate/containers/Weaviate.java b/src/it/java/io/weaviate/containers/Weaviate.java index 6126d5150..5c7c8e648 100644 --- a/src/it/java/io/weaviate/containers/Weaviate.java +++ b/src/it/java/io/weaviate/containers/Weaviate.java @@ -69,7 +69,13 @@ public Builder withDefaultVectorizer(String module) { return this; } + public Builder withContextionary() { + addModule(Contextionary.MODULE); + return this; + } + public Builder withContextionaryUrl(String url) { + withContextionary(); contextionaryUrl = url; return this; } diff --git a/src/it/java/io/weaviate/integration/AggregationITest.java b/src/it/java/io/weaviate/integration/AggregationITest.java index bd54ed865..035e26b8b 100644 --- a/src/it/java/io/weaviate/integration/AggregationITest.java +++ b/src/it/java/io/weaviate/integration/AggregationITest.java @@ -16,13 +16,13 @@ import io.weaviate.client6.v1.collections.Property; import io.weaviate.client6.v1.collections.VectorIndex; import io.weaviate.client6.v1.collections.Vectorizer; -import io.weaviate.client6.v1.collections.Vectors; import io.weaviate.client6.v1.collections.aggregate.AggregateGroupByRequest.GroupBy; import io.weaviate.client6.v1.collections.aggregate.AggregateGroupByResponse; import io.weaviate.client6.v1.collections.aggregate.Group; import io.weaviate.client6.v1.collections.aggregate.GroupedBy; import io.weaviate.client6.v1.collections.aggregate.IntegerMetric; import io.weaviate.client6.v1.collections.aggregate.Metric; +import io.weaviate.client6.v1.collections.object.Vectors; import io.weaviate.containers.Container; public class AggregationITest extends ConcurrentTest { @@ -36,7 +36,7 @@ public static void beforeAll() throws IOException { .properties( Property.text("category"), Property.integer("price")) - .vectors(Vectors.of(new VectorIndex<>(Vectorizer.none())))); + .vectors(io.weaviate.client6.v1.collections.Vectors.of(new VectorIndex<>(Vectorizer.none())))); var things = client.collections.use(COLLECTION); for (var category : List.of("Shoes", "Hat", "Jacket")) { @@ -47,7 +47,7 @@ public static void beforeAll() throws IOException { things.data.insert(Map.of( "category", category, "price", category.length()), - meta -> meta.vectors(vector)); + meta -> meta.vectors(Vectors.of(vector))); } } } diff --git a/src/it/java/io/weaviate/integration/CollectionsITest.java b/src/it/java/io/weaviate/integration/CollectionsITest.java new file mode 100644 index 000000000..80c19c36c --- /dev/null +++ b/src/it/java/io/weaviate/integration/CollectionsITest.java @@ -0,0 +1,91 @@ +package io.weaviate.integration; + +import java.io.IOException; + +import org.assertj.core.api.Assertions; +import org.assertj.core.api.InstanceOfAssertFactories; +import org.junit.Test; + +import io.weaviate.ConcurrentTest; +import io.weaviate.client6.WeaviateClient; +import io.weaviate.client6.v1.collections.Collection; +import io.weaviate.client6.v1.collections.Property; +import io.weaviate.client6.v1.collections.VectorIndex; +import io.weaviate.client6.v1.collections.VectorIndex.IndexType; +import io.weaviate.client6.v1.collections.VectorIndex.IndexingStrategy; +import io.weaviate.client6.v1.collections.Vectorizer; +import io.weaviate.client6.v1.collections.Vectors; +import io.weaviate.containers.Container; +import io.weaviate.containers.Contextionary; + +public class CollectionsITest extends ConcurrentTest { + private static WeaviateClient client = Container.WEAVIATE.getClient(); + + @Test + public void testCreateGetDelete() throws IOException { + var collectionName = ns("Things"); + client.collections.create(collectionName, + col -> col + .properties(Property.text("username"), Property.integer("age")) + .vector(new VectorIndex<>(IndexingStrategy.hnsw(), Vectorizer.none()))); + + var thingsCollection = client.collections.getConfig(collectionName); + + Assertions.assertThat(thingsCollection).get() + .hasFieldOrPropertyWithValue("name", collectionName) + .extracting(Collection::vectors).extracting(Vectors::getDefault) + .as("default vector").satisfies(defaultVector -> { + Assertions.assertThat(defaultVector).extracting(VectorIndex::vectorizer) + .as("has none vectorizer").isInstanceOf(Contextionary.class); + Assertions.assertThat(defaultVector).extracting(VectorIndex::configuration) + .as("has hnsw index").returns(IndexType.HNSW, IndexingStrategy::type); + }); + + client.collections.delete(collectionName); + var noCollection = client.collections.getConfig(collectionName); + Assertions.assertThat(noCollection).as("after delete").isEmpty(); + } + + @Test + public void testCrossReferences() throws IOException { + // Arrange: Create Owners collection + var nsOwners = ns("Owners"); + client.collections.create(nsOwners); + + // Act: Create Things collection with owner -> owners + var nsThings = ns("Things"); + client.collections.create(nsThings, + col -> col.properties(Property.reference("ownedBy", nsOwners))); + var things = client.collections.use(nsThings); + + // Assert: Things --ownedBy-> Owners + Assertions.assertThat(things.config.get()) + .as("after create Things").get() + .satisfies(c -> { + Assertions.assertThat(c.properties()) + .as("ownedBy").filteredOn(p -> p.name().equals("ownedBy")).first() + .extracting(p -> p.dataTypes()).asInstanceOf(InstanceOfAssertFactories.LIST) + .containsOnly(nsOwners); + }); + + // Arrange: Create OnlineStores and Markets collections + var nsOnlineStores = ns("OnlineStores"); + client.collections.create(nsOnlineStores); + + var nsMarkets = ns("Markets"); + client.collections.create(nsMarkets); + + // Act: Update Things collections to add polymorphic reference + things.config.addProperty(Property.reference("soldIn", nsOnlineStores, nsMarkets)); + + // Assert: Things --soldIn-> [OnlineStores, Markets] + Assertions.assertThat(things.config.get()) + .as("after add property").get() + .satisfies(c -> { + Assertions.assertThat(c.properties()) + .as("soldIn").filteredOn(p -> p.name().equals("soldIn")).first() + .extracting(p -> p.dataTypes()).asInstanceOf(InstanceOfAssertFactories.LIST) + .containsOnly(nsOnlineStores, nsMarkets); + }); + } +} diff --git a/src/it/java/io/weaviate/client6/v1/DataITest.java b/src/it/java/io/weaviate/integration/DataITest.java similarity index 66% rename from src/it/java/io/weaviate/client6/v1/DataITest.java rename to src/it/java/io/weaviate/integration/DataITest.java index f64702f1e..f0ccd3765 100644 --- a/src/it/java/io/weaviate/client6/v1/DataITest.java +++ b/src/it/java/io/weaviate/integration/DataITest.java @@ -1,4 +1,4 @@ -package io.weaviate.client6.v1; +package io.weaviate.integration; import java.io.IOException; import java.util.Map; @@ -14,30 +14,34 @@ import io.weaviate.client6.v1.collections.VectorIndex; import io.weaviate.client6.v1.collections.VectorIndex.IndexingStrategy; import io.weaviate.client6.v1.collections.Vectorizer; +import io.weaviate.client6.v1.collections.object.Vectors; import io.weaviate.containers.Container; public class DataITest extends ConcurrentTest { private static WeaviateClient client = Container.WEAVIATE.getClient(); - private static final String COLLECTION = unique("Things"); + private static final String COLLECTION = unique("Artists"); private static final String VECTOR_INDEX = "bring_your_own"; @BeforeClass public static void beforeAll() throws IOException { - createTestCollection(); + createTestCollections(); } @Test public void testCreateGetDelete() throws IOException { - var things = client.collections.use(COLLECTION); + var artists = client.collections.use(COLLECTION); var id = randomUUID(); Float[] vector = { 1f, 2f, 3f }; - things.data.insert(Map.of("username", "john doe"), metadata -> metadata + artists.data.insert(Map.of("name", "john doe"), metadata -> metadata .id(id) .vectors(Vectors.of(VECTOR_INDEX, vector))); - var object = things.data.get(id, query -> query.withVector()); + var object = artists.data.get(id, query -> query + .returnProperties("name") + .includeVector()); + Assertions.assertThat(object) .as("object exists after insert").get() .satisfies(obj -> { @@ -50,18 +54,27 @@ public void testCreateGetDelete() throws IOException { Assertions.assertThat(obj.properties()) .as("has expected properties") - .containsEntry("username", "john doe"); + .containsEntry("name", "john doe"); }); - things.data.delete(id); - object = things.data.get(id); + artists.data.delete(id); + object = artists.data.get(id); Assertions.assertThat(object).isEmpty().as("object not exists after deletion"); } - private static void createTestCollection() throws IOException { + private static void createTestCollections() throws IOException { + var awardsGrammy = unique("Grammy"); + client.collections.create(awardsGrammy); + + var awardsOscar = unique("Oscar"); + client.collections.create(awardsOscar); + client.collections.create(COLLECTION, col -> col - .properties(Property.text("username"), Property.integer("age")) + .properties( + Property.text("name"), + Property.integer("age"), + Property.reference("hasAwards", awardsGrammy, awardsOscar)) .vector(VECTOR_INDEX, new VectorIndex<>(IndexingStrategy.hnsw(), Vectorizer.none()))); } } diff --git a/src/it/java/io/weaviate/integration/ReferencesITest.java b/src/it/java/io/weaviate/integration/ReferencesITest.java new file mode 100644 index 000000000..533d2693d --- /dev/null +++ b/src/it/java/io/weaviate/integration/ReferencesITest.java @@ -0,0 +1,181 @@ +package io.weaviate.integration; + +import java.io.IOException; +import java.util.Map; +import java.util.Optional; + +import org.assertj.core.api.Assertions; +import org.assertj.core.api.InstanceOfAssertFactories; +import org.junit.Test; + +import io.weaviate.ConcurrentTest; +import io.weaviate.client6.WeaviateClient; +import io.weaviate.client6.v1.collections.Property; +import io.weaviate.client6.v1.collections.Reference; +import io.weaviate.client6.v1.collections.object.ObjectReference; +import io.weaviate.client6.v1.collections.object.WeaviateObject; +import io.weaviate.client6.v1.collections.query.MetadataField; +import io.weaviate.client6.v1.collections.query.QueryReference; +import io.weaviate.containers.Container; + +/** + * Scenarios related to reference properties: + * + */ +public class ReferencesITest extends ConcurrentTest { + private static final WeaviateClient client = Container.WEAVIATE.getClient(); + + @Test + public void testReferences() throws IOException { + // Arrange: create collection with cross-references + var nsArtists = ns("Artists"); + var nsGrammy = ns("Grammy"); + var nsOscar = ns("Oscar"); + + client.collections.create(nsOscar); + client.collections.create(nsGrammy); + + // Act: create Artists collection with hasAwards reference + client.collections.create(nsArtists, + col -> col + .properties( + Property.text("name"), + Property.integer("age"), + Property.reference("hasAwards", nsGrammy, nsOscar))); + + var artists = client.collections.use(nsArtists); + var grammies = client.collections.use(nsGrammy); + var oscars = client.collections.use(nsOscar); + + // Act: check collection configuration is correct + var collectionArtists = artists.config.get(); + Assertions.assertThat(collectionArtists).get() + .as("Artists: create collection") + .extracting(c -> c.properties().stream().filter(Property::isReference).findFirst()) + .as("has one reference property").extracting(Optional::get) + .returns("hasAwards", Property::name) + .extracting(Property::dataTypes, InstanceOfAssertFactories.list(String.class)) + .containsOnly(nsGrammy, nsOscar); + + // Act: insert some data + var grammy_1 = grammies.data.insert(Map.of()); + var grammy_2 = grammies.data.insert(Map.of()); + var oscar_1 = oscars.data.insert(Map.of()); + var oscar_2 = oscars.data.insert(Map.of()); + + var alex = artists.data.insert( + Map.of("name", "Alex"), + opt -> opt + .reference("hasAwards", Reference.uuids( + grammy_1.metadata().id(), oscar_1.metadata().id())) + .reference("hasAwards", Reference.objects(grammy_2, oscar_2))); + + // Act: add one more reference + var nsMovies = ns("Movies"); + client.collections.create(nsMovies); + artists.config.addProperty(Property.reference("featuredIn", nsMovies)); + + collectionArtists = artists.config.get(); + Assertions.assertThat(collectionArtists).get() + .as("Artists: add reference to Movies") + .extracting(c -> c.properties().stream() + .filter(property -> property.name().equals("featuredIn")).findFirst()) + .as("featuredIn reference property").extracting(Optional::get) + .returns(true, Property::isReference) + .extracting(Property::dataTypes, InstanceOfAssertFactories.list(String.class)) + .containsOnly(nsMovies); + + var gotAlex = artists.data.get(alex.metadata().id(), + opt -> opt.returnReferences( + QueryReference.multi("hasAwards", nsOscar, + ref -> ref.returnMetadata(MetadataField.ID)), + QueryReference.multi("hasAwards", nsGrammy, + ref -> ref.returnMetadata(MetadataField.ID)))); + + Assertions.assertThat(gotAlex).get() + .as("Artists: fetch by id including hasAwards references") + .extracting(WeaviateObject::references, InstanceOfAssertFactories.map(String.class, ObjectReference.class)) + .as("hasAwards object reference").extractingByKey("hasAwards") + .extracting(ObjectReference::objects, InstanceOfAssertFactories.list(WeaviateObject.class)) + .extracting(objects -> objects.metadata().id()) + .containsOnly( + // INVESTIGATE: When references to 2+ collections are requested, + // seems to Weaviate only return references to the first one in the list. + // In this case we request { "hasAwards": Oscars } and { "hasAwards": Grammys } + // so the latter will not be in the response. + // + // grammy_1.metadata().id(), grammy_2.metadata().id(), + oscar_1.metadata().id(), oscar_2.metadata().id()); + } + + @Test + public void testNestedReferences() throws IOException { + // Arrange: create collection with cross-references + var nsArtists = ns("Artists"); + var nsGrammy = ns("Grammy"); + var nsAcademy = ns("Academy"); + + client.collections.create(nsAcademy, + opt -> opt + .properties(Property.text("ceo"))); + + // Act: create Artists collection with hasAwards reference + client.collections.create(nsGrammy, + col -> col + .properties(Property.reference("presentedBy", nsAcademy))); + + client.collections.create(nsArtists, + col -> col + .properties( + Property.text("name"), + Property.integer("age"), + Property.reference("hasAwards", nsGrammy))); + + var artists = client.collections.use(nsArtists); + var grammies = client.collections.use(nsGrammy); + var academies = client.collections.use(nsAcademy); + + // Act: insert some data + var musicAcademy = academies.data.insert(Map.of("ceo", "Harvy Mason")); + + var grammy_1 = grammies.data.insert(Map.of(), + opt -> opt.reference("presentedBy", Reference.objects(musicAcademy))); + + var alex = artists.data.insert( + Map.of("name", "Alex"), + opt -> opt + .reference("hasAwards", Reference.objects(grammy_1))); + + // Assert: fetch nested references + var gotAlex = artists.data.get(alex.metadata().id(), + opt -> opt.returnReferences( + QueryReference.single("hasAwards", + ref -> ref + // Name of the CEO of the presenting academy + .returnReferences( + QueryReference.single("presentedBy", r -> r.returnProperties("ceo"))) + // Grammy ID + .returnMetadata(MetadataField.ID)))); + + Assertions.assertThat(gotAlex).get() + .as("Artists: fetch by id including nested references") + .extracting(WeaviateObject::references, InstanceOfAssertFactories.map(String.class, ObjectReference.class)) + .as("hasAwards object reference").extractingByKey("hasAwards") + .extracting(ObjectReference::objects, InstanceOfAssertFactories.list(WeaviateObject.class)) + .hasSize(1).allSatisfy(award -> Assertions.assertThat(award) + .returns(grammy_1.metadata().id(), g -> g.metadata().id()) + .extracting(WeaviateObject::references, + InstanceOfAssertFactories.map(String.class, ObjectReference.class)) + .extractingByKey("presentedBy") + .extracting(ObjectReference::objects, InstanceOfAssertFactories.list(WeaviateObject.class)) + .hasSize(1).extracting(WeaviateObject::properties) + .allSatisfy(properties -> Assertions.assertThat(properties) + .asInstanceOf(InstanceOfAssertFactories.map(String.class, Object.class)) + .containsEntry("ceo", "Harvy Mason"))); + } +} diff --git a/src/it/java/io/weaviate/integration/NearVectorQueryITest.java b/src/it/java/io/weaviate/integration/SearchITest.java similarity index 68% rename from src/it/java/io/weaviate/integration/NearVectorQueryITest.java rename to src/it/java/io/weaviate/integration/SearchITest.java index 66258810d..2496e2801 100644 --- a/src/it/java/io/weaviate/integration/NearVectorQueryITest.java +++ b/src/it/java/io/weaviate/integration/SearchITest.java @@ -9,22 +9,32 @@ import org.assertj.core.api.Assertions; import org.junit.BeforeClass; +import org.junit.ClassRule; import org.junit.Test; +import org.junit.rules.TestRule; import io.weaviate.ConcurrentTest; import io.weaviate.client6.WeaviateClient; -import io.weaviate.client6.v1.Vectors; import io.weaviate.client6.v1.collections.Property; import io.weaviate.client6.v1.collections.VectorIndex; import io.weaviate.client6.v1.collections.VectorIndex.IndexingStrategy; import io.weaviate.client6.v1.collections.Vectorizer; -import io.weaviate.client6.v1.query.GroupedQueryResult; -import io.weaviate.client6.v1.query.MetadataField; -import io.weaviate.client6.v1.query.NearVector; +import io.weaviate.client6.v1.collections.object.Vectors; +import io.weaviate.client6.v1.collections.query.GroupedQueryResult; +import io.weaviate.client6.v1.collections.query.MetadataField; +import io.weaviate.client6.v1.collections.query.NearVector; import io.weaviate.containers.Container; - -public class NearVectorQueryITest extends ConcurrentTest { - private static final WeaviateClient client = Container.WEAVIATE.getClient(); +import io.weaviate.containers.Container.Group; +import io.weaviate.containers.Contextionary; +import io.weaviate.containers.Weaviate; + +public class SearchITest extends ConcurrentTest { + private static final Group compose = Container.compose( + Weaviate.custom().withContextionaryUrl(Contextionary.URL).build(), + Container.CONTEXTIONARY); + @ClassRule // Bind containers to lifetime to the test + public static final TestRule _rule = compose.asTestRule(); + private static final WeaviateClient client = compose.getClient(); private static final String COLLECTION = unique("Things"); private static final String VECTOR_INDEX = "bring_your_own"; @@ -81,7 +91,6 @@ public void testNearVector_groupBy() { Assertions.assertThat(result.objects) .as("object belongs a group") .allMatch(obj -> result.groups.get(obj.belongsToGroup).objects().contains(obj)); - } /** @@ -117,4 +126,28 @@ private static void createTestCollection() throws IOException { .properties(Property.text("category")) .vector(VECTOR_INDEX, new VectorIndex<>(IndexingStrategy.hnsw(), Vectorizer.none()))); } + + @Test + public void testNearText() throws IOException { + var nsSongs = ns("Songs"); + client.collections.create(nsSongs, + col -> col + .properties(Property.text("title")) + .vector(new VectorIndex<>(IndexingStrategy.hnsw(), Vectorizer.contextionary()))); + + var songs = client.collections.use(nsSongs); + songs.data.insert(Map.of("title", "Yellow Submarine")); + songs.data.insert(Map.of("title", "Run Through The Jungle")); + songs.data.insert(Map.of("title", "Welcome To The Jungle")); + + var result = songs.query.nearText("forest", + opt -> opt + .distance(0.5f) + .returnProperties("title")); + + Assertions.assertThat(result.objects).hasSize(2) + .extracting(obj -> obj.properties).allSatisfy( + properties -> Assertions.assertThat(properties) + .allSatisfy((_k, v) -> Assertions.assertThat((String) v).contains("Jungle"))); + } } diff --git a/src/main/java/io/weaviate/client6/WeaviateClient.java b/src/main/java/io/weaviate/client6/WeaviateClient.java index 8dba725a5..724fc65e9 100644 --- a/src/main/java/io/weaviate/client6/WeaviateClient.java +++ b/src/main/java/io/weaviate/client6/WeaviateClient.java @@ -5,18 +5,18 @@ import io.weaviate.client6.internal.GrpcClient; import io.weaviate.client6.internal.HttpClient; -import io.weaviate.client6.v1.collections.Collections; +import io.weaviate.client6.v1.collections.CollectionsClient; public class WeaviateClient implements Closeable { private final HttpClient http; private final GrpcClient grpc; - public final Collections collections; + public final CollectionsClient collections; public WeaviateClient(Config config) { this.http = new HttpClient(); this.grpc = new GrpcClient(config); - this.collections = new Collections(config, http, grpc); + this.collections = new CollectionsClient(config, http, grpc); } @Override diff --git a/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/AggregateMarshaler.java b/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/AggregateMarshaler.java index 446adba78..bf046b0dd 100644 --- a/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/AggregateMarshaler.java +++ b/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/AggregateMarshaler.java @@ -8,13 +8,12 @@ import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoAggregate.AggregateRequest.Aggregation; import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoBaseSearch; import io.weaviate.client6.internal.GRPC; -import io.weaviate.client6.v1.collections.aggregate.AggregateGroupByRequest; import io.weaviate.client6.v1.collections.aggregate.AggregateGroupByRequest.GroupBy; import io.weaviate.client6.v1.collections.aggregate.AggregateRequest; import io.weaviate.client6.v1.collections.aggregate.IntegerMetric; import io.weaviate.client6.v1.collections.aggregate.Metric; import io.weaviate.client6.v1.collections.aggregate.TextMetric; -import io.weaviate.client6.v1.query.NearVector; +import io.weaviate.client6.v1.collections.query.NearVector; public final class AggregateMarshaler { private final WeaviateProtoAggregate.AggregateRequest.Builder req = WeaviateProtoAggregate.AggregateRequest diff --git a/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/SearchMarshaler.java b/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/SearchMarshaler.java index a85970bb1..b1d532218 100644 --- a/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/SearchMarshaler.java +++ b/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/SearchMarshaler.java @@ -9,8 +9,9 @@ import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoSearchGet.SearchRequest; import io.weaviate.client6.internal.GRPC; import io.weaviate.client6.internal.codec.grpc.GrpcMarshaler; -import io.weaviate.client6.v1.query.CommonQueryOptions; -import io.weaviate.client6.v1.query.NearVector; +import io.weaviate.client6.v1.collections.query.CommonQueryOptions; +import io.weaviate.client6.v1.collections.query.NearText; +import io.weaviate.client6.v1.collections.query.NearVector; public class SearchMarshaler implements GrpcMarshaler { private final WeaviateProtoSearchGet.SearchRequest.Builder req = WeaviateProtoSearchGet.SearchRequest.newBuilder(); @@ -43,10 +44,28 @@ public SearchMarshaler addNearVector(NearVector nv) { nearVector.setDistance(nv.distance()); } + // TODO: add targets, vector_for_targets req.setNearVector(nearVector); return this; } + public SearchMarshaler addNearText(NearText nt) { + setCommon(nt.common()); + + var nearText = WeaviateProtoBaseSearch.NearTextSearch.newBuilder(); + nearText.addAllQuery(nt.text()); + + if (nt.certainty() != null) { + nearText.setCertainty(nt.certainty()); + } else if (nt.distance() != null) { + nearText.setDistance(nt.distance()); + } + + // TODO: add move_to, move_away, targets + req.setNearText(nearText); + return this; + } + private void setCommon(CommonQueryOptions o) { if (o.limit() != null) { req.setLimit(o.limit()); diff --git a/src/main/java/io/weaviate/client6/v1/Collection.java b/src/main/java/io/weaviate/client6/v1/Collection.java deleted file mode 100644 index b1f40dcc4..000000000 --- a/src/main/java/io/weaviate/client6/v1/Collection.java +++ /dev/null @@ -1,20 +0,0 @@ -package io.weaviate.client6.v1; - -import io.weaviate.client6.Config; -import io.weaviate.client6.internal.GrpcClient; -import io.weaviate.client6.internal.HttpClient; -import io.weaviate.client6.v1.collections.aggregate.WeaviateAggregate; -import io.weaviate.client6.v1.data.Data; -import io.weaviate.client6.v1.query.Query; - -public class Collection { - public final Query query; - public final Data data; - public final WeaviateAggregate aggregate; - - public Collection(String collectionName, Config config, GrpcClient grpc, HttpClient http) { - this.query = new Query<>(collectionName, grpc); - this.data = new Data<>(collectionName, config, http); - this.aggregate = new WeaviateAggregate(collectionName, grpc); - } -} diff --git a/src/main/java/io/weaviate/client6/v1/ObjectMetadata.java b/src/main/java/io/weaviate/client6/v1/ObjectMetadata.java deleted file mode 100644 index d63b0225a..000000000 --- a/src/main/java/io/weaviate/client6/v1/ObjectMetadata.java +++ /dev/null @@ -1,55 +0,0 @@ -package io.weaviate.client6.v1; - -import java.util.function.Consumer; - -public record ObjectMetadata(String id, Vectors vectors) { - - public static ObjectMetadata with(Consumer options) { - var opt = new Builder(options); - return new ObjectMetadata(opt.id, opt.vectors); - } - - public static class Builder { - public String id; - public Vectors vectors; - - public Builder id(String id) { - this.id = id; - return this; - } - - public Builder vectors(Vectors vectors) { - this.vectors = vectors; - return this; - } - - public Builder vectors(Float[] vector) { - this.vectors = Vectors.of(vector); - return this; - } - - public Builder vectors(Float[][] vector) { - this.vectors = Vectors.of(vector); - return this; - } - - public Builder vectors(String name, Float[] vector) { - this.vectors = Vectors.of(name, vector); - return this; - } - - public Builder vectors(String name, Float[][] vector) { - this.vectors = Vectors.of(name, vector); - return this; - } - - public Builder vectors(Consumer named) { - this.vectors = Vectors.with(named); - return this; - } - - private Builder(Consumer options) { - options.accept(this); - } - } -} diff --git a/src/main/java/io/weaviate/client6/v1/collections/AtomicDataType.java b/src/main/java/io/weaviate/client6/v1/collections/AtomicDataType.java new file mode 100644 index 000000000..da54b1c28 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/AtomicDataType.java @@ -0,0 +1,15 @@ +package io.weaviate.client6.v1.collections; + +import com.google.gson.annotations.SerializedName; + +public enum AtomicDataType { + @SerializedName("text") + TEXT, + @SerializedName("int") + INT; + + public static boolean isAtomic(String type) { + return type.equals(TEXT.name().toLowerCase()) + || type.equals(INT.name().toLowerCase()); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/collections/Collection.java b/src/main/java/io/weaviate/client6/v1/collections/Collection.java new file mode 100644 index 000000000..d0fe10903 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/Collection.java @@ -0,0 +1,50 @@ +package io.weaviate.client6.v1.collections; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.function.Consumer; + +import io.weaviate.client6.v1.collections.Vectors.NamedVectors; + +public record Collection(String name, List properties, Vectors vectors) { + + public static Collection with(String name, Consumer options) { + var config = new Builder(options); + return new Collection(name, config.properties, config.vectors); + } + + public static class Builder { + private List properties = new ArrayList<>(); + private Vectors vectors; + + public Builder properties(Property... properties) { + this.properties = Arrays.asList(properties); + return this; + } + + public Builder vectors(Vectors vectors) { + this.vectors = vectors; + return this; + } + + public Builder vector(VectorIndex vector) { + this.vectors = Vectors.of(vector); + return this; + } + + public Builder vector(String name, VectorIndex vector) { + this.vectors = new Vectors(name, vector); + return this; + } + + public Builder vectors(Consumer named) { + this.vectors = Vectors.with(named); + return this; + } + + Builder(Consumer options) { + options.accept(this); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/collections/CollectionClient.java b/src/main/java/io/weaviate/client6/v1/collections/CollectionClient.java new file mode 100644 index 000000000..4f8cd6fdf --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/CollectionClient.java @@ -0,0 +1,22 @@ +package io.weaviate.client6.v1.collections; + +import io.weaviate.client6.Config; +import io.weaviate.client6.internal.GrpcClient; +import io.weaviate.client6.internal.HttpClient; +import io.weaviate.client6.v1.collections.aggregate.AggregateClient; +import io.weaviate.client6.v1.collections.data.DataClient; +import io.weaviate.client6.v1.collections.query.QueryClient; + +public class CollectionClient { + public final QueryClient query; + public final DataClient data; + public final CollectionConfigClient config; + public final AggregateClient aggregate; + + public CollectionClient(String collectionName, Config config, GrpcClient grpc, HttpClient http) { + this.query = new QueryClient<>(collectionName, grpc); + this.data = new DataClient<>(collectionName, config, http, grpc); + this.config = new CollectionConfigClient(collectionName, config, http); + this.aggregate = new AggregateClient(collectionName, grpc); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/collections/CollectionConfigClient.java b/src/main/java/io/weaviate/client6/v1/collections/CollectionConfigClient.java new file mode 100644 index 000000000..186ea3ee2 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/CollectionConfigClient.java @@ -0,0 +1,68 @@ +package io.weaviate.client6.v1.collections; + +import java.io.IOException; +import java.io.InputStreamReader; +import java.util.Optional; + +import org.apache.hc.core5.http.ClassicHttpRequest; +import org.apache.hc.core5.http.ContentType; +import org.apache.hc.core5.http.HttpStatus; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.apache.hc.core5.http.io.support.ClassicRequestBuilder; + +import com.google.gson.Gson; + +import io.weaviate.client6.Config; +import io.weaviate.client6.internal.DtoTypeAdapterFactory; +import io.weaviate.client6.internal.HttpClient; +import lombok.AllArgsConstructor; + +@AllArgsConstructor +public class CollectionConfigClient { + // TODO: hide befind an internal HttpClient + private final String collectionName; + private final Config config; + private final HttpClient httpClient; + + private static final Gson gson = new Gson(); + static { + DtoTypeAdapterFactory.register( + Collection.class, + CollectionDefinitionDTO.class, + m -> { + return new CollectionDefinitionDTO(m); + }); + } + + public Optional get() throws IOException { + ClassicHttpRequest httpGet = ClassicRequestBuilder + .get(config.baseUrl() + "/schema/" + collectionName) + .build(); + + return httpClient.http.execute(httpGet, response -> { + if (response.getCode() == HttpStatus.SC_NOT_FOUND) { + return Optional.empty(); + } + try (var r = new InputStreamReader(response.getEntity().getContent())) { + var collection = gson.fromJson(r, Collection.class); + return Optional.ofNullable(collection); + } + }); + } + + public void addProperty(Property property) throws IOException { + ClassicHttpRequest httpPost = ClassicRequestBuilder + .post(config.baseUrl() + "/schema/" + collectionName + "/properties") + .setEntity(gson.toJson(property), ContentType.APPLICATION_JSON) + .build(); + + httpClient.http.execute(httpPost, response -> { + var entity = response.getEntity(); + if (response.getCode() != HttpStatus.SC_SUCCESS) { + var message = EntityUtils.toString(entity); + throw new RuntimeException("HTTP " + response.getCode() + ": " + message); + } + return null; + }); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/collections/CollectionDefinition.java b/src/main/java/io/weaviate/client6/v1/collections/CollectionDefinition.java deleted file mode 100644 index b599f4ceb..000000000 --- a/src/main/java/io/weaviate/client6/v1/collections/CollectionDefinition.java +++ /dev/null @@ -1,51 +0,0 @@ -package io.weaviate.client6.v1.collections; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.function.Consumer; - -import io.weaviate.client6.v1.collections.Vectors.NamedVectors; - -public record CollectionDefinition(String name, List properties, Vectors vectors) { - - public static CollectionDefinition with(String name, Consumer options) { - var config = new Configuration(options); - return new CollectionDefinition(name, config.properties, config.vectors); - } - - // Tucked Builder for additional collection configuration. - public static class Configuration { - public List properties = new ArrayList<>(); - public Vectors vectors; - - public Configuration properties(Property... properties) { - this.properties = Arrays.asList(properties); - return this; - } - - public Configuration vectors(Vectors vectors) { - this.vectors = vectors; - return this; - } - - public Configuration vector(VectorIndex vector) { - this.vectors = Vectors.of(vector); - return this; - } - - public Configuration vector(String name, VectorIndex vector) { - this.vectors = new Vectors(name, vector); - return this; - } - - public Configuration vectors(Consumer named) { - this.vectors = Vectors.with(named); - return this; - } - - Configuration(Consumer options) { - options.accept(this); - } - } -} diff --git a/src/main/java/io/weaviate/client6/v1/collections/CollectionDefinitionDTO.java b/src/main/java/io/weaviate/client6/v1/collections/CollectionDefinitionDTO.java index e0333fad0..d35dce478 100644 --- a/src/main/java/io/weaviate/client6/v1/collections/CollectionDefinitionDTO.java +++ b/src/main/java/io/weaviate/client6/v1/collections/CollectionDefinitionDTO.java @@ -6,7 +6,7 @@ import io.weaviate.client6.internal.DtoTypeAdapterFactory; -class CollectionDefinitionDTO implements DtoTypeAdapterFactory.Dto { +class CollectionDefinitionDTO implements DtoTypeAdapterFactory.Dto { @SerializedName("class") String collection; @@ -25,7 +25,7 @@ class CollectionDefinitionDTO implements DtoTypeAdapterFactory.Dto { return new CollectionDefinitionDTO(m); @@ -76,6 +75,7 @@ private static class VectorizerSerde @Override public Vectorizer deserialize(JsonElement json, Type typeOfT, JsonDeserializationContext context) throws JsonParseException { + // TODO: deserialize different kinds of vectorizers return Vectorizer.none(); } @@ -112,9 +112,13 @@ public Vectors read(JsonReader in) throws IOException { }) .create(); - public void create(String name, Consumer options) throws IOException { - var collection = CollectionDefinition.with(name, options); + public void create(String name) throws IOException { + create(name, opt -> { + }); + } + public void create(String name, Consumer options) throws IOException { + var collection = Collection.with(name, options); ClassicHttpRequest httpPost = ClassicRequestBuilder .post(config.baseUrl() + "/schema") .setEntity(gson.toJson(collection), ContentType.APPLICATION_JSON) @@ -131,7 +135,7 @@ public void create(String name, Consumer opt }); } - public Optional getConfig(String name) throws IOException { + public Optional getConfig(String name) throws IOException { ClassicHttpRequest httpGet = ClassicRequestBuilder .get(config.baseUrl() + "/schema/" + name) .build(); @@ -141,7 +145,7 @@ public Optional getConfig(String name) throws IOException return Optional.empty(); } try (var r = new InputStreamReader(response.getEntity().getContent())) { - var collection = gson.fromJson(r, CollectionDefinition.class); + var collection = gson.fromJson(r, Collection.class); return Optional.ofNullable(collection); } }); @@ -162,7 +166,7 @@ public void delete(String name) throws IOException { }); } - public Collection> use(String name) { - return new Collection<>(name, config, grpcClient, httpClient); + public CollectionClient> use(String name) { + return new CollectionClient<>(name, config, grpcClient, httpClient); } } diff --git a/src/main/java/io/weaviate/client6/v1/collections/ContextionaryVectorizer.java b/src/main/java/io/weaviate/client6/v1/collections/ContextionaryVectorizer.java new file mode 100644 index 000000000..1bb580ada --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/ContextionaryVectorizer.java @@ -0,0 +1,37 @@ +package io.weaviate.client6.v1.collections; + +import java.util.Map; +import java.util.function.Consumer; + +import com.google.gson.annotations.SerializedName; + +import lombok.AllArgsConstructor; + +@AllArgsConstructor +public class ContextionaryVectorizer extends Vectorizer { + @SerializedName("text2vec-contextionary") + private Map configuration; + + public static ContextionaryVectorizer of() { + return new Builder().build(); + } + + public static ContextionaryVectorizer of(Consumer fn) { + var builder = new Builder(); + fn.accept(builder); + return builder.build(); + } + + public static class Builder { + private boolean vectorizeCollectionName = false; + + public Builder vectorizeCollectionName() { + this.vectorizeCollectionName = true; + return this; + } + + public ContextionaryVectorizer build() { + return new ContextionaryVectorizer(Map.of("vectorizeClassName", vectorizeCollectionName)); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/collections/DataType.java b/src/main/java/io/weaviate/client6/v1/collections/DataType.java deleted file mode 100644 index 8ec96470f..000000000 --- a/src/main/java/io/weaviate/client6/v1/collections/DataType.java +++ /dev/null @@ -1,10 +0,0 @@ -package io.weaviate.client6.v1.collections; - -import com.google.gson.annotations.SerializedName; - -public enum DataType { - @SerializedName("text") - TEXT, - @SerializedName("int") - INT; -} diff --git a/src/main/java/io/weaviate/client6/v1/collections/HNSW.java b/src/main/java/io/weaviate/client6/v1/collections/HNSW.java index cac6b45b4..938c44ffa 100644 --- a/src/main/java/io/weaviate/client6/v1/collections/HNSW.java +++ b/src/main/java/io/weaviate/client6/v1/collections/HNSW.java @@ -17,26 +17,26 @@ public enum Distance { this(null, null); } - static HNSW with(Consumer options) { - var opt = new Options(options); + static HNSW with(Consumer options) { + var opt = new Builder(options); return new HNSW(opt.distance, opt.skip); } - public static class Options { - public Distance distance; - public Boolean skip; + public static class Builder { + private Distance distance; + private Boolean skip; - public Options distance(Distance distance) { + public Builder distance(Distance distance) { this.distance = distance; return this; } - public Options disableIndexation() { + public Builder disableIndexation() { this.skip = true; return this; } - public Options(Consumer options) { + public Builder(Consumer options) { options.accept(this); } } diff --git a/src/main/java/io/weaviate/client6/v1/collections/Property.java b/src/main/java/io/weaviate/client6/v1/collections/Property.java index 81d23a5df..a1a811ddd 100644 --- a/src/main/java/io/weaviate/client6/v1/collections/Property.java +++ b/src/main/java/io/weaviate/client6/v1/collections/Property.java @@ -2,46 +2,36 @@ import java.util.Arrays; import java.util.List; -import java.util.function.Consumer; import com.google.gson.annotations.SerializedName; -public class Property { - @SerializedName("name") - public final String name; - - @SerializedName("dataType") - public final List dataTypes; +public record Property( + @SerializedName("name") String name, + @SerializedName("dataType") List dataTypes) { /** Add text property with default configuration. */ public static Property text(String name) { - return new Property(name, DataType.TEXT); + return new Property(name, AtomicDataType.TEXT); } /** Add integer property with default configuration. */ public static Property integer(String name) { - return new Property(name, DataType.INT); + return new Property(name, AtomicDataType.INT); } - public static final class Configuration { - private List dataTypes; - - public Configuration dataTypes(DataType... types) { - this.dataTypes = Arrays.asList(types); - return this; - } + public static Property reference(String name, String... collections) { + return new Property(name, collections); } - private Property(String name, DataType type) { - this.name = name; - this.dataTypes = List.of(type); + public boolean isReference() { + return dataTypes.stream().noneMatch(t -> AtomicDataType.isAtomic(t)); } - public Property(String name, Consumer options) { - var config = new Configuration(); - options.accept(config); + private Property(String name, AtomicDataType type) { + this(name, List.of(type.name().toLowerCase())); + } - this.name = name; - this.dataTypes = config.dataTypes; + private Property(String name, String... collections) { + this(name, Arrays.asList(collections)); } } diff --git a/src/main/java/io/weaviate/client6/v1/collections/Reference.java b/src/main/java/io/weaviate/client6/v1/collections/Reference.java new file mode 100644 index 000000000..b17799911 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/Reference.java @@ -0,0 +1,59 @@ +package io.weaviate.client6.v1.collections; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +import com.google.gson.stream.JsonWriter; + +import io.weaviate.client6.v1.collections.object.WeaviateObject; + +public record Reference(String collection, List uuids) { + + public Reference(String collection, String uuid) { + this(collection, List.of(uuid)); + } + + /** + * Create reference to objects by their UUIDs. + *

+ * Weaviate will search each of the existing collections to identify + * the objects before inserting the references, so this may include + * some performance overhead. + */ + public static Reference uuids(String... uuids) { + return new Reference(null, Arrays.asList(uuids)); + } + + /** Create references to {@link WeaviateObject}. */ + public static Reference[] objects(WeaviateObject... objects) { + return Arrays.stream(objects) + .map(o -> new Reference(o.collection(), o.metadata().id())) + .toArray(Reference[]::new); + } + + /** Create references to objects in a collection by their UUIDs. */ + public static Reference collection(String collection, String... uuids) { + return new Reference(collection, Arrays.asList(uuids)); + } + + // TODO: put this in a type adapter. + /** writeValue assumes an array has been started will be ended by the caller. */ + public void writeValue(JsonWriter w) throws IOException { + for (var uuid : uuids) { + w.beginObject(); + w.name("beacon"); + w.value(toBeacon(uuid)); + w.endObject(); + } + } + + private String toBeacon(String uuid) { + var beacon = "weaviate://localhost/"; + if (collection != null) { + beacon += collection + "/"; + } + beacon += uuid; + return beacon; + } +} diff --git a/src/main/java/io/weaviate/client6/v1/collections/VectorIndex.java b/src/main/java/io/weaviate/client6/v1/collections/VectorIndex.java index 5db348263..b8aea7f0c 100644 --- a/src/main/java/io/weaviate/client6/v1/collections/VectorIndex.java +++ b/src/main/java/io/weaviate/client6/v1/collections/VectorIndex.java @@ -29,7 +29,7 @@ public static IndexingStrategy hnsw() { return new HNSW(); } - public static IndexingStrategy hnsw(Consumer options) { + public static IndexingStrategy hnsw(Consumer options) { return HNSW.with(options); } } diff --git a/src/main/java/io/weaviate/client6/v1/collections/Vectorizer.java b/src/main/java/io/weaviate/client6/v1/collections/Vectorizer.java index ad9c4260f..d421d926a 100644 --- a/src/main/java/io/weaviate/client6/v1/collections/Vectorizer.java +++ b/src/main/java/io/weaviate/client6/v1/collections/Vectorizer.java @@ -1,8 +1,18 @@ package io.weaviate.client6.v1.collections; +import java.util.function.Consumer; + // This class is WIP, I haven't decided how to structure it yet. public abstract class Vectorizer { public static NoneVectorizer none() { return new NoneVectorizer(); } + + public static ContextionaryVectorizer contextionary() { + return ContextionaryVectorizer.of(); + } + + public static ContextionaryVectorizer contextionary(Consumer fn) { + return ContextionaryVectorizer.of(fn); + } } diff --git a/src/main/java/io/weaviate/client6/v1/collections/aggregate/WeaviateAggregate.java b/src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateClient.java similarity index 95% rename from src/main/java/io/weaviate/client6/v1/collections/aggregate/WeaviateAggregate.java rename to src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateClient.java index 73474fb7b..cfc774eae 100644 --- a/src/main/java/io/weaviate/client6/v1/collections/aggregate/WeaviateAggregate.java +++ b/src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateClient.java @@ -5,13 +5,13 @@ import io.weaviate.client6.internal.GrpcClient; import io.weaviate.client6.internal.codec.grpc.v1.AggregateMarshaler; import io.weaviate.client6.internal.codec.grpc.v1.AggregateUnmarshaler; -import io.weaviate.client6.v1.query.NearVector; +import io.weaviate.client6.v1.collections.query.NearVector; -public class WeaviateAggregate { +public class AggregateClient { private final String collectionName; private final GrpcClient grpcClient; - public WeaviateAggregate(String collectionName, GrpcClient grpc) { + public AggregateClient(String collectionName, GrpcClient grpc) { this.collectionName = collectionName; this.grpcClient = grpc; } @@ -55,11 +55,11 @@ public AggregateResponse nearVector( public AggregateGroupByResponse nearVector( Float[] vector, + Consumer nearVectorOptions, AggregateGroupByRequest.GroupBy groupBy, Consumer options) { var aggregation = AggregateRequest.with(collectionName, options); - var nearVector = NearVector.with(vector, opt -> { - }); + var nearVector = NearVector.with(vector, nearVectorOptions); var req = new AggregateMarshaler(aggregation.collectionName()) .addAggregation(aggregation) @@ -72,11 +72,11 @@ public AggregateGroupByResponse nearVector( public AggregateGroupByResponse nearVector( Float[] vector, - Consumer nearVectorOptions, AggregateGroupByRequest.GroupBy groupBy, Consumer options) { var aggregation = AggregateRequest.with(collectionName, options); - var nearVector = NearVector.with(vector, nearVectorOptions); + var nearVector = NearVector.with(vector, opt -> { + }); var req = new AggregateMarshaler(aggregation.collectionName()) .addAggregation(aggregation) diff --git a/src/main/java/io/weaviate/client6/v1/collections/aggregate/TopOccurrences.java b/src/main/java/io/weaviate/client6/v1/collections/aggregate/TopOccurrences.java deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/main/java/io/weaviate/client6/v1/data/ConsistencyLevel.java b/src/main/java/io/weaviate/client6/v1/collections/data/ConsistencyLevel.java similarity index 51% rename from src/main/java/io/weaviate/client6/v1/data/ConsistencyLevel.java rename to src/main/java/io/weaviate/client6/v1/collections/data/ConsistencyLevel.java index 1011a7a02..3347b012c 100644 --- a/src/main/java/io/weaviate/client6/v1/data/ConsistencyLevel.java +++ b/src/main/java/io/weaviate/client6/v1/collections/data/ConsistencyLevel.java @@ -1,4 +1,4 @@ -package io.weaviate.client6.v1.data; +package io.weaviate.client6.v1.collections.data; public enum ConsistencyLevel { ONE, QUORUM, ALL diff --git a/src/main/java/io/weaviate/client6/v1/collections/data/DataClient.java b/src/main/java/io/weaviate/client6/v1/collections/data/DataClient.java new file mode 100644 index 000000000..e8878e1e8 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/data/DataClient.java @@ -0,0 +1,233 @@ +package io.weaviate.client6.v1.collections.data; + +import java.io.IOException; +import java.time.OffsetDateTime; +import java.util.Date; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Consumer; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import org.apache.hc.client5.http.impl.classic.CloseableHttpClient; +import org.apache.hc.client5.http.impl.classic.HttpClients; +import org.apache.hc.core5.http.ClassicHttpRequest; +import org.apache.hc.core5.http.ContentType; +import org.apache.hc.core5.http.HttpStatus; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.apache.hc.core5.http.io.support.ClassicRequestBuilder; + +import com.google.gson.Gson; + +import io.weaviate.client6.Config; +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoBase.Vectors.VectorType; +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoProperties.Value; +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoSearchGet.MetadataResult; +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoSearchGet.PropertiesResult; +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoSearchGet.SearchRequest; +import io.weaviate.client6.internal.GRPC; +import io.weaviate.client6.internal.GrpcClient; +import io.weaviate.client6.internal.HttpClient; +import io.weaviate.client6.v1.collections.object.ObjectMetadata; +import io.weaviate.client6.v1.collections.object.ObjectReference; +import io.weaviate.client6.v1.collections.object.Vectors; +import io.weaviate.client6.v1.collections.object.WeaviateObject; +import lombok.AllArgsConstructor; + +@AllArgsConstructor +public class DataClient { + // TODO: inject singleton as dependency + private static final Gson gson = new Gson(); + + // TODO: this should be wrapped around in some TypeInspector etc. + private final String collectionName; + + // TODO: hide befind an internal HttpClient + private final Config config; + private final HttpClient httpClient; + private final GrpcClient grpcClient; + + public WeaviateObject insert(T properties) throws IOException { + return insert(properties, opt -> { + }); + } + + public WeaviateObject insert(T properties, Consumer> fn) throws IOException { + return insert(InsertObjectRequest.of(collectionName, properties, fn)); + } + + public WeaviateObject insert(InsertObjectRequest request) throws IOException { + ClassicHttpRequest httpPost = ClassicRequestBuilder + .post(config.baseUrl() + "/objects") + .setEntity(request.serialize(gson), ContentType.APPLICATION_JSON) + .build(); + + return httpClient.http.execute(httpPost, response -> { + var entity = response.getEntity(); + if (response.getCode() != HttpStatus.SC_SUCCESS) { // Does not return 201 + var message = EntityUtils.toString(entity); + throw new RuntimeException("HTTP " + response.getCode() + ": " + message); + } + + return WeaviateObject.fromJson(gson, entity.getContent()); + }); + } + + public Optional> get(String id) throws IOException { + return get(id, q -> { + }); + } + + public Optional> get(String id, Consumer fn) throws IOException { + return findById(FetchByIdRequest.of(collectionName, id, fn)); + } + + private Optional> findById(FetchByIdRequest request) { + var req = SearchRequest.newBuilder(); + req.setUses127Api(true); + req.setUses125Api(true); + req.setUses123Api(true); + request.appendTo(req); + var result = grpcClient.grpc.search(req.build()); + var objects = result.getResultsList().stream().map(r -> { + var tempObj = readPropertiesResult(r.getProperties()); + MetadataResult meta = r.getMetadata(); + Vectors vectors; + if (!meta.getVectorBytes().isEmpty()) { + vectors = Vectors.of(GRPC.fromByteString(meta.getVectorBytes())); + } else { + vectors = Vectors.of(meta.getVectorsList().stream().collect( + Collectors.toMap( + io.weaviate.client6.grpc.protocol.v1.WeaviateProtoBase.Vectors::getName, + v -> { + if (v.getType().equals(VectorType.VECTOR_TYPE_SINGLE_FP32)) { + return GRPC.fromByteString(v.getVectorBytes()); + } else { + return GRPC.fromByteStringMulti(v.getVectorBytes()); + } + }))); + } + var metadata = new ObjectMetadata(meta.getId(), vectors); + return new WeaviateObject<>( + tempObj.collection(), + tempObj.properties(), + tempObj.references(), + metadata); + }).toList(); + if (objects.isEmpty()) { + return Optional.empty(); + } + return Optional.ofNullable((WeaviateObject) objects.get(0)); + } + + private static WeaviateObject readPropertiesResult(PropertiesResult res) { + var collection = res.getTargetCollection(); + var objectProperties = convertProtoMap(res.getNonRefProps().getFieldsMap()); + + // In case a reference is multi-target, there will be a separate + // "reference property" for each of the targets, so instead of + // `collect` we need to `reduce` the map, merging related references + // as we go. + // I.e. { "ref": A-1 } , { "ref": B-1 } => { "ref": [A-1, B-1] } + var referenceProperties = res.getRefPropsList().stream().reduce( + new HashMap(), + (map, ref) -> { + var refObjects = ref.getPropertiesList().stream() + .map(DataClient::readPropertiesResult) + .toList(); + + // Merge ObjectReferences by joining the underlying WeaviateObjects. + map.merge( + ref.getPropName(), + new ObjectReference((List>) refObjects), + (left, right) -> { + var joined = Stream.concat( + left.objects().stream(), + right.objects().stream()).toList(); + return new ObjectReference(joined); + }); + return map; + }, + (left, right) -> { + left.putAll(right); + return left; + }); + + MetadataResult meta = res.getMetadata(); + Vectors vectors; + if (meta.getVectorBytes() != null) { + vectors = Vectors.of(GRPC.fromByteString(meta.getVectorBytes())); + } else { + vectors = Vectors.of(meta.getVectorsList().stream().collect( + Collectors.toMap( + io.weaviate.client6.grpc.protocol.v1.WeaviateProtoBase.Vectors::getName, + v -> { + if (v.getType().equals(VectorType.VECTOR_TYPE_MULTI_FP32)) { + return GRPC.fromByteString(v.getVectorBytes()); + } else { + return GRPC.fromByteStringMulti(v.getVectorBytes()); + } + }))); + } + var metadata = new ObjectMetadata(meta.getId(), vectors); + return new WeaviateObject<>(collection, objectProperties, referenceProperties, metadata); + } + + /* + * Convert Map to Map such that can be + * (de-)serialized by {@link Gson}. + */ + private static Map convertProtoMap(Map map) { + return map.entrySet().stream() + // We cannot use Collectors.toMap() here, because convertProtoValue may + // return null (a collection property can be null), which breaks toMap(). + // See: https://bugs.openjdk.org/browse/JDK-8148463 + .collect( + HashMap::new, + (m, e) -> m.put(e.getKey(), convertProtoValue(e.getValue())), + HashMap::putAll); + } + + /** + * Convert protobuf's Value stub to an Object by extracting the first available + * field. The checks are non-exhaustive and only cover text, boolean, and + * integer values. + */ + private static Object convertProtoValue(Value value) { + if (value.hasNullValue()) { + // return value.getNullValue(); + return null; + } else if (value.hasTextValue()) { + return value.getTextValue(); + } else if (value.hasBoolValue()) { + return value.getBoolValue(); + } else if (value.hasIntValue()) { + return value.getIntValue(); + } else if (value.hasNumberValue()) { + return value.getNumberValue(); + } else if (value.hasDateValue()) { + OffsetDateTime offsetDateTime = OffsetDateTime.parse(value.getDateValue()); + return Date.from(offsetDateTime.toInstant()); + } else { + assert false : "branch not covered"; + } + return null; + } + + public void delete(String id) throws IOException { + try (CloseableHttpClient httpclient = HttpClients.createDefault()) { + ClassicHttpRequest httpGet = ClassicRequestBuilder + .delete(config.baseUrl() + "/objects/" + collectionName + "/" + id) + .build(); + + httpClient.http.execute(httpGet, response -> { + if (response.getCode() != HttpStatus.SC_NO_CONTENT) { + throw new RuntimeException(EntityUtils.toString(response.getEntity())); + } + return null; + }); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/collections/data/FetchByIdRequest.java b/src/main/java/io/weaviate/client6/v1/collections/data/FetchByIdRequest.java new file mode 100644 index 000000000..28cb8635f --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/data/FetchByIdRequest.java @@ -0,0 +1,112 @@ +package io.weaviate.client6.v1.collections.data; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.function.Consumer; + +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoBase.FilterTarget; +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoBase.Filters; +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoBase.Filters.Operator; +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoSearchGet.MetadataRequest; +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoSearchGet.PropertiesRequest; +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoSearchGet.RefPropertiesRequest; +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoSearchGet.SearchRequest; +import io.weaviate.client6.v1.collections.query.QueryReference; + +public record FetchByIdRequest( + String collection, + String id, + boolean includeVector, + List includeVectors, + List returnProperties, + List returnReferences) { + + public FetchByIdRequest(Builder options) { + this( + options.collection, + options.uuid, + options.includeVector, + options.includeVectors, + options.returnProperties, + options.returnReferences); + } + + public static FetchByIdRequest of(String collection, String uuid, Consumer fn) { + var builder = new Builder(collection, uuid); + fn.accept(builder); + return new FetchByIdRequest(builder); + } + + public static class Builder { + private final String collection; + private final String uuid; + + public Builder(String collection, String uuid) { + this.collection = collection; + this.uuid = uuid; + } + + private boolean includeVector; + private List includeVectors = new ArrayList<>(); + private List returnProperties = new ArrayList<>(); + private List returnReferences = new ArrayList<>(); + + public final Builder includeVector() { + this.includeVector = true; + return this; + } + + public final Builder includeVectors(String... vectors) { + this.includeVectors = Arrays.asList(vectors); + return this; + } + + public final Builder returnProperties(String... properties) { + this.returnProperties = Arrays.asList(properties); + return this; + } + + public final Builder returnReferences(QueryReference... references) { + this.returnReferences = Arrays.asList(references); + return this; + } + + } + + void appendTo(SearchRequest.Builder req) { + req.setLimit(1); + req.setCollection(collection); + + req.setFilters(Filters.newBuilder() + .setTarget(FilterTarget.newBuilder().setProperty("_id")) + .setValueText(id) + .setOperator(Operator.OPERATOR_EQUAL)); + + if (!returnProperties.isEmpty() || !returnReferences.isEmpty()) { + var properties = PropertiesRequest.newBuilder(); + + if (!returnProperties.isEmpty()) { + properties.addAllNonRefProperties(returnProperties); + } + + if (!returnReferences.isEmpty()) { + returnReferences.forEach(r -> { + var references = RefPropertiesRequest.newBuilder(); + r.appendTo(references); + properties.addRefProperties(references); + }); + } + req.setProperties(properties); + } + + // Always request UUID back in this request. + var metadata = MetadataRequest.newBuilder().setUuid(true); + if (includeVector) { + metadata.setVector(true); + } else if (!includeVectors.isEmpty()) { + metadata.addAllVectors(includeVectors); + } + req.setMetadata(metadata); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/data/GetParameters.java b/src/main/java/io/weaviate/client6/v1/collections/data/GetParameters.java similarity index 97% rename from src/main/java/io/weaviate/client6/v1/data/GetParameters.java rename to src/main/java/io/weaviate/client6/v1/collections/data/GetParameters.java index 819ffc09c..a0e4afd43 100644 --- a/src/main/java/io/weaviate/client6/v1/data/GetParameters.java +++ b/src/main/java/io/weaviate/client6/v1/collections/data/GetParameters.java @@ -1,4 +1,4 @@ -package io.weaviate.client6.v1.data; +package io.weaviate.client6.v1.collections.data; import java.util.LinkedHashSet; import java.util.List; diff --git a/src/main/java/io/weaviate/client6/v1/collections/data/InsertObjectRequest.java b/src/main/java/io/weaviate/client6/v1/collections/data/InsertObjectRequest.java new file mode 100644 index 000000000..a04f993ea --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/data/InsertObjectRequest.java @@ -0,0 +1,149 @@ +package io.weaviate.client6.v1.collections.data; + +import java.io.IOException; +import java.io.StringWriter; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; + +import com.google.gson.Gson; +import com.google.gson.reflect.TypeToken; + +import io.weaviate.client6.v1.collections.Reference; +import io.weaviate.client6.v1.collections.object.Vectors; + +public record InsertObjectRequest(String collection, T properties, String id, Vectors vectors, + Map> references) { + + /** Create InsertObjectRequest from Builder options. */ + public InsertObjectRequest(Builder builder) { + this(builder.collection, builder.properties, builder.id, builder.vectors, builder.references); + } + + /** + * Construct InsertObjectRequest with optional parameters. + * + * @param Shape of the object properties, e.g. + * {@code Map} + * @param collection Collection to insert to. + * @param properties Object properties. + * @param fn Optional parameters + * @return InsertObjectRequest + */ + static InsertObjectRequest of(String collection, T properties, Consumer> fn) { + var builder = new Builder<>(collection, properties); + fn.accept(builder); + return builder.build(); + } + + public static class Builder { + private final String collection; // Required + private final T properties; // Required + + private String id; + private Vectors vectors; + private final Map> references = new HashMap<>(); + + Builder(String collection, T properties) { + this.collection = collection; + this.properties = properties; + } + + /** Define custom object id. Must be a valid UUID. */ + public Builder id(String id) { + this.id = id; + return this; + } + + /** + * Supply one or more (named) vectors. Calls to {@link #vectors} are not + * chainable. Use {@link Vectors#of(Consumer)} to pass multiple vectors. + */ + public Builder vectors(Vectors vectors) { + this.vectors = vectors; + return this; + } + + /** + * Add a reference. Calls to {@link #reference} can be chained + * to add multiple references. + */ + public Builder reference(String property, Reference... references) { + for (var ref : references) { + addReference(property, ref); + } + return this; + } + + private void addReference(String property, Reference reference) { + if (!references.containsKey(property)) { + references.put(property, new ArrayList<>()); + } + references.get(property).add(reference); + } + + /** Build a new InsertObjectRequest. */ + public InsertObjectRequest build() { + return new InsertObjectRequest<>(this); + } + } + + // Here we're just rawdogging JSON serialization just to get a good feel for it. + public String serialize(Gson gson) throws IOException { + var buf = new StringWriter(); + var w = gson.newJsonWriter(buf); + + w.beginObject(); + + w.name("class"); + w.value(collection); + + if (id != null) { + w.name("id"); + w.value(id); + } + + if (vectors != null) { + var unnamed = vectors.getUnnamed(); + if (unnamed.isPresent()) { + w.name("vector"); + gson.getAdapter(Float[].class).write(w, unnamed.get()); + } else { + w.name("vectors"); + gson.getAdapter(new TypeToken>() { + }).write(w, vectors.getNamed()); + } + } + + if (properties != null || references != null) { + w.name("properties"); + w.beginObject(); + + if (properties != null) { + assert properties instanceof Map : "properties not a Map"; + for (var entry : ((Map) properties).entrySet()) { + w.name(entry.getKey()); + gson.getAdapter(Object.class).write(w, entry.getValue()); + } + + } + if (references != null && !references.isEmpty()) { + for (var entry : references.entrySet()) { + w.name(entry.getKey()); + w.beginArray(); + for (var ref : entry.getValue()) { + ref.writeValue(w); + } + w.endArray(); + } + } + + w.endObject(); + } + + w.endObject(); + return buf.toString(); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/data/QueryParameters.java b/src/main/java/io/weaviate/client6/v1/collections/data/QueryParameters.java similarity index 95% rename from src/main/java/io/weaviate/client6/v1/data/QueryParameters.java rename to src/main/java/io/weaviate/client6/v1/collections/data/QueryParameters.java index 2b94c834c..6f7eda826 100644 --- a/src/main/java/io/weaviate/client6/v1/data/QueryParameters.java +++ b/src/main/java/io/weaviate/client6/v1/collections/data/QueryParameters.java @@ -1,4 +1,4 @@ -package io.weaviate.client6.v1.data; +package io.weaviate.client6.v1.collections.data; import java.io.UnsupportedEncodingException; import java.net.URLEncoder; diff --git a/src/main/java/io/weaviate/client6/v1/collections/object/ObjectMetadata.java b/src/main/java/io/weaviate/client6/v1/collections/object/ObjectMetadata.java new file mode 100644 index 000000000..61ffcb9de --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/object/ObjectMetadata.java @@ -0,0 +1,30 @@ +package io.weaviate.client6.v1.collections.object; + +import java.util.function.Consumer; + +public record ObjectMetadata(String id, Vectors vectors) { + + public static ObjectMetadata with(Consumer options) { + var opt = new Builder(options); + return new ObjectMetadata(opt.id, opt.vectors); + } + + public static class Builder { + private String id; + private Vectors vectors; + + public Builder id(String id) { + this.id = id; + return this; + } + + public Builder vectors(Vectors vectors) { + this.vectors = vectors; + return this; + } + + private Builder(Consumer options) { + options.accept(this); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/collections/object/ObjectReference.java b/src/main/java/io/weaviate/client6/v1/collections/object/ObjectReference.java new file mode 100644 index 000000000..bc5c82c04 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/object/ObjectReference.java @@ -0,0 +1,6 @@ +package io.weaviate.client6.v1.collections.object; + +import java.util.List; + +public record ObjectReference(List> objects) { +} diff --git a/src/main/java/io/weaviate/client6/v1/Vectors.java b/src/main/java/io/weaviate/client6/v1/collections/object/Vectors.java similarity index 87% rename from src/main/java/io/weaviate/client6/v1/Vectors.java rename to src/main/java/io/weaviate/client6/v1/collections/object/Vectors.java index 9a69a5fa9..db0669e90 100644 --- a/src/main/java/io/weaviate/client6/v1/Vectors.java +++ b/src/main/java/io/weaviate/client6/v1/collections/object/Vectors.java @@ -1,4 +1,4 @@ -package io.weaviate.client6.v1; +package io.weaviate.client6.v1.collections.object; import java.util.Collections; import java.util.HashMap; @@ -6,10 +6,13 @@ import java.util.Optional; import java.util.function.Consumer; +import lombok.ToString; + /** * Vectors is an abstraction over named vectors. * It may contain both 1-dimensional and 2-dimensional vectors. */ +@ToString public class Vectors { private static final String DEFAULT = "default"; @@ -53,7 +56,7 @@ private Optional getOnly() { return Optional.ofNullable(namedVectors.values().iterator().next()); } - public Map asMap() { + public Map getNamed() { return Map.copyOf(namedVectors); } @@ -73,9 +76,8 @@ private Vectors(Map vectors) { this.namedVectors = Collections.unmodifiableMap(vectors); } - static Vectors with(Consumer named) { - var vectors = new NamedVectors(named); - return new Vectors(vectors.namedVectors); + private Vectors(NamedVectors named) { + this.namedVectors = named.namedVectors; } /** @@ -106,6 +108,12 @@ public static Vectors of(Map vectors) { return new Vectors(vectors); } + public static Vectors of(Consumer fn) { + var named = new NamedVectors(); + fn.accept(named); + return named.build(); + } + public static class NamedVectors { private Map namedVectors = new HashMap<>(); @@ -119,8 +127,8 @@ public NamedVectors vector(String name, Float[][] vector) { return this; } - NamedVectors(Consumer options) { - options.accept(this); + public Vectors build() { + return new Vectors(this); } } } diff --git a/src/main/java/io/weaviate/client6/v1/data/WeaviateObject.java b/src/main/java/io/weaviate/client6/v1/collections/object/WeaviateObject.java similarity index 59% rename from src/main/java/io/weaviate/client6/v1/data/WeaviateObject.java rename to src/main/java/io/weaviate/client6/v1/collections/object/WeaviateObject.java index edc08d251..28d5cc3b2 100644 --- a/src/main/java/io/weaviate/client6/v1/data/WeaviateObject.java +++ b/src/main/java/io/weaviate/client6/v1/collections/object/WeaviateObject.java @@ -1,21 +1,23 @@ -package io.weaviate.client6.v1.data; +package io.weaviate.client6.v1.collections.object; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; +import java.util.Map; import java.util.function.Consumer; import com.google.common.reflect.TypeToken; import com.google.gson.Gson; -import io.weaviate.client6.v1.ObjectMetadata; +public record WeaviateObject( + String collection, + T properties, + Map references, + ObjectMetadata metadata) { -// TODO: unify this with collections.SearchObject - -public record WeaviateObject(String collection, T properties, ObjectMetadata metadata) { - - WeaviateObject(String collection, T properties, Consumer options) { - this(collection, properties, ObjectMetadata.with(options)); + public WeaviateObject(String collection, T properties, Map references, + Consumer options) { + this(collection, properties, references, ObjectMetadata.with(options)); } // JSON serialization ---------------- diff --git a/src/main/java/io/weaviate/client6/v1/data/WeaviateObjectDTO.java b/src/main/java/io/weaviate/client6/v1/collections/object/WeaviateObjectDTO.java similarity index 78% rename from src/main/java/io/weaviate/client6/v1/data/WeaviateObjectDTO.java rename to src/main/java/io/weaviate/client6/v1/collections/object/WeaviateObjectDTO.java index ed9b00af8..e57afc1b5 100644 --- a/src/main/java/io/weaviate/client6/v1/data/WeaviateObjectDTO.java +++ b/src/main/java/io/weaviate/client6/v1/collections/object/WeaviateObjectDTO.java @@ -1,4 +1,4 @@ -package io.weaviate.client6.v1.data; +package io.weaviate.client6.v1.collections.object; import java.util.ArrayList; import java.util.HashMap; @@ -6,9 +6,6 @@ import com.google.gson.annotations.SerializedName; -import io.weaviate.client6.v1.ObjectMetadata; -import io.weaviate.client6.v1.Vectors; - class WeaviateObjectDTO { @SerializedName("class") String collection; @@ -26,7 +23,7 @@ class WeaviateObjectDTO { if (object.metadata() != null) { this.id = object.metadata().id(); if (object.metadata().vectors() != null) { - this.vectors = object.metadata().vectors().asMap(); + this.vectors = object.metadata().vectors().getNamed(); } } } @@ -44,6 +41,9 @@ WeaviateObject toWeaviateObject() { arrayVectors.put(entry.getKey(), vector); } } - return new WeaviateObject(collection, properties, new ObjectMetadata(id, Vectors.of(arrayVectors))); + + return new WeaviateObject(collection, properties, + /* no references through HTTP */ new HashMap<>(), + new ObjectMetadata(id, Vectors.of(arrayVectors))); } } diff --git a/src/main/java/io/weaviate/client6/v1/query/CommonQueryOptions.java b/src/main/java/io/weaviate/client6/v1/collections/query/CommonQueryOptions.java similarity index 85% rename from src/main/java/io/weaviate/client6/v1/query/CommonQueryOptions.java rename to src/main/java/io/weaviate/client6/v1/collections/query/CommonQueryOptions.java index ddf1e1ab1..930ef3836 100644 --- a/src/main/java/io/weaviate/client6/v1/query/CommonQueryOptions.java +++ b/src/main/java/io/weaviate/client6/v1/collections/query/CommonQueryOptions.java @@ -1,4 +1,4 @@ -package io.weaviate.client6.v1.query; +package io.weaviate.client6.v1.collections.query; import java.util.ArrayList; import java.util.Arrays; @@ -18,6 +18,7 @@ public record CommonQueryOptions( String after, String consistencyLevel /* TODO: use ConsistencyLevel enum */, List returnProperties, + List returnReferences, List returnMetadata) { public CommonQueryOptions(Builder> options) { @@ -28,6 +29,7 @@ public CommonQueryOptions(Builder> options) { options.after, options.consistencyLevel, options.returnProperties, + options.returnReferences, options.returnMetadata); } @@ -39,6 +41,7 @@ public static abstract class Builder> { private String after; private String consistencyLevel; private List returnProperties = new ArrayList<>(); + private List returnReferences = new ArrayList<>(); private List returnMetadata = new ArrayList<>(); public final SELF limit(Integer limit) { @@ -66,6 +69,16 @@ public final SELF consistencyLevel(String consistencyLevel) { return (SELF) this; } + public final SELF returnProperties(String... properties) { + this.returnProperties = Arrays.asList(properties); + return (SELF) this; + } + + public final SELF returnReferences(QueryReference references) { + this.returnReferences = Arrays.asList(references); + return (SELF) this; + } + public final SELF returnMetadata(Metadata... metadata) { this.returnMetadata = Arrays.asList(metadata); return (SELF) this; diff --git a/src/main/java/io/weaviate/client6/v1/query/GroupedQueryResult.java b/src/main/java/io/weaviate/client6/v1/collections/query/GroupedQueryResult.java similarity index 84% rename from src/main/java/io/weaviate/client6/v1/query/GroupedQueryResult.java rename to src/main/java/io/weaviate/client6/v1/collections/query/GroupedQueryResult.java index 01b8e68a4..cc50cf7a9 100644 --- a/src/main/java/io/weaviate/client6/v1/query/GroupedQueryResult.java +++ b/src/main/java/io/weaviate/client6/v1/collections/query/GroupedQueryResult.java @@ -1,9 +1,9 @@ -package io.weaviate.client6.v1.query; +package io.weaviate.client6.v1.collections.query; import java.util.List; import java.util.Map; -import io.weaviate.client6.v1.query.QueryResult.SearchObject; +import io.weaviate.client6.v1.collections.query.QueryResult.SearchObject; import lombok.AllArgsConstructor; @AllArgsConstructor diff --git a/src/main/java/io/weaviate/client6/v1/query/Metadata.java b/src/main/java/io/weaviate/client6/v1/collections/query/Metadata.java similarity index 88% rename from src/main/java/io/weaviate/client6/v1/query/Metadata.java rename to src/main/java/io/weaviate/client6/v1/collections/query/Metadata.java index 4cc37bd98..fe1d40889 100644 --- a/src/main/java/io/weaviate/client6/v1/query/Metadata.java +++ b/src/main/java/io/weaviate/client6/v1/collections/query/Metadata.java @@ -1,4 +1,4 @@ -package io.weaviate.client6.v1.query; +package io.weaviate.client6.v1.collections.query; import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoSearchGet.MetadataRequest; diff --git a/src/main/java/io/weaviate/client6/v1/query/MetadataField.java b/src/main/java/io/weaviate/client6/v1/collections/query/MetadataField.java similarity index 93% rename from src/main/java/io/weaviate/client6/v1/query/MetadataField.java rename to src/main/java/io/weaviate/client6/v1/collections/query/MetadataField.java index fbec8a04c..bf4e43986 100644 --- a/src/main/java/io/weaviate/client6/v1/query/MetadataField.java +++ b/src/main/java/io/weaviate/client6/v1/collections/query/MetadataField.java @@ -1,4 +1,4 @@ -package io.weaviate.client6.v1.query; +package io.weaviate.client6.v1.collections.query; import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoSearchGet.MetadataRequest; diff --git a/src/main/java/io/weaviate/client6/v1/collections/query/NearText.java b/src/main/java/io/weaviate/client6/v1/collections/query/NearText.java new file mode 100644 index 000000000..0a585a281 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/query/NearText.java @@ -0,0 +1,35 @@ +package io.weaviate.client6.v1.collections.query; + +import java.util.List; +import java.util.function.Consumer; + +public record NearText(List text, Float distance, Float certainty, CommonQueryOptions common) { + + public static NearText with(String text, Consumer fn) { + return with(List.of(text), fn); + } + + public static NearText with(List text, Consumer fn) { + var opt = new Builder(); + fn.accept(opt); + return new NearText(text, opt.distance, opt.certainty, new CommonQueryOptions(opt)); + } + + public static class Builder extends CommonQueryOptions.Builder { + private Float distance; + private Float certainty; + + public Builder distance(float distance) { + this.distance = distance; + return this; + } + + public Builder certainty(float certainty) { + this.certainty = certainty; + return this; + } + } + + public static record GroupBy(String property, int maxGroups, int maxObjectsPerGroup) { + } +} diff --git a/src/main/java/io/weaviate/client6/v1/query/NearVector.java b/src/main/java/io/weaviate/client6/v1/collections/query/NearVector.java similarity index 94% rename from src/main/java/io/weaviate/client6/v1/query/NearVector.java rename to src/main/java/io/weaviate/client6/v1/collections/query/NearVector.java index 6cfee7f8f..3bcc4fef0 100644 --- a/src/main/java/io/weaviate/client6/v1/query/NearVector.java +++ b/src/main/java/io/weaviate/client6/v1/collections/query/NearVector.java @@ -1,4 +1,4 @@ -package io.weaviate.client6.v1.query; +package io.weaviate.client6.v1.collections.query; import java.util.function.Consumer; diff --git a/src/main/java/io/weaviate/client6/v1/query/Query.java b/src/main/java/io/weaviate/client6/v1/collections/query/QueryClient.java similarity index 87% rename from src/main/java/io/weaviate/client6/v1/query/Query.java rename to src/main/java/io/weaviate/client6/v1/collections/query/QueryClient.java index 673ed1f48..cb767d635 100644 --- a/src/main/java/io/weaviate/client6/v1/query/Query.java +++ b/src/main/java/io/weaviate/client6/v1/collections/query/QueryClient.java @@ -1,4 +1,4 @@ -package io.weaviate.client6.v1.query; +package io.weaviate.client6.v1.collections.query; import java.time.OffsetDateTime; import java.util.ArrayList; @@ -18,7 +18,7 @@ import io.weaviate.client6.internal.GrpcClient; import io.weaviate.client6.internal.codec.grpc.v1.SearchMarshaler; -public class Query { +public class QueryClient { // TODO: this should be wrapped around in some TypeInspector etc. private final String collectionName; @@ -26,11 +26,18 @@ public class Query { // (probably on a "higher" level); private final GrpcClient grpcClient; - public Query(String collectionName, GrpcClient grpc) { + public QueryClient(String collectionName, GrpcClient grpc) { this.grpcClient = grpc; this.collectionName = collectionName; } + public QueryResult nearVector(Float[] vector) { + var query = NearVector.with(vector, opt -> { + }); + var req = new SearchMarshaler(collectionName).addNearVector(query); + return search(req.marshal()); + } + public QueryResult nearVector(Float[] vector, Consumer options) { var query = NearVector.with(vector, options); var req = new SearchMarshaler(collectionName).addNearVector(query); @@ -53,6 +60,19 @@ public GroupedQueryResult nearVector(Float[] vector, NearVector.GroupBy group return searchGrouped(req.marshal()); } + public QueryResult nearText(String text, Consumer fn) { + var query = NearText.with(text, fn); + var req = new SearchMarshaler(collectionName).addNearText(query); + return search(req.marshal()); + } + + public QueryResult nearText(String text) { + var query = NearText.with(text, opt -> { + }); + var req = new SearchMarshaler(collectionName).addNearText(query); + return search(req.marshal()); + } + private QueryResult search(SearchRequest req) { var reply = grpcClient.grpc.search(req); return deserializeUntyped(reply); diff --git a/src/main/java/io/weaviate/client6/v1/collections/query/QueryReference.java b/src/main/java/io/weaviate/client6/v1/collections/query/QueryReference.java new file mode 100644 index 000000000..82902f491 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/query/QueryReference.java @@ -0,0 +1,132 @@ +package io.weaviate.client6.v1.collections.query; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.function.Consumer; + +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoSearchGet.MetadataRequest; +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoSearchGet.PropertiesRequest; +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoSearchGet.RefPropertiesRequest; + +public record QueryReference( + String property, + String collection, + boolean includeVector, List includeVectors, + List returnProperties, + List returnReferences, + List returnMetadata) { + + public QueryReference(Builder options) { + this( + options.property, + options.collection, + options.includeVector, + options.includeVectors, + options.returnProperties, + options.returnReferences, + options.returnMetadata); + } + + public static QueryReference single(String property) { + return single(property, opt -> { + }); + } + + public static QueryReference single(String property, Consumer fn) { + var builder = new Builder(null, property); + fn.accept(builder); + return new QueryReference(builder); + } + + // TODO: check if we can supply mutiple collections + public static QueryReference multi(String property, String collection) { + return multi(collection, property, opt -> { + }); + } + + public static QueryReference multi(String property, String collection, Consumer fn) { + var builder = new Builder(collection, property); + fn.accept(builder); + return new QueryReference(builder); + } + + public static QueryReference[] multi(String property, Consumer fn, String... collections) { + return Arrays.stream(collections).map(collection -> { + var builder = new Builder(collection, property); + fn.accept(builder); + return new QueryReference(builder); + }).toArray(QueryReference[]::new); + } + + public static class Builder { + private final String property; + private final String collection; + + public Builder(String collection, String property) { + this.property = property; + this.collection = collection; + } + + private boolean includeVector; + private List includeVectors = new ArrayList<>(); + private List returnProperties = new ArrayList<>(); + private List returnReferences = new ArrayList<>(); + private List returnMetadata = new ArrayList<>(); + + public final Builder includeVector() { + this.includeVector = true; + return this; + } + + public final Builder includeVectors(String... vectors) { + this.includeVectors = Arrays.asList(vectors); + return this; + } + + public final Builder returnProperties(String... properties) { + this.returnProperties = Arrays.asList(properties); + return this; + } + + public final Builder returnReferences(QueryReference... references) { + this.returnReferences = Arrays.asList(references); + return this; + } + + public final Builder returnMetadata(Metadata... metadata) { + this.returnMetadata = Arrays.asList(metadata); + return this; + } + } + + public void appendTo(RefPropertiesRequest.Builder references) { + references.setReferenceProperty(property); + if (collection != null) { + references.setTargetCollection(collection); + } + + if (!returnMetadata.isEmpty()) { + var metadata = MetadataRequest.newBuilder(); + returnMetadata.forEach(m -> m.appendTo(metadata)); + references.setMetadata(metadata); + } + + if (!returnProperties.isEmpty() || !returnReferences.isEmpty()) { + var properties = PropertiesRequest.newBuilder(); + + if (!returnProperties.isEmpty()) { + properties.addAllNonRefProperties(returnProperties); + } + + if (!returnReferences.isEmpty()) { + returnReferences.forEach(r -> { + var ref = RefPropertiesRequest.newBuilder(); + r.appendTo(ref); + properties.addRefProperties(ref); + }); + } + references.setProperties(properties); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/query/QueryResult.java b/src/main/java/io/weaviate/client6/v1/collections/query/QueryResult.java similarity index 89% rename from src/main/java/io/weaviate/client6/v1/query/QueryResult.java rename to src/main/java/io/weaviate/client6/v1/collections/query/QueryResult.java index 3d03a9840..0fac388f1 100644 --- a/src/main/java/io/weaviate/client6/v1/query/QueryResult.java +++ b/src/main/java/io/weaviate/client6/v1/collections/query/QueryResult.java @@ -1,4 +1,4 @@ -package io.weaviate.client6.v1.query; +package io.weaviate.client6.v1.collections.query; import java.util.List; diff --git a/src/main/java/io/weaviate/client6/v1/data/Data.java b/src/main/java/io/weaviate/client6/v1/data/Data.java deleted file mode 100644 index 54b476b22..000000000 --- a/src/main/java/io/weaviate/client6/v1/data/Data.java +++ /dev/null @@ -1,94 +0,0 @@ -package io.weaviate.client6.v1.data; - -import java.io.IOException; -import java.util.Optional; -import java.util.function.Consumer; - -import org.apache.hc.client5.http.impl.classic.CloseableHttpClient; -import org.apache.hc.client5.http.impl.classic.HttpClients; -import org.apache.hc.core5.http.ClassicHttpRequest; -import org.apache.hc.core5.http.ContentType; -import org.apache.hc.core5.http.HttpStatus; -import org.apache.hc.core5.http.io.entity.EntityUtils; -import org.apache.hc.core5.http.io.support.ClassicRequestBuilder; - -import com.google.gson.Gson; - -import io.weaviate.client6.Config; -import io.weaviate.client6.internal.HttpClient; -import io.weaviate.client6.v1.ObjectMetadata; -import lombok.AllArgsConstructor; - -@AllArgsConstructor -public class Data { - // TODO: inject singleton as dependency - private static final Gson gson = new Gson(); - - // TODO: this should be wrapped around in some TypeInspector etc. - private final String collectionName; - - // TODO: hide befind an internal HttpClient - private final Config config; - private final HttpClient httpClient; - - public WeaviateObject insert(T object) throws IOException { - return insert(object, opt -> { - }); - } - - public WeaviateObject insert(T object, Consumer options) throws IOException { - var body = new WeaviateObject<>(collectionName, object, options); - ClassicHttpRequest httpPost = ClassicRequestBuilder - .post(config.baseUrl() + "/objects") - .setEntity(body.toJson(gson), ContentType.APPLICATION_JSON) - .build(); - - return httpClient.http.execute(httpPost, response -> { - var entity = response.getEntity(); - if (response.getCode() != HttpStatus.SC_SUCCESS) { // Does not return 201 - var message = EntityUtils.toString(entity); - throw new RuntimeException("HTTP " + response.getCode() + ": " + message); - } - - return WeaviateObject.fromJson(gson, entity.getContent()); - }); - } - - public Optional> get(String id) throws IOException { - return get(id, q -> { - }); - } - - public Optional> get(String id, Consumer query) throws IOException { - try (CloseableHttpClient httpclient = HttpClients.createDefault()) { - ClassicHttpRequest httpGet = ClassicRequestBuilder - .get(config.baseUrl() + "/objects/" + collectionName + "/" + id + QueryParameters.encodeGet(query)) - .build(); - - return httpClient.http.execute(httpGet, response -> { - if (response.getCode() == HttpStatus.SC_NOT_FOUND) { - return Optional.empty(); - } - - WeaviateObject object = WeaviateObject.fromJson( - gson, response.getEntity().getContent()); - return Optional.ofNullable(object); - }); - } - } - - public void delete(String id) throws IOException { - try (CloseableHttpClient httpclient = HttpClients.createDefault()) { - ClassicHttpRequest httpGet = ClassicRequestBuilder - .delete(config.baseUrl() + "/objects/" + collectionName + "/" + id) - .build(); - - httpClient.http.execute(httpGet, response -> { - if (response.getCode() != HttpStatus.SC_NO_CONTENT) { - throw new RuntimeException(EntityUtils.toString(response.getEntity())); - } - return null; - }); - } - } -} diff --git a/src/test/java/io/weaviate/client6/v1/ObjectMetadataTest.java b/src/test/java/io/weaviate/client6/v1/ObjectMetadataTest.java index 542f5f7da..f95446fdc 100644 --- a/src/test/java/io/weaviate/client6/v1/ObjectMetadataTest.java +++ b/src/test/java/io/weaviate/client6/v1/ObjectMetadataTest.java @@ -5,6 +5,9 @@ import org.assertj.core.api.Assertions; import org.junit.Test; +import io.weaviate.client6.v1.collections.object.ObjectMetadata; +import io.weaviate.client6.v1.collections.object.Vectors; + public class ObjectMetadataTest { @Test @@ -28,7 +31,7 @@ public final void testVectorsMetadata_unnamed() { @Test public final void testVectorsMetadata_default() { Float[] vector = { 1f, 2f, 3f }; - var metadata = ObjectMetadata.with(m -> m.vectors(vector)); + var metadata = ObjectMetadata.with(m -> m.vectors(Vectors.of(vector))); Assertions.assertThat(metadata.vectors()) .as("default vector").isNotNull() @@ -40,7 +43,7 @@ public final void testVectorsMetadata_default() { @Test public final void testVectorsMetadata_default_2d() { Float[][] vector = { { 1f, 2f, 3f }, { 1f, 2f, 3f } }; - var metadata = ObjectMetadata.with(m -> m.vectors(vector)); + var metadata = ObjectMetadata.with(m -> m.vectors(Vectors.of(vector))); Assertions.assertThat(metadata.vectors()) .as("default 2d vector").isNotNull() @@ -52,7 +55,7 @@ public final void testVectorsMetadata_default_2d() { @Test public final void testVectorsMetadata_named() { Float[] vector = { 1f, 2f, 3f }; - var metadata = ObjectMetadata.with(m -> m.vectors("vector-1", vector)); + var metadata = ObjectMetadata.with(m -> m.vectors(Vectors.of("vector-1", vector))); Assertions.assertThat(metadata.vectors()) .as("named vector").isNotNull() @@ -64,7 +67,7 @@ public final void testVectorsMetadata_named() { @Test public final void testVectorsMetadata_named_2d() { Float[][] vector = { { 1f, 2f, 3f }, { 1f, 2f, 3f } }; - var metadata = ObjectMetadata.with(m -> m.vectors("vector-1", vector)); + var metadata = ObjectMetadata.with(m -> m.vectors(Vectors.of("vector-1", vector))); Assertions.assertThat(metadata.vectors()) .as("named 2d vector").isNotNull() @@ -78,9 +81,9 @@ public final void testVectorsMetadata_multiple_named() { Float[][] vector_1 = { { 1f, 2f, 3f }, { 1f, 2f, 3f } }; Float[] vector_2 = { 4f, 5f, 6f }; var metadata = ObjectMetadata.with(m -> m.vectors( - named -> named + Vectors.of(named -> named .vector("vector-1", vector_1) - .vector("vector-2", vector_2))); + .vector("vector-2", vector_2)))); Assertions.assertThat(metadata.vectors()) .as("multiple named vectors").isNotNull() diff --git a/src/test/java/io/weaviate/client6/v1/collections/VectorsTest.java b/src/test/java/io/weaviate/client6/v1/collections/VectorsTest.java index 8deae4893..be7ef7174 100644 --- a/src/test/java/io/weaviate/client6/v1/collections/VectorsTest.java +++ b/src/test/java/io/weaviate/client6/v1/collections/VectorsTest.java @@ -25,7 +25,7 @@ public class VectorsTest { // private static final Gson gson = new Gson(); static { - DtoTypeAdapterFactory.register(CollectionDefinition.class, CollectionDefinitionDTO.class, + DtoTypeAdapterFactory.register(Collection.class, CollectionDefinitionDTO.class, m -> new CollectionDefinitionDTO(m)); } private static final Gson gson = new GsonBuilder() @@ -108,13 +108,13 @@ public static Object[][] testCases() { @Test @DataMethod(source = VectorsTest.class, method = "testCases") - public void test_toJson(String want, CollectionDefinition collection, String... compareKeys) { + public void test_toJson(String want, Collection collection, String... compareKeys) { var got = gson.toJson(collection); assertEqual(want, got, compareKeys); } - private static CollectionDefinition collectionWithVectors(Vectors vectors) { - return new CollectionDefinition("Things", List.of(), vectors); + private static Collection collectionWithVectors(Vectors vectors) { + return new Collection("Things", List.of(), vectors); } private void assertEqual(String wantJson, String gotJson, String... compareKeys) { diff --git a/src/test/java/io/weaviate/client6/v1/data/QueryParametersTest.java b/src/test/java/io/weaviate/client6/v1/collections/data/QueryParametersTest.java similarity index 96% rename from src/test/java/io/weaviate/client6/v1/data/QueryParametersTest.java rename to src/test/java/io/weaviate/client6/v1/collections/data/QueryParametersTest.java index 6c6caa676..275c8a9b7 100644 --- a/src/test/java/io/weaviate/client6/v1/data/QueryParametersTest.java +++ b/src/test/java/io/weaviate/client6/v1/collections/data/QueryParametersTest.java @@ -1,4 +1,4 @@ -package io.weaviate.client6.v1.data; +package io.weaviate.client6.v1.collections.data; import org.assertj.core.api.Assertions; import org.junit.Test;