Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 165 additions & 0 deletions src/it/java/io/weaviate/integration/AggregationITest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
package io.weaviate.integration;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.function.Supplier;

import org.assertj.core.api.Assertions;
import org.assertj.core.api.InstanceOfAssertFactories;
import org.junit.BeforeClass;
import org.junit.Test;

import io.weaviate.ConcurrentTest;
import io.weaviate.client6.WeaviateClient;
import io.weaviate.client6.v1.collections.Property;
import io.weaviate.client6.v1.collections.VectorIndex;
import io.weaviate.client6.v1.collections.Vectorizer;
import io.weaviate.client6.v1.collections.Vectors;
import io.weaviate.client6.v1.collections.aggregate.AggregateGroupByRequest.GroupBy;
import io.weaviate.client6.v1.collections.aggregate.AggregateGroupByResponse;
import io.weaviate.client6.v1.collections.aggregate.Group;
import io.weaviate.client6.v1.collections.aggregate.GroupedBy;
import io.weaviate.client6.v1.collections.aggregate.IntegerMetric;
import io.weaviate.client6.v1.collections.aggregate.Metric;
import io.weaviate.containers.Container;

public class AggregationITest extends ConcurrentTest {
private static WeaviateClient client = Container.WEAVIATE.getClient();
private static final String COLLECTION = unique("Things");

@BeforeClass
public static void beforeAll() throws IOException {
client.collections.create(COLLECTION,
collection -> collection
.properties(
Property.text("category"),
Property.integer("price"))
.vectors(Vectors.of(new VectorIndex<>(Vectorizer.none()))));

var things = client.collections.use(COLLECTION);
for (var category : List.of("Shoes", "Hat", "Jacket")) {
for (var i = 0; i < 5; i++) {
var vector = randomVector(10, -.1f, .1f);
// For simplicity, the "price" for each items equals to the
// number of characters in the name of the category.
things.data.insert(Map.of(
"category", category,
"price", category.length()),
meta -> meta.vectors(vector));
}
}
}

@Test
public void testOverAll() {
var things = client.collections.use(COLLECTION);
var result = things.aggregate.overAll(
with -> with.metrics(
Metric.integer("price", calculate -> calculate
.median().max().count()))
.includeTotalCount());

Assertions.assertThat(result)
.as("includes all objects").hasFieldOrPropertyWithValue("totalCount", 15L)
.as("'price' is IntegerMetric").returns(true, p -> p.isIntegerProperty("price"))
.as("aggregated prices").extracting(p -> p.getInteger("price"))
.as("min").returns(null, IntegerMetric.Values::min)
.as("max").returns(6L, IntegerMetric.Values::max)
.as("median").returns(5D, IntegerMetric.Values::median)
.as("count").returns(15L, IntegerMetric.Values::count);
}

@Test
public void testOverAll_groupBy_category() {
var things = client.collections.use(COLLECTION);
var result = things.aggregate.overAll(
new GroupBy("category"),
with -> with.metrics(
Metric.integer("price", calculate -> calculate
.min().max().count()))
.includeTotalCount());

Assertions.assertThat(result)
.extracting(AggregateGroupByResponse::groups)
.asInstanceOf(InstanceOfAssertFactories.list(Group.class))
.as("group per category").hasSize(3)
.allSatisfy(group -> {
Assertions.assertThat(group)
.extracting(Group::by)
.as(group.by().property() + " is Text property").returns(true, GroupedBy::isText);

String category = group.by().getAsText();
var expectedPrice = (long) category.length();

Function<String, Supplier<String>> desc = (String metric) -> {
return () -> "%s ('%s'.length)".formatted(metric, category);
};

Assertions.assertThat(group)
.as("'price' is IntegerMetric").returns(true, g -> g.isIntegerProperty("price"))
.as("aggregated prices").extracting(g -> g.getInteger("price"))
.as(desc.apply("max")).returns(expectedPrice, IntegerMetric.Values::max)
.as(desc.apply("min")).returns(expectedPrice, IntegerMetric.Values::min)
.as(desc.apply("count")).returns(5L, IntegerMetric.Values::count);
});
}

@Test
public void testNearVector() {
var things = client.collections.use(COLLECTION);
var result = things.aggregate.nearVector(
randomVector(10, -1f, 1f),
near -> near.limit(5),
with -> with.metrics(
Metric.integer("price", calculate -> calculate
.min().max().count()))
.objectLimit(4)
.includeTotalCount());

Assertions.assertThat(result)
.as("includes all objects").hasFieldOrPropertyWithValue("totalCount", 4L)
.as("'price' is IntegerMetric").returns(true, p -> p.isIntegerProperty("price"))
.as("aggregated prices").extracting(p -> p.getInteger("price"))
.as("count").returns(4L, IntegerMetric.Values::count);
}

@Test
public void testNearVector_groupBy_category() {
var things = client.collections.use(COLLECTION);
var result = things.aggregate.nearVector(
randomVector(10, -1f, 1f),
near -> near.distance(2f),
new GroupBy("category"),
with -> with.metrics(
Metric.integer("price", calculate -> calculate
.min().max().median()))
.objectLimit(9)
.includeTotalCount());

Assertions.assertThat(result)
.extracting(AggregateGroupByResponse::groups)
.asInstanceOf(InstanceOfAssertFactories.list(Group.class))
.as("group per category").hasSize(3)
.allSatisfy(group -> {
Assertions.assertThat(group)
.extracting(Group::by)
.as(group.by().property() + " is Text property").returns(true, GroupedBy::isText);

String category = group.by().getAsText();
var expectedPrice = (long) category.length();

Function<String, Supplier<String>> desc = (String metric) -> {
return () -> "%s ('%s'.length)".formatted(metric, category);
};

Assertions.assertThat(group)
.as("'price' is IntegerMetric").returns(true, g -> g.isIntegerProperty("price"))
.as("aggregated prices").extracting(g -> g.getInteger("price"))
.as(desc.apply("max")).returns(expectedPrice, IntegerMetric.Values::max)
.as(desc.apply("min")).returns(expectedPrice, IntegerMetric.Values::min)
.as(desc.apply("median")).returns((double) expectedPrice, IntegerMetric.Values::median);
});
}
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package io.weaviate.client6.v1.query;
package io.weaviate.integration;

import java.io.IOException;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.assertj.core.api.Assertions;
Expand All @@ -13,16 +14,21 @@
import io.weaviate.ConcurrentTest;
import io.weaviate.client6.WeaviateClient;
import io.weaviate.client6.v1.Vectors;
import io.weaviate.client6.v1.collections.Property;
import io.weaviate.client6.v1.collections.VectorIndex;
import io.weaviate.client6.v1.collections.VectorIndex.IndexingStrategy;
import io.weaviate.client6.v1.collections.Vectorizer;
import io.weaviate.client6.v1.query.GroupedQueryResult;
import io.weaviate.client6.v1.query.MetadataField;
import io.weaviate.client6.v1.query.NearVector;
import io.weaviate.containers.Container;

public class NearVectorQueryITest extends ConcurrentTest {
private static final WeaviateClient client = Container.WEAVIATE.getClient();

private static final String COLLECTION = unique("Things");
private static final String VECTOR_INDEX = "bring_your_own";
private static final List<String> CATEGORIES = List.of("red", "green");

/**
* One of the inserted vectors which will be used as target vector for search.
Expand All @@ -32,7 +38,7 @@ public class NearVectorQueryITest extends ConcurrentTest {
@BeforeClass
public static void beforeAll() throws IOException {
createTestCollection();
var created = createVectors(10);
var created = populateTest(10);
searchVector = created.values().iterator().next();
}

Expand All @@ -41,31 +47,56 @@ public void testNearVector() {
// TODO: test that we return the results in the expected order
// Because re-ranking should work correctly
var things = client.collections.use(COLLECTION);
QueryResult<Map<String, Object>> result = things.query.nearVector(searchVector,
var result = things.query.nearVector(searchVector,
opt -> opt
.distance(2f)
.limit(3)
.returnMetadata(MetadataField.DISTANCE));

Assertions.assertThat(result.objects).hasSize(3);
float maxDistance = Collections.max(result.objects,
Comparator.comparing(obj -> obj.metadata.distance)).metadata.distance;
Comparator.comparing(obj -> obj.metadata.distance())).metadata.distance();
Assertions.assertThat(maxDistance).isLessThanOrEqualTo(2f);
}

@Test
public void testNearVector_groupBy() {
// TODO: test that we return the results in the expected order
// Because re-ranking should work correctly
var things = client.collections.use(COLLECTION);
var result = things.query.nearVector(searchVector,
new NearVector.GroupBy("category", 2, 5),
opt -> opt.distance(10f));

Assertions.assertThat(result.groups)
.as("group per category").containsOnlyKeys(CATEGORIES)
.hasSizeLessThanOrEqualTo(2)
.allSatisfy((category, group) -> {
Assertions.assertThat(group)
.as("group name").returns(category, GroupedQueryResult.Group::name);
Assertions.assertThat(group.numberOfObjects())
.as("[%s] has 1+ object", category).isLessThanOrEqualTo(5L);
});

Assertions.assertThat(result.objects)
.as("object belongs a group")
.allMatch(obj -> result.groups.get(obj.belongsToGroup).objects().contains(obj));

}

/**
* Insert 10 objects with random vectors.
*
* @returns IDs of inserted objects and their corresponding vectors.
*/
private static Map<String, Float[]> createVectors(int n) throws IOException {
private static Map<String, Float[]> populateTest(int n) throws IOException {
var created = new HashMap<String, Float[]>();

var things = client.collections.use(COLLECTION);
for (int i = 0; i < n; i++) {
var vector = randomVector(10, -.01f, .001f);
var object = things.data.insert(
Map.of(),
Map.of("category", CATEGORIES.get(i % CATEGORIES.size())),
metadata -> metadata
.id(randomUUID())
.vectors(Vectors.of(VECTOR_INDEX, vector)));
Expand All @@ -83,6 +114,7 @@ private static Map<String, Float[]> createVectors(int n) throws IOException {
*/
private static void createTestCollection() throws IOException {
client.collections.create(COLLECTION, cfg -> cfg
.properties(Property.text("category"))
.vector(VECTOR_INDEX, new VectorIndex<>(IndexingStrategy.hnsw(), Vectorizer.none())));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package io.weaviate.client6.internal.codec.grpc;

public interface GrpcMarshaler<R> {
R marshal();
}
Loading