From c01aa019112abdc7dd3599f4e645076a0f690fe3 Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Wed, 19 Mar 2025 19:43:56 +0100 Subject: [PATCH 1/7] feat: aggregate over all objects in collection Supports integer aggregations. Also committing a bunch of group-by related code, although that's WIP. --- .../v1/collections/CollectionsITest.java | 55 ++++++---- .../v1/query/NearVectorQueryITest.java | 10 +- .../integration/AggregationITest.java | 79 ++++++++++++++ .../io/weaviate/client6/v1/Collection.java | 3 + .../collections/CollectionDefinitionDTO.java | 14 +-- .../client6/v1/collections/Collections.java | 6 +- .../aggregate/AggregateGroupByResult.java | 7 ++ .../aggregate/AggregateRequest.java | 97 +++++++++++++++++ .../aggregate/AggregateResult.java | 27 +++++ .../v1/collections/aggregate/Group.java | 17 +++ .../v1/collections/aggregate/GroupedBy.java | 7 ++ .../collections/aggregate/IntegerMetric.java | 100 ++++++++++++++++++ .../v1/collections/aggregate/Metric.java | 37 +++++++ .../v1/collections/aggregate/TextMetric.java | 87 +++++++++++++++ .../collections/aggregate/TopOccurrence.java | 4 + .../collections/aggregate/TopOccurrences.java | 0 .../aggregate/WeaviateAggregate.java | 68 ++++++++++++ .../io/weaviate/client6/v1/data/Data.java | 5 + .../client6/v1/data/WeaviateObjectDTO.java | 20 ++-- .../java/io/weaviate}/internal/GRPCTest.java | 0 20 files changed, 603 insertions(+), 40 deletions(-) create mode 100644 src/it/java/io/weaviate/integration/AggregationITest.java create mode 100644 src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateGroupByResult.java create mode 100644 src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateRequest.java create mode 100644 src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateResult.java create mode 100644 src/main/java/io/weaviate/client6/v1/collections/aggregate/Group.java create mode 100644 src/main/java/io/weaviate/client6/v1/collections/aggregate/GroupedBy.java create mode 100644 src/main/java/io/weaviate/client6/v1/collections/aggregate/IntegerMetric.java create mode 100644 src/main/java/io/weaviate/client6/v1/collections/aggregate/Metric.java create mode 100644 src/main/java/io/weaviate/client6/v1/collections/aggregate/TextMetric.java create mode 100644 src/main/java/io/weaviate/client6/v1/collections/aggregate/TopOccurrence.java create mode 100644 src/main/java/io/weaviate/client6/v1/collections/aggregate/TopOccurrences.java create mode 100644 src/main/java/io/weaviate/client6/v1/collections/aggregate/WeaviateAggregate.java rename src/{it/java/io/weaviate/client6 => test/java/io/weaviate}/internal/GRPCTest.java (100%) diff --git a/src/it/java/io/weaviate/client6/v1/collections/CollectionsITest.java b/src/it/java/io/weaviate/client6/v1/collections/CollectionsITest.java index dd50d69e1..a214c6337 100644 --- a/src/it/java/io/weaviate/client6/v1/collections/CollectionsITest.java +++ b/src/it/java/io/weaviate/client6/v1/collections/CollectionsITest.java @@ -17,25 +17,40 @@ public class CollectionsITest extends ConcurrentTest { @Test public void testCreateGetDelete() throws IOException { var collectionName = ns("Things_1"); - client.collections.create(collectionName, - col -> col - .properties(Property.text("username"), Property.integer("age")) - .vector(new VectorIndex<>(IndexingStrategy.hnsw(), Vectorizer.none()))); - - var thingsCollection = client.collections.getConfig(collectionName); - - Assertions.assertThat(thingsCollection).get() - .hasFieldOrPropertyWithValue("name", collectionName) - .extracting(CollectionDefinition::vectors).extracting(Vectors::getDefault) - .as("default vector").satisfies(defaultVector -> { - Assertions.assertThat(defaultVector).extracting(VectorIndex::vectorizer) - .as("has none vectorizer").isInstanceOf(NoneVectorizer.class); - Assertions.assertThat(defaultVector).extracting(VectorIndex::configuration) - .as("has hnsw index").returns(IndexType.HNSW, IndexingStrategy::type); - }); - - client.collections.delete(collectionName); - var noCollection = client.collections.getConfig(collectionName); - Assertions.assertThat(noCollection).as("after delete").isEmpty(); + +// -------------------------------------------- +var defaultIndex = new VectorIndex<>(Vectorizer.none()); +var hnswIndex = new VectorIndex<>(IndexingStrategy.hnsw(), Vectorizer.none()); +// -------------------------------------------- + +client.collections.create(collectionName, + collection -> collection + .properties(Property.text("username"), Property.integer("age")) + .vector(defaultIndex) + .vector("only-one", hnswIndex) + .vectors(named -> named + .vector("vector-a", hnswIndex) + .vector("vector-b", hnswIndex))); + +// -------------------------------------------- +var thingsCollection = client.collections.getConfig(collectionName); +// -------------------------------------------- + +Assertions.assertThat(thingsCollection).get() + .hasFieldOrPropertyWithValue("name", collectionName) + .extracting(CollectionDefinition::vectors).extracting(Vectors::getDefault) + .as("default vector").satisfies(defaultVector -> { + Assertions.assertThat(defaultVector).extracting(VectorIndex::vectorizer) + .as("has none vectorizer").isInstanceOf(NoneVectorizer.class); + Assertions.assertThat(defaultVector).extracting(VectorIndex::configuration) + .as("has hnsw index").returns(IndexType.HNSW, IndexingStrategy::type); + }); + +// -------------------------------------------- +client.collections.delete(collectionName); +// -------------------------------------------- + +var noCollection = client.collections.getConfig(collectionName); +Assertions.assertThat(noCollection).as("after delete").isEmpty(); } } diff --git a/src/it/java/io/weaviate/client6/v1/query/NearVectorQueryITest.java b/src/it/java/io/weaviate/client6/v1/query/NearVectorQueryITest.java index 0b8693b75..93e945a2c 100644 --- a/src/it/java/io/weaviate/client6/v1/query/NearVectorQueryITest.java +++ b/src/it/java/io/weaviate/client6/v1/query/NearVectorQueryITest.java @@ -41,11 +41,11 @@ public void testNearVector() { // TODO: test that we return the results in the expected order // Because re-ranking should work correctly var things = client.collections.use(COLLECTION); - QueryResult> result = things.query.nearVector(searchVector, - opt -> opt - .distance(2f) - .limit(3) - .returnMetadata(MetadataField.DISTANCE)); +QueryResult> result = things.query.nearVector(searchVector, + opt -> opt + .distance(2f) + .limit(3) + .returnMetadata(MetadataField.DISTANCE)); Assertions.assertThat(result.objects).hasSize(3); float maxDistance = Collections.max(result.objects, diff --git a/src/it/java/io/weaviate/integration/AggregationITest.java b/src/it/java/io/weaviate/integration/AggregationITest.java new file mode 100644 index 000000000..e3147ad45 --- /dev/null +++ b/src/it/java/io/weaviate/integration/AggregationITest.java @@ -0,0 +1,79 @@ +package io.weaviate.integration; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import org.assertj.core.api.Assertions; +import org.junit.BeforeClass; +import org.junit.Test; + +import io.weaviate.ConcurrentTest; +import io.weaviate.client6.WeaviateClient; +import io.weaviate.client6.v1.collections.Property; +import io.weaviate.client6.v1.collections.aggregate.IntegerMetric; +import io.weaviate.client6.v1.collections.aggregate.Metric; +import io.weaviate.containers.Container; + +public class AggregationITest extends ConcurrentTest { + private static WeaviateClient client = Container.WEAVIATE.getClient(); + private static final String COLLECTION = unique("Things"); + + @BeforeClass + public static void beforeAll() throws IOException { + client.collections.create(COLLECTION, + collection -> collection + .properties( + Property.text("category"), + Property.integer("price"))); + + var things = client.collections.use(COLLECTION); + for (var category : List.of("Shoes", "Hat", "Jacket")) { + for (var i = 0; i < 5; i++) { + // For simplicity, the "price" for each items equals to the + // number of characters in the name of the category. + things.data.insert(Map.of( + "category", category, + "price", category.length())); + } + } + } + + @Test + public void testOverAll() { + var things = client.collections.use(COLLECTION); + var result = things.aggregate.overAll( + with -> with.metrics( + Metric.integer("price", calculate -> calculate + .median().max().count())) + .includeTotalCount()); + + Assertions.assertThat(result) + .as("includes all objects").hasFieldOrPropertyWithValue("totalCount", 15L) + .as("'price' is IntegerMetric").returns(true, p -> p.isIntegerProperty("price")) + .as("aggregated prices").extracting(p -> p.getInteger("price")) + .as("min").returns(null, IntegerMetric.Values::min) + .as("max").returns(6L, IntegerMetric.Values::max) + .as("median").returns(5D, IntegerMetric.Values::median) + .as("count").returns(15L, IntegerMetric.Values::count); + } + + // @Test + public void testOverAll_groupBy_category() { + var things = client.collections.use(COLLECTION); + var result = things.aggregate.overAll( + with -> with.metrics( + Metric.integer("price", calculate -> calculate + .median().max().count())) + .includeTotalCount()); + + Assertions.assertThat(result) + .as("includes all objects").hasFieldOrPropertyWithValue("totalCount", 15L) + .as("'price' is IntegerMetric").returns(true, p -> p.isIntegerProperty("price")) + .as("aggregated prices").extracting(p -> p.getInteger("price")) + .as("min").returns(null, IntegerMetric.Values::min) + .as("max").returns(6L, IntegerMetric.Values::max) + .as("median").returns(5f, IntegerMetric.Values::median) + .as("count").returns(15L, IntegerMetric.Values::count); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/Collection.java b/src/main/java/io/weaviate/client6/v1/Collection.java index e12f56915..b1f40dcc4 100644 --- a/src/main/java/io/weaviate/client6/v1/Collection.java +++ b/src/main/java/io/weaviate/client6/v1/Collection.java @@ -3,15 +3,18 @@ import io.weaviate.client6.Config; import io.weaviate.client6.internal.GrpcClient; import io.weaviate.client6.internal.HttpClient; +import io.weaviate.client6.v1.collections.aggregate.WeaviateAggregate; import io.weaviate.client6.v1.data.Data; import io.weaviate.client6.v1.query.Query; public class Collection { public final Query query; public final Data data; + public final WeaviateAggregate aggregate; public Collection(String collectionName, Config config, GrpcClient grpc, HttpClient http) { this.query = new Query<>(collectionName, grpc); this.data = new Data<>(collectionName, config, http); + this.aggregate = new WeaviateAggregate(collectionName, grpc); } } diff --git a/src/main/java/io/weaviate/client6/v1/collections/CollectionDefinitionDTO.java b/src/main/java/io/weaviate/client6/v1/collections/CollectionDefinitionDTO.java index 2c6cd5c85..e0333fad0 100644 --- a/src/main/java/io/weaviate/client6/v1/collections/CollectionDefinitionDTO.java +++ b/src/main/java/io/weaviate/client6/v1/collections/CollectionDefinitionDTO.java @@ -30,12 +30,14 @@ public CollectionDefinitionDTO(CollectionDefinition colDef) { this.properties = colDef.properties(); this.vectors = colDef.vectors(); - var unnamed = this.vectors.getUnnamed(); - if (unnamed.isPresent()) { - var index = unnamed.get(); - this.vectorIndexType = index.type(); - this.vectorIndexConfig = index.configuration(); - this.vectorizer = index.vectorizer(); + if (this.vectors != null) { + var unnamed = this.vectors.getUnnamed(); + if (unnamed.isPresent()) { + var index = unnamed.get(); + this.vectorIndexType = index.type(); + this.vectorIndexConfig = index.configuration(); + this.vectorizer = index.vectorizer(); + } } } diff --git a/src/main/java/io/weaviate/client6/v1/collections/Collections.java b/src/main/java/io/weaviate/client6/v1/collections/Collections.java index 0a2b8e97d..4f65915c9 100644 --- a/src/main/java/io/weaviate/client6/v1/collections/Collections.java +++ b/src/main/java/io/weaviate/client6/v1/collections/Collections.java @@ -95,7 +95,11 @@ public JsonElement serialize(Vectorizer src, Type typeOfSrc, JsonSerializationCo @Override public void write(JsonWriter out, Vectors value) throws IOException { - gson.toJson(value.asMap(), Map.class, out); + if (value != null) { + gson.toJson(value.asMap(), Map.class, out); + } else { + out.nullValue(); + } } @Override diff --git a/src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateGroupByResult.java b/src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateGroupByResult.java new file mode 100644 index 000000000..4eb30712b --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateGroupByResult.java @@ -0,0 +1,7 @@ +package io.weaviate.client6.v1.collections.aggregate; + +import java.util.List; + +public record AggregateGroupByResult(List> groups) { + +} diff --git a/src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateRequest.java b/src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateRequest.java new file mode 100644 index 000000000..2c0af8297 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateRequest.java @@ -0,0 +1,97 @@ +package io.weaviate.client6.v1.collections.aggregate; + +import java.util.Arrays; +import java.util.List; +import java.util.function.Consumer; + +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoAggregate.AggregateRequest.Aggregation; + +public record AggregateRequest(String collectionName, Integer objectLimit, GroupBy groupBy, + List>> returnMetrics, boolean includeTotalCount) { + + public static AggregateRequest with(String collectionName, Consumer options) { + var opt = new Builder(options); + return new AggregateRequest(collectionName, opt.objectLimit, null, opt.metrics, opt.includeTotalCount); + } + + public static AggregateRequest with(String collectionName, Consumer options, + Consumer groupByOptions) { + var opt = new Builder(options); + return new AggregateRequest(collectionName, opt.objectLimit, GroupBy.with(groupByOptions), opt.metrics, + opt.includeTotalCount); + } + + void appendTo(io.weaviate.client6.grpc.protocol.v1.WeaviateProtoAggregate.AggregateRequest.Builder req) { + if (groupBy != null) { + var by = io.weaviate.client6.grpc.protocol.v1.WeaviateProtoAggregate.AggregateRequest.GroupBy.newBuilder(); + by.setCollection(collectionName); + groupBy.appendTo(by); + req.setGroupBy(by); + } + + if (includeTotalCount) { + req.setObjectsCount(true); + } + + if (objectLimit != null) { + req.setObjectLimit(objectLimit); + } + + for (Metric metric : returnMetrics) { + var agg = Aggregation.newBuilder(); + metric.appendTo(agg); + req.addAggregations(agg); + } + } + + public static class Builder { + private List>> metrics; + private Integer objectLimit; + private boolean includeTotalCount = false; + + Builder(Consumer options) { + options.accept(this); + } + + public Builder objectLimit(int limit) { + this.objectLimit = limit; + return this; + } + + public Builder includeTotalCount() { + this.includeTotalCount = true; + return this; + } + + @SafeVarargs + public final Builder metrics(Metric>... metrics) { + this.metrics = Arrays.asList(metrics); + return this; + } + } + + public static record GroupBy(String property) { + public static GroupBy with(Consumer options) { + var opt = new Builder(options); + return new GroupBy(opt.property); + } + + public void appendTo( + io.weaviate.client6.grpc.protocol.v1.WeaviateProtoAggregate.AggregateRequest.GroupBy.Builder groupBy) { + groupBy.setProperty(property); + } + + public static class Builder { + private String property; + + public Builder property(String name) { + this.property = name; + return this; + } + + Builder(Consumer options) { + options.accept(this); + } + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateResult.java b/src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateResult.java new file mode 100644 index 000000000..99ceae35e --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateResult.java @@ -0,0 +1,27 @@ +package io.weaviate.client6.v1.collections.aggregate; + +import java.util.Map; + +public record AggregateResult(Map properties, Long totalCount) { + public boolean isTextProperties(String name) { + return properties.get(name) instanceof TextMetric.Values; + } + + public boolean isIntegerProperty(String name) { + return properties.get(name) instanceof IntegerMetric.Values; + } + + public TextMetric.Values getText(String name) { + if (!isTextProperties(name)) { + throw new IllegalStateException(name + " is not a Text property"); + } + return (TextMetric.Values) this.properties.get(name); + } + + public IntegerMetric.Values getInteger(String name) { + if (!isIntegerProperty(name)) { + throw new IllegalStateException(name + " is not a Integer property"); + } + return (IntegerMetric.Values) this.properties.get(name); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/collections/aggregate/Group.java b/src/main/java/io/weaviate/client6/v1/collections/aggregate/Group.java new file mode 100644 index 000000000..3f7f763b9 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/aggregate/Group.java @@ -0,0 +1,17 @@ +package io.weaviate.client6.v1.collections.aggregate; + +import java.util.Map; + +public record Group(GroupedBy by, Map properties, int totalCount) { + // TODO: have DataType util method for this? + public boolean isTextProperties(String name) { + return properties.get(name) instanceof TextMetric.Values; + } + + public TextMetric.Values getText(String name) { + if (!isTextProperties(name)) { + throw new IllegalStateException(name + " is not a Text property"); + } + return (TextMetric.Values) this.properties.get(name); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/collections/aggregate/GroupedBy.java b/src/main/java/io/weaviate/client6/v1/collections/aggregate/GroupedBy.java new file mode 100644 index 000000000..7796a1c10 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/aggregate/GroupedBy.java @@ -0,0 +1,7 @@ +package io.weaviate.client6.v1.collections.aggregate; + +public record GroupedBy(String property, T value) { + public boolean isText() { + return value instanceof String; + } +} diff --git a/src/main/java/io/weaviate/client6/v1/collections/aggregate/IntegerMetric.java b/src/main/java/io/weaviate/client6/v1/collections/aggregate/IntegerMetric.java new file mode 100644 index 000000000..e65da9b93 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/aggregate/IntegerMetric.java @@ -0,0 +1,100 @@ +package io.weaviate.client6.v1.collections.aggregate; + +import java.util.HashSet; +import java.util.Set; +import java.util.function.Consumer; + +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoAggregate.AggregateRequest.Aggregation; + +public class IntegerMetric extends Metric { + private final Set functions; + + public record Values(Long count, Long min, Long max, Double mean, Double median, Long mode, Long sum) + implements Metric.Values { + } + + IntegerMetric(String property, Consumer options) { + super(property); + var opt = new Builder(options); + this.functions = opt.functions; + } + + private enum AggregateFunction { + COUNT, MIN, MAX, MEAN, MEDIAN, MODE, SUM + } + + public static class Builder { + private final Set functions = new HashSet<>(); + + public Builder count() { + functions.add(AggregateFunction.COUNT); + return this; + } + + public Builder min() { + functions.add(AggregateFunction.MIN); + return this; + } + + public Builder max() { + functions.add(AggregateFunction.MAX); + return this; + } + + public Builder mean() { + functions.add(AggregateFunction.MEAN); + return this; + } + + public Builder median() { + functions.add(AggregateFunction.MEDIAN); + return this; + } + + public Builder mode() { + functions.add(AggregateFunction.MODE); + return this; + } + + public Builder sum() { + functions.add(AggregateFunction.SUM); + return this; + } + + Builder(Consumer options) { + options.accept(this); + } + } + + void appendTo(Aggregation.Builder aggregation) { + aggregation.setProperty(property); + var integer = io.weaviate.client6.grpc.protocol.v1.WeaviateProtoAggregate.AggregateRequest.Aggregation.Integer + .newBuilder(); + for (var f : functions) { + switch (f) { + case COUNT: + integer.setCount(true); + break; + case MIN: + integer.setMinimum(true); + break; + case MAX: + integer.setMaximum(true); + break; + case MEAN: + integer.setMean(true); + break; + case MODE: + integer.setMode(true); + break; + case MEDIAN: + integer.setMedian(true); + break; + case SUM: + integer.setSum(true); + break; + } + } + aggregation.setInt(integer); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/collections/aggregate/Metric.java b/src/main/java/io/weaviate/client6/v1/collections/aggregate/Metric.java new file mode 100644 index 000000000..0b1989e02 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/aggregate/Metric.java @@ -0,0 +1,37 @@ +package io.weaviate.client6.v1.collections.aggregate; + +import java.util.function.Consumer; + +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoAggregate.AggregateRequest.Aggregation; + +public abstract class Metric> { + protected final String property; + + abstract void appendTo(Aggregation.Builder aggregation); + + public Metric(String property) { + this.property = property; + } + + public static TextMetric text(String property) { + return new TextMetric(property, _options -> { + }); + } + + public static TextMetric text(String property, Consumer options) { + return new TextMetric(property, options); + } + + public static IntegerMetric integer(String property) { + return new IntegerMetric(property, _options -> { + }); + } + + public static IntegerMetric integer(String property, Consumer options) { + return new IntegerMetric(property, options); + } + + public interface Values { + Long count(); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/collections/aggregate/TextMetric.java b/src/main/java/io/weaviate/client6/v1/collections/aggregate/TextMetric.java new file mode 100644 index 000000000..61424e198 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/aggregate/TextMetric.java @@ -0,0 +1,87 @@ +package io.weaviate.client6.v1.collections.aggregate; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.function.Consumer; + +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoAggregate.AggregateRequest.Aggregation; +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoAggregate.AggregateRequest.Aggregation.Text; + +public class TextMetric extends Metric { + private final Set functions; + private final boolean occurrenceCount; + private final Integer atLeast; + + public record Values(Long count, List topOccurrences) implements Metric.Values { + } + + TextMetric(String property, Consumer options) { + super(property); + + var opt = new Builder(options); + this.functions = opt.functions; + this.occurrenceCount = opt.occurrenceCount; + this.atLeast = opt.atLeast; + } + + private enum AggregateFunction { + COUNT, TYPE, TOP_OCCURENCES + } + + public static class Builder { + private final Set functions = new HashSet<>(); + private boolean occurrenceCount = false; + private Integer atLeast; + + public Builder count() { + functions.add(AggregateFunction.COUNT); + return this; + } + + public Builder type() { + functions.add(AggregateFunction.TYPE); + return this; + } + + public Builder topOccurences() { + functions.add(AggregateFunction.TOP_OCCURENCES); + return this; + } + + public Builder topOccurences(int atLeast) { + topOccurences(); + this.atLeast = atLeast; + return this; + } + + public Builder includeTopOccurencesCount() { + topOccurences(); + this.occurrenceCount = true; + return this; + } + + Builder(Consumer options) { + options.accept(this); + } + } + + void appendTo(Aggregation.Builder aggregation) { + aggregation.setProperty(property); + var text = Text.newBuilder(); + for (var f : functions) { + switch (f) { + case TYPE: + text.setType(true); + case COUNT: + text.setCount(true); + case TOP_OCCURENCES: + text.setTopOccurences(true); + if (atLeast != null) { + text.setTopOccurencesLimit(atLeast); + } + } + } + aggregation.setText(text); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/collections/aggregate/TopOccurrence.java b/src/main/java/io/weaviate/client6/v1/collections/aggregate/TopOccurrence.java new file mode 100644 index 000000000..9d903ae82 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/aggregate/TopOccurrence.java @@ -0,0 +1,4 @@ +package io.weaviate.client6.v1.collections.aggregate; + +public record TopOccurrence(String value, int occurrenceCount) { +} diff --git a/src/main/java/io/weaviate/client6/v1/collections/aggregate/TopOccurrences.java b/src/main/java/io/weaviate/client6/v1/collections/aggregate/TopOccurrences.java new file mode 100644 index 000000000..e69de29bb diff --git a/src/main/java/io/weaviate/client6/v1/collections/aggregate/WeaviateAggregate.java b/src/main/java/io/weaviate/client6/v1/collections/aggregate/WeaviateAggregate.java new file mode 100644 index 000000000..08414077f --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/aggregate/WeaviateAggregate.java @@ -0,0 +1,68 @@ +package io.weaviate.client6.v1.collections.aggregate; + +import java.util.HashMap; +import java.util.Map; +import java.util.function.Consumer; + +import io.weaviate.client6.internal.GrpcClient; + +public class WeaviateAggregate { + private final String collectionName; + private final GrpcClient grpcClient; + + public WeaviateAggregate(String collectionName, GrpcClient grpc) { + this.collectionName = collectionName; + this.grpcClient = grpc; + } + + public AggregateResult overAll(Consumer options) { + var aggregation = AggregateRequest.with(collectionName, options); + var req = io.weaviate.client6.grpc.protocol.v1.WeaviateProtoAggregate.AggregateRequest.newBuilder(); + req.setCollection(collectionName); + aggregation.appendTo(req); + var reply = grpcClient.grpc.aggregate(req.build()); + + Long totalCount = null; + Map properties = new HashMap<>(); + + if (reply.hasSingleResult()) { + var single = reply.getSingleResult(); + totalCount = single.hasObjectsCount() ? single.getObjectsCount() : null; + var aggregations = single.getAggregations().getAggregationsList(); + for (var agg : aggregations) { + var property = agg.getProperty(); + Metric.Values value = null; + + if (agg.hasInt()) { + var metrics = agg.getInt(); + value = new IntegerMetric.Values( + metrics.hasCount() ? metrics.getCount() : null, + metrics.hasMinimum() ? metrics.getMinimum() : null, + metrics.hasMaximum() ? metrics.getMaximum() : null, + metrics.hasMean() ? metrics.getMean() : null, + metrics.hasMedian() ? metrics.getMedian() : null, + metrics.hasMode() ? metrics.getMode() : null, + metrics.hasSum() ? metrics.getSum() : null); + } else { + assert false : "branch not covered"; + } + + if (value != null) { + properties.put(property, value); + } + } + } + var result = new AggregateResult(properties, totalCount); + return result; + } + + public AggregateGroupByResult overAll(Consumer groupByOptions, + Consumer options) { + var aggregation = AggregateRequest.with(collectionName, options, groupByOptions); + var req = io.weaviate.client6.grpc.protocol.v1.WeaviateProtoAggregate.AggregateRequest.newBuilder(); + req.setCollection(collectionName); + aggregation.appendTo(req); + var reply = grpcClient.grpc.aggregate(req.build()); + return null; + } +} diff --git a/src/main/java/io/weaviate/client6/v1/data/Data.java b/src/main/java/io/weaviate/client6/v1/data/Data.java index e5f457b23..54b476b22 100644 --- a/src/main/java/io/weaviate/client6/v1/data/Data.java +++ b/src/main/java/io/weaviate/client6/v1/data/Data.java @@ -31,6 +31,11 @@ public class Data { private final Config config; private final HttpClient httpClient; + public WeaviateObject insert(T object) throws IOException { + return insert(object, opt -> { + }); + } + public WeaviateObject insert(T object, Consumer options) throws IOException { var body = new WeaviateObject<>(collectionName, object, options); ClassicHttpRequest httpPost = ClassicRequestBuilder diff --git a/src/main/java/io/weaviate/client6/v1/data/WeaviateObjectDTO.java b/src/main/java/io/weaviate/client6/v1/data/WeaviateObjectDTO.java index 9d3e3fcc7..ed9b00af8 100644 --- a/src/main/java/io/weaviate/client6/v1/data/WeaviateObjectDTO.java +++ b/src/main/java/io/weaviate/client6/v1/data/WeaviateObjectDTO.java @@ -25,20 +25,24 @@ class WeaviateObjectDTO { if (object.metadata() != null) { this.id = object.metadata().id(); - this.vectors = object.metadata().vectors().asMap(); + if (object.metadata().vectors() != null) { + this.vectors = object.metadata().vectors().asMap(); + } } } WeaviateObject toWeaviateObject() { Map arrayVectors = new HashMap<>(); - for (var entry : vectors.entrySet()) { - var value = (ArrayList) entry.getValue(); - var vector = new Float[value.size()]; - int i = 0; - for (var v : value) { - vector[i++] = v.floatValue(); + if (vectors != null) { + for (var entry : vectors.entrySet()) { + var value = (ArrayList) entry.getValue(); + var vector = new Float[value.size()]; + int i = 0; + for (var v : value) { + vector[i++] = v.floatValue(); + } + arrayVectors.put(entry.getKey(), vector); } - arrayVectors.put(entry.getKey(), vector); } return new WeaviateObject(collection, properties, new ObjectMetadata(id, Vectors.of(arrayVectors))); } diff --git a/src/it/java/io/weaviate/client6/internal/GRPCTest.java b/src/test/java/io/weaviate/internal/GRPCTest.java similarity index 100% rename from src/it/java/io/weaviate/client6/internal/GRPCTest.java rename to src/test/java/io/weaviate/internal/GRPCTest.java From 1ecd08612ff6ef722d5f2d4288d9cab28e3fa7a7 Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Wed, 19 Mar 2025 20:44:20 +0100 Subject: [PATCH 2/7] feat: add aggregation with groupBy clause Only supports aggregation on text / int properties at the moment --- .../integration/AggregationITest.java | 40 ++++++++++++---- .../v1/collections/aggregate/Group.java | 17 +++++-- .../v1/collections/aggregate/GroupedBy.java | 7 +++ .../aggregate/WeaviateAggregate.java | 48 ++++++++++++++++++- 4 files changed, 98 insertions(+), 14 deletions(-) diff --git a/src/it/java/io/weaviate/integration/AggregationITest.java b/src/it/java/io/weaviate/integration/AggregationITest.java index e3147ad45..50c5af2e0 100644 --- a/src/it/java/io/weaviate/integration/AggregationITest.java +++ b/src/it/java/io/weaviate/integration/AggregationITest.java @@ -3,14 +3,20 @@ import java.io.IOException; import java.util.List; import java.util.Map; +import java.util.function.Function; +import java.util.function.Supplier; import org.assertj.core.api.Assertions; +import org.assertj.core.api.InstanceOfAssertFactories; import org.junit.BeforeClass; import org.junit.Test; import io.weaviate.ConcurrentTest; import io.weaviate.client6.WeaviateClient; import io.weaviate.client6.v1.collections.Property; +import io.weaviate.client6.v1.collections.aggregate.AggregateGroupByResult; +import io.weaviate.client6.v1.collections.aggregate.Group; +import io.weaviate.client6.v1.collections.aggregate.GroupedBy; import io.weaviate.client6.v1.collections.aggregate.IntegerMetric; import io.weaviate.client6.v1.collections.aggregate.Metric; import io.weaviate.containers.Container; @@ -58,22 +64,38 @@ public void testOverAll() { .as("count").returns(15L, IntegerMetric.Values::count); } - // @Test + @Test public void testOverAll_groupBy_category() { var things = client.collections.use(COLLECTION); var result = things.aggregate.overAll( + groupBy -> groupBy.property("category"), with -> with.metrics( Metric.integer("price", calculate -> calculate - .median().max().count())) + .min().max().count())) .includeTotalCount()); Assertions.assertThat(result) - .as("includes all objects").hasFieldOrPropertyWithValue("totalCount", 15L) - .as("'price' is IntegerMetric").returns(true, p -> p.isIntegerProperty("price")) - .as("aggregated prices").extracting(p -> p.getInteger("price")) - .as("min").returns(null, IntegerMetric.Values::min) - .as("max").returns(6L, IntegerMetric.Values::max) - .as("median").returns(5f, IntegerMetric.Values::median) - .as("count").returns(15L, IntegerMetric.Values::count); + .extracting(AggregateGroupByResult::groups) + .asInstanceOf(InstanceOfAssertFactories.list(Group.class)) + .as("group per category").hasSize(3) + .allSatisfy(group -> { + Assertions.assertThat(group) + .extracting(Group::by) + .as(group.by().property() + " is Text property").returns(true, GroupedBy::isText); + + String category = group.by().getAsText(); + var expectedPrice = (long) category.length(); + + Function> desc = (String metric) -> { + return () -> "%s ('%s'.length)".formatted(metric, category); + }; + + Assertions.assertThat(group) + .as("'price' is IntegerMetric").returns(true, g -> g.isIntegerProperty("price")) + .as("aggregated prices").extracting(g -> g.getInteger("price")) + .as(desc.apply("max")).returns(expectedPrice, IntegerMetric.Values::max) + .as(desc.apply("min")).returns(expectedPrice, IntegerMetric.Values::min) + .as(desc.apply("count")).returns(5L, IntegerMetric.Values::count); + }); } } diff --git a/src/main/java/io/weaviate/client6/v1/collections/aggregate/Group.java b/src/main/java/io/weaviate/client6/v1/collections/aggregate/Group.java index 3f7f763b9..05f010ac7 100644 --- a/src/main/java/io/weaviate/client6/v1/collections/aggregate/Group.java +++ b/src/main/java/io/weaviate/client6/v1/collections/aggregate/Group.java @@ -2,16 +2,27 @@ import java.util.Map; -public record Group(GroupedBy by, Map properties, int totalCount) { +public record Group(GroupedBy by, Map properties, Long totalCount) { // TODO: have DataType util method for this? - public boolean isTextProperties(String name) { + public boolean isTextProperty(String name) { return properties.get(name) instanceof TextMetric.Values; } + public boolean isIntegerProperty(String name) { + return properties.get(name) instanceof IntegerMetric.Values; + } + public TextMetric.Values getText(String name) { - if (!isTextProperties(name)) { + if (!isTextProperty(name)) { throw new IllegalStateException(name + " is not a Text property"); } return (TextMetric.Values) this.properties.get(name); } + + public IntegerMetric.Values getInteger(String name) { + if (!isIntegerProperty(name)) { + throw new IllegalStateException(name + " is not a Integer property"); + } + return (IntegerMetric.Values) this.properties.get(name); + } } diff --git a/src/main/java/io/weaviate/client6/v1/collections/aggregate/GroupedBy.java b/src/main/java/io/weaviate/client6/v1/collections/aggregate/GroupedBy.java index 7796a1c10..c751dca3a 100644 --- a/src/main/java/io/weaviate/client6/v1/collections/aggregate/GroupedBy.java +++ b/src/main/java/io/weaviate/client6/v1/collections/aggregate/GroupedBy.java @@ -4,4 +4,11 @@ public record GroupedBy(String property, T value) { public boolean isText() { return value instanceof String; } + + public String getAsText() { + if (!isText()) { + throw new IllegalStateException(property + " is not a Text property"); + } + return (String) value; + } } diff --git a/src/main/java/io/weaviate/client6/v1/collections/aggregate/WeaviateAggregate.java b/src/main/java/io/weaviate/client6/v1/collections/aggregate/WeaviateAggregate.java index 08414077f..99ce116e0 100644 --- a/src/main/java/io/weaviate/client6/v1/collections/aggregate/WeaviateAggregate.java +++ b/src/main/java/io/weaviate/client6/v1/collections/aggregate/WeaviateAggregate.java @@ -1,6 +1,8 @@ package io.weaviate.client6.v1.collections.aggregate; +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.function.Consumer; @@ -46,7 +48,6 @@ public AggregateResult overAll(Consumer options) { } else { assert false : "branch not covered"; } - if (value != null) { properties.put(property, value); } @@ -63,6 +64,49 @@ public AggregateGroupByResult overAll(Consumer req.setCollection(collectionName); aggregation.appendTo(req); var reply = grpcClient.grpc.aggregate(req.build()); - return null; + + List> groups = new ArrayList<>(); + if (reply.hasGroupedResults()) { + for (var result : reply.getGroupedResults().getGroupsList()) { + final Long totalCount = result.hasObjectsCount() ? result.getObjectsCount() : null; + + GroupedBy groupedBy = null; + var gb = result.getGroupedBy(); + if (gb.hasInt()) { + groupedBy = new GroupedBy(gb.getPathList().get(0), gb.getInt()); + } else if (gb.hasText()) { + groupedBy = new GroupedBy(gb.getPathList().get(0), gb.getText()); + } else { + assert false : "branch not covered"; + } + + Map properties = new HashMap<>(); + for (var agg : result.getAggregations().getAggregationsList()) { + var property = agg.getProperty(); + Metric.Values value = null; + + if (agg.hasInt()) { + var metrics = agg.getInt(); + value = new IntegerMetric.Values( + metrics.hasCount() ? metrics.getCount() : null, + metrics.hasMinimum() ? metrics.getMinimum() : null, + metrics.hasMaximum() ? metrics.getMaximum() : null, + metrics.hasMean() ? metrics.getMean() : null, + metrics.hasMedian() ? metrics.getMedian() : null, + metrics.hasMode() ? metrics.getMode() : null, + metrics.hasSum() ? metrics.getSum() : null); + } else { + assert false : "branch not covered"; + } + if (value != null) { + properties.put(property, value); + } + } + Group group = new Group<>(groupedBy, properties, totalCount); + groups.add(group); + + } + } + return new AggregateGroupByResult(groups); } } From 1b5f48ae5da8848dcb0b2a28a13004205e0532ec Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Wed, 19 Mar 2025 20:48:58 +0100 Subject: [PATCH 3/7] chore: fix tests and import paths --- .../v1/collections/CollectionsITest.java | 55 +++++++------------ .../internal/DtoTypeAdapterFactoryTest.java | 4 +- .../{ => client6}/internal/GRPCTest.java | 0 3 files changed, 21 insertions(+), 38 deletions(-) rename src/test/java/io/weaviate/{ => client6}/internal/DtoTypeAdapterFactoryTest.java (96%) rename src/test/java/io/weaviate/{ => client6}/internal/GRPCTest.java (100%) diff --git a/src/it/java/io/weaviate/client6/v1/collections/CollectionsITest.java b/src/it/java/io/weaviate/client6/v1/collections/CollectionsITest.java index a214c6337..dd50d69e1 100644 --- a/src/it/java/io/weaviate/client6/v1/collections/CollectionsITest.java +++ b/src/it/java/io/weaviate/client6/v1/collections/CollectionsITest.java @@ -17,40 +17,25 @@ public class CollectionsITest extends ConcurrentTest { @Test public void testCreateGetDelete() throws IOException { var collectionName = ns("Things_1"); - -// -------------------------------------------- -var defaultIndex = new VectorIndex<>(Vectorizer.none()); -var hnswIndex = new VectorIndex<>(IndexingStrategy.hnsw(), Vectorizer.none()); -// -------------------------------------------- - -client.collections.create(collectionName, - collection -> collection - .properties(Property.text("username"), Property.integer("age")) - .vector(defaultIndex) - .vector("only-one", hnswIndex) - .vectors(named -> named - .vector("vector-a", hnswIndex) - .vector("vector-b", hnswIndex))); - -// -------------------------------------------- -var thingsCollection = client.collections.getConfig(collectionName); -// -------------------------------------------- - -Assertions.assertThat(thingsCollection).get() - .hasFieldOrPropertyWithValue("name", collectionName) - .extracting(CollectionDefinition::vectors).extracting(Vectors::getDefault) - .as("default vector").satisfies(defaultVector -> { - Assertions.assertThat(defaultVector).extracting(VectorIndex::vectorizer) - .as("has none vectorizer").isInstanceOf(NoneVectorizer.class); - Assertions.assertThat(defaultVector).extracting(VectorIndex::configuration) - .as("has hnsw index").returns(IndexType.HNSW, IndexingStrategy::type); - }); - -// -------------------------------------------- -client.collections.delete(collectionName); -// -------------------------------------------- - -var noCollection = client.collections.getConfig(collectionName); -Assertions.assertThat(noCollection).as("after delete").isEmpty(); + client.collections.create(collectionName, + col -> col + .properties(Property.text("username"), Property.integer("age")) + .vector(new VectorIndex<>(IndexingStrategy.hnsw(), Vectorizer.none()))); + + var thingsCollection = client.collections.getConfig(collectionName); + + Assertions.assertThat(thingsCollection).get() + .hasFieldOrPropertyWithValue("name", collectionName) + .extracting(CollectionDefinition::vectors).extracting(Vectors::getDefault) + .as("default vector").satisfies(defaultVector -> { + Assertions.assertThat(defaultVector).extracting(VectorIndex::vectorizer) + .as("has none vectorizer").isInstanceOf(NoneVectorizer.class); + Assertions.assertThat(defaultVector).extracting(VectorIndex::configuration) + .as("has hnsw index").returns(IndexType.HNSW, IndexingStrategy::type); + }); + + client.collections.delete(collectionName); + var noCollection = client.collections.getConfig(collectionName); + Assertions.assertThat(noCollection).as("after delete").isEmpty(); } } diff --git a/src/test/java/io/weaviate/internal/DtoTypeAdapterFactoryTest.java b/src/test/java/io/weaviate/client6/internal/DtoTypeAdapterFactoryTest.java similarity index 96% rename from src/test/java/io/weaviate/internal/DtoTypeAdapterFactoryTest.java rename to src/test/java/io/weaviate/client6/internal/DtoTypeAdapterFactoryTest.java index 85cf85da6..f3ca920db 100644 --- a/src/test/java/io/weaviate/internal/DtoTypeAdapterFactoryTest.java +++ b/src/test/java/io/weaviate/client6/internal/DtoTypeAdapterFactoryTest.java @@ -1,4 +1,4 @@ -package io.weaviate.internal; +package io.weaviate.client6.internal; import org.assertj.core.api.Assertions; import org.junit.Test; @@ -10,8 +10,6 @@ import com.jparams.junit4.JParamsTestRunner; import com.jparams.junit4.data.DataMethod; -import io.weaviate.client6.internal.DtoTypeAdapterFactory; - @RunWith(JParamsTestRunner.class) public class DtoTypeAdapterFactoryTest { /** Person should be serialized to PersonDto. */ diff --git a/src/test/java/io/weaviate/internal/GRPCTest.java b/src/test/java/io/weaviate/client6/internal/GRPCTest.java similarity index 100% rename from src/test/java/io/weaviate/internal/GRPCTest.java rename to src/test/java/io/weaviate/client6/internal/GRPCTest.java From 37f530fcaa72acac9a9d57017e0ef76b3ce6d68e Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Mon, 24 Mar 2025 20:08:38 +0100 Subject: [PATCH 4/7] feat: aggregate with near vector filter --- .../v1/query/NearVectorQueryITest.java | 10 +- .../integration/AggregationITest.java | 33 ++++- .../internal/codec/grpc/GrpcMarshaler.java | 5 + .../codec/grpc/v1/AggregateMarshaler.java | 140 ++++++++++++++++++ .../codec/grpc/v1/SearchMarshaler.java | 77 ++++++++++ .../client6/v1/collections/VectorIndex.java | 2 +- .../aggregate/AggregateGroupByRequest.java | 26 ++++ ...ult.java => AggregateGroupByResponse.java} | 2 +- .../aggregate/AggregateRequest.java | 74 ++------- ...gateResult.java => AggregateResponse.java} | 2 +- .../collections/aggregate/IntegerMetric.java | 64 ++------ .../v1/collections/aggregate/Metric.java | 21 +-- .../v1/collections/aggregate/TextMetric.java | 52 ++----- .../aggregate/WeaviateAggregate.java | 67 +++++++-- .../client6/v1/query/CommonQueryOptions.java | 106 +++++++++++++ .../weaviate/client6/v1/query/Metadata.java | 2 +- .../weaviate/client6/v1/query/NearVector.java | 43 +----- .../io/weaviate/client6/v1/query/Query.java | 21 +-- .../client6/v1/query/QueryOptions.java | 84 ----------- .../client6/v1/collections/VectorsTest.java | 18 ++- 20 files changed, 521 insertions(+), 328 deletions(-) create mode 100644 src/main/java/io/weaviate/client6/internal/codec/grpc/GrpcMarshaler.java create mode 100644 src/main/java/io/weaviate/client6/internal/codec/grpc/v1/AggregateMarshaler.java create mode 100644 src/main/java/io/weaviate/client6/internal/codec/grpc/v1/SearchMarshaler.java create mode 100644 src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateGroupByRequest.java rename src/main/java/io/weaviate/client6/v1/collections/aggregate/{AggregateGroupByResult.java => AggregateGroupByResponse.java} (56%) rename src/main/java/io/weaviate/client6/v1/collections/aggregate/{AggregateResult.java => AggregateResponse.java} (88%) create mode 100644 src/main/java/io/weaviate/client6/v1/query/CommonQueryOptions.java delete mode 100644 src/main/java/io/weaviate/client6/v1/query/QueryOptions.java diff --git a/src/it/java/io/weaviate/client6/v1/query/NearVectorQueryITest.java b/src/it/java/io/weaviate/client6/v1/query/NearVectorQueryITest.java index 93e945a2c..0b8693b75 100644 --- a/src/it/java/io/weaviate/client6/v1/query/NearVectorQueryITest.java +++ b/src/it/java/io/weaviate/client6/v1/query/NearVectorQueryITest.java @@ -41,11 +41,11 @@ public void testNearVector() { // TODO: test that we return the results in the expected order // Because re-ranking should work correctly var things = client.collections.use(COLLECTION); -QueryResult> result = things.query.nearVector(searchVector, - opt -> opt - .distance(2f) - .limit(3) - .returnMetadata(MetadataField.DISTANCE)); + QueryResult> result = things.query.nearVector(searchVector, + opt -> opt + .distance(2f) + .limit(3) + .returnMetadata(MetadataField.DISTANCE)); Assertions.assertThat(result.objects).hasSize(3); float maxDistance = Collections.max(result.objects, diff --git a/src/it/java/io/weaviate/integration/AggregationITest.java b/src/it/java/io/weaviate/integration/AggregationITest.java index 50c5af2e0..db7ba6242 100644 --- a/src/it/java/io/weaviate/integration/AggregationITest.java +++ b/src/it/java/io/weaviate/integration/AggregationITest.java @@ -14,7 +14,10 @@ import io.weaviate.ConcurrentTest; import io.weaviate.client6.WeaviateClient; import io.weaviate.client6.v1.collections.Property; -import io.weaviate.client6.v1.collections.aggregate.AggregateGroupByResult; +import io.weaviate.client6.v1.collections.VectorIndex; +import io.weaviate.client6.v1.collections.Vectorizer; +import io.weaviate.client6.v1.collections.Vectors; +import io.weaviate.client6.v1.collections.aggregate.AggregateGroupByResponse; import io.weaviate.client6.v1.collections.aggregate.Group; import io.weaviate.client6.v1.collections.aggregate.GroupedBy; import io.weaviate.client6.v1.collections.aggregate.IntegerMetric; @@ -31,16 +34,19 @@ public static void beforeAll() throws IOException { collection -> collection .properties( Property.text("category"), - Property.integer("price"))); + Property.integer("price")) + .vectors(Vectors.of(new VectorIndex<>(Vectorizer.none())))); var things = client.collections.use(COLLECTION); for (var category : List.of("Shoes", "Hat", "Jacket")) { for (var i = 0; i < 5; i++) { + var vector = randomVector(10, -.1f, .1f); // For simplicity, the "price" for each items equals to the // number of characters in the name of the category. things.data.insert(Map.of( "category", category, - "price", category.length())); + "price", category.length()), + meta -> meta.vectors(vector)); } } } @@ -75,7 +81,7 @@ public void testOverAll_groupBy_category() { .includeTotalCount()); Assertions.assertThat(result) - .extracting(AggregateGroupByResult::groups) + .extracting(AggregateGroupByResponse::groups) .asInstanceOf(InstanceOfAssertFactories.list(Group.class)) .as("group per category").hasSize(3) .allSatisfy(group -> { @@ -98,4 +104,23 @@ public void testOverAll_groupBy_category() { .as(desc.apply("count")).returns(5L, IntegerMetric.Values::count); }); } + + @Test + public void testNearVector() { + var things = client.collections.use(COLLECTION); + var result = things.aggregate.nearVector( + randomVector(10, -1f, 1f), + near -> near.limit(5), + with -> with.metrics( + Metric.integer("price", calculate -> calculate + .min().max().count())) + .objectLimit(4) + .includeTotalCount()); + + Assertions.assertThat(result) + .as("includes all objects").hasFieldOrPropertyWithValue("totalCount", 4L) + .as("'price' is IntegerMetric").returns(true, p -> p.isIntegerProperty("price")) + .as("aggregated prices").extracting(p -> p.getInteger("price")) + .as("count").returns(4L, IntegerMetric.Values::count); + } } diff --git a/src/main/java/io/weaviate/client6/internal/codec/grpc/GrpcMarshaler.java b/src/main/java/io/weaviate/client6/internal/codec/grpc/GrpcMarshaler.java new file mode 100644 index 000000000..ed6624b39 --- /dev/null +++ b/src/main/java/io/weaviate/client6/internal/codec/grpc/GrpcMarshaler.java @@ -0,0 +1,5 @@ +package io.weaviate.client6.internal.codec.grpc; + +public interface GrpcMarshaler { + R marshal(); +} diff --git a/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/AggregateMarshaler.java b/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/AggregateMarshaler.java new file mode 100644 index 000000000..9dea88c3a --- /dev/null +++ b/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/AggregateMarshaler.java @@ -0,0 +1,140 @@ +package io.weaviate.client6.internal.codec.grpc.v1; + +import java.util.function.BiConsumer; + +import com.google.common.collect.ImmutableMap; + +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoAggregate; +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoAggregate.AggregateRequest.Aggregation; +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoBaseSearch; +import io.weaviate.client6.internal.GRPC; +import io.weaviate.client6.v1.collections.aggregate.AggregateGroupByRequest; +import io.weaviate.client6.v1.collections.aggregate.AggregateGroupByRequest.GroupBy; +import io.weaviate.client6.v1.collections.aggregate.AggregateRequest; +import io.weaviate.client6.v1.collections.aggregate.IntegerMetric; +import io.weaviate.client6.v1.collections.aggregate.Metric; +import io.weaviate.client6.v1.collections.aggregate.TextMetric; +import io.weaviate.client6.v1.query.NearVector; + +public final class AggregateMarshaler { + private final WeaviateProtoAggregate.AggregateRequest.Builder req = WeaviateProtoAggregate.AggregateRequest + .newBuilder(); + + public WeaviateProtoAggregate.AggregateRequest marshal(AggregateGroupByRequest aggregateGroupBy) { + var aggregate = aggregateGroupBy.aggregate(); + if (aggregateGroupBy.groupBy() != null) { + addGroupBy(aggregate.collectionName(), aggregateGroupBy.groupBy(), req); + } + return marshal(aggregate); + } + + public WeaviateProtoAggregate.AggregateRequest marshal(NearVector nearVector, + AggregateGroupByRequest aggregateGroupBy) { + var aggregate = aggregateGroupBy.aggregate(); + if (aggregateGroupBy.groupBy() != null) { + addGroupBy(aggregate.collectionName(), aggregateGroupBy.groupBy(), req); + } + return marshal(nearVector, aggregate); + } + + public WeaviateProtoAggregate.AggregateRequest marshal(NearVector nearVector, AggregateRequest aggregate) { + req.setNearVector(buildNearVector(nearVector)); + if (nearVector.common().limit() != null) { + req.setLimit(nearVector.common().limit()); + } + return marshal(aggregate); + } + + private WeaviateProtoBaseSearch.NearVector buildNearVector(NearVector nv) { + var nearVector = WeaviateProtoBaseSearch.NearVector.newBuilder(); + nearVector.setVectorBytes(GRPC.toByteString(nv.vector())); + + if (nv.certainty() != null) { + nearVector.setCertainty(nv.certainty()); + } else if (nv.distance() != null) { + nearVector.setDistance(nv.distance()); + } + return nearVector.build(); + } + + public WeaviateProtoAggregate.AggregateRequest marshal(AggregateRequest aggregate) { + req.setCollection(aggregate.collectionName()); + + if (aggregate.includeTotalCount()) { + req.setObjectsCount(true); + } + + if (aggregate.objectLimit() != null) { + req.setObjectLimit(aggregate.objectLimit()); + } + + for (Metric metric : aggregate.returnMetrics()) { + addMetric(metric); + } + + return req.build(); + } + + private void addMetric(Metric metric) { + var aggregation = Aggregation.newBuilder(); + aggregation.setProperty(metric.property()); + + if (metric instanceof TextMetric m) { + var text = Aggregation.Text.newBuilder(); + m.functions().forEach(f -> set(f, text)); + if (m.atLeast() != null) { + text.setTopOccurencesLimit(m.atLeast()); + } + aggregation.setText(text); + } else if (metric instanceof IntegerMetric m) { + var integer = Aggregation.Integer.newBuilder(); + m.functions().forEach(f -> set(f, integer)); + aggregation.setInt(integer); + } else { + assert false : "branch not covered"; + } + + req.addAggregations(aggregation); + } + + private void addGroupBy(String collectionName, GroupBy groupBy, WeaviateProtoAggregate.AggregateRequest.Builder req) { + var by = WeaviateProtoAggregate.AggregateRequest.GroupBy.newBuilder(); + by.setCollection(collectionName); + by.setProperty(groupBy.property()); + req.setGroupBy(by); + } + + @SuppressWarnings("unchecked") + static final void set(Enum fn, B builder) { + if (metrics.containsKey(fn)) { + ((Toggle) metrics.get(fn)).toggleOn(builder); + } + } + + static final ImmutableMap, Toggle> metrics = new ImmutableMap.Builder, Toggle>() + .put(TextMetric._Function.TYPE, new Toggle<>(Aggregation.Text.Builder::setType)) + .put(TextMetric._Function.COUNT, new Toggle<>(Aggregation.Text.Builder::setCount)) + .put(TextMetric._Function.TOP_OCCURRENCES, new Toggle<>(Aggregation.Text.Builder::setTopOccurences)) + + .put(IntegerMetric._Function.COUNT, new Toggle<>(Aggregation.Integer.Builder::setCount)) + .put(IntegerMetric._Function.MIN, new Toggle<>(Aggregation.Integer.Builder::setMinimum)) + .put(IntegerMetric._Function.MAX, new Toggle<>(Aggregation.Integer.Builder::setMaximum)) + .put(IntegerMetric._Function.MEAN, new Toggle<>(Aggregation.Integer.Builder::setMean)) + .put(IntegerMetric._Function.MEDIAN, new Toggle<>(Aggregation.Integer.Builder::setMedian)) + .put(IntegerMetric._Function.MODE, new Toggle<>(Aggregation.Integer.Builder::setMode)) + .put(IntegerMetric._Function.SUM, new Toggle<>(Aggregation.Integer.Builder::setSum)) + .build(); + + static class Toggle { + private final BiConsumer setter; + + Toggle(BiConsumer setter) { + this.setter = setter; + } + + final void toggleOn(B builder) { + setter.accept(builder, true); + } + } + +} diff --git a/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/SearchMarshaler.java b/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/SearchMarshaler.java new file mode 100644 index 000000000..801358726 --- /dev/null +++ b/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/SearchMarshaler.java @@ -0,0 +1,77 @@ +package io.weaviate.client6.internal.codec.grpc.v1; + +import org.apache.commons.lang3.StringUtils; + +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoBaseSearch; +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoSearchGet; +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoSearchGet.MetadataRequest; +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoSearchGet.PropertiesRequest; +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoSearchGet.SearchRequest; +import io.weaviate.client6.internal.GRPC; +import io.weaviate.client6.internal.codec.grpc.GrpcMarshaler; +import io.weaviate.client6.v1.query.CommonQueryOptions; +import io.weaviate.client6.v1.query.NearVector; + +public class SearchMarshaler implements GrpcMarshaler { + private final WeaviateProtoSearchGet.SearchRequest.Builder req = WeaviateProtoSearchGet.SearchRequest.newBuilder(); + + public SearchMarshaler(String collectionName) { + req.setCollection(collectionName); + req.setUses123Api(true); + req.setUses125Api(true); + req.setUses127Api(true); + } + + public SearchMarshaler addNearVector(NearVector nv) { + setCommon(nv.common()); + + var nearVector = WeaviateProtoBaseSearch.NearVector.newBuilder(); + nearVector.setVectorBytes(GRPC.toByteString(nv.vector())); + + if (nv.certainty() != null) { + nearVector.setCertainty(nv.certainty()); + } else if (nv.distance() != null) { + nearVector.setDistance(nv.distance()); + } + + req.setNearVector(nearVector); + return this; + } + + private void setCommon(CommonQueryOptions o) { + if (o.limit() != null) { + req.setLimit(o.limit()); + } + if (o.offset() != null) { + req.setOffset(o.offset()); + } + if (StringUtils.isNotBlank(o.after())) { + req.setAfter(o.after()); + } + if (StringUtils.isNotBlank(o.consistencyLevel())) { + req.setConsistencyLevelValue(Integer.valueOf(o.consistencyLevel())); + } + if (o.autocut() != null) { + req.setAutocut(o.autocut()); + } + + if (!o.returnMetadata().isEmpty()) { + var metadata = MetadataRequest.newBuilder(); + o.returnMetadata().forEach(m -> m.appendTo(metadata)); + req.setMetadata(metadata); + } + + if (!o.returnProperties().isEmpty()) { + var properties = PropertiesRequest.newBuilder(); + for (String property : o.returnProperties()) { + properties.addNonRefProperties(property); + } + req.setProperties(properties); + } + } + + @Override + public SearchRequest marshal() { + return req.build(); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/collections/VectorIndex.java b/src/main/java/io/weaviate/client6/v1/collections/VectorIndex.java index ad1160dbf..5db348263 100644 --- a/src/main/java/io/weaviate/client6/v1/collections/VectorIndex.java +++ b/src/main/java/io/weaviate/client6/v1/collections/VectorIndex.java @@ -19,7 +19,7 @@ public VectorIndex(IndexingStrategy index, V vectorizer) { } public VectorIndex(V vectorizer) { - this(null, vectorizer, null); + this(IndexingStrategy.hnsw(), vectorizer); } public static sealed interface IndexingStrategy permits HNSW { diff --git a/src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateGroupByRequest.java b/src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateGroupByRequest.java new file mode 100644 index 000000000..0d3786f87 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateGroupByRequest.java @@ -0,0 +1,26 @@ +package io.weaviate.client6.v1.collections.aggregate; + +import java.util.function.Consumer; + +public record AggregateGroupByRequest(AggregateRequest aggregate, GroupBy groupBy) { + + public static record GroupBy(String property) { + public static GroupBy with(Consumer options) { + var opt = new Builder(options); + return new GroupBy(opt.property); + } + + public static class Builder { + private String property; + + public Builder property(String name) { + this.property = name; + return this; + } + + Builder(Consumer options) { + options.accept(this); + } + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateGroupByResult.java b/src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateGroupByResponse.java similarity index 56% rename from src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateGroupByResult.java rename to src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateGroupByResponse.java index 4eb30712b..8cfeef016 100644 --- a/src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateGroupByResult.java +++ b/src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateGroupByResponse.java @@ -2,6 +2,6 @@ import java.util.List; -public record AggregateGroupByResult(List> groups) { +public record AggregateGroupByResponse(List> groups) { } diff --git a/src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateRequest.java b/src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateRequest.java index 2c0af8297..3b7c75899 100644 --- a/src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateRequest.java +++ b/src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateRequest.java @@ -4,48 +4,23 @@ import java.util.List; import java.util.function.Consumer; -import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoAggregate.AggregateRequest.Aggregation; - -public record AggregateRequest(String collectionName, Integer objectLimit, GroupBy groupBy, - List>> returnMetrics, boolean includeTotalCount) { +public record AggregateRequest( + String collectionName, + Integer objectLimit, + boolean includeTotalCount, + List returnMetrics) { public static AggregateRequest with(String collectionName, Consumer options) { var opt = new Builder(options); - return new AggregateRequest(collectionName, opt.objectLimit, null, opt.metrics, opt.includeTotalCount); - } - - public static AggregateRequest with(String collectionName, Consumer options, - Consumer groupByOptions) { - var opt = new Builder(options); - return new AggregateRequest(collectionName, opt.objectLimit, GroupBy.with(groupByOptions), opt.metrics, - opt.includeTotalCount); - } - - void appendTo(io.weaviate.client6.grpc.protocol.v1.WeaviateProtoAggregate.AggregateRequest.Builder req) { - if (groupBy != null) { - var by = io.weaviate.client6.grpc.protocol.v1.WeaviateProtoAggregate.AggregateRequest.GroupBy.newBuilder(); - by.setCollection(collectionName); - groupBy.appendTo(by); - req.setGroupBy(by); - } - - if (includeTotalCount) { - req.setObjectsCount(true); - } - - if (objectLimit != null) { - req.setObjectLimit(objectLimit); - } - - for (Metric metric : returnMetrics) { - var agg = Aggregation.newBuilder(); - metric.appendTo(agg); - req.addAggregations(agg); - } + return new AggregateRequest( + collectionName, + opt.objectLimit, + opt.includeTotalCount, + opt.metrics); } public static class Builder { - private List>> metrics; + private List metrics; private Integer objectLimit; private boolean includeTotalCount = false; @@ -64,34 +39,9 @@ public Builder includeTotalCount() { } @SafeVarargs - public final Builder metrics(Metric>... metrics) { + public final Builder metrics(Metric... metrics) { this.metrics = Arrays.asList(metrics); return this; } } - - public static record GroupBy(String property) { - public static GroupBy with(Consumer options) { - var opt = new Builder(options); - return new GroupBy(opt.property); - } - - public void appendTo( - io.weaviate.client6.grpc.protocol.v1.WeaviateProtoAggregate.AggregateRequest.GroupBy.Builder groupBy) { - groupBy.setProperty(property); - } - - public static class Builder { - private String property; - - public Builder property(String name) { - this.property = name; - return this; - } - - Builder(Consumer options) { - options.accept(this); - } - } - } } diff --git a/src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateResult.java b/src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateResponse.java similarity index 88% rename from src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateResult.java rename to src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateResponse.java index 99ceae35e..f2d0cde13 100644 --- a/src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateResult.java +++ b/src/main/java/io/weaviate/client6/v1/collections/aggregate/AggregateResponse.java @@ -2,7 +2,7 @@ import java.util.Map; -public record AggregateResult(Map properties, Long totalCount) { +public record AggregateResponse(Map properties, Long totalCount) { public boolean isTextProperties(String name) { return properties.get(name) instanceof TextMetric.Values; } diff --git a/src/main/java/io/weaviate/client6/v1/collections/aggregate/IntegerMetric.java b/src/main/java/io/weaviate/client6/v1/collections/aggregate/IntegerMetric.java index e65da9b93..10ef8474f 100644 --- a/src/main/java/io/weaviate/client6/v1/collections/aggregate/IntegerMetric.java +++ b/src/main/java/io/weaviate/client6/v1/collections/aggregate/IntegerMetric.java @@ -1,63 +1,61 @@ package io.weaviate.client6.v1.collections.aggregate; +import java.util.ArrayList; import java.util.HashSet; +import java.util.List; import java.util.Set; import java.util.function.Consumer; -import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoAggregate.AggregateRequest.Aggregation; - -public class IntegerMetric extends Metric { - private final Set functions; +public record IntegerMetric(String property, List<_Function> functions) implements Metric { public record Values(Long count, Long min, Long max, Double mean, Double median, Long mode, Long sum) implements Metric.Values { } - IntegerMetric(String property, Consumer options) { - super(property); + static IntegerMetric with(String property, Consumer options) { var opt = new Builder(options); - this.functions = opt.functions; + return new IntegerMetric(property, new ArrayList<>(opt.functions)); } - private enum AggregateFunction { - COUNT, MIN, MAX, MEAN, MEDIAN, MODE, SUM + public enum _Function { + COUNT, MIN, MAX, MEAN, MEDIAN, MODE, SUM; } public static class Builder { - private final Set functions = new HashSet<>(); + private final Set<_Function> functions = new HashSet<>(); public Builder count() { - functions.add(AggregateFunction.COUNT); + functions.add(_Function.COUNT); return this; } public Builder min() { - functions.add(AggregateFunction.MIN); + functions.add(_Function.MIN); return this; } public Builder max() { - functions.add(AggregateFunction.MAX); + functions.add(_Function.MAX); return this; } public Builder mean() { - functions.add(AggregateFunction.MEAN); + functions.add(_Function.MEAN); return this; } public Builder median() { - functions.add(AggregateFunction.MEDIAN); + functions.add(_Function.MEDIAN); return this; } public Builder mode() { - functions.add(AggregateFunction.MODE); + functions.add(_Function.MODE); return this; } public Builder sum() { - functions.add(AggregateFunction.SUM); + functions.add(_Function.SUM); return this; } @@ -65,36 +63,4 @@ public Builder sum() { options.accept(this); } } - - void appendTo(Aggregation.Builder aggregation) { - aggregation.setProperty(property); - var integer = io.weaviate.client6.grpc.protocol.v1.WeaviateProtoAggregate.AggregateRequest.Aggregation.Integer - .newBuilder(); - for (var f : functions) { - switch (f) { - case COUNT: - integer.setCount(true); - break; - case MIN: - integer.setMinimum(true); - break; - case MAX: - integer.setMaximum(true); - break; - case MEAN: - integer.setMean(true); - break; - case MODE: - integer.setMode(true); - break; - case MEDIAN: - integer.setMedian(true); - break; - case SUM: - integer.setSum(true); - break; - } - } - aggregation.setInt(integer); - } } diff --git a/src/main/java/io/weaviate/client6/v1/collections/aggregate/Metric.java b/src/main/java/io/weaviate/client6/v1/collections/aggregate/Metric.java index 0b1989e02..588af7e43 100644 --- a/src/main/java/io/weaviate/client6/v1/collections/aggregate/Metric.java +++ b/src/main/java/io/weaviate/client6/v1/collections/aggregate/Metric.java @@ -1,34 +1,29 @@ package io.weaviate.client6.v1.collections.aggregate; +import java.util.List; import java.util.function.Consumer; -import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoAggregate.AggregateRequest.Aggregation; +public interface Metric { + String property(); -public abstract class Metric> { - protected final String property; - - abstract void appendTo(Aggregation.Builder aggregation); - - public Metric(String property) { - this.property = property; - } + List> functions(); public static TextMetric text(String property) { - return new TextMetric(property, _options -> { + return TextMetric.with(property, _options -> { }); } public static TextMetric text(String property, Consumer options) { - return new TextMetric(property, options); + return TextMetric.with(property, options); } public static IntegerMetric integer(String property) { - return new IntegerMetric(property, _options -> { + return IntegerMetric.with(property, _options -> { }); } public static IntegerMetric integer(String property, Consumer options) { - return new IntegerMetric(property, options); + return IntegerMetric.with(property, options); } public interface Values { diff --git a/src/main/java/io/weaviate/client6/v1/collections/aggregate/TextMetric.java b/src/main/java/io/weaviate/client6/v1/collections/aggregate/TextMetric.java index 61424e198..7499cff70 100644 --- a/src/main/java/io/weaviate/client6/v1/collections/aggregate/TextMetric.java +++ b/src/main/java/io/weaviate/client6/v1/collections/aggregate/TextMetric.java @@ -1,51 +1,46 @@ package io.weaviate.client6.v1.collections.aggregate; +import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Set; import java.util.function.Consumer; -import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoAggregate.AggregateRequest.Aggregation; -import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoAggregate.AggregateRequest.Aggregation.Text; - -public class TextMetric extends Metric { - private final Set functions; - private final boolean occurrenceCount; - private final Integer atLeast; +public record TextMetric(String property, List<_Function> functions, boolean occurrenceCount, + Integer atLeast) + implements Metric { public record Values(Long count, List topOccurrences) implements Metric.Values { } - TextMetric(String property, Consumer options) { - super(property); - + static TextMetric with(String property, Consumer options) { var opt = new Builder(options); - this.functions = opt.functions; - this.occurrenceCount = opt.occurrenceCount; - this.atLeast = opt.atLeast; + return new TextMetric(property, + new ArrayList<>(opt.functions), + opt.occurrenceCount, opt.atLeast); } - private enum AggregateFunction { - COUNT, TYPE, TOP_OCCURENCES + public enum _Function { + COUNT, TYPE, TOP_OCCURRENCES; } public static class Builder { - private final Set functions = new HashSet<>(); + private final Set<_Function> functions = new HashSet<>(); private boolean occurrenceCount = false; private Integer atLeast; public Builder count() { - functions.add(AggregateFunction.COUNT); + functions.add(_Function.COUNT); return this; } public Builder type() { - functions.add(AggregateFunction.TYPE); + functions.add(_Function.TYPE); return this; } public Builder topOccurences() { - functions.add(AggregateFunction.TOP_OCCURENCES); + functions.add(_Function.TOP_OCCURRENCES); return this; } @@ -65,23 +60,4 @@ public Builder includeTopOccurencesCount() { options.accept(this); } } - - void appendTo(Aggregation.Builder aggregation) { - aggregation.setProperty(property); - var text = Text.newBuilder(); - for (var f : functions) { - switch (f) { - case TYPE: - text.setType(true); - case COUNT: - text.setCount(true); - case TOP_OCCURENCES: - text.setTopOccurences(true); - if (atLeast != null) { - text.setTopOccurencesLimit(atLeast); - } - } - } - aggregation.setText(text); - } } diff --git a/src/main/java/io/weaviate/client6/v1/collections/aggregate/WeaviateAggregate.java b/src/main/java/io/weaviate/client6/v1/collections/aggregate/WeaviateAggregate.java index 99ce116e0..5d5eb4044 100644 --- a/src/main/java/io/weaviate/client6/v1/collections/aggregate/WeaviateAggregate.java +++ b/src/main/java/io/weaviate/client6/v1/collections/aggregate/WeaviateAggregate.java @@ -7,6 +7,8 @@ import java.util.function.Consumer; import io.weaviate.client6.internal.GrpcClient; +import io.weaviate.client6.internal.codec.grpc.v1.AggregateMarshaler; +import io.weaviate.client6.v1.query.NearVector; public class WeaviateAggregate { private final String collectionName; @@ -17,12 +19,10 @@ public WeaviateAggregate(String collectionName, GrpcClient grpc) { this.grpcClient = grpc; } - public AggregateResult overAll(Consumer options) { + public AggregateResponse overAll(Consumer options) { var aggregation = AggregateRequest.with(collectionName, options); - var req = io.weaviate.client6.grpc.protocol.v1.WeaviateProtoAggregate.AggregateRequest.newBuilder(); - req.setCollection(collectionName); - aggregation.appendTo(req); - var reply = grpcClient.grpc.aggregate(req.build()); + var req = new AggregateMarshaler().marshal(aggregation); + var reply = grpcClient.grpc.aggregate(req); Long totalCount = null; Map properties = new HashMap<>(); @@ -53,17 +53,17 @@ public AggregateResult overAll(Consumer options) { } } } - var result = new AggregateResult(properties, totalCount); + var result = new AggregateResponse(properties, totalCount); return result; } - public AggregateGroupByResult overAll(Consumer groupByOptions, + public AggregateGroupByResponse overAll(Consumer groupByOptions, Consumer options) { - var aggregation = AggregateRequest.with(collectionName, options, groupByOptions); - var req = io.weaviate.client6.grpc.protocol.v1.WeaviateProtoAggregate.AggregateRequest.newBuilder(); - req.setCollection(collectionName); - aggregation.appendTo(req); - var reply = grpcClient.grpc.aggregate(req.build()); + var aggregation = AggregateRequest.with(collectionName, options); + var groupBy = AggregateGroupByRequest.GroupBy.with(groupByOptions); + + var req = new AggregateMarshaler().marshal(new AggregateGroupByRequest(aggregation, groupBy)); + var reply = grpcClient.grpc.aggregate(req); List> groups = new ArrayList<>(); if (reply.hasGroupedResults()) { @@ -107,6 +107,47 @@ public AggregateGroupByResult overAll(Consumer } } - return new AggregateGroupByResult(groups); + return new AggregateGroupByResponse(groups); + } + + public AggregateResponse nearVector(Float[] vector, Consumer nearVectorOptions, + Consumer options) { + var aggregation = AggregateRequest.with(collectionName, options); + var nearVector = NearVector.with(vector, nearVectorOptions); + + var req = new AggregateMarshaler().marshal(nearVector, aggregation); + var reply = grpcClient.grpc.aggregate(req); + + Long totalCount = null; + Map properties = new HashMap<>(); + + if (reply.hasSingleResult()) { + var single = reply.getSingleResult(); + totalCount = single.hasObjectsCount() ? single.getObjectsCount() : null; + var aggregations = single.getAggregations().getAggregationsList(); + for (var agg : aggregations) { + var property = agg.getProperty(); + Metric.Values value = null; + + if (agg.hasInt()) { + var metrics = agg.getInt(); + value = new IntegerMetric.Values( + metrics.hasCount() ? metrics.getCount() : null, + metrics.hasMinimum() ? metrics.getMinimum() : null, + metrics.hasMaximum() ? metrics.getMaximum() : null, + metrics.hasMean() ? metrics.getMean() : null, + metrics.hasMedian() ? metrics.getMedian() : null, + metrics.hasMode() ? metrics.getMode() : null, + metrics.hasSum() ? metrics.getSum() : null); + } else { + assert false : "branch not covered"; + } + if (value != null) { + properties.put(property, value); + } + } + } + var result = new AggregateResponse(properties, totalCount); + return result; } } diff --git a/src/main/java/io/weaviate/client6/v1/query/CommonQueryOptions.java b/src/main/java/io/weaviate/client6/v1/query/CommonQueryOptions.java new file mode 100644 index 000000000..ddf1e1ab1 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/query/CommonQueryOptions.java @@ -0,0 +1,106 @@ +package io.weaviate.client6.v1.query; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.apache.commons.lang3.StringUtils; + +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoSearchGet.MetadataRequest; +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoSearchGet.PropertiesRequest; +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoSearchGet.SearchRequest; + +@SuppressWarnings("unchecked") +public record CommonQueryOptions( + Integer limit, + Integer offset, + Integer autocut, + String after, + String consistencyLevel /* TODO: use ConsistencyLevel enum */, + List returnProperties, + List returnMetadata) { + + public CommonQueryOptions(Builder> options) { + this( + options.limit, + options.offset, + options.autocut, + options.after, + options.consistencyLevel, + options.returnProperties, + options.returnMetadata); + + } + + public static abstract class Builder> { + private Integer limit; + private Integer offset; + private Integer autocut; + private String after; + private String consistencyLevel; + private List returnProperties = new ArrayList<>(); + private List returnMetadata = new ArrayList<>(); + + public final SELF limit(Integer limit) { + this.limit = limit; + return (SELF) this; + } + + public final SELF offset(Integer offset) { + this.offset = offset; + return (SELF) this; + } + + public final SELF autocut(Integer autocut) { + this.autocut = autocut; + return (SELF) this; + } + + public final SELF after(String after) { + this.after = after; + return (SELF) this; + } + + public final SELF consistencyLevel(String consistencyLevel) { + this.consistencyLevel = consistencyLevel; + return (SELF) this; + } + + public final SELF returnMetadata(Metadata... metadata) { + this.returnMetadata = Arrays.asList(metadata); + return (SELF) this; + } + + void appendTo(SearchRequest.Builder search) { + if (limit != null) { + search.setLimit(limit); + } + if (offset != null) { + search.setOffset(offset); + } + if (StringUtils.isNotBlank(after)) { + search.setAfter(after); + } + if (StringUtils.isNotBlank(consistencyLevel)) { + search.setConsistencyLevelValue(Integer.valueOf(consistencyLevel)); + } + if (autocut != null) { + search.setAutocut(autocut); + } + + if (!returnMetadata.isEmpty()) { + var metadata = MetadataRequest.newBuilder(); + returnMetadata.forEach(m -> m.appendTo(metadata)); + search.setMetadata(metadata); + } + + if (!returnProperties.isEmpty()) { + var properties = PropertiesRequest.newBuilder(); + for (String property : returnProperties) { + properties.addNonRefProperties(property); + } + search.setProperties(properties); + } + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/query/Metadata.java b/src/main/java/io/weaviate/client6/v1/query/Metadata.java index d490ee67f..4cc37bd98 100644 --- a/src/main/java/io/weaviate/client6/v1/query/Metadata.java +++ b/src/main/java/io/weaviate/client6/v1/query/Metadata.java @@ -5,7 +5,7 @@ /** * Metadata is the common base for all properties that are requestes as * "_additional". It is an inteface all metadata properties MUST implement to be - * used in {@link QueryOptions}. + * used in {@link CommonQueryOptions}. */ public interface Metadata { void appendTo(MetadataRequest.Builder metadata); diff --git a/src/main/java/io/weaviate/client6/v1/query/NearVector.java b/src/main/java/io/weaviate/client6/v1/query/NearVector.java index e479e8809..acb76279b 100644 --- a/src/main/java/io/weaviate/client6/v1/query/NearVector.java +++ b/src/main/java/io/weaviate/client6/v1/query/NearVector.java @@ -2,53 +2,26 @@ import java.util.function.Consumer; -import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoBaseSearch; -import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoSearchGet.SearchRequest; -import io.weaviate.client6.internal.GRPC; +public record NearVector(Float[] vector, Float distance, Float certainty, CommonQueryOptions common) { -public class NearVector { - private final Float[] vector; - private final Options options; - - void appendTo(SearchRequest.Builder search) { - var nearVector = WeaviateProtoBaseSearch.NearVector.newBuilder(); - - // TODO: we should only add (named) Vectors. - // Since we do not force the users to supply a name when defining an index, - // we also need a way to "get default vector name" from the collection. - // For Map (untyped query handle) we always require the name. - nearVector.setVectorBytes(GRPC.toByteString(vector)); - options.append(search, nearVector); - search.setNearVector(nearVector.build()); - } - - public NearVector(Float[] vector, Consumer options) { - this.options = new Options(); - this.vector = vector; - options.accept(this.options); + public static NearVector with(Float[] vector, Consumer options) { + var opt = new Builder(); + options.accept(opt); + return new NearVector(vector, opt.distance, opt.certainty, new CommonQueryOptions(opt)); } - public static class Options extends QueryOptions { + public static class Builder extends CommonQueryOptions.Builder { private Float distance; private Float certainty; - public Options distance(float distance) { + public Builder distance(float distance) { this.distance = distance; return this; } - public Options certainty(float certainty) { + public Builder certainty(float certainty) { this.certainty = certainty; return this; } - - void append(SearchRequest.Builder search, WeaviateProtoBaseSearch.NearVector.Builder nearVector) { - if (certainty != null) { - nearVector.setCertainty(certainty); - } else if (distance != null) { - nearVector.setDistance(distance); - } - super.appendTo(search); - } } } diff --git a/src/main/java/io/weaviate/client6/v1/query/Query.java b/src/main/java/io/weaviate/client6/v1/query/Query.java index 7ac1508c7..5de02b452 100644 --- a/src/main/java/io/weaviate/client6/v1/query/Query.java +++ b/src/main/java/io/weaviate/client6/v1/query/Query.java @@ -15,11 +15,9 @@ import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoSearchGet.SearchRequest; import io.weaviate.client6.internal.GRPC; import io.weaviate.client6.internal.GrpcClient; +import io.weaviate.client6.internal.codec.grpc.v1.SearchMarshaler; public class Query { - // TODO: inject singleton as dependency - private static final Gson gson = new Gson(); - // TODO: this should be wrapped around in some TypeInspector etc. private final String collectionName; @@ -32,19 +30,10 @@ public Query(String collectionName, GrpcClient grpc) { this.collectionName = collectionName; } - public QueryResult nearVector(Float[] vector, Consumer options) { - var query = new NearVector(vector, options); - - // TODO: Since we always need to set these values, we migth want to move the - // next block to some factory method. - var req = SearchRequest.newBuilder(); - req.setCollection(collectionName); - req.setUses123Api(true); - req.setUses125Api(true); - req.setUses127Api(true); - - query.appendTo(req); - return search(req.build()); + public QueryResult nearVector(Float[] vector, Consumer options) { + var query = NearVector.with(vector, options); + var req = new SearchMarshaler(collectionName).addNearVector(query); + return search(req.marshal()); } private QueryResult search(SearchRequest req) { diff --git a/src/main/java/io/weaviate/client6/v1/query/QueryOptions.java b/src/main/java/io/weaviate/client6/v1/query/QueryOptions.java deleted file mode 100644 index 5ae284953..000000000 --- a/src/main/java/io/weaviate/client6/v1/query/QueryOptions.java +++ /dev/null @@ -1,84 +0,0 @@ -package io.weaviate.client6.v1.query; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -import org.apache.commons.lang3.StringUtils; - -import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoSearchGet.MetadataRequest; -import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoSearchGet.PropertiesRequest; -import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoSearchGet.SearchRequest; - -@SuppressWarnings("unchecked") -abstract class QueryOptions> { - private Integer limit; - private Integer offset; - private Integer autocut; - private String after; - private String consistencyLevel; - private List returnProperties = new ArrayList<>(); - private List returnMetadata = new ArrayList<>(); - - public final SELF limit(Integer limit) { - this.limit = limit; - return (SELF) this; - } - - public final SELF offset(Integer offset) { - this.offset = offset; - return (SELF) this; - } - - public final SELF autocut(Integer autocut) { - this.autocut = autocut; - return (SELF) this; - } - - public final SELF after(String after) { - this.after = after; - return (SELF) this; - } - - public final SELF consistencyLevel(String consistencyLevel) { - this.consistencyLevel = consistencyLevel; - return (SELF) this; - } - - public final SELF returnMetadata(Metadata... metadata) { - this.returnMetadata = Arrays.asList(metadata); - return (SELF) this; - } - - void appendTo(SearchRequest.Builder search) { - if (limit != null) { - search.setLimit(limit); - } - if (offset != null) { - search.setOffset(offset); - } - if (StringUtils.isNotBlank(after)) { - search.setAfter(after); - } - if (StringUtils.isNotBlank(consistencyLevel)) { - search.setConsistencyLevelValue(Integer.valueOf(consistencyLevel)); - } - if (autocut != null) { - search.setAutocut(autocut); - } - - if (!returnMetadata.isEmpty()) { - var metadata = MetadataRequest.newBuilder(); - returnMetadata.forEach(m -> m.appendTo(metadata)); - search.setMetadata(metadata); - } - - if (!returnProperties.isEmpty()) { - var properties = PropertiesRequest.newBuilder(); - for (String property : returnProperties) { - properties.addNonRefProperties(property); - } - search.setProperties(properties); - } - } -} diff --git a/src/test/java/io/weaviate/client6/v1/collections/VectorsTest.java b/src/test/java/io/weaviate/client6/v1/collections/VectorsTest.java index 37426ac75..8deae4893 100644 --- a/src/test/java/io/weaviate/client6/v1/collections/VectorsTest.java +++ b/src/test/java/io/weaviate/client6/v1/collections/VectorsTest.java @@ -54,7 +54,11 @@ public static Object[][] testCases() { """ { "vectorConfig": { - "default": { "vectorizer": { "none":{}}} + "default": { + "vectorizer": { "none": {}}, + "vectorIndexType": "hnsw", + "vectorIndexConfig": {} + } } } """, @@ -65,9 +69,13 @@ public static Object[][] testCases() { """ { "vectorConfig": { - "vector-1": { "vectorizer": { "none":{}}}, + "vector-1": { + "vectorizer": { "none": {}}, + "vectorIndexType": "hnsw", + "vectorIndexConfig": {} + }, "vector-2": { - "vectorizer": { "none":{}}, + "vectorizer": { "none": {}}, "vectorIndexType": "hnsw", "vectorIndexConfig": {} } @@ -83,8 +91,8 @@ public static Object[][] testCases() { """ { "vectorizer": { "none": {}}, - "vectorIndexConfig": { "distance": "COSINE", "skip": true }, - "vectorIndexType": "hnsw" + "vectorIndexType": "hnsw", + "vectorIndexConfig": { "distance": "COSINE", "skip": true } } """, collectionWithVectors(Vectors.unnamed( From c18c88dba97060686f062f53c01e3a205e809b97 Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Tue, 25 Mar 2025 12:24:21 +0100 Subject: [PATCH 5/7] feat: add aggregation + near vector + groupBy --- .../integration/AggregationITest.java | 38 +++++++++ .../codec/grpc/v1/AggregateMarshaler.java | 65 ++++++++++++--- .../aggregate/WeaviateAggregate.java | 79 ++++++++++++++++++- 3 files changed, 169 insertions(+), 13 deletions(-) diff --git a/src/it/java/io/weaviate/integration/AggregationITest.java b/src/it/java/io/weaviate/integration/AggregationITest.java index db7ba6242..f5ae4665b 100644 --- a/src/it/java/io/weaviate/integration/AggregationITest.java +++ b/src/it/java/io/weaviate/integration/AggregationITest.java @@ -123,4 +123,42 @@ public void testNearVector() { .as("aggregated prices").extracting(p -> p.getInteger("price")) .as("count").returns(4L, IntegerMetric.Values::count); } + + @Test + public void testNearVector_groupBy_category() { + var things = client.collections.use(COLLECTION); + var result = things.aggregate.nearVector( + randomVector(10, -1f, 1f), + near -> near.distance(2f), + groupBy -> groupBy.property("category"), + with -> with.metrics( + Metric.integer("price", calculate -> calculate + .min().max().median())) + .objectLimit(9) + .includeTotalCount()); + + Assertions.assertThat(result) + .extracting(AggregateGroupByResponse::groups) + .asInstanceOf(InstanceOfAssertFactories.list(Group.class)) + .as("group per category").hasSize(3) + .allSatisfy(group -> { + Assertions.assertThat(group) + .extracting(Group::by) + .as(group.by().property() + " is Text property").returns(true, GroupedBy::isText); + + String category = group.by().getAsText(); + var expectedPrice = (long) category.length(); + + Function> desc = (String metric) -> { + return () -> "%s ('%s'.length)".formatted(metric, category); + }; + + Assertions.assertThat(group) + .as("'price' is IntegerMetric").returns(true, g -> g.isIntegerProperty("price")) + .as("aggregated prices").extracting(g -> g.getInteger("price")) + .as(desc.apply("max")).returns(expectedPrice, IntegerMetric.Values::max) + .as(desc.apply("min")).returns(expectedPrice, IntegerMetric.Values::min) + .as(desc.apply("median")).returns((double) expectedPrice, IntegerMetric.Values::median); + }); + } } diff --git a/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/AggregateMarshaler.java b/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/AggregateMarshaler.java index 9dea88c3a..89ac0a9d9 100644 --- a/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/AggregateMarshaler.java +++ b/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/AggregateMarshaler.java @@ -19,11 +19,65 @@ public final class AggregateMarshaler { private final WeaviateProtoAggregate.AggregateRequest.Builder req = WeaviateProtoAggregate.AggregateRequest .newBuilder(); + private final String collectionName; + + public AggregateMarshaler(String collectionName) { + this.collectionName = collectionName; + } + + public WeaviateProtoAggregate.AggregateRequest marshal() { + return req.build(); + } + + public AggregateMarshaler addAggregation(AggregateRequest aggregate) { + req.setCollection(collectionName); + + if (aggregate.includeTotalCount()) { + req.setObjectsCount(true); + } + + if (aggregate.objectLimit() != null) { + req.setObjectLimit(aggregate.objectLimit()); + } + + for (Metric metric : aggregate.returnMetrics()) { + addMetric(metric); + } + + return this; + } + + public AggregateMarshaler addGroupBy(GroupBy groupBy) { + var by = WeaviateProtoAggregate.AggregateRequest.GroupBy.newBuilder(); + by.setCollection(collectionName); + by.setProperty(groupBy.property()); + req.setGroupBy(by); + return this; + } + + public AggregateMarshaler addNearVector(NearVector nv) { + var nearVector = WeaviateProtoBaseSearch.NearVector.newBuilder(); + nearVector.setVectorBytes(GRPC.toByteString(nv.vector())); + + if (nv.certainty() != null) { + nearVector.setCertainty(nv.certainty()); + } else if (nv.distance() != null) { + nearVector.setDistance(nv.distance()); + } + + req.setNearVector(nearVector); + + // Base query options + if (nv.common().limit() != null) { + req.setLimit(nv.common().limit()); + } + return this; + } public WeaviateProtoAggregate.AggregateRequest marshal(AggregateGroupByRequest aggregateGroupBy) { var aggregate = aggregateGroupBy.aggregate(); if (aggregateGroupBy.groupBy() != null) { - addGroupBy(aggregate.collectionName(), aggregateGroupBy.groupBy(), req); + // addGroupBy(aggregate.collectionName(), aggregateGroupBy.groupBy(), req); } return marshal(aggregate); } @@ -32,7 +86,7 @@ public WeaviateProtoAggregate.AggregateRequest marshal(NearVector nearVector, AggregateGroupByRequest aggregateGroupBy) { var aggregate = aggregateGroupBy.aggregate(); if (aggregateGroupBy.groupBy() != null) { - addGroupBy(aggregate.collectionName(), aggregateGroupBy.groupBy(), req); + // addGroupBy(aggregate.collectionName(), aggregateGroupBy.groupBy(), req); } return marshal(nearVector, aggregate); } @@ -97,13 +151,6 @@ private void addMetric(Metric metric) { req.addAggregations(aggregation); } - private void addGroupBy(String collectionName, GroupBy groupBy, WeaviateProtoAggregate.AggregateRequest.Builder req) { - var by = WeaviateProtoAggregate.AggregateRequest.GroupBy.newBuilder(); - by.setCollection(collectionName); - by.setProperty(groupBy.property()); - req.setGroupBy(by); - } - @SuppressWarnings("unchecked") static final void set(Enum fn, B builder) { if (metrics.containsKey(fn)) { diff --git a/src/main/java/io/weaviate/client6/v1/collections/aggregate/WeaviateAggregate.java b/src/main/java/io/weaviate/client6/v1/collections/aggregate/WeaviateAggregate.java index 5d5eb4044..02e81872d 100644 --- a/src/main/java/io/weaviate/client6/v1/collections/aggregate/WeaviateAggregate.java +++ b/src/main/java/io/weaviate/client6/v1/collections/aggregate/WeaviateAggregate.java @@ -21,7 +21,9 @@ public WeaviateAggregate(String collectionName, GrpcClient grpc) { public AggregateResponse overAll(Consumer options) { var aggregation = AggregateRequest.with(collectionName, options); - var req = new AggregateMarshaler().marshal(aggregation); + var req = new AggregateMarshaler(aggregation.collectionName()) + .addAggregation(aggregation) + .marshal(); var reply = grpcClient.grpc.aggregate(req); Long totalCount = null; @@ -62,7 +64,10 @@ public AggregateGroupByResponse overAll(Consumer> groups = new ArrayList<>(); @@ -110,12 +115,17 @@ public AggregateGroupByResponse overAll(Consumer nearVectorOptions, + public AggregateResponse nearVector( + Float[] vector, + Consumer nearVectorOptions, Consumer options) { var aggregation = AggregateRequest.with(collectionName, options); var nearVector = NearVector.with(vector, nearVectorOptions); - var req = new AggregateMarshaler().marshal(nearVector, aggregation); + var req = new AggregateMarshaler(aggregation.collectionName()) + .addAggregation(aggregation) + .addNearVector(nearVector) + .marshal(); var reply = grpcClient.grpc.aggregate(req); Long totalCount = null; @@ -150,4 +160,65 @@ public AggregateResponse nearVector(Float[] vector, Consumer var result = new AggregateResponse(properties, totalCount); return result; } + + public AggregateGroupByResponse nearVector( + Float[] vector, + Consumer nearVectorOptions, + Consumer groupByOptions, + Consumer options) { + var aggregation = AggregateRequest.with(collectionName, options); + var nearVector = NearVector.with(vector, nearVectorOptions); + var groupBy = AggregateGroupByRequest.GroupBy.with(groupByOptions); + + var req = new AggregateMarshaler(aggregation.collectionName()) + .addAggregation(aggregation) + .addGroupBy(groupBy) + .addNearVector(nearVector) + .marshal(); + var reply = grpcClient.grpc.aggregate(req); + + List> groups = new ArrayList<>(); + if (reply.hasGroupedResults()) { + for (var result : reply.getGroupedResults().getGroupsList()) { + final Long totalCount = result.hasObjectsCount() ? result.getObjectsCount() : null; + + GroupedBy groupedBy = null; + var gb = result.getGroupedBy(); + if (gb.hasInt()) { + groupedBy = new GroupedBy(gb.getPathList().get(0), gb.getInt()); + } else if (gb.hasText()) { + groupedBy = new GroupedBy(gb.getPathList().get(0), gb.getText()); + } else { + assert false : "branch not covered"; + } + + Map properties = new HashMap<>(); + for (var agg : result.getAggregations().getAggregationsList()) { + var property = agg.getProperty(); + Metric.Values value = null; + + if (agg.hasInt()) { + var metrics = agg.getInt(); + value = new IntegerMetric.Values( + metrics.hasCount() ? metrics.getCount() : null, + metrics.hasMinimum() ? metrics.getMinimum() : null, + metrics.hasMaximum() ? metrics.getMaximum() : null, + metrics.hasMean() ? metrics.getMean() : null, + metrics.hasMedian() ? metrics.getMedian() : null, + metrics.hasMode() ? metrics.getMode() : null, + metrics.hasSum() ? metrics.getSum() : null); + } else { + assert false : "branch not covered"; + } + if (value != null) { + properties.put(property, value); + } + } + Group group = new Group<>(groupedBy, properties, totalCount); + groups.add(group); + + } + } + return new AggregateGroupByResponse(groups); + } } From 9ec90b69810ef1222ce47cdd412e9276bf8e81cd Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Tue, 25 Mar 2025 14:28:27 +0100 Subject: [PATCH 6/7] refactor: move unmarshaling code to internal/codec --- .../integration/AggregationITest.java | 5 +- .../codec/grpc/v1/AggregateMarshaler.java | 55 ------ .../codec/grpc/v1/AggregateUnmarshaler.java | 102 ++++++++++ .../aggregate/WeaviateAggregate.java | 181 +++--------------- 4 files changed, 128 insertions(+), 215 deletions(-) create mode 100644 src/main/java/io/weaviate/client6/internal/codec/grpc/v1/AggregateUnmarshaler.java diff --git a/src/it/java/io/weaviate/integration/AggregationITest.java b/src/it/java/io/weaviate/integration/AggregationITest.java index f5ae4665b..bd54ed865 100644 --- a/src/it/java/io/weaviate/integration/AggregationITest.java +++ b/src/it/java/io/weaviate/integration/AggregationITest.java @@ -17,6 +17,7 @@ import io.weaviate.client6.v1.collections.VectorIndex; import io.weaviate.client6.v1.collections.Vectorizer; import io.weaviate.client6.v1.collections.Vectors; +import io.weaviate.client6.v1.collections.aggregate.AggregateGroupByRequest.GroupBy; import io.weaviate.client6.v1.collections.aggregate.AggregateGroupByResponse; import io.weaviate.client6.v1.collections.aggregate.Group; import io.weaviate.client6.v1.collections.aggregate.GroupedBy; @@ -74,7 +75,7 @@ public void testOverAll() { public void testOverAll_groupBy_category() { var things = client.collections.use(COLLECTION); var result = things.aggregate.overAll( - groupBy -> groupBy.property("category"), + new GroupBy("category"), with -> with.metrics( Metric.integer("price", calculate -> calculate .min().max().count())) @@ -130,7 +131,7 @@ public void testNearVector_groupBy_category() { var result = things.aggregate.nearVector( randomVector(10, -1f, 1f), near -> near.distance(2f), - groupBy -> groupBy.property("category"), + new GroupBy("category"), with -> with.metrics( Metric.integer("price", calculate -> calculate .min().max().median())) diff --git a/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/AggregateMarshaler.java b/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/AggregateMarshaler.java index 89ac0a9d9..446adba78 100644 --- a/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/AggregateMarshaler.java +++ b/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/AggregateMarshaler.java @@ -74,61 +74,6 @@ public AggregateMarshaler addNearVector(NearVector nv) { return this; } - public WeaviateProtoAggregate.AggregateRequest marshal(AggregateGroupByRequest aggregateGroupBy) { - var aggregate = aggregateGroupBy.aggregate(); - if (aggregateGroupBy.groupBy() != null) { - // addGroupBy(aggregate.collectionName(), aggregateGroupBy.groupBy(), req); - } - return marshal(aggregate); - } - - public WeaviateProtoAggregate.AggregateRequest marshal(NearVector nearVector, - AggregateGroupByRequest aggregateGroupBy) { - var aggregate = aggregateGroupBy.aggregate(); - if (aggregateGroupBy.groupBy() != null) { - // addGroupBy(aggregate.collectionName(), aggregateGroupBy.groupBy(), req); - } - return marshal(nearVector, aggregate); - } - - public WeaviateProtoAggregate.AggregateRequest marshal(NearVector nearVector, AggregateRequest aggregate) { - req.setNearVector(buildNearVector(nearVector)); - if (nearVector.common().limit() != null) { - req.setLimit(nearVector.common().limit()); - } - return marshal(aggregate); - } - - private WeaviateProtoBaseSearch.NearVector buildNearVector(NearVector nv) { - var nearVector = WeaviateProtoBaseSearch.NearVector.newBuilder(); - nearVector.setVectorBytes(GRPC.toByteString(nv.vector())); - - if (nv.certainty() != null) { - nearVector.setCertainty(nv.certainty()); - } else if (nv.distance() != null) { - nearVector.setDistance(nv.distance()); - } - return nearVector.build(); - } - - public WeaviateProtoAggregate.AggregateRequest marshal(AggregateRequest aggregate) { - req.setCollection(aggregate.collectionName()); - - if (aggregate.includeTotalCount()) { - req.setObjectsCount(true); - } - - if (aggregate.objectLimit() != null) { - req.setObjectLimit(aggregate.objectLimit()); - } - - for (Metric metric : aggregate.returnMetrics()) { - addMetric(metric); - } - - return req.build(); - } - private void addMetric(Metric metric) { var aggregation = Aggregation.newBuilder(); aggregation.setProperty(metric.property()); diff --git a/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/AggregateUnmarshaler.java b/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/AggregateUnmarshaler.java new file mode 100644 index 000000000..c26c174be --- /dev/null +++ b/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/AggregateUnmarshaler.java @@ -0,0 +1,102 @@ +package io.weaviate.client6.internal.codec.grpc.v1; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import io.weaviate.client6.grpc.protocol.v1.WeaviateProtoAggregate; +import io.weaviate.client6.v1.collections.aggregate.AggregateGroupByResponse; +import io.weaviate.client6.v1.collections.aggregate.AggregateResponse; +import io.weaviate.client6.v1.collections.aggregate.Group; +import io.weaviate.client6.v1.collections.aggregate.GroupedBy; +import io.weaviate.client6.v1.collections.aggregate.IntegerMetric; +import io.weaviate.client6.v1.collections.aggregate.Metric; + +public final class AggregateUnmarshaler { + private final WeaviateProtoAggregate.AggregateReply reply; + + public AggregateUnmarshaler(WeaviateProtoAggregate.AggregateReply reply) { + this.reply = reply; + } + + public AggregateResponse single() { + Long totalCount = null; + Map properties = new HashMap<>(); + + if (reply.hasSingleResult()) { + var single = reply.getSingleResult(); + totalCount = single.hasObjectsCount() ? single.getObjectsCount() : null; + var aggregations = single.getAggregations().getAggregationsList(); + for (var agg : aggregations) { + var property = agg.getProperty(); + Metric.Values value = null; + + if (agg.hasInt()) { + var metrics = agg.getInt(); + value = new IntegerMetric.Values( + metrics.hasCount() ? metrics.getCount() : null, + metrics.hasMinimum() ? metrics.getMinimum() : null, + metrics.hasMaximum() ? metrics.getMaximum() : null, + metrics.hasMean() ? metrics.getMean() : null, + metrics.hasMedian() ? metrics.getMedian() : null, + metrics.hasMode() ? metrics.getMode() : null, + metrics.hasSum() ? metrics.getSum() : null); + } else { + assert false : "branch not covered"; + } + if (value != null) { + properties.put(property, value); + } + } + } + var result = new AggregateResponse(properties, totalCount); + return result; + } + + public AggregateGroupByResponse grouped() { + List> groups = new ArrayList<>(); + if (reply.hasGroupedResults()) { + for (var result : reply.getGroupedResults().getGroupsList()) { + final Long totalCount = result.hasObjectsCount() ? result.getObjectsCount() : null; + + GroupedBy groupedBy = null; + var gb = result.getGroupedBy(); + if (gb.hasInt()) { + groupedBy = new GroupedBy(gb.getPathList().get(0), gb.getInt()); + } else if (gb.hasText()) { + groupedBy = new GroupedBy(gb.getPathList().get(0), gb.getText()); + } else { + assert false : "branch not covered"; + } + + Map properties = new HashMap<>(); + for (var agg : result.getAggregations().getAggregationsList()) { + var property = agg.getProperty(); + Metric.Values value = null; + + if (agg.hasInt()) { + var metrics = agg.getInt(); + value = new IntegerMetric.Values( + metrics.hasCount() ? metrics.getCount() : null, + metrics.hasMinimum() ? metrics.getMinimum() : null, + metrics.hasMaximum() ? metrics.getMaximum() : null, + metrics.hasMean() ? metrics.getMean() : null, + metrics.hasMedian() ? metrics.getMedian() : null, + metrics.hasMode() ? metrics.getMode() : null, + metrics.hasSum() ? metrics.getSum() : null); + } else { + assert false : "branch not covered"; + } + if (value != null) { + properties.put(property, value); + } + } + Group group = new Group<>(groupedBy, properties, totalCount); + groups.add(group); + + } + } + return new AggregateGroupByResponse(groups); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/collections/aggregate/WeaviateAggregate.java b/src/main/java/io/weaviate/client6/v1/collections/aggregate/WeaviateAggregate.java index 02e81872d..73474fb7b 100644 --- a/src/main/java/io/weaviate/client6/v1/collections/aggregate/WeaviateAggregate.java +++ b/src/main/java/io/weaviate/client6/v1/collections/aggregate/WeaviateAggregate.java @@ -1,13 +1,10 @@ package io.weaviate.client6.v1.collections.aggregate; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; import java.util.function.Consumer; import io.weaviate.client6.internal.GrpcClient; import io.weaviate.client6.internal.codec.grpc.v1.AggregateMarshaler; +import io.weaviate.client6.internal.codec.grpc.v1.AggregateUnmarshaler; import io.weaviate.client6.v1.query.NearVector; public class WeaviateAggregate { @@ -25,94 +22,20 @@ public AggregateResponse overAll(Consumer options) { .addAggregation(aggregation) .marshal(); var reply = grpcClient.grpc.aggregate(req); - - Long totalCount = null; - Map properties = new HashMap<>(); - - if (reply.hasSingleResult()) { - var single = reply.getSingleResult(); - totalCount = single.hasObjectsCount() ? single.getObjectsCount() : null; - var aggregations = single.getAggregations().getAggregationsList(); - for (var agg : aggregations) { - var property = agg.getProperty(); - Metric.Values value = null; - - if (agg.hasInt()) { - var metrics = agg.getInt(); - value = new IntegerMetric.Values( - metrics.hasCount() ? metrics.getCount() : null, - metrics.hasMinimum() ? metrics.getMinimum() : null, - metrics.hasMaximum() ? metrics.getMaximum() : null, - metrics.hasMean() ? metrics.getMean() : null, - metrics.hasMedian() ? metrics.getMedian() : null, - metrics.hasMode() ? metrics.getMode() : null, - metrics.hasSum() ? metrics.getSum() : null); - } else { - assert false : "branch not covered"; - } - if (value != null) { - properties.put(property, value); - } - } - } - var result = new AggregateResponse(properties, totalCount); - return result; + return new AggregateUnmarshaler(reply).single(); } - public AggregateGroupByResponse overAll(Consumer groupByOptions, + public AggregateGroupByResponse overAll( + AggregateGroupByRequest.GroupBy groupBy, Consumer options) { var aggregation = AggregateRequest.with(collectionName, options); - var groupBy = AggregateGroupByRequest.GroupBy.with(groupByOptions); var req = new AggregateMarshaler(aggregation.collectionName()) .addAggregation(aggregation) .addGroupBy(groupBy) .marshal(); var reply = grpcClient.grpc.aggregate(req); - - List> groups = new ArrayList<>(); - if (reply.hasGroupedResults()) { - for (var result : reply.getGroupedResults().getGroupsList()) { - final Long totalCount = result.hasObjectsCount() ? result.getObjectsCount() : null; - - GroupedBy groupedBy = null; - var gb = result.getGroupedBy(); - if (gb.hasInt()) { - groupedBy = new GroupedBy(gb.getPathList().get(0), gb.getInt()); - } else if (gb.hasText()) { - groupedBy = new GroupedBy(gb.getPathList().get(0), gb.getText()); - } else { - assert false : "branch not covered"; - } - - Map properties = new HashMap<>(); - for (var agg : result.getAggregations().getAggregationsList()) { - var property = agg.getProperty(); - Metric.Values value = null; - - if (agg.hasInt()) { - var metrics = agg.getInt(); - value = new IntegerMetric.Values( - metrics.hasCount() ? metrics.getCount() : null, - metrics.hasMinimum() ? metrics.getMinimum() : null, - metrics.hasMaximum() ? metrics.getMaximum() : null, - metrics.hasMean() ? metrics.getMean() : null, - metrics.hasMedian() ? metrics.getMedian() : null, - metrics.hasMode() ? metrics.getMode() : null, - metrics.hasSum() ? metrics.getSum() : null); - } else { - assert false : "branch not covered"; - } - if (value != null) { - properties.put(property, value); - } - } - Group group = new Group<>(groupedBy, properties, totalCount); - groups.add(group); - - } - } - return new AggregateGroupByResponse(groups); + return new AggregateUnmarshaler(reply).grouped(); } public AggregateResponse nearVector( @@ -127,48 +50,33 @@ public AggregateResponse nearVector( .addNearVector(nearVector) .marshal(); var reply = grpcClient.grpc.aggregate(req); + return new AggregateUnmarshaler(reply).single(); + } - Long totalCount = null; - Map properties = new HashMap<>(); - - if (reply.hasSingleResult()) { - var single = reply.getSingleResult(); - totalCount = single.hasObjectsCount() ? single.getObjectsCount() : null; - var aggregations = single.getAggregations().getAggregationsList(); - for (var agg : aggregations) { - var property = agg.getProperty(); - Metric.Values value = null; + public AggregateGroupByResponse nearVector( + Float[] vector, + AggregateGroupByRequest.GroupBy groupBy, + Consumer options) { + var aggregation = AggregateRequest.with(collectionName, options); + var nearVector = NearVector.with(vector, opt -> { + }); - if (agg.hasInt()) { - var metrics = agg.getInt(); - value = new IntegerMetric.Values( - metrics.hasCount() ? metrics.getCount() : null, - metrics.hasMinimum() ? metrics.getMinimum() : null, - metrics.hasMaximum() ? metrics.getMaximum() : null, - metrics.hasMean() ? metrics.getMean() : null, - metrics.hasMedian() ? metrics.getMedian() : null, - metrics.hasMode() ? metrics.getMode() : null, - metrics.hasSum() ? metrics.getSum() : null); - } else { - assert false : "branch not covered"; - } - if (value != null) { - properties.put(property, value); - } - } - } - var result = new AggregateResponse(properties, totalCount); - return result; + var req = new AggregateMarshaler(aggregation.collectionName()) + .addAggregation(aggregation) + .addGroupBy(groupBy) + .addNearVector(nearVector) + .marshal(); + var reply = grpcClient.grpc.aggregate(req); + return new AggregateUnmarshaler(reply).grouped(); } public AggregateGroupByResponse nearVector( Float[] vector, Consumer nearVectorOptions, - Consumer groupByOptions, + AggregateGroupByRequest.GroupBy groupBy, Consumer options) { var aggregation = AggregateRequest.with(collectionName, options); var nearVector = NearVector.with(vector, nearVectorOptions); - var groupBy = AggregateGroupByRequest.GroupBy.with(groupByOptions); var req = new AggregateMarshaler(aggregation.collectionName()) .addAggregation(aggregation) @@ -176,49 +84,6 @@ public AggregateGroupByResponse nearVector( .addNearVector(nearVector) .marshal(); var reply = grpcClient.grpc.aggregate(req); - - List> groups = new ArrayList<>(); - if (reply.hasGroupedResults()) { - for (var result : reply.getGroupedResults().getGroupsList()) { - final Long totalCount = result.hasObjectsCount() ? result.getObjectsCount() : null; - - GroupedBy groupedBy = null; - var gb = result.getGroupedBy(); - if (gb.hasInt()) { - groupedBy = new GroupedBy(gb.getPathList().get(0), gb.getInt()); - } else if (gb.hasText()) { - groupedBy = new GroupedBy(gb.getPathList().get(0), gb.getText()); - } else { - assert false : "branch not covered"; - } - - Map properties = new HashMap<>(); - for (var agg : result.getAggregations().getAggregationsList()) { - var property = agg.getProperty(); - Metric.Values value = null; - - if (agg.hasInt()) { - var metrics = agg.getInt(); - value = new IntegerMetric.Values( - metrics.hasCount() ? metrics.getCount() : null, - metrics.hasMinimum() ? metrics.getMinimum() : null, - metrics.hasMaximum() ? metrics.getMaximum() : null, - metrics.hasMean() ? metrics.getMean() : null, - metrics.hasMedian() ? metrics.getMedian() : null, - metrics.hasMode() ? metrics.getMode() : null, - metrics.hasSum() ? metrics.getSum() : null); - } else { - assert false : "branch not covered"; - } - if (value != null) { - properties.put(property, value); - } - } - Group group = new Group<>(groupedBy, properties, totalCount); - groups.add(group); - - } - } - return new AggregateGroupByResponse(groups); + return new AggregateUnmarshaler(reply).grouped(); } } From c76c726ae7a5c3467fab17a7da233bc26598ccd2 Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Tue, 25 Mar 2025 19:34:17 +0100 Subject: [PATCH 7/7] feat: add GroupBy to NearVector search --- .../NearVectorQueryITest.java | 44 +++++++++++++--- .../codec/grpc/v1/SearchMarshaler.java | 9 ++++ .../client6/v1/query/GroupedQueryResult.java | 26 ++++++++++ .../weaviate/client6/v1/query/NearVector.java | 3 ++ .../io/weaviate/client6/v1/query/Query.java | 52 +++++++++++++++++++ .../client6/v1/query/QueryResult.java | 8 +-- 6 files changed, 129 insertions(+), 13 deletions(-) rename src/it/java/io/weaviate/{client6/v1/query => integration}/NearVectorQueryITest.java (59%) create mode 100644 src/main/java/io/weaviate/client6/v1/query/GroupedQueryResult.java diff --git a/src/it/java/io/weaviate/client6/v1/query/NearVectorQueryITest.java b/src/it/java/io/weaviate/integration/NearVectorQueryITest.java similarity index 59% rename from src/it/java/io/weaviate/client6/v1/query/NearVectorQueryITest.java rename to src/it/java/io/weaviate/integration/NearVectorQueryITest.java index 0b8693b75..66258810d 100644 --- a/src/it/java/io/weaviate/client6/v1/query/NearVectorQueryITest.java +++ b/src/it/java/io/weaviate/integration/NearVectorQueryITest.java @@ -1,9 +1,10 @@ -package io.weaviate.client6.v1.query; +package io.weaviate.integration; import java.io.IOException; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; +import java.util.List; import java.util.Map; import org.assertj.core.api.Assertions; @@ -13,9 +14,13 @@ import io.weaviate.ConcurrentTest; import io.weaviate.client6.WeaviateClient; import io.weaviate.client6.v1.Vectors; +import io.weaviate.client6.v1.collections.Property; import io.weaviate.client6.v1.collections.VectorIndex; import io.weaviate.client6.v1.collections.VectorIndex.IndexingStrategy; import io.weaviate.client6.v1.collections.Vectorizer; +import io.weaviate.client6.v1.query.GroupedQueryResult; +import io.weaviate.client6.v1.query.MetadataField; +import io.weaviate.client6.v1.query.NearVector; import io.weaviate.containers.Container; public class NearVectorQueryITest extends ConcurrentTest { @@ -23,6 +28,7 @@ public class NearVectorQueryITest extends ConcurrentTest { private static final String COLLECTION = unique("Things"); private static final String VECTOR_INDEX = "bring_your_own"; + private static final List CATEGORIES = List.of("red", "green"); /** * One of the inserted vectors which will be used as target vector for search. @@ -32,7 +38,7 @@ public class NearVectorQueryITest extends ConcurrentTest { @BeforeClass public static void beforeAll() throws IOException { createTestCollection(); - var created = createVectors(10); + var created = populateTest(10); searchVector = created.values().iterator().next(); } @@ -41,7 +47,7 @@ public void testNearVector() { // TODO: test that we return the results in the expected order // Because re-ranking should work correctly var things = client.collections.use(COLLECTION); - QueryResult> result = things.query.nearVector(searchVector, + var result = things.query.nearVector(searchVector, opt -> opt .distance(2f) .limit(3) @@ -49,23 +55,48 @@ public void testNearVector() { Assertions.assertThat(result.objects).hasSize(3); float maxDistance = Collections.max(result.objects, - Comparator.comparing(obj -> obj.metadata.distance)).metadata.distance; + Comparator.comparing(obj -> obj.metadata.distance())).metadata.distance(); Assertions.assertThat(maxDistance).isLessThanOrEqualTo(2f); } + @Test + public void testNearVector_groupBy() { + // TODO: test that we return the results in the expected order + // Because re-ranking should work correctly + var things = client.collections.use(COLLECTION); + var result = things.query.nearVector(searchVector, + new NearVector.GroupBy("category", 2, 5), + opt -> opt.distance(10f)); + + Assertions.assertThat(result.groups) + .as("group per category").containsOnlyKeys(CATEGORIES) + .hasSizeLessThanOrEqualTo(2) + .allSatisfy((category, group) -> { + Assertions.assertThat(group) + .as("group name").returns(category, GroupedQueryResult.Group::name); + Assertions.assertThat(group.numberOfObjects()) + .as("[%s] has 1+ object", category).isLessThanOrEqualTo(5L); + }); + + Assertions.assertThat(result.objects) + .as("object belongs a group") + .allMatch(obj -> result.groups.get(obj.belongsToGroup).objects().contains(obj)); + + } + /** * Insert 10 objects with random vectors. * * @returns IDs of inserted objects and their corresponding vectors. */ - private static Map createVectors(int n) throws IOException { + private static Map populateTest(int n) throws IOException { var created = new HashMap(); var things = client.collections.use(COLLECTION); for (int i = 0; i < n; i++) { var vector = randomVector(10, -.01f, .001f); var object = things.data.insert( - Map.of(), + Map.of("category", CATEGORIES.get(i % CATEGORIES.size())), metadata -> metadata .id(randomUUID()) .vectors(Vectors.of(VECTOR_INDEX, vector))); @@ -83,6 +114,7 @@ private static Map createVectors(int n) throws IOException { */ private static void createTestCollection() throws IOException { client.collections.create(COLLECTION, cfg -> cfg + .properties(Property.text("category")) .vector(VECTOR_INDEX, new VectorIndex<>(IndexingStrategy.hnsw(), Vectorizer.none()))); } } diff --git a/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/SearchMarshaler.java b/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/SearchMarshaler.java index 801358726..a85970bb1 100644 --- a/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/SearchMarshaler.java +++ b/src/main/java/io/weaviate/client6/internal/codec/grpc/v1/SearchMarshaler.java @@ -22,6 +22,15 @@ public SearchMarshaler(String collectionName) { req.setUses127Api(true); } + public SearchMarshaler addGroupBy(NearVector.GroupBy gb) { + var groupBy = WeaviateProtoSearchGet.GroupBy.newBuilder(); + groupBy.addPath(gb.property()); + groupBy.setNumberOfGroups(gb.maxGroups()); + groupBy.setObjectsPerGroup(gb.maxObjectsPerGroup()); + req.setGroupBy(groupBy); + return this; + } + public SearchMarshaler addNearVector(NearVector nv) { setCommon(nv.common()); diff --git a/src/main/java/io/weaviate/client6/v1/query/GroupedQueryResult.java b/src/main/java/io/weaviate/client6/v1/query/GroupedQueryResult.java new file mode 100644 index 000000000..01b8e68a4 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/query/GroupedQueryResult.java @@ -0,0 +1,26 @@ +package io.weaviate.client6.v1.query; + +import java.util.List; +import java.util.Map; + +import io.weaviate.client6.v1.query.QueryResult.SearchObject; +import lombok.AllArgsConstructor; + +@AllArgsConstructor +public class GroupedQueryResult { + public final List> objects; + public final Map> groups; + + public static class WithGroupSearchObject extends SearchObject { + public final String belongsToGroup; + + public WithGroupSearchObject(String group, T properties, QueryMetadata metadata) { + super(properties, metadata); + this.belongsToGroup = group; + } + } + + public record Group(String name, Float minDistance, Float maxDistance, long numberOfObjects, + List> objects) { + } +} diff --git a/src/main/java/io/weaviate/client6/v1/query/NearVector.java b/src/main/java/io/weaviate/client6/v1/query/NearVector.java index acb76279b..6cfee7f8f 100644 --- a/src/main/java/io/weaviate/client6/v1/query/NearVector.java +++ b/src/main/java/io/weaviate/client6/v1/query/NearVector.java @@ -24,4 +24,7 @@ public Builder certainty(float certainty) { return this; } } + + public static record GroupBy(String property, int maxGroups, int maxObjectsPerGroup) { + } } diff --git a/src/main/java/io/weaviate/client6/v1/query/Query.java b/src/main/java/io/weaviate/client6/v1/query/Query.java index 5de02b452..673ed1f48 100644 --- a/src/main/java/io/weaviate/client6/v1/query/Query.java +++ b/src/main/java/io/weaviate/client6/v1/query/Query.java @@ -1,6 +1,7 @@ package io.weaviate.client6.v1.query; import java.time.OffsetDateTime; +import java.util.ArrayList; import java.util.Date; import java.util.List; import java.util.Map; @@ -36,11 +37,32 @@ public QueryResult nearVector(Float[] vector, Consumer op return search(req.marshal()); } + public GroupedQueryResult nearVector(Float[] vector, NearVector.GroupBy groupBy, + Consumer options) { + var query = NearVector.with(vector, options); + var req = new SearchMarshaler(collectionName).addNearVector(query) + .addGroupBy(groupBy); + return searchGrouped(req.marshal()); + } + + public GroupedQueryResult nearVector(Float[] vector, NearVector.GroupBy groupBy) { + var query = NearVector.with(vector, opt -> { + }); + var req = new SearchMarshaler(collectionName).addNearVector(query) + .addGroupBy(groupBy); + return searchGrouped(req.marshal()); + } + private QueryResult search(SearchRequest req) { var reply = grpcClient.grpc.search(req); return deserializeUntyped(reply); } + private GroupedQueryResult searchGrouped(SearchRequest req) { + var reply = grpcClient.grpc.search(req); + return deserializeUntypedGrouped(reply); + } + public QueryResult deserializeUntyped(SearchReply reply) { List> objects = reply.getResultsList().stream() .map(res -> { @@ -59,6 +81,36 @@ public QueryResult deserializeUntyped(SearchReply reply) { return new QueryResult(objects); } + public GroupedQueryResult deserializeUntypedGrouped(SearchReply reply) { + var allObjects = new ArrayList>(); + Map> allGroups = reply.getGroupByResultsList() + .stream().map(g -> { + var groupName = g.getName(); + var groupObjects = g.getObjectsList().stream().map(res -> { + Map properties = convertProtoMap(res.getProperties().getNonRefProps().getFieldsMap()); + + MetadataResult meta = res.getMetadata(); + var metadata = new QueryResult.SearchObject.QueryMetadata( + meta.getId(), + meta.getDistancePresent() ? meta.getDistance() : null, + GRPC.fromByteString(meta.getVectorBytes())); + var obj = new GroupedQueryResult.WithGroupSearchObject(groupName, (T) properties, metadata); + + allObjects.add(obj); + + return obj; + }).toList(); + + return new GroupedQueryResult.Group<>( + groupName, + g.getMinDistance(), + g.getMaxDistance(), + g.getNumberOfObjects(), + groupObjects); + }).collect(Collectors.toMap(GroupedQueryResult.Group::name, g -> g)); + return new GroupedQueryResult<>(allObjects, allGroups); + } + /** * Convert Map to Map such that can be * (de-)serialized by {@link Gson}. diff --git a/src/main/java/io/weaviate/client6/v1/query/QueryResult.java b/src/main/java/io/weaviate/client6/v1/query/QueryResult.java index 24b0a91e2..3d03a9840 100644 --- a/src/main/java/io/weaviate/client6/v1/query/QueryResult.java +++ b/src/main/java/io/weaviate/client6/v1/query/QueryResult.java @@ -3,7 +3,6 @@ import java.util.List; import lombok.AllArgsConstructor; -import lombok.ToString; @AllArgsConstructor public class QueryResult { @@ -14,13 +13,8 @@ public static class SearchObject { public final T properties; public final QueryMetadata metadata; - @AllArgsConstructor - @ToString - public static class QueryMetadata { - String id; - Float distance; + public record QueryMetadata(String id, Float distance, Float[] vector) { // TODO: use Vectors (to handle both Float[] and Float[][]) - Float[] vector; } } }