diff --git a/pom.xml b/pom.xml index 4339bba88..3a04bd412 100644 --- a/pom.xml +++ b/pom.xml @@ -349,6 +349,7 @@ ${project.basedir}/src/it/java + ${project.basedir}/src/it/resources diff --git a/src/it/java/io/weaviate/containers/Container.java b/src/it/java/io/weaviate/containers/Container.java index 9ef69cfe8..2c4218a9f 100644 --- a/src/it/java/io/weaviate/containers/Container.java +++ b/src/it/java/io/weaviate/containers/Container.java @@ -16,6 +16,7 @@ public class Container { public static final Weaviate WEAVIATE = Weaviate.createDefault(); public static final Contextionary CONTEXTIONARY = Contextionary.createDefault(); + public static final Img2VecNeural IMG2VEC_NEURAL = Img2VecNeural.createDefault(); static { startAll(); @@ -39,21 +40,23 @@ static void stopAll() { WEAVIATE.stop(); } - public static Group compose(Weaviate weaviate, GenericContainer... containers) { - return new Group(weaviate, containers); + public static ContainerGroup compose(Weaviate weaviate, GenericContainer... containers) { + return new ContainerGroup(weaviate, containers); } public static TestRule asTestRule(Startable container) { return new PerTestSuite(container); }; - public static class Group implements Startable { + public static class ContainerGroup implements Startable { private final Weaviate weaviate; private final List> containers; - private Group(Weaviate weaviate, GenericContainer... containers) { + private ContainerGroup(Weaviate weaviate, GenericContainer... containers) { this.weaviate = weaviate; this.containers = Arrays.asList(containers); + + weaviate.dependsOn(containers); setSharedNetwork(); } @@ -63,8 +66,7 @@ public WeaviateClient getClient() { @Override public void start() { - containers.forEach(GenericContainer::start); - weaviate.start(); + weaviate.start(); // testcontainers will resolve dependencies } @Override diff --git a/src/it/java/io/weaviate/containers/Contextionary.java b/src/it/java/io/weaviate/containers/Contextionary.java index 76ec5aefd..69abde7df 100644 --- a/src/it/java/io/weaviate/containers/Contextionary.java +++ b/src/it/java/io/weaviate/containers/Contextionary.java @@ -35,7 +35,7 @@ public Contextionary build() { .withEnv("EXTENSIONS_STORAGE_ORIGIN", "http://weaviate:8080") .withEnv("NEIGHBOR_OCCURRENCE_IGNORE_PERCENTILE", "5") .withEnv("ENABLE_COMPOUND_SPLITTING", "'false'"); - container.withCreateContainerCmdModifier(cmd -> cmd.withHostName("contextionary")); + container.withCreateContainerCmdModifier(cmd -> cmd.withHostName(HOST_NAME)); return container; } } diff --git a/src/it/java/io/weaviate/containers/Img2VecNeural.java b/src/it/java/io/weaviate/containers/Img2VecNeural.java new file mode 100644 index 000000000..28bddd990 --- /dev/null +++ b/src/it/java/io/weaviate/containers/Img2VecNeural.java @@ -0,0 +1,38 @@ +package io.weaviate.containers; + +import org.testcontainers.containers.GenericContainer; + +public class Img2VecNeural extends GenericContainer { + public static final String DOCKER_IMAGE = "cr.weaviate.io/semitechnologies/img2vec-pytorch"; + public static final String VERSION = "resnet50"; + + public static final String MODULE = "img2vec-neural"; + public static final String HOST_NAME = MODULE; + public static final String URL = HOST_NAME + ":8080"; + + static Img2VecNeural createDefault() { + return new Builder().build(); + } + + static Img2VecNeural.Builder custom() { + return new Builder(); + } + + public static class Builder { + private String versionTag; + + public Builder() { + this.versionTag = VERSION; + } + + public Img2VecNeural build() { + var container = new Img2VecNeural(DOCKER_IMAGE + ":" + versionTag); + container.withCreateContainerCmdModifier(cmd -> cmd.withHostName(HOST_NAME)); + return container; + } + } + + public Img2VecNeural(String image) { + super(image); + } +} diff --git a/src/it/java/io/weaviate/containers/Weaviate.java b/src/it/java/io/weaviate/containers/Weaviate.java index 6126d5150..d6251c028 100644 --- a/src/it/java/io/weaviate/containers/Weaviate.java +++ b/src/it/java/io/weaviate/containers/Weaviate.java @@ -1,7 +1,10 @@ package io.weaviate.containers; import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; import java.util.HashSet; +import java.util.Map; import java.util.Set; import org.testcontainers.weaviate.WeaviateContainer; @@ -10,7 +13,7 @@ import io.weaviate.client6.WeaviateClient; public class Weaviate extends WeaviateContainer { - private static WeaviateClient clientInstance; + private WeaviateClient clientInstance; public static final String VERSION = "1.29.0"; public static final String DOCKER_IMAGE = "semitechnologies/weaviate"; @@ -42,14 +45,14 @@ public static Weaviate.Builder custom() { public static class Builder { private String versionTag; - private Set enableModules; + private Set enableModules = new HashSet<>(); private String defaultVectorizerModule; - private String contextionaryUrl; private boolean telemetry; + private Map environment = new HashMap<>(); + public Builder() { this.versionTag = VERSION; - this.enableModules = new HashSet<>(); this.telemetry = false; } @@ -58,24 +61,31 @@ public Builder withVersion(String version) { return this; } - public Builder addModule(String module) { - enableModules.add(module); + public Builder addModules(String... modules) { + enableModules.addAll(Arrays.asList(modules)); return this; } public Builder withDefaultVectorizer(String module) { - addModule(module); - defaultVectorizerModule = module; + addModules(module); + environment.put("DEFAULT_VECTORIZER_MODULE", module); return this; } public Builder withContextionaryUrl(String url) { - contextionaryUrl = url; + addModules(Contextionary.MODULE); + environment.put("CONTEXTIONARY_URL", url); + return this; + } + + public Builder withImageInference(String url, String module) { + addModules(module); + environment.put("IMAGE_INFERENCE_API", "http://" + url); return this; } - public Builder enableTelemetry() { - telemetry = true; + public Builder enableTelemetry(boolean enable) { + telemetry = enable; return this; } @@ -83,18 +93,14 @@ public Weaviate build() { var c = new Weaviate(DOCKER_IMAGE + ":" + versionTag); if (!enableModules.isEmpty()) { + c.withEnv("ENABLE_API_BASED_MODULES", "'true'"); c.withEnv("ENABLE_MODULES", String.join(",", enableModules)); } - if (defaultVectorizerModule != null) { - c.withEnv("DEFAULT_VECTORIZER_MODULE", defaultVectorizerModule); - } - if (contextionaryUrl != null) { - c.withEnv("CONTEXTIONARY_URL", contextionaryUrl); - } if (!telemetry) { c.withEnv("DISABLE_TELEMETRY", "true"); } + environment.forEach((name, value) -> c.withEnv(name, value)); c.withCreateContainerCmdModifier(cmd -> cmd.withHostName("weaviate")); return c; } diff --git a/src/it/java/io/weaviate/integration/CollectionsITest.java b/src/it/java/io/weaviate/integration/CollectionsITest.java index fc70e0fde..59083c44b 100644 --- a/src/it/java/io/weaviate/integration/CollectionsITest.java +++ b/src/it/java/io/weaviate/integration/CollectionsITest.java @@ -60,7 +60,6 @@ public void testCrossReferences() throws IOException { // Assert: Things --ownedBy-> Owners Assertions.assertThat(things.config.get()) - // Assertions.assertThat(client.collections.getConfig(nsOwners)) .as("after create Things").get() .satisfies(c -> { Assertions.assertThat(c.references()) diff --git a/src/it/java/io/weaviate/integration/DataITest.java b/src/it/java/io/weaviate/integration/DataITest.java index 50fc5205d..aed87e17c 100644 --- a/src/it/java/io/weaviate/integration/DataITest.java +++ b/src/it/java/io/weaviate/integration/DataITest.java @@ -15,6 +15,7 @@ import io.weaviate.client6.v1.collections.VectorIndex.IndexingStrategy; import io.weaviate.client6.v1.collections.Vectorizer; import io.weaviate.client6.v1.collections.object.Vectors; +import io.weaviate.client6.v1.collections.object.WeaviateObject; import io.weaviate.containers.Container; public class DataITest extends ConcurrentTest { @@ -62,6 +63,29 @@ public void testCreateGetDelete() throws IOException { Assertions.assertThat(object).isEmpty().as("object not exists after deletion"); } + @Test + public void testBlobData() throws IOException { + var nsCats = ns("Cats"); + + client.collections.create(nsCats, + collection -> collection.properties( + Property.text("breed"), + Property.blob("img"))); + + var cats = client.collections.use(nsCats); + var ragdollPng = EncodedMedia.IMAGE; + var ragdoll = cats.data.insert(Map.of( + "breed", "ragdoll", + "img", ragdollPng)); + + var got = cats.data.get(ragdoll.metadata().id(), + cat -> cat.returnProperties("img")); + + Assertions.assertThat(got).get() + .extracting(WeaviateObject::properties, InstanceOfAssertFactories.MAP) + .extractingByKey("img").isEqualTo(ragdollPng); + } + private static void createTestCollections() throws IOException { var awardsGrammy = unique("Grammy"); client.collections.create(awardsGrammy); diff --git a/src/it/java/io/weaviate/integration/EncodedMedia.java b/src/it/java/io/weaviate/integration/EncodedMedia.java new file mode 100644 index 000000000..d075de08f --- /dev/null +++ b/src/it/java/io/weaviate/integration/EncodedMedia.java @@ -0,0 +1,6 @@ +package io.weaviate.integration; + +class EncodedMedia { + public static final String IMAGE = ""; + +} diff --git a/src/it/java/io/weaviate/integration/NearVectorQueryITest.java b/src/it/java/io/weaviate/integration/NearVectorQueryITest.java deleted file mode 100644 index 863652781..000000000 --- a/src/it/java/io/weaviate/integration/NearVectorQueryITest.java +++ /dev/null @@ -1,120 +0,0 @@ -package io.weaviate.integration; - -import java.io.IOException; -import java.util.Collections; -import java.util.Comparator; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import org.assertj.core.api.Assertions; -import org.junit.BeforeClass; -import org.junit.Test; - -import io.weaviate.ConcurrentTest; -import io.weaviate.client6.WeaviateClient; -import io.weaviate.client6.v1.collections.Property; -import io.weaviate.client6.v1.collections.VectorIndex; -import io.weaviate.client6.v1.collections.VectorIndex.IndexingStrategy; -import io.weaviate.client6.v1.collections.Vectorizer; -import io.weaviate.client6.v1.collections.object.Vectors; -import io.weaviate.client6.v1.collections.query.GroupedQueryResult; -import io.weaviate.client6.v1.collections.query.MetadataField; -import io.weaviate.client6.v1.collections.query.NearVector; -import io.weaviate.containers.Container; - -public class NearVectorQueryITest extends ConcurrentTest { - private static final WeaviateClient client = Container.WEAVIATE.getClient(); - - private static final String COLLECTION = unique("Things"); - private static final String VECTOR_INDEX = "bring_your_own"; - private static final List CATEGORIES = List.of("red", "green"); - - /** - * One of the inserted vectors which will be used as target vector for search. - */ - private static Float[] searchVector; - - @BeforeClass - public static void beforeAll() throws IOException { - createTestCollection(); - var created = populateTest(10); - searchVector = created.values().iterator().next(); - } - - @Test - public void testNearVector() { - // TODO: test that we return the results in the expected order - // Because re-ranking should work correctly - var things = client.collections.use(COLLECTION); - var result = things.query.nearVector(searchVector, - opt -> opt - .distance(2f) - .limit(3) - .returnMetadata(MetadataField.DISTANCE)); - - Assertions.assertThat(result.objects).hasSize(3); - float maxDistance = Collections.max(result.objects, - Comparator.comparing(obj -> obj.metadata.distance())).metadata.distance(); - Assertions.assertThat(maxDistance).isLessThanOrEqualTo(2f); - } - - @Test - public void testNearVector_groupBy() { - // TODO: test that we return the results in the expected order - // Because re-ranking should work correctly - var things = client.collections.use(COLLECTION); - var result = things.query.nearVector(searchVector, - new NearVector.GroupBy("category", 2, 5), - opt -> opt.distance(10f)); - - Assertions.assertThat(result.groups) - .as("group per category").containsOnlyKeys(CATEGORIES) - .hasSizeLessThanOrEqualTo(2) - .allSatisfy((category, group) -> { - Assertions.assertThat(group) - .as("group name").returns(category, GroupedQueryResult.Group::name); - Assertions.assertThat(group.numberOfObjects()) - .as("[%s] has 1+ object", category).isLessThanOrEqualTo(5L); - }); - - Assertions.assertThat(result.objects) - .as("object belongs a group") - .allMatch(obj -> result.groups.get(obj.belongsToGroup).objects().contains(obj)); - - } - - /** - * Insert 10 objects with random vectors. - * - * @returns IDs of inserted objects and their corresponding vectors. - */ - private static Map populateTest(int n) throws IOException { - var created = new HashMap(); - - var things = client.collections.use(COLLECTION); - for (int i = 0; i < n; i++) { - var vector = randomVector(10, -.01f, .001f); - var object = things.data.insert( - Map.of("category", CATEGORIES.get(i % CATEGORIES.size())), - metadata -> metadata - .id(randomUUID()) - .vectors(Vectors.of(VECTOR_INDEX, vector))); - - created.put(object.metadata().id(), vector); - } - - return created; - } - - /** - * Create {@link COLLECTION} with {@link VECTOR_INDEX} vector index. - * - * @throws IOException - */ - private static void createTestCollection() throws IOException { - client.collections.create(COLLECTION, cfg -> cfg - .properties(Property.text("category")) - .vector(VECTOR_INDEX, new VectorIndex<>(IndexingStrategy.hnsw(), Vectorizer.none()))); - } -} diff --git a/src/it/java/io/weaviate/integration/SearchITest.java b/src/it/java/io/weaviate/integration/SearchITest.java new file mode 100644 index 000000000..38b83c26a --- /dev/null +++ b/src/it/java/io/weaviate/integration/SearchITest.java @@ -0,0 +1,225 @@ +package io.weaviate.integration; + +import java.io.IOException; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.assertj.core.api.Assertions; +import org.assertj.core.api.InstanceOfAssertFactories; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.rules.TestRule; + +import io.weaviate.ConcurrentTest; +import io.weaviate.client6.WeaviateClient; +import io.weaviate.client6.v1.collections.Property; +import io.weaviate.client6.v1.collections.Reference; +import io.weaviate.client6.v1.collections.VectorIndex; +import io.weaviate.client6.v1.collections.VectorIndex.IndexingStrategy; +import io.weaviate.client6.v1.collections.Vectorizer; +import io.weaviate.client6.v1.collections.object.Vectors; +import io.weaviate.client6.v1.collections.query.GroupedQueryResult; +import io.weaviate.client6.v1.collections.query.MetadataField; +import io.weaviate.client6.v1.collections.query.NearText; +import io.weaviate.client6.v1.collections.query.NearVector; +import io.weaviate.containers.Container; +import io.weaviate.containers.Container.ContainerGroup; +import io.weaviate.containers.Contextionary; +import io.weaviate.containers.Img2VecNeural; +import io.weaviate.containers.Weaviate; + +public class SearchITest extends ConcurrentTest { + private static final ContainerGroup compose = Container.compose( + Weaviate.custom() + .withContextionaryUrl(Contextionary.URL) + .withImageInference(Img2VecNeural.URL, Img2VecNeural.MODULE) + .build(), + Container.IMG2VEC_NEURAL, + Container.CONTEXTIONARY); + @ClassRule // Bind containers to the lifetime of the test + public static final TestRule _rule = compose.asTestRule(); + private static final WeaviateClient client = compose.getClient(); + + private static final String COLLECTION = unique("Things"); + private static final String VECTOR_INDEX = "bring_your_own"; + private static final List CATEGORIES = List.of("red", "green"); + + /** + * One of the inserted vectors which will be used as target vector for search. + */ + private static Float[] searchVector; + + @BeforeClass + public static void beforeAll() throws IOException { + createTestCollection(); + var created = populateTest(10); + searchVector = created.values().iterator().next(); + } + + @Test + public void testNearVector() { + var things = client.collections.use(COLLECTION); + var result = things.query.nearVector(searchVector, + opt -> opt + .distance(2f) + .limit(3) + .returnMetadata(MetadataField.DISTANCE)); + + Assertions.assertThat(result.objects).hasSize(3); + float maxDistance = Collections.max(result.objects, + Comparator.comparing(obj -> obj.metadata.distance())).metadata.distance(); + Assertions.assertThat(maxDistance).isLessThanOrEqualTo(2f); + } + + @Test + public void testNearVector_groupBy() { + var things = client.collections.use(COLLECTION); + var result = things.query.nearVector(searchVector, + new NearVector.GroupBy("category", 2, 5), + opt -> opt.distance(10f)); + + Assertions.assertThat(result.groups) + .as("group per category").containsOnlyKeys(CATEGORIES) + .hasSizeLessThanOrEqualTo(2) + .allSatisfy((category, group) -> { + Assertions.assertThat(group) + .as("group name").returns(category, GroupedQueryResult.Group::name); + Assertions.assertThat(group.numberOfObjects()) + .as("[%s] has 1+ object", category).isLessThanOrEqualTo(5L); + }); + + Assertions.assertThat(result.objects) + .as("object belongs a group") + .allMatch(obj -> result.groups.get(obj.belongsToGroup).objects().contains(obj)); + } + + /** + * Insert 10 objects with random vectors. + * + * @returns IDs of inserted objects and their corresponding vectors. + */ + private static Map populateTest(int n) throws IOException { + var created = new HashMap(); + + var things = client.collections.use(COLLECTION); + for (int i = 0; i < n; i++) { + var vector = randomVector(10, -.01f, .001f); + var object = things.data.insert( + Map.of("category", CATEGORIES.get(i % CATEGORIES.size())), + metadata -> metadata + .id(randomUUID()) + .vectors(Vectors.of(VECTOR_INDEX, vector))); + + created.put(object.metadata().id(), vector); + } + + return created; + } + + /** + * Create {@link COLLECTION} with {@link VECTOR_INDEX} vector index. + * + * @throws IOException + */ + private static void createTestCollection() throws IOException { + client.collections.create(COLLECTION, cfg -> cfg + .properties(Property.text("category")) + .vector(VECTOR_INDEX, new VectorIndex<>(IndexingStrategy.hnsw(), Vectorizer.none()))); + } + + @Test + public void testNearText() throws IOException { + var nsSongs = ns("Songs"); + client.collections.create(nsSongs, + col -> col + .properties(Property.text("title")) + .vector(new VectorIndex<>(IndexingStrategy.hnsw(), Vectorizer.text2vecContextionary()))); + + var songs = client.collections.use(nsSongs); + var submarine = songs.data.insert(Map.of("title", "Yellow Submarine")); + songs.data.insert(Map.of("title", "Run Through The Jungle")); + songs.data.insert(Map.of("title", "Welcome To The Jungle")); + + var result = songs.query.nearText("forest", + opt -> opt + .distance(0.5f) + .moveTo(.98f, to -> to.concepts("tropical")) + .moveAway(.4f, away -> away.uuids(submarine.metadata().id())) + .returnProperties("title")); + + Assertions.assertThat(result.objects).hasSize(2) + .extracting(obj -> obj.properties).allSatisfy( + properties -> Assertions.assertThat(properties) + .allSatisfy((_k, v) -> Assertions.assertThat((String) v).contains("Jungle"))); + } + + @Test + public void testNearText_groupBy() throws IOException { + var vectorIndex = new VectorIndex<>(IndexingStrategy.hnsw(), Vectorizer.text2vecContextionary()); + + var nsArtists = ns("Artists"); + client.collections.create(nsArtists, + col -> col + .properties(Property.text("name")) + .vector(vectorIndex)); + + var artists = client.collections.use(nsArtists); + var beatles = artists.data.insert(Map.of("name", "Beatles")); + var ccr = artists.data.insert(Map.of("name", "CCR")); + + var nsSongs = ns("Songs"); + client.collections.create(nsSongs, + col -> col + .properties(Property.text("title")) + .references(Property.reference("performedBy", nsArtists)) + .vector(vectorIndex)); + + var songs = client.collections.use(nsSongs); + songs.data.insert(Map.of("title", "Yellow Submarine"), + s -> s.reference("performedBy", Reference.objects(beatles))); + songs.data.insert(Map.of("title", "Run Through The Jungle"), + s -> s.reference("performedBy", Reference.objects(ccr))); + + var result = songs.query.nearText("nature", + new NearText.GroupBy("performedBy", 2, 1), + opt -> opt + .returnProperties("title")); + + Assertions.assertThat(result.groups).hasSize(2) + .containsOnlyKeys( + "weaviate://localhost/%s/%s".formatted(nsArtists, beatles.metadata().id()), + "weaviate://localhost/%s/%s".formatted(nsArtists, ccr.metadata().id())); + } + + @Test + // @Ignore("no fitting image to test with") + public void testNearImage() throws IOException { + var nsCats = ns("Cats"); + + client.collections.create(nsCats, + collection -> collection + .properties( + Property.text("breed"), + Property.blob("img")) + .vector(new VectorIndex<>( + IndexingStrategy.hnsw(), + Vectorizer.img2VecNeuralVectorizer( + i2v -> i2v.imageFields("img"))))); + + var cats = client.collections.use(nsCats); + cats.data.insert(Map.of( + "breed", "ragdoll", + "img", EncodedMedia.IMAGE)); + + var got = cats.query.nearImage(EncodedMedia.IMAGE, + opt -> opt.returnProperties("breed")); + + Assertions.assertThat(got.objects).hasSize(1).first() + .extracting(obj -> obj.properties, InstanceOfAssertFactories.MAP) + .extractingByKey("breed").isEqualTo("ragdoll"); + } +} diff --git a/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/SearchMarshaler.java b/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/SearchMarshaler.java index ca6769f01..9e30ef515 100644 --- a/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/SearchMarshaler.java +++ b/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/SearchMarshaler.java @@ -3,6 +3,7 @@ import org.apache.commons.lang3.StringUtils; import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoBaseSearch; +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoBaseSearch.NearTextSearch; import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoSearchGet; import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoSearchGet.MetadataRequest; import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoSearchGet.PropertiesRequest; @@ -10,6 +11,8 @@ import io.weaviate.client6.internal.GRPC; import io.weaviate.client6.internal.codec.grpc.GrpcMarshaler; import io.weaviate.client6.v1.collections.query.CommonQueryOptions; +import io.weaviate.client6.v1.collections.query.NearImage; +import io.weaviate.client6.v1.collections.query.NearText; import io.weaviate.client6.v1.collections.query.NearVector; public class SearchMarshaler implements GrpcMarshaler { @@ -31,6 +34,15 @@ public SearchMarshaler addGroupBy(NearVector.GroupBy gb) { return this; } + public SearchMarshaler addGroupBy(NearText.GroupBy gb) { + var groupBy = WeaviateProtoSearchGet.GroupBy.newBuilder(); + groupBy.addPath(gb.property()); + groupBy.setNumberOfGroups(gb.maxGroups()); + groupBy.setObjectsPerGroup(gb.maxObjectsPerGroup()); + req.setGroupBy(groupBy); + return this; + } + public SearchMarshaler addNearVector(NearVector nv) { setCommon(nv.common()); @@ -43,10 +55,56 @@ public SearchMarshaler addNearVector(NearVector nv) { nearVector.setDistance(nv.distance()); } + // TODO: add targets, vector_for_targets req.setNearVector(nearVector); return this; } + public SearchMarshaler addNearImage(NearImage ni) { + setCommon(ni.common()); + + var nearImage = WeaviateProtoBaseSearch.NearImageSearch.newBuilder(); + nearImage.setImage(ni.image()); + + if (ni.certainty() != null) { + nearImage.setCertainty(ni.certainty()); + } else if (ni.distance() != null) { + nearImage.setDistance(ni.distance()); + } + + req.setNearImage(nearImage); + return this; + } + + public SearchMarshaler addNearText(NearText nt) { + setCommon(nt.common()); + + var nearText = WeaviateProtoBaseSearch.NearTextSearch.newBuilder(); + nearText.addAllQuery(nt.text()); + + if (nt.certainty() != null) { + nearText.setCertainty(nt.certainty()); + } else if (nt.distance() != null) { + nearText.setDistance(nt.distance()); + } + + // TODO: add targets + if (nt.moveTo() != null) { + var to = NearTextSearch.Move.newBuilder(); + nt.moveTo().appendTo(to); + nearText.setMoveTo(to); + } + + if (nt.moveAway() != null) { + var away = NearTextSearch.Move.newBuilder(); + nt.moveAway().appendTo(away); + nearText.setMoveAway(away); + } + + req.setNearText(nearText); + return this; + } + private void setCommon(CommonQueryOptions o) { if (o.limit() != null) { req.setLimit(o.limit()); diff --git a/src/main/java/io/weaviate/client6/v1/collections/AtomicDataType.java b/src/main/java/io/weaviate/client6/v1/collections/AtomicDataType.java index da54b1c28..38c33ed22 100644 --- a/src/main/java/io/weaviate/client6/v1/collections/AtomicDataType.java +++ b/src/main/java/io/weaviate/client6/v1/collections/AtomicDataType.java @@ -6,10 +6,13 @@ public enum AtomicDataType { @SerializedName("text") TEXT, @SerializedName("int") - INT; + INT, + @SerializedName("blob") + BLOB; public static boolean isAtomic(String type) { return type.equals(TEXT.name().toLowerCase()) - || type.equals(INT.name().toLowerCase()); + || type.equals(INT.name().toLowerCase()) + || type.equals(BLOB.name().toLowerCase()); } } diff --git a/src/main/java/io/weaviate/client6/v1/collections/Collection.java b/src/main/java/io/weaviate/client6/v1/collections/Collection.java index 81298a029..870701311 100644 --- a/src/main/java/io/weaviate/client6/v1/collections/Collection.java +++ b/src/main/java/io/weaviate/client6/v1/collections/Collection.java @@ -9,48 +9,47 @@ public record Collection(String name, List properties, List references, Vectors vectors) { - public static Collection with(String name, Consumer options) { - var config = new Configuration(options); + public static Collection with(String name, Consumer options) { + var config = new Builder(options); return new Collection(name, config.properties, config.references, config.vectors); } - // Tucked Builder for additional collection configuration. - public static class Configuration { - public List properties = new ArrayList<>(); + public static class Builder { + private List properties = new ArrayList<>(); public List references = new ArrayList<>(); - public Vectors vectors; + private Vectors vectors; - public Configuration properties(Property... properties) { + public Builder properties(Property... properties) { this.properties = Arrays.asList(properties); return this; } - public Configuration references(ReferenceProperty... references) { + public Builder references(ReferenceProperty... references) { this.references = Arrays.asList(references); return this; } - public Configuration vectors(Vectors vectors) { + public Builder vectors(Vectors vectors) { this.vectors = vectors; return this; } - public Configuration vector(VectorIndex vector) { + public Builder vector(VectorIndex vector) { this.vectors = Vectors.of(vector); return this; } - public Configuration vector(String name, VectorIndex vector) { + public Builder vector(String name, VectorIndex vector) { this.vectors = new Vectors(name, vector); return this; } - public Configuration vectors(Consumer named) { + public Builder vectors(Consumer named) { this.vectors = Vectors.with(named); return this; } - Configuration(Consumer options) { + Builder(Consumer options) { options.accept(this); } } diff --git a/src/main/java/io/weaviate/client6/v1/collections/CollectionsClient.java b/src/main/java/io/weaviate/client6/v1/collections/CollectionsClient.java index fa3ff38de..48f41c4df 100644 --- a/src/main/java/io/weaviate/client6/v1/collections/CollectionsClient.java +++ b/src/main/java/io/weaviate/client6/v1/collections/CollectionsClient.java @@ -73,6 +73,7 @@ private static class VectorizerSerde @Override public Vectorizer deserialize(JsonElement json, Type typeOfT, JsonDeserializationContext context) throws JsonParseException { + // TODO: deserialize different kinds of vectorizers return Vectorizer.none(); } @@ -116,7 +117,7 @@ public void create(String name) throws IOException { }); } - public void create(String name, Consumer options) throws IOException { + public void create(String name, Consumer options) throws IOException { var collection = Collection.with(name, options); ClassicHttpRequest httpPost = ClassicRequestBuilder .post(config.baseUrl() + "/schema") diff --git a/src/main/java/io/weaviate/client6/v1/collections/ContextionaryVectorizer.java b/src/main/java/io/weaviate/client6/v1/collections/ContextionaryVectorizer.java new file mode 100644 index 000000000..1bb580ada --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/ContextionaryVectorizer.java @@ -0,0 +1,37 @@ +package io.weaviate.client6.v1.collections; + +import java.util.Map; +import java.util.function.Consumer; + +import com.google.gson.annotations.SerializedName; + +import lombok.AllArgsConstructor; + +@AllArgsConstructor +public class ContextionaryVectorizer extends Vectorizer { + @SerializedName("text2vec-contextionary") + private Map configuration; + + public static ContextionaryVectorizer of() { + return new Builder().build(); + } + + public static ContextionaryVectorizer of(Consumer fn) { + var builder = new Builder(); + fn.accept(builder); + return builder.build(); + } + + public static class Builder { + private boolean vectorizeCollectionName = false; + + public Builder vectorizeCollectionName() { + this.vectorizeCollectionName = true; + return this; + } + + public ContextionaryVectorizer build() { + return new ContextionaryVectorizer(Map.of("vectorizeClassName", vectorizeCollectionName)); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/collections/HNSW.java b/src/main/java/io/weaviate/client6/v1/collections/HNSW.java index cac6b45b4..938c44ffa 100644 --- a/src/main/java/io/weaviate/client6/v1/collections/HNSW.java +++ b/src/main/java/io/weaviate/client6/v1/collections/HNSW.java @@ -17,26 +17,26 @@ public enum Distance { this(null, null); } - static HNSW with(Consumer options) { - var opt = new Options(options); + static HNSW with(Consumer options) { + var opt = new Builder(options); return new HNSW(opt.distance, opt.skip); } - public static class Options { - public Distance distance; - public Boolean skip; + public static class Builder { + private Distance distance; + private Boolean skip; - public Options distance(Distance distance) { + public Builder distance(Distance distance) { this.distance = distance; return this; } - public Options disableIndexation() { + public Builder disableIndexation() { this.skip = true; return this; } - public Options(Consumer options) { + public Builder(Consumer options) { options.accept(this); } } diff --git a/src/main/java/io/weaviate/client6/v1/collections/Img2VecNeuralVectorizer.java b/src/main/java/io/weaviate/client6/v1/collections/Img2VecNeuralVectorizer.java new file mode 100644 index 000000000..a0efc5c61 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/Img2VecNeuralVectorizer.java @@ -0,0 +1,40 @@ +package io.weaviate.client6.v1.collections; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; + +import com.google.gson.annotations.SerializedName; + +import lombok.AllArgsConstructor; + +@AllArgsConstructor +public class Img2VecNeuralVectorizer extends Vectorizer { + @SerializedName("img2vec-neural") + private Map configuration; + + public static Img2VecNeuralVectorizer of() { + return new Builder().build(); + } + + public static Img2VecNeuralVectorizer of(Consumer fn) { + var builder = new Builder(); + fn.accept(builder); + return builder.build(); + } + + public static class Builder { + private List imageFields = new ArrayList<>(); + + public Builder imageFields(String... fields) { + this.imageFields = Arrays.asList(fields); + return this; + } + + public Img2VecNeuralVectorizer build() { + return new Img2VecNeuralVectorizer(Map.of("imageFields", imageFields)); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/collections/Multi2VecClipVectorizer.java b/src/main/java/io/weaviate/client6/v1/collections/Multi2VecClipVectorizer.java new file mode 100644 index 000000000..305e8373a --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/Multi2VecClipVectorizer.java @@ -0,0 +1,100 @@ +package io.weaviate.client6.v1.collections; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; + +import com.google.gson.annotations.SerializedName; + +import lombok.AllArgsConstructor; + +@AllArgsConstructor +public class Multi2VecClipVectorizer extends Vectorizer { + @SerializedName("multi2vec-clip") + private Map configuration; + + public static Multi2VecClipVectorizer of() { + return new Builder().build(); + } + + public static Multi2VecClipVectorizer of(Consumer fn) { + var builder = new Builder(); + fn.accept(builder); + return builder.build(); + } + + public static class Builder { + private boolean vectorizeCollectionName = false; + private String inferenceUrl; + private Map imageFields = new HashMap<>(); + private Map textFields = new HashMap<>(); + + public Builder inferenceUrl(String inferenceUrl) { + this.inferenceUrl = inferenceUrl; + return this; + } + + public Builder imageFields(String... fields) { + Arrays.stream(fields).forEach(f -> imageFields.put(f, null)); + return this; + } + + public Builder imageField(String field, float weight) { + imageFields.put(field, weight); + return this; + } + + public Builder textFields(String... fields) { + Arrays.stream(fields).forEach(f -> textFields.put(f, null)); + return this; + } + + public Builder textField(String field, float weight) { + textFields.put(field, weight); + return this; + } + + public Builder vectorizeCollectionName() { + this.vectorizeCollectionName = true; + return this; + } + + public Multi2VecClipVectorizer build() { + return new Multi2VecClipVectorizer(new HashMap<>() { + { + put("vectorizeClassName", vectorizeCollectionName); + if (inferenceUrl != null) { + put("inferenceUrl", inferenceUrl); + } + + var _imageFields = new ArrayList(); + var _imageWeights = new ArrayList(); + splitEntries(imageFields, _imageFields, _imageWeights); + + var _textFields = new ArrayList(); + var _textWeights = new ArrayList(); + splitEntries(imageFields, _textFields, _textWeights); + + put("imageFields", _imageFields); + put("textFields", _textFields); + put("weights", Map.of( + "imageWeights", _imageWeights, + "textWeights", _textWeights)); + } + }); + } + + private void splitEntries(Map map, List keys, List values) { + map.entrySet().forEach(entry -> { + keys.add(entry.getKey()); + var value = entry.getValue(); + if (value != null) { + values.add(value); + } + }); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/collections/Property.java b/src/main/java/io/weaviate/client6/v1/collections/Property.java index 9c2456eac..bb9293abb 100644 --- a/src/main/java/io/weaviate/client6/v1/collections/Property.java +++ b/src/main/java/io/weaviate/client6/v1/collections/Property.java @@ -19,6 +19,11 @@ public static Property integer(String name) { return new Property(name, AtomicDataType.INT); } + /** Add blob property with default configuration. */ + public static Property blob(String name) { + return new Property(name, AtomicDataType.BLOB); + } + public static ReferenceProperty reference(String name, String... collections) { return new ReferenceProperty(name, Arrays.asList(collections)); } diff --git a/src/main/java/io/weaviate/client6/v1/collections/Text2VecWeaviateVectorizer.java b/src/main/java/io/weaviate/client6/v1/collections/Text2VecWeaviateVectorizer.java new file mode 100644 index 000000000..db1f9a0f3 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/Text2VecWeaviateVectorizer.java @@ -0,0 +1,72 @@ +package io.weaviate.client6.v1.collections; + +import java.util.HashMap; +import java.util.Map; +import java.util.function.Consumer; + +import com.google.gson.annotations.SerializedName; + +import lombok.AllArgsConstructor; + +@AllArgsConstructor +public class Text2VecWeaviateVectorizer extends Vectorizer { + @SerializedName("text2vec-weaviate") + private Map configuration; + + public static Text2VecWeaviateVectorizer of() { + return new Builder().build(); + } + + public static Text2VecWeaviateVectorizer of(Consumer fn) { + var builder = new Builder(); + fn.accept(builder); + return builder.build(); + } + + public static final String SNOWFLAKE_ARCTIC_EMBED_L_20 = "Snowflake/snowflake-arctic-embed-l-v2.0"; + public static final String SNOWFLAKE_ARCTIC_EMBED_M_15 = "Snowflake/snowflake-arctic-embed-m-v1.5"; + + public static class Builder { + private boolean vectorizeCollectionName = false; + private String baseUrl; + private Integer dimensions; + private String model; + + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + public Builder dimensions(int dimensions) { + this.dimensions = dimensions; + return this; + } + + public Builder model(String model) { + this.model = model; + return this; + } + + public Builder vectorizeCollectionName() { + this.vectorizeCollectionName = true; + return this; + } + + public Text2VecWeaviateVectorizer build() { + return new Text2VecWeaviateVectorizer(new HashMap<>() { + { + put("vectorizeClassName", vectorizeCollectionName); + if (baseUrl != null) { + put("baseURL", baseUrl); + } + if (dimensions != null) { + put("dimensions", dimensions); + } + if (model != null) { + put("model", model); + } + } + }); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/collections/VectorIndex.java b/src/main/java/io/weaviate/client6/v1/collections/VectorIndex.java index 5db348263..b8aea7f0c 100644 --- a/src/main/java/io/weaviate/client6/v1/collections/VectorIndex.java +++ b/src/main/java/io/weaviate/client6/v1/collections/VectorIndex.java @@ -29,7 +29,7 @@ public static IndexingStrategy hnsw() { return new HNSW(); } - public static IndexingStrategy hnsw(Consumer options) { + public static IndexingStrategy hnsw(Consumer options) { return HNSW.with(options); } } diff --git a/src/main/java/io/weaviate/client6/v1/collections/Vectorizer.java b/src/main/java/io/weaviate/client6/v1/collections/Vectorizer.java index ad9c4260f..f2e07be5a 100644 --- a/src/main/java/io/weaviate/client6/v1/collections/Vectorizer.java +++ b/src/main/java/io/weaviate/client6/v1/collections/Vectorizer.java @@ -1,8 +1,44 @@ package io.weaviate.client6.v1.collections; +import java.util.function.Consumer; + // This class is WIP, I haven't decided how to structure it yet. public abstract class Vectorizer { public static NoneVectorizer none() { return new NoneVectorizer(); } + + public static ContextionaryVectorizer text2vecContextionary() { + return ContextionaryVectorizer.of(); + } + + public static ContextionaryVectorizer text2vecContextionary(Consumer fn) { + return ContextionaryVectorizer.of(fn); + } + + // TODO: add test cases + public static Text2VecWeaviateVectorizer text2vecWeaviate() { + return Text2VecWeaviateVectorizer.of(); + } + + public static Text2VecWeaviateVectorizer text2vecWeaviate(Consumer fn) { + return Text2VecWeaviateVectorizer.of(fn); + } + + // TODO: add test cases + public static Multi2VecClipVectorizer multi2vecClip() { + return Multi2VecClipVectorizer.of(); + } + + public static Multi2VecClipVectorizer multi2vecClip(Consumer fn) { + return Multi2VecClipVectorizer.of(fn); + } + + public static Img2VecNeuralVectorizer img2VecNeuralVectorizer() { + return Img2VecNeuralVectorizer.of(); + } + + public static Img2VecNeuralVectorizer img2VecNeuralVectorizer(Consumer fn) { + return Img2VecNeuralVectorizer.of(fn); + } } diff --git a/src/main/java/io/weaviate/client6/v1/collections/aggregate/TopOccurrences.java b/src/main/java/io/weaviate/client6/v1/collections/aggregate/TopOccurrences.java deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/main/java/io/weaviate/client6/v1/collections/data/DataClient.java b/src/main/java/io/weaviate/client6/v1/collections/data/DataClient.java index e8878e1e8..e6030f4cd 100644 --- a/src/main/java/io/weaviate/client6/v1/collections/data/DataClient.java +++ b/src/main/java/io/weaviate/client6/v1/collections/data/DataClient.java @@ -207,7 +207,11 @@ private static Object convertProtoValue(Value value) { return value.getIntValue(); } else if (value.hasNumberValue()) { return value.getNumberValue(); - } else if (value.hasDateValue()) { + } else if (value.hasBlobValue()) { + return value.getBlobValue(); + } else if (value.hasDateValue()) + + { OffsetDateTime offsetDateTime = OffsetDateTime.parse(value.getDateValue()); return Date.from(offsetDateTime.toInstant()); } else { diff --git a/src/main/java/io/weaviate/client6/v1/collections/object/ObjectMetadata.java b/src/main/java/io/weaviate/client6/v1/collections/object/ObjectMetadata.java index e4cf36602..61ffcb9de 100644 --- a/src/main/java/io/weaviate/client6/v1/collections/object/ObjectMetadata.java +++ b/src/main/java/io/weaviate/client6/v1/collections/object/ObjectMetadata.java @@ -10,8 +10,8 @@ public static ObjectMetadata with(Consumer options) { } public static class Builder { - public String id; - public Vectors vectors; + private String id; + private Vectors vectors; public Builder id(String id) { this.id = id; diff --git a/src/main/java/io/weaviate/client6/v1/collections/query/NearImage.java b/src/main/java/io/weaviate/client6/v1/collections/query/NearImage.java new file mode 100644 index 000000000..d6b978cc6 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/query/NearImage.java @@ -0,0 +1,30 @@ +package io.weaviate.client6.v1.collections.query; + +import java.util.function.Consumer; + +public record NearImage(String image, Float distance, Float certainty, CommonQueryOptions common) { + + public static NearImage with(String image, Consumer fn) { + var opt = new Builder(); + fn.accept(opt); + return new NearImage(image, opt.distance, opt.certainty, new CommonQueryOptions(opt)); + } + + public static class Builder extends CommonQueryOptions.Builder { + private Float distance; + private Float certainty; + + public Builder distance(float distance) { + this.distance = distance; + return this; + } + + public Builder certainty(float certainty) { + this.certainty = certainty; + return this; + } + } + + public static record GroupBy(String property, int maxGroups, int maxObjectsPerGroup) { + } +} diff --git a/src/main/java/io/weaviate/client6/v1/collections/query/NearText.java b/src/main/java/io/weaviate/client6/v1/collections/query/NearText.java new file mode 100644 index 000000000..6185b772b --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/query/NearText.java @@ -0,0 +1,87 @@ +package io.weaviate.client6.v1.collections.query; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.function.Consumer; + +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoBaseSearch; + +public record NearText(List text, Float distance, Float certainty, Move moveTo, Move moveAway, + CommonQueryOptions common) { + + public static NearText with(String text, Consumer fn) { + return with(List.of(text), fn); + } + + public static NearText with(List text, Consumer fn) { + var opt = new Builder(); + fn.accept(opt); + return new NearText(text, opt.distance, opt.certainty, opt.moveTo, opt.moveAway, new CommonQueryOptions(opt)); + } + + public static class Builder extends CommonQueryOptions.Builder { + private Float distance; + private Float certainty; + private Move moveTo; + private Move moveAway; + + public Builder distance(float distance) { + this.distance = distance; + return this; + } + + public Builder certainty(float certainty) { + this.certainty = certainty; + return this; + } + + public Builder moveTo(float force, Consumer fn) { + var move = new Move(force); + fn.accept(move); + this.moveTo = move; + return this; + } + + public Builder moveAway(float force, Consumer fn) { + var move = new Move(force); + fn.accept(move); + this.moveAway = move; + return this; + } + + } + + public static class Move { + private final Float force; + private List objects = new ArrayList<>(); + private List concepts = new ArrayList<>(); + + Move(float force) { + this.force = force; + } + + public Move uuids(String... uuids) { + this.objects = Arrays.asList(uuids); + return this; + } + + public Move concepts(String... concepts) { + this.concepts = Arrays.asList(concepts); + return this; + } + + public void appendTo(WeaviateProtoBaseSearch.NearTextSearch.Move.Builder move) { + move.setForce(force); + if (!objects.isEmpty()) { + move.addAllUuids(objects); + } + if (!concepts.isEmpty()) { + move.addAllConcepts(concepts); + } + } + } + + public static record GroupBy(String property, int maxGroups, int maxObjectsPerGroup) { + } +} diff --git a/src/main/java/io/weaviate/client6/v1/collections/query/QueryClient.java b/src/main/java/io/weaviate/client6/v1/collections/query/QueryClient.java index 66a8a6540..f701d5b2b 100644 --- a/src/main/java/io/weaviate/client6/v1/collections/query/QueryClient.java +++ b/src/main/java/io/weaviate/client6/v1/collections/query/QueryClient.java @@ -31,6 +31,13 @@ public QueryClient(String collectionName, GrpcClient grpc) { this.collectionName = collectionName; } + public QueryResult nearVector(Float[] vector) { + var query = NearVector.with(vector, opt -> { + }); + var req = new SearchMarshaler(collectionName).addNearVector(query); + return search(req.marshal()); + } + public QueryResult nearVector(Float[] vector, Consumer options) { var query = NearVector.with(vector, options); var req = new SearchMarshaler(collectionName).addNearVector(query); @@ -53,6 +60,33 @@ public GroupedQueryResult nearVector(Float[] vector, NearVector.GroupBy group return searchGrouped(req.marshal()); } + public QueryResult nearText(String text, Consumer fn) { + var query = NearText.with(text, fn); + var req = new SearchMarshaler(collectionName).addNearText(query); + return search(req.marshal()); + } + + public GroupedQueryResult nearText(String text, NearText.GroupBy groupBy, Consumer fn) { + var query = NearText.with(text, fn); + var req = new SearchMarshaler(collectionName) + .addNearText(query) + .addGroupBy(groupBy); + return searchGrouped(req.marshal()); + } + + public QueryResult nearText(String text) { + var query = NearText.with(text, opt -> { + }); + var req = new SearchMarshaler(collectionName).addNearText(query); + return search(req.marshal()); + } + + public QueryResult nearImage(String image, Consumer fn) { + var query = NearImage.with(image, fn); + var req = new SearchMarshaler(collectionName).addNearImage(query); + return search(req.marshal()); + } + private QueryResult search(SearchRequest req) { var reply = grpcClient.grpc.search(req); return deserializeUntyped(reply);