diff --git a/src/it/java/io/weaviate/integration/AggregationITest.java b/src/it/java/io/weaviate/integration/AggregationITest.java new file mode 100644 index 000000000..bd54ed865 --- /dev/null +++ b/src/it/java/io/weaviate/integration/AggregationITest.java @@ -0,0 +1,165 @@ +package io.weaviate.integration; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.function.Supplier; + +import org.assertj.core.api.Assertions; +import org.assertj.core.api.InstanceOfAssertFactories; +import org.junit.BeforeClass; +import org.junit.Test; + +import io.weaviate.ConcurrentTest; +import io.weaviate.client6.WeaviateClient; +import io.weaviate.client6.v1.collections.Property; +import io.weaviate.client6.v1.collections.VectorIndex; +import io.weaviate.client6.v1.collections.Vectorizer; +import io.weaviate.client6.v1.collections.Vectors; +import io.weaviate.client6.v1.collections.aggregate.AggregateGroupByRequest.GroupBy; +import io.weaviate.client6.v1.collections.aggregate.AggregateGroupByResponse; +import io.weaviate.client6.v1.collections.aggregate.Group; +import io.weaviate.client6.v1.collections.aggregate.GroupedBy; +import io.weaviate.client6.v1.collections.aggregate.IntegerMetric; +import io.weaviate.client6.v1.collections.aggregate.Metric; +import io.weaviate.containers.Container; + +public class AggregationITest extends ConcurrentTest { + private static WeaviateClient client = Container.WEAVIATE.getClient(); + private static final String COLLECTION = unique("Things"); + + @BeforeClass + public static void beforeAll() throws IOException { + client.collections.create(COLLECTION, + collection -> collection + .properties( + Property.text("category"), + Property.integer("price")) + .vectors(Vectors.of(new VectorIndex<>(Vectorizer.none())))); + + var things = client.collections.use(COLLECTION); + for (var category : List.of("Shoes", "Hat", "Jacket")) { + for (var i = 0; i < 5; i++) { + var vector = randomVector(10, -.1f, .1f); + // For simplicity, the "price" for each items equals to the + // number of characters in the name of the category. + things.data.insert(Map.of( + "category", category, + "price", category.length()), + meta -> meta.vectors(vector)); + } + } + } + + @Test + public void testOverAll() { + var things = client.collections.use(COLLECTION); + var result = things.aggregate.overAll( + with -> with.metrics( + Metric.integer("price", calculate -> calculate + .median().max().count())) + .includeTotalCount()); + + Assertions.assertThat(result) + .as("includes all objects").hasFieldOrPropertyWithValue("totalCount", 15L) + .as("'price' is IntegerMetric").returns(true, p -> p.isIntegerProperty("price")) + .as("aggregated prices").extracting(p -> p.getInteger("price")) + .as("min").returns(null, IntegerMetric.Values::min) + .as("max").returns(6L, IntegerMetric.Values::max) + .as("median").returns(5D, IntegerMetric.Values::median) + .as("count").returns(15L, IntegerMetric.Values::count); + } + + @Test + public void testOverAll_groupBy_category() { + var things = client.collections.use(COLLECTION); + var result = things.aggregate.overAll( + new GroupBy("category"), + with -> with.metrics( + Metric.integer("price", calculate -> calculate + .min().max().count())) + .includeTotalCount()); + + Assertions.assertThat(result) + .extracting(AggregateGroupByResponse::groups) + .asInstanceOf(InstanceOfAssertFactories.list(Group.class)) + .as("group per category").hasSize(3) + .allSatisfy(group -> { + Assertions.assertThat(group) + .extracting(Group::by) + .as(group.by().property() + " is Text property").returns(true, GroupedBy::isText); + + String category = group.by().getAsText(); + var expectedPrice = (long) category.length(); + + Function> desc = (String metric) -> { + return () -> "%s ('%s'.length)".formatted(metric, category); + }; + + Assertions.assertThat(group) + .as("'price' is IntegerMetric").returns(true, g -> g.isIntegerProperty("price")) + .as("aggregated prices").extracting(g -> g.getInteger("price")) + .as(desc.apply("max")).returns(expectedPrice, IntegerMetric.Values::max) + .as(desc.apply("min")).returns(expectedPrice, IntegerMetric.Values::min) + .as(desc.apply("count")).returns(5L, IntegerMetric.Values::count); + }); + } + + @Test + public void testNearVector() { + var things = client.collections.use(COLLECTION); + var result = things.aggregate.nearVector( + randomVector(10, -1f, 1f), + near -> near.limit(5), + with -> with.metrics( + Metric.integer("price", calculate -> calculate + .min().max().count())) + .objectLimit(4) + .includeTotalCount()); + + Assertions.assertThat(result) + .as("includes all objects").hasFieldOrPropertyWithValue("totalCount", 4L) + .as("'price' is IntegerMetric").returns(true, p -> p.isIntegerProperty("price")) + .as("aggregated prices").extracting(p -> p.getInteger("price")) + .as("count").returns(4L, IntegerMetric.Values::count); + } + + @Test + public void testNearVector_groupBy_category() { + var things = client.collections.use(COLLECTION); + var result = things.aggregate.nearVector( + randomVector(10, -1f, 1f), + near -> near.distance(2f), + new GroupBy("category"), + with -> with.metrics( + Metric.integer("price", calculate -> calculate + .min().max().median())) + .objectLimit(9) + .includeTotalCount()); + + Assertions.assertThat(result) + .extracting(AggregateGroupByResponse::groups) + .asInstanceOf(InstanceOfAssertFactories.list(Group.class)) + .as("group per category").hasSize(3) + .allSatisfy(group -> { + Assertions.assertThat(group) + .extracting(Group::by) + .as(group.by().property() + " is Text property").returns(true, GroupedBy::isText); + + String category = group.by().getAsText(); + var expectedPrice = (long) category.length(); + + Function> desc = (String metric) -> { + return () -> "%s ('%s'.length)".formatted(metric, category); + }; + + Assertions.assertThat(group) + .as("'price' is IntegerMetric").returns(true, g -> g.isIntegerProperty("price")) + .as("aggregated prices").extracting(g -> g.getInteger("price")) + .as(desc.apply("max")).returns(expectedPrice, IntegerMetric.Values::max) + .as(desc.apply("min")).returns(expectedPrice, IntegerMetric.Values::min) + .as(desc.apply("median")).returns((double) expectedPrice, IntegerMetric.Values::median); + }); + } +} diff --git a/src/it/java/io/weaviate/client6/v1/query/NearVectorQueryITest.java b/src/it/java/io/weaviate/integration/NearVectorQueryITest.java similarity index 59% rename from src/it/java/io/weaviate/client6/v1/query/NearVectorQueryITest.java rename to src/it/java/io/weaviate/integration/NearVectorQueryITest.java index 0b8693b75..66258810d 100644 --- a/src/it/java/io/weaviate/client6/v1/query/NearVectorQueryITest.java +++ b/src/it/java/io/weaviate/integration/NearVectorQueryITest.java @@ -1,9 +1,10 @@ -package io.weaviate.client6.v1.query; +package io.weaviate.integration; import java.io.IOException; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; +import java.util.List; import java.util.Map; import org.assertj.core.api.Assertions; @@ -13,9 +14,13 @@ import io.weaviate.ConcurrentTest; import io.weaviate.client6.WeaviateClient; import io.weaviate.client6.v1.Vectors; +import io.weaviate.client6.v1.collections.Property; import io.weaviate.client6.v1.collections.VectorIndex; import io.weaviate.client6.v1.collections.VectorIndex.IndexingStrategy; import io.weaviate.client6.v1.collections.Vectorizer; +import io.weaviate.client6.v1.query.GroupedQueryResult; +import io.weaviate.client6.v1.query.MetadataField; +import io.weaviate.client6.v1.query.NearVector; import io.weaviate.containers.Container; public class NearVectorQueryITest extends ConcurrentTest { @@ -23,6 +28,7 @@ public class NearVectorQueryITest extends ConcurrentTest { private static final String COLLECTION = unique("Things"); private static final String VECTOR_INDEX = "bring_your_own"; + private static final List CATEGORIES = List.of("red", "green"); /** * One of the inserted vectors which will be used as target vector for search. @@ -32,7 +38,7 @@ public class NearVectorQueryITest extends ConcurrentTest { @BeforeClass public static void beforeAll() throws IOException { createTestCollection(); - var created = createVectors(10); + var created = populateTest(10); searchVector = created.values().iterator().next(); } @@ -41,7 +47,7 @@ public void testNearVector() { // TODO: test that we return the results in the expected order // Because re-ranking should work correctly var things = client.collections.use(COLLECTION); - QueryResult> result = things.query.nearVector(searchVector, + var result = things.query.nearVector(searchVector, opt -> opt .distance(2f) .limit(3) @@ -49,23 +55,48 @@ public void testNearVector() { Assertions.assertThat(result.objects).hasSize(3); float maxDistance = Collections.max(result.objects, - Comparator.comparing(obj -> obj.metadata.distance)).metadata.distance; + Comparator.comparing(obj -> obj.metadata.distance())).metadata.distance(); Assertions.assertThat(maxDistance).isLessThanOrEqualTo(2f); } + @Test + public void testNearVector_groupBy() { + // TODO: test that we return the results in the expected order + // Because re-ranking should work correctly + var things = client.collections.use(COLLECTION); + var result = things.query.nearVector(searchVector, + new NearVector.GroupBy("category", 2, 5), + opt -> opt.distance(10f)); + + Assertions.assertThat(result.groups) + .as("group per category").containsOnlyKeys(CATEGORIES) + .hasSizeLessThanOrEqualTo(2) + .allSatisfy((category, group) -> { + Assertions.assertThat(group) + .as("group name").returns(category, GroupedQueryResult.Group::name); + Assertions.assertThat(group.numberOfObjects()) + .as("[%s] has 1+ object", category).isLessThanOrEqualTo(5L); + }); + + Assertions.assertThat(result.objects) + .as("object belongs a group") + .allMatch(obj -> result.groups.get(obj.belongsToGroup).objects().contains(obj)); + + } + /** * Insert 10 objects with random vectors. * * @returns IDs of inserted objects and their corresponding vectors. */ - private static Map createVectors(int n) throws IOException { + private static Map populateTest(int n) throws IOException { var created = new HashMap(); var things = client.collections.use(COLLECTION); for (int i = 0; i < n; i++) { var vector = randomVector(10, -.01f, .001f); var object = things.data.insert( - Map.of(), + Map.of("category", CATEGORIES.get(i % CATEGORIES.size())), metadata -> metadata .id(randomUUID()) .vectors(Vectors.of(VECTOR_INDEX, vector))); @@ -83,6 +114,7 @@ private static Map createVectors(int n) throws IOException { */ private static void createTestCollection() throws IOException { client.collections.create(COLLECTION, cfg -> cfg + .properties(Property.text("category")) .vector(VECTOR_INDEX, new VectorIndex<>(IndexingStrategy.hnsw(), Vectorizer.none()))); } } diff --git a/src/main/java/io/weaviate/client6/internal/codec/grpc/GrpcMarshaler.java b/src/main/java/io/weaviate/client6/internal/codec/grpc/GrpcMarshaler.java new file mode 100644 index 000000000..ed6624b39 --- /dev/null +++ b/src/main/java/io/weaviate/client6/internal/codec/grpc/GrpcMarshaler.java @@ -0,0 +1,5 @@ +package io.weaviate.client6.internal.codec.grpc; + +public interface GrpcMarshaler { + R marshal(); +} diff --git a/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/AggregateMarshaler.java b/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/AggregateMarshaler.java new file mode 100644 index 000000000..446adba78 --- /dev/null +++ b/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/AggregateMarshaler.java @@ -0,0 +1,132 @@ +package io.weaviate.client6.internal.codec.grpc.v1; + +import java.util.function.BiConsumer; + +import com.google.common.collect.ImmutableMap; + +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoAggregate; +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoAggregate.AggregateRequest.Aggregation; +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoBaseSearch; +import io.weaviate.client6.internal.GRPC; +import io.weaviate.client6.v1.collections.aggregate.AggregateGroupByRequest; +import io.weaviate.client6.v1.collections.aggregate.AggregateGroupByRequest.GroupBy; +import io.weaviate.client6.v1.collections.aggregate.AggregateRequest; +import io.weaviate.client6.v1.collections.aggregate.IntegerMetric; +import io.weaviate.client6.v1.collections.aggregate.Metric; +import io.weaviate.client6.v1.collections.aggregate.TextMetric; +import io.weaviate.client6.v1.query.NearVector; + +public final class AggregateMarshaler { + private final WeaviateProtoAggregate.AggregateRequest.Builder req = WeaviateProtoAggregate.AggregateRequest + .newBuilder(); + private final String collectionName; + + public AggregateMarshaler(String collectionName) { + this.collectionName = collectionName; + } + + public WeaviateProtoAggregate.AggregateRequest marshal() { + return req.build(); + } + + public AggregateMarshaler addAggregation(AggregateRequest aggregate) { + req.setCollection(collectionName); + + if (aggregate.includeTotalCount()) { + req.setObjectsCount(true); + } + + if (aggregate.objectLimit() != null) { + req.setObjectLimit(aggregate.objectLimit()); + } + + for (Metric metric : aggregate.returnMetrics()) { + addMetric(metric); + } + + return this; + } + + public AggregateMarshaler addGroupBy(GroupBy groupBy) { + var by = WeaviateProtoAggregate.AggregateRequest.GroupBy.newBuilder(); + by.setCollection(collectionName); + by.setProperty(groupBy.property()); + req.setGroupBy(by); + return this; + } + + public AggregateMarshaler addNearVector(NearVector nv) { + var nearVector = WeaviateProtoBaseSearch.NearVector.newBuilder(); + nearVector.setVectorBytes(GRPC.toByteString(nv.vector())); + + if (nv.certainty() != null) { + nearVector.setCertainty(nv.certainty()); + } else if (nv.distance() != null) { + nearVector.setDistance(nv.distance()); + } + + req.setNearVector(nearVector); + + // Base query options + if (nv.common().limit() != null) { + req.setLimit(nv.common().limit()); + } + return this; + } + + private void addMetric(Metric metric) { + var aggregation = Aggregation.newBuilder(); + aggregation.setProperty(metric.property()); + + if (metric instanceof TextMetric m) { + var text = Aggregation.Text.newBuilder(); + m.functions().forEach(f -> set(f, text)); + if (m.atLeast() != null) { + text.setTopOccurencesLimit(m.atLeast()); + } + aggregation.setText(text); + } else if (metric instanceof IntegerMetric m) { + var integer = Aggregation.Integer.newBuilder(); + m.functions().forEach(f -> set(f, integer)); + aggregation.setInt(integer); + } else { + assert false : "branch not covered"; + } + + req.addAggregations(aggregation); + } + + @SuppressWarnings("unchecked") + static final void set(Enum fn, B builder) { + if (metrics.containsKey(fn)) { + ((Toggle) metrics.get(fn)).toggleOn(builder); + } + } + + static final ImmutableMap, Toggle> metrics = new ImmutableMap.Builder, Toggle>() + .put(TextMetric._Function.TYPE, new Toggle<>(Aggregation.Text.Builder::setType)) + .put(TextMetric._Function.COUNT, new Toggle<>(Aggregation.Text.Builder::setCount)) + .put(TextMetric._Function.TOP_OCCURRENCES, new Toggle<>(Aggregation.Text.Builder::setTopOccurences)) + + .put(IntegerMetric._Function.COUNT, new Toggle<>(Aggregation.Integer.Builder::setCount)) + .put(IntegerMetric._Function.MIN, new Toggle<>(Aggregation.Integer.Builder::setMinimum)) + .put(IntegerMetric._Function.MAX, new Toggle<>(Aggregation.Integer.Builder::setMaximum)) + .put(IntegerMetric._Function.MEAN, new Toggle<>(Aggregation.Integer.Builder::setMean)) + .put(IntegerMetric._Function.MEDIAN, new Toggle<>(Aggregation.Integer.Builder::setMedian)) + .put(IntegerMetric._Function.MODE, new Toggle<>(Aggregation.Integer.Builder::setMode)) + .put(IntegerMetric._Function.SUM, new Toggle<>(Aggregation.Integer.Builder::setSum)) + .build(); + + static class Toggle { + private final BiConsumer setter; + + Toggle(BiConsumer setter) { + this.setter = setter; + } + + final void toggleOn(B builder) { + setter.accept(builder, true); + } + } + +} diff --git a/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/AggregateUnmarshaler.java b/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/AggregateUnmarshaler.java new file mode 100644 index 000000000..c26c174be --- /dev/null +++ b/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/AggregateUnmarshaler.java @@ -0,0 +1,102 @@ +package io.weaviate.client6.internal.codec.grpc.v1; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoAggregate; +import io.weaviate.client6.v1.collections.aggregate.AggregateGroupByResponse; +import io.weaviate.client6.v1.collections.aggregate.AggregateResponse; +import io.weaviate.client6.v1.collections.aggregate.Group; +import io.weaviate.client6.v1.collections.aggregate.GroupedBy; +import io.weaviate.client6.v1.collections.aggregate.IntegerMetric; +import io.weaviate.client6.v1.collections.aggregate.Metric; + +public final class AggregateUnmarshaler { + private final WeaviateProtoAggregate.AggregateReply reply; + + public AggregateUnmarshaler(WeaviateProtoAggregate.AggregateReply reply) { + this.reply = reply; + } + + public AggregateResponse single() { + Long totalCount = null; + Map properties = new HashMap<>(); + + if (reply.hasSingleResult()) { + var single = reply.getSingleResult(); + totalCount = single.hasObjectsCount() ? single.getObjectsCount() : null; + var aggregations = single.getAggregations().getAggregationsList(); + for (var agg : aggregations) { + var property = agg.getProperty(); + Metric.Values value = null; + + if (agg.hasInt()) { + var metrics = agg.getInt(); + value = new IntegerMetric.Values( + metrics.hasCount() ? metrics.getCount() : null, + metrics.hasMinimum() ? metrics.getMinimum() : null, + metrics.hasMaximum() ? metrics.getMaximum() : null, + metrics.hasMean() ? metrics.getMean() : null, + metrics.hasMedian() ? metrics.getMedian() : null, + metrics.hasMode() ? metrics.getMode() : null, + metrics.hasSum() ? metrics.getSum() : null); + } else { + assert false : "branch not covered"; + } + if (value != null) { + properties.put(property, value); + } + } + } + var result = new AggregateResponse(properties, totalCount); + return result; + } + + public AggregateGroupByResponse grouped() { + List> groups = new ArrayList<>(); + if (reply.hasGroupedResults()) { + for (var result : reply.getGroupedResults().getGroupsList()) { + final Long totalCount = result.hasObjectsCount() ? result.getObjectsCount() : null; + + GroupedBy groupedBy = null; + var gb = result.getGroupedBy(); + if (gb.hasInt()) { + groupedBy = new GroupedBy(gb.getPathList().get(0), gb.getInt()); + } else if (gb.hasText()) { + groupedBy = new GroupedBy(gb.getPathList().get(0), gb.getText()); + } else { + assert false : "branch not covered"; + } + + Map properties = new HashMap<>(); + for (var agg : result.getAggregations().getAggregationsList()) { + var property = agg.getProperty(); + Metric.Values value = null; + + if (agg.hasInt()) { + var metrics = agg.getInt(); + value = new IntegerMetric.Values( + metrics.hasCount() ? metrics.getCount() : null, + metrics.hasMinimum() ? metrics.getMinimum() : null, + metrics.hasMaximum() ? metrics.getMaximum() : null, + metrics.hasMean() ? metrics.getMean() : null, + metrics.hasMedian() ? metrics.getMedian() : null, + metrics.hasMode() ? metrics.getMode() : null, + metrics.hasSum() ? metrics.getSum() : null); + } else { + assert false : "branch not covered"; + } + if (value != null) { + properties.put(property, value); + } + } + Group group = new Group<>(groupedBy, properties, totalCount); + groups.add(group); + + } + } + return new AggregateGroupByResponse(groups); + } +} diff --git a/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/SearchMarshaler.java b/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/SearchMarshaler.java new file mode 100644 index 000000000..a85970bb1 --- /dev/null +++ b/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/SearchMarshaler.java @@ -0,0 +1,86 @@ +package io.weaviate.client6.internal.codec.grpc.v1; + +import org.apache.commons.lang3.StringUtils; + +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoBaseSearch; +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoSearchGet; +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoSearchGet.MetadataRequest; +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoSearchGet.PropertiesRequest; +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoSearchGet.SearchRequest; +import io.weaviate.client6.internal.GRPC; +import io.weaviate.client6.internal.codec.grpc.GrpcMarshaler; +import io.weaviate.client6.v1.query.CommonQueryOptions; +import io.weaviate.client6.v1.query.NearVector; + +public class SearchMarshaler implements GrpcMarshaler { + private final WeaviateProtoSearchGet.SearchRequest.Builder req = WeaviateProtoSearchGet.SearchRequest.newBuilder(); + + public SearchMarshaler(String collectionName) { + req.setCollection(collectionName); + req.setUses123Api(true); + req.setUses125Api(true); + req.setUses127Api(true); + } + + public SearchMarshaler addGroupBy(NearVector.GroupBy gb) { + var groupBy = WeaviateProtoSearchGet.GroupBy.newBuilder(); + groupBy.addPath(gb.property()); + groupBy.setNumberOfGroups(gb.maxGroups()); + groupBy.setObjectsPerGroup(gb.maxObjectsPerGroup()); + req.setGroupBy(groupBy); + return this; + } + + public SearchMarshaler addNearVector(NearVector nv) { + setCommon(nv.common()); + + var nearVector = WeaviateProtoBaseSearch.NearVector.newBuilder(); + nearVector.setVectorBytes(GRPC.toByteString(nv.vector())); + + if (nv.certainty() != null) { + nearVector.setCertainty(nv.certainty()); + } else if (nv.distance() != null) { + nearVector.setDistance(nv.distance()); + } + + req.setNearVector(nearVector); + return this; + } + + private void setCommon(CommonQueryOptions o) { + if (o.limit() != null) { + req.setLimit(o.limit()); + } + if (o.offset() != null) { + req.setOffset(o.offset()); + } + if (StringUtils.isNotBlank(o.after())) { + req.setAfter(o.after()); + } + if (StringUtils.isNotBlank(o.consistencyLevel())) { + req.setConsistencyLevelValue(Integer.valueOf(o.consistencyLevel())); + } + if (o.autocut() != null) { + req.setAutocut(o.autocut()); + } + + if (!o.returnMetadata().isEmpty()) { + var metadata = MetadataRequest.newBuilder(); + o.returnMetadata().forEach(m -> m.appendTo(metadata)); + req.setMetadata(metadata); + } + + if (!o.returnProperties().isEmpty()) { + var properties = PropertiesRequest.newBuilder(); + for (String property : o.returnProperties()) { + properties.addNonRefProperties(property); + } + req.setProperties(properties); + } + } + + @Override + public SearchRequest marshal() { + return req.build(); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/Collection.java b/src/main/java/io/weaviate/client6/v1/Collection.java index e12f56915..b1f40dcc4 100644 --- a/src/main/java/io/weaviate/client6/v1/Collection.java +++ b/src/main/java/io/weaviate/client6/v1/Collection.java @@ -3,15 +3,18 @@ import io.weaviate.client6.Config; import io.weaviate.client6.internal.GrpcClient; import io.weaviate.client6.internal.HttpClient; +import io.weaviate.client6.v1.collections.aggregate.WeaviateAggregate; import io.weaviate.client6.v1.data.Data; import io.weaviate.client6.v1.query.Query; public class Collection { public final Query query; public final Data data; + public final WeaviateAggregate aggregate; public Collection(String collectionName, Config config, GrpcClient grpc, HttpClient http) { this.query = new Query<>(collectionName, grpc); this.data = new Data<>(collectionName, config, http); + this.aggregate = new WeaviateAggregate(collectionName, grpc); } } diff --git a/src/main/java/io/weaviate/client6/v1/collections/CollectionDefinitionDTO.java b/src/main/java/io/weaviate/client6/v1/collections/CollectionDefinitionDTO.java index 2c6cd5c85..e0333fad0 100644 --- a/src/main/java/io/weaviate/client6/v1/collections/CollectionDefinitionDTO.java +++ b/src/main/java/io/weaviate/client6/v1/collections/CollectionDefinitionDTO.java @@ -30,12 +30,14 @@ public CollectionDefinitionDTO(CollectionDefinition colDef) { this.properties = colDef.properties(); this.vectors = colDef.vectors(); - var unnamed = this.vectors.getUnnamed(); - if (unnamed.isPresent()) { - var index = unnamed.get(); - this.vectorIndexType = index.type(); - this.vectorIndexConfig = index.configuration(); - this.vectorizer = index.vectorizer(); + if (this.vectors != null) { + var unnamed = this.vectors.getUnnamed(); + if (unnamed.isPresent()) { + var index = unnamed.get(); + this.vectorIndexType = index.type(); + this.vectorIndexConfig = index.configuration(); + this.vectorizer = index.vectorizer(); + } } } diff --git a/src/main/java/io/weaviate/client6/v1/collections/Collections.java b/src/main/java/io/weaviate/client6/v1/collections/Collections.java index 0a2b8e97d..4f65915c9 100644 --- a/src/main/java/io/weaviate/client6/v1/collections/Collections.java +++ b/src/main/java/io/weaviate/client6/v1/collections/Collections.java @@ -95,7 +95,11 @@ public JsonElement serialize(Vectorizer src, Type typeOfSrc, JsonSerializationCo @Override public void write(JsonWriter out, Vectors value) throws IOException { - gson.toJson(value.asMap(), Map.class, out); + if (value != null) { + gson.toJson(value.asMap(), Map.class, out); + } else { + out.nullValue(); + } } @Override diff --git a/src/main/java/io/weaviate/client6/v1/collections/VectorIndex.java b/src/main/java/io/weaviate/client6/v1/collections/VectorIndex.java index ad1160dbf..5db348263 100644 --- a/src/main/java/io/weaviate/client6/v1/collections/VectorIndex.java +++ b/src/main/java/io/weaviate/client6/v1/collections/VectorIndex.java @@ -19,7 +19,7 @@ public VectorIndex(IndexingStrategy index, V vectorizer) { } public VectorIndex(V vectorizer) { - this(null, vectorizer, null); + this(IndexingStrategy.hnsw(), vectorizer); } public static sealed interface IndexingStrategy permits HNSW { diff --git a/src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateGroupByRequest.java b/src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateGroupByRequest.java new file mode 100644 index 000000000..0d3786f87 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateGroupByRequest.java @@ -0,0 +1,26 @@ +package io.weaviate.client6.v1.collections.aggregate; + +import java.util.function.Consumer; + +public record AggregateGroupByRequest(AggregateRequest aggregate, GroupBy groupBy) { + + public static record GroupBy(String property) { + public static GroupBy with(Consumer options) { + var opt = new Builder(options); + return new GroupBy(opt.property); + } + + public static class Builder { + private String property; + + public Builder property(String name) { + this.property = name; + return this; + } + + Builder(Consumer options) { + options.accept(this); + } + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateGroupByResponse.java b/src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateGroupByResponse.java new file mode 100644 index 000000000..8cfeef016 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateGroupByResponse.java @@ -0,0 +1,7 @@ +package io.weaviate.client6.v1.collections.aggregate; + +import java.util.List; + +public record AggregateGroupByResponse(List> groups) { + +} diff --git a/src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateRequest.java b/src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateRequest.java new file mode 100644 index 000000000..3b7c75899 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateRequest.java @@ -0,0 +1,47 @@ +package io.weaviate.client6.v1.collections.aggregate; + +import java.util.Arrays; +import java.util.List; +import java.util.function.Consumer; + +public record AggregateRequest( + String collectionName, + Integer objectLimit, + boolean includeTotalCount, + List returnMetrics) { + + public static AggregateRequest with(String collectionName, Consumer options) { + var opt = new Builder(options); + return new AggregateRequest( + collectionName, + opt.objectLimit, + opt.includeTotalCount, + opt.metrics); + } + + public static class Builder { + private List metrics; + private Integer objectLimit; + private boolean includeTotalCount = false; + + Builder(Consumer options) { + options.accept(this); + } + + public Builder objectLimit(int limit) { + this.objectLimit = limit; + return this; + } + + public Builder includeTotalCount() { + this.includeTotalCount = true; + return this; + } + + @SafeVarargs + public final Builder metrics(Metric... metrics) { + this.metrics = Arrays.asList(metrics); + return this; + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateResponse.java b/src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateResponse.java new file mode 100644 index 000000000..f2d0cde13 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateResponse.java @@ -0,0 +1,27 @@ +package io.weaviate.client6.v1.collections.aggregate; + +import java.util.Map; + +public record AggregateResponse(Map properties, Long totalCount) { + public boolean isTextProperties(String name) { + return properties.get(name) instanceof TextMetric.Values; + } + + public boolean isIntegerProperty(String name) { + return properties.get(name) instanceof IntegerMetric.Values; + } + + public TextMetric.Values getText(String name) { + if (!isTextProperties(name)) { + throw new IllegalStateException(name + " is not a Text property"); + } + return (TextMetric.Values) this.properties.get(name); + } + + public IntegerMetric.Values getInteger(String name) { + if (!isIntegerProperty(name)) { + throw new IllegalStateException(name + " is not a Integer property"); + } + return (IntegerMetric.Values) this.properties.get(name); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/collections/aggregate/Group.java b/src/main/java/io/weaviate/client6/v1/collections/aggregate/Group.java new file mode 100644 index 000000000..05f010ac7 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/aggregate/Group.java @@ -0,0 +1,28 @@ +package io.weaviate.client6.v1.collections.aggregate; + +import java.util.Map; + +public record Group(GroupedBy by, Map properties, Long totalCount) { + // TODO: have DataType util method for this? + public boolean isTextProperty(String name) { + return properties.get(name) instanceof TextMetric.Values; + } + + public boolean isIntegerProperty(String name) { + return properties.get(name) instanceof IntegerMetric.Values; + } + + public TextMetric.Values getText(String name) { + if (!isTextProperty(name)) { + throw new IllegalStateException(name + " is not a Text property"); + } + return (TextMetric.Values) this.properties.get(name); + } + + public IntegerMetric.Values getInteger(String name) { + if (!isIntegerProperty(name)) { + throw new IllegalStateException(name + " is not a Integer property"); + } + return (IntegerMetric.Values) this.properties.get(name); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/collections/aggregate/GroupedBy.java b/src/main/java/io/weaviate/client6/v1/collections/aggregate/GroupedBy.java new file mode 100644 index 000000000..c751dca3a --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/aggregate/GroupedBy.java @@ -0,0 +1,14 @@ +package io.weaviate.client6.v1.collections.aggregate; + +public record GroupedBy(String property, T value) { + public boolean isText() { + return value instanceof String; + } + + public String getAsText() { + if (!isText()) { + throw new IllegalStateException(property + " is not a Text property"); + } + return (String) value; + } +} diff --git a/src/main/java/io/weaviate/client6/v1/collections/aggregate/IntegerMetric.java b/src/main/java/io/weaviate/client6/v1/collections/aggregate/IntegerMetric.java new file mode 100644 index 000000000..10ef8474f --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/aggregate/IntegerMetric.java @@ -0,0 +1,66 @@ +package io.weaviate.client6.v1.collections.aggregate; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.function.Consumer; + +public record IntegerMetric(String property, List<_Function> functions) implements Metric { + + public record Values(Long count, Long min, Long max, Double mean, Double median, Long mode, Long sum) + implements Metric.Values { + } + + static IntegerMetric with(String property, Consumer options) { + var opt = new Builder(options); + return new IntegerMetric(property, new ArrayList<>(opt.functions)); + } + + public enum _Function { + COUNT, MIN, MAX, MEAN, MEDIAN, MODE, SUM; + } + + public static class Builder { + private final Set<_Function> functions = new HashSet<>(); + + public Builder count() { + functions.add(_Function.COUNT); + return this; + } + + public Builder min() { + functions.add(_Function.MIN); + return this; + } + + public Builder max() { + functions.add(_Function.MAX); + return this; + } + + public Builder mean() { + functions.add(_Function.MEAN); + return this; + } + + public Builder median() { + functions.add(_Function.MEDIAN); + return this; + } + + public Builder mode() { + functions.add(_Function.MODE); + return this; + } + + public Builder sum() { + functions.add(_Function.SUM); + return this; + } + + Builder(Consumer options) { + options.accept(this); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/collections/aggregate/Metric.java b/src/main/java/io/weaviate/client6/v1/collections/aggregate/Metric.java new file mode 100644 index 000000000..588af7e43 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/aggregate/Metric.java @@ -0,0 +1,32 @@ +package io.weaviate.client6.v1.collections.aggregate; + +import java.util.List; +import java.util.function.Consumer; + +public interface Metric { + String property(); + + List> functions(); + + public static TextMetric text(String property) { + return TextMetric.with(property, _options -> { + }); + } + + public static TextMetric text(String property, Consumer options) { + return TextMetric.with(property, options); + } + + public static IntegerMetric integer(String property) { + return IntegerMetric.with(property, _options -> { + }); + } + + public static IntegerMetric integer(String property, Consumer options) { + return IntegerMetric.with(property, options); + } + + public interface Values { + Long count(); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/collections/aggregate/TextMetric.java b/src/main/java/io/weaviate/client6/v1/collections/aggregate/TextMetric.java new file mode 100644 index 000000000..7499cff70 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/aggregate/TextMetric.java @@ -0,0 +1,63 @@ +package io.weaviate.client6.v1.collections.aggregate; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.function.Consumer; + +public record TextMetric(String property, List<_Function> functions, boolean occurrenceCount, + Integer atLeast) + implements Metric { + + public record Values(Long count, List topOccurrences) implements Metric.Values { + } + + static TextMetric with(String property, Consumer options) { + var opt = new Builder(options); + return new TextMetric(property, + new ArrayList<>(opt.functions), + opt.occurrenceCount, opt.atLeast); + } + + public enum _Function { + COUNT, TYPE, TOP_OCCURRENCES; + } + + public static class Builder { + private final Set<_Function> functions = new HashSet<>(); + private boolean occurrenceCount = false; + private Integer atLeast; + + public Builder count() { + functions.add(_Function.COUNT); + return this; + } + + public Builder type() { + functions.add(_Function.TYPE); + return this; + } + + public Builder topOccurences() { + functions.add(_Function.TOP_OCCURRENCES); + return this; + } + + public Builder topOccurences(int atLeast) { + topOccurences(); + this.atLeast = atLeast; + return this; + } + + public Builder includeTopOccurencesCount() { + topOccurences(); + this.occurrenceCount = true; + return this; + } + + Builder(Consumer options) { + options.accept(this); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/collections/aggregate/TopOccurrence.java b/src/main/java/io/weaviate/client6/v1/collections/aggregate/TopOccurrence.java new file mode 100644 index 000000000..9d903ae82 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/aggregate/TopOccurrence.java @@ -0,0 +1,4 @@ +package io.weaviate.client6.v1.collections.aggregate; + +public record TopOccurrence(String value, int occurrenceCount) { +} diff --git a/src/main/java/io/weaviate/client6/v1/collections/aggregate/TopOccurrences.java b/src/main/java/io/weaviate/client6/v1/collections/aggregate/TopOccurrences.java new file mode 100644 index 000000000..e69de29bb diff --git a/src/main/java/io/weaviate/client6/v1/collections/aggregate/WeaviateAggregate.java b/src/main/java/io/weaviate/client6/v1/collections/aggregate/WeaviateAggregate.java new file mode 100644 index 000000000..73474fb7b --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/aggregate/WeaviateAggregate.java @@ -0,0 +1,89 @@ +package io.weaviate.client6.v1.collections.aggregate; + +import java.util.function.Consumer; + +import io.weaviate.client6.internal.GrpcClient; +import io.weaviate.client6.internal.codec.grpc.v1.AggregateMarshaler; +import io.weaviate.client6.internal.codec.grpc.v1.AggregateUnmarshaler; +import io.weaviate.client6.v1.query.NearVector; + +public class WeaviateAggregate { + private final String collectionName; + private final GrpcClient grpcClient; + + public WeaviateAggregate(String collectionName, GrpcClient grpc) { + this.collectionName = collectionName; + this.grpcClient = grpc; + } + + public AggregateResponse overAll(Consumer options) { + var aggregation = AggregateRequest.with(collectionName, options); + var req = new AggregateMarshaler(aggregation.collectionName()) + .addAggregation(aggregation) + .marshal(); + var reply = grpcClient.grpc.aggregate(req); + return new AggregateUnmarshaler(reply).single(); + } + + public AggregateGroupByResponse overAll( + AggregateGroupByRequest.GroupBy groupBy, + Consumer options) { + var aggregation = AggregateRequest.with(collectionName, options); + + var req = new AggregateMarshaler(aggregation.collectionName()) + .addAggregation(aggregation) + .addGroupBy(groupBy) + .marshal(); + var reply = grpcClient.grpc.aggregate(req); + return new AggregateUnmarshaler(reply).grouped(); + } + + public AggregateResponse nearVector( + Float[] vector, + Consumer nearVectorOptions, + Consumer options) { + var aggregation = AggregateRequest.with(collectionName, options); + var nearVector = NearVector.with(vector, nearVectorOptions); + + var req = new AggregateMarshaler(aggregation.collectionName()) + .addAggregation(aggregation) + .addNearVector(nearVector) + .marshal(); + var reply = grpcClient.grpc.aggregate(req); + return new AggregateUnmarshaler(reply).single(); + } + + public AggregateGroupByResponse nearVector( + Float[] vector, + AggregateGroupByRequest.GroupBy groupBy, + Consumer options) { + var aggregation = AggregateRequest.with(collectionName, options); + var nearVector = NearVector.with(vector, opt -> { + }); + + var req = new AggregateMarshaler(aggregation.collectionName()) + .addAggregation(aggregation) + .addGroupBy(groupBy) + .addNearVector(nearVector) + .marshal(); + var reply = grpcClient.grpc.aggregate(req); + return new AggregateUnmarshaler(reply).grouped(); + } + + public AggregateGroupByResponse nearVector( + Float[] vector, + Consumer nearVectorOptions, + AggregateGroupByRequest.GroupBy groupBy, + Consumer options) { + var aggregation = AggregateRequest.with(collectionName, options); + var nearVector = NearVector.with(vector, nearVectorOptions); + + var req = new AggregateMarshaler(aggregation.collectionName()) + .addAggregation(aggregation) + .addGroupBy(groupBy) + .addNearVector(nearVector) + .marshal(); + var reply = grpcClient.grpc.aggregate(req); + return new AggregateUnmarshaler(reply).grouped(); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/data/Data.java b/src/main/java/io/weaviate/client6/v1/data/Data.java index e5f457b23..54b476b22 100644 --- a/src/main/java/io/weaviate/client6/v1/data/Data.java +++ b/src/main/java/io/weaviate/client6/v1/data/Data.java @@ -31,6 +31,11 @@ public class Data { private final Config config; private final HttpClient httpClient; + public WeaviateObject insert(T object) throws IOException { + return insert(object, opt -> { + }); + } + public WeaviateObject insert(T object, Consumer options) throws IOException { var body = new WeaviateObject<>(collectionName, object, options); ClassicHttpRequest httpPost = ClassicRequestBuilder diff --git a/src/main/java/io/weaviate/client6/v1/data/WeaviateObjectDTO.java b/src/main/java/io/weaviate/client6/v1/data/WeaviateObjectDTO.java index 9d3e3fcc7..ed9b00af8 100644 --- a/src/main/java/io/weaviate/client6/v1/data/WeaviateObjectDTO.java +++ b/src/main/java/io/weaviate/client6/v1/data/WeaviateObjectDTO.java @@ -25,20 +25,24 @@ class WeaviateObjectDTO { if (object.metadata() != null) { this.id = object.metadata().id(); - this.vectors = object.metadata().vectors().asMap(); + if (object.metadata().vectors() != null) { + this.vectors = object.metadata().vectors().asMap(); + } } } WeaviateObject toWeaviateObject() { Map arrayVectors = new HashMap<>(); - for (var entry : vectors.entrySet()) { - var value = (ArrayList) entry.getValue(); - var vector = new Float[value.size()]; - int i = 0; - for (var v : value) { - vector[i++] = v.floatValue(); + if (vectors != null) { + for (var entry : vectors.entrySet()) { + var value = (ArrayList) entry.getValue(); + var vector = new Float[value.size()]; + int i = 0; + for (var v : value) { + vector[i++] = v.floatValue(); + } + arrayVectors.put(entry.getKey(), vector); } - arrayVectors.put(entry.getKey(), vector); } return new WeaviateObject(collection, properties, new ObjectMetadata(id, Vectors.of(arrayVectors))); } diff --git a/src/main/java/io/weaviate/client6/v1/query/CommonQueryOptions.java b/src/main/java/io/weaviate/client6/v1/query/CommonQueryOptions.java new file mode 100644 index 000000000..ddf1e1ab1 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/query/CommonQueryOptions.java @@ -0,0 +1,106 @@ +package io.weaviate.client6.v1.query; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.apache.commons.lang3.StringUtils; + +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoSearchGet.MetadataRequest; +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoSearchGet.PropertiesRequest; +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoSearchGet.SearchRequest; + +@SuppressWarnings("unchecked") +public record CommonQueryOptions( + Integer limit, + Integer offset, + Integer autocut, + String after, + String consistencyLevel /* TODO: use ConsistencyLevel enum */, + List returnProperties, + List returnMetadata) { + + public CommonQueryOptions(Builder> options) { + this( + options.limit, + options.offset, + options.autocut, + options.after, + options.consistencyLevel, + options.returnProperties, + options.returnMetadata); + + } + + public static abstract class Builder> { + private Integer limit; + private Integer offset; + private Integer autocut; + private String after; + private String consistencyLevel; + private List returnProperties = new ArrayList<>(); + private List returnMetadata = new ArrayList<>(); + + public final SELF limit(Integer limit) { + this.limit = limit; + return (SELF) this; + } + + public final SELF offset(Integer offset) { + this.offset = offset; + return (SELF) this; + } + + public final SELF autocut(Integer autocut) { + this.autocut = autocut; + return (SELF) this; + } + + public final SELF after(String after) { + this.after = after; + return (SELF) this; + } + + public final SELF consistencyLevel(String consistencyLevel) { + this.consistencyLevel = consistencyLevel; + return (SELF) this; + } + + public final SELF returnMetadata(Metadata... metadata) { + this.returnMetadata = Arrays.asList(metadata); + return (SELF) this; + } + + void appendTo(SearchRequest.Builder search) { + if (limit != null) { + search.setLimit(limit); + } + if (offset != null) { + search.setOffset(offset); + } + if (StringUtils.isNotBlank(after)) { + search.setAfter(after); + } + if (StringUtils.isNotBlank(consistencyLevel)) { + search.setConsistencyLevelValue(Integer.valueOf(consistencyLevel)); + } + if (autocut != null) { + search.setAutocut(autocut); + } + + if (!returnMetadata.isEmpty()) { + var metadata = MetadataRequest.newBuilder(); + returnMetadata.forEach(m -> m.appendTo(metadata)); + search.setMetadata(metadata); + } + + if (!returnProperties.isEmpty()) { + var properties = PropertiesRequest.newBuilder(); + for (String property : returnProperties) { + properties.addNonRefProperties(property); + } + search.setProperties(properties); + } + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/query/GroupedQueryResult.java b/src/main/java/io/weaviate/client6/v1/query/GroupedQueryResult.java new file mode 100644 index 000000000..01b8e68a4 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/query/GroupedQueryResult.java @@ -0,0 +1,26 @@ +package io.weaviate.client6.v1.query; + +import java.util.List; +import java.util.Map; + +import io.weaviate.client6.v1.query.QueryResult.SearchObject; +import lombok.AllArgsConstructor; + +@AllArgsConstructor +public class GroupedQueryResult { + public final List> objects; + public final Map> groups; + + public static class WithGroupSearchObject extends SearchObject { + public final String belongsToGroup; + + public WithGroupSearchObject(String group, T properties, QueryMetadata metadata) { + super(properties, metadata); + this.belongsToGroup = group; + } + } + + public record Group(String name, Float minDistance, Float maxDistance, long numberOfObjects, + List> objects) { + } +} diff --git a/src/main/java/io/weaviate/client6/v1/query/Metadata.java b/src/main/java/io/weaviate/client6/v1/query/Metadata.java index d490ee67f..4cc37bd98 100644 --- a/src/main/java/io/weaviate/client6/v1/query/Metadata.java +++ b/src/main/java/io/weaviate/client6/v1/query/Metadata.java @@ -5,7 +5,7 @@ /** * Metadata is the common base for all properties that are requestes as * "_additional". It is an inteface all metadata properties MUST implement to be - * used in {@link QueryOptions}. + * used in {@link CommonQueryOptions}. */ public interface Metadata { void appendTo(MetadataRequest.Builder metadata); diff --git a/src/main/java/io/weaviate/client6/v1/query/NearVector.java b/src/main/java/io/weaviate/client6/v1/query/NearVector.java index e479e8809..6cfee7f8f 100644 --- a/src/main/java/io/weaviate/client6/v1/query/NearVector.java +++ b/src/main/java/io/weaviate/client6/v1/query/NearVector.java @@ -2,53 +2,29 @@ import java.util.function.Consumer; -import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoBaseSearch; -import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoSearchGet.SearchRequest; -import io.weaviate.client6.internal.GRPC; +public record NearVector(Float[] vector, Float distance, Float certainty, CommonQueryOptions common) { -public class NearVector { - private final Float[] vector; - private final Options options; - - void appendTo(SearchRequest.Builder search) { - var nearVector = WeaviateProtoBaseSearch.NearVector.newBuilder(); - - // TODO: we should only add (named) Vectors. - // Since we do not force the users to supply a name when defining an index, - // we also need a way to "get default vector name" from the collection. - // For Map (untyped query handle) we always require the name. - nearVector.setVectorBytes(GRPC.toByteString(vector)); - options.append(search, nearVector); - search.setNearVector(nearVector.build()); - } - - public NearVector(Float[] vector, Consumer options) { - this.options = new Options(); - this.vector = vector; - options.accept(this.options); + public static NearVector with(Float[] vector, Consumer options) { + var opt = new Builder(); + options.accept(opt); + return new NearVector(vector, opt.distance, opt.certainty, new CommonQueryOptions(opt)); } - public static class Options extends QueryOptions { + public static class Builder extends CommonQueryOptions.Builder { private Float distance; private Float certainty; - public Options distance(float distance) { + public Builder distance(float distance) { this.distance = distance; return this; } - public Options certainty(float certainty) { + public Builder certainty(float certainty) { this.certainty = certainty; return this; } + } - void append(SearchRequest.Builder search, WeaviateProtoBaseSearch.NearVector.Builder nearVector) { - if (certainty != null) { - nearVector.setCertainty(certainty); - } else if (distance != null) { - nearVector.setDistance(distance); - } - super.appendTo(search); - } + public static record GroupBy(String property, int maxGroups, int maxObjectsPerGroup) { } } diff --git a/src/main/java/io/weaviate/client6/v1/query/Query.java b/src/main/java/io/weaviate/client6/v1/query/Query.java index 7ac1508c7..673ed1f48 100644 --- a/src/main/java/io/weaviate/client6/v1/query/Query.java +++ b/src/main/java/io/weaviate/client6/v1/query/Query.java @@ -1,6 +1,7 @@ package io.weaviate.client6.v1.query; import java.time.OffsetDateTime; +import java.util.ArrayList; import java.util.Date; import java.util.List; import java.util.Map; @@ -15,11 +16,9 @@ import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoSearchGet.SearchRequest; import io.weaviate.client6.internal.GRPC; import io.weaviate.client6.internal.GrpcClient; +import io.weaviate.client6.internal.codec.grpc.v1.SearchMarshaler; public class Query { - // TODO: inject singleton as dependency - private static final Gson gson = new Gson(); - // TODO: this should be wrapped around in some TypeInspector etc. private final String collectionName; @@ -32,19 +31,26 @@ public Query(String collectionName, GrpcClient grpc) { this.collectionName = collectionName; } - public QueryResult nearVector(Float[] vector, Consumer options) { - var query = new NearVector(vector, options); + public QueryResult nearVector(Float[] vector, Consumer options) { + var query = NearVector.with(vector, options); + var req = new SearchMarshaler(collectionName).addNearVector(query); + return search(req.marshal()); + } - // TODO: Since we always need to set these values, we migth want to move the - // next block to some factory method. - var req = SearchRequest.newBuilder(); - req.setCollection(collectionName); - req.setUses123Api(true); - req.setUses125Api(true); - req.setUses127Api(true); + public GroupedQueryResult nearVector(Float[] vector, NearVector.GroupBy groupBy, + Consumer options) { + var query = NearVector.with(vector, options); + var req = new SearchMarshaler(collectionName).addNearVector(query) + .addGroupBy(groupBy); + return searchGrouped(req.marshal()); + } - query.appendTo(req); - return search(req.build()); + public GroupedQueryResult nearVector(Float[] vector, NearVector.GroupBy groupBy) { + var query = NearVector.with(vector, opt -> { + }); + var req = new SearchMarshaler(collectionName).addNearVector(query) + .addGroupBy(groupBy); + return searchGrouped(req.marshal()); } private QueryResult search(SearchRequest req) { @@ -52,6 +58,11 @@ private QueryResult search(SearchRequest req) { return deserializeUntyped(reply); } + private GroupedQueryResult searchGrouped(SearchRequest req) { + var reply = grpcClient.grpc.search(req); + return deserializeUntypedGrouped(reply); + } + public QueryResult deserializeUntyped(SearchReply reply) { List> objects = reply.getResultsList().stream() .map(res -> { @@ -70,6 +81,36 @@ public QueryResult deserializeUntyped(SearchReply reply) { return new QueryResult(objects); } + public GroupedQueryResult deserializeUntypedGrouped(SearchReply reply) { + var allObjects = new ArrayList>(); + Map> allGroups = reply.getGroupByResultsList() + .stream().map(g -> { + var groupName = g.getName(); + var groupObjects = g.getObjectsList().stream().map(res -> { + Map properties = convertProtoMap(res.getProperties().getNonRefProps().getFieldsMap()); + + MetadataResult meta = res.getMetadata(); + var metadata = new QueryResult.SearchObject.QueryMetadata( + meta.getId(), + meta.getDistancePresent() ? meta.getDistance() : null, + GRPC.fromByteString(meta.getVectorBytes())); + var obj = new GroupedQueryResult.WithGroupSearchObject(groupName, (T) properties, metadata); + + allObjects.add(obj); + + return obj; + }).toList(); + + return new GroupedQueryResult.Group<>( + groupName, + g.getMinDistance(), + g.getMaxDistance(), + g.getNumberOfObjects(), + groupObjects); + }).collect(Collectors.toMap(GroupedQueryResult.Group::name, g -> g)); + return new GroupedQueryResult<>(allObjects, allGroups); + } + /** * Convert Map to Map such that can be * (de-)serialized by {@link Gson}. diff --git a/src/main/java/io/weaviate/client6/v1/query/QueryOptions.java b/src/main/java/io/weaviate/client6/v1/query/QueryOptions.java deleted file mode 100644 index 5ae284953..000000000 --- a/src/main/java/io/weaviate/client6/v1/query/QueryOptions.java +++ /dev/null @@ -1,84 +0,0 @@ -package io.weaviate.client6.v1.query; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -import org.apache.commons.lang3.StringUtils; - -import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoSearchGet.MetadataRequest; -import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoSearchGet.PropertiesRequest; -import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoSearchGet.SearchRequest; - -@SuppressWarnings("unchecked") -abstract class QueryOptions> { - private Integer limit; - private Integer offset; - private Integer autocut; - private String after; - private String consistencyLevel; - private List returnProperties = new ArrayList<>(); - private List returnMetadata = new ArrayList<>(); - - public final SELF limit(Integer limit) { - this.limit = limit; - return (SELF) this; - } - - public final SELF offset(Integer offset) { - this.offset = offset; - return (SELF) this; - } - - public final SELF autocut(Integer autocut) { - this.autocut = autocut; - return (SELF) this; - } - - public final SELF after(String after) { - this.after = after; - return (SELF) this; - } - - public final SELF consistencyLevel(String consistencyLevel) { - this.consistencyLevel = consistencyLevel; - return (SELF) this; - } - - public final SELF returnMetadata(Metadata... metadata) { - this.returnMetadata = Arrays.asList(metadata); - return (SELF) this; - } - - void appendTo(SearchRequest.Builder search) { - if (limit != null) { - search.setLimit(limit); - } - if (offset != null) { - search.setOffset(offset); - } - if (StringUtils.isNotBlank(after)) { - search.setAfter(after); - } - if (StringUtils.isNotBlank(consistencyLevel)) { - search.setConsistencyLevelValue(Integer.valueOf(consistencyLevel)); - } - if (autocut != null) { - search.setAutocut(autocut); - } - - if (!returnMetadata.isEmpty()) { - var metadata = MetadataRequest.newBuilder(); - returnMetadata.forEach(m -> m.appendTo(metadata)); - search.setMetadata(metadata); - } - - if (!returnProperties.isEmpty()) { - var properties = PropertiesRequest.newBuilder(); - for (String property : returnProperties) { - properties.addNonRefProperties(property); - } - search.setProperties(properties); - } - } -} diff --git a/src/main/java/io/weaviate/client6/v1/query/QueryResult.java b/src/main/java/io/weaviate/client6/v1/query/QueryResult.java index 24b0a91e2..3d03a9840 100644 --- a/src/main/java/io/weaviate/client6/v1/query/QueryResult.java +++ b/src/main/java/io/weaviate/client6/v1/query/QueryResult.java @@ -3,7 +3,6 @@ import java.util.List; import lombok.AllArgsConstructor; -import lombok.ToString; @AllArgsConstructor public class QueryResult { @@ -14,13 +13,8 @@ public static class SearchObject { public final T properties; public final QueryMetadata metadata; - @AllArgsConstructor - @ToString - public static class QueryMetadata { - String id; - Float distance; + public record QueryMetadata(String id, Float distance, Float[] vector) { // TODO: use Vectors (to handle both Float[] and Float[][]) - Float[] vector; } } } diff --git a/src/test/java/io/weaviate/internal/DtoTypeAdapterFactoryTest.java b/src/test/java/io/weaviate/client6/internal/DtoTypeAdapterFactoryTest.java similarity index 96% rename from src/test/java/io/weaviate/internal/DtoTypeAdapterFactoryTest.java rename to src/test/java/io/weaviate/client6/internal/DtoTypeAdapterFactoryTest.java index 85cf85da6..f3ca920db 100644 --- a/src/test/java/io/weaviate/internal/DtoTypeAdapterFactoryTest.java +++ b/src/test/java/io/weaviate/client6/internal/DtoTypeAdapterFactoryTest.java @@ -1,4 +1,4 @@ -package io.weaviate.internal; +package io.weaviate.client6.internal; import org.assertj.core.api.Assertions; import org.junit.Test; @@ -10,8 +10,6 @@ import com.jparams.junit4.JParamsTestRunner; import com.jparams.junit4.data.DataMethod; -import io.weaviate.client6.internal.DtoTypeAdapterFactory; - @RunWith(JParamsTestRunner.class) public class DtoTypeAdapterFactoryTest { /** Person should be serialized to PersonDto. */ diff --git a/src/it/java/io/weaviate/client6/internal/GRPCTest.java b/src/test/java/io/weaviate/client6/internal/GRPCTest.java similarity index 100% rename from src/it/java/io/weaviate/client6/internal/GRPCTest.java rename to src/test/java/io/weaviate/client6/internal/GRPCTest.java diff --git a/src/test/java/io/weaviate/client6/v1/collections/VectorsTest.java b/src/test/java/io/weaviate/client6/v1/collections/VectorsTest.java index 37426ac75..8deae4893 100644 --- a/src/test/java/io/weaviate/client6/v1/collections/VectorsTest.java +++ b/src/test/java/io/weaviate/client6/v1/collections/VectorsTest.java @@ -54,7 +54,11 @@ public static Object[][] testCases() { """ { "vectorConfig": { - "default": { "vectorizer": { "none":{}}} + "default": { + "vectorizer": { "none": {}}, + "vectorIndexType": "hnsw", + "vectorIndexConfig": {} + } } } """, @@ -65,9 +69,13 @@ public static Object[][] testCases() { """ { "vectorConfig": { - "vector-1": { "vectorizer": { "none":{}}}, + "vector-1": { + "vectorizer": { "none": {}}, + "vectorIndexType": "hnsw", + "vectorIndexConfig": {} + }, "vector-2": { - "vectorizer": { "none":{}}, + "vectorizer": { "none": {}}, "vectorIndexType": "hnsw", "vectorIndexConfig": {} } @@ -83,8 +91,8 @@ public static Object[][] testCases() { """ { "vectorizer": { "none": {}}, - "vectorIndexConfig": { "distance": "COSINE", "skip": true }, - "vectorIndexType": "hnsw" + "vectorIndexType": "hnsw", + "vectorIndexConfig": { "distance": "COSINE", "skip": true } } """, collectionWithVectors(Vectors.unnamed(