Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 58 additions & 54 deletions src/it/java/io/weaviate/integration/AggregationITest.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,16 @@

import io.weaviate.ConcurrentTest;
import io.weaviate.client6.WeaviateClient;
import io.weaviate.client6.v1.api.collections.Vectors;
import io.weaviate.client6.v1.api.collections.aggregate.AggregateResponseGroup;
import io.weaviate.client6.v1.api.collections.aggregate.AggregateResponseGrouped;
import io.weaviate.client6.v1.api.collections.aggregate.Aggregation;
import io.weaviate.client6.v1.api.collections.aggregate.GroupBy;
import io.weaviate.client6.v1.api.collections.aggregate.GroupedBy;
import io.weaviate.client6.v1.api.collections.aggregate.IntegerAggregation;
import io.weaviate.client6.v1.collections.Property;
import io.weaviate.client6.v1.collections.VectorIndex;
import io.weaviate.client6.v1.collections.Vectorizer;
import io.weaviate.client6.v1.collections.aggregate.AggregateGroupByRequest.GroupBy;
import io.weaviate.client6.v1.collections.aggregate.AggregateGroupByResponse;
import io.weaviate.client6.v1.collections.aggregate.Group;
import io.weaviate.client6.v1.collections.aggregate.GroupedBy;
import io.weaviate.client6.v1.collections.aggregate.IntegerMetric;
import io.weaviate.client6.v1.collections.aggregate.Metric;
import io.weaviate.client6.v1.collections.object.Vectors;
import io.weaviate.containers.Container;

public class AggregationITest extends ConcurrentTest {
Expand Down Expand Up @@ -56,53 +56,55 @@ public static void beforeAll() throws IOException {
public void testOverAll() {
var things = client.collections.use(COLLECTION);
var result = things.aggregate.overAll(
with -> with.metrics(
Metric.integer("price", calculate -> calculate
.median().max().count()))
.includeTotalCount());
with -> with
.metrics(
Aggregation.integer("price",
calculate -> calculate.median().max().count()))
.includeTotalCount(true));

Assertions.assertThat(result)
.as("includes all objects").hasFieldOrPropertyWithValue("totalCount", 15L)
.as("'price' is IntegerMetric").returns(true, p -> p.isIntegerProperty("price"))
.as("aggregated prices").extracting(p -> p.getInteger("price"))
.as("min").returns(null, IntegerMetric.Values::min)
.as("max").returns(6L, IntegerMetric.Values::max)
.as("median").returns(5D, IntegerMetric.Values::median)
.as("count").returns(15L, IntegerMetric.Values::count);
.as("'price' is IntegerAggregation").returns(true, p -> p.isInteger("price"))
.as("aggregated prices").extracting(p -> p.integer("price"))
.as("min").returns(null, IntegerAggregation.Values::min)
.as("max").returns(6L, IntegerAggregation.Values::max)
.as("median").returns(5D, IntegerAggregation.Values::median)
.as("count").returns(15L, IntegerAggregation.Values::count);
}

@Test
public void testOverAll_groupBy_category() {
var things = client.collections.use(COLLECTION);
var result = things.aggregate.overAll(
new GroupBy("category"),
with -> with.metrics(
Metric.integer("price", calculate -> calculate
.min().max().count()))
.includeTotalCount());
with -> with
.metrics(
Aggregation.integer("price",
calculate -> calculate.min().max().count()))
.includeTotalCount(true),
new GroupBy("category"));

Assertions.assertThat(result)
.extracting(AggregateGroupByResponse::groups)
.asInstanceOf(InstanceOfAssertFactories.list(Group.class))
.extracting(AggregateResponseGrouped::groups)
.asInstanceOf(InstanceOfAssertFactories.list(AggregateResponseGroup.class))
.as("group per category").hasSize(3)
.allSatisfy(group -> {
Assertions.assertThat(group)
.extracting(Group::by)
.as(group.by().property() + " is Text property").returns(true, GroupedBy::isText);
.extracting(AggregateResponseGroup::groupedBy)
.as(group.groupedBy().property() + " is Text property").returns(true, GroupedBy::isText);

String category = group.by().getAsText();
String category = group.groupedBy().text();
var expectedPrice = (long) category.length();

Function<String, Supplier<String>> desc = (String metric) -> {
return () -> "%s ('%s'.length)".formatted(metric, category);
};

Assertions.assertThat(group)
.as("'price' is IntegerMetric").returns(true, g -> g.isIntegerProperty("price"))
.as("aggregated prices").extracting(g -> g.getInteger("price"))
.as(desc.apply("max")).returns(expectedPrice, IntegerMetric.Values::max)
.as(desc.apply("min")).returns(expectedPrice, IntegerMetric.Values::min)
.as(desc.apply("count")).returns(5L, IntegerMetric.Values::count);
.as("'price' is IntegerAggregation").returns(true, g -> g.isInteger("price"))
.as("aggregated prices").extracting(g -> g.integer("price"))
.as(desc.apply("max")).returns(expectedPrice, IntegerAggregation.Values::max)
.as(desc.apply("min")).returns(expectedPrice, IntegerAggregation.Values::min)
.as(desc.apply("count")).returns(5L, IntegerAggregation.Values::count);
});
}

Expand All @@ -112,17 +114,18 @@ public void testNearVector() {
var result = things.aggregate.nearVector(
randomVector(10, -1f, 1f),
near -> near.limit(5),
with -> with.metrics(
Metric.integer("price", calculate -> calculate
.min().max().count()))
with -> with
.metrics(
Aggregation.integer("price",
calculate -> calculate.min().max().count()))
.objectLimit(4)
.includeTotalCount());
.includeTotalCount(true));

Assertions.assertThat(result)
.as("includes all objects").hasFieldOrPropertyWithValue("totalCount", 4L)
.as("'price' is IntegerMetric").returns(true, p -> p.isIntegerProperty("price"))
.as("aggregated prices").extracting(p -> p.getInteger("price"))
.as("count").returns(4L, IntegerMetric.Values::count);
.as("'price' is IntegerAggregation").returns(true, p -> p.isInteger("price"))
.as("aggregated prices").extracting(p -> p.integer("price"))
.as("count").returns(4L, IntegerAggregation.Values::count);
}

@Test
Expand All @@ -131,35 +134,36 @@ public void testNearVector_groupBy_category() {
var result = things.aggregate.nearVector(
randomVector(10, -1f, 1f),
near -> near.distance(2f),
new GroupBy("category"),
with -> with.metrics(
Metric.integer("price", calculate -> calculate
.min().max().median()))
with -> with
.metrics(
Aggregation.integer("price",
calculate -> calculate.min().max().median()))
.objectLimit(9)
.includeTotalCount());
.includeTotalCount(true),
new GroupBy("category"));

Assertions.assertThat(result)
.extracting(AggregateGroupByResponse::groups)
.asInstanceOf(InstanceOfAssertFactories.list(Group.class))
.extracting(AggregateResponseGrouped::groups)
.asInstanceOf(InstanceOfAssertFactories.list(AggregateResponseGroup.class))
.as("group per category").hasSize(3)
.allSatisfy(group -> {
Assertions.assertThat(group)
.extracting(Group::by)
.as(group.by().property() + " is Text property").returns(true, GroupedBy::isText);
.extracting(AggregateResponseGroup::groupedBy)
.as(group.groupedBy().property() + " is Text property").returns(true, GroupedBy::isText);

String category = group.by().getAsText();
String category = group.groupedBy().text();
var expectedPrice = (long) category.length();

Function<String, Supplier<String>> desc = (String metric) -> {
return () -> "%s ('%s'.length)".formatted(metric, category);
};

Assertions.assertThat(group)
.as("'price' is IntegerMetric").returns(true, g -> g.isIntegerProperty("price"))
.as("aggregated prices").extracting(g -> g.getInteger("price"))
.as(desc.apply("max")).returns(expectedPrice, IntegerMetric.Values::max)
.as(desc.apply("min")).returns(expectedPrice, IntegerMetric.Values::min)
.as(desc.apply("median")).returns((double) expectedPrice, IntegerMetric.Values::median);
.as("'price' is IntegerAggregation").returns(true, g -> g.isInteger("price"))
.as("aggregated prices").extracting(g -> g.integer("price"))
.as(desc.apply("max")).returns(expectedPrice, IntegerAggregation.Values::max)
.as(desc.apply("min")).returns(expectedPrice, IntegerAggregation.Values::min)
.as(desc.apply("median")).returns((double) expectedPrice, IntegerAggregation.Values::median);
});
}
}
13 changes: 6 additions & 7 deletions src/it/java/io/weaviate/integration/DataITest.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@

import io.weaviate.ConcurrentTest;
import io.weaviate.client6.WeaviateClient;
import io.weaviate.client6.v1.api.collections.Vectors;
import io.weaviate.client6.v1.collections.Property;
import io.weaviate.client6.v1.collections.VectorIndex;
import io.weaviate.client6.v1.collections.VectorIndex.IndexingStrategy;
import io.weaviate.client6.v1.collections.Vectorizer;
import io.weaviate.client6.v1.collections.object.Vectors;
import io.weaviate.client6.v1.collections.object.WeaviateObject;
import io.weaviate.containers.Container;

public class DataITest extends ConcurrentTest {
Expand All @@ -39,9 +38,9 @@ public void testCreateGetDelete() throws IOException {
.id(id)
.vectors(Vectors.of(VECTOR_INDEX, vector)));

var object = artists.data.get(id, query -> query
var object = artists.query.byId(id, query -> query
.returnProperties("name")
.includeVector());
.includeVector(true));

Assertions.assertThat(object)
.as("object exists after insert").get()
Expand All @@ -59,7 +58,7 @@ public void testCreateGetDelete() throws IOException {
});

artists.data.delete(id);
object = artists.data.get(id);
object = artists.query.byId(id);
Assertions.assertThat(object).isEmpty().as("object not exists after deletion");
}

Expand All @@ -78,11 +77,11 @@ public void testBlobData() throws IOException {
"breed", "ragdoll",
"img", ragdollPng));

var got = cats.data.get(ragdoll.metadata().id(),
var got = cats.query.byId(ragdoll.metadata().id(),
cat -> cat.returnProperties("img"));

Assertions.assertThat(got).get()
.extracting(WeaviateObject::properties, InstanceOfAssertFactories.MAP)
.extracting(io.weaviate.client6.v1.api.collections.WeaviateObject::properties, InstanceOfAssertFactories.MAP)
.extractingByKey("img").isEqualTo(ragdollPng);
}

Expand Down
17 changes: 9 additions & 8 deletions src/it/java/io/weaviate/integration/ReferencesITest.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@

import io.weaviate.ConcurrentTest;
import io.weaviate.client6.WeaviateClient;
import io.weaviate.client6.v1.api.collections.ObjectReference;
import io.weaviate.client6.v1.api.collections.WeaviateObject;
import io.weaviate.client6.v1.api.collections.query.MetadataField;
import io.weaviate.client6.v1.api.collections.query.QueryMetadata;
import io.weaviate.client6.v1.api.collections.query.QueryReference;
import io.weaviate.client6.v1.collections.Property;
import io.weaviate.client6.v1.collections.Reference;
import io.weaviate.client6.v1.collections.ReferenceProperty;
import io.weaviate.client6.v1.collections.object.ObjectReference;
import io.weaviate.client6.v1.collections.object.WeaviateObject;
import io.weaviate.client6.v1.collections.query.MetadataField;
import io.weaviate.client6.v1.collections.query.QueryReference;
import io.weaviate.containers.Container;

/**
Expand Down Expand Up @@ -91,7 +92,7 @@ public void testReferences() throws IOException {
.extracting(ReferenceProperty::dataTypes, InstanceOfAssertFactories.list(String.class))
.containsOnly(nsMovies);

var gotAlex = artists.data.get(alex.metadata().id(),
var gotAlex = artists.query.byId(alex.metadata().id(),
opt -> opt.returnReferences(
QueryReference.multi("hasAwards", nsOscar,
ref -> ref.returnMetadata(MetadataField.ID)),
Expand All @@ -103,7 +104,7 @@ public void testReferences() throws IOException {
.extracting(WeaviateObject::references, InstanceOfAssertFactories.map(String.class, ObjectReference.class))
.as("hasAwards object reference").extractingByKey("hasAwards")
.extracting(ObjectReference::objects, InstanceOfAssertFactories.list(WeaviateObject.class))
.extracting(objects -> objects.metadata().id())
.extracting(object -> ((QueryMetadata) object.metadata()).id())
.containsOnly(
// INVESTIGATE: When references to 2+ collections are requested,
// seems to Weaviate only return references to the first one in the list.
Expand Down Expand Up @@ -154,7 +155,7 @@ public void testNestedReferences() throws IOException {
.reference("hasAwards", Reference.objects(grammy_1)));

// Assert: fetch nested references
var gotAlex = artists.data.get(alex.metadata().id(),
var gotAlex = artists.query.byId(alex.metadata().id(),
opt -> opt.returnReferences(
QueryReference.single("hasAwards",
ref -> ref
Expand All @@ -170,7 +171,7 @@ public void testNestedReferences() throws IOException {
.as("hasAwards object reference").extractingByKey("hasAwards")
.extracting(ObjectReference::objects, InstanceOfAssertFactories.list(WeaviateObject.class))
.hasSize(1).allSatisfy(award -> Assertions.assertThat(award)
.returns(grammy_1.metadata().id(), g -> g.metadata().id())
.returns(grammy_1.metadata().id(), grammy -> ((QueryMetadata) grammy.metadata()).id())
.extracting(WeaviateObject::references,
InstanceOfAssertFactories.map(String.class, ObjectReference.class))
.extractingByKey("presentedBy")
Expand Down
44 changes: 21 additions & 23 deletions src/it/java/io/weaviate/integration/SearchITest.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@

import io.weaviate.ConcurrentTest;
import io.weaviate.client6.WeaviateClient;
import io.weaviate.client6.v1.api.collections.Vectors;
import io.weaviate.client6.v1.api.collections.WeaviateObject;
import io.weaviate.client6.v1.api.collections.query.GroupBy;
import io.weaviate.client6.v1.api.collections.query.MetadataField;
import io.weaviate.client6.v1.api.collections.query.QueryResponseGroup;
import io.weaviate.client6.v1.collections.Property;
import io.weaviate.client6.v1.collections.Reference;
import io.weaviate.client6.v1.collections.VectorIndex;
import io.weaviate.client6.v1.collections.VectorIndex.IndexingStrategy;
import io.weaviate.client6.v1.collections.Vectorizer;
import io.weaviate.client6.v1.collections.object.Vectors;
import io.weaviate.client6.v1.collections.query.GroupedQueryResult;
import io.weaviate.client6.v1.collections.query.MetadataField;
import io.weaviate.client6.v1.collections.query.NearText;
import io.weaviate.client6.v1.collections.query.NearVector;
import io.weaviate.containers.Container;
import io.weaviate.containers.Container.ContainerGroup;
import io.weaviate.containers.Contextionary;
Expand Down Expand Up @@ -69,32 +69,32 @@ public void testNearVector() {
.limit(3)
.returnMetadata(MetadataField.DISTANCE));

Assertions.assertThat(result.objects).hasSize(3);
float maxDistance = Collections.max(result.objects,
Comparator.comparing(obj -> obj.metadata.distance())).metadata.distance();
Assertions.assertThat(result.objects()).hasSize(3);
float maxDistance = Collections.max(result.objects(),
Comparator.comparing(obj -> obj.metadata().distance())).metadata().distance();
Assertions.assertThat(maxDistance).isLessThanOrEqualTo(2f);
}

@Test
public void testNearVector_groupBy() {
var things = client.collections.use(COLLECTION);
var result = things.query.nearVector(searchVector,
new NearVector.GroupBy("category", 2, 5),
opt -> opt.distance(10f));
opt -> opt.distance(10f),
GroupBy.property("category", 2, 5));

Assertions.assertThat(result.groups)
Assertions.assertThat(result.groups())
.as("group per category").containsOnlyKeys(CATEGORIES)
.hasSizeLessThanOrEqualTo(2)
.allSatisfy((category, group) -> {
Assertions.assertThat(group)
.as("group name").returns(category, GroupedQueryResult.Group::name);
.as("group name").returns(category, QueryResponseGroup::name);
Assertions.assertThat(group.numberOfObjects())
.as("[%s] has 1+ object", category).isLessThanOrEqualTo(5L);
});

Assertions.assertThat(result.objects)
Assertions.assertThat(result.objects())
.as("object belongs a group")
.allMatch(obj -> result.groups.get(obj.belongsToGroup).objects().contains(obj));
.allMatch(obj -> result.groups().get(obj.belongsToGroup()).objects().contains(obj));
}

/**
Expand Down Expand Up @@ -151,8 +151,8 @@ public void testNearText() throws IOException {
.moveAway(.4f, away -> away.uuids(submarine.metadata().id()))
.returnProperties("title"));

Assertions.assertThat(result.objects).hasSize(2)
.extracting(obj -> obj.properties).allSatisfy(
Assertions.assertThat(result.objects()).hasSize(2)
.extracting(WeaviateObject::properties).allSatisfy(
properties -> Assertions.assertThat(properties)
.allSatisfy((_k, v) -> Assertions.assertThat((String) v).contains("Jungle")));
}
Expand Down Expand Up @@ -185,18 +185,16 @@ public void testNearText_groupBy() throws IOException {
s -> s.reference("performedBy", Reference.objects(ccr)));

var result = songs.query.nearText("nature",
new NearText.GroupBy("performedBy", 2, 1),
opt -> opt
.returnProperties("title"));
opt -> opt.returnProperties("title"),
GroupBy.property("performedBy", 2, 1));

Assertions.assertThat(result.groups).hasSize(2)
Assertions.assertThat(result.groups()).hasSize(2)
.containsOnlyKeys(
"weaviate://localhost/%s/%s".formatted(nsArtists, beatles.metadata().id()),
"weaviate://localhost/%s/%s".formatted(nsArtists, ccr.metadata().id()));
}

@Test
// @Ignore("no fitting image to test with")
public void testNearImage() throws IOException {
var nsCats = ns("Cats");

Expand All @@ -218,8 +216,8 @@ public void testNearImage() throws IOException {
var got = cats.query.nearImage(EncodedMedia.IMAGE,
opt -> opt.returnProperties("breed"));

Assertions.assertThat(got.objects).hasSize(1).first()
.extracting(obj -> obj.properties, InstanceOfAssertFactories.MAP)
Assertions.assertThat(got.objects()).hasSize(1).first()
.extracting(WeaviateObject::properties, InstanceOfAssertFactories.MAP)
.extractingByKey("breed").isEqualTo("ragdoll");
}
}
Loading