From a16a0a3f5a658d0f09158ff24747f0bed92cd48b Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Thu, 9 Oct 2025 14:22:31 +0200 Subject: [PATCH 01/25] feat: implement generative query with bm25 filter --- .../io/weaviate/integration/SearchITest.java | 98 ++++++ .../v1/api/collections/CollectionHandle.java | 4 + .../v1/api/collections/Generative.java | 6 +- .../generate/AbstractGenerateClient.java | 137 +++++++++ .../collections/generate/GenerativeDebug.java | 4 + .../generate/GenerativeObject.java | 12 + .../generate/GenerativeRequest.java | 43 +++ .../generate/GenerativeResponse.java | 64 ++++ .../generate/GenerativeResponseGroup.java | 26 ++ .../generate/GenerativeResponseGrouped.java | 77 +++++ .../collections/generate/GenerativeTask.java | 157 ++++++++++ .../generate/ProviderMetadata.java | 7 + .../api/collections/generate/TaskOutput.java | 8 + .../generate/WeaviateGenerateClient.java | 41 +++ .../generative/DummyGenerative.java | 24 ++ .../collections/query/BaseQueryOptions.java | 22 ++ .../collections/query/GenerativeSearch.java | 157 ++++++++++ .../collections/query/QueryObjectGrouped.java | 6 +- .../api/collections/query/QueryOperator.java | 2 +- .../api/collections/query/QueryRequest.java | 278 ++---------------- .../api/collections/query/QueryResponse.java | 201 ++++++++++++- .../query/QueryResponseGrouped.java | 47 ++- 22 files changed, 1162 insertions(+), 259 deletions(-) create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/generate/AbstractGenerateClient.java create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeDebug.java create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeObject.java create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeRequest.java create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeResponse.java create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeResponseGroup.java create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeResponseGrouped.java create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeTask.java create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/generate/ProviderMetadata.java create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/generate/TaskOutput.java create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/generate/WeaviateGenerateClient.java create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/generative/DummyGenerative.java create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/query/GenerativeSearch.java diff --git a/src/it/java/io/weaviate/integration/SearchITest.java b/src/it/java/io/weaviate/integration/SearchITest.java index c08c92a6d..46a2ded83 100644 --- a/src/it/java/io/weaviate/integration/SearchITest.java +++ b/src/it/java/io/weaviate/integration/SearchITest.java @@ -27,6 +27,9 @@ import io.weaviate.client6.v1.api.collections.WeaviateMetadata; import io.weaviate.client6.v1.api.collections.WeaviateObject; import io.weaviate.client6.v1.api.collections.data.Reference; +import io.weaviate.client6.v1.api.collections.generate.GenerativeObject; +import io.weaviate.client6.v1.api.collections.generate.TaskOutput; +import io.weaviate.client6.v1.api.collections.generative.DummyGenerative; import io.weaviate.client6.v1.api.collections.query.GroupBy; import io.weaviate.client6.v1.api.collections.query.Metadata; import io.weaviate.client6.v1.api.collections.query.QueryMetadata; @@ -47,6 +50,7 @@ public class SearchITest extends ConcurrentTest { Weaviate.custom() .withContextionaryUrl(Contextionary.URL) .withImageInference(Img2VecNeural.URL, Img2VecNeural.MODULE) + .addModules("generative-dummy") .build(), Container.IMG2VEC_NEURAL, Container.CONTEXTIONARY); @@ -549,4 +553,98 @@ public void testNearVector_targetVectors() throws IOException { .hasSize(1).extracting(WeaviateObject::uuid) .containsExactly(thing456.uuids().get(0)); } + + @Test + public void testGenerative_bm25() throws IOException { + // Arrange + var nsThings = ns("Things"); + + client.collections.create(nsThings, + c -> c + .properties(Property.text("title")) + .generativeModule(new DummyGenerative()) + .vectorConfig(VectorConfig.text2vecContextionary( + t2v -> t2v.sourceProperties("title")))); + + var things = client.collections.use(nsThings); + + things.data.insertMany( + Map.of("title", "Salad Fork"), + Map.of("title", "Dessert Fork")); + + // Act + var french = things.generate.bm25( + "fork", + bm25 -> bm25.queryProperties("title").limit(2), + generate -> generate + .singlePrompt("translate to French") + .groupedTask("count letters R")); + + // Assert + Assertions.assertThat(french.objects()) + .as("individual results") + .hasSize(2) + .extracting(GenerativeObject::generated) + .allSatisfy(generated -> { + Assertions.assertThat(generated.text()).isNotBlank(); + }); + + Assertions.assertThat(french.generated()) + .as("summary") + .extracting(TaskOutput::text, InstanceOfAssertFactories.STRING) + .isNotBlank(); + } + + @Test + public void testGenerative_bm25_groupBy() throws IOException { + // Arrange + var nsThings = ns("Things"); + + client.collections.create(nsThings, + c -> c + .properties(Property.text("title")) + .generativeModule(new DummyGenerative()) + .vectorConfig(VectorConfig.text2vecContextionary( + t2v -> t2v.sourceProperties("title")))); + + var things = client.collections.use(nsThings); + + things.data.insertMany( + Map.of("title", "Salad Fork"), + Map.of("title", "Dessert Fork")); + + // Act + var french = things.generate.bm25( + "fork", + bm25 -> bm25.queryProperties("title").limit(2), + generate -> generate + .singlePrompt("translate to French") + .groupedTask("count letters R"), + GroupBy.property("title", 5, 5)); + + // Assert + Assertions.assertThat(french.objects()) + .as("individual results") + .hasSize(2); + + Assertions.assertThat(french.groups()) + .as("grouped results") + .hasSize(2) + .allSatisfy((groupName, group) -> { + Assertions.assertThat(group.objects()) + .describedAs("objects in group %s", groupName) + .hasSize(1); + + Assertions.assertThat(group.generated()) + .describedAs("summary group %s", groupName) + .extracting(TaskOutput::text, InstanceOfAssertFactories.STRING) + .isNotBlank(); + + }); + + Assertions.assertThat(french.generated()) + .as("summary") + .extracting(TaskOutput::text, InstanceOfAssertFactories.STRING) + .isNotBlank(); + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandle.java b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandle.java index 08f41eca3..7af8ed549 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandle.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandle.java @@ -6,6 +6,7 @@ import io.weaviate.client6.v1.api.collections.aggregate.WeaviateAggregateClient; import io.weaviate.client6.v1.api.collections.config.WeaviateConfigClient; import io.weaviate.client6.v1.api.collections.data.WeaviateDataClient; +import io.weaviate.client6.v1.api.collections.generate.WeaviateGenerateClient; import io.weaviate.client6.v1.api.collections.pagination.Paginator; import io.weaviate.client6.v1.api.collections.query.ConsistencyLevel; import io.weaviate.client6.v1.api.collections.query.WeaviateQueryClient; @@ -20,6 +21,7 @@ public class CollectionHandle { public final WeaviateDataClient data; public final WeaviateQueryClient query; public final WeaviateAggregateClient aggregate; + public final WeaviateGenerateClient generate; public final WeaviateTenantsClient tenants; private final CollectionHandleDefaults defaults; @@ -32,6 +34,7 @@ public CollectionHandle( this.config = new WeaviateConfigClient(collection, restTransport, grpcTransport, defaults); this.aggregate = new WeaviateAggregateClient(collection, grpcTransport, defaults); this.query = new WeaviateQueryClient<>(collection, grpcTransport, defaults); + this.generate = new WeaviateGenerateClient<>(collection, grpcTransport, defaults); this.data = new WeaviateDataClient<>(collection, restTransport, grpcTransport, defaults); this.defaults = defaults; @@ -43,6 +46,7 @@ private CollectionHandle(CollectionHandle c, CollectionHandleDefaul this.config = new WeaviateConfigClient(c.config, defaults); this.aggregate = new WeaviateAggregateClient(c.aggregate, defaults); this.query = new WeaviateQueryClient<>(c.query, defaults); + this.generate = new WeaviateGenerateClient<>(c.generate, defaults); this.data = new WeaviateDataClient<>(c.data, defaults); this.defaults = defaults; diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java b/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java index 632713cdd..74ffb2f4d 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java @@ -14,12 +14,14 @@ import com.google.gson.stream.JsonWriter; import io.weaviate.client6.v1.api.collections.generative.CohereGenerative; +import io.weaviate.client6.v1.api.collections.generative.DummyGenerative; import io.weaviate.client6.v1.internal.ObjectBuilder; import io.weaviate.client6.v1.internal.json.JsonEnum; public interface Generative { public enum Kind implements JsonEnum { - COHERE("generative-cohere"); + COHERE("generative-cohere"), + DUMMY("generative-dummy"); private static final Map jsonValueMap = JsonEnum.collectNames(Kind.values()); private final String jsonValue; @@ -68,6 +70,7 @@ private final void addAdapter(Gson gson, Generative.Kind kind, Class { + protected final CollectionDescriptor collection; + protected final GrpcTransport grpcTransport; + protected final CollectionHandleDefaults defaults; + + AbstractGenerateClient(CollectionDescriptor collection, GrpcTransport grpcTransport, + CollectionHandleDefaults defaults) { + this.collection = collection; + this.grpcTransport = grpcTransport; + this.defaults = defaults; + } + + /** Copy constructor that sets new defaults. */ + AbstractGenerateClient( + AbstractGenerateClient c, + CollectionHandleDefaults defaults) { + this(c.collection, c.grpcTransport, defaults); + } + + protected abstract ResponseT performRequest(QueryOperator operator, GenerativeTask generate); + + protected abstract GroupedResponseT performRequest(QueryOperator operator, GenerativeTask generate, GroupBy groupBy); + + // BM25 queries ------------------------------------------------------------- + + /** + * Query collection objects using keyword (BM25) search. + * + * @param query Query string. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT bm25(String query, + Function> generateFn) { + return bm25(Bm25.of(query), GenerativeTask.of(generateFn)); + } + + /** + * Query collection objects using keyword (BM25) search. + * + * @param query Query string. + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT bm25( + String query, + Function> fn, + Function> generateFn) { + return bm25(Bm25.of(query, fn), GenerativeTask.of(generateFn)); + } + + /** + * Query collection objects using keyword (BM25) search. + * + * @param query BM25 query request. + * @param generate Generative task. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT bm25(Bm25 query, GenerativeTask generate) { + return performRequest(query, generate); + } + + /** + * Query collection objects using keyword (BM25) search. + * + * @param query Query string. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see QueryResponseGrouped + */ + public GroupedResponseT bm25(String query, + Function> generateFn, + GroupBy groupBy) { + return bm25(Bm25.of(query), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Query collection objects using keyword (BM25) search. + * + * @param query Query string. + * @param fn Lambda expression for optional parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see QueryResponseGrouped + */ + public GroupedResponseT bm25(String query, + Function> fn, + Function> generateFn, + GroupBy groupBy) { + return bm25(Bm25.of(query, fn), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Query collection objects using keyword (BM25) search. + * + * @param query BM25 query request. + * @param generate Generative task. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see QueryResponseGrouped + */ + public GroupedResponseT bm25(Bm25 query, GenerativeTask generate, GroupBy groupBy) { + return performRequest(query, generate, groupBy); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeDebug.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeDebug.java new file mode 100644 index 000000000..4b08cdc12 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeDebug.java @@ -0,0 +1,4 @@ +package io.weaviate.client6.v1.api.collections.generate; + +public record GenerativeDebug(String fullPrompt) { +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeObject.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeObject.java new file mode 100644 index 000000000..1767865e9 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeObject.java @@ -0,0 +1,12 @@ +package io.weaviate.client6.v1.api.collections.generate; + +import io.weaviate.client6.v1.api.collections.query.QueryMetadata; + +public record GenerativeObject( + /** Object properties. */ + PropertiesT properties, + /** Object metadata. */ + QueryMetadata metadata, + /** Generative task output. */ + TaskOutput generated) { +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeRequest.java new file mode 100644 index 000000000..b09ce9f4e --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeRequest.java @@ -0,0 +1,43 @@ +package io.weaviate.client6.v1.api.collections.generate; + +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; +import io.weaviate.client6.v1.api.collections.query.GroupBy; +import io.weaviate.client6.v1.api.collections.query.QueryOperator; +import io.weaviate.client6.v1.api.collections.query.QueryRequest; +import io.weaviate.client6.v1.internal.grpc.Rpc; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateBlockingStub; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateFutureStub; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet; +import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; + +public record GenerativeRequest(QueryOperator operator, GenerativeTask generate, GroupBy groupBy) { + static Rpc, WeaviateProtoSearchGet.SearchReply> rpc( + CollectionDescriptor collection, + CollectionHandleDefaults defaults) { + return Rpc.of( + request -> { + var query = QueryRequest.marshal( + new QueryRequest(request.operator, request.groupBy), + collection, defaults); + var generative = WeaviateProtoGenerative.GenerativeSearch.newBuilder(); + request.generate.appendTo(generative); + var builder = query.toBuilder(); + builder.setGenerative(generative); + return builder.build(); + }, + reply -> GenerativeResponse.unmarshal(reply, collection), + () -> WeaviateBlockingStub::search, + () -> WeaviateFutureStub::search); + } + + static Rpc, WeaviateProtoSearchGet.SearchReply> grouped( + CollectionDescriptor collection, + CollectionHandleDefaults defaults) { + var rpc = rpc(collection, defaults); + return Rpc.of( + request -> rpc.marshal(request), + reply -> GenerativeResponseGrouped.unmarshal(reply, collection, defaults), + () -> rpc.method(), () -> rpc.methodAsync()); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeResponse.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeResponse.java new file mode 100644 index 000000000..8bc76b862 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeResponse.java @@ -0,0 +1,64 @@ +package io.weaviate.client6.v1.api.collections.generate; + +import java.util.List; + +import io.weaviate.client6.v1.api.collections.generative.DummyGenerative; +import io.weaviate.client6.v1.api.collections.query.QueryResponse; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet; +import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; + +public record GenerativeResponse( + float took, + List> objects, + TaskOutput generated) { + static GenerativeResponse unmarshal( + WeaviateProtoSearchGet.SearchReply reply, + CollectionDescriptor collection) { + var objects = reply + .getResultsList() + .stream() + .map(result -> { + var object = QueryResponse.unmarshalResultObject( + result.getProperties(), result.getMetadata(), collection); + TaskOutput generative = null; + if (result.hasGenerative()) { + generative = GenerativeResponse.unmarshalTaskOutput(result.getGenerative()); + } + return new GenerativeObject<>( + object.properties(), + object.metadata(), + generative); + }) + .toList(); + + TaskOutput summary = null; + if (reply.hasGenerativeGroupedResults()) { + summary = GenerativeResponse.unmarshalTaskOutput(reply.getGenerativeGroupedResults()); + } + return new GenerativeResponse<>(reply.getTook(), objects, summary); + } + + static TaskOutput unmarshalTaskOutput(List values) { + if (values.isEmpty()) { + return null; + } + var generated = values.get(0); + + var metadata = generated.getMetadata(); + ProviderMetadata providerMetadata = null; + if (metadata.hasDummy()) { + providerMetadata = new DummyGenerative.Metadata(); + } + + GenerativeDebug debug = null; + if (generated.getDebug() != null && generated.getDebug().getFullPrompt() != null) { + debug = new GenerativeDebug(generated.getDebug().getFullPrompt()); + } + return new TaskOutput(generated.getResult(), providerMetadata, debug); + } + + static TaskOutput unmarshalTaskOutput(WeaviateProtoGenerative.GenerativeResult result) { + return unmarshalTaskOutput(result.getValuesList()); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeResponseGroup.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeResponseGroup.java new file mode 100644 index 000000000..4a9620530 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeResponseGroup.java @@ -0,0 +1,26 @@ +package io.weaviate.client6.v1.api.collections.generate; + +import java.util.List; + +import io.weaviate.client6.v1.api.collections.query.QueryObjectGrouped; + +public record GenerativeResponseGroup( + /** Group name. */ + String name, + /** + * The smallest distance value among all objects in the group, indicating the + * most similar object in that group to the query + */ + Float minDistance, + /** + * The largest distance value among all objects in the group, indicating the + * least similar object in that group to the query. + */ + Float maxDistance, + /** The size of the group. */ + long numberOfObjects, + /** Objects retrieved in the query. */ + List> objects, + /** Output of the summary task for this group. */ + TaskOutput generated) { +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeResponseGrouped.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeResponseGrouped.java new file mode 100644 index 000000000..673cddbde --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeResponseGrouped.java @@ -0,0 +1,77 @@ +package io.weaviate.client6.v1.api.collections.generate; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; + +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; +import io.weaviate.client6.v1.api.collections.query.QueryObjectGrouped; +import io.weaviate.client6.v1.api.collections.query.QueryResponse; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet; +import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; + +public record GenerativeResponseGrouped( + /** Execution time of the request as measure by the server. */ + float took, + /** Objects returned by the associated query. */ + List> objects, + /** Grouped results with per-group generated output. */ + Map> groups, + /** Output of the summary group task. */ + TaskOutput generated) { + + static GenerativeResponseGrouped unmarshal( + WeaviateProtoSearchGet.SearchReply reply, + CollectionDescriptor collection, + CollectionHandleDefaults defaults) { + var allObjects = new ArrayList>(); + var groups = reply.getGroupByResultsList().stream() + .map(group -> { + var groupName = group.getName(); + List> objects = group.getObjectsList().stream() + .map(object -> QueryResponse.unmarshalResultObject( + object.getProperties(), + object.getMetadata(), + collection)) + .map(object -> new QueryObjectGrouped<>( + object.properties(), + object.metadata(), + groupName)) + .toList(); + + allObjects.addAll(objects); + + TaskOutput generative = null; + if (group.hasGenerativeResult()) { + generative = GenerativeResponse.unmarshalTaskOutput(group.getGenerativeResult()); + } else if (group.hasGenerative()) { + // As of today the server continues to use the deprecated field in response. + generative = GenerativeResponse.unmarshalTaskOutput(List.of(group.getGenerative())); + } + + return new GenerativeResponseGroup<>( + groupName, + group.getMinDistance(), + group.getMaxDistance(), + group.getNumberOfObjects(), + objects, + generative); + }) + // Collectors.toMap() throws an NPE if either key or value in the map are null. + // In this specific case it is safe to use it, as the function in the map above + // always returns a QueryResponseGroup. + // The name of the group should not be null either, that's something we assume + // about the server's response. + .collect(Collectors.toMap(GenerativeResponseGroup::name, Function.identity())); + + TaskOutput summary = null; + if (reply.hasGenerativeGroupedResults()) { + summary = GenerativeResponse.unmarshalTaskOutput(reply.getGenerativeGroupedResults()); + } else if (reply.hasGenerativeGroupedResult()) { + summary = new TaskOutput(reply.getGenerativeGroupedResult(), null, null); + } + return new GenerativeResponseGrouped<>(reply.getTook(), allObjects, groups, summary); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeTask.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeTask.java new file mode 100644 index 000000000..e5f4dfcf9 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeTask.java @@ -0,0 +1,157 @@ +package io.weaviate.client6.v1.api.collections.generate; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.function.Function; + +import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; + +public record GenerativeTask(Single single, Grouped grouped) { + public static GenerativeTask of(Function> fn) { + return fn.apply(new Builder()).build(); + } + + public GenerativeTask(Builder builder) { + this(builder.single, builder.grouped); + } + + public static class Builder implements ObjectBuilder { + private Single single; + private Grouped grouped; + + public Builder singlePrompt(String prompt) { + this.single = Single.of(prompt); + return this; + } + + public Builder singlePrompt(String prompt, Function> fn) { + this.single = Single.of(prompt, fn); + return this; + + } + + public Builder groupedTask(String prompt) { + this.grouped = Grouped.of(prompt); + return this; + } + + public Builder groupedTask(String prompt, Function> fn) { + this.grouped = Grouped.of(prompt, fn); + return this; + } + + @Override + public GenerativeTask build() { + return new GenerativeTask(this); + } + } + + void appendTo(WeaviateProtoGenerative.GenerativeSearch.Builder req) { + if (single != null) { + single.appendTo(req); + } + if (grouped != null) { + grouped.appendTo(req); + } + } + + public record Single(String prompt, boolean debug) { + public static Single of(String prompt) { + return of(prompt, ObjectBuilder.identity()); + } + + public static Single of(String prompt, Function> fn) { + return fn.apply(new Builder(prompt)).build(); + } + + public Single(Builder builder) { + this(builder.prompt, builder.debug); + } + + public static class Builder implements ObjectBuilder { + private final String prompt; + private boolean debug = false; + + public Builder(String prompt) { + this.prompt = prompt; + } + + public Builder debug(boolean enable) { + this.debug = enable; + return this; + } + + @Override + public Single build() { + return new Single(this); + } + } + + public void appendTo(WeaviateProtoGenerative.GenerativeSearch.Builder req) { + req.setSingle( + WeaviateProtoGenerative.GenerativeSearch.Single.newBuilder() + .setPrompt(prompt) + .setDebug(debug)); + } + } + + public record Grouped(String prompt, boolean debug, List properties) { + public static Grouped of(String prompt) { + return of(prompt, ObjectBuilder.identity()); + } + + public static Grouped of(String prompt, Function> fn) { + return fn.apply(new Builder(prompt)).build(); + } + + public Grouped(Builder builder) { + this(builder.prompt, builder.debug, builder.properties); + } + + public static class Builder implements ObjectBuilder { + private final String prompt; + private final List properties = new ArrayList<>(); + private boolean debug = false; + + public Builder(String prompt) { + this.prompt = prompt; + } + + public Builder properties(String... properties) { + return properties(Arrays.asList(properties)); + } + + public Builder properties(List properties) { + this.properties.addAll(properties); + return this; + } + + public Builder debug(boolean enable) { + this.debug = enable; + return this; + } + + @Override + public Grouped build() { + return new Grouped(this); + } + } + + public void appendTo(WeaviateProtoGenerative.GenerativeSearch.Builder req) { + var grouped = WeaviateProtoGenerative.GenerativeSearch.Grouped.newBuilder() + .setTask(prompt) + .setDebug(debug); + + if (properties != null && !properties.isEmpty()) { + grouped.setProperties( + WeaviateProtoBase.TextArray.newBuilder() + .addAllValues(properties)); + + } + req.setGrouped(grouped); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/ProviderMetadata.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/ProviderMetadata.java new file mode 100644 index 000000000..d56908e7d --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generate/ProviderMetadata.java @@ -0,0 +1,7 @@ +package io.weaviate.client6.v1.api.collections.generate; + +import io.weaviate.client6.v1.api.collections.Generative; + +public interface ProviderMetadata { + Generative.Kind _kind(); +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/TaskOutput.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/TaskOutput.java new file mode 100644 index 000000000..a3becb78a --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generate/TaskOutput.java @@ -0,0 +1,8 @@ +package io.weaviate.client6.v1.api.collections.generate; + +public record TaskOutput( + String text, + ProviderMetadata metadata, + GenerativeDebug debug) { + +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/WeaviateGenerateClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/WeaviateGenerateClient.java new file mode 100644 index 000000000..eef92cbfb --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generate/WeaviateGenerateClient.java @@ -0,0 +1,41 @@ +package io.weaviate.client6.v1.api.collections.generate; + +import java.util.Optional; + +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; +import io.weaviate.client6.v1.api.collections.WeaviateObject; +import io.weaviate.client6.v1.api.collections.query.GroupBy; +import io.weaviate.client6.v1.api.collections.query.QueryMetadata; +import io.weaviate.client6.v1.api.collections.query.QueryOperator; +import io.weaviate.client6.v1.internal.grpc.GrpcTransport; +import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; + +public class WeaviateGenerateClient + extends + AbstractGenerateClient>, GenerativeResponse, GenerativeResponseGrouped> { + + public WeaviateGenerateClient( + CollectionDescriptor collection, + GrpcTransport grpcTransport, + CollectionHandleDefaults defaults) { + super(collection, grpcTransport, defaults); + } + + /** Copy constructor that sets new defaults. */ + public WeaviateGenerateClient(WeaviateGenerateClient c, CollectionHandleDefaults defaults) { + super(c, defaults); + } + + @Override + protected final GenerativeResponse performRequest(QueryOperator operator, GenerativeTask generate) { + var request = new GenerativeRequest(operator, generate, null); + return this.grpcTransport.performRequest(request, GenerativeRequest.rpc(collection, defaults)); + } + + @Override + protected final GenerativeResponseGrouped performRequest(QueryOperator operator, GenerativeTask generate, + GroupBy groupBy) { + var request = new GenerativeRequest(operator, generate, groupBy); + return this.grpcTransport.performRequest(request, GenerativeRequest.grouped(collection, defaults)); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/DummyGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/DummyGenerative.java new file mode 100644 index 000000000..f6d1f915a --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/DummyGenerative.java @@ -0,0 +1,24 @@ +package io.weaviate.client6.v1.api.collections.generative; + +import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.api.collections.generate.ProviderMetadata; + +public record DummyGenerative() implements Generative { + @Override + public Kind _kind() { + return Generative.Kind.DUMMY; + } + + @Override + public Object _self() { + return this; + } + + public static record Metadata() implements ProviderMetadata { + + @Override + public Kind _kind() { + return Generative.Kind.DUMMY; + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/BaseQueryOptions.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/BaseQueryOptions.java index 036ce1165..e7f433dea 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/BaseQueryOptions.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/BaseQueryOptions.java @@ -3,12 +3,14 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.function.Function; import org.apache.commons.lang3.StringUtils; import io.weaviate.client6.v1.api.collections.query.Metadata.MetadataField; import io.weaviate.client6.v1.internal.ObjectBuilder; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet; public record BaseQueryOptions( @@ -18,6 +20,7 @@ public record BaseQueryOptions( String after, ConsistencyLevel consistencyLevel, Where where, + GenerativeSearch generativeSearch, List returnProperties, List returnReferences, List returnMetadata) { @@ -30,6 +33,7 @@ private BaseQueryOptions(Builder, T> builder.after, builder.consistencyLevel, builder.where, + builder.generativeSearch, builder.returnProperties, builder.returnReferences, builder.returnMetadata); @@ -44,6 +48,7 @@ public static abstract class Builder, T extends Ob private String after; private ConsistencyLevel consistencyLevel; private Where where; + private GenerativeSearch generativeSearch; private List returnProperties = new ArrayList<>(); private List returnReferences = new ArrayList<>(); private List returnMetadata = new ArrayList<>(); @@ -102,6 +107,17 @@ public final SELF consistencyLevel(ConsistencyLevel consistencyLevel) { return (SELF) this; } + /** + * Add arguments for generative query. + * Builders which support this parameter should make the method public. + * + * @param fn Lambda expression for optional parameters. + */ + protected SELF generate(Function> fn) { + this.generativeSearch = GenerativeSearch.of(fn); + return (SELF) this; + } + /** * Filter result set using traditional filtering operators: {@code eq}, * {@code gte}, {@code like}, etc. @@ -194,6 +210,12 @@ final void appendTo(WeaviateProtoSearchGet.SearchRequest.Builder req) { req.setFilters(filter); } + if (generativeSearch != null) { + var generative = WeaviateProtoGenerative.GenerativeSearch.newBuilder(); + generativeSearch.appendTo(generative); + req.setGenerative(generative); + } + var metadata = WeaviateProtoSearchGet.MetadataRequest.newBuilder(); returnMetadata.forEach(m -> m.appendTo(metadata)); req.setMetadata(metadata); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/GenerativeSearch.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/GenerativeSearch.java new file mode 100644 index 000000000..359dd282c --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/GenerativeSearch.java @@ -0,0 +1,157 @@ +package io.weaviate.client6.v1.api.collections.query; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.function.Function; + +import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; + +public record GenerativeSearch(Single single, Grouped grouped) { + public static GenerativeSearch of(Function> fn) { + return fn.apply(new Builder()).build(); + } + + public GenerativeSearch(Builder builder) { + this(builder.single, builder.grouped); + } + + public static class Builder implements ObjectBuilder { + private Single single; + private Grouped grouped; + + public Builder singlePrompt(String prompt) { + this.single = Single.of(prompt); + return this; + } + + public Builder singlePrompt(String prompt, Function> fn) { + this.single = Single.of(prompt, fn); + return this; + + } + + public Builder groupedTask(String prompt) { + this.grouped = Grouped.of(prompt); + return this; + } + + public Builder groupedTask(String prompt, Function> fn) { + this.grouped = Grouped.of(prompt, fn); + return this; + } + + @Override + public GenerativeSearch build() { + return new GenerativeSearch(this); + } + } + + void appendTo(WeaviateProtoGenerative.GenerativeSearch.Builder req) { + if (single != null) { + single.appendTo(req); + } + if (grouped != null) { + grouped.appendTo(req); + } + } + + public record Single(String prompt, boolean debug) { + public static Single of(String prompt) { + return of(prompt, ObjectBuilder.identity()); + } + + public static Single of(String prompt, Function> fn) { + return fn.apply(new Builder(prompt)).build(); + } + + public Single(Builder builder) { + this(builder.prompt, builder.debug); + } + + public static class Builder implements ObjectBuilder { + private final String prompt; + private boolean debug = false; + + public Builder(String prompt) { + this.prompt = prompt; + } + + public Builder debug(boolean enable) { + this.debug = enable; + return this; + } + + @Override + public Single build() { + return new Single(this); + } + } + + public void appendTo(WeaviateProtoGenerative.GenerativeSearch.Builder req) { + req.setSingle( + WeaviateProtoGenerative.GenerativeSearch.Single.newBuilder() + .setPrompt(prompt) + .setDebug(debug)); + } + } + + public record Grouped(String prompt, boolean debug, List properties) { + public static Grouped of(String prompt) { + return of(prompt, ObjectBuilder.identity()); + } + + public static Grouped of(String prompt, Function> fn) { + return fn.apply(new Builder(prompt)).build(); + } + + public Grouped(Builder builder) { + this(builder.prompt, builder.debug, builder.properties); + } + + public static class Builder implements ObjectBuilder { + private final String prompt; + private final List properties = new ArrayList<>(); + private boolean debug = false; + + public Builder(String prompt) { + this.prompt = prompt; + } + + public Builder properties(String... properties) { + return properties(Arrays.asList(properties)); + } + + public Builder properties(List properties) { + this.properties.addAll(properties); + return this; + } + + public Builder debug(boolean enable) { + this.debug = enable; + return this; + } + + @Override + public Grouped build() { + return new Grouped(this); + } + } + + public void appendTo(WeaviateProtoGenerative.GenerativeSearch.Builder req) { + var grouped = WeaviateProtoGenerative.GenerativeSearch.Grouped.newBuilder() + .setTask(prompt) + .setDebug(debug); + + if (properties != null && !properties.isEmpty()) { + grouped.setProperties( + WeaviateProtoBase.TextArray.newBuilder() + .addAllValues(properties)); + + } + req.setGrouped(grouped); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryObjectGrouped.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryObjectGrouped.java index f35f3e824..f98014af1 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryObjectGrouped.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryObjectGrouped.java @@ -2,15 +2,15 @@ import io.weaviate.client6.v1.api.collections.WeaviateObject; -public record QueryObjectGrouped( +public record QueryObjectGrouped( /** Object properties. */ - T properties, + PropertiesT properties, /** Object metadata. */ QueryMetadata metadata, /** Name of the group that the object belongs to. */ String belongsToGroup) { - QueryObjectGrouped(WeaviateObject object, + QueryObjectGrouped(WeaviateObject object, String belongsToGroup) { this(object.properties(), object.metadata(), belongsToGroup); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryOperator.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryOperator.java index a3844b4eb..a2fe89d11 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryOperator.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryOperator.java @@ -2,7 +2,7 @@ import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet; -interface QueryOperator { +public interface QueryOperator { /** Append QueryOperator to the request message. */ void appendTo(WeaviateProtoSearchGet.SearchRequest.Builder req); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java index 11d31b4e5..fcc553f9e 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java @@ -1,277 +1,57 @@ package io.weaviate.client6.v1.api.collections.query; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.UUID; -import java.util.function.Function; -import java.util.stream.Collectors; -import java.util.stream.Stream; - import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; -import io.weaviate.client6.v1.api.collections.ObjectMetadata; -import io.weaviate.client6.v1.api.collections.Vectors; -import io.weaviate.client6.v1.api.collections.WeaviateObject; -import io.weaviate.client6.v1.internal.DateUtil; -import io.weaviate.client6.v1.internal.grpc.ByteStringUtil; import io.weaviate.client6.v1.internal.grpc.Rpc; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateBlockingStub; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateFutureStub; -import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoProperties; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet; import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; -import io.weaviate.client6.v1.internal.orm.PropertiesBuilder; public record QueryRequest(QueryOperator operator, GroupBy groupBy) { - static Rpc, WeaviateProtoSearchGet.SearchReply> rpc( - CollectionDescriptor collection, + static Rpc, WeaviateProtoSearchGet.SearchReply> rpc( + CollectionDescriptor collection, CollectionHandleDefaults defaults) { return Rpc.of( - request -> { - var message = WeaviateProtoSearchGet.SearchRequest.newBuilder(); - message.setUses127Api(true); - message.setUses125Api(true); - message.setUses123Api(true); - message.setCollection(collection.collectionName()); - request.operator.appendTo(message); - - if (defaults.tenant() != null) { - message.setTenant(defaults.tenant()); - } - if (defaults.consistencyLevel() != null) { - defaults.consistencyLevel().appendTo(message); - } - - if (request.groupBy != null) { - request.groupBy.appendTo(message); - } - return message.build(); - }, - reply -> { - var objects = reply - .getResultsList() - .stream() - .map(obj -> QueryRequest.unmarshalResultObject( - obj.getProperties(), obj.getMetadata(), collection)) - .toList(); - return new QueryResponse<>(objects); - }, + request -> QueryRequest.marshal(request, collection, defaults), + reply -> QueryResponse.unmarshal(reply, collection), () -> WeaviateBlockingStub::search, () -> WeaviateFutureStub::search); } - static Rpc, WeaviateProtoSearchGet.SearchReply> grouped( - CollectionDescriptor collection, + public static WeaviateProtoSearchGet.SearchRequest marshal( + QueryRequest request, + CollectionDescriptor collection, CollectionHandleDefaults defaults) { - var rpc = rpc(collection, defaults); - return Rpc.of( - request -> rpc.marshal(request), - reply -> { - var allObjects = new ArrayList>(); - var groups = reply.getGroupByResultsList() - .stream().map(group -> { - var name = group.getName(); - List> objects = group.getObjectsList().stream() - .map(obj -> QueryRequest.unmarshalResultObject( - obj.getProperties(), - obj.getMetadata(), - collection)) - .map(obj -> new QueryObjectGrouped<>(obj, name)) - .toList(); - - allObjects.addAll(objects); - return new QueryResponseGroup<>( - name, - group.getMinDistance(), - group.getMaxDistance(), - group.getNumberOfObjects(), - objects); - }) - // Collectors.toMap() throws an NPE if either key or value in the map are null. - // In this specific case it is safe to use it, as the function in the map above - // always returns a QueryResponseGroup. - // The name of the group should not be null either, that's something we assume - // about the server's response. - .collect(Collectors.toMap(QueryResponseGroup::name, Function.identity())); + var message = WeaviateProtoSearchGet.SearchRequest.newBuilder(); + message.setUses127Api(true); + message.setUses125Api(true); + message.setUses123Api(true); + message.setCollection(collection.collectionName()); + request.operator.appendTo(message); - return new QueryResponseGrouped(allObjects, groups); - }, () -> rpc.method(), () -> rpc.methodAsync()); - } - - private static WeaviateObject unmarshalResultObject( - WeaviateProtoSearchGet.PropertiesResult propertiesResult, - WeaviateProtoSearchGet.MetadataResult metadataResult, - CollectionDescriptor collection) { - var object = unmarshalWithReferences(propertiesResult, metadataResult, collection); - var metadata = new QueryMetadata.Builder() - .uuid(object.metadata().uuid()) - .vectors(object.metadata().vectors()); - - if (metadataResult.getCreationTimeUnixPresent()) { - metadata.creationTimeUnix(metadataResult.getCreationTimeUnix()); - } - if (metadataResult.getLastUpdateTimeUnixPresent()) { - metadata.lastUpdateTimeUnix(metadataResult.getLastUpdateTimeUnix()); - } - if (metadataResult.getDistancePresent()) { - metadata.distance(metadataResult.getDistance()); - } - if (metadataResult.getCertaintyPresent()) { - metadata.certainty(metadataResult.getCertainty()); + if (defaults.tenant() != null) { + message.setTenant(defaults.tenant()); } - if (metadataResult.getScorePresent()) { - metadata.score(metadataResult.getScore()); + if (defaults.consistencyLevel() != null) { + defaults.consistencyLevel().appendTo(message); } - if (metadataResult.getExplainScorePresent()) { - metadata.explainScore(metadataResult.getExplainScore()); - } - return new WeaviateObject<>(collection.collectionName(), object.properties(), object.references(), - metadata.build()); - } - - private static WeaviateObject unmarshalWithReferences( - WeaviateProtoSearchGet.PropertiesResult propertiesResult, - WeaviateProtoSearchGet.MetadataResult metadataResult, - CollectionDescriptor descriptor) { - var properties = descriptor.propertiesBuilder(); - propertiesResult.getNonRefProps().getFieldsMap().entrySet().stream() - .forEach(entry -> setProperty(entry.getKey(), entry.getValue(), properties, descriptor)); - - // In case a reference is multi-target, there will be a separate - // "reference property" for each of the targets, so instead of - // `collect` we need to `reduce` the map, merging related references - // as we go. - // I.e. { "ref": A-1 } , { "ref": B-1 } => { "ref": [A-1, B-1] } - var referenceProperties = propertiesResult.getRefPropsList() - .stream().reduce( - new HashMap>(), - (map, ref) -> { - var refObjects = ref.getPropertiesList().stream() - .map(property -> { - var reference = unmarshalWithReferences( - property, property.getMetadata(), - // TODO: this should be possible to configure for ODM? - CollectionDescriptor.ofMap(property.getTargetCollection())); - return (Object) new WeaviateObject<>( - reference.collection(), - (Object) reference.properties(), - reference.references(), - reference.metadata()); - }) - .toList(); - - // Merge ObjectReferences by joining the underlying WeaviateObjects. - map.merge( - ref.getPropName(), - refObjects, - (left, right) -> { - var joined = Stream.concat( - left.stream(), - right.stream()).toList(); - return joined; - }); - return map; - }, - (left, right) -> { - left.putAll(right); - return left; - }); - - ObjectMetadata metadata = null; - if (metadataResult != null) { - var metadataBuilder = new ObjectMetadata.Builder() - .uuid(metadataResult.getId()); - var vectors = new Vectors[metadataResult.getVectorsList().size()]; - var i = 0; - for (final var vector : metadataResult.getVectorsList()) { - var vectorName = vector.getName(); - var vbytes = vector.getVectorBytes(); - switch (vector.getType()) { - case VECTOR_TYPE_SINGLE_FP32: - vectors[i++] = Vectors.of(vectorName, ByteStringUtil.decodeVectorSingle(vbytes)); - break; - case VECTOR_TYPE_MULTI_FP32: - vectors[i++] = Vectors.of(vectorName, ByteStringUtil.decodeVectorMulti(vbytes)); - break; - default: - continue; - } - } - metadataBuilder.vectors(vectors); - metadata = metadataBuilder.build(); + if (request.groupBy != null) { + request.groupBy.appendTo(message); } - var obj = new WeaviateObject.Builder() - .collection(descriptor.collectionName()) - .properties(properties.build()) - .references(referenceProperties) - .metadata(metadata); - return obj.build(); + return message.build(); } - private static void setProperty(String property, WeaviateProtoProperties.Value value, - PropertiesBuilder builder, CollectionDescriptor descriptor) { - if (value.hasNullValue()) { - builder.setNull(property); - } else if (value.hasTextValue()) { - builder.setText(property, value.getTextValue()); - } else if (value.hasBoolValue()) { - builder.setBoolean(property, value.getBoolValue()); - } else if (value.hasIntValue()) { - builder.setLong(property, value.getIntValue()); - } else if (value.hasNumberValue()) { - builder.setDouble(property, value.getNumberValue()); - } else if (value.hasBlobValue()) { - builder.setBlob(property, value.getBlobValue()); - } else if (value.hasDateValue()) { - builder.setOffsetDateTime(property, DateUtil.fromISO8601(value.getDateValue())); - } else if (value.hasUuidValue()) { - builder.setUuid(property, UUID.fromString(value.getUuidValue())); - } else if (value.hasListValue()) { - var list = value.getListValue(); - if (list.hasTextValues()) { - builder.setTextArray(property, list.getTextValues().getValuesList()); - } else if (list.hasIntValues()) { - var ints = Arrays.stream( - ByteStringUtil.decodeIntValues(list.getIntValues().getValues())) - .boxed().toList(); - builder.setLongArray(property, ints); - } else if (list.hasNumberValues()) { - var numbers = Arrays.stream( - ByteStringUtil.decodeNumberValues(list.getNumberValues().getValues())) - .boxed().toList(); - builder.setDoubleArray(property, numbers); - } else if (list.hasUuidValues()) { - var uuids = list.getUuidValues().getValuesList().stream() - .map(UUID::fromString).toList(); - builder.setUuidArray(property, uuids); - } else if (list.hasBoolValues()) { - builder.setBooleanArray(property, list.getBoolValues().getValuesList()); - } else if (list.hasDateValues()) { - var dates = list.getDateValues().getValuesList().stream() - .map(DateUtil::fromISO8601).toList(); - builder.setOffsetDateTimeArray(property, dates); - } else if (list.hasObjectValues()) { - List objects = list.getObjectValues().getValuesList().stream() - .map(object -> { - var properties = descriptor.propertiesBuilder(); - object.getFieldsMap().entrySet().stream() - .forEach(entry -> setProperty(entry.getKey(), entry.getValue(), properties, descriptor)); - return properties.build(); - }).toList(); - builder.setNestedObjectArray(property, objects); - } - } else if (value.hasObjectValue()) { - var object = value.getObjectValue(); - var properties = descriptor.propertiesBuilder(); - object.getFieldsMap().entrySet().stream() - .forEach(entry -> setProperty(entry.getKey(), entry.getValue(), properties, descriptor)); - builder.setNestedObject(property, properties.build()); - } else { - throw new IllegalArgumentException(property + " data type is not supported"); - } + static Rpc, WeaviateProtoSearchGet.SearchReply> grouped( + CollectionDescriptor collection, + CollectionHandleDefaults defaults) { + var rpc = rpc(collection, defaults); + return Rpc.of( + request -> rpc.marshal(request), + reply -> QueryResponseGrouped.unmarshal(reply, collection, defaults), + () -> rpc.method(), () -> rpc.methodAsync()); } + } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryResponse.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryResponse.java index b5d465f8c..1160b0255 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryResponse.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryResponse.java @@ -1,9 +1,206 @@ package io.weaviate.client6.v1.api.collections.query; +import java.util.Arrays; +import java.util.HashMap; import java.util.List; +import java.util.UUID; +import java.util.stream.Stream; +import io.weaviate.client6.v1.api.collections.ObjectMetadata; +import io.weaviate.client6.v1.api.collections.Vectors; import io.weaviate.client6.v1.api.collections.WeaviateObject; +import io.weaviate.client6.v1.internal.DateUtil; +import io.weaviate.client6.v1.internal.grpc.ByteStringUtil; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoProperties; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet; +import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; +import io.weaviate.client6.v1.internal.orm.PropertiesBuilder; -public record QueryResponse( - List> objects) { +public record QueryResponse( + List> objects) { + + static QueryResponse unmarshal(WeaviateProtoSearchGet.SearchReply reply, + CollectionDescriptor collection) { + var objects = reply + .getResultsList() + .stream() + .map(obj -> QueryResponse.unmarshalResultObject( + obj.getProperties(), obj.getMetadata(), collection)) + .toList(); + return new QueryResponse<>(objects); + } + + public static WeaviateObject unmarshalResultObject( + WeaviateProtoSearchGet.PropertiesResult propertiesResult, + WeaviateProtoSearchGet.MetadataResult metadataResult, + CollectionDescriptor collection) { + var object = unmarshalWithReferences(propertiesResult, metadataResult, collection); + var metadata = new QueryMetadata.Builder() + .uuid(object.metadata().uuid()) + .vectors(object.metadata().vectors()); + + if (metadataResult.getCreationTimeUnixPresent()) { + metadata.creationTimeUnix(metadataResult.getCreationTimeUnix()); + } + if (metadataResult.getLastUpdateTimeUnixPresent()) { + metadata.lastUpdateTimeUnix(metadataResult.getLastUpdateTimeUnix()); + } + if (metadataResult.getDistancePresent()) { + metadata.distance(metadataResult.getDistance()); + } + if (metadataResult.getCertaintyPresent()) { + metadata.certainty(metadataResult.getCertainty()); + } + if (metadataResult.getScorePresent()) { + metadata.score(metadataResult.getScore()); + } + if (metadataResult.getExplainScorePresent()) { + metadata.explainScore(metadataResult.getExplainScore()); + } + return new WeaviateObject<>(collection.collectionName(), object.properties(), object.references(), + metadata.build()); + } + + static WeaviateObject unmarshalWithReferences( + WeaviateProtoSearchGet.PropertiesResult propertiesResult, + WeaviateProtoSearchGet.MetadataResult metadataResult, + CollectionDescriptor descriptor) { + var properties = descriptor.propertiesBuilder(); + propertiesResult.getNonRefProps().getFieldsMap().entrySet().stream() + .forEach(entry -> setProperty(entry.getKey(), entry.getValue(), properties, descriptor)); + + // In case a reference is multi-target, there will be a separate + // "reference property" for each of the targets, so instead of + // `collect` we need to `reduce` the map, merging related references + // as we go. + // I.e. { "ref": A-1 } , { "ref": B-1 } => { "ref": [A-1, B-1] } + var referenceProperties = propertiesResult.getRefPropsList() + .stream().reduce( + new HashMap>(), + (map, ref) -> { + var refObjects = ref.getPropertiesList().stream() + .map(property -> { + var reference = unmarshalWithReferences( + property, property.getMetadata(), + CollectionDescriptor.ofMap(property.getTargetCollection())); + return (Object) new WeaviateObject<>( + reference.collection(), + (Object) reference.properties(), + reference.references(), + reference.metadata()); + }) + .toList(); + + // Merge ObjectReferences by joining the underlying WeaviateObjects. + map.merge( + ref.getPropName(), + refObjects, + (left, right) -> { + var joined = Stream.concat( + left.stream(), + right.stream()).toList(); + return joined; + }); + return map; + }, + (left, right) -> { + left.putAll(right); + return left; + }); + + ObjectMetadata metadata = null; + if (metadataResult != null) { + var metadataBuilder = new ObjectMetadata.Builder() + .uuid(metadataResult.getId()); + + var vectors = new Vectors[metadataResult.getVectorsList().size()]; + var i = 0; + for (final var vector : metadataResult.getVectorsList()) { + var vectorName = vector.getName(); + var vbytes = vector.getVectorBytes(); + switch (vector.getType()) { + case VECTOR_TYPE_SINGLE_FP32: + vectors[i++] = Vectors.of(vectorName, ByteStringUtil.decodeVectorSingle(vbytes)); + break; + case VECTOR_TYPE_MULTI_FP32: + vectors[i++] = Vectors.of(vectorName, ByteStringUtil.decodeVectorMulti(vbytes)); + break; + default: + continue; + } + } + metadataBuilder.vectors(vectors); + metadata = metadataBuilder.build(); + } + + var obj = new WeaviateObject.Builder() + .collection(descriptor.collectionName()) + .properties(properties.build()) + .references(referenceProperties) + .metadata(metadata); + return obj.build(); + } + + static void setProperty(String property, WeaviateProtoProperties.Value value, + PropertiesBuilder builder, CollectionDescriptor descriptor) { + if (value.hasNullValue()) { + builder.setNull(property); + } else if (value.hasTextValue()) { + builder.setText(property, value.getTextValue()); + } else if (value.hasBoolValue()) { + builder.setBoolean(property, value.getBoolValue()); + } else if (value.hasIntValue()) { + builder.setLong(property, value.getIntValue()); + } else if (value.hasNumberValue()) { + builder.setDouble(property, value.getNumberValue()); + } else if (value.hasBlobValue()) { + builder.setBlob(property, value.getBlobValue()); + } else if (value.hasDateValue()) { + builder.setOffsetDateTime(property, DateUtil.fromISO8601(value.getDateValue())); + } else if (value.hasUuidValue()) { + builder.setUuid(property, UUID.fromString(value.getUuidValue())); + } else if (value.hasListValue()) { + var list = value.getListValue(); + if (list.hasTextValues()) { + builder.setTextArray(property, list.getTextValues().getValuesList()); + } else if (list.hasIntValues()) { + var ints = Arrays.stream( + ByteStringUtil.decodeIntValues(list.getIntValues().getValues())) + .boxed().toList(); + builder.setLongArray(property, ints); + } else if (list.hasNumberValues()) { + var numbers = Arrays.stream( + ByteStringUtil.decodeNumberValues(list.getNumberValues().getValues())) + .boxed().toList(); + builder.setDoubleArray(property, numbers); + } else if (list.hasUuidValues()) { + var uuids = list.getUuidValues().getValuesList().stream() + .map(UUID::fromString).toList(); + builder.setUuidArray(property, uuids); + } else if (list.hasBoolValues()) { + builder.setBooleanArray(property, list.getBoolValues().getValuesList()); + } else if (list.hasDateValues()) { + var dates = list.getDateValues().getValuesList().stream() + .map(DateUtil::fromISO8601).toList(); + builder.setOffsetDateTimeArray(property, dates); + } else if (list.hasObjectValues()) { + List objects = list.getObjectValues().getValuesList().stream() + .map(object -> { + var properties = descriptor.propertiesBuilder(); + object.getFieldsMap().entrySet().stream() + .forEach(entry -> setProperty(entry.getKey(), entry.getValue(), properties, descriptor)); + return properties.build(); + }).toList(); + builder.setNestedObjectArray(property, objects); + } + } else if (value.hasObjectValue()) { + var object = value.getObjectValue(); + var properties = descriptor.propertiesBuilder(); + object.getFieldsMap().entrySet().stream() + .forEach(entry -> setProperty(entry.getKey(), entry.getValue(), properties, descriptor)); + builder.setNestedObject(property, properties.build()); + } else { + throw new IllegalArgumentException(property + " data type is not supported"); + } + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryResponseGrouped.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryResponseGrouped.java index 9ee8442fa..587be10ac 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryResponseGrouped.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryResponseGrouped.java @@ -1,11 +1,52 @@ package io.weaviate.client6.v1.api.collections.query; +import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; -public record QueryResponseGrouped( +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet; +import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; + +public record QueryResponseGrouped( /** All objects retrieved in the query. */ - List> objects, + List> objects, /** Grouped response objects. */ - Map> groups) { + Map> groups) { + + static QueryResponseGrouped unmarshal( + WeaviateProtoSearchGet.SearchReply reply, + CollectionDescriptor collection, + CollectionHandleDefaults defaults) { + var allObjects = new ArrayList>(); + var groups = reply.getGroupByResultsList() + .stream().map(group -> { + var name = group.getName(); + List> objects = group.getObjectsList().stream() + .map(obj -> QueryResponse.unmarshalResultObject( + obj.getProperties(), + obj.getMetadata(), + collection)) + .map(obj -> new QueryObjectGrouped<>(obj, name)) + .toList(); + + allObjects.addAll(objects); + return new QueryResponseGroup<>( + name, + group.getMinDistance(), + group.getMaxDistance(), + group.getNumberOfObjects(), + objects); + }) + // Collectors.toMap() throws an NPE if either key or value in the map are null. + // In this specific case it is safe to use it, as the function in the map above + // always returns a QueryResponseGroup. + // The name of the group should not be null either, that's something we assume + // about the server's response. + .collect(Collectors.toMap(QueryResponseGroup::name, Function.identity())); + + return new QueryResponseGrouped(allObjects, groups); + } } From 7f61a620184a159001a5625125b6a9d88d462f07 Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Tue, 21 Oct 2025 11:39:03 +0200 Subject: [PATCH 02/25] feat: add generative w/ fetchObjects --- .../generate/AbstractGenerateClient.java | 68 ++++++++++++++++++- .../generate/WeaviateGenerateClient.java | 6 +- 2 files changed, 67 insertions(+), 7 deletions(-) diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/AbstractGenerateClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/AbstractGenerateClient.java index e318b323a..6364d0aa4 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generate/AbstractGenerateClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generate/AbstractGenerateClient.java @@ -5,6 +5,7 @@ import io.weaviate.client6.v1.api.WeaviateApiException; import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; import io.weaviate.client6.v1.api.collections.query.Bm25; +import io.weaviate.client6.v1.api.collections.query.FetchObjects; import io.weaviate.client6.v1.api.collections.query.GroupBy; import io.weaviate.client6.v1.api.collections.query.QueryOperator; import io.weaviate.client6.v1.api.collections.query.QueryResponseGrouped; @@ -12,7 +13,7 @@ import io.weaviate.client6.v1.internal.grpc.GrpcTransport; import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; -abstract class AbstractGenerateClient { +abstract class AbstractGenerateClient { protected final CollectionDescriptor collection; protected final GrpcTransport grpcTransport; protected final CollectionHandleDefaults defaults; @@ -26,7 +27,7 @@ abstract class AbstractGenerateClient c, + AbstractGenerateClient c, CollectionHandleDefaults defaults) { this(c.collection, c.grpcTransport, defaults); } @@ -35,6 +36,69 @@ abstract class AbstractGenerateClient> fn, + Function> generateFn) { + return fetchObjects(FetchObjects.of(fn), GenerativeTask.of(generateFn)); + } + + /** + * Retrieve objects without applying a Vector Search or Keyword Search filter. + * + * @param query FetchObjects query. + * @param generate Generative task. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT fetchObjects(FetchObjects query, GenerativeTask generate) { + return performRequest(query, generate); + } + + /** + * Retrieve objects without applying a Vector Search or Keyword Search filter. + * + * + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see QueryResponseGrouped + */ + public GroupedResponseT fetchObjects(Function> fn, + Function> generateFn, + GroupBy groupBy) { + return fetchObjects(FetchObjects.of(fn), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Retrieve objects without applying a Vector Search or Keyword Search filter. + * + * @param query FetchObjects query. + * @param generate Generative task. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see QueryResponseGrouped + */ + public GroupedResponseT fetchObjects(FetchObjects query, GenerativeTask generate, GroupBy groupBy) { + return performRequest(query, generate, groupBy); + } // BM25 queries ------------------------------------------------------------- /** diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/WeaviateGenerateClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/WeaviateGenerateClient.java index eef92cbfb..391007b5a 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generate/WeaviateGenerateClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generate/WeaviateGenerateClient.java @@ -1,18 +1,14 @@ package io.weaviate.client6.v1.api.collections.generate; -import java.util.Optional; - import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; -import io.weaviate.client6.v1.api.collections.WeaviateObject; import io.weaviate.client6.v1.api.collections.query.GroupBy; -import io.weaviate.client6.v1.api.collections.query.QueryMetadata; import io.weaviate.client6.v1.api.collections.query.QueryOperator; import io.weaviate.client6.v1.internal.grpc.GrpcTransport; import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; public class WeaviateGenerateClient extends - AbstractGenerateClient>, GenerativeResponse, GenerativeResponseGrouped> { + AbstractGenerateClient, GenerativeResponseGrouped> { public WeaviateGenerateClient( CollectionDescriptor collection, From c6d08651105ece25d7842dbbb9c6c19c92f01132 Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Tue, 21 Oct 2025 11:46:52 +0200 Subject: [PATCH 03/25] feat: add generative w/ hybrid query --- .../generate/AbstractGenerateClient.java | 175 ++++++++++++++++++ 1 file changed, 175 insertions(+) diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/AbstractGenerateClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/AbstractGenerateClient.java index 6364d0aa4..cef44b400 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generate/AbstractGenerateClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generate/AbstractGenerateClient.java @@ -7,8 +7,10 @@ import io.weaviate.client6.v1.api.collections.query.Bm25; import io.weaviate.client6.v1.api.collections.query.FetchObjects; import io.weaviate.client6.v1.api.collections.query.GroupBy; +import io.weaviate.client6.v1.api.collections.query.Hybrid; import io.weaviate.client6.v1.api.collections.query.QueryOperator; import io.weaviate.client6.v1.api.collections.query.QueryResponseGrouped; +import io.weaviate.client6.v1.api.collections.query.Target; import io.weaviate.client6.v1.internal.ObjectBuilder; import io.weaviate.client6.v1.internal.grpc.GrpcTransport; import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; @@ -198,4 +200,177 @@ public GroupedResponseT bm25(String query, public GroupedResponseT bm25(Bm25 query, GenerativeTask generate, GroupBy groupBy) { return performRequest(query, generate, groupBy); } + + // Hybrid queries ----------------------------------------------------------- + + /** + * Query collection objects using hybrid search. + * + * @param query Query string. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT hybrid(String query, + Function> generateFn) { + return hybrid(Hybrid.of(query), GenerativeTask.of(generateFn)); + } + + /** + * Query collection objects using hybrid search. + * + * @param query Query string. + * @param fn Lambda expression for optional parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT hybrid( + String query, + Function> fn, + Function> generateFn) { + return hybrid(Hybrid.of(query, fn), GenerativeTask.of(generateFn)); + } + + /** + * Query collection objects using hybrid search. + * + * @param searchTarget Query target. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT hybrid( + Target searchTarget, + Function> generateFn) { + return hybrid(Hybrid.of(searchTarget), GenerativeTask.of(generateFn)); + } + + /** + * Query collection objects using hybrid search. + * + * @param searchTarget Query target. + * @param fn Lambda expression for optional parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT hybrid( + Target searchTarget, + Function> fn, + Function> generateFn) { + return hybrid(Hybrid.of(searchTarget, fn), GenerativeTask.of(generateFn)); + } + + /** + * Query collection objects using hybrid search. + * + * @param query Hybrid query request. + * @param generate Generative task. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT hybrid(Hybrid query, GenerativeTask generate) { + return performRequest(query, generate); + } + + /** + * Query collection objects using hybrid search. + * + * @param query Query string. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see QueryResponseGrouped + */ + public GroupedResponseT hybrid( + String query, + Function> generateFn, + GroupBy groupBy) { + return hybrid(Hybrid.of(query), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Query collection objects using hybrid search. + * + * @param query Query string. + * @param fn Lambda expression for optional parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see QueryResponseGrouped + */ + public GroupedResponseT hybrid( + String query, + Function> generateFn, + Function> fn, GroupBy groupBy) { + return hybrid(Hybrid.of(query, fn), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Query collection objects using hybrid search. + * + * @param searchTarget Query target. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see QueryResponseGrouped + */ + public GroupedResponseT hybrid( + Target searchTarget, + Function> generateFn, + GroupBy groupBy) { + return hybrid(Hybrid.of(searchTarget), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Query collection objects using hybrid search. + * + * @param searchTarget Query target. + * @param fn Lambda expression for optional parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see QueryResponseGrouped + */ + public GroupedResponseT hybrid( + Target searchTarget, + Function> generateFn, + Function> fn, + GroupBy groupBy) { + return hybrid(Hybrid.of(searchTarget, fn), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Query collection objects using hybrid search. + * + * @param query Query string. + * @param generate Generative task. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see QueryResponseGrouped + */ + public GroupedResponseT hybrid(Hybrid query, GenerativeTask generate, GroupBy groupBy) { + return performRequest(query, generate, groupBy); + } } From 9707ddf09aabe79cd8dc82814b6b689a3ad2438e Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Tue, 21 Oct 2025 12:00:02 +0200 Subject: [PATCH 04/25] feat: add generative w/ nearVector query --- .../generate/AbstractGenerateClient.java | 179 +++++++++++++++++- .../query/AbstractQueryClient.java | 8 +- 2 files changed, 178 insertions(+), 9 deletions(-) diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/AbstractGenerateClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/AbstractGenerateClient.java index cef44b400..5d48c25fb 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generate/AbstractGenerateClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generate/AbstractGenerateClient.java @@ -8,6 +8,8 @@ import io.weaviate.client6.v1.api.collections.query.FetchObjects; import io.weaviate.client6.v1.api.collections.query.GroupBy; import io.weaviate.client6.v1.api.collections.query.Hybrid; +import io.weaviate.client6.v1.api.collections.query.NearVector; +import io.weaviate.client6.v1.api.collections.query.NearVectorTarget; import io.weaviate.client6.v1.api.collections.query.QueryOperator; import io.weaviate.client6.v1.api.collections.query.QueryResponseGrouped; import io.weaviate.client6.v1.api.collections.query.Target; @@ -167,7 +169,7 @@ public GroupedResponseT bm25(String query, * Query collection objects using keyword (BM25) search. * * @param query Query string. - * @param fn Lambda expression for optional parameters. + * @param fn Lambda expression for optional search parameters. * @param generateFn Lambda expression for generative task parameters. * @param groupBy Group-by clause. * @return Grouped query result. @@ -220,7 +222,7 @@ public ResponseT hybrid(String query, * Query collection objects using hybrid search. * * @param query Query string. - * @param fn Lambda expression for optional parameters. + * @param fn Lambda expression for optional search parameters. * @param generateFn Lambda expression for generative task parameters. * @throws WeaviateApiException in case the server returned with an * error status code. @@ -250,7 +252,7 @@ public ResponseT hybrid( * Query collection objects using hybrid search. * * @param searchTarget Query target. - * @param fn Lambda expression for optional parameters. + * @param fn Lambda expression for optional search parameters. * @param generateFn Lambda expression for generative task parameters. * @throws WeaviateApiException in case the server returned with an * error status code. @@ -298,7 +300,7 @@ public GroupedResponseT hybrid( * Query collection objects using hybrid search. * * @param query Query string. - * @param fn Lambda expression for optional parameters. + * @param fn Lambda expression for optional search parameters. * @param generateFn Lambda expression for generative task parameters. * @param groupBy Group-by clause. * @return Grouped query result. @@ -339,7 +341,7 @@ public GroupedResponseT hybrid( * Query collection objects using hybrid search. * * @param searchTarget Query target. - * @param fn Lambda expression for optional parameters. + * @param fn Lambda expression for optional search parameters. * @param generateFn Lambda expression for generative task parameters. * @param groupBy Group-by clause. * @return Grouped query result. @@ -373,4 +375,171 @@ public GroupedResponseT hybrid( public GroupedResponseT hybrid(Hybrid query, GenerativeTask generate, GroupBy groupBy) { return performRequest(query, generate, groupBy); } + + // NearVector queries ------------------------------------------------------- + + /** + * Query collection objects using near vector search. + * + * @param vector Query vector. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearVector(float[] vector, + Function> generateFn) { + return nearVector(Target.vector(vector), generateFn); + } + + /** + * Query collection objects using near vector search. + * + * @param vector Query vector. + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearVector(float[] vector, + Function> fn, + Function> generateFn) { + return nearVector(Target.vector(vector), fn, generateFn); + } + + /** + * Query collection objects using near vector search. + * + * @param searchTarget Target query vectors. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearVector(NearVectorTarget searchTarget, + Function> generateFn) { + return nearVector(NearVector.of(searchTarget), GenerativeTask.of(generateFn)); + } + + /** + * Query collection objects using near vector search. + * + * @param searchTarget Target query vectors. + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearVector(NearVectorTarget searchTarget, + Function> fn, + Function> generateFn) { + return nearVector(NearVector.of(searchTarget, fn), GenerativeTask.of(generateFn)); + } + + /** + * Query collection objects using near vector search. + * + * @param query Near vector query request. + * @param generate Generative task. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearVector(NearVector query, GenerativeTask generate) { + return performRequest(query, generate); + } + + /** + * Query collection objects using near vector search. + * + * @param vector Query vector. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see QueryResponseGrouped + */ + public GroupedResponseT nearVector(float[] vector, + Function> generateFn, + GroupBy groupBy) { + return nearVector(Target.vector(vector), generateFn, groupBy); + } + + /** + * Query collection objects using near vector search. + * + * @param vector Query vector. + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see QueryResponseGrouped + */ + public GroupedResponseT nearVector(float[] vector, + Function> fn, + Function> generateFn, + GroupBy groupBy) { + return nearVector(Target.vector(vector), fn, generateFn, groupBy); + } + + /** + * Query collection objects using near vector search. + * + * @param searchTarget Target query vectors. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see QueryResponseGrouped + */ + public GroupedResponseT nearVector( + NearVectorTarget searchTarget, + Function> generateFn, + GroupBy groupBy) { + return nearVector(NearVector.of(searchTarget), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Query collection objects using near vector search. + * + * @param searchTarget Target query vectors. + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see QueryResponseGrouped + */ + public GroupedResponseT nearVector(NearVectorTarget searchTarget, + Function> fn, + Function> generateFn, + GroupBy groupBy) { + return nearVector(NearVector.of(searchTarget, fn), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Query collection objects using near vector search. + * + * @param query Near vector query request. + * @param generate Generative task. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see QueryResponseGrouped + */ + public GroupedResponseT nearVector(NearVector query, GenerativeTask generate, GroupBy groupBy) { + return performRequest(query, generate, groupBy); + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/AbstractQueryClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/AbstractQueryClient.java index 8774ef508..db684e54f 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/AbstractQueryClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/AbstractQueryClient.java @@ -398,7 +398,7 @@ public ResponseT nearVector(float[] vector, Function> fn) { - return performRequest(NearVector.of(searchTarget, fn)); + return nearVector(NearVector.of(searchTarget, fn)); } /** @@ -473,7 +473,7 @@ public GroupedResponseT nearVector(float[] vector, Function> fn, GroupBy groupBy) { - return performRequest(NearVector.of(searchTarget, fn), groupBy); + return nearVector(NearVector.of(searchTarget, fn), groupBy); } /** From 82111dd8992d873dcc3a2a16118339b86666d6de Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Tue, 21 Oct 2025 12:08:43 +0200 Subject: [PATCH 05/25] feat: add generative w/ nearObject query --- .../generate/AbstractGenerateClient.java | 100 ++++++++++++++++++ 1 file changed, 100 insertions(+) diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/AbstractGenerateClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/AbstractGenerateClient.java index 5d48c25fb..b00127a0d 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generate/AbstractGenerateClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generate/AbstractGenerateClient.java @@ -8,6 +8,7 @@ import io.weaviate.client6.v1.api.collections.query.FetchObjects; import io.weaviate.client6.v1.api.collections.query.GroupBy; import io.weaviate.client6.v1.api.collections.query.Hybrid; +import io.weaviate.client6.v1.api.collections.query.NearObject; import io.weaviate.client6.v1.api.collections.query.NearVector; import io.weaviate.client6.v1.api.collections.query.NearVectorTarget; import io.weaviate.client6.v1.api.collections.query.QueryOperator; @@ -542,4 +543,103 @@ public GroupedResponseT nearVector(NearVectorTarget searchTarget, public GroupedResponseT nearVector(NearVector query, GenerativeTask generate, GroupBy groupBy) { return performRequest(query, generate, groupBy); } + + // NearObject queries ------------------------------------------------------- + + /** + * Query collection objects using near object search. + * + * @param uuid Query object UUID. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearObject(String uuid, + Function> generateFn) { + return nearObject(NearObject.of(uuid), GenerativeTask.of(generateFn)); + } + + /** + * Query collection objects using near object search. + * + * @param uuid Query object UUID. + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearObject(String uuid, + Function> fn, + Function> generateFn) { + return nearObject(NearObject.of(uuid, fn), GenerativeTask.of(generateFn)); + } + + /** + * Query collection objects using near object search. + * + * @param query Near object query request. + * @param generate Generative task. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearObject(NearObject query, GenerativeTask generate) { + return performRequest(query, generate); + } + + /** + * Query collection objects using near object search. + * + * @param uuid Query object UUID. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see QueryResponseGrouped + */ + public GroupedResponseT nearObject(String uuid, + Function> generateFn, + GroupBy groupBy) { + return nearObject(NearObject.of(uuid), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Query collection objects using near object search. + * + * @param uuid Query object UUID. + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see QueryResponseGrouped + */ + public GroupedResponseT nearObject(String uuid, + Function> fn, + Function> generateFn, + GroupBy groupBy) { + return nearObject(NearObject.of(uuid, fn), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Query collection objects using near object search. + * + * @param query Near object query request. + * @param generate Generative task. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see QueryResponseGrouped + */ + public GroupedResponseT nearObject(NearObject query, GenerativeTask generate, GroupBy groupBy) { + return performRequest(query, generate, groupBy); + } } From e9bb5c5d3899f6e304b61610208aa38292f193e6 Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Tue, 21 Oct 2025 12:19:31 +0200 Subject: [PATCH 06/25] feat: add generative w/ nearText query --- .../generate/AbstractGenerateClient.java | 209 ++++++++++++++++++ .../query/AbstractQueryClient.java | 2 +- 2 files changed, 210 insertions(+), 1 deletion(-) diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/AbstractGenerateClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/AbstractGenerateClient.java index b00127a0d..ff0201615 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generate/AbstractGenerateClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generate/AbstractGenerateClient.java @@ -1,5 +1,6 @@ package io.weaviate.client6.v1.api.collections.generate; +import java.util.List; import java.util.function.Function; import io.weaviate.client6.v1.api.WeaviateApiException; @@ -9,6 +10,7 @@ import io.weaviate.client6.v1.api.collections.query.GroupBy; import io.weaviate.client6.v1.api.collections.query.Hybrid; import io.weaviate.client6.v1.api.collections.query.NearObject; +import io.weaviate.client6.v1.api.collections.query.NearText; import io.weaviate.client6.v1.api.collections.query.NearVector; import io.weaviate.client6.v1.api.collections.query.NearVectorTarget; import io.weaviate.client6.v1.api.collections.query.QueryOperator; @@ -642,4 +644,211 @@ public GroupedResponseT nearObject(String uuid, public GroupedResponseT nearObject(NearObject query, GenerativeTask generate, GroupBy groupBy) { return performRequest(query, generate, groupBy); } + + // NearText queries --------------------------------------------------------- + + /** + * Query collection objects using near text search. + * + * @param text Query concepts. + * @param fn Lambda expression for optional parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearText(String text, + Function> fn, + Function> generateFn) { + return nearText(Target.text(List.of(text)), fn, generateFn); + } + + /** + * Query collection objects using near text search. + * + * @param text Query concepts. + * @param fn Lambda expression for optional parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearText(List text, + Function> fn, + Function> generateFn) { + return nearText(Target.text(text), fn, generateFn); + } + + /** + * Query collection objects using near text search. + * + * @param searchTarget Target query concepts. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearText(Target searchTarget, + Function> generateFn) { + return nearText(NearText.of(searchTarget), GenerativeTask.of(generateFn)); + } + + /** + * Query collection objects using near text search. + * + * @param searchTarget Target query concepts. + * @param fn Lambda expression for optional parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearText(Target searchTarget, + Function> fn, + Function> generateFn) { + return nearText(NearText.of(searchTarget, fn), GenerativeTask.of(generateFn)); + } + + /** + * Query collection objects using near text search. + * + * @param query Near text query request. + * @param generate Generative task. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearText(NearText query, GenerativeTask generate) { + return performRequest(query, generate); + } + + /** + * Query collection objects using near text search. + * + * @param text Query concepts. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see QueryResponseGrouped + */ + public GroupedResponseT nearText(String text, + Function> generateFn, + GroupBy groupBy) { + return nearText(Target.text(List.of(text)), generateFn, groupBy); + } + + /** + * Query collection objects using near text search. + * + * @param text Query concepts. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see QueryResponseGrouped + */ + public GroupedResponseT nearText(List text, + Function> generateFn, GroupBy groupBy) { + return nearText(Target.text(text), generateFn, groupBy); + } + + /** + * Query collection objects using near text search. + * + * @param text Query concepts. + * @param fn Lambda expression for optional parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see QueryResponseGrouped + */ + public GroupedResponseT nearText(String text, + Function> fn, + Function> generateFn, + GroupBy groupBy) { + return nearText(Target.text(List.of(text)), fn, generateFn, groupBy); + } + + /** + * Query collection objects using near text search. + * + * @param text Query concepts. + * @param fn Lambda expression for optional parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see QueryResponseGrouped + */ + public GroupedResponseT nearText(List text, + Function> fn, + Function> generateFn, + GroupBy groupBy) { + return nearText(Target.text(text), fn, generateFn, groupBy); + } + + /** + * Query collection objects using near text search. + * + * @param searchTarget Target query concepts. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see QueryResponseGrouped + */ + public GroupedResponseT nearText(Target searchTarget, + Function> generateFn, GroupBy groupBy) { + return nearText(NearText.of(searchTarget), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Query collection objects using near text search. + * + * @param searchTarget Target query concepts. + * @param fn Lambda expression for optional parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see QueryResponseGrouped + */ + public GroupedResponseT nearText(Target searchTarget, + Function> fn, + Function> generateFn, + GroupBy groupBy) { + return nearText(NearText.of(searchTarget, fn), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Query collection objects using near text search. + * + * @param query Near text query request. + * @param generate Generative task. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see QueryResponseGrouped + */ + public GroupedResponseT nearText(NearText query, GenerativeTask generate, GroupBy groupBy) { + return performRequest(query, generate, groupBy); + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/AbstractQueryClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/AbstractQueryClient.java index db684e54f..887bc53fc 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/AbstractQueryClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/AbstractQueryClient.java @@ -734,7 +734,7 @@ public GroupedResponseT nearText(String text, public GroupedResponseT nearText(List text, Function> fn, GroupBy groupBy) { - return nearText(Target.text(text), groupBy); + return nearText(Target.text(text), fn, groupBy); } /** From c7219b8723521b651ca03e043266c1536e8f8c7f Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Tue, 21 Oct 2025 12:43:55 +0200 Subject: [PATCH 07/25] feat: add generative w/ nearImage + nearAudio queries --- .../generate/AbstractGenerateClient.java | 375 ++++++++++++++++-- 1 file changed, 349 insertions(+), 26 deletions(-) diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/AbstractGenerateClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/AbstractGenerateClient.java index ff0201615..d85790573 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generate/AbstractGenerateClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generate/AbstractGenerateClient.java @@ -9,12 +9,13 @@ import io.weaviate.client6.v1.api.collections.query.FetchObjects; import io.weaviate.client6.v1.api.collections.query.GroupBy; import io.weaviate.client6.v1.api.collections.query.Hybrid; +import io.weaviate.client6.v1.api.collections.query.NearAudio; +import io.weaviate.client6.v1.api.collections.query.NearImage; import io.weaviate.client6.v1.api.collections.query.NearObject; import io.weaviate.client6.v1.api.collections.query.NearText; import io.weaviate.client6.v1.api.collections.query.NearVector; import io.weaviate.client6.v1.api.collections.query.NearVectorTarget; import io.weaviate.client6.v1.api.collections.query.QueryOperator; -import io.weaviate.client6.v1.api.collections.query.QueryResponseGrouped; import io.weaviate.client6.v1.api.collections.query.Target; import io.weaviate.client6.v1.internal.ObjectBuilder; import io.weaviate.client6.v1.internal.grpc.GrpcTransport; @@ -82,7 +83,7 @@ public ResponseT fetchObjects(FetchObjects query, GenerativeTask generate) { * error status code. * * @see GroupBy - * @see QueryResponseGrouped + * @see GenerativeResponseGrouped */ public GroupedResponseT fetchObjects(Function> fn, Function> generateFn, @@ -101,7 +102,7 @@ public GroupedResponseT fetchObjects(Function> generateFn, @@ -180,7 +181,7 @@ public GroupedResponseT bm25(String query, * error status code. * * @see GroupBy - * @see QueryResponseGrouped + * @see GenerativeResponseGrouped */ public GroupedResponseT bm25(String query, Function> fn, @@ -200,7 +201,7 @@ public GroupedResponseT bm25(String query, * error status code. * * @see GroupBy - * @see QueryResponseGrouped + * @see GenerativeResponseGrouped */ public GroupedResponseT bm25(Bm25 query, GenerativeTask generate, GroupBy groupBy) { return performRequest(query, generate, groupBy); @@ -290,7 +291,7 @@ public ResponseT hybrid(Hybrid query, GenerativeTask generate) { * error status code. * * @see GroupBy - * @see QueryResponseGrouped + * @see GenerativeResponseGrouped */ public GroupedResponseT hybrid( String query, @@ -311,7 +312,7 @@ public GroupedResponseT hybrid( * error status code. * * @see GroupBy - * @see QueryResponseGrouped + * @see GenerativeResponseGrouped */ public GroupedResponseT hybrid( String query, @@ -331,7 +332,7 @@ public GroupedResponseT hybrid( * error status code. * * @see GroupBy - * @see QueryResponseGrouped + * @see GenerativeResponseGrouped */ public GroupedResponseT hybrid( Target searchTarget, @@ -352,7 +353,7 @@ public GroupedResponseT hybrid( * error status code. * * @see GroupBy - * @see QueryResponseGrouped + * @see GenerativeResponseGrouped */ public GroupedResponseT hybrid( Target searchTarget, @@ -373,7 +374,7 @@ public GroupedResponseT hybrid( * error status code. * * @see GroupBy - * @see QueryResponseGrouped + * @see GenerativeResponseGrouped */ public GroupedResponseT hybrid(Hybrid query, GenerativeTask generate, GroupBy groupBy) { return performRequest(query, generate, groupBy); @@ -460,7 +461,7 @@ public ResponseT nearVector(NearVector query, GenerativeTask generate) { * error status code. * * @see GroupBy - * @see QueryResponseGrouped + * @see GenerativeResponseGrouped */ public GroupedResponseT nearVector(float[] vector, Function> generateFn, @@ -480,7 +481,7 @@ public GroupedResponseT nearVector(float[] vector, * error status code. * * @see GroupBy - * @see QueryResponseGrouped + * @see GenerativeResponseGrouped */ public GroupedResponseT nearVector(float[] vector, Function> fn, @@ -500,7 +501,7 @@ public GroupedResponseT nearVector(float[] vector, * error status code. * * @see GroupBy - * @see QueryResponseGrouped + * @see GenerativeResponseGrouped */ public GroupedResponseT nearVector( NearVectorTarget searchTarget, @@ -520,7 +521,7 @@ public GroupedResponseT nearVector( * error status code. * * @see GroupBy - * @see QueryResponseGrouped + * @see GenerativeResponseGrouped */ public GroupedResponseT nearVector(NearVectorTarget searchTarget, Function> fn, @@ -540,7 +541,7 @@ public GroupedResponseT nearVector(NearVectorTarget searchTarget, * error status code. * * @see GroupBy - * @see QueryResponseGrouped + * @see GenerativeResponseGrouped */ public GroupedResponseT nearVector(NearVector query, GenerativeTask generate, GroupBy groupBy) { return performRequest(query, generate, groupBy); @@ -599,7 +600,7 @@ public ResponseT nearObject(NearObject query, GenerativeTask generate) { * error status code. * * @see GroupBy - * @see QueryResponseGrouped + * @see GenerativeResponseGrouped */ public GroupedResponseT nearObject(String uuid, Function> generateFn, @@ -619,7 +620,7 @@ public GroupedResponseT nearObject(String uuid, * error status code. * * @see GroupBy - * @see QueryResponseGrouped + * @see GenerativeResponseGrouped */ public GroupedResponseT nearObject(String uuid, Function> fn, @@ -639,7 +640,7 @@ public GroupedResponseT nearObject(String uuid, * error status code. * * @see GroupBy - * @see QueryResponseGrouped + * @see GenerativeResponseGrouped */ public GroupedResponseT nearObject(NearObject query, GenerativeTask generate, GroupBy groupBy) { return performRequest(query, generate, groupBy); @@ -728,7 +729,7 @@ public ResponseT nearText(NearText query, GenerativeTask generate) { * error status code. * * @see GroupBy - * @see QueryResponseGrouped + * @see GenerativeResponseGrouped */ public GroupedResponseT nearText(String text, Function> generateFn, @@ -747,7 +748,7 @@ public GroupedResponseT nearText(String text, * error status code. * * @see GroupBy - * @see QueryResponseGrouped + * @see GenerativeResponseGrouped */ public GroupedResponseT nearText(List text, Function> generateFn, GroupBy groupBy) { @@ -766,7 +767,7 @@ public GroupedResponseT nearText(List text, * error status code. * * @see GroupBy - * @see QueryResponseGrouped + * @see GenerativeResponseGrouped */ public GroupedResponseT nearText(String text, Function> fn, @@ -787,7 +788,7 @@ public GroupedResponseT nearText(String text, * error status code. * * @see GroupBy - * @see QueryResponseGrouped + * @see GenerativeResponseGrouped */ public GroupedResponseT nearText(List text, Function> fn, @@ -807,7 +808,7 @@ public GroupedResponseT nearText(List text, * error status code. * * @see GroupBy - * @see QueryResponseGrouped + * @see GenerativeResponseGrouped */ public GroupedResponseT nearText(Target searchTarget, Function> generateFn, GroupBy groupBy) { @@ -826,7 +827,7 @@ public GroupedResponseT nearText(Target searchTarget, * error status code. * * @see GroupBy - * @see QueryResponseGrouped + * @see GenerativeResponseGrouped */ public GroupedResponseT nearText(Target searchTarget, Function> fn, @@ -846,9 +847,331 @@ public GroupedResponseT nearText(Target searchTarget, * error status code. * * @see GroupBy - * @see QueryResponseGrouped + * @see GenerativeResponseGrouped */ public GroupedResponseT nearText(NearText query, GenerativeTask generate, GroupBy groupBy) { return performRequest(query, generate, groupBy); } + + // NearImage queries -------------------------------------------------------- + + /** + * Query collection objects using near image search. + * + * @param image Query image (base64-encoded). + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearImage(String image, + Function> generateFn) { + return nearImage(NearImage.of(image), GenerativeTask.of(generateFn)); + } + + /** + * Query collection objects using near image search. + * + * @param image Query image (base64-encoded). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearImage(String image, Function> fn, + Function> generateFn) { + return nearImage(NearImage.of(image, fn), GenerativeTask.of(generateFn)); + } + + /** + * Query collection objects using near image search. + * + * @param searchTarget Query target (base64-encoded image). + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearImage(Target searchTarget, + Function> generateFn) { + return nearImage(NearImage.of(searchTarget), GenerativeTask.of(generateFn)); + } + + /** + * Query collection objects using near image search. + * + * @param searchTarget Query target (base64-encoded image). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearImage(Target searchTarget, Function> fn, + Function> generateFn) { + return nearImage(NearImage.of(searchTarget, fn), GenerativeTask.of(generateFn)); + } + + /** + * Query collection objects using near image search. + * + * @param query Near image query request. + * @param generate Generative task. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearImage(NearImage query, GenerativeTask generate) { + return performRequest(query, generate); + } + + /** + * Query collection objects using near image search. + * + * @param image Query image (base64-encoded). + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * + * @see GroupBy + * @see GenerativeResponseGrouped + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public GroupedResponseT nearImage(String image, + Function> generateFn, GroupBy groupBy) { + return nearImage(NearImage.of(image), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Query collection objects using near image search. + * + * @param searchTarget Query target (base64-encoded image). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearImage(Target searchTarget, Function> fn, + Function> generateFn, + GroupBy groupBy) { + return nearImage(NearImage.of(searchTarget, fn), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Query collection objects using near image search. + * + * @param searchTarget Query target (base64-encoded image). + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * + * @see GroupBy + * @see GenerativeResponseGrouped + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public GroupedResponseT nearImage(Target searchTarget, + Function> generateFn, GroupBy groupBy) { + return nearImage(NearImage.of(searchTarget), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Query collection objects using near image search. + * + * @param image Query image (base64-encoded). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearImage(String image, Function> fn, + Function> generateFn, + GroupBy groupBy) { + return nearImage(NearImage.of(image, fn), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Query collection objects using near image search. + * + * @param query Near image query request. + * @param generate Generative task. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearImage(NearImage query, GenerativeTask generate, GroupBy groupBy) { + return performRequest(query, generate, groupBy); + } + + // NearAudio queries -------------------------------------------------------- + + /** + * Query collection objects using near audio search. + * + * @param audio Query audio (base64-encoded). + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearAudio(String audio, + Function> generateFn) { + return nearAudio(NearAudio.of(audio), GenerativeTask.of(generateFn)); + } + + /** + * Query collection objects using near audio search. + * + * @param audio Query audio (base64-encoded). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearAudio(String audio, Function> fn, + Function> generateFn) { + return nearAudio(NearAudio.of(audio, fn), GenerativeTask.of(generateFn)); + } + + /** + * Query collection objects using near audio search. + * + * @param searchTarget Query target (base64-encoded audio). + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearAudio(Target searchTarget, + Function> generateFn) { + return nearAudio(NearAudio.of(searchTarget), GenerativeTask.of(generateFn)); + } + + /** + * Query collection objects using near audio search. + * + * @param searchTarget Query target (base64-encoded audio). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearAudio(Target searchTarget, Function> fn, + Function> generateFn) { + return nearAudio(NearAudio.of(searchTarget, fn), GenerativeTask.of(generateFn)); + } + + /** + * Query collection objects using near audio search. + * + * @param query Near audio query request. + * @param generate Generative task. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearAudio(NearAudio query, GenerativeTask generate) { + return performRequest(query, generate); + } + + /** + * Query collection objects using near audio search. + * + * @param audio Query audio (base64-encoded). + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * + * @see GroupBy + * @see GenerativeResponseGrouped + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public GroupedResponseT nearAudio(String audio, + Function> generateFn, GroupBy groupBy) { + return nearAudio(NearAudio.of(audio), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Query collection objects using near audio search. + * + * @param searchTarget Query target (base64-encoded audio). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearAudio(Target searchTarget, Function> fn, + Function> generateFn, + GroupBy groupBy) { + return nearAudio(NearAudio.of(searchTarget, fn), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Query collection objects using near audio search. + * + * @param searchTarget Query target (base64-encoded audio). + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * + * @see GroupBy + * @see GenerativeResponseGrouped + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public GroupedResponseT nearAudio(Target searchTarget, + Function> generateFn, GroupBy groupBy) { + return nearAudio(NearAudio.of(searchTarget), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Query collection objects using near audio search. + * + * @param audio Query audio (base64-encoded). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearAudio(String audio, Function> fn, + Function> generateFn, + GroupBy groupBy) { + return nearAudio(NearAudio.of(audio, fn), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Query collection objects using near audio search. + * + * @param query Near audio query request. + * @param generate Generative task. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearAudio(NearAudio query, GenerativeTask generate, GroupBy groupBy) { + return performRequest(query, generate, groupBy); + } } From dde9e11824237cf803279d89c0447b4cd6449eba Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Tue, 21 Oct 2025 12:53:11 +0200 Subject: [PATCH 08/25] feat: add generative w/ nearVideo/Thermal/Depth/Imu queries --- .../generate/AbstractGenerateClient.java | 648 ++++++++++++++++++ 1 file changed, 648 insertions(+) diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/AbstractGenerateClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/AbstractGenerateClient.java index d85790573..cde5101bb 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generate/AbstractGenerateClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generate/AbstractGenerateClient.java @@ -10,11 +10,15 @@ import io.weaviate.client6.v1.api.collections.query.GroupBy; import io.weaviate.client6.v1.api.collections.query.Hybrid; import io.weaviate.client6.v1.api.collections.query.NearAudio; +import io.weaviate.client6.v1.api.collections.query.NearDepth; import io.weaviate.client6.v1.api.collections.query.NearImage; +import io.weaviate.client6.v1.api.collections.query.NearImu; import io.weaviate.client6.v1.api.collections.query.NearObject; import io.weaviate.client6.v1.api.collections.query.NearText; +import io.weaviate.client6.v1.api.collections.query.NearThermal; import io.weaviate.client6.v1.api.collections.query.NearVector; import io.weaviate.client6.v1.api.collections.query.NearVectorTarget; +import io.weaviate.client6.v1.api.collections.query.NearVideo; import io.weaviate.client6.v1.api.collections.query.QueryOperator; import io.weaviate.client6.v1.api.collections.query.Target; import io.weaviate.client6.v1.internal.ObjectBuilder; @@ -1174,4 +1178,648 @@ public GroupedResponseT nearAudio(String audio, Function> generateFn) { + return nearVideo(NearVideo.of(video), GenerativeTask.of(generateFn)); + } + + /** + * Query collection objects using near video search. + * + * @param video Query video (base64-encoded). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearVideo(String video, Function> fn, + Function> generateFn) { + return nearVideo(NearVideo.of(video, fn), GenerativeTask.of(generateFn)); + } + + /** + * Query collection objects using near video search. + * + * @param searchTarget Query target (base64-encoded video). + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearVideo(Target searchTarget, + Function> generateFn) { + return nearVideo(NearVideo.of(searchTarget), GenerativeTask.of(generateFn)); + } + + /** + * Query collection objects using near video search. + * + * @param searchTarget Query target (base64-encoded video). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearVideo(Target searchTarget, Function> fn, + Function> generateFn) { + return nearVideo(NearVideo.of(searchTarget, fn), GenerativeTask.of(generateFn)); + } + + /** + * Query collection objects using near video search. + * + * @param query Near video query request. + * @param generate Generative task. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearVideo(NearVideo query, GenerativeTask generate) { + return performRequest(query, generate); + } + + /** + * Query collection objects using near video search. + * + * @param video Query video (base64-encoded). + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * + * @see GroupBy + * @see GenerativeResponseGrouped + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public GroupedResponseT nearVideo(String video, + Function> generateFn, GroupBy groupBy) { + return nearVideo(NearVideo.of(video), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Query collection objects using near video search. + * + * @param searchTarget Query target (base64-encoded video). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearVideo(Target searchTarget, Function> fn, + Function> generateFn, + GroupBy groupBy) { + return nearVideo(NearVideo.of(searchTarget, fn), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Query collection objects using near video search. + * + * @param searchTarget Query target (base64-encoded video). + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * + * @see GroupBy + * @see GenerativeResponseGrouped + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public GroupedResponseT nearVideo(Target searchTarget, + Function> generateFn, GroupBy groupBy) { + return nearVideo(NearVideo.of(searchTarget), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Query collection objects using near video search. + * + * @param video Query video (base64-encoded). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearVideo(String video, Function> fn, + Function> generateFn, + GroupBy groupBy) { + return nearVideo(NearVideo.of(video, fn), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Query collection objects using near video search. + * + * @param query Near video query request. + * @param generate Generative task. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearVideo(NearVideo query, GenerativeTask generate, GroupBy groupBy) { + return performRequest(query, generate, groupBy); + } + + // NearThermal queries -------------------------------------------------------- + + /** + * Query collection objects using near thermal search. + * + * @param thermal Query thermal (base64-encoded). + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearThermal(String thermal, + Function> generateFn) { + return nearThermal(NearThermal.of(thermal), GenerativeTask.of(generateFn)); + } + + /** + * Query collection objects using near thermal search. + * + * @param thermal Query thermal (base64-encoded). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearThermal(String thermal, Function> fn, + Function> generateFn) { + return nearThermal(NearThermal.of(thermal, fn), GenerativeTask.of(generateFn)); + } + + /** + * Query collection objects using near thermal search. + * + * @param searchTarget Query target (base64-encoded thermal). + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearThermal(Target searchTarget, + Function> generateFn) { + return nearThermal(NearThermal.of(searchTarget), GenerativeTask.of(generateFn)); + } + + /** + * Query collection objects using near thermal search. + * + * @param searchTarget Query target (base64-encoded thermal). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearThermal(Target searchTarget, Function> fn, + Function> generateFn) { + return nearThermal(NearThermal.of(searchTarget, fn), GenerativeTask.of(generateFn)); + } + + /** + * Query collection objects using near thermal search. + * + * @param query Near thermal query request. + * @param generate Generative task. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearThermal(NearThermal query, GenerativeTask generate) { + return performRequest(query, generate); + } + + /** + * Query collection objects using near thermal search. + * + * @param thermal Query thermal (base64-encoded). + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * + * @see GroupBy + * @see GenerativeResponseGrouped + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public GroupedResponseT nearThermal(String thermal, + Function> generateFn, GroupBy groupBy) { + return nearThermal(NearThermal.of(thermal), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Query collection objects using near thermal search. + * + * @param searchTarget Query target (base64-encoded thermal). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearThermal(Target searchTarget, Function> fn, + Function> generateFn, + GroupBy groupBy) { + return nearThermal(NearThermal.of(searchTarget, fn), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Query collection objects using near thermal search. + * + * @param searchTarget Query target (base64-encoded thermal). + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * + * @see GroupBy + * @see GenerativeResponseGrouped + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public GroupedResponseT nearThermal(Target searchTarget, + Function> generateFn, GroupBy groupBy) { + return nearThermal(NearThermal.of(searchTarget), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Query collection objects using near thermal search. + * + * @param thermal Query thermal (base64-encoded). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearThermal(String thermal, Function> fn, + Function> generateFn, + GroupBy groupBy) { + return nearThermal(NearThermal.of(thermal, fn), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Query collection objects using near thermal search. + * + * @param query Near thermal query request. + * @param generate Generative task. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearThermal(NearThermal query, GenerativeTask generate, GroupBy groupBy) { + return performRequest(query, generate, groupBy); + } + + // NearDepth queries -------------------------------------------------------- + + /** + * Query collection objects using near depth search. + * + * @param depth Query depth (base64-encoded). + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearDepth(String depth, + Function> generateFn) { + return nearDepth(NearDepth.of(depth), GenerativeTask.of(generateFn)); + } + + /** + * Query collection objects using near depth search. + * + * @param depth Query depth (base64-encoded). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearDepth(String depth, Function> fn, + Function> generateFn) { + return nearDepth(NearDepth.of(depth, fn), GenerativeTask.of(generateFn)); + } + + /** + * Query collection objects using near depth search. + * + * @param searchTarget Query target (base64-encoded depth). + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearDepth(Target searchTarget, + Function> generateFn) { + return nearDepth(NearDepth.of(searchTarget), GenerativeTask.of(generateFn)); + } + + /** + * Query collection objects using near depth search. + * + * @param searchTarget Query target (base64-encoded depth). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearDepth(Target searchTarget, Function> fn, + Function> generateFn) { + return nearDepth(NearDepth.of(searchTarget, fn), GenerativeTask.of(generateFn)); + } + + /** + * Query collection objects using near depth search. + * + * @param query Near depth query request. + * @param generate Generative task. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearDepth(NearDepth query, GenerativeTask generate) { + return performRequest(query, generate); + } + + /** + * Query collection objects using near depth search. + * + * @param depth Query depth (base64-encoded). + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * + * @see GroupBy + * @see GenerativeResponseGrouped + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public GroupedResponseT nearDepth(String depth, + Function> generateFn, GroupBy groupBy) { + return nearDepth(NearDepth.of(depth), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Query collection objects using near depth search. + * + * @param searchTarget Query target (base64-encoded depth). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearDepth(Target searchTarget, Function> fn, + Function> generateFn, + GroupBy groupBy) { + return nearDepth(NearDepth.of(searchTarget, fn), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Query collection objects using near depth search. + * + * @param searchTarget Query target (base64-encoded depth). + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * + * @see GroupBy + * @see GenerativeResponseGrouped + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public GroupedResponseT nearDepth(Target searchTarget, + Function> generateFn, GroupBy groupBy) { + return nearDepth(NearDepth.of(searchTarget), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Query collection objects using near depth search. + * + * @param depth Query depth (base64-encoded). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearDepth(String depth, Function> fn, + Function> generateFn, + GroupBy groupBy) { + return nearDepth(NearDepth.of(depth, fn), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Query collection objects using near depth search. + * + * @param query Near depth query request. + * @param generate Generative task. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearDepth(NearDepth query, GenerativeTask generate, GroupBy groupBy) { + return performRequest(query, generate, groupBy); + } + + // NearImu queries -------------------------------------------------------- + + /** + * Query collection objects using near IMU search. + * + * @param imu Query IMU (base64-encoded). + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearImu(String imu, + Function> generateFn) { + return nearImu(NearImu.of(imu), GenerativeTask.of(generateFn)); + } + + /** + * Query collection objects using near IMU search. + * + * @param imu Query IMU (base64-encoded). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearImu(String imu, Function> fn, + Function> generateFn) { + return nearImu(NearImu.of(imu, fn), GenerativeTask.of(generateFn)); + } + + /** + * Query collection objects using near IMU search. + * + * @param searchTarget Query target (base64-encoded IMU). + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearImu(Target searchTarget, + Function> generateFn) { + return nearImu(NearImu.of(searchTarget), GenerativeTask.of(generateFn)); + } + + /** + * Query collection objects using near IMU search. + * + * @param searchTarget Query target (base64-encoded IMU). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearImu(Target searchTarget, Function> fn, + Function> generateFn) { + return nearImu(NearImu.of(searchTarget, fn), GenerativeTask.of(generateFn)); + } + + /** + * Query collection objects using near IMU search. + * + * @param query Near IMU query request. + * @param generate Generative task. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearImu(NearImu query, GenerativeTask generate) { + return performRequest(query, generate); + } + + /** + * Query collection objects using near IMU search. + * + * @param imu Query IMU (base64-encoded). + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * + * @see GroupBy + * @see GenerativeResponseGrouped + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public GroupedResponseT nearImu(String imu, + Function> generateFn, GroupBy groupBy) { + return nearImu(NearImu.of(imu), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Query collection objects using near IMU search. + * + * @param searchTarget Query target (base64-encoded IMU). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearImu(Target searchTarget, Function> fn, + Function> generateFn, + GroupBy groupBy) { + return nearImu(NearImu.of(searchTarget, fn), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Query collection objects using near IMU search. + * + * @param searchTarget Query target (base64-encoded IMU). + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * + * @see GroupBy + * @see GenerativeResponseGrouped + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public GroupedResponseT nearImu(Target searchTarget, + Function> generateFn, GroupBy groupBy) { + return nearImu(NearImu.of(searchTarget), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Query collection objects using near IMU search. + * + * @param imu Query IMU (base64-encoded). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearImu(String imu, Function> fn, + Function> generateFn, + GroupBy groupBy) { + return nearImu(NearImu.of(imu, fn), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Query collection objects using near IMU search. + * + * @param query Near IMU query request. + * @param generate Generative task. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearImu(NearImu query, GenerativeTask generate, GroupBy groupBy) { + return performRequest(query, generate, groupBy); + } } From f8f7bdf803cea5da49ffbcd4c07c8c83dd665319 Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Tue, 21 Oct 2025 13:21:19 +0200 Subject: [PATCH 09/25] feat: add configurations for generative modules: - Mistral - Anyscale - Databricks --- .../v1/api/collections/Generative.java | 9 +++ .../generative/AnyscaleGenerative.java | 65 +++++++++++++++ .../generative/DatabricksGenerative.java | 81 +++++++++++++++++++ .../generative/MistralGenerative.java | 73 +++++++++++++++++ 4 files changed, 228 insertions(+) create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/generative/AnyscaleGenerative.java create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/generative/DatabricksGenerative.java create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/generative/MistralGenerative.java diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java b/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java index 74ffb2f4d..e3f72af65 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java @@ -13,14 +13,20 @@ import com.google.gson.stream.JsonToken; import com.google.gson.stream.JsonWriter; +import io.weaviate.client6.v1.api.collections.generative.AnyscaleGenerative; import io.weaviate.client6.v1.api.collections.generative.CohereGenerative; +import io.weaviate.client6.v1.api.collections.generative.DatabricksGenerative; import io.weaviate.client6.v1.api.collections.generative.DummyGenerative; +import io.weaviate.client6.v1.api.collections.generative.MistralGenerative; import io.weaviate.client6.v1.internal.ObjectBuilder; import io.weaviate.client6.v1.internal.json.JsonEnum; public interface Generative { public enum Kind implements JsonEnum { + ANYSCALE("generative-anyscale"), COHERE("generative-cohere"), + DATABRICKS("generative-databricks"), + MISTRAL("generative-mistral"), DUMMY("generative-dummy"); private static final Map jsonValueMap = JsonEnum.collectNames(Kind.values()); @@ -69,7 +75,10 @@ private final void addAdapter(Gson gson, Generative.Kind kind, Class> fn) { + return fn.apply(new Builder()).build(); + } + + public AnyscaleGenerative(Builder builder) { + this( + builder.baseUrl, + builder.model, + builder.temperature); + } + + public static class Builder implements ObjectBuilder { + private String baseUrl; + private String model; + private Float temperature; + + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + public Builder model(String model) { + this.model = model; + return this; + } + + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public AnyscaleGenerative build() { + return new AnyscaleGenerative(this); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/DatabricksGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/DatabricksGenerative.java new file mode 100644 index 000000000..d17bf6f0f --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/DatabricksGenerative.java @@ -0,0 +1,81 @@ +package io.weaviate.client6.v1.api.collections.generative; + +import java.util.function.Function; + +import com.google.gson.annotations.SerializedName; + +import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.internal.ObjectBuilder; + +public record DatabricksGenerative( + @SerializedName("endpoint") String baseUrl, + @SerializedName("maxTokens") Integer maxTokens, + @SerializedName("topK") Integer topK, + @SerializedName("topP") Float topP, + @SerializedName("temperature") Float temperature) implements Generative { + + @Override + public Kind _kind() { + return Generative.Kind.DATABRICKS; + } + + @Override + public Object _self() { + return this; + } + + public static DatabricksGenerative of() { + return of(ObjectBuilder.identity()); + } + + public static DatabricksGenerative of(Function> fn) { + return fn.apply(new Builder()).build(); + } + + public DatabricksGenerative(Builder builder) { + this( + builder.endpoint, + builder.maxTokens, + builder.topK, + builder.topP, + builder.temperature); + } + + public static class Builder implements ObjectBuilder { + private String endpoint; + private Integer maxTokens; + private Integer topK; + private Float topP; + private Float temperature; + + public Builder endpoint(String endpoint) { + this.endpoint = endpoint; + return this; + } + + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + public Builder topK(int topK) { + this.topK = topK; + return this; + } + + public Builder topP(float topP) { + this.topP = topP; + return this; + } + + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public DatabricksGenerative build() { + return new DatabricksGenerative(this); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/MistralGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/MistralGenerative.java new file mode 100644 index 000000000..593a4d7b9 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/MistralGenerative.java @@ -0,0 +1,73 @@ +package io.weaviate.client6.v1.api.collections.generative; + +import java.util.function.Function; + +import com.google.gson.annotations.SerializedName; + +import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.internal.ObjectBuilder; + +public record MistralGenerative( + @SerializedName("baseURL") String baseUrl, + @SerializedName("model") String model, + @SerializedName("maxTokens") Integer maxTokens, + @SerializedName("temperature") Float temperature) implements Generative { + + @Override + public Kind _kind() { + return Generative.Kind.MISTRAL; + } + + @Override + public Object _self() { + return this; + } + + public static MistralGenerative of() { + return of(ObjectBuilder.identity()); + } + + public static MistralGenerative of(Function> fn) { + return fn.apply(new Builder()).build(); + } + + public MistralGenerative(Builder builder) { + this( + builder.baseUrl, + builder.model, + builder.maxTokens, + builder.temperature); + } + + public static class Builder implements ObjectBuilder { + private String baseUrl; + private String model; + private Integer maxTokens; + private Float temperature; + + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + public Builder model(String model) { + this.model = model; + return this; + } + + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public MistralGenerative build() { + return new MistralGenerative(this); + } + } +} From d59a7c24c2486f4a882df3094ecfa51f4ae20f2e Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Tue, 21 Oct 2025 15:23:50 +0200 Subject: [PATCH 10/25] feat: add generative modules - Anthropic - AWS - Azure / OpenAI - Friendliai - Google - Nvidia - Ollama - Xai --- .../v1/api/collections/Generative.java | 27 +++++ .../generative/AnthropicGenerative.java | 94 +++++++++++++++ .../generative/AnyscaleGenerative.java | 4 + .../collections/generative/AwsGenerative.java | 71 +++++++++++ .../generative/AzureOpenAiGenerative.java | 110 ++++++++++++++++++ .../generative/CohereGenerative.java | 59 ++++++---- .../generative/DatabricksGenerative.java | 25 ++-- .../generative/FriendliaiGenerative.java | 80 +++++++++++++ .../generative/GoogleGenerative.java | 106 +++++++++++++++++ .../generative/MistralGenerative.java | 7 ++ .../generative/NvidiaGenerative.java | 80 +++++++++++++ .../generative/OllamaGenerative.java | 59 ++++++++++ .../generative/OpenAiGenerative.java | 106 +++++++++++++++++ .../collections/generative/XaiGenerative.java | 80 +++++++++++++ 14 files changed, 877 insertions(+), 31 deletions(-) create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/generative/AnthropicGenerative.java create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/generative/AwsGenerative.java create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/generative/AzureOpenAiGenerative.java create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/generative/FriendliaiGenerative.java create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/generative/GoogleGenerative.java create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/generative/NvidiaGenerative.java create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/generative/OllamaGenerative.java create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/generative/OpenAiGenerative.java create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/generative/XaiGenerative.java diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java b/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java index e3f72af65..49fda595e 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java @@ -13,20 +13,38 @@ import com.google.gson.stream.JsonToken; import com.google.gson.stream.JsonWriter; +import io.weaviate.client6.v1.api.collections.generative.AnthropicGenerative; import io.weaviate.client6.v1.api.collections.generative.AnyscaleGenerative; +import io.weaviate.client6.v1.api.collections.generative.AwsGenerative; +import io.weaviate.client6.v1.api.collections.generative.AzureOpenAiGenerative; import io.weaviate.client6.v1.api.collections.generative.CohereGenerative; import io.weaviate.client6.v1.api.collections.generative.DatabricksGenerative; import io.weaviate.client6.v1.api.collections.generative.DummyGenerative; +import io.weaviate.client6.v1.api.collections.generative.FriendliaiGenerative; +import io.weaviate.client6.v1.api.collections.generative.GoogleGenerative; import io.weaviate.client6.v1.api.collections.generative.MistralGenerative; +import io.weaviate.client6.v1.api.collections.generative.NvidiaGenerative; +import io.weaviate.client6.v1.api.collections.generative.OllamaGenerative; +import io.weaviate.client6.v1.api.collections.generative.OpenAiGenerative; +import io.weaviate.client6.v1.api.collections.generative.XaiGenerative; import io.weaviate.client6.v1.internal.ObjectBuilder; import io.weaviate.client6.v1.internal.json.JsonEnum; public interface Generative { public enum Kind implements JsonEnum { ANYSCALE("generative-anyscale"), + AWS("generative-aws"), + ANTHROPIC("generative-anthropic"), COHERE("generative-cohere"), DATABRICKS("generative-databricks"), + FRIENDLIAI("generative-friendliai"), + GOOGLE("generative-google"), MISTRAL("generative-mistral"), + NVIDIA("generative-nvidia"), + OLLAMA("generative-ollama"), + OPENAI("generative-openai"), + AZURE_OPENAI("generative-openai"), + XAI("generative-xai"), DUMMY("generative-dummy"); private static final Map jsonValueMap = JsonEnum.collectNames(Kind.values()); @@ -76,9 +94,18 @@ private final void addAdapter(Gson gson, Generative.Kind kind, Class stopSequences) implements Generative { + + @Override + public Kind _kind() { + return Generative.Kind.ANTHROPIC; + } + + @Override + public Object _self() { + return this; + } + + public static AnthropicGenerative of() { + return of(ObjectBuilder.identity()); + } + + public static AnthropicGenerative of(Function> fn) { + return fn.apply(new Builder()).build(); + } + + public AnthropicGenerative(Builder builder) { + this( + builder.model, + builder.maxTokens, + builder.temperature, + builder.topK, + builder.stopSequences); + } + + public static class Builder implements ObjectBuilder { + private Integer topK; + private String model; + private Integer maxTokens; + private Float temperature; + private List stopSequences = new ArrayList<>(); + + public Builder topK(int topK) { + this.topK = topK; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** Limit the number of tokens to generate in the response. */ + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + public Builder stopSequences(String... stopSequences) { + return stopSequences(Arrays.asList(stopSequences)); + } + + public Builder stopSequences(List stopSequences) { + this.stopSequences = stopSequences; + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public AnthropicGenerative build() { + return new AnthropicGenerative(this); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnyscaleGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnyscaleGenerative.java index b69f18cb2..e480c3ac8 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnyscaleGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnyscaleGenerative.java @@ -52,6 +52,10 @@ public Builder model(String model) { return this; } + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ public Builder temperature(float temperature) { this.temperature = temperature; return this; diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AwsGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AwsGenerative.java new file mode 100644 index 000000000..0f1bf16d2 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AwsGenerative.java @@ -0,0 +1,71 @@ +package io.weaviate.client6.v1.api.collections.generative; + +import java.util.function.Function; + +import com.google.gson.annotations.SerializedName; + +import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.internal.ObjectBuilder; + +public record AwsGenerative( + @SerializedName("region") String region, + @SerializedName("service") String service, + @SerializedName("endpoint") String baseURL, + @SerializedName("model") String model) implements Generative { + + @Override + public Kind _kind() { + return Generative.Kind.AWS; + } + + @Override + public Object _self() { + return this; + } + + public static AwsGenerative of(String region, String service) { + return of(region, service, ObjectBuilder.identity()); + } + + public static AwsGenerative of(String region, String service, Function> fn) { + return fn.apply(new Builder(region, service)).build(); + } + + public AwsGenerative(Builder builder) { + this( + builder.service, + builder.region, + builder.baseUrl, + builder.model); + } + + public static class Builder implements ObjectBuilder { + private final String region; + private final String service; + + public Builder(String service, String region) { + this.service = service; + this.region = region; + } + + private String baseUrl; + private String model; + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + @Override + public AwsGenerative build() { + return new AwsGenerative(this); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AzureOpenAiGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AzureOpenAiGenerative.java new file mode 100644 index 000000000..a63195ff4 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AzureOpenAiGenerative.java @@ -0,0 +1,110 @@ +package io.weaviate.client6.v1.api.collections.generative; + +import java.util.function.Function; + +import com.google.gson.annotations.SerializedName; + +import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.internal.ObjectBuilder; + +public record AzureOpenAiGenerative( + @SerializedName("baseURL") String baseUrl, + @SerializedName("frequencyPenaltyProperty") Float frequencyPenalty, + @SerializedName("presencePenaltyProperty") Float presencePenalty, + @SerializedName("maxTokensProperty") Integer maxTokens, + @SerializedName("temperatureProperty") Float temperature, + @SerializedName("topPProperty") Float topP, + + @SerializedName("resourceName") String resourceName, + @SerializedName("deploymentId") String deploymentId) implements Generative { + + @Override + public Kind _kind() { + return Generative.Kind.AZURE_OPENAI; + } + + @Override + public Object _self() { + return this; + } + + public static AzureOpenAiGenerative of(String resourceName, String deploymentId) { + return of(resourceName, deploymentId, ObjectBuilder.identity()); + } + + public static AzureOpenAiGenerative of(String resourceName, String deploymentId, + Function> fn) { + return fn.apply(new Builder(resourceName, deploymentId)).build(); + } + + public AzureOpenAiGenerative(Builder builder) { + this( + builder.baseUrl, + builder.frequencyPenalty, + builder.presencePenalty, + builder.maxTokens, + builder.temperature, + builder.topP, + builder.resourceName, + builder.deploymentId); + } + + public static class Builder implements ObjectBuilder { + private final String resourceName; + private final String deploymentId; + + private String baseUrl; + private Float frequencyPenalty; + private Float presencePenalty; + private Integer maxTokens; + private Float temperature; + private Float topP; + + public Builder(String resourceName, String deploymentId) { + this.resourceName = resourceName; + this.deploymentId = deploymentId; + } + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + /** Limit the number of tokens to generate in the response. */ + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + public Builder frequencyPenalty(float frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + return this; + } + + public Builder presencePenalty(float presencePenalty) { + this.presencePenalty = presencePenalty; + return this; + } + + /** Top P value for nucleus sampling. */ + public Builder topP(float topP) { + this.topP = topP; + return this; + } + + @Override + public AzureOpenAiGenerative build() { + return new AzureOpenAiGenerative(this); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/CohereGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/CohereGenerative.java index b95ffc601..39d463317 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/CohereGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/CohereGenerative.java @@ -11,12 +11,13 @@ import io.weaviate.client6.v1.internal.ObjectBuilder; public record CohereGenerative( - @SerializedName("kProperty") String kProperty, + @SerializedName("baseURL") String baseUrl, + @SerializedName("kProperty") Integer k, @SerializedName("model") String model, - @SerializedName("maxTokensProperty") Integer maxTokensProperty, + @SerializedName("maxTokensProperty") Integer maxTokens, + @SerializedName("temperatureProperty") Float temperature, @SerializedName("returnLikelihoodsProperty") String returnLikelihoodsProperty, - @SerializedName("stopSequencesProperty") List stopSequencesProperty, - @SerializedName("temperatureProperty") String temperatureProperty) implements Generative { + @SerializedName("stopSequencesProperty") List stopSequences) implements Generative { @Override public Kind _kind() { @@ -38,34 +39,44 @@ public static CohereGenerative of(Function { - private String kProperty; + private String baseUrl; + private Integer k; private String model; - private Integer maxTokensProperty; + private Integer maxTokens; + private Float temperature; private String returnLikelihoodsProperty; - private List stopSequencesProperty = new ArrayList<>(); - private String temperatureProperty; + private List stopSequences = new ArrayList<>(); - public Builder kProperty(String kProperty) { - this.kProperty = kProperty; + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; return this; } + public Builder k(int k) { + this.k = k; + return this; + } + + /** Select generative model. */ public Builder model(String model) { this.model = model; return this; } - public Builder maxTokensProperty(int maxTokensProperty) { - this.maxTokensProperty = maxTokensProperty; + /** Limit the number of tokens to generate in the response. */ + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; return this; } @@ -74,17 +85,21 @@ public Builder returnLikelihoodsProperty(String returnLikelihoodsProperty) { return this; } - public Builder stopSequencesProperty(String... stopSequencesProperty) { - return stopSequencesProperty(Arrays.asList(stopSequencesProperty)); + public Builder stopSequences(String... stopSequences) { + return stopSequences(Arrays.asList(stopSequences)); } - public Builder stopSequencesProperty(List stopSequencesProperty) { - this.stopSequencesProperty = stopSequencesProperty; + public Builder stopSequences(List stopSequences) { + this.stopSequences = stopSequences; return this; } - public Builder temperatureProperty(String temperatureProperty) { - this.temperatureProperty = temperatureProperty; + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; return this; } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/DatabricksGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/DatabricksGenerative.java index d17bf6f0f..43f344662 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/DatabricksGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/DatabricksGenerative.java @@ -24,17 +24,17 @@ public Object _self() { return this; } - public static DatabricksGenerative of() { - return of(ObjectBuilder.identity()); + public static DatabricksGenerative of(String baseURL) { + return of(baseURL, ObjectBuilder.identity()); } - public static DatabricksGenerative of(Function> fn) { - return fn.apply(new Builder()).build(); + public static DatabricksGenerative of(String baseURL, Function> fn) { + return fn.apply(new Builder(baseURL)).build(); } public DatabricksGenerative(Builder builder) { this( - builder.endpoint, + builder.baseURL, builder.maxTokens, builder.topK, builder.topP, @@ -42,32 +42,39 @@ public DatabricksGenerative(Builder builder) { } public static class Builder implements ObjectBuilder { - private String endpoint; + private final String baseURL; + private Integer maxTokens; private Integer topK; private Float topP; private Float temperature; - public Builder endpoint(String endpoint) { - this.endpoint = endpoint; - return this; + public Builder(String baseURL) { + this.baseURL = baseURL; } + /** Limit the number of tokens to generate in the response. */ public Builder maxTokens(int maxTokens) { this.maxTokens = maxTokens; return this; } + /** Top K value for sampling. */ public Builder topK(int topK) { this.topK = topK; return this; } + /** Top P value for nucleus sampling. */ public Builder topP(float topP) { this.topP = topP; return this; } + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ public Builder temperature(float temperature) { this.temperature = temperature; return this; diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/FriendliaiGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/FriendliaiGenerative.java new file mode 100644 index 000000000..45b6e4e60 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/FriendliaiGenerative.java @@ -0,0 +1,80 @@ +package io.weaviate.client6.v1.api.collections.generative; + +import java.util.function.Function; + +import com.google.gson.annotations.SerializedName; + +import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.internal.ObjectBuilder; + +public record FriendliaiGenerative( + @SerializedName("baseURL") String baseUrl, + @SerializedName("model") String model, + @SerializedName("maxTokens") Integer maxTokens, + @SerializedName("temperature") Float temperature) implements Generative { + + @Override + public Kind _kind() { + return Generative.Kind.FRIENDLIAI; + } + + @Override + public Object _self() { + return this; + } + + public static FriendliaiGenerative of() { + return of(ObjectBuilder.identity()); + } + + public static FriendliaiGenerative of(Function> fn) { + return fn.apply(new Builder()).build(); + } + + public FriendliaiGenerative(Builder builder) { + this( + builder.baseUrl, + builder.model, + builder.maxTokens, + builder.temperature); + } + + public static class Builder implements ObjectBuilder { + private String baseUrl; + private String model; + private Integer maxTokens; + private Float temperature; + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + /** Limit the number of tokens to generate in the response. */ + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public FriendliaiGenerative build() { + return new FriendliaiGenerative(this); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/GoogleGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/GoogleGenerative.java new file mode 100644 index 000000000..d7404134e --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/GoogleGenerative.java @@ -0,0 +1,106 @@ +package io.weaviate.client6.v1.api.collections.generative; + +import java.util.function.Function; + +import com.google.gson.annotations.SerializedName; + +import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.internal.ObjectBuilder; + +public record GoogleGenerative( + @SerializedName("apiEndpoint") String baseUrl, + @SerializedName("modelId") String model, + @SerializedName("projectId") String projectId, + @SerializedName("maxOutputTokens") Integer maxTokens, + @SerializedName("topK") Integer topK, + @SerializedName("topP") Float topP, + @SerializedName("temperature") Float temperature) implements Generative { + + @Override + public Kind _kind() { + return Generative.Kind.GOOGLE; + } + + @Override + public Object _self() { + return this; + } + + public static GoogleGenerative of(String projectId) { + return of(projectId, ObjectBuilder.identity()); + } + + public static GoogleGenerative of(String projectId, Function> fn) { + return fn.apply(new Builder(projectId)).build(); + } + + public GoogleGenerative(Builder builder) { + this( + builder.baseUrl, + builder.model, + builder.projectId, + builder.maxTokens, + builder.topK, + builder.topP, + builder.temperature); + } + + public static class Builder implements ObjectBuilder { + private final String projectId; + + private String baseUrl; + private String model; + private Integer maxTokens; + private Integer topK; + private Float topP; + private Float temperature; + + public Builder(String projectId) { + this.projectId = projectId; + } + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** Limit the number of tokens to generate in the response. */ + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + /** Top K value for sampling. */ + public Builder topK(int topK) { + this.topK = topK; + return this; + } + + /** Top P value for nucleus sampling. */ + public Builder topP(float topP) { + this.topP = topP; + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public GoogleGenerative build() { + return new GoogleGenerative(this); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/MistralGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/MistralGenerative.java index 593a4d7b9..85c76e1fa 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/MistralGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/MistralGenerative.java @@ -45,21 +45,28 @@ public static class Builder implements ObjectBuilder { private Integer maxTokens; private Float temperature; + /** Base URL of the generative provider. */ public Builder baseUrl(String baseUrl) { this.baseUrl = baseUrl; return this; } + /** Limit the number of tokens to generate in the response. */ public Builder maxTokens(int maxTokens) { this.maxTokens = maxTokens; return this; } + /** Select generative model. */ public Builder model(String model) { this.model = model; return this; } + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ public Builder temperature(float temperature) { this.temperature = temperature; return this; diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/NvidiaGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/NvidiaGenerative.java new file mode 100644 index 000000000..0d79c3d7e --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/NvidiaGenerative.java @@ -0,0 +1,80 @@ +package io.weaviate.client6.v1.api.collections.generative; + +import java.util.function.Function; + +import com.google.gson.annotations.SerializedName; + +import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.internal.ObjectBuilder; + +public record NvidiaGenerative( + @SerializedName("baseURL") String baseUrl, + @SerializedName("model") String model, + @SerializedName("maxTokens") Integer maxTokens, + @SerializedName("temperature") Float temperature) implements Generative { + + @Override + public Kind _kind() { + return Generative.Kind.NVIDIA; + } + + @Override + public Object _self() { + return this; + } + + public static NvidiaGenerative of() { + return of(ObjectBuilder.identity()); + } + + public static NvidiaGenerative of(Function> fn) { + return fn.apply(new Builder()).build(); + } + + public NvidiaGenerative(Builder builder) { + this( + builder.baseUrl, + builder.model, + builder.maxTokens, + builder.temperature); + } + + public static class Builder implements ObjectBuilder { + private String baseUrl; + private String model; + private Integer maxTokens; + private Float temperature; + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + /** Limit the number of tokens to generate in the response. */ + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public NvidiaGenerative build() { + return new NvidiaGenerative(this); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/OllamaGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/OllamaGenerative.java new file mode 100644 index 000000000..2fd3986c7 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/OllamaGenerative.java @@ -0,0 +1,59 @@ +package io.weaviate.client6.v1.api.collections.generative; + +import java.util.function.Function; + +import com.google.gson.annotations.SerializedName; + +import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.internal.ObjectBuilder; + +public record OllamaGenerative( + @SerializedName("apiEndpoint") String apiEndpoint, + @SerializedName("model") String model) implements Generative { + + @Override + public Kind _kind() { + return Generative.Kind.OLLAMA; + } + + @Override + public Object _self() { + return this; + } + + public static OllamaGenerative of() { + return of(ObjectBuilder.identity()); + } + + public static OllamaGenerative of(Function> fn) { + return fn.apply(new Builder()).build(); + } + + public OllamaGenerative(Builder builder) { + this( + builder.apiEndpoint, + builder.model); + } + + public static class Builder implements ObjectBuilder { + private String apiEndpoint; + private String model; + + /** Destination endpoint of the generative provider. */ + public Builder apiEndpoint(String apiEndpoint) { + this.apiEndpoint = apiEndpoint; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + @Override + public OllamaGenerative build() { + return new OllamaGenerative(this); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/OpenAiGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/OpenAiGenerative.java new file mode 100644 index 000000000..a38ca7d68 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/OpenAiGenerative.java @@ -0,0 +1,106 @@ +package io.weaviate.client6.v1.api.collections.generative; + +import java.util.function.Function; + +import com.google.gson.annotations.SerializedName; + +import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.internal.ObjectBuilder; + +public record OpenAiGenerative( + @SerializedName("baseURL") String baseUrl, + @SerializedName("frequencyPenaltyProperty") Float frequencyPenalty, + @SerializedName("presencePenaltyProperty") Float presencePenalty, + @SerializedName("maxTokensProperty") Integer maxTokens, + @SerializedName("temperatureProperty") Float temperature, + @SerializedName("topPProperty") Float topP, + + @SerializedName("model") String model) implements Generative { + + @Override + public Kind _kind() { + return Generative.Kind.OPENAI; + } + + @Override + public Object _self() { + return this; + } + + public static OpenAiGenerative of() { + return of(ObjectBuilder.identity()); + } + + public static OpenAiGenerative of(Function> fn) { + return fn.apply(new Builder()).build(); + } + + public OpenAiGenerative(Builder builder) { + this( + builder.baseUrl, + builder.frequencyPenalty, + builder.presencePenalty, + builder.maxTokens, + builder.temperature, + builder.topP, + builder.model); + } + + public static class Builder implements ObjectBuilder { + private String baseUrl; + private Float frequencyPenalty; + private Float presencePenalty; + private Integer maxTokens; + private Float temperature; + private Float topP; + private String model; + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** Limit the number of tokens to generate in the response. */ + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + public Builder frequencyPenalty(float frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + return this; + } + + public Builder presencePenalty(float presencePenalty) { + this.presencePenalty = presencePenalty; + return this; + } + + /** Top P value for nucleus sampling. */ + public Builder topP(float topP) { + this.topP = topP; + return this; + } + + @Override + public OpenAiGenerative build() { + return new OpenAiGenerative(this); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/XaiGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/XaiGenerative.java new file mode 100644 index 000000000..b68a94e82 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/XaiGenerative.java @@ -0,0 +1,80 @@ +package io.weaviate.client6.v1.api.collections.generative; + +import java.util.function.Function; + +import com.google.gson.annotations.SerializedName; + +import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.internal.ObjectBuilder; + +public record XaiGenerative( + @SerializedName("baseURL") String baseUrl, + @SerializedName("model") String model, + @SerializedName("maxTokens") Integer maxTokens, + @SerializedName("temperature") Float temperature) implements Generative { + + @Override + public Kind _kind() { + return Generative.Kind.XAI; + } + + @Override + public Object _self() { + return this; + } + + public static XaiGenerative of() { + return of(ObjectBuilder.identity()); + } + + public static XaiGenerative of(Function> fn) { + return fn.apply(new Builder()).build(); + } + + public XaiGenerative(Builder builder) { + this( + builder.baseUrl, + builder.model, + builder.maxTokens, + builder.temperature); + } + + public static class Builder implements ObjectBuilder { + private String baseUrl; + private String model; + private Integer maxTokens; + private Float temperature; + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + /** Limit the number of tokens to generate in the response. */ + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public XaiGenerative build() { + return new XaiGenerative(this); + } + } +} From 15724ba79dcbd5c7c1c7a68e8d22cc0eec5aad5b Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Tue, 21 Oct 2025 15:34:19 +0200 Subject: [PATCH 11/25] test: update test code --- .../weaviate/client6/v1/internal/json/JSONTest.java | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java b/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java index 24b79dc15..9ca331158 100644 --- a/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java +++ b/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java @@ -430,21 +430,21 @@ public static Object[][] testCases() { { Generative.class, Generative.cohere(generate -> generate - .kProperty("k-property") - .maxTokensProperty(10) + .k(1) + .maxTokens(10) .model("example-model") .returnLikelihoodsProperty("likelihood") - .stopSequencesProperty("stop", "halt") - .temperatureProperty("celcius")), + .stopSequences("stop", "halt") + .temperature(.2f)), """ { "generative-cohere": { - "kProperty": "k-property", + "kProperty": 1, "maxTokensProperty": 10, "model": "example-model", "returnLikelihoodsProperty": "likelihood", "stopSequencesProperty": ["stop", "halt"], - "temperatureProperty": "celcius" + "temperatureProperty": 0.2 } } """, From c8b07703868da3d26975bfd5290670b6640fb790 Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Tue, 21 Oct 2025 19:58:43 +0200 Subject: [PATCH 12/25] fix: rename google's 'kind' to 'palm' --- .../java/io/weaviate/client6/v1/api/collections/Generative.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java b/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java index 49fda595e..67a5a6154 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java @@ -38,7 +38,7 @@ public enum Kind implements JsonEnum { COHERE("generative-cohere"), DATABRICKS("generative-databricks"), FRIENDLIAI("generative-friendliai"), - GOOGLE("generative-google"), + GOOGLE("generative-palm"), MISTRAL("generative-mistral"), NVIDIA("generative-nvidia"), OLLAMA("generative-ollama"), From bd2deb7914eac26a64dc1b19935a2415bf011503 Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Tue, 21 Oct 2025 20:57:35 +0200 Subject: [PATCH 13/25] feat: read provider metadata from generative response --- .../generate/GenerativeResponse.java | 110 ++++++++++++++++++ .../generate/ProviderMetadata.java | 7 -- .../api/collections/generate/TaskOutput.java | 2 + .../generative/AnthropicGenerative.java | 11 ++ .../generative/AnyscaleGenerative.java | 9 ++ .../collections/generative/AwsGenerative.java | 11 +- .../generative/CohereGenerative.java | 19 +++ .../generative/DatabricksGenerative.java | 8 ++ .../generative/FriendliaiGenerative.java | 8 ++ .../generative/GoogleGenerative.java | 17 +++ .../generative/MistralGenerative.java | 8 ++ .../generative/NvidiaGenerative.java | 8 ++ .../generative/OllamaGenerative.java | 9 ++ .../generative/OpenAiGenerative.java | 8 ++ .../generative/ProviderMetadata.java | 13 +++ .../collections/generative/XaiGenerative.java | 8 ++ src/main/proto/v1/generative.proto | 57 ++++----- 17 files changed, 271 insertions(+), 42 deletions(-) delete mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/generate/ProviderMetadata.java create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/generative/ProviderMetadata.java diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeResponse.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeResponse.java index 8bc76b862..82a75cefe 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeResponse.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeResponse.java @@ -1,8 +1,22 @@ package io.weaviate.client6.v1.api.collections.generate; +import java.util.ArrayList; import java.util.List; +import io.weaviate.client6.v1.api.collections.generative.AnthropicGenerative; +import io.weaviate.client6.v1.api.collections.generative.AnyscaleGenerative; +import io.weaviate.client6.v1.api.collections.generative.AwsGenerative; +import io.weaviate.client6.v1.api.collections.generative.CohereGenerative; +import io.weaviate.client6.v1.api.collections.generative.DatabricksGenerative; import io.weaviate.client6.v1.api.collections.generative.DummyGenerative; +import io.weaviate.client6.v1.api.collections.generative.FriendliaiGenerative; +import io.weaviate.client6.v1.api.collections.generative.GoogleGenerative; +import io.weaviate.client6.v1.api.collections.generative.MistralGenerative; +import io.weaviate.client6.v1.api.collections.generative.NvidiaGenerative; +import io.weaviate.client6.v1.api.collections.generative.OllamaGenerative; +import io.weaviate.client6.v1.api.collections.generative.OpenAiGenerative; +import io.weaviate.client6.v1.api.collections.generative.ProviderMetadata; +import io.weaviate.client6.v1.api.collections.generative.XaiGenerative; import io.weaviate.client6.v1.api.collections.query.QueryResponse; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet; @@ -49,6 +63,102 @@ static TaskOutput unmarshalTaskOutput(List(cohere.getWarnings().getValuesList())); + } else if (metadata.hasDatabricks()) { + var databricks = metadata.getDatabricks(); + var usage = databricks.getUsage(); + providerMetadata = new DatabricksGenerative.Metadata(new ProviderMetadata.Usage( + usage.hasPromptTokens() ? usage.getPromptTokens() : null, + usage.hasCompletionTokens() ? usage.getCompletionTokens() : null, + usage.hasTotalTokens() ? usage.getTotalTokens() : null)); + } else if (metadata.hasFriendliai()) { + var friendliai = metadata.getFriendliai(); + var usage = friendliai.getUsage(); + providerMetadata = new FriendliaiGenerative.Metadata(new ProviderMetadata.Usage( + usage.hasPromptTokens() ? usage.getPromptTokens() : null, + usage.hasCompletionTokens() ? usage.getCompletionTokens() : null, + usage.hasTotalTokens() ? usage.getTotalTokens() : null)); + } else if (metadata.hasMistral()) { + var mistral = metadata.getMistral(); + var usage = mistral.getUsage(); + providerMetadata = new MistralGenerative.Metadata(new ProviderMetadata.Usage( + usage.hasPromptTokens() ? usage.getPromptTokens() : null, + usage.hasCompletionTokens() ? usage.getCompletionTokens() : null, + usage.hasTotalTokens() ? usage.getTotalTokens() : null)); + } else if (metadata.hasNvidia()) { + var nvidia = metadata.getNvidia(); + var usage = nvidia.getUsage(); + providerMetadata = new NvidiaGenerative.Metadata(new ProviderMetadata.Usage( + usage.hasPromptTokens() ? usage.getPromptTokens() : null, + usage.hasCompletionTokens() ? usage.getCompletionTokens() : null, + usage.hasTotalTokens() ? usage.getTotalTokens() : null)); + } else if (metadata.hasOllama()) { + providerMetadata = new OllamaGenerative.Metadata(); + } else if (metadata.hasOpenai()) { + var openai = metadata.getOpenai(); + var usage = openai.getUsage(); + providerMetadata = new OpenAiGenerative.Metadata(new ProviderMetadata.Usage( + usage.hasPromptTokens() ? usage.getPromptTokens() : null, + usage.hasCompletionTokens() ? usage.getCompletionTokens() : null, + usage.hasTotalTokens() ? usage.getTotalTokens() : null)); + } else if (metadata.hasGoogle()) { + var google = metadata.getGoogle(); + var tokens = google.getMetadata().getTokenMetadata(); + var usage = google.getUsageMetadata(); + providerMetadata = new GoogleGenerative.Metadata( + new GoogleGenerative.Metadata.TokenMetadata( + new GoogleGenerative.Metadata.TokenCount( + tokens.getInputTokenCount().hasTotalBillableCharacters() + ? tokens.getInputTokenCount().getTotalBillableCharacters() + : null, + tokens.getInputTokenCount().hasTotalTokens() + ? tokens.getInputTokenCount().getTotalTokens() + : null), + new GoogleGenerative.Metadata.TokenCount( + tokens.getOutputTokenCount().hasTotalBillableCharacters() + ? tokens.getOutputTokenCount().getTotalBillableCharacters() + : null, + tokens.getOutputTokenCount().hasTotalTokens() + ? tokens.getOutputTokenCount().getTotalTokens() + : null)), + new GoogleGenerative.Metadata.Usage( + usage.hasPromptTokenCount() ? usage.getPromptTokenCount() : null, + usage.hasCandidatesTokenCount() ? usage.getCandidatesTokenCount() : null, + usage.hasTotalTokenCount() ? usage.getTotalTokenCount() : null)); + } else if (metadata.hasXai()) { + var xai = metadata.getXai(); + var usage = xai.getUsage(); + providerMetadata = new XaiGenerative.Metadata(new ProviderMetadata.Usage( + usage.hasPromptTokens() ? usage.getPromptTokens() : null, + usage.hasCompletionTokens() ? usage.getCompletionTokens() : null, + usage.hasTotalTokens() ? usage.getTotalTokens() : null)); } GenerativeDebug debug = null; diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/ProviderMetadata.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/ProviderMetadata.java deleted file mode 100644 index d56908e7d..000000000 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generate/ProviderMetadata.java +++ /dev/null @@ -1,7 +0,0 @@ -package io.weaviate.client6.v1.api.collections.generate; - -import io.weaviate.client6.v1.api.collections.Generative; - -public interface ProviderMetadata { - Generative.Kind _kind(); -} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/TaskOutput.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/TaskOutput.java index a3becb78a..379e7ebc7 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generate/TaskOutput.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generate/TaskOutput.java @@ -1,5 +1,7 @@ package io.weaviate.client6.v1.api.collections.generate; +import io.weaviate.client6.v1.api.collections.generative.ProviderMetadata; + public record TaskOutput( String text, ProviderMetadata metadata, diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnthropicGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnthropicGenerative.java index aa4c58ede..131eb059d 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnthropicGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnthropicGenerative.java @@ -91,4 +91,15 @@ public AnthropicGenerative build() { return new AnthropicGenerative(this); } } + + public static record Metadata(Usage usage) implements ProviderMetadata { + + @Override + public Generative.Kind _kind() { + return Generative.Kind.ANTHROPIC; + } + + public static record Usage(Long inputTokens, Long outputTokens) { + } + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnyscaleGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnyscaleGenerative.java index e480c3ac8..576114bd2 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnyscaleGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnyscaleGenerative.java @@ -5,6 +5,7 @@ import com.google.gson.annotations.SerializedName; import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.api.collections.generate.ProviderMetadata; import io.weaviate.client6.v1.internal.ObjectBuilder; public record AnyscaleGenerative( @@ -66,4 +67,12 @@ public AnyscaleGenerative build() { return new AnyscaleGenerative(this); } } + + public static record Metadata() implements ProviderMetadata { + + @Override + public Generative.Kind _kind() { + return Generative.Kind.ANYSCALE; + } + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AwsGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AwsGenerative.java index 0f1bf16d2..f00ba18df 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AwsGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AwsGenerative.java @@ -5,6 +5,7 @@ import com.google.gson.annotations.SerializedName; import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.api.collections.generate.ProviderMetadata; import io.weaviate.client6.v1.internal.ObjectBuilder; public record AwsGenerative( @@ -14,7 +15,7 @@ public record AwsGenerative( @SerializedName("model") String model) implements Generative { @Override - public Kind _kind() { + public Generative.Kind _kind() { return Generative.Kind.AWS; } @@ -68,4 +69,12 @@ public AwsGenerative build() { return new AwsGenerative(this); } } + + public static record Metadata() implements ProviderMetadata { + + @Override + public Generative.Kind _kind() { + return Generative.Kind.AWS; + } + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/CohereGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/CohereGenerative.java index 39d463317..d6d9a5e5d 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/CohereGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/CohereGenerative.java @@ -108,4 +108,23 @@ public CohereGenerative build() { return new CohereGenerative(this); } } + + public static record Metadata(ApiVersion apiVersion, BilledUnits billedUnits, Tokens tokens, List warnings) + implements ProviderMetadata { + + @Override + public Generative.Kind _kind() { + return Generative.Kind.COHERE; + } + + public static record ApiVersion(String version, Boolean deprecated, Boolean experimental) { + } + + public static record BilledUnits(Double inputTokens, Double outputTokens, Double searchUnits, + Double classifications) { + } + + public static record Tokens(Double inputTokens, Double outputTokens) { + } + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/DatabricksGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/DatabricksGenerative.java index 43f344662..da82c3bf7 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/DatabricksGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/DatabricksGenerative.java @@ -85,4 +85,12 @@ public DatabricksGenerative build() { return new DatabricksGenerative(this); } } + + public static record Metadata(ProviderMetadata.Usage usage) implements ProviderMetadata { + + @Override + public Generative.Kind _kind() { + return Generative.Kind.DATABRICKS; + } + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/FriendliaiGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/FriendliaiGenerative.java index 45b6e4e60..300622525 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/FriendliaiGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/FriendliaiGenerative.java @@ -77,4 +77,12 @@ public FriendliaiGenerative build() { return new FriendliaiGenerative(this); } } + + public static record Metadata(ProviderMetadata.Usage usage) implements ProviderMetadata { + + @Override + public Generative.Kind _kind() { + return Generative.Kind.FRIENDLIAI; + } + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/GoogleGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/GoogleGenerative.java index d7404134e..2ced65d90 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/GoogleGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/GoogleGenerative.java @@ -103,4 +103,21 @@ public GoogleGenerative build() { return new GoogleGenerative(this); } } + + public static record Metadata(TokenMetadata tokens, Usage usage) implements ProviderMetadata { + + @Override + public Generative.Kind _kind() { + return Generative.Kind.GOOGLE; + } + + public static record TokenCount(Long totalBillableCharacters, Long totalTokens) { + } + + public static record TokenMetadata(TokenCount inputTokens, TokenCount outputTokens) { + } + + public static record Usage(Long promptTokenCount, Long candidatesTokenCount, Long totalTokenCount) { + } + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/MistralGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/MistralGenerative.java index 85c76e1fa..60357301d 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/MistralGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/MistralGenerative.java @@ -77,4 +77,12 @@ public MistralGenerative build() { return new MistralGenerative(this); } } + + public static record Metadata(ProviderMetadata.Usage usage) implements ProviderMetadata { + + @Override + public Generative.Kind _kind() { + return Generative.Kind.MISTRAL; + } + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/NvidiaGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/NvidiaGenerative.java index 0d79c3d7e..80e09eff1 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/NvidiaGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/NvidiaGenerative.java @@ -77,4 +77,12 @@ public NvidiaGenerative build() { return new NvidiaGenerative(this); } } + + public static record Metadata(ProviderMetadata.Usage usage) implements ProviderMetadata { + + @Override + public Generative.Kind _kind() { + return Generative.Kind.NVIDIA; + } + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/OllamaGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/OllamaGenerative.java index 2fd3986c7..3508f69e4 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/OllamaGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/OllamaGenerative.java @@ -5,6 +5,7 @@ import com.google.gson.annotations.SerializedName; import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.api.collections.generate.ProviderMetadata; import io.weaviate.client6.v1.internal.ObjectBuilder; public record OllamaGenerative( @@ -56,4 +57,12 @@ public OllamaGenerative build() { return new OllamaGenerative(this); } } + + public static record Metadata() implements ProviderMetadata { + + @Override + public Generative.Kind _kind() { + return Generative.Kind.OLLAMA; + } + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/OpenAiGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/OpenAiGenerative.java index a38ca7d68..78da286bc 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/OpenAiGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/OpenAiGenerative.java @@ -103,4 +103,12 @@ public OpenAiGenerative build() { return new OpenAiGenerative(this); } } + + public static record Metadata(ProviderMetadata.Usage usage) implements ProviderMetadata { + + @Override + public Generative.Kind _kind() { + return Generative.Kind.OPENAI; + } + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/ProviderMetadata.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/ProviderMetadata.java new file mode 100644 index 000000000..0d3dc27a1 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/ProviderMetadata.java @@ -0,0 +1,13 @@ +package io.weaviate.client6.v1.api.collections.generative; + +import io.weaviate.client6.v1.api.collections.Generative; + +public interface ProviderMetadata { + Generative.Kind _kind(); + + record Usage( + Long promptTokens, + Long completionTokens, + Long totalTokens) { + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/XaiGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/XaiGenerative.java index b68a94e82..c7b0a81d6 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/XaiGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/XaiGenerative.java @@ -77,4 +77,12 @@ public XaiGenerative build() { return new XaiGenerative(this); } } + + public static record Metadata(ProviderMetadata.Usage usage) implements ProviderMetadata { + + @Override + public Generative.Kind _kind() { + return Generative.Kind.XAI; + } + } } diff --git a/src/main/proto/v1/generative.proto b/src/main/proto/v1/generative.proto index 4e6e32525..fb3deda67 100644 --- a/src/main/proto/v1/generative.proto +++ b/src/main/proto/v1/generative.proto @@ -23,9 +23,9 @@ message GenerativeSearch { bool debug = 4; } - string single_response_prompt = 1 [deprecated = true]; - string grouped_response_task = 2 [deprecated = true]; - repeated string grouped_properties = 3 [deprecated = true]; + string single_response_prompt = 1 [ deprecated = true ]; + string grouped_response_task = 2 [ deprecated = true ]; + repeated string grouped_properties = 3 [ deprecated = true ]; Single single = 4; Grouped grouped = 5; } @@ -49,7 +49,7 @@ message GenerativeProvider { } } -message GenerativeAnthropic{ +message GenerativeAnthropic { optional string base_url = 1; optional int64 max_tokens = 2; optional string model = 3; @@ -61,13 +61,13 @@ message GenerativeAnthropic{ optional TextArray image_properties = 9; } -message GenerativeAnyscale{ +message GenerativeAnyscale { optional string base_url = 1; optional string model = 2; optional double temperature = 3; } -message GenerativeAWS{ +message GenerativeAWS { optional string model = 3; optional double temperature = 8; optional string service = 9; @@ -79,7 +79,7 @@ message GenerativeAWS{ optional TextArray image_properties = 15; } -message GenerativeCohere{ +message GenerativeCohere { optional string base_url = 1; optional double frequency_penalty = 2; optional int64 max_tokens = 3; @@ -91,10 +91,9 @@ message GenerativeCohere{ optional double temperature = 9; } -message GenerativeDummy{ -} +message GenerativeDummy {} -message GenerativeMistral{ +message GenerativeMistral { optional string base_url = 1; optional int64 max_tokens = 2; optional string model = 3; @@ -102,7 +101,7 @@ message GenerativeMistral{ optional double top_p = 5; } -message GenerativeOllama{ +message GenerativeOllama { optional string api_endpoint = 1; optional string model = 2; optional double temperature = 3; @@ -110,7 +109,7 @@ message GenerativeOllama{ optional TextArray image_properties = 5; } -message GenerativeOpenAI{ +message GenerativeOpenAI { optional double frequency_penalty = 1; optional int64 max_tokens = 2; optional string model = 3; @@ -128,7 +127,7 @@ message GenerativeOpenAI{ optional TextArray image_properties = 15; } -message GenerativeGoogle{ +message GenerativeGoogle { optional double frequency_penalty = 1; optional int64 max_tokens = 2; optional string model = 3; @@ -145,7 +144,7 @@ message GenerativeGoogle{ optional TextArray image_properties = 14; } -message GenerativeDatabricks{ +message GenerativeDatabricks { optional string endpoint = 1; optional string model = 2; optional double frequency_penalty = 3; @@ -159,7 +158,7 @@ message GenerativeDatabricks{ optional double top_p = 11; } -message GenerativeFriendliAI{ +message GenerativeFriendliAI { optional string base_url = 1; optional string model = 2; optional int64 max_tokens = 3; @@ -168,7 +167,7 @@ message GenerativeFriendliAI{ optional double top_p = 6; } -message GenerativeNvidia{ +message GenerativeNvidia { optional string base_url = 1; optional string model = 2; optional double temperature = 3; @@ -176,7 +175,7 @@ message GenerativeNvidia{ optional int64 max_tokens = 5; } -message GenerativeXAI{ +message GenerativeXAI { optional string base_url = 1; optional string model = 2; optional double temperature = 3; @@ -194,11 +193,9 @@ message GenerativeAnthropicMetadata { Usage usage = 1; } -message GenerativeAnyscaleMetadata { -} +message GenerativeAnyscaleMetadata {} -message GenerativeAWSMetadata { -} +message GenerativeAWSMetadata {} message GenerativeCohereMetadata { message ApiVersion { @@ -222,8 +219,7 @@ message GenerativeCohereMetadata { optional TextArray warnings = 4; } -message GenerativeDummyMetadata { -} +message GenerativeDummyMetadata {} message GenerativeMistralMetadata { message Usage { @@ -234,8 +230,7 @@ message GenerativeMistralMetadata { optional Usage usage = 1; } -message GenerativeOllamaMetadata { -} +message GenerativeOllamaMetadata {} message GenerativeOpenAIMetadata { message Usage { @@ -255,9 +250,7 @@ message GenerativeGoogleMetadata { optional TokenCount input_token_count = 1; optional TokenCount output_token_count = 2; } - message Metadata { - optional TokenMetadata token_metadata = 1; - } + message Metadata { optional TokenMetadata token_metadata = 1; } message UsageMetadata { optional int64 prompt_token_count = 1; optional int64 candidates_token_count = 2; @@ -327,10 +320,6 @@ message GenerativeReply { optional GenerativeMetadata metadata = 3; } -message GenerativeResult { - repeated GenerativeReply values = 1; -} +message GenerativeResult { repeated GenerativeReply values = 1; } -message GenerativeDebug { - optional string full_prompt = 1; -} +message GenerativeDebug { optional string full_prompt = 1; } From 5a05ea02809ef78658f8c653510643a8948f6839 Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Tue, 21 Oct 2025 21:02:05 +0200 Subject: [PATCH 14/25] chore: paraphrase generative javadoc --- .../generate/AbstractGenerateClient.java | 221 +++++++++--------- 1 file changed, 113 insertions(+), 108 deletions(-) diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/AbstractGenerateClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/AbstractGenerateClient.java index cde5101bb..c2ef15ea1 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generate/AbstractGenerateClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generate/AbstractGenerateClient.java @@ -51,7 +51,8 @@ abstract class AbstractGenerateClient // Object queries ----------------------------------------------------------- /** - * Retrieve objects without applying a Vector Search or Keyword Search filter. + * Retrieve objects without applying a Vector Search or Keyword Search filter + * and run a generative task on the query results. * * @param fn Lambda expression for optional search parameters. * @param generateFn Lambda expression for generative task parameters. @@ -64,7 +65,8 @@ public ResponseT fetchObjects(Function text, } /** - * Query collection objects using near text search. + * Run a generative task on the results of a near text search. * * @param searchTarget Target query concepts. * @param generateFn Lambda expression for generative task parameters. @@ -696,7 +701,7 @@ public ResponseT nearText(Target searchTarget, } /** - * Query collection objects using near text search. + * Run a generative task on the results of a near text search. * * @param searchTarget Target query concepts. * @param fn Lambda expression for optional parameters. @@ -711,7 +716,7 @@ public ResponseT nearText(Target searchTarget, } /** - * Query collection objects using near text search. + * Run a generative task on the results of a near text search. * * @param query Near text query request. * @param generate Generative task. @@ -723,7 +728,7 @@ public ResponseT nearText(NearText query, GenerativeTask generate) { } /** - * Query collection objects using near text search. + * Run a generative task on the results of a near text search. * * @param text Query concepts. * @param generateFn Lambda expression for generative task parameters. @@ -742,7 +747,7 @@ public GroupedResponseT nearText(String text, } /** - * Query collection objects using near text search. + * Run a generative task on the results of a near text search. * * @param text Query concepts. * @param generateFn Lambda expression for generative task parameters. @@ -760,7 +765,7 @@ public GroupedResponseT nearText(List text, } /** - * Query collection objects using near text search. + * Run a generative task on the results of a near text search. * * @param text Query concepts. * @param fn Lambda expression for optional parameters. @@ -781,7 +786,7 @@ public GroupedResponseT nearText(String text, } /** - * Query collection objects using near text search. + * Run a generative task on the results of a near text search. * * @param text Query concepts. * @param fn Lambda expression for optional parameters. @@ -802,7 +807,7 @@ public GroupedResponseT nearText(List text, } /** - * Query collection objects using near text search. + * Run a generative task on the results of a near text search. * * @param searchTarget Target query concepts. * @param generateFn Lambda expression for generative task parameters. @@ -820,7 +825,7 @@ public GroupedResponseT nearText(Target searchTarget, } /** - * Query collection objects using near text search. + * Run a generative task on the results of a near text search. * * @param searchTarget Target query concepts. * @param fn Lambda expression for optional parameters. @@ -841,7 +846,7 @@ public GroupedResponseT nearText(Target searchTarget, } /** - * Query collection objects using near text search. + * Run a generative task on the results of a near text search. * * @param query Near text query request. * @param generate Generative task. @@ -860,7 +865,7 @@ public GroupedResponseT nearText(NearText query, GenerativeTask generate, GroupB // NearImage queries -------------------------------------------------------- /** - * Query collection objects using near image search. + * Run a generative task on the results of a near image search. * * @param image Query image (base64-encoded). * @param generateFn Lambda expression for generative task parameters. @@ -873,7 +878,7 @@ public ResponseT nearImage(String image, } /** - * Query collection objects using near image search. + * Run a generative task on the results of a near image search. * * @param image Query image (base64-encoded). * @param fn Lambda expression for optional search parameters. @@ -887,7 +892,7 @@ public ResponseT nearImage(String image, Function Date: Tue, 21 Oct 2025 21:33:42 +0200 Subject: [PATCH 15/25] feat: provide static factories for generative providers --- .../v1/api/collections/Generative.java | 199 +++++++++++++++++- .../generative/AnyscaleGenerative.java | 1 - .../collections/generative/AwsGenerative.java | 1 - 3 files changed, 197 insertions(+), 4 deletions(-) diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java b/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java index 67a5a6154..43f0e99f6 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java @@ -68,13 +68,65 @@ public static Kind valueOfJson(String jsonValue) { Object _self(); - /** Configure a default Cohere generative module. */ + /** Configure a default {@code generative-anthropic} module. */ + public static Generative anthropic() { + return AnthropicGenerative.of(); + } + + /** + * Configure a {@code generative-anthropic} module. + * + * @param fn Lambda expression for optional parameters. + */ + public static Generative anthropic(Function> fn) { + return AnthropicGenerative.of(fn); + } + + /** Configure a default {@code generative-anyscale} module. */ + public static Generative anyscale() { + return AnyscaleGenerative.of(); + } + + /** + * Configure a {@code generative-anyscale} module. + * + * @param fn Lambda expression for optional parameters. + */ + public static Generative anyscale(Function> fn) { + return AnyscaleGenerative.of(fn); + } + + /** + * Configure a default {@code generative-aws} module. + * + * @param region AWS region. + * @param service AWS service to use, e.g. {@code "bedrock"} or + * {@code "sagemaker"}. + */ + public static Generative aws(String region, String service) { + return AwsGenerative.of(region, service); + } + + /** + * Configure a {@code generative-aws} module. + * + * @param region AWS region. + * @param service AWS service to use, e.g. {@code "bedrock"} or + * {@code "sagemaker"}. + * @param fn Lambda expression for optional parameters. + */ + public static Generative aws(String region, String service, + Function> fn) { + return AwsGenerative.of(region, service, fn); + } + + /** Configure a default {@code generative-cohere} module. */ public static Generative cohere() { return CohereGenerative.of(); } /** - * Configure a Cohere generative module. + * Configure a {@code generative-cohere} module. * * @param fn Lambda expression for optional parameters. */ @@ -82,6 +134,149 @@ public static Generative cohere(Function> fn) { + return DatabricksGenerative.of(baseURL, fn); + } + + /** Configure a default {@code generative-frienliai} module. */ + public static Generative frienliai() { + return FriendliaiGenerative.of(); + } + + /** + * Configure a {@code generative-frienliai} module. + * + * @param fn Lambda expression for optional parameters. + */ + public static Generative frienliai(Function> fn) { + return FriendliaiGenerative.of(fn); + } + + /** Configure a default {@code generative-palm} module. */ + public static Generative google(String projectId) { + return GoogleGenerative.of(projectId); + } + + /** + * Configure a {@code generative-palm} module. + * + * @param projectId Project ID. + * @param fn Lambda expression for optional parameters. + */ + public static Generative google(String projectId, + Function> fn) { + return GoogleGenerative.of(projectId, fn); + } + + /** Configure a default {@code generative-mistral} module. */ + public static Generative mistral() { + return MistralGenerative.of(); + } + + /** + * Configure a {@code generative-mistral} module. + * + * @param fn Lambda expression for optional parameters. + */ + public static Generative mistral(Function> fn) { + return MistralGenerative.of(fn); + } + + /** Configure a default {@code generative-nvidia} module. */ + public static Generative nvidia() { + return NvidiaGenerative.of(); + } + + /** + * Configure a {@code generative-nvidia} module. + * + * @param fn Lambda expression for optional parameters. + */ + public static Generative nvidia(Function> fn) { + return NvidiaGenerative.of(fn); + } + + /** Configure a default {@code generative-ollama} module. */ + public static Generative ollama() { + return OllamaGenerative.of(); + } + + /** + * Configure a {@code generative-ollama} module. + * + * @param fn Lambda expression for optional parameters. + */ + public static Generative ollama(Function> fn) { + return OllamaGenerative.of(fn); + } + + /** Configure a default {@code generative-openai} module. */ + public static Generative openai() { + return OpenAiGenerative.of(); + } + + /** + * Configure a {@code generative-openai} module. + * + * @param fn Lambda expression for optional parameters. + */ + public static Generative openai(Function> fn) { + return OpenAiGenerative.of(fn); + } + + /** + * Configure a default {@code generative-openai} module + * hosted on Microsoft Azure. + * + * @param resourceName Name of the Azure OpenAI resource. + * @param deploymentId Azure OpenAI deployment ID. + */ + public static Generative azure(String resourceName, String deploymentId) { + return AzureOpenAiGenerative.of(resourceName, deploymentId); + } + + /** + * Configure a {@code generative-openai} module hosted on Microsoft Azure. + * + * @param resourceName Name of the Azure OpenAI resource. + * @param deploymentId Azure OpenAI deployment ID. + * @param fn Lambda expression for optional parameters. + */ + public static Generative azure(String resourceName, String deploymentId, + Function> fn) { + return AzureOpenAiGenerative.of(resourceName, deploymentId, fn); + } + + /** Configure a default {@code generative-xai} module. */ + public static Generative xai() { + return XaiGenerative.of(); + } + + /** + * Configure a {@code generative-xai} module. + * + * @param fn Lambda expression for optional parameters. + */ + public static Generative xai(Function> fn) { + return XaiGenerative.of(fn); + } + public static enum CustomTypeAdapterFactory implements TypeAdapterFactory { INSTANCE; diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnyscaleGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnyscaleGenerative.java index 576114bd2..028412092 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnyscaleGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnyscaleGenerative.java @@ -5,7 +5,6 @@ import com.google.gson.annotations.SerializedName; import io.weaviate.client6.v1.api.collections.Generative; -import io.weaviate.client6.v1.api.collections.generate.ProviderMetadata; import io.weaviate.client6.v1.internal.ObjectBuilder; public record AnyscaleGenerative( diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AwsGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AwsGenerative.java index f00ba18df..02661d4d4 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AwsGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AwsGenerative.java @@ -5,7 +5,6 @@ import com.google.gson.annotations.SerializedName; import io.weaviate.client6.v1.api.collections.Generative; -import io.weaviate.client6.v1.api.collections.generate.ProviderMetadata; import io.weaviate.client6.v1.internal.ObjectBuilder; public record AwsGenerative( From 30744ec4fc83918fab46707f51cae8f4678e6039 Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Wed, 22 Oct 2025 10:15:53 +0200 Subject: [PATCH 16/25] chore: delete redundant import --- .../client6/v1/api/collections/generative/OllamaGenerative.java | 1 - 1 file changed, 1 deletion(-) diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/OllamaGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/OllamaGenerative.java index 3508f69e4..c538e5acf 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/OllamaGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/OllamaGenerative.java @@ -5,7 +5,6 @@ import com.google.gson.annotations.SerializedName; import io.weaviate.client6.v1.api.collections.Generative; -import io.weaviate.client6.v1.api.collections.generate.ProviderMetadata; import io.weaviate.client6.v1.internal.ObjectBuilder; public record OllamaGenerative( From aee94ba80730bd08e06d0f10e47bf9aa71430ce8 Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Wed, 22 Oct 2025 12:29:41 +0200 Subject: [PATCH 17/25] feat: add methods to cast Generative to specific classes --- .../v1/api/collections/Generative.java | 207 +++++++++++++++++- .../generative/AnthropicGenerative.java | 1 + .../client6/v1/internal/TaggedUnion.java | 27 +++ 3 files changed, 228 insertions(+), 7 deletions(-) create mode 100644 src/main/java/io/weaviate/client6/v1/internal/TaggedUnion.java diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java b/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java index 43f0e99f6..31701ee9c 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java @@ -28,9 +28,10 @@ import io.weaviate.client6.v1.api.collections.generative.OpenAiGenerative; import io.weaviate.client6.v1.api.collections.generative.XaiGenerative; import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.TaggedUnion; import io.weaviate.client6.v1.internal.json.JsonEnum; -public interface Generative { +public interface Generative extends TaggedUnion { public enum Kind implements JsonEnum { ANYSCALE("generative-anyscale"), AWS("generative-aws"), @@ -64,10 +65,6 @@ public static Kind valueOfJson(String jsonValue) { } } - Kind _kind(); - - Object _self(); - /** Configure a default {@code generative-anthropic} module. */ public static Generative anthropic() { return AnthropicGenerative.of(); @@ -277,6 +274,201 @@ public static Generative xai(Function TypeAdapter create(Gson gson, TypeToken type) { init(gson); } - final TypeAdapter writeAdapter = (TypeAdapter) gson.getDelegateAdapter(this, TypeToken.get(rawType)); + final TypeAdapter writeAdapter = (TypeAdapter) gson.getDelegateAdapter(this, + TypeToken.get(rawType)); return (TypeAdapter) new TypeAdapter() { @Override public void write(JsonWriter out, Generative value) throws IOException { out.beginObject(); out.name(value._kind().jsonValue()); - writeAdapter.write(out, (T) value._self()); + writeAdapter.write(out, value._self()); out.endObject(); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnthropicGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnthropicGenerative.java index 131eb059d..04533b1a2 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnthropicGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnthropicGenerative.java @@ -9,6 +9,7 @@ import io.weaviate.client6.v1.api.collections.Generative; import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.TaggedUnion; public record AnthropicGenerative( @SerializedName("model") String model, diff --git a/src/main/java/io/weaviate/client6/v1/internal/TaggedUnion.java b/src/main/java/io/weaviate/client6/v1/internal/TaggedUnion.java new file mode 100644 index 000000000..42b944294 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/internal/TaggedUnion.java @@ -0,0 +1,27 @@ +package io.weaviate.client6.v1.internal; + +public interface TaggedUnion, SelfT> { + KindT _kind(); + + SelfT _self(); + + /** Does the current instance have the kind? */ + default boolean _is(KindT kind) { + return _kind() == kind; + } + + /** Convert tagged union instance to one of its variants. */ + default > Value _as(KindT kind) { + return TaggedUnion.as(this, kind); + } + + /** Convert tagged union instance to one of its variants. */ + public static , Tag extends Enum, Value> Value as(Union union, Tag kind) { + if (union._is(kind)) { + @SuppressWarnings("unchecked") + Value value = (Value) union._self(); + return value; + } + throw new IllegalStateException("Cannot convert '%s' variant to '%s'".formatted(union._kind(), kind)); + } +} From f18ddf2280045f20d460ab7abb4a42bd43c3acc5 Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Wed, 22 Oct 2025 13:25:11 +0200 Subject: [PATCH 18/25] test: add JSON tests for Generative.CustomTypeAdapterFactory --- .../v1/api/collections/Generative.java | 38 +-- .../generative/AnthropicGenerative.java | 1 - .../collections/generative/AwsGenerative.java | 2 +- .../generative/CohereGenerative.java | 10 +- .../generative/DatabricksGenerative.java | 8 +- .../generative/DummyGenerative.java | 1 - .../generative/OllamaGenerative.java | 12 +- .../client6/v1/internal/json/JSONTest.java | 286 ++++++++++++++++-- 8 files changed, 299 insertions(+), 59 deletions(-) diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java b/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java index 31701ee9c..da7818c75 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java @@ -6,11 +6,11 @@ import java.util.function.Function; import com.google.gson.Gson; +import com.google.gson.JsonParser; import com.google.gson.TypeAdapter; import com.google.gson.TypeAdapterFactory; import com.google.gson.reflect.TypeToken; import com.google.gson.stream.JsonReader; -import com.google.gson.stream.JsonToken; import com.google.gson.stream.JsonWriter; import io.weaviate.client6.v1.api.collections.generative.AnthropicGenerative; @@ -161,7 +161,7 @@ public static Generative frienliai() { * * @param fn Lambda expression for optional parameters. */ - public static Generative frienliai(Function> fn) { + public static Generative friendliai(Function> fn) { return FriendliaiGenerative.of(fn); } @@ -508,7 +508,7 @@ public TypeAdapter create(Gson gson, TypeToken type) { init(gson); } - final TypeAdapter writeAdapter = (TypeAdapter) gson.getDelegateAdapter(this, + final TypeAdapter writeAdapter = (TypeAdapter) gson.getDelegateAdapter(this, TypeToken.get(rawType)); return (TypeAdapter) new TypeAdapter() { @@ -516,27 +516,31 @@ public TypeAdapter create(Gson gson, TypeToken type) { public void write(JsonWriter out, Generative value) throws IOException { out.beginObject(); out.name(value._kind().jsonValue()); - writeAdapter.write(out, value._self()); + writeAdapter.write(out, (T) value._self()); out.endObject(); } @Override public Generative read(JsonReader in) throws IOException { - in.beginObject(); - var moduleName = in.nextName(); - try { - var kind = Generative.Kind.valueOfJson(moduleName); - var adapter = readAdapters.get(kind); - assert adapter != null : "no generative adapter for kind " + kind; - return adapter.read(in); - } catch (IllegalArgumentException e) { - return null; - } finally { - if (in.peek() == JsonToken.BEGIN_OBJECT) { - in.beginObject(); + var jsonObject = JsonParser.parseReader(in).getAsJsonObject(); + var provider = jsonObject.keySet().iterator().next(); + + var generative = jsonObject.get(provider).getAsJsonObject(); + Generative.Kind kind; + if (provider.equals(Generative.Kind.OPENAI.jsonValue())) { + kind = generative.has("deploymentId") && generative.has("resourceName") + ? Generative.Kind.AZURE_OPENAI + : Generative.Kind.OPENAI; + } else { + try { + kind = Generative.Kind.valueOfJson(provider); + } catch (IllegalArgumentException e) { + return null; } - in.endObject(); } + var adapter = readAdapters.get(kind); + assert adapter != null : "no generative adapter for kind " + kind; + return adapter.fromJsonTree(generative); } }.nullSafe(); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnthropicGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnthropicGenerative.java index 04533b1a2..131eb059d 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnthropicGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnthropicGenerative.java @@ -9,7 +9,6 @@ import io.weaviate.client6.v1.api.collections.Generative; import io.weaviate.client6.v1.internal.ObjectBuilder; -import io.weaviate.client6.v1.internal.TaggedUnion; public record AnthropicGenerative( @SerializedName("model") String model, diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AwsGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AwsGenerative.java index 02661d4d4..013d4aac6 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AwsGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AwsGenerative.java @@ -10,7 +10,7 @@ public record AwsGenerative( @SerializedName("region") String region, @SerializedName("service") String service, - @SerializedName("endpoint") String baseURL, + @SerializedName("endpoint") String baseUrl, @SerializedName("model") String model) implements Generative { @Override diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/CohereGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/CohereGenerative.java index d6d9a5e5d..f63c72c78 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/CohereGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/CohereGenerative.java @@ -12,7 +12,7 @@ public record CohereGenerative( @SerializedName("baseURL") String baseUrl, - @SerializedName("kProperty") Integer k, + @SerializedName("kProperty") Integer topK, @SerializedName("model") String model, @SerializedName("maxTokensProperty") Integer maxTokens, @SerializedName("temperatureProperty") Float temperature, @@ -40,7 +40,7 @@ public static CohereGenerative of(Function { private String baseUrl; - private Integer k; + private Integer topK; private String model; private Integer maxTokens; private Float temperature; @@ -63,8 +63,8 @@ public Builder baseUrl(String baseUrl) { return this; } - public Builder k(int k) { - this.k = k; + public Builder topK(int topK) { + this.topK = topK; return this; } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/DatabricksGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/DatabricksGenerative.java index da82c3bf7..ffceeddb0 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/DatabricksGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/DatabricksGenerative.java @@ -34,7 +34,7 @@ public static DatabricksGenerative of(String baseURL, Function { - private final String baseURL; + private final String baseUrl; private Integer maxTokens; private Integer topK; private Float topP; private Float temperature; - public Builder(String baseURL) { - this.baseURL = baseURL; + public Builder(String baseUrl) { + this.baseUrl = baseUrl; } /** Limit the number of tokens to generate in the response. */ diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/DummyGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/DummyGenerative.java index f6d1f915a..cfe300a5e 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/DummyGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/DummyGenerative.java @@ -1,7 +1,6 @@ package io.weaviate.client6.v1.api.collections.generative; import io.weaviate.client6.v1.api.collections.Generative; -import io.weaviate.client6.v1.api.collections.generate.ProviderMetadata; public record DummyGenerative() implements Generative { @Override diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/OllamaGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/OllamaGenerative.java index c538e5acf..25fc6f3c7 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/OllamaGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/OllamaGenerative.java @@ -8,7 +8,7 @@ import io.weaviate.client6.v1.internal.ObjectBuilder; public record OllamaGenerative( - @SerializedName("apiEndpoint") String apiEndpoint, + @SerializedName("apiEndpoint") String baseUrl, @SerializedName("model") String model) implements Generative { @Override @@ -31,17 +31,17 @@ public static OllamaGenerative of(Function { - private String apiEndpoint; + private String baseUrl; private String model; - /** Destination endpoint of the generative provider. */ - public Builder apiEndpoint(String apiEndpoint) { - this.apiEndpoint = apiEndpoint; + /** Base URL of the generative model. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; return this; } diff --git a/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java b/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java index 9ca331158..f6f15e0dd 100644 --- a/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java +++ b/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java @@ -426,30 +426,6 @@ public static Object[][] testCases() { """, }, - // Generative.CustomTypeAdapterFactory - { - Generative.class, - Generative.cohere(generate -> generate - .k(1) - .maxTokens(10) - .model("example-model") - .returnLikelihoodsProperty("likelihood") - .stopSequences("stop", "halt") - .temperature(.2f)), - """ - { - "generative-cohere": { - "kProperty": 1, - "maxTokensProperty": 10, - "model": "example-model", - "returnLikelihoodsProperty": "likelihood", - "stopSequencesProperty": ["stop", "halt"], - "temperatureProperty": 0.2 - } - } - """, - }, - // BatchReference.CustomTypeAdapterFactory { BatchReference.class, @@ -917,6 +893,268 @@ public static Object[][] testCases() { } """ }, + + // Generative.CustomTypeAdapterFactory + { + Generative.class, + Generative.anyscale(cfg -> cfg + .baseUrl("https://example.com") + .model("example-model") + .temperature(3f)), + """ + { + "generative-anyscale": { + "baseURL": "https://example.com", + "temperature": 3.0, + "model": "example-model" + } + } + """, + }, + { + Generative.class, + Generative.anthropic(cfg -> cfg + .topK(1) + .maxTokens(2) + .temperature(3f) + .model("example-model") + .stopSequences("stop", "halt")), + """ + { + "generative-anthropic": { + "topK": 1, + "maxTokens": 2, + "temperature": 3.0, + "model": "example-model", + "stopSequences": ["stop", "halt"] + } + } + """, + }, + { + Generative.class, + Generative.aws( + "aws-region", + "aws-service", + cfg -> cfg + .baseUrl("https://example.com") + .model("example-model")), + """ + { + "generative-aws": { + "endpoint": "https://example.com", + "model": "example-model", + "region": "aws-region", + "service": "aws-service" + } + } + """, + }, + { + Generative.class, + Generative.cohere(cfg -> cfg + .topK(1) + .maxTokens(2) + .temperature(3f) + .model("example-model") + .returnLikelihoodsProperty("likelihood") + .stopSequences("stop", "halt")), + """ + { + "generative-cohere": { + "kProperty": 1, + "maxTokensProperty": 2, + "temperatureProperty": 3.0, + "model": "example-model", + "returnLikelihoodsProperty": "likelihood", + "stopSequencesProperty": ["stop", "halt"] + } + } + """, + }, + { + Generative.class, + Generative.databricks( + "https://example.com", + cfg -> cfg + .topK(1) + .maxTokens(2) + .temperature(3f) + .topP(4f)), + """ + { + "generative-databricks": { + "endpoint": "https://example.com", + "topK": 1, + "maxTokens": 2, + "temperature": 3.0, + "topP": 4.0 + } + } + """, + }, + { + Generative.class, + Generative.friendliai(cfg -> cfg + .baseUrl("https://example.com") + .maxTokens(2) + .temperature(3f) + .model("example-model")), + """ + { + "generative-friendliai": { + "baseURL": "https://example.com", + "maxTokens": 2, + "temperature": 3.0, + "model": "example-model" + } + } + """, + }, + { + Generative.class, + Generative.mistral(cfg -> cfg + .baseUrl("https://example.com") + .maxTokens(2) + .temperature(3f) + .model("example-model")), + """ + { + "generative-mistral": { + "baseURL": "https://example.com", + "maxTokens": 2, + "temperature": 3.0, + "model": "example-model" + } + } + """, + }, + { + Generative.class, + Generative.nvidia(cfg -> cfg + .baseUrl("https://example.com") + .maxTokens(2) + .temperature(3f) + .model("example-model")), + """ + { + "generative-nvidia": { + "baseURL": "https://example.com", + "maxTokens": 2, + "temperature": 3.0, + "model": "example-model" + } + } + """, + }, + { + Generative.class, + Generative.google( + "google-project", + cfg -> cfg + .baseUrl("https://example.com") + .maxTokens(2) + .temperature(3f) + .topK(4) + .topP(5f) + .model("example-model")), + """ + { + "generative-palm": { + "apiEndpoint": "https://example.com", + "maxOutputTokens": 2, + "temperature": 3.0, + "topK": 4, + "topP": 5, + "projectId": "google-project", + "modelId": "example-model" + } + } + """, + }, + { + Generative.class, + Generative.ollama(cfg -> cfg + .baseUrl("https://example.com") + .model("example-model")), + """ + { + "generative-ollama": { + "apiEndpoint": "https://example.com", + "model": "example-model" + } + } + """, + }, + { + Generative.class, + Generative.xai(cfg -> cfg + .baseUrl("https://example.com") + .maxTokens(2) + .temperature(3f) + .model("example-model")), + """ + { + "generative-xai": { + "baseURL": "https://example.com", + "maxTokens": 2, + "temperature": 3.0, + "model": "example-model" + } + } + """, + }, + { + Generative.class, + Generative.openai(cfg -> cfg + .baseUrl("https://example.com") + .frequencyPenalty(1f) + .presencePenalty(2f) + .temperature(3f) + .topP(4f) + .maxTokens(5) + .model("o3-mini")), + """ + { + "generative-openai": { + "baseURL": "https://example.com", + "frequencyPenaltyProperty": 1.0, + "presencePenaltyProperty": 2.0, + "temperatureProperty": 3.0, + "topPProperty": 4.0, + "maxTokensProperty": 5, + "model": "o3-mini" + } + } + """ + }, + { + Generative.class, + Generative.azure( + "azure-resource", + "azure-deployment", + cfg -> cfg + .baseUrl("https://example.com") + .frequencyPenalty(1f) + .presencePenalty(2f) + .temperature(3f) + .topP(4f) + .maxTokens(5)), + """ + { + "generative-openai": { + "baseURL": "https://example.com", + "frequencyPenaltyProperty": 1.0, + "presencePenaltyProperty": 2.0, + "temperatureProperty": 3.0, + "topPProperty": 4.0, + "maxTokensProperty": 5, + "resourceName": "azure-resource", + "deploymentId": "azure-deployment" + } + } + """ + }, }; } From 478b7a67cfc1f61f0a7597413e9150d319bc6cde Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Wed, 22 Oct 2025 14:42:33 +0200 Subject: [PATCH 19/25] feat(rag): add dynamic providers for Anthropic/Anyscale/Aws/Cohere --- .../collections/generate/GenerativeTask.java | 43 ++++- .../generative/AnthropicGenerative.java | 180 +++++++++++++++++- .../generative/AnyscaleGenerative.java | 68 +++++++ .../collections/generative/AwsGenerative.java | 148 ++++++++++++++ .../generative/CohereGenerative.java | 159 ++++++++++++++++ 5 files changed, 591 insertions(+), 7 deletions(-) diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeTask.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeTask.java index e5f4dfcf9..238894775 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeTask.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeTask.java @@ -58,7 +58,7 @@ void appendTo(WeaviateProtoGenerative.GenerativeSearch.Builder req) { } } - public record Single(String prompt, boolean debug) { + public record Single(String prompt, boolean debug, List providers) { public static Single of(String prompt) { return of(prompt, ObjectBuilder.identity()); } @@ -68,11 +68,12 @@ public static Single of(String prompt, Function> } public Single(Builder builder) { - this(builder.prompt, builder.debug); + this(builder.prompt, builder.debug, builder.providers); } public static class Builder implements ObjectBuilder { private final String prompt; + private final List providers = new ArrayList<>(); private boolean debug = false; public Builder(String prompt) { @@ -84,6 +85,12 @@ public Builder debug(boolean enable) { return this; } + public Builder generativeProvider(DynamicProvider provider) { + providers.clear(); // Protobuf allows `repeated` but the server expects there to be 1. + providers.add(provider); + return this; + } + @Override public Single build() { return new Single(this); @@ -91,14 +98,23 @@ public Single build() { } public void appendTo(WeaviateProtoGenerative.GenerativeSearch.Builder req) { + var ragProviders = providers.stream() + .map(provider -> { + var proto = WeaviateProtoGenerative.GenerativeProvider.newBuilder(); + provider.appendTo(proto); + return proto.build(); + }) + .toList(); + req.setSingle( WeaviateProtoGenerative.GenerativeSearch.Single.newBuilder() .setPrompt(prompt) - .setDebug(debug)); + .setDebug(debug) + .addAllQueries(ragProviders)); } } - public record Grouped(String prompt, boolean debug, List properties) { + public record Grouped(String prompt, boolean debug, List properties, List providers) { public static Grouped of(String prompt) { return of(prompt, ObjectBuilder.identity()); } @@ -108,11 +124,12 @@ public static Grouped of(String prompt, Function } public Grouped(Builder builder) { - this(builder.prompt, builder.debug, builder.properties); + this(builder.prompt, builder.debug, builder.properties, builder.providers); } public static class Builder implements ObjectBuilder { private final String prompt; + private final List providers = new ArrayList<>(); private final List properties = new ArrayList<>(); private boolean debug = false; @@ -129,6 +146,12 @@ public Builder properties(List properties) { return this; } + public Builder generativeProvider(DynamicProvider provider) { + providers.clear(); // Protobuf allows `repeated` but the server expects there to be 1. + providers.add(provider); + return this; + } + public Builder debug(boolean enable) { this.debug = enable; return this; @@ -151,6 +174,16 @@ public void appendTo(WeaviateProtoGenerative.GenerativeSearch.Builder req) { .addAllValues(properties)); } + + var ragProviders = providers.stream() + .map(provider -> { + var proto = WeaviateProtoGenerative.GenerativeProvider.newBuilder(); + provider.appendTo(proto); + return proto.build(); + }) + .toList(); + grouped.addAllQueries(ragProviders); + req.setGrouped(grouped); } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnthropicGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnthropicGenerative.java index 131eb059d..957ba708c 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnthropicGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnthropicGenerative.java @@ -8,13 +8,17 @@ import com.google.gson.annotations.SerializedName; import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.api.collections.generate.DynamicProvider; import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; public record AnthropicGenerative( @SerializedName("model") String model, @SerializedName("maxTokens") Integer maxTokens, @SerializedName("temperature") Float temperature, @SerializedName("topK") Integer topK, + @SerializedName("topP") Float topP, @SerializedName("stopSequences") List stopSequences) implements Generative { @Override @@ -41,21 +45,30 @@ public AnthropicGenerative(Builder builder) { builder.maxTokens, builder.temperature, builder.topK, + builder.topP, builder.stopSequences); } public static class Builder implements ObjectBuilder { private Integer topK; + private Float topP; private String model; private Integer maxTokens; private Float temperature; - private List stopSequences = new ArrayList<>(); + private final List stopSequences = new ArrayList<>(); + /** Top K value for sampling. */ public Builder topK(int topK) { this.topK = topK; return this; } + /** Top P value for nucleus sampling. */ + public Builder topP(float topP) { + this.topP = topP; + return this; + } + /** Select generative model. */ public Builder model(String model) { this.model = model; @@ -68,12 +81,18 @@ public Builder maxTokens(int maxTokens) { return this; } + /** + * Set tokens which should signal the model to stop generating further output. + */ public Builder stopSequences(String... stopSequences) { return stopSequences(Arrays.asList(stopSequences)); } + /** + * Set tokens which should signal the model to stop generating further output. + */ public Builder stopSequences(List stopSequences) { - this.stopSequences = stopSequences; + this.stopSequences.addAll(stopSequences); return this; } @@ -102,4 +121,161 @@ public Generative.Kind _kind() { public static record Usage(Long inputTokens, Long outputTokens) { } } + + public static record Provider( + String baseUrl, + Integer maxTokens, + String model, + Float temperature, + Integer topK, + Float topP, + List stopSequences, + List images, + List imageProperties) implements DynamicProvider { + + public static Provider of( + Function> fn) { + return fn.apply(new Builder()).build(); + } + + @Override + public void appendTo( + io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative.GenerativeProvider.Builder req) { + var provider = WeaviateProtoGenerative.GenerativeAnthropic.newBuilder(); + if (baseUrl != null) { + provider.setBaseUrl(baseUrl); + } + if (maxTokens != null) { + provider.setMaxTokens(maxTokens); + } + if (model != null) { + provider.setModel(model); + } + if (temperature != null) { + provider.setTemperature(temperature); + } + if (topK != null) { + provider.setTopK(topK); + } + if (topP != null) { + provider.setTopP(topP); + } + + if (stopSequences != null) { + provider.setStopSequences(WeaviateProtoBase.TextArray.newBuilder() + .addAllValues(stopSequences)); + } + if (images != null) { + provider.setImages(WeaviateProtoBase.TextArray.newBuilder() + .addAllValues(images)); + } + if (imageProperties != null) { + provider.setImageProperties(WeaviateProtoBase.TextArray.newBuilder() + .addAllValues(imageProperties)); + } + req.setAnthropic(provider); + } + + public Provider(Builder builder) { + this( + builder.baseUrl, + builder.maxTokens, + builder.model, + builder.temperature, + builder.topK, + builder.topP, + builder.stopSequences, + builder.images, + builder.imageProperties); + } + + public static class Builder implements ObjectBuilder { + private String baseUrl; + private Integer topK; + private Float topP; + private String model; + private Integer maxTokens; + private Float temperature; + private final List stopSequences = new ArrayList<>(); + private final List images = new ArrayList<>(); + private final List imageProperties = new ArrayList<>(); + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + /** Top K value for sampling. */ + public Builder topK(int topK) { + this.topK = topK; + return this; + } + + /** Top P value for nucleus sampling. */ + public Builder topP(float topP) { + this.topP = topP; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** Limit the number of tokens to generate in the response. */ + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + /** + * Set tokens which should signal the model to stop generating further output. + */ + public Builder stopSequences(String... stopSequences) { + return stopSequences(Arrays.asList(stopSequences)); + } + + /** + * Set tokens which should signal the model to stop generating further output. + */ + public Builder stopSequences(List stopSequences) { + this.stopSequences.addAll(stopSequences); + return this; + } + + public Builder images(String... images) { + return images(Arrays.asList(images)); + } + + public Builder images(List images) { + this.images.addAll(images); + return this; + } + + public Builder imageProperties(String... imageProperties) { + return imageProperties(Arrays.asList(imageProperties)); + } + + public Builder imageProperties(List imageProperties) { + this.imageProperties.addAll(imageProperties); + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public AnthropicGenerative.Provider build() { + return new AnthropicGenerative.Provider(this); + } + } + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnyscaleGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnyscaleGenerative.java index 028412092..acfb810cc 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnyscaleGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnyscaleGenerative.java @@ -5,7 +5,9 @@ import com.google.gson.annotations.SerializedName; import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.api.collections.generate.DynamicProvider; import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; public record AnyscaleGenerative( @SerializedName("baseURL") String baseUrl, @@ -74,4 +76,70 @@ public Generative.Kind _kind() { return Generative.Kind.ANYSCALE; } } + + public static record Provider( + String baseUrl, + String model, + Float temperature) implements DynamicProvider { + + public static Provider of( + Function> fn) { + return fn.apply(new Builder()).build(); + } + + @Override + public void appendTo( + io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative.GenerativeProvider.Builder req) { + var provider = WeaviateProtoGenerative.GenerativeAnyscale.newBuilder(); + if (baseUrl != null) { + provider.setBaseUrl(baseUrl); + } + if (model != null) { + provider.setModel(model); + } + if (temperature != null) { + provider.setTemperature(temperature); + } + req.setAnyscale(provider); + } + + public Provider(Builder builder) { + this( + builder.baseUrl, + builder.model, + builder.temperature); + } + + public static class Builder implements ObjectBuilder { + private String baseUrl; + private String model; + private Float temperature; + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public AnyscaleGenerative.Provider build() { + return new AnyscaleGenerative.Provider(this); + } + } + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AwsGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AwsGenerative.java index 013d4aac6..12605908e 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AwsGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AwsGenerative.java @@ -1,11 +1,17 @@ package io.weaviate.client6.v1.api.collections.generative; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; import java.util.function.Function; import com.google.gson.annotations.SerializedName; import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.api.collections.generate.DynamicProvider; import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; public record AwsGenerative( @SerializedName("region") String region, @@ -76,4 +82,146 @@ public Generative.Kind _kind() { return Generative.Kind.AWS; } } + + public static record Provider( + String region, + String service, + String baseUrl, + String model, + String targetModel, + String targetModelVariant, + Float temperature, + List images, + List imageProperties) implements DynamicProvider { + + public static Provider of( + Function> fn) { + return fn.apply(new Builder()).build(); + } + + @Override + public void appendTo( + io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative.GenerativeProvider.Builder req) { + var provider = WeaviateProtoGenerative.GenerativeAWS.newBuilder(); + if (region != null) { + provider.setRegion(region); + } + if (service != null) { + provider.setService(service); + } + if (baseUrl != null) { + provider.setEndpoint(baseUrl); + } + if (model != null) { + provider.setModel(model); + } + if (temperature != null) { + provider.setTemperature(temperature); + } + if (targetModel != null) { + provider.setTargetModel(targetModel); + } + if (targetModelVariant != null) { + provider.setTargetVariant(targetModelVariant); + } + if (images != null) { + provider.setImages(WeaviateProtoBase.TextArray.newBuilder() + .addAllValues(images)); + } + if (imageProperties != null) { + provider.setImageProperties(WeaviateProtoBase.TextArray.newBuilder() + .addAllValues(imageProperties)); + } + req.setAws(provider); + } + + public Provider(Builder builder) { + this( + builder.region, + builder.service, + builder.baseUrl, + builder.model, + builder.targetModel, + builder.targetModelVariant, + builder.temperature, + builder.images, + builder.imageProperties); + } + + public static class Builder implements ObjectBuilder { + private String region; + private String service; + private String baseUrl; + private String model; + private String targetModel; + private String targetModelVariant; + private Float temperature; + private final List images = new ArrayList<>(); + private final List imageProperties = new ArrayList<>(); + + public Builder region(String region) { + this.region = region; + return this; + } + + public Builder service(String service) { + this.service = service; + return this; + } + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + public Builder targetModel(String targetModel) { + this.targetModel = targetModel; + return this; + } + + public Builder targetModelVariant(String targetModelVariant) { + this.targetModelVariant = targetModelVariant; + return this; + } + + public Builder images(String... images) { + return images(Arrays.asList(images)); + } + + public Builder images(List images) { + this.images.addAll(images); + return this; + } + + public Builder imageProperties(String... imageProperties) { + return imageProperties(Arrays.asList(imageProperties)); + } + + public Builder imageProperties(List imageProperties) { + this.imageProperties.addAll(imageProperties); + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public AwsGenerative.Provider build() { + return new AwsGenerative.Provider(this); + } + } + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/CohereGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/CohereGenerative.java index f63c72c78..4c84b4339 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/CohereGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/CohereGenerative.java @@ -8,7 +8,10 @@ import com.google.gson.annotations.SerializedName; import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.api.collections.generate.DynamicProvider; import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; public record CohereGenerative( @SerializedName("baseURL") String baseUrl, @@ -63,6 +66,7 @@ public Builder baseUrl(String baseUrl) { return this; } + /** Top K value for sampling. */ public Builder topK(int topK) { this.topK = topK; return this; @@ -85,10 +89,16 @@ public Builder returnLikelihoodsProperty(String returnLikelihoodsProperty) { return this; } + /** + * Set tokens which should signal the model to stop generating further output. + */ public Builder stopSequences(String... stopSequences) { return stopSequences(Arrays.asList(stopSequences)); } + /** + * Set tokens which should signal the model to stop generating further output. + */ public Builder stopSequences(List stopSequences) { this.stopSequences = stopSequences; return this; @@ -127,4 +137,153 @@ public static record BilledUnits(Double inputTokens, Double outputTokens, Double public static record Tokens(Double inputTokens, Double outputTokens) { } } + + public static record Provider( + String baseUrl, + Integer maxTokens, + String model, + Float temperature, + Integer topK, + Float topP, + Float frequencyPenalty, + Float presencePenalty, + List stopSequences) implements DynamicProvider { + + public static Provider of( + Function> fn) { + return fn.apply(new Builder()).build(); + } + + @Override + public void appendTo( + io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative.GenerativeProvider.Builder req) { + var provider = WeaviateProtoGenerative.GenerativeCohere.newBuilder(); + if (baseUrl != null) { + provider.setBaseUrl(baseUrl); + } + if (maxTokens != null) { + provider.setMaxTokens(maxTokens); + } + if (model != null) { + provider.setModel(model); + } + if (temperature != null) { + provider.setTemperature(temperature); + } + if (topK != null) { + provider.setK(topK); + } + if (topP != null) { + provider.setP(topP); + } + + if (frequencyPenalty != null) { + provider.setFrequencyPenalty(frequencyPenalty); + } + if (presencePenalty != null) { + provider.setPresencePenalty(presencePenalty); + } + + if (stopSequences != null) { + provider.setStopSequences(WeaviateProtoBase.TextArray.newBuilder() + .addAllValues(stopSequences)); + } + req.setCohere(provider); + } + + public Provider(Builder builder) { + this( + builder.baseUrl, + builder.maxTokens, + builder.model, + builder.temperature, + builder.topK, + builder.topP, + builder.frequencyPenalty, + builder.presencePenalty, + builder.stopSequences); + } + + public static class Builder implements ObjectBuilder { + private String baseUrl; + private Integer topK; + private Float topP; + private String model; + private Integer maxTokens; + private Float temperature; + private Float frequencyPenalty; + private Float presencePenalty; + private final List stopSequences = new ArrayList<>(); + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + /** Top K value for sampling. */ + public Builder topK(int topK) { + this.topK = topK; + return this; + } + + /** Top P value for nucleus sampling. */ + public Builder topP(float topP) { + this.topP = topP; + return this; + } + + public Builder frequencyPenalty(float frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + return this; + } + + /** Top P value for nucleus sampling. */ + public Builder presencePenalty(float presencePenalty) { + this.presencePenalty = presencePenalty; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** Limit the number of tokens to generate in the response. */ + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + /** + * Set tokens which should signal the model to stop generating further output. + */ + public Builder stopSequences(String... stopSequences) { + return stopSequences(Arrays.asList(stopSequences)); + } + + /** + * Set tokens which should signal the model to stop generating further output. + */ + public Builder stopSequences(List stopSequences) { + this.stopSequences.addAll(stopSequences); + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public CohereGenerative.Provider build() { + return new CohereGenerative.Provider(this); + } + } + } } From 3131dafcf04c6f393434305fb75bffc5b28c23a4 Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Wed, 22 Oct 2025 16:15:44 +0200 Subject: [PATCH 20/25] feat: add move dynamic generative providers Azure Databricks Friendliai Google Mistral Nvidia Ollama OpenAI XAI --- .../collections/generate/DynamicProvider.java | 21 ++ .../generative/AzureOpenAiGenerative.java | 210 ++++++++++++++++++ .../generative/DatabricksGenerative.java | 174 +++++++++++++++ .../generative/FriendliaiGenerative.java | 103 +++++++++ .../generative/GoogleGenerative.java | 209 +++++++++++++++++ .../generative/MistralGenerative.java | 92 ++++++++ .../generative/NvidiaGenerative.java | 92 ++++++++ .../generative/OllamaGenerative.java | 104 +++++++++ .../generative/OpenAiGenerative.java | 210 ++++++++++++++++++ .../collections/generative/XaiGenerative.java | 128 +++++++++++ 10 files changed, 1343 insertions(+) create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/generate/DynamicProvider.java diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/DynamicProvider.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/DynamicProvider.java new file mode 100644 index 000000000..d5930af65 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generate/DynamicProvider.java @@ -0,0 +1,21 @@ +package io.weaviate.client6.v1.api.collections.generate; + +import java.util.function.Function; + +import io.weaviate.client6.v1.api.collections.generative.AnthropicGenerative; +import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; + +public interface DynamicProvider { + void appendTo(WeaviateProtoGenerative.GenerativeProvider.Builder req); + + /** + * Configure {@code generative-anthropic} as a dynamic provider. + * + * @param fn Lambda expression for optional parameters. + */ + public static DynamicProvider anthropic( + Function> fn) { + return AnthropicGenerative.Provider.of(fn); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AzureOpenAiGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AzureOpenAiGenerative.java index a63195ff4..94d0a3c0a 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AzureOpenAiGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AzureOpenAiGenerative.java @@ -1,11 +1,17 @@ package io.weaviate.client6.v1.api.collections.generative; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; import java.util.function.Function; import com.google.gson.annotations.SerializedName; import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.api.collections.generate.DynamicProvider; import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; public record AzureOpenAiGenerative( @SerializedName("baseURL") String baseUrl, @@ -107,4 +113,208 @@ public AzureOpenAiGenerative build() { return new AzureOpenAiGenerative(this); } } + + public static record Provider( + String baseUrl, + Integer maxTokens, + String model, + Float temperature, + Integer n, + Float topP, + Float frequencyPenalty, + Float presencePenalty, + String apiVersion, + String resourceName, + String deploymentId, + List stopSequences, + List images, + List imageProperties) implements DynamicProvider { + + public static Provider of( + Function> fn) { + return fn.apply(new Builder()).build(); + } + + @Override + public void appendTo( + io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative.GenerativeProvider.Builder req) { + var provider = WeaviateProtoGenerative.GenerativeOpenAI.newBuilder(); + if (baseUrl != null) { + provider.setBaseUrl(baseUrl); + } + if (maxTokens != null) { + provider.setMaxTokens(maxTokens); + } + if (model != null) { + provider.setModel(model); + } + if (temperature != null) { + provider.setTemperature(temperature); + } + if (n != null) { + provider.setN(n); + } + if (topP != null) { + provider.setTopP(topP); + } + if (frequencyPenalty != null) { + provider.setFrequencyPenalty(frequencyPenalty); + } + if (presencePenalty != null) { + provider.setPresencePenalty(presencePenalty); + } + if (apiVersion != null) { + provider.setApiVersion(apiVersion); + } + if (resourceName != null) { + provider.setResourceName(resourceName); + } + if (deploymentId != null) { + provider.setDeploymentId(deploymentId); + } + if (stopSequences != null) { + provider.setStop(WeaviateProtoBase.TextArray.newBuilder() + .addAllValues(stopSequences)); + } + provider.setIsAzure(true); + req.setOpenai(provider); + } + + public Provider(Builder builder) { + this( + builder.baseUrl, + builder.maxTokens, + builder.model, + builder.temperature, + builder.n, + builder.topP, + builder.frequencyPenalty, + builder.presencePenalty, + builder.apiVersion, + builder.resourceName, + builder.deploymentId, + builder.stopSequences, + builder.images, + builder.imageProperties); + } + + public static class Builder implements ObjectBuilder { + private String baseUrl; + private Integer n; + private Float topP; + private String model; + private Integer maxTokens; + private Float temperature; + private Float frequencyPenalty; + private Float presencePenalty; + private String apiVersion; + private String resourceName; + private String deploymentId; + private final List stopSequences = new ArrayList<>(); + private final List images = new ArrayList<>(); + private final List imageProperties = new ArrayList<>(); + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + public Builder n(int n) { + this.n = n; + return this; + } + + /** Top P value for nucleus sampling. */ + public Builder topP(float topP) { + this.topP = topP; + return this; + } + + public Builder frequencyPenalty(float frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + return this; + } + + /** Top P value for nucleus sampling. */ + public Builder presencePenalty(float presencePenalty) { + this.presencePenalty = presencePenalty; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** Limit the number of tokens to generate in the response. */ + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + /** + * Set tokens which should signal the model to stop generating further output. + */ + public Builder stopSequences(String... stopSequences) { + return stopSequences(Arrays.asList(stopSequences)); + } + + /** + * Set tokens which should signal the model to stop generating further output. + */ + public Builder stopSequences(List stopSequences) { + this.stopSequences.addAll(stopSequences); + return this; + } + + public Builder apiVersion(String apiVersion) { + this.apiVersion = apiVersion; + return this; + } + + public Builder resourceName(String resourceName) { + this.resourceName = resourceName; + return this; + } + + public Builder deploymentId(String deploymentId) { + this.deploymentId = deploymentId; + return this; + } + + public Builder images(String... images) { + return images(Arrays.asList(images)); + } + + public Builder images(List images) { + this.images.addAll(images); + return this; + } + + public Builder imageProperties(String... imageProperties) { + return imageProperties(Arrays.asList(imageProperties)); + } + + public Builder imageProperties(List imageProperties) { + this.imageProperties.addAll(imageProperties); + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public AzureOpenAiGenerative.Provider build() { + return new AzureOpenAiGenerative.Provider(this); + } + } + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/DatabricksGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/DatabricksGenerative.java index ffceeddb0..6005970ae 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/DatabricksGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/DatabricksGenerative.java @@ -1,11 +1,17 @@ package io.weaviate.client6.v1.api.collections.generative; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; import java.util.function.Function; import com.google.gson.annotations.SerializedName; import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.api.collections.generate.DynamicProvider; import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; public record DatabricksGenerative( @SerializedName("endpoint") String baseUrl, @@ -93,4 +99,172 @@ public Generative.Kind _kind() { return Generative.Kind.DATABRICKS; } } + + public static record Provider( + String baseUrl, + Integer maxTokens, + String model, + Float temperature, + Integer n, + Float topP, + Float frequencyPenalty, + Float presencePenalty, + Boolean logProbs, + Integer topLogProbs, + List stopSequences) implements DynamicProvider { + + public static Provider of( + Function> fn) { + return fn.apply(new Builder()).build(); + } + + @Override + public void appendTo( + io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative.GenerativeProvider.Builder req) { + var provider = WeaviateProtoGenerative.GenerativeDatabricks.newBuilder(); + if (baseUrl != null) { + provider.setEndpoint(baseUrl); + } + if (maxTokens != null) { + provider.setMaxTokens(maxTokens); + } + if (model != null) { + provider.setModel(model); + } + if (temperature != null) { + provider.setTemperature(temperature); + } + if (n != null) { + provider.setN(n); + } + if (topP != null) { + provider.setTopP(topP); + } + if (frequencyPenalty != null) { + provider.setFrequencyPenalty(frequencyPenalty); + } + if (presencePenalty != null) { + provider.setPresencePenalty(presencePenalty); + } + if (logProbs != null) { + provider.setLogProbs(logProbs); + } + if (topLogProbs != null) { + provider.setTopLogProbs(topLogProbs); + } + if (stopSequences != null) { + provider.setStop(WeaviateProtoBase.TextArray.newBuilder() + .addAllValues(stopSequences)); + } + req.setDatabricks(provider); + } + + public Provider(Builder builder) { + this( + builder.baseUrl, + builder.maxTokens, + builder.model, + builder.temperature, + builder.n, + builder.topP, + builder.frequencyPenalty, + builder.presencePenalty, + builder.logProbs, + builder.topLogProbs, + builder.stopSequences); + } + + public static class Builder implements ObjectBuilder { + private String baseUrl; + private Integer n; + private Float topP; + private String model; + private Integer maxTokens; + private Float temperature; + private Float frequencyPenalty; + private Float presencePenalty; + private Boolean logProbs; + private Integer topLogProbs; + private final List stopSequences = new ArrayList<>(); + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + public Builder n(int n) { + this.n = n; + return this; + } + + /** Top P value for nucleus sampling. */ + public Builder topP(float topP) { + this.topP = topP; + return this; + } + + public Builder frequencyPenalty(float frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + return this; + } + + /** Top P value for nucleus sampling. */ + public Builder presencePenalty(float presencePenalty) { + this.presencePenalty = presencePenalty; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** Limit the number of tokens to generate in the response. */ + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + /** + * Set tokens which should signal the model to stop generating further output. + */ + public Builder stopSequences(String... stopSequences) { + return stopSequences(Arrays.asList(stopSequences)); + } + + /** + * Set tokens which should signal the model to stop generating further output. + */ + public Builder stopSequences(List stopSequences) { + this.stopSequences.addAll(stopSequences); + return this; + } + + public Builder logProbs(boolean logProbs) { + this.logProbs = logProbs; + return this; + } + + public Builder topLogProbs(int topLogProbs) { + this.topLogProbs = topLogProbs; + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public DatabricksGenerative.Provider build() { + return new DatabricksGenerative.Provider(this); + } + } + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/FriendliaiGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/FriendliaiGenerative.java index 300622525..5e0d3c16c 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/FriendliaiGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/FriendliaiGenerative.java @@ -5,7 +5,9 @@ import com.google.gson.annotations.SerializedName; import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.api.collections.generate.DynamicProvider; import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; public record FriendliaiGenerative( @SerializedName("baseURL") String baseUrl, @@ -85,4 +87,105 @@ public Generative.Kind _kind() { return Generative.Kind.FRIENDLIAI; } } + + public static record Provider( + String baseUrl, + Integer maxTokens, + String model, + Float temperature, + Integer n, + Float topP) implements DynamicProvider { + + public static Provider of( + Function> fn) { + return fn.apply(new Builder()).build(); + } + + @Override + public void appendTo( + io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative.GenerativeProvider.Builder req) { + var provider = WeaviateProtoGenerative.GenerativeFriendliAI.newBuilder(); + if (baseUrl != null) { + provider.setBaseUrl(baseUrl); + } + if (maxTokens != null) { + provider.setMaxTokens(maxTokens); + } + if (model != null) { + provider.setModel(model); + } + if (temperature != null) { + provider.setTemperature(temperature); + } + if (n != null) { + provider.setN(n); + } + if (topP != null) { + provider.setTopP(topP); + } + req.setFriendliai(provider); + } + + public Provider(Builder builder) { + this( + builder.baseUrl, + builder.maxTokens, + builder.model, + builder.temperature, + builder.n, + builder.topP); + } + + public static class Builder implements ObjectBuilder { + private String baseUrl; + private Integer n; + private Float topP; + private String model; + private Integer maxTokens; + private Float temperature; + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + public Builder n(int n) { + this.n = n; + return this; + } + + /** Top P value for nucleus sampling. */ + public Builder topP(float topP) { + this.topP = topP; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** Limit the number of tokens to generate in the response. */ + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public FriendliaiGenerative.Provider build() { + return new FriendliaiGenerative.Provider(this); + } + } + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/GoogleGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/GoogleGenerative.java index 2ced65d90..9fba3987e 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/GoogleGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/GoogleGenerative.java @@ -1,11 +1,17 @@ package io.weaviate.client6.v1.api.collections.generative; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; import java.util.function.Function; import com.google.gson.annotations.SerializedName; import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.api.collections.generate.DynamicProvider; import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; public record GoogleGenerative( @SerializedName("apiEndpoint") String baseUrl, @@ -120,4 +126,207 @@ public static record TokenMetadata(TokenCount inputTokens, TokenCount outputToke public static record Usage(Long promptTokenCount, Long candidatesTokenCount, Long totalTokenCount) { } } + + public static record Provider( + String baseUrl, + Integer maxTokens, + String model, + Float temperature, + Integer topK, + Float topP, + Float frequencyPenalty, + Float presencePenalty, + String projectId, + String endpointId, + String region, + List stopSequences, + List images, + List imageProperties) implements DynamicProvider { + + public static Provider of( + Function> fn) { + return fn.apply(new Builder()).build(); + } + + @Override + public void appendTo( + io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative.GenerativeProvider.Builder req) { + var provider = WeaviateProtoGenerative.GenerativeGoogle.newBuilder(); + if (baseUrl != null) { + provider.setApiEndpoint(baseUrl); + } + if (maxTokens != null) { + provider.setMaxTokens(maxTokens); + } + if (model != null) { + provider.setModel(model); + } + if (temperature != null) { + provider.setTemperature(temperature); + } + if (topK != null) { + provider.setTopK(topK); + } + if (topP != null) { + provider.setTopP(topP); + } + if (projectId != null) { + provider.setProjectId(projectId); + } + if (endpointId != null) { + provider.setEndpointId(endpointId); + } + if (region != null) { + provider.setRegion(region); + } + if (frequencyPenalty != null) { + provider.setFrequencyPenalty(frequencyPenalty); + } + if (presencePenalty != null) { + provider.setPresencePenalty(presencePenalty); + } + if (stopSequences != null) { + provider.setStopSequences(WeaviateProtoBase.TextArray.newBuilder() + .addAllValues(stopSequences)); + } + req.setGoogle(provider); + } + + public Provider(Builder builder) { + this( + builder.baseUrl, + builder.maxTokens, + builder.model, + builder.temperature, + builder.topK, + builder.topP, + builder.frequencyPenalty, + builder.presencePenalty, + builder.projectId, + builder.endpointId, + builder.region, + builder.stopSequences, + builder.images, + builder.imageProperties); + } + + public static class Builder implements ObjectBuilder { + private String baseUrl; + private Integer topK; + private Float topP; + private String model; + private Integer maxTokens; + private Float temperature; + private Float frequencyPenalty; + private Float presencePenalty; + private String projectId; + private String endpointId; + private String region; + private final List stopSequences = new ArrayList<>(); + private final List images = new ArrayList<>(); + private final List imageProperties = new ArrayList<>(); + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + public Builder topK(int topK) { + this.topK = topK; + return this; + } + + /** Top P value for nucleus sampling. */ + public Builder topP(float topP) { + this.topP = topP; + return this; + } + + public Builder frequencyPenalty(float frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + return this; + } + + /** Top P value for nucleus sampling. */ + public Builder presencePenalty(float presencePenalty) { + this.presencePenalty = presencePenalty; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** Limit the number of tokens to generate in the response. */ + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + /** + * Set tokens which should signal the model to stop generating further output. + */ + public Builder stopSequences(String... stopSequences) { + return stopSequences(Arrays.asList(stopSequences)); + } + + /** + * Set tokens which should signal the model to stop generating further output. + */ + public Builder stopSequences(List stopSequences) { + this.stopSequences.addAll(stopSequences); + return this; + } + + public Builder projectId(String projectId) { + this.projectId = projectId; + return this; + } + + public Builder endpointId(String endpointId) { + this.endpointId = endpointId; + return this; + } + + public Builder region(String region) { + this.region = region; + return this; + } + + public Builder images(String... images) { + return images(Arrays.asList(images)); + } + + public Builder images(List images) { + this.images.addAll(images); + return this; + } + + public Builder imageProperties(String... imageProperties) { + return imageProperties(Arrays.asList(imageProperties)); + } + + public Builder imageProperties(List imageProperties) { + this.imageProperties.addAll(imageProperties); + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public GoogleGenerative.Provider build() { + return new GoogleGenerative.Provider(this); + } + } + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/MistralGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/MistralGenerative.java index 60357301d..7e8cc1404 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/MistralGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/MistralGenerative.java @@ -5,7 +5,9 @@ import com.google.gson.annotations.SerializedName; import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.api.collections.generate.DynamicProvider; import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; public record MistralGenerative( @SerializedName("baseURL") String baseUrl, @@ -85,4 +87,94 @@ public Generative.Kind _kind() { return Generative.Kind.MISTRAL; } } + + public static record Provider( + String baseUrl, + Integer maxTokens, + String model, + Float temperature, + Float topP) implements DynamicProvider { + + public static Provider of( + Function> fn) { + return fn.apply(new Builder()).build(); + } + + @Override + public void appendTo( + io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative.GenerativeProvider.Builder req) { + var provider = WeaviateProtoGenerative.GenerativeMistral.newBuilder(); + if (baseUrl != null) { + provider.setBaseUrl(baseUrl); + } + if (maxTokens != null) { + provider.setMaxTokens(maxTokens); + } + if (model != null) { + provider.setModel(model); + } + if (temperature != null) { + provider.setTemperature(temperature); + } + if (topP != null) { + provider.setTopP(topP); + } + req.setMistral(provider); + } + + public Provider(Builder builder) { + this( + builder.baseUrl, + builder.maxTokens, + builder.model, + builder.temperature, + builder.topP); + } + + public static class Builder implements ObjectBuilder { + private String baseUrl; + private Float topP; + private String model; + private Integer maxTokens; + private Float temperature; + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + /** Top P value for nucleus sampling. */ + public Builder topP(float topP) { + this.topP = topP; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** Limit the number of tokens to generate in the response. */ + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public MistralGenerative.Provider build() { + return new MistralGenerative.Provider(this); + } + } + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/NvidiaGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/NvidiaGenerative.java index 80e09eff1..77a60e734 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/NvidiaGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/NvidiaGenerative.java @@ -5,7 +5,9 @@ import com.google.gson.annotations.SerializedName; import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.api.collections.generate.DynamicProvider; import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; public record NvidiaGenerative( @SerializedName("baseURL") String baseUrl, @@ -85,4 +87,94 @@ public Generative.Kind _kind() { return Generative.Kind.NVIDIA; } } + + public static record Provider( + String baseUrl, + Integer maxTokens, + String model, + Float temperature, + Float topP) implements DynamicProvider { + + public static Provider of( + Function> fn) { + return fn.apply(new Builder()).build(); + } + + @Override + public void appendTo( + io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative.GenerativeProvider.Builder req) { + var provider = WeaviateProtoGenerative.GenerativeNvidia.newBuilder(); + if (baseUrl != null) { + provider.setBaseUrl(baseUrl); + } + if (maxTokens != null) { + provider.setMaxTokens(maxTokens); + } + if (model != null) { + provider.setModel(model); + } + if (temperature != null) { + provider.setTemperature(temperature); + } + if (topP != null) { + provider.setTopP(topP); + } + req.setNvidia(provider); + } + + public Provider(Builder builder) { + this( + builder.baseUrl, + builder.maxTokens, + builder.model, + builder.temperature, + builder.topP); + } + + public static class Builder implements ObjectBuilder { + private String baseUrl; + private Float topP; + private String model; + private Integer maxTokens; + private Float temperature; + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + /** Top P value for nucleus sampling. */ + public Builder topP(float topP) { + this.topP = topP; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** Limit the number of tokens to generate in the response. */ + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public NvidiaGenerative.Provider build() { + return new NvidiaGenerative.Provider(this); + } + } + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/OllamaGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/OllamaGenerative.java index 25fc6f3c7..4e0ada76a 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/OllamaGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/OllamaGenerative.java @@ -1,11 +1,17 @@ package io.weaviate.client6.v1.api.collections.generative; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; import java.util.function.Function; import com.google.gson.annotations.SerializedName; import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.api.collections.generate.DynamicProvider; import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; public record OllamaGenerative( @SerializedName("apiEndpoint") String baseUrl, @@ -64,4 +70,102 @@ public Generative.Kind _kind() { return Generative.Kind.OLLAMA; } } + + public static record Provider( + String baseUrl, + String model, + Float temperature, + List images, + List imageProperties) implements DynamicProvider { + + public static Provider of( + Function> fn) { + return fn.apply(new Builder()).build(); + } + + @Override + public void appendTo( + io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative.GenerativeProvider.Builder req) { + var provider = WeaviateProtoGenerative.GenerativeOllama.newBuilder(); + if (baseUrl != null) { + provider.setApiEndpoint(baseUrl); + } + if (model != null) { + provider.setModel(model); + } + if (temperature != null) { + provider.setTemperature(temperature); + } + if (images != null) { + provider.setImages(WeaviateProtoBase.TextArray.newBuilder() + .addAllValues(images)); + } + if (imageProperties != null) { + provider.setImageProperties(WeaviateProtoBase.TextArray.newBuilder() + .addAllValues(imageProperties)); + } + req.setOllama(provider); + } + + public Provider(Builder builder) { + this( + builder.baseUrl, + builder.model, + builder.temperature, + builder.images, + builder.imageProperties); + } + + public static class Builder implements ObjectBuilder { + private String baseUrl; + private String model; + private Float temperature; + private final List images = new ArrayList<>(); + private final List imageProperties = new ArrayList<>(); + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + public Builder images(String... images) { + return images(Arrays.asList(images)); + } + + public Builder images(List images) { + this.images.addAll(images); + return this; + } + + public Builder imageProperties(String... imageProperties) { + return imageProperties(Arrays.asList(imageProperties)); + } + + public Builder imageProperties(List imageProperties) { + this.imageProperties.addAll(imageProperties); + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public OllamaGenerative.Provider build() { + return new OllamaGenerative.Provider(this); + } + } + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/OpenAiGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/OpenAiGenerative.java index 78da286bc..68010bf4c 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/OpenAiGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/OpenAiGenerative.java @@ -1,11 +1,17 @@ package io.weaviate.client6.v1.api.collections.generative; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; import java.util.function.Function; import com.google.gson.annotations.SerializedName; import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.api.collections.generate.DynamicProvider; import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; public record OpenAiGenerative( @SerializedName("baseURL") String baseUrl, @@ -111,4 +117,208 @@ public Generative.Kind _kind() { return Generative.Kind.OPENAI; } } + + public static record Provider( + String baseUrl, + Integer maxTokens, + String model, + Float temperature, + Integer n, + Float topP, + Float frequencyPenalty, + Float presencePenalty, + String apiVersion, + String resourceName, + String deploymentId, + List stopSequences, + List images, + List imageProperties) implements DynamicProvider { + + public static Provider of( + Function> fn) { + return fn.apply(new Builder()).build(); + } + + @Override + public void appendTo( + io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative.GenerativeProvider.Builder req) { + var provider = WeaviateProtoGenerative.GenerativeOpenAI.newBuilder(); + if (baseUrl != null) { + provider.setBaseUrl(baseUrl); + } + if (maxTokens != null) { + provider.setMaxTokens(maxTokens); + } + if (model != null) { + provider.setModel(model); + } + if (temperature != null) { + provider.setTemperature(temperature); + } + if (n != null) { + provider.setN(n); + } + if (topP != null) { + provider.setTopP(topP); + } + if (frequencyPenalty != null) { + provider.setFrequencyPenalty(frequencyPenalty); + } + if (presencePenalty != null) { + provider.setPresencePenalty(presencePenalty); + } + if (apiVersion != null) { + provider.setApiVersion(apiVersion); + } + if (resourceName != null) { + provider.setResourceName(resourceName); + } + if (deploymentId != null) { + provider.setDeploymentId(deploymentId); + } + if (stopSequences != null) { + provider.setStop(WeaviateProtoBase.TextArray.newBuilder() + .addAllValues(stopSequences)); + } + provider.setIsAzure(false); + req.setOpenai(provider); + } + + public Provider(Builder builder) { + this( + builder.baseUrl, + builder.maxTokens, + builder.model, + builder.temperature, + builder.n, + builder.topP, + builder.frequencyPenalty, + builder.presencePenalty, + builder.apiVersion, + builder.resourceName, + builder.deploymentId, + builder.stopSequences, + builder.images, + builder.imageProperties); + } + + public static class Builder implements ObjectBuilder { + private String baseUrl; + private Integer n; + private Float topP; + private String model; + private Integer maxTokens; + private Float temperature; + private Float frequencyPenalty; + private Float presencePenalty; + private String apiVersion; + private String resourceName; + private String deploymentId; + private final List stopSequences = new ArrayList<>(); + private final List images = new ArrayList<>(); + private final List imageProperties = new ArrayList<>(); + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + public Builder n(int n) { + this.n = n; + return this; + } + + /** Top P value for nucleus sampling. */ + public Builder topP(float topP) { + this.topP = topP; + return this; + } + + public Builder frequencyPenalty(float frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + return this; + } + + /** Top P value for nucleus sampling. */ + public Builder presencePenalty(float presencePenalty) { + this.presencePenalty = presencePenalty; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** Limit the number of tokens to generate in the response. */ + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + /** + * Set tokens which should signal the model to stop generating further output. + */ + public Builder stopSequences(String... stopSequences) { + return stopSequences(Arrays.asList(stopSequences)); + } + + /** + * Set tokens which should signal the model to stop generating further output. + */ + public Builder stopSequences(List stopSequences) { + this.stopSequences.addAll(stopSequences); + return this; + } + + public Builder apiVersion(String apiVersion) { + this.apiVersion = apiVersion; + return this; + } + + public Builder resourceName(String resourceName) { + this.resourceName = resourceName; + return this; + } + + public Builder deploymentId(String deploymentId) { + this.deploymentId = deploymentId; + return this; + } + + public Builder images(String... images) { + return images(Arrays.asList(images)); + } + + public Builder images(List images) { + this.images.addAll(images); + return this; + } + + public Builder imageProperties(String... imageProperties) { + return imageProperties(Arrays.asList(imageProperties)); + } + + public Builder imageProperties(List imageProperties) { + this.imageProperties.addAll(imageProperties); + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public OpenAiGenerative.Provider build() { + return new OpenAiGenerative.Provider(this); + } + } + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/XaiGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/XaiGenerative.java index c7b0a81d6..795aa659f 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/XaiGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/XaiGenerative.java @@ -1,11 +1,17 @@ package io.weaviate.client6.v1.api.collections.generative; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; import java.util.function.Function; import com.google.gson.annotations.SerializedName; import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.api.collections.generate.DynamicProvider; import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; public record XaiGenerative( @SerializedName("baseURL") String baseUrl, @@ -85,4 +91,126 @@ public Generative.Kind _kind() { return Generative.Kind.XAI; } } + + public static record Provider( + String baseUrl, + Integer maxTokens, + String model, + Float temperature, + Float topP, + List images, + List imageProperties) implements DynamicProvider { + + public static Provider of( + Function> fn) { + return fn.apply(new Builder()).build(); + } + + @Override + public void appendTo( + io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative.GenerativeProvider.Builder req) { + var provider = WeaviateProtoGenerative.GenerativeXAI.newBuilder(); + if (baseUrl != null) { + provider.setBaseUrl(baseUrl); + } + if (maxTokens != null) { + provider.setMaxTokens(maxTokens); + } + if (model != null) { + provider.setModel(model); + } + if (temperature != null) { + provider.setTemperature(temperature); + } + if (topP != null) { + provider.setTopP(topP); + } + if (images != null) { + provider.setImages(WeaviateProtoBase.TextArray.newBuilder() + .addAllValues(images)); + } + if (imageProperties != null) { + provider.setImageProperties(WeaviateProtoBase.TextArray.newBuilder() + .addAllValues(imageProperties)); + } + req.setXai(provider); + } + + public Provider(Builder builder) { + this( + builder.baseUrl, + builder.maxTokens, + builder.model, + builder.temperature, + builder.topP, + builder.images, + builder.imageProperties); + } + + public static class Builder implements ObjectBuilder { + private String baseUrl; + private Float topP; + private String model; + private Integer maxTokens; + private Float temperature; + private final List images = new ArrayList<>(); + private final List imageProperties = new ArrayList<>(); + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + /** Top P value for nucleus sampling. */ + public Builder topP(float topP) { + this.topP = topP; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** Limit the number of tokens to generate in the response. */ + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + public Builder images(String... images) { + return images(Arrays.asList(images)); + } + + public Builder images(List images) { + this.images.addAll(images); + return this; + } + + public Builder imageProperties(String... imageProperties) { + return imageProperties(Arrays.asList(imageProperties)); + } + + public Builder imageProperties(List imageProperties) { + this.imageProperties.addAll(imageProperties); + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public XaiGenerative.Provider build() { + return new XaiGenerative.Provider(this); + } + } + } } From 4f0bf5e51e943d1ce549024ee7e5b459fc01d3c9 Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Wed, 22 Oct 2025 16:16:36 +0200 Subject: [PATCH 21/25] chore: remove Azure-related config from OpenAI --- .../generative/OpenAiGenerative.java | 33 ------------------- 1 file changed, 33 deletions(-) diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/OpenAiGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/OpenAiGenerative.java index 68010bf4c..dbcf4cb0c 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/OpenAiGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/OpenAiGenerative.java @@ -127,9 +127,6 @@ public static record Provider( Float topP, Float frequencyPenalty, Float presencePenalty, - String apiVersion, - String resourceName, - String deploymentId, List stopSequences, List images, List imageProperties) implements DynamicProvider { @@ -167,15 +164,6 @@ public void appendTo( if (presencePenalty != null) { provider.setPresencePenalty(presencePenalty); } - if (apiVersion != null) { - provider.setApiVersion(apiVersion); - } - if (resourceName != null) { - provider.setResourceName(resourceName); - } - if (deploymentId != null) { - provider.setDeploymentId(deploymentId); - } if (stopSequences != null) { provider.setStop(WeaviateProtoBase.TextArray.newBuilder() .addAllValues(stopSequences)); @@ -194,9 +182,6 @@ public Provider(Builder builder) { builder.topP, builder.frequencyPenalty, builder.presencePenalty, - builder.apiVersion, - builder.resourceName, - builder.deploymentId, builder.stopSequences, builder.images, builder.imageProperties); @@ -211,9 +196,6 @@ public static class Builder implements ObjectBuilder private Float temperature; private Float frequencyPenalty; private Float presencePenalty; - private String apiVersion; - private String resourceName; - private String deploymentId; private final List stopSequences = new ArrayList<>(); private final List images = new ArrayList<>(); private final List imageProperties = new ArrayList<>(); @@ -273,21 +255,6 @@ public Builder stopSequences(List stopSequences) { return this; } - public Builder apiVersion(String apiVersion) { - this.apiVersion = apiVersion; - return this; - } - - public Builder resourceName(String resourceName) { - this.resourceName = resourceName; - return this; - } - - public Builder deploymentId(String deploymentId) { - this.deploymentId = deploymentId; - return this; - } - public Builder images(String... images) { return images(Arrays.asList(images)); } From af31af4c3c281a2617fab260303f48e7ae0f5e5d Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Wed, 22 Oct 2025 16:22:14 +0200 Subject: [PATCH 22/25] feat: add dynamic provider static factories --- .../collections/generate/DynamicProvider.java | 133 ++++++++++++++++++ 1 file changed, 133 insertions(+) diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/DynamicProvider.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/DynamicProvider.java index d5930af65..883ba5ed1 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generate/DynamicProvider.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generate/DynamicProvider.java @@ -3,6 +3,18 @@ import java.util.function.Function; import io.weaviate.client6.v1.api.collections.generative.AnthropicGenerative; +import io.weaviate.client6.v1.api.collections.generative.AnyscaleGenerative; +import io.weaviate.client6.v1.api.collections.generative.AwsGenerative; +import io.weaviate.client6.v1.api.collections.generative.AzureOpenAiGenerative; +import io.weaviate.client6.v1.api.collections.generative.CohereGenerative; +import io.weaviate.client6.v1.api.collections.generative.DatabricksGenerative; +import io.weaviate.client6.v1.api.collections.generative.FriendliaiGenerative; +import io.weaviate.client6.v1.api.collections.generative.GoogleGenerative; +import io.weaviate.client6.v1.api.collections.generative.MistralGenerative; +import io.weaviate.client6.v1.api.collections.generative.NvidiaGenerative; +import io.weaviate.client6.v1.api.collections.generative.OllamaGenerative; +import io.weaviate.client6.v1.api.collections.generative.OpenAiGenerative; +import io.weaviate.client6.v1.api.collections.generative.XaiGenerative; import io.weaviate.client6.v1.internal.ObjectBuilder; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; @@ -18,4 +30,125 @@ public static DynamicProvider anthropic( Function> fn) { return AnthropicGenerative.Provider.of(fn); } + + /** + * Configure {@code generative-anyscale} as a dynamic provider. + * + * @param fn Lambda expression for optional parameters. + */ + public static DynamicProvider anyscale( + Function> fn) { + return AnyscaleGenerative.Provider.of(fn); + } + + /** + * Configure {@code generative-aws} as a dynamic provider. + * + * @param fn Lambda expression for optional parameters. + */ + public static DynamicProvider aws( + Function> fn) { + return AwsGenerative.Provider.of(fn); + } + + /** + * Configure {@code generative-cohere} as a dynamic provider. + * + * @param fn Lambda expression for optional parameters. + */ + public static DynamicProvider cohere( + Function> fn) { + return CohereGenerative.Provider.of(fn); + } + + /** + * Configure {@code generative-databricks} as a dynamic provider. + * + * @param fn Lambda expression for optional parameters. + */ + public static DynamicProvider databricks( + Function> fn) { + return DatabricksGenerative.Provider.of(fn); + } + + /** + * Configure {@code generative-friendliai} as a dynamic provider. + * + * @param fn Lambda expression for optional parameters. + */ + public static DynamicProvider friendliai( + Function> fn) { + return FriendliaiGenerative.Provider.of(fn); + } + + /** + * Configure {@code generative-palm} as a dynamic provider. + * + * @param fn Lambda expression for optional parameters. + */ + public static DynamicProvider google( + Function> fn) { + return GoogleGenerative.Provider.of(fn); + } + + /** + * Configure {@code generative-mistral} as a dynamic provider. + * + * @param fn Lambda expression for optional parameters. + */ + public static DynamicProvider mistral( + Function> fn) { + return MistralGenerative.Provider.of(fn); + } + + /** + * Configure {@code generative-nvidia} as a dynamic provider. + * + * @param fn Lambda expression for optional parameters. + */ + public static DynamicProvider nvidia( + Function> fn) { + return NvidiaGenerative.Provider.of(fn); + } + + /** + * Configure {@code generative-ollama} as a dynamic provider. + * + * @param fn Lambda expression for optional parameters. + */ + public static DynamicProvider ollama( + Function> fn) { + return OllamaGenerative.Provider.of(fn); + } + + /** + * Configure {@code generative-openai} as a dynamic provider. + * + * @param fn Lambda expression for optional parameters. + */ + public static DynamicProvider openai( + Function> fn) { + return OpenAiGenerative.Provider.of(fn); + } + + /** + * Configure {@code generative-openai} as a dynamic provider. + * + * @param fn Lambda expression for optional parameters. + */ + public static DynamicProvider azure( + Function> fn) { + return AzureOpenAiGenerative.Provider.of(fn); + } + + /** + * Configure {@code generative-xai} as a dynamic provider. + * + * @param fn Lambda expression for optional parameters. + */ + public static DynamicProvider xai( + Function> fn) { + return XaiGenerative.Provider.of(fn); + } + } From f26b55d9a97daab0e18a7cc60f0c9b5de4b48682 Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Wed, 22 Oct 2025 16:25:47 +0200 Subject: [PATCH 23/25] feat: extend generative functionality to async client --- .../collections/CollectionHandleAsync.java | 4 ++ .../generate/WeaviateGenerateClientAsync.java | 41 +++++++++++++++++++ 2 files changed, 45 insertions(+) create mode 100644 src/main/java/io/weaviate/client6/v1/api/collections/generate/WeaviateGenerateClientAsync.java diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleAsync.java b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleAsync.java index 5e8196dd2..83d18ed2f 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleAsync.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleAsync.java @@ -8,6 +8,7 @@ import io.weaviate.client6.v1.api.collections.aggregate.WeaviateAggregateClientAsync; import io.weaviate.client6.v1.api.collections.config.WeaviateConfigClientAsync; import io.weaviate.client6.v1.api.collections.data.WeaviateDataClientAsync; +import io.weaviate.client6.v1.api.collections.generate.WeaviateGenerateClientAsync; import io.weaviate.client6.v1.api.collections.pagination.AsyncPaginator; import io.weaviate.client6.v1.api.collections.query.ConsistencyLevel; import io.weaviate.client6.v1.api.collections.query.WeaviateQueryClientAsync; @@ -21,6 +22,7 @@ public class CollectionHandleAsync { public final WeaviateConfigClientAsync config; public final WeaviateDataClientAsync data; public final WeaviateQueryClientAsync query; + public final WeaviateGenerateClientAsync generate; public final WeaviateAggregateClientAsync aggregate; public final WeaviateTenantsClientAsync tenants; @@ -35,6 +37,7 @@ public CollectionHandleAsync( this.config = new WeaviateConfigClientAsync(collection, restTransport, grpcTransport, defaults); this.aggregate = new WeaviateAggregateClientAsync(collection, grpcTransport, defaults); this.query = new WeaviateQueryClientAsync<>(collection, grpcTransport, defaults); + this.generate = new WeaviateGenerateClientAsync<>(collection, grpcTransport, defaults); this.data = new WeaviateDataClientAsync<>(collection, restTransport, grpcTransport, defaults); this.defaults = defaults; @@ -46,6 +49,7 @@ private CollectionHandleAsync(CollectionHandleAsync c, CollectionHa this.config = new WeaviateConfigClientAsync(c.config, defaults); this.aggregate = new WeaviateAggregateClientAsync(c.aggregate, defaults); this.query = new WeaviateQueryClientAsync<>(c.query, defaults); + this.generate = new WeaviateGenerateClientAsync<>(c.generate, defaults); this.data = new WeaviateDataClientAsync<>(c.data, defaults); this.defaults = defaults; diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/WeaviateGenerateClientAsync.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/WeaviateGenerateClientAsync.java new file mode 100644 index 000000000..eff3866a7 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generate/WeaviateGenerateClientAsync.java @@ -0,0 +1,41 @@ +package io.weaviate.client6.v1.api.collections.generate; + +import java.util.concurrent.CompletableFuture; + +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; +import io.weaviate.client6.v1.api.collections.query.GroupBy; +import io.weaviate.client6.v1.api.collections.query.QueryOperator; +import io.weaviate.client6.v1.internal.grpc.GrpcTransport; +import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; + +public class WeaviateGenerateClientAsync + extends + AbstractGenerateClient>, CompletableFuture>> { + + public WeaviateGenerateClientAsync( + CollectionDescriptor collection, + GrpcTransport grpcTransport, + CollectionHandleDefaults defaults) { + super(collection, grpcTransport, defaults); + } + + /** Copy constructor that sets new defaults. */ + public WeaviateGenerateClientAsync(WeaviateGenerateClientAsync c, CollectionHandleDefaults defaults) { + super(c, defaults); + } + + @Override + protected final CompletableFuture> performRequest(QueryOperator operator, + GenerativeTask generate) { + var request = new GenerativeRequest(operator, generate, null); + return this.grpcTransport.performRequestAsync(request, GenerativeRequest.rpc(collection, defaults)); + } + + @Override + protected final CompletableFuture> performRequest(QueryOperator operator, + GenerativeTask generate, + GroupBy groupBy) { + var request = new GenerativeRequest(operator, generate, groupBy); + return this.grpcTransport.performRequestAsync(request, GenerativeRequest.grouped(collection, defaults)); + } +} From fe07d37d22a2a47f6172f39c559fa9231c07cf0a Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Wed, 22 Oct 2025 16:42:21 +0200 Subject: [PATCH 24/25] chore: rename generativeProvider -> dynamicProvider --- .../client6/v1/api/collections/generate/GenerativeTask.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeTask.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeTask.java index 238894775..013697b66 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeTask.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeTask.java @@ -85,7 +85,7 @@ public Builder debug(boolean enable) { return this; } - public Builder generativeProvider(DynamicProvider provider) { + public Builder dynamicProvider(DynamicProvider provider) { providers.clear(); // Protobuf allows `repeated` but the server expects there to be 1. providers.add(provider); return this; @@ -146,7 +146,7 @@ public Builder properties(List properties) { return this; } - public Builder generativeProvider(DynamicProvider provider) { + public Builder dynamicProvider(DynamicProvider provider) { providers.clear(); // Protobuf allows `repeated` but the server expects there to be 1. providers.add(provider); return this; From e66e30440f599af53d81016b0c00d062560c88f2 Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Wed, 22 Oct 2025 16:56:13 +0200 Subject: [PATCH 25/25] chore: remove redundant method from ProviderMetadata interface --- .../v1/api/collections/generative/AnthropicGenerative.java | 6 ------ .../v1/api/collections/generative/AnyscaleGenerative.java | 5 ----- .../v1/api/collections/generative/AwsGenerative.java | 5 ----- .../v1/api/collections/generative/CohereGenerative.java | 5 ----- .../v1/api/collections/generative/DatabricksGenerative.java | 5 ----- .../v1/api/collections/generative/DummyGenerative.java | 5 ----- .../v1/api/collections/generative/FriendliaiGenerative.java | 5 ----- .../v1/api/collections/generative/GoogleGenerative.java | 5 ----- .../v1/api/collections/generative/MistralGenerative.java | 5 ----- .../v1/api/collections/generative/NvidiaGenerative.java | 5 ----- .../v1/api/collections/generative/OllamaGenerative.java | 5 ----- .../v1/api/collections/generative/OpenAiGenerative.java | 5 ----- .../v1/api/collections/generative/ProviderMetadata.java | 4 ---- .../v1/api/collections/generative/XaiGenerative.java | 5 ----- 14 files changed, 70 deletions(-) diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnthropicGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnthropicGenerative.java index 957ba708c..bb836d49f 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnthropicGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnthropicGenerative.java @@ -112,12 +112,6 @@ public AnthropicGenerative build() { } public static record Metadata(Usage usage) implements ProviderMetadata { - - @Override - public Generative.Kind _kind() { - return Generative.Kind.ANTHROPIC; - } - public static record Usage(Long inputTokens, Long outputTokens) { } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnyscaleGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnyscaleGenerative.java index acfb810cc..a2279e0a2 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnyscaleGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnyscaleGenerative.java @@ -70,11 +70,6 @@ public AnyscaleGenerative build() { } public static record Metadata() implements ProviderMetadata { - - @Override - public Generative.Kind _kind() { - return Generative.Kind.ANYSCALE; - } } public static record Provider( diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AwsGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AwsGenerative.java index 12605908e..1589b15db 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AwsGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AwsGenerative.java @@ -76,11 +76,6 @@ public AwsGenerative build() { } public static record Metadata() implements ProviderMetadata { - - @Override - public Generative.Kind _kind() { - return Generative.Kind.AWS; - } } public static record Provider( diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/CohereGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/CohereGenerative.java index 4c84b4339..9a3d8860e 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/CohereGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/CohereGenerative.java @@ -122,11 +122,6 @@ public CohereGenerative build() { public static record Metadata(ApiVersion apiVersion, BilledUnits billedUnits, Tokens tokens, List warnings) implements ProviderMetadata { - @Override - public Generative.Kind _kind() { - return Generative.Kind.COHERE; - } - public static record ApiVersion(String version, Boolean deprecated, Boolean experimental) { } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/DatabricksGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/DatabricksGenerative.java index 6005970ae..df2b44f14 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/DatabricksGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/DatabricksGenerative.java @@ -93,11 +93,6 @@ public DatabricksGenerative build() { } public static record Metadata(ProviderMetadata.Usage usage) implements ProviderMetadata { - - @Override - public Generative.Kind _kind() { - return Generative.Kind.DATABRICKS; - } } public static record Provider( diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/DummyGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/DummyGenerative.java index cfe300a5e..8f45163fe 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/DummyGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/DummyGenerative.java @@ -14,10 +14,5 @@ public Object _self() { } public static record Metadata() implements ProviderMetadata { - - @Override - public Kind _kind() { - return Generative.Kind.DUMMY; - } } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/FriendliaiGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/FriendliaiGenerative.java index 5e0d3c16c..d154dc8b3 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/FriendliaiGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/FriendliaiGenerative.java @@ -81,11 +81,6 @@ public FriendliaiGenerative build() { } public static record Metadata(ProviderMetadata.Usage usage) implements ProviderMetadata { - - @Override - public Generative.Kind _kind() { - return Generative.Kind.FRIENDLIAI; - } } public static record Provider( diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/GoogleGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/GoogleGenerative.java index 9fba3987e..084bab0b2 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/GoogleGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/GoogleGenerative.java @@ -112,11 +112,6 @@ public GoogleGenerative build() { public static record Metadata(TokenMetadata tokens, Usage usage) implements ProviderMetadata { - @Override - public Generative.Kind _kind() { - return Generative.Kind.GOOGLE; - } - public static record TokenCount(Long totalBillableCharacters, Long totalTokens) { } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/MistralGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/MistralGenerative.java index 7e8cc1404..3f64cd06e 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/MistralGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/MistralGenerative.java @@ -81,11 +81,6 @@ public MistralGenerative build() { } public static record Metadata(ProviderMetadata.Usage usage) implements ProviderMetadata { - - @Override - public Generative.Kind _kind() { - return Generative.Kind.MISTRAL; - } } public static record Provider( diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/NvidiaGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/NvidiaGenerative.java index 77a60e734..6bc156a16 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/NvidiaGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/NvidiaGenerative.java @@ -81,11 +81,6 @@ public NvidiaGenerative build() { } public static record Metadata(ProviderMetadata.Usage usage) implements ProviderMetadata { - - @Override - public Generative.Kind _kind() { - return Generative.Kind.NVIDIA; - } } public static record Provider( diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/OllamaGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/OllamaGenerative.java index 4e0ada76a..89a356b8d 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/OllamaGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/OllamaGenerative.java @@ -64,11 +64,6 @@ public OllamaGenerative build() { } public static record Metadata() implements ProviderMetadata { - - @Override - public Generative.Kind _kind() { - return Generative.Kind.OLLAMA; - } } public static record Provider( diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/OpenAiGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/OpenAiGenerative.java index dbcf4cb0c..0417aacda 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/OpenAiGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/OpenAiGenerative.java @@ -111,11 +111,6 @@ public OpenAiGenerative build() { } public static record Metadata(ProviderMetadata.Usage usage) implements ProviderMetadata { - - @Override - public Generative.Kind _kind() { - return Generative.Kind.OPENAI; - } } public static record Provider( diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/ProviderMetadata.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/ProviderMetadata.java index 0d3dc27a1..884dd6dd3 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/ProviderMetadata.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/ProviderMetadata.java @@ -1,10 +1,6 @@ package io.weaviate.client6.v1.api.collections.generative; -import io.weaviate.client6.v1.api.collections.Generative; - public interface ProviderMetadata { - Generative.Kind _kind(); - record Usage( Long promptTokens, Long completionTokens, diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/XaiGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/XaiGenerative.java index 795aa659f..d736c658c 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/XaiGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/XaiGenerative.java @@ -85,11 +85,6 @@ public XaiGenerative build() { } public static record Metadata(ProviderMetadata.Usage usage) implements ProviderMetadata { - - @Override - public Generative.Kind _kind() { - return Generative.Kind.XAI; - } } public static record Provider(