diff --git a/src/it/java/io/weaviate/integration/SearchITest.java b/src/it/java/io/weaviate/integration/SearchITest.java index c08c92a6d..46a2ded83 100644 --- a/src/it/java/io/weaviate/integration/SearchITest.java +++ b/src/it/java/io/weaviate/integration/SearchITest.java @@ -27,6 +27,9 @@ import io.weaviate.client6.v1.api.collections.WeaviateMetadata; import io.weaviate.client6.v1.api.collections.WeaviateObject; import io.weaviate.client6.v1.api.collections.data.Reference; +import io.weaviate.client6.v1.api.collections.generate.GenerativeObject; +import io.weaviate.client6.v1.api.collections.generate.TaskOutput; +import io.weaviate.client6.v1.api.collections.generative.DummyGenerative; import io.weaviate.client6.v1.api.collections.query.GroupBy; import io.weaviate.client6.v1.api.collections.query.Metadata; import io.weaviate.client6.v1.api.collections.query.QueryMetadata; @@ -47,6 +50,7 @@ public class SearchITest extends ConcurrentTest { Weaviate.custom() .withContextionaryUrl(Contextionary.URL) .withImageInference(Img2VecNeural.URL, Img2VecNeural.MODULE) + .addModules("generative-dummy") .build(), Container.IMG2VEC_NEURAL, Container.CONTEXTIONARY); @@ -549,4 +553,98 @@ public void testNearVector_targetVectors() throws IOException { .hasSize(1).extracting(WeaviateObject::uuid) .containsExactly(thing456.uuids().get(0)); } + + @Test + public void testGenerative_bm25() throws IOException { + // Arrange + var nsThings = ns("Things"); + + client.collections.create(nsThings, + c -> c + .properties(Property.text("title")) + .generativeModule(new DummyGenerative()) + .vectorConfig(VectorConfig.text2vecContextionary( + t2v -> t2v.sourceProperties("title")))); + + var things = client.collections.use(nsThings); + + things.data.insertMany( + Map.of("title", "Salad Fork"), + Map.of("title", "Dessert Fork")); + + // Act + var french = things.generate.bm25( + "fork", + bm25 -> bm25.queryProperties("title").limit(2), + generate -> generate + .singlePrompt("translate to French") + .groupedTask("count letters R")); + + // Assert + Assertions.assertThat(french.objects()) + .as("individual results") + .hasSize(2) + .extracting(GenerativeObject::generated) + .allSatisfy(generated -> { + Assertions.assertThat(generated.text()).isNotBlank(); + }); + + Assertions.assertThat(french.generated()) + .as("summary") + .extracting(TaskOutput::text, InstanceOfAssertFactories.STRING) + .isNotBlank(); + } + + @Test + public void testGenerative_bm25_groupBy() throws IOException { + // Arrange + var nsThings = ns("Things"); + + client.collections.create(nsThings, + c -> c + .properties(Property.text("title")) + .generativeModule(new DummyGenerative()) + .vectorConfig(VectorConfig.text2vecContextionary( + t2v -> t2v.sourceProperties("title")))); + + var things = client.collections.use(nsThings); + + things.data.insertMany( + Map.of("title", "Salad Fork"), + Map.of("title", "Dessert Fork")); + + // Act + var french = things.generate.bm25( + "fork", + bm25 -> bm25.queryProperties("title").limit(2), + generate -> generate + .singlePrompt("translate to French") + .groupedTask("count letters R"), + GroupBy.property("title", 5, 5)); + + // Assert + Assertions.assertThat(french.objects()) + .as("individual results") + .hasSize(2); + + Assertions.assertThat(french.groups()) + .as("grouped results") + .hasSize(2) + .allSatisfy((groupName, group) -> { + Assertions.assertThat(group.objects()) + .describedAs("objects in group %s", groupName) + .hasSize(1); + + Assertions.assertThat(group.generated()) + .describedAs("summary group %s", groupName) + .extracting(TaskOutput::text, InstanceOfAssertFactories.STRING) + .isNotBlank(); + + }); + + Assertions.assertThat(french.generated()) + .as("summary") + .extracting(TaskOutput::text, InstanceOfAssertFactories.STRING) + .isNotBlank(); + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandle.java b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandle.java index 08f41eca3..7af8ed549 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandle.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandle.java @@ -6,6 +6,7 @@ import io.weaviate.client6.v1.api.collections.aggregate.WeaviateAggregateClient; import io.weaviate.client6.v1.api.collections.config.WeaviateConfigClient; import io.weaviate.client6.v1.api.collections.data.WeaviateDataClient; +import io.weaviate.client6.v1.api.collections.generate.WeaviateGenerateClient; import io.weaviate.client6.v1.api.collections.pagination.Paginator; import io.weaviate.client6.v1.api.collections.query.ConsistencyLevel; import io.weaviate.client6.v1.api.collections.query.WeaviateQueryClient; @@ -20,6 +21,7 @@ public class CollectionHandle { public final WeaviateDataClient data; public final WeaviateQueryClient query; public final WeaviateAggregateClient aggregate; + public final WeaviateGenerateClient generate; public final WeaviateTenantsClient tenants; private final CollectionHandleDefaults defaults; @@ -32,6 +34,7 @@ public CollectionHandle( this.config = new WeaviateConfigClient(collection, restTransport, grpcTransport, defaults); this.aggregate = new WeaviateAggregateClient(collection, grpcTransport, defaults); this.query = new WeaviateQueryClient<>(collection, grpcTransport, defaults); + this.generate = new WeaviateGenerateClient<>(collection, grpcTransport, defaults); this.data = new WeaviateDataClient<>(collection, restTransport, grpcTransport, defaults); this.defaults = defaults; @@ -43,6 +46,7 @@ private CollectionHandle(CollectionHandle c, CollectionHandleDefaul this.config = new WeaviateConfigClient(c.config, defaults); this.aggregate = new WeaviateAggregateClient(c.aggregate, defaults); this.query = new WeaviateQueryClient<>(c.query, defaults); + this.generate = new WeaviateGenerateClient<>(c.generate, defaults); this.data = new WeaviateDataClient<>(c.data, defaults); this.defaults = defaults; diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleAsync.java b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleAsync.java index 5e8196dd2..83d18ed2f 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleAsync.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/CollectionHandleAsync.java @@ -8,6 +8,7 @@ import io.weaviate.client6.v1.api.collections.aggregate.WeaviateAggregateClientAsync; import io.weaviate.client6.v1.api.collections.config.WeaviateConfigClientAsync; import io.weaviate.client6.v1.api.collections.data.WeaviateDataClientAsync; +import io.weaviate.client6.v1.api.collections.generate.WeaviateGenerateClientAsync; import io.weaviate.client6.v1.api.collections.pagination.AsyncPaginator; import io.weaviate.client6.v1.api.collections.query.ConsistencyLevel; import io.weaviate.client6.v1.api.collections.query.WeaviateQueryClientAsync; @@ -21,6 +22,7 @@ public class CollectionHandleAsync { public final WeaviateConfigClientAsync config; public final WeaviateDataClientAsync data; public final WeaviateQueryClientAsync query; + public final WeaviateGenerateClientAsync generate; public final WeaviateAggregateClientAsync aggregate; public final WeaviateTenantsClientAsync tenants; @@ -35,6 +37,7 @@ public CollectionHandleAsync( this.config = new WeaviateConfigClientAsync(collection, restTransport, grpcTransport, defaults); this.aggregate = new WeaviateAggregateClientAsync(collection, grpcTransport, defaults); this.query = new WeaviateQueryClientAsync<>(collection, grpcTransport, defaults); + this.generate = new WeaviateGenerateClientAsync<>(collection, grpcTransport, defaults); this.data = new WeaviateDataClientAsync<>(collection, restTransport, grpcTransport, defaults); this.defaults = defaults; @@ -46,6 +49,7 @@ private CollectionHandleAsync(CollectionHandleAsync c, CollectionHa this.config = new WeaviateConfigClientAsync(c.config, defaults); this.aggregate = new WeaviateAggregateClientAsync(c.aggregate, defaults); this.query = new WeaviateQueryClientAsync<>(c.query, defaults); + this.generate = new WeaviateGenerateClientAsync<>(c.generate, defaults); this.data = new WeaviateDataClientAsync<>(c.data, defaults); this.defaults = defaults; diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java b/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java index 632713cdd..da7818c75 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/Generative.java @@ -6,20 +6,47 @@ import java.util.function.Function; import com.google.gson.Gson; +import com.google.gson.JsonParser; import com.google.gson.TypeAdapter; import com.google.gson.TypeAdapterFactory; import com.google.gson.reflect.TypeToken; import com.google.gson.stream.JsonReader; -import com.google.gson.stream.JsonToken; import com.google.gson.stream.JsonWriter; +import io.weaviate.client6.v1.api.collections.generative.AnthropicGenerative; +import io.weaviate.client6.v1.api.collections.generative.AnyscaleGenerative; +import io.weaviate.client6.v1.api.collections.generative.AwsGenerative; +import io.weaviate.client6.v1.api.collections.generative.AzureOpenAiGenerative; import io.weaviate.client6.v1.api.collections.generative.CohereGenerative; +import io.weaviate.client6.v1.api.collections.generative.DatabricksGenerative; +import io.weaviate.client6.v1.api.collections.generative.DummyGenerative; +import io.weaviate.client6.v1.api.collections.generative.FriendliaiGenerative; +import io.weaviate.client6.v1.api.collections.generative.GoogleGenerative; +import io.weaviate.client6.v1.api.collections.generative.MistralGenerative; +import io.weaviate.client6.v1.api.collections.generative.NvidiaGenerative; +import io.weaviate.client6.v1.api.collections.generative.OllamaGenerative; +import io.weaviate.client6.v1.api.collections.generative.OpenAiGenerative; +import io.weaviate.client6.v1.api.collections.generative.XaiGenerative; import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.TaggedUnion; import io.weaviate.client6.v1.internal.json.JsonEnum; -public interface Generative { +public interface Generative extends TaggedUnion { public enum Kind implements JsonEnum { - COHERE("generative-cohere"); + ANYSCALE("generative-anyscale"), + AWS("generative-aws"), + ANTHROPIC("generative-anthropic"), + COHERE("generative-cohere"), + DATABRICKS("generative-databricks"), + FRIENDLIAI("generative-friendliai"), + GOOGLE("generative-palm"), + MISTRAL("generative-mistral"), + NVIDIA("generative-nvidia"), + OLLAMA("generative-ollama"), + OPENAI("generative-openai"), + AZURE_OPENAI("generative-openai"), + XAI("generative-xai"), + DUMMY("generative-dummy"); private static final Map jsonValueMap = JsonEnum.collectNames(Kind.values()); private final String jsonValue; @@ -38,17 +65,65 @@ public static Kind valueOfJson(String jsonValue) { } } - Kind _kind(); + /** Configure a default {@code generative-anthropic} module. */ + public static Generative anthropic() { + return AnthropicGenerative.of(); + } + + /** + * Configure a {@code generative-anthropic} module. + * + * @param fn Lambda expression for optional parameters. + */ + public static Generative anthropic(Function> fn) { + return AnthropicGenerative.of(fn); + } + + /** Configure a default {@code generative-anyscale} module. */ + public static Generative anyscale() { + return AnyscaleGenerative.of(); + } + + /** + * Configure a {@code generative-anyscale} module. + * + * @param fn Lambda expression for optional parameters. + */ + public static Generative anyscale(Function> fn) { + return AnyscaleGenerative.of(fn); + } + + /** + * Configure a default {@code generative-aws} module. + * + * @param region AWS region. + * @param service AWS service to use, e.g. {@code "bedrock"} or + * {@code "sagemaker"}. + */ + public static Generative aws(String region, String service) { + return AwsGenerative.of(region, service); + } - Object _self(); + /** + * Configure a {@code generative-aws} module. + * + * @param region AWS region. + * @param service AWS service to use, e.g. {@code "bedrock"} or + * {@code "sagemaker"}. + * @param fn Lambda expression for optional parameters. + */ + public static Generative aws(String region, String service, + Function> fn) { + return AwsGenerative.of(region, service, fn); + } - /** Configure a default Cohere generative module. */ + /** Configure a default {@code generative-cohere} module. */ public static Generative cohere() { return CohereGenerative.of(); } /** - * Configure a Cohere generative module. + * Configure a {@code generative-cohere} module. * * @param fn Lambda expression for optional parameters. */ @@ -56,6 +131,344 @@ public static Generative cohere(Function> fn) { + return DatabricksGenerative.of(baseURL, fn); + } + + /** Configure a default {@code generative-frienliai} module. */ + public static Generative frienliai() { + return FriendliaiGenerative.of(); + } + + /** + * Configure a {@code generative-frienliai} module. + * + * @param fn Lambda expression for optional parameters. + */ + public static Generative friendliai(Function> fn) { + return FriendliaiGenerative.of(fn); + } + + /** Configure a default {@code generative-palm} module. */ + public static Generative google(String projectId) { + return GoogleGenerative.of(projectId); + } + + /** + * Configure a {@code generative-palm} module. + * + * @param projectId Project ID. + * @param fn Lambda expression for optional parameters. + */ + public static Generative google(String projectId, + Function> fn) { + return GoogleGenerative.of(projectId, fn); + } + + /** Configure a default {@code generative-mistral} module. */ + public static Generative mistral() { + return MistralGenerative.of(); + } + + /** + * Configure a {@code generative-mistral} module. + * + * @param fn Lambda expression for optional parameters. + */ + public static Generative mistral(Function> fn) { + return MistralGenerative.of(fn); + } + + /** Configure a default {@code generative-nvidia} module. */ + public static Generative nvidia() { + return NvidiaGenerative.of(); + } + + /** + * Configure a {@code generative-nvidia} module. + * + * @param fn Lambda expression for optional parameters. + */ + public static Generative nvidia(Function> fn) { + return NvidiaGenerative.of(fn); + } + + /** Configure a default {@code generative-ollama} module. */ + public static Generative ollama() { + return OllamaGenerative.of(); + } + + /** + * Configure a {@code generative-ollama} module. + * + * @param fn Lambda expression for optional parameters. + */ + public static Generative ollama(Function> fn) { + return OllamaGenerative.of(fn); + } + + /** Configure a default {@code generative-openai} module. */ + public static Generative openai() { + return OpenAiGenerative.of(); + } + + /** + * Configure a {@code generative-openai} module. + * + * @param fn Lambda expression for optional parameters. + */ + public static Generative openai(Function> fn) { + return OpenAiGenerative.of(fn); + } + + /** + * Configure a default {@code generative-openai} module + * hosted on Microsoft Azure. + * + * @param resourceName Name of the Azure OpenAI resource. + * @param deploymentId Azure OpenAI deployment ID. + */ + public static Generative azure(String resourceName, String deploymentId) { + return AzureOpenAiGenerative.of(resourceName, deploymentId); + } + + /** + * Configure a {@code generative-openai} module hosted on Microsoft Azure. + * + * @param resourceName Name of the Azure OpenAI resource. + * @param deploymentId Azure OpenAI deployment ID. + * @param fn Lambda expression for optional parameters. + */ + public static Generative azure(String resourceName, String deploymentId, + Function> fn) { + return AzureOpenAiGenerative.of(resourceName, deploymentId, fn); + } + + /** Configure a default {@code generative-xai} module. */ + public static Generative xai() { + return XaiGenerative.of(); + } + + /** + * Configure a {@code generative-xai} module. + * + * @param fn Lambda expression for optional parameters. + */ + public static Generative xai(Function> fn) { + return XaiGenerative.of(fn); + } + + /** Is this a {@code generative-anyscale} provider? */ + default boolean isAnyscale() { + return _is(Generative.Kind.ANYSCALE); + } + + /** + * Get as {@link AnyscaleGenerative} instance. + * + * @throws IllegalStateException if the current kind is not + * {@code generative-anyscale}. + */ + default AnyscaleGenerative asAnyscale() { + return _as(Generative.Kind.ANYSCALE); + } + + /** Is this a {@code generative-aws} provider? */ + default boolean isAws() { + return _is(Generative.Kind.AWS); + } + + /** + * Get as {@link AwsGenerative} instance. + * + * @throws IllegalStateException if the current kind is not + * {@code generative-aws}. + */ + default AwsGenerative asAws() { + return _as(Generative.Kind.AWS); + } + + /** Is this a {@code generative-anthropic} provider? */ + default boolean isAnthropic() { + return _is(Generative.Kind.ANTHROPIC); + } + + /** + * Get as {@link AnthropicGenerative} instance. + * + * @throws IllegalStateException if the current kind is not + * {@code generative-anthropic}. + */ + default AnthropicGenerative asAnthropic() { + return _as(Generative.Kind.ANTHROPIC); + } + + /** Is this a {@code generative-cohere} provider? */ + default boolean isCohere() { + return _is(Generative.Kind.COHERE); + } + + /** + * Get as {@link CohereGenerative} instance. + * + * @throws IllegalStateException if the current kind is not + * {@code generative-cohere}. + */ + default CohereGenerative asCohere() { + return _as(Generative.Kind.COHERE); + } + + /** Is this a {@code generative-databricks} provider? */ + default boolean isDatabricks() { + return _is(Generative.Kind.DATABRICKS); + } + + /** + * Get as {@link DatabricksGenerative} instance. + * + * @throws IllegalStateException if the current kind is not + * {@code generative-databricks}. + */ + default DatabricksGenerative asDatabricks() { + return _as(Generative.Kind.DATABRICKS); + } + + /** Is this a {@code generative-friendliai} provider? */ + default boolean isFriendliai() { + return _is(Generative.Kind.FRIENDLIAI); + } + + /** + * Get as {@link FriendliaiGenerative} instance. + * + * @throws IllegalStateException if the current kind is not + * {@code generative-friendliai}. + */ + default FriendliaiGenerative asFriendliai() { + return _as(Generative.Kind.FRIENDLIAI); + } + + /** Is this a {@code generative-palm} provider? */ + default boolean isGoogle() { + return _is(Generative.Kind.GOOGLE); + } + + /** + * Get as {@link GoogleGenerative} instance. + * + * @throws IllegalStateException if the current kind is not + * {@code generative-palm}. + */ + default GoogleGenerative asGoogle() { + return _as(Generative.Kind.GOOGLE); + } + + /** Is this a {@code generative-mistral} provider? */ + default boolean isMistral() { + return _is(Generative.Kind.MISTRAL); + } + + /** + * Get as {@link MistralGenerative} instance. + * + * @throws IllegalStateException if the current kind is not + * {@code generative-mistral}. + */ + default MistralGenerative asMistral() { + return _as(Generative.Kind.MISTRAL); + } + + /** Is this a {@code generative-nvidia} provider? */ + default boolean isNvidia() { + return _is(Generative.Kind.NVIDIA); + } + + /** + * Get as {@link NvidiaGenerative} instance. + * + * @throws IllegalStateException if the current kind is not + * {@code generative-nvidia}. + */ + default NvidiaGenerative asNvidia() { + return _as(Generative.Kind.NVIDIA); + } + + /** Is this a {@code generative-ollama} provider? */ + default boolean isOllama() { + return _is(Generative.Kind.OLLAMA); + } + + /** + * Get as {@link OllamaGenerative} instance. + * + * @throws IllegalStateException if the current kind is not + * {@code generative-ollama}. + */ + default OllamaGenerative asOllama() { + return _as(Generative.Kind.OLLAMA); + } + + /** Is this a {@code generative-openai} provider? */ + default boolean isOpenAi() { + return _is(Generative.Kind.OPENAI); + } + + /** + * Get as {@link OpenAiGenerative} instance. + * + * @throws IllegalStateException if the current kind is not + * {@code generative-openai}. + */ + default OpenAiGenerative asOpenAi() { + return _as(Generative.Kind.OPENAI); + } + + /** Is this an Azure-specific {@code generative-openai} provider? */ + default boolean isAzure() { + return _is(Generative.Kind.AZURE_OPENAI); + } + + /** + * Get as {@link AzureOpenAiGenerative} instance. + * + * @throws IllegalStateException if the current kind is not + * {@code generative-openai}. + */ + default AzureOpenAiGenerative asAzure() { + return _as(Generative.Kind.AZURE_OPENAI); + } + + /** Is this a {@code generative-xai} provider? */ + default boolean isXai() { + return _is(Generative.Kind.XAI); + } + + /** + * Get as {@link XaiGenerative} instance. + * + * @throws IllegalStateException if the current kind is not + * {@code generative-xai}. + */ + default XaiGenerative asXai() { + return _as(Generative.Kind.XAI); + } + public static enum CustomTypeAdapterFactory implements TypeAdapterFactory { INSTANCE; @@ -67,7 +480,20 @@ private final void addAdapter(Gson gson, Generative.Kind kind, Class TypeAdapter create(Gson gson, TypeToken type) { init(gson); } - final TypeAdapter writeAdapter = (TypeAdapter) gson.getDelegateAdapter(this, TypeToken.get(rawType)); + final TypeAdapter writeAdapter = (TypeAdapter) gson.getDelegateAdapter(this, + TypeToken.get(rawType)); return (TypeAdapter) new TypeAdapter() { @Override @@ -95,20 +522,25 @@ public void write(JsonWriter out, Generative value) throws IOException { @Override public Generative read(JsonReader in) throws IOException { - in.beginObject(); - var moduleName = in.nextName(); - try { - var kind = Generative.Kind.valueOfJson(moduleName); - var adapter = readAdapters.get(kind); - return adapter.read(in); - } catch (IllegalArgumentException e) { - return null; - } finally { - if (in.peek() == JsonToken.BEGIN_OBJECT) { - in.beginObject(); + var jsonObject = JsonParser.parseReader(in).getAsJsonObject(); + var provider = jsonObject.keySet().iterator().next(); + + var generative = jsonObject.get(provider).getAsJsonObject(); + Generative.Kind kind; + if (provider.equals(Generative.Kind.OPENAI.jsonValue())) { + kind = generative.has("deploymentId") && generative.has("resourceName") + ? Generative.Kind.AZURE_OPENAI + : Generative.Kind.OPENAI; + } else { + try { + kind = Generative.Kind.valueOfJson(provider); + } catch (IllegalArgumentException e) { + return null; } - in.endObject(); } + var adapter = readAdapters.get(kind); + assert adapter != null : "no generative adapter for kind " + kind; + return adapter.fromJsonTree(generative); } }.nullSafe(); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/AbstractGenerateClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/AbstractGenerateClient.java new file mode 100644 index 000000000..c2ef15ea1 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generate/AbstractGenerateClient.java @@ -0,0 +1,1830 @@ +package io.weaviate.client6.v1.api.collections.generate; + +import java.util.List; +import java.util.function.Function; + +import io.weaviate.client6.v1.api.WeaviateApiException; +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; +import io.weaviate.client6.v1.api.collections.query.Bm25; +import io.weaviate.client6.v1.api.collections.query.FetchObjects; +import io.weaviate.client6.v1.api.collections.query.GroupBy; +import io.weaviate.client6.v1.api.collections.query.Hybrid; +import io.weaviate.client6.v1.api.collections.query.NearAudio; +import io.weaviate.client6.v1.api.collections.query.NearDepth; +import io.weaviate.client6.v1.api.collections.query.NearImage; +import io.weaviate.client6.v1.api.collections.query.NearImu; +import io.weaviate.client6.v1.api.collections.query.NearObject; +import io.weaviate.client6.v1.api.collections.query.NearText; +import io.weaviate.client6.v1.api.collections.query.NearThermal; +import io.weaviate.client6.v1.api.collections.query.NearVector; +import io.weaviate.client6.v1.api.collections.query.NearVectorTarget; +import io.weaviate.client6.v1.api.collections.query.NearVideo; +import io.weaviate.client6.v1.api.collections.query.QueryOperator; +import io.weaviate.client6.v1.api.collections.query.Target; +import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.GrpcTransport; +import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; + +abstract class AbstractGenerateClient { + protected final CollectionDescriptor collection; + protected final GrpcTransport grpcTransport; + protected final CollectionHandleDefaults defaults; + + AbstractGenerateClient(CollectionDescriptor collection, GrpcTransport grpcTransport, + CollectionHandleDefaults defaults) { + this.collection = collection; + this.grpcTransport = grpcTransport; + this.defaults = defaults; + } + + /** Copy constructor that sets new defaults. */ + AbstractGenerateClient( + AbstractGenerateClient c, + CollectionHandleDefaults defaults) { + this(c.collection, c.grpcTransport, defaults); + } + + protected abstract ResponseT performRequest(QueryOperator operator, GenerativeTask generate); + + protected abstract GroupedResponseT performRequest(QueryOperator operator, GenerativeTask generate, GroupBy groupBy); + + // Object queries ----------------------------------------------------------- + + /** + * Retrieve objects without applying a Vector Search or Keyword Search filter + * and run a generative task on the query results. + * + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT fetchObjects(Function> fn, + Function> generateFn) { + return fetchObjects(FetchObjects.of(fn), GenerativeTask.of(generateFn)); + } + + /** + * Retrieve objects without applying a Vector Search or Keyword Search filter + * and run a generative task on the query results. + * + * @param query FetchObjects query. + * @param generate Generative task. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT fetchObjects(FetchObjects query, GenerativeTask generate) { + return performRequest(query, generate); + } + + /** + * Retrieve objects without applying a Vector Search or Keyword Search filter + * and run a generative task on the query results. + * + * + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT fetchObjects(Function> fn, + Function> generateFn, + GroupBy groupBy) { + return fetchObjects(FetchObjects.of(fn), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Retrieve objects without applying a Vector Search or Keyword Search filter + * and run a generative task on the query results. + * + * @param query FetchObjects query. + * @param generate Generative task. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT fetchObjects(FetchObjects query, GenerativeTask generate, GroupBy groupBy) { + return performRequest(query, generate, groupBy); + } + + // BM25 queries ------------------------------------------------------------- + + /** + * Run a generative task on the results of a keyword (BM25) search. + * + * @param query Query string. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT bm25(String query, + Function> generateFn) { + return bm25(Bm25.of(query), GenerativeTask.of(generateFn)); + } + + /** + * Run a generative task on the results of a keyword (BM25) search. + * + * @param query Query string. + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT bm25( + String query, + Function> fn, + Function> generateFn) { + return bm25(Bm25.of(query, fn), GenerativeTask.of(generateFn)); + } + + /** + * Run a generative task on the results of a keyword (BM25) search. + * + * @param query BM25 query request. + * @param generate Generative task. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT bm25(Bm25 query, GenerativeTask generate) { + return performRequest(query, generate); + } + + /** + * Run a generative task on the results of a keyword (BM25) search. + * + * @param query Query string. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT bm25(String query, + Function> generateFn, + GroupBy groupBy) { + return bm25(Bm25.of(query), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Run a generative task on the results of a keyword (BM25) search. + * + * @param query Query string. + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT bm25(String query, + Function> fn, + Function> generateFn, + GroupBy groupBy) { + return bm25(Bm25.of(query, fn), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Run a generative task on the results of a keyword (BM25) search. + * + * @param query BM25 query request. + * @param generate Generative task. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT bm25(Bm25 query, GenerativeTask generate, GroupBy groupBy) { + return performRequest(query, generate, groupBy); + } + + // Hybrid queries ----------------------------------------------------------- + + /** + * Run a generative task on the results of a hybrid search. + * + * @param query Query string. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT hybrid(String query, + Function> generateFn) { + return hybrid(Hybrid.of(query), GenerativeTask.of(generateFn)); + } + + /** + * Run a generative task on the results of a hybrid search. + * + * @param query Query string. + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT hybrid( + String query, + Function> fn, + Function> generateFn) { + return hybrid(Hybrid.of(query, fn), GenerativeTask.of(generateFn)); + } + + /** + * Run a generative task on the results of a hybrid search. + * + * @param searchTarget Query target. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT hybrid( + Target searchTarget, + Function> generateFn) { + return hybrid(Hybrid.of(searchTarget), GenerativeTask.of(generateFn)); + } + + /** + * Run a generative task on the results of a hybrid search. + * + * @param searchTarget Query target. + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT hybrid( + Target searchTarget, + Function> fn, + Function> generateFn) { + return hybrid(Hybrid.of(searchTarget, fn), GenerativeTask.of(generateFn)); + } + + /** + * Run a generative task on the results of a hybrid search. + * + * @param query Hybrid query request. + * @param generate Generative task. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT hybrid(Hybrid query, GenerativeTask generate) { + return performRequest(query, generate); + } + + /** + * Run a generative task on the results of a hybrid search. + * + * @param query Query string. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT hybrid( + String query, + Function> generateFn, + GroupBy groupBy) { + return hybrid(Hybrid.of(query), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Run a generative task on the results of a hybrid search. + * + * @param query Query string. + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT hybrid( + String query, + Function> generateFn, + Function> fn, GroupBy groupBy) { + return hybrid(Hybrid.of(query, fn), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Run a generative task on the results of a hybrid search. + * + * @param searchTarget Query target. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT hybrid( + Target searchTarget, + Function> generateFn, + GroupBy groupBy) { + return hybrid(Hybrid.of(searchTarget), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Run a generative task on the results of a hybrid search. + * + * @param searchTarget Query target. + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT hybrid( + Target searchTarget, + Function> generateFn, + Function> fn, + GroupBy groupBy) { + return hybrid(Hybrid.of(searchTarget, fn), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Run a generative task on the results of a hybrid search. + * + * @param query Query string. + * @param generate Generative task. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT hybrid(Hybrid query, GenerativeTask generate, GroupBy groupBy) { + return performRequest(query, generate, groupBy); + } + + // NearVector queries ------------------------------------------------------- + + /** + * Run a generative task on the results of a near vector search. + * + * @param vector Query vector. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearVector(float[] vector, + Function> generateFn) { + return nearVector(Target.vector(vector), generateFn); + } + + /** + * Run a generative task on the results of a near vector search. + * + * @param vector Query vector. + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearVector(float[] vector, + Function> fn, + Function> generateFn) { + return nearVector(Target.vector(vector), fn, generateFn); + } + + /** + * Run a generative task on the results of a near vector search. + * + * @param searchTarget Target query vectors. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearVector(NearVectorTarget searchTarget, + Function> generateFn) { + return nearVector(NearVector.of(searchTarget), GenerativeTask.of(generateFn)); + } + + /** + * Run a generative task on the results of a near vector search. + * + * @param searchTarget Target query vectors. + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearVector(NearVectorTarget searchTarget, + Function> fn, + Function> generateFn) { + return nearVector(NearVector.of(searchTarget, fn), GenerativeTask.of(generateFn)); + } + + /** + * Run a generative task on the results of a near vector search. + * + * @param query Near vector query request. + * @param generate Generative task. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearVector(NearVector query, GenerativeTask generate) { + return performRequest(query, generate); + } + + /** + * Run a generative task on the results of a near vector search. + * + * @param vector Query vector. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearVector(float[] vector, + Function> generateFn, + GroupBy groupBy) { + return nearVector(Target.vector(vector), generateFn, groupBy); + } + + /** + * Run a generative task on the results of a near vector search. + * + * @param vector Query vector. + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearVector(float[] vector, + Function> fn, + Function> generateFn, + GroupBy groupBy) { + return nearVector(Target.vector(vector), fn, generateFn, groupBy); + } + + /** + * Run a generative task on the results of a near vector search. + * + * @param searchTarget Target query vectors. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearVector( + NearVectorTarget searchTarget, + Function> generateFn, + GroupBy groupBy) { + return nearVector(NearVector.of(searchTarget), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Run a generative task on the results of a near vector search. + * + * @param searchTarget Target query vectors. + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearVector(NearVectorTarget searchTarget, + Function> fn, + Function> generateFn, + GroupBy groupBy) { + return nearVector(NearVector.of(searchTarget, fn), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Run a generative task on the results of a near vector search. + * + * @param query Near vector query request. + * @param generate Generative task. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearVector(NearVector query, GenerativeTask generate, GroupBy groupBy) { + return performRequest(query, generate, groupBy); + } + + // NearObject queries ------------------------------------------------------- + + /** + * Run a generative task on the results of a near object search. + * + * @param uuid Query object UUID. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearObject(String uuid, + Function> generateFn) { + return nearObject(NearObject.of(uuid), GenerativeTask.of(generateFn)); + } + + /** + * Run a generative task on the results of a near object search. + * + * @param uuid Query object UUID. + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearObject(String uuid, + Function> fn, + Function> generateFn) { + return nearObject(NearObject.of(uuid, fn), GenerativeTask.of(generateFn)); + } + + /** + * Run a generative task on the results of a near object search. + * + * @param query Near object query request. + * @param generate Generative task. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearObject(NearObject query, GenerativeTask generate) { + return performRequest(query, generate); + } + + /** + * Run a generative task on the results of a near object search. + * + * @param uuid Query object UUID. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearObject(String uuid, + Function> generateFn, + GroupBy groupBy) { + return nearObject(NearObject.of(uuid), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Run a generative task on the results of a near object search. + * + * @param uuid Query object UUID. + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearObject(String uuid, + Function> fn, + Function> generateFn, + GroupBy groupBy) { + return nearObject(NearObject.of(uuid, fn), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Run a generative task on the results of a near object search. + * + * @param query Near object query request. + * @param generate Generative task. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearObject(NearObject query, GenerativeTask generate, GroupBy groupBy) { + return performRequest(query, generate, groupBy); + } + + // NearText queries --------------------------------------------------------- + + /** + * Run a generative task on the results of a near text search. + * + * @param text Query concepts. + * @param fn Lambda expression for optional parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearText(String text, + Function> fn, + Function> generateFn) { + return nearText(Target.text(List.of(text)), fn, generateFn); + } + + /** + * Run a generative task on the results of a near text search. + * + * @param text Query concepts. + * @param fn Lambda expression for optional parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearText(List text, + Function> fn, + Function> generateFn) { + return nearText(Target.text(text), fn, generateFn); + } + + /** + * Run a generative task on the results of a near text search. + * + * @param searchTarget Target query concepts. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearText(Target searchTarget, + Function> generateFn) { + return nearText(NearText.of(searchTarget), GenerativeTask.of(generateFn)); + } + + /** + * Run a generative task on the results of a near text search. + * + * @param searchTarget Target query concepts. + * @param fn Lambda expression for optional parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearText(Target searchTarget, + Function> fn, + Function> generateFn) { + return nearText(NearText.of(searchTarget, fn), GenerativeTask.of(generateFn)); + } + + /** + * Run a generative task on the results of a near text search. + * + * @param query Near text query request. + * @param generate Generative task. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearText(NearText query, GenerativeTask generate) { + return performRequest(query, generate); + } + + /** + * Run a generative task on the results of a near text search. + * + * @param text Query concepts. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearText(String text, + Function> generateFn, + GroupBy groupBy) { + return nearText(Target.text(List.of(text)), generateFn, groupBy); + } + + /** + * Run a generative task on the results of a near text search. + * + * @param text Query concepts. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearText(List text, + Function> generateFn, GroupBy groupBy) { + return nearText(Target.text(text), generateFn, groupBy); + } + + /** + * Run a generative task on the results of a near text search. + * + * @param text Query concepts. + * @param fn Lambda expression for optional parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearText(String text, + Function> fn, + Function> generateFn, + GroupBy groupBy) { + return nearText(Target.text(List.of(text)), fn, generateFn, groupBy); + } + + /** + * Run a generative task on the results of a near text search. + * + * @param text Query concepts. + * @param fn Lambda expression for optional parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearText(List text, + Function> fn, + Function> generateFn, + GroupBy groupBy) { + return nearText(Target.text(text), fn, generateFn, groupBy); + } + + /** + * Run a generative task on the results of a near text search. + * + * @param searchTarget Target query concepts. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearText(Target searchTarget, + Function> generateFn, GroupBy groupBy) { + return nearText(NearText.of(searchTarget), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Run a generative task on the results of a near text search. + * + * @param searchTarget Target query concepts. + * @param fn Lambda expression for optional parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearText(Target searchTarget, + Function> fn, + Function> generateFn, + GroupBy groupBy) { + return nearText(NearText.of(searchTarget, fn), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Run a generative task on the results of a near text search. + * + * @param query Near text query request. + * @param generate Generative task. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearText(NearText query, GenerativeTask generate, GroupBy groupBy) { + return performRequest(query, generate, groupBy); + } + + // NearImage queries -------------------------------------------------------- + + /** + * Run a generative task on the results of a near image search. + * + * @param image Query image (base64-encoded). + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearImage(String image, + Function> generateFn) { + return nearImage(NearImage.of(image), GenerativeTask.of(generateFn)); + } + + /** + * Run a generative task on the results of a near image search. + * + * @param image Query image (base64-encoded). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearImage(String image, Function> fn, + Function> generateFn) { + return nearImage(NearImage.of(image, fn), GenerativeTask.of(generateFn)); + } + + /** + * Run a generative task on the results of a near image search. + * + * @param searchTarget Query target (base64-encoded image). + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearImage(Target searchTarget, + Function> generateFn) { + return nearImage(NearImage.of(searchTarget), GenerativeTask.of(generateFn)); + } + + /** + * Run a generative task on the results of a near image search. + * + * @param searchTarget Query target (base64-encoded image). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearImage(Target searchTarget, Function> fn, + Function> generateFn) { + return nearImage(NearImage.of(searchTarget, fn), GenerativeTask.of(generateFn)); + } + + /** + * Run a generative task on the results of a near image search. + * + * @param query Near image query request. + * @param generate Generative task. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearImage(NearImage query, GenerativeTask generate) { + return performRequest(query, generate); + } + + /** + * Run a generative task on the results of a near image search. + * + * @param image Query image (base64-encoded). + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * + * @see GroupBy + * @see GenerativeResponseGrouped + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public GroupedResponseT nearImage(String image, + Function> generateFn, GroupBy groupBy) { + return nearImage(NearImage.of(image), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Run a generative task on the results of a near image search. + * + * @param searchTarget Query target (base64-encoded image). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearImage(Target searchTarget, Function> fn, + Function> generateFn, + GroupBy groupBy) { + return nearImage(NearImage.of(searchTarget, fn), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Run a generative task on the results of a near image search. + * + * @param searchTarget Query target (base64-encoded image). + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * + * @see GroupBy + * @see GenerativeResponseGrouped + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public GroupedResponseT nearImage(Target searchTarget, + Function> generateFn, GroupBy groupBy) { + return nearImage(NearImage.of(searchTarget), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Run a generative task on the results of a near image search. + * + * @param image Query image (base64-encoded). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearImage(String image, Function> fn, + Function> generateFn, + GroupBy groupBy) { + return nearImage(NearImage.of(image, fn), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Run a generative task on the results of a near image search. + * + * @param query Near image query request. + * @param generate Generative task. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearImage(NearImage query, GenerativeTask generate, GroupBy groupBy) { + return performRequest(query, generate, groupBy); + } + + // NearAudio queries -------------------------------------------------------- + + /** + * Run a generative task on the results of a near audio search. + * + * @param audio Query audio (base64-encoded). + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearAudio(String audio, + Function> generateFn) { + return nearAudio(NearAudio.of(audio), GenerativeTask.of(generateFn)); + } + + /** + * Run a generative task on the results of a near audio search. + * + * @param audio Query audio (base64-encoded). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearAudio(String audio, Function> fn, + Function> generateFn) { + return nearAudio(NearAudio.of(audio, fn), GenerativeTask.of(generateFn)); + } + + /** + * Run a generative task on the results of a near audio search. + * + * @param searchTarget Query target (base64-encoded audio). + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearAudio(Target searchTarget, + Function> generateFn) { + return nearAudio(NearAudio.of(searchTarget), GenerativeTask.of(generateFn)); + } + + /** + * Run a generative task on the results of a near audio search. + * + * @param searchTarget Query target (base64-encoded audio). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearAudio(Target searchTarget, Function> fn, + Function> generateFn) { + return nearAudio(NearAudio.of(searchTarget, fn), GenerativeTask.of(generateFn)); + } + + /** + * Run a generative task on the results of a near audio search. + * + * @param query Near audio query request. + * @param generate Generative task. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearAudio(NearAudio query, GenerativeTask generate) { + return performRequest(query, generate); + } + + /** + * Run a generative task on the results of a near audio search. + * + * @param audio Query audio (base64-encoded). + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * + * @see GroupBy + * @see GenerativeResponseGrouped + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public GroupedResponseT nearAudio(String audio, + Function> generateFn, GroupBy groupBy) { + return nearAudio(NearAudio.of(audio), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Run a generative task on the results of a near audio search. + * + * @param searchTarget Query target (base64-encoded audio). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearAudio(Target searchTarget, Function> fn, + Function> generateFn, + GroupBy groupBy) { + return nearAudio(NearAudio.of(searchTarget, fn), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Run a generative task on the results of a near audio search. + * + * @param searchTarget Query target (base64-encoded audio). + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * + * @see GroupBy + * @see GenerativeResponseGrouped + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public GroupedResponseT nearAudio(Target searchTarget, + Function> generateFn, GroupBy groupBy) { + return nearAudio(NearAudio.of(searchTarget), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Run a generative task on the results of a near audio search. + * + * @param audio Query audio (base64-encoded). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearAudio(String audio, Function> fn, + Function> generateFn, + GroupBy groupBy) { + return nearAudio(NearAudio.of(audio, fn), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Run a generative task on the results of a near audio search. + * + * @param query Near audio query request. + * @param generate Generative task. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearAudio(NearAudio query, GenerativeTask generate, GroupBy groupBy) { + return performRequest(query, generate, groupBy); + } + + // NearVideo queries -------------------------------------------------------- + + /** + * Run a generative task on the results of a near video search. + * + * @param video Query video (base64-encoded). + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearVideo(String video, + Function> generateFn) { + return nearVideo(NearVideo.of(video), GenerativeTask.of(generateFn)); + } + + /** + * Run a generative task on the results of a near video search. + * + * @param video Query video (base64-encoded). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearVideo(String video, Function> fn, + Function> generateFn) { + return nearVideo(NearVideo.of(video, fn), GenerativeTask.of(generateFn)); + } + + /** + * Run a generative task on the results of a near video search. + * + * @param searchTarget Query target (base64-encoded video). + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearVideo(Target searchTarget, + Function> generateFn) { + return nearVideo(NearVideo.of(searchTarget), GenerativeTask.of(generateFn)); + } + + /** + * Run a generative task on the results of a near video search. + * + * @param searchTarget Query target (base64-encoded video). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearVideo(Target searchTarget, Function> fn, + Function> generateFn) { + return nearVideo(NearVideo.of(searchTarget, fn), GenerativeTask.of(generateFn)); + } + + /** + * Run a generative task on the results of a near video search. + * + * @param query Near video query request. + * @param generate Generative task. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearVideo(NearVideo query, GenerativeTask generate) { + return performRequest(query, generate); + } + + /** + * Run a generative task on the results of a near video search. + * + * @param video Query video (base64-encoded). + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * + * @see GroupBy + * @see GenerativeResponseGrouped + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public GroupedResponseT nearVideo(String video, + Function> generateFn, GroupBy groupBy) { + return nearVideo(NearVideo.of(video), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Run a generative task on the results of a near video search. + * + * @param searchTarget Query target (base64-encoded video). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearVideo(Target searchTarget, Function> fn, + Function> generateFn, + GroupBy groupBy) { + return nearVideo(NearVideo.of(searchTarget, fn), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Run a generative task on the results of a near video search. + * + * @param searchTarget Query target (base64-encoded video). + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * + * @see GroupBy + * @see GenerativeResponseGrouped + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public GroupedResponseT nearVideo(Target searchTarget, + Function> generateFn, GroupBy groupBy) { + return nearVideo(NearVideo.of(searchTarget), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Run a generative task on the results of a near video search. + * + * @param video Query video (base64-encoded). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearVideo(String video, Function> fn, + Function> generateFn, + GroupBy groupBy) { + return nearVideo(NearVideo.of(video, fn), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Run a generative task on the results of a near video search. + * + * @param query Near video query request. + * @param generate Generative task. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearVideo(NearVideo query, GenerativeTask generate, GroupBy groupBy) { + return performRequest(query, generate, groupBy); + } + + // NearThermal queries -------------------------------------------------------- + + /** + * Run a generative task on the results of a near thermal search. + * + * @param thermal Query thermal (base64-encoded). + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearThermal(String thermal, + Function> generateFn) { + return nearThermal(NearThermal.of(thermal), GenerativeTask.of(generateFn)); + } + + /** + * Run a generative task on the results of a near thermal search. + * + * @param thermal Query thermal (base64-encoded). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearThermal(String thermal, Function> fn, + Function> generateFn) { + return nearThermal(NearThermal.of(thermal, fn), GenerativeTask.of(generateFn)); + } + + /** + * Run a generative task on the results of a near thermal search. + * + * @param searchTarget Query target (base64-encoded thermal). + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearThermal(Target searchTarget, + Function> generateFn) { + return nearThermal(NearThermal.of(searchTarget), GenerativeTask.of(generateFn)); + } + + /** + * Run a generative task on the results of a near thermal search. + * + * @param searchTarget Query target (base64-encoded thermal). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearThermal(Target searchTarget, Function> fn, + Function> generateFn) { + return nearThermal(NearThermal.of(searchTarget, fn), GenerativeTask.of(generateFn)); + } + + /** + * Run a generative task on the results of a near thermal search. + * + * @param query Near thermal query request. + * @param generate Generative task. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearThermal(NearThermal query, GenerativeTask generate) { + return performRequest(query, generate); + } + + /** + * Run a generative task on the results of a near thermal search. + * + * @param thermal Query thermal (base64-encoded). + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * + * @see GroupBy + * @see GenerativeResponseGrouped + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public GroupedResponseT nearThermal(String thermal, + Function> generateFn, GroupBy groupBy) { + return nearThermal(NearThermal.of(thermal), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Run a generative task on the results of a near thermal search. + * + * @param searchTarget Query target (base64-encoded thermal). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearThermal(Target searchTarget, Function> fn, + Function> generateFn, + GroupBy groupBy) { + return nearThermal(NearThermal.of(searchTarget, fn), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Run a generative task on the results of a near thermal search. + * + * @param searchTarget Query target (base64-encoded thermal). + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * + * @see GroupBy + * @see GenerativeResponseGrouped + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public GroupedResponseT nearThermal(Target searchTarget, + Function> generateFn, GroupBy groupBy) { + return nearThermal(NearThermal.of(searchTarget), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Run a generative task on the results of a near thermal search. + * + * @param thermal Query thermal (base64-encoded). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearThermal(String thermal, Function> fn, + Function> generateFn, + GroupBy groupBy) { + return nearThermal(NearThermal.of(thermal, fn), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Run a generative task on the results of a near thermal search. + * + * @param query Near thermal query request. + * @param generate Generative task. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearThermal(NearThermal query, GenerativeTask generate, GroupBy groupBy) { + return performRequest(query, generate, groupBy); + } + + // NearDepth queries -------------------------------------------------------- + + /** + * Run a generative task on the results of a near depth search. + * + * @param depth Query depth (base64-encoded). + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearDepth(String depth, + Function> generateFn) { + return nearDepth(NearDepth.of(depth), GenerativeTask.of(generateFn)); + } + + /** + * Run a generative task on the results of a near depth search. + * + * @param depth Query depth (base64-encoded). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearDepth(String depth, Function> fn, + Function> generateFn) { + return nearDepth(NearDepth.of(depth, fn), GenerativeTask.of(generateFn)); + } + + /** + * Run a generative task on the results of a near depth search. + * + * @param searchTarget Query target (base64-encoded depth). + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearDepth(Target searchTarget, + Function> generateFn) { + return nearDepth(NearDepth.of(searchTarget), GenerativeTask.of(generateFn)); + } + + /** + * Run a generative task on the results of a near depth search. + * + * @param searchTarget Query target (base64-encoded depth). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearDepth(Target searchTarget, Function> fn, + Function> generateFn) { + return nearDepth(NearDepth.of(searchTarget, fn), GenerativeTask.of(generateFn)); + } + + /** + * Run a generative task on the results of a near depth search. + * + * @param query Near depth query request. + * @param generate Generative task. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearDepth(NearDepth query, GenerativeTask generate) { + return performRequest(query, generate); + } + + /** + * Run a generative task on the results of a near depth search. + * + * @param depth Query depth (base64-encoded). + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * + * @see GroupBy + * @see GenerativeResponseGrouped + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public GroupedResponseT nearDepth(String depth, + Function> generateFn, GroupBy groupBy) { + return nearDepth(NearDepth.of(depth), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Run a generative task on the results of a near depth search. + * + * @param searchTarget Query target (base64-encoded depth). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearDepth(Target searchTarget, Function> fn, + Function> generateFn, + GroupBy groupBy) { + return nearDepth(NearDepth.of(searchTarget, fn), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Run a generative task on the results of a near depth search. + * + * @param searchTarget Query target (base64-encoded depth). + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * + * @see GroupBy + * @see GenerativeResponseGrouped + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public GroupedResponseT nearDepth(Target searchTarget, + Function> generateFn, GroupBy groupBy) { + return nearDepth(NearDepth.of(searchTarget), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Run a generative task on the results of a near depth search. + * + * @param depth Query depth (base64-encoded). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearDepth(String depth, Function> fn, + Function> generateFn, + GroupBy groupBy) { + return nearDepth(NearDepth.of(depth, fn), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Run a generative task on the results of a near depth search. + * + * @param query Near depth query request. + * @param generate Generative task. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearDepth(NearDepth query, GenerativeTask generate, GroupBy groupBy) { + return performRequest(query, generate, groupBy); + } + + // NearImu queries -------------------------------------------------------- + + /** + * Run a generative task on the results of a near IMU search. + * + * @param imu Query IMU (base64-encoded). + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearImu(String imu, + Function> generateFn) { + return nearImu(NearImu.of(imu), GenerativeTask.of(generateFn)); + } + + /** + * Run a generative task on the results of a near IMU search. + * + * @param imu Query IMU (base64-encoded). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearImu(String imu, Function> fn, + Function> generateFn) { + return nearImu(NearImu.of(imu, fn), GenerativeTask.of(generateFn)); + } + + /** + * Run a generative task on the results of a near IMU search. + * + * @param searchTarget Query target (base64-encoded IMU). + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearImu(Target searchTarget, + Function> generateFn) { + return nearImu(NearImu.of(searchTarget), GenerativeTask.of(generateFn)); + } + + /** + * Run a generative task on the results of a near IMU search. + * + * @param searchTarget Query target (base64-encoded IMU). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearImu(Target searchTarget, Function> fn, + Function> generateFn) { + return nearImu(NearImu.of(searchTarget, fn), GenerativeTask.of(generateFn)); + } + + /** + * Run a generative task on the results of a near IMU search. + * + * @param query Near IMU query request. + * @param generate Generative task. + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public ResponseT nearImu(NearImu query, GenerativeTask generate) { + return performRequest(query, generate); + } + + /** + * Run a generative task on the results of a near IMU search. + * + * @param imu Query IMU (base64-encoded). + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * + * @see GroupBy + * @see GenerativeResponseGrouped + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public GroupedResponseT nearImu(String imu, + Function> generateFn, GroupBy groupBy) { + return nearImu(NearImu.of(imu), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Run a generative task on the results of a near IMU search. + * + * @param searchTarget Query target (base64-encoded IMU). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearImu(Target searchTarget, Function> fn, + Function> generateFn, + GroupBy groupBy) { + return nearImu(NearImu.of(searchTarget, fn), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Run a generative task on the results of a near IMU search. + * + * @param searchTarget Query target (base64-encoded IMU). + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * + * @see GroupBy + * @see GenerativeResponseGrouped + * @throws WeaviateApiException in case the server returned with an + * error status code. + */ + public GroupedResponseT nearImu(Target searchTarget, + Function> generateFn, GroupBy groupBy) { + return nearImu(NearImu.of(searchTarget), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Run a generative task on the results of a near IMU search. + * + * @param imu Query IMU (base64-encoded). + * @param fn Lambda expression for optional search parameters. + * @param generateFn Lambda expression for generative task parameters. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearImu(String imu, Function> fn, + Function> generateFn, + GroupBy groupBy) { + return nearImu(NearImu.of(imu, fn), GenerativeTask.of(generateFn), groupBy); + } + + /** + * Run a generative task on the results of a near IMU search. + * + * @param query Near IMU query request. + * @param generate Generative task. + * @param groupBy Group-by clause. + * @return Grouped query result. + * @throws WeaviateApiException in case the server returned with an + * error status code. + * + * @see GroupBy + * @see GenerativeResponseGrouped + */ + public GroupedResponseT nearImu(NearImu query, GenerativeTask generate, GroupBy groupBy) { + return performRequest(query, generate, groupBy); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/DynamicProvider.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/DynamicProvider.java new file mode 100644 index 000000000..883ba5ed1 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generate/DynamicProvider.java @@ -0,0 +1,154 @@ +package io.weaviate.client6.v1.api.collections.generate; + +import java.util.function.Function; + +import io.weaviate.client6.v1.api.collections.generative.AnthropicGenerative; +import io.weaviate.client6.v1.api.collections.generative.AnyscaleGenerative; +import io.weaviate.client6.v1.api.collections.generative.AwsGenerative; +import io.weaviate.client6.v1.api.collections.generative.AzureOpenAiGenerative; +import io.weaviate.client6.v1.api.collections.generative.CohereGenerative; +import io.weaviate.client6.v1.api.collections.generative.DatabricksGenerative; +import io.weaviate.client6.v1.api.collections.generative.FriendliaiGenerative; +import io.weaviate.client6.v1.api.collections.generative.GoogleGenerative; +import io.weaviate.client6.v1.api.collections.generative.MistralGenerative; +import io.weaviate.client6.v1.api.collections.generative.NvidiaGenerative; +import io.weaviate.client6.v1.api.collections.generative.OllamaGenerative; +import io.weaviate.client6.v1.api.collections.generative.OpenAiGenerative; +import io.weaviate.client6.v1.api.collections.generative.XaiGenerative; +import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; + +public interface DynamicProvider { + void appendTo(WeaviateProtoGenerative.GenerativeProvider.Builder req); + + /** + * Configure {@code generative-anthropic} as a dynamic provider. + * + * @param fn Lambda expression for optional parameters. + */ + public static DynamicProvider anthropic( + Function> fn) { + return AnthropicGenerative.Provider.of(fn); + } + + /** + * Configure {@code generative-anyscale} as a dynamic provider. + * + * @param fn Lambda expression for optional parameters. + */ + public static DynamicProvider anyscale( + Function> fn) { + return AnyscaleGenerative.Provider.of(fn); + } + + /** + * Configure {@code generative-aws} as a dynamic provider. + * + * @param fn Lambda expression for optional parameters. + */ + public static DynamicProvider aws( + Function> fn) { + return AwsGenerative.Provider.of(fn); + } + + /** + * Configure {@code generative-cohere} as a dynamic provider. + * + * @param fn Lambda expression for optional parameters. + */ + public static DynamicProvider cohere( + Function> fn) { + return CohereGenerative.Provider.of(fn); + } + + /** + * Configure {@code generative-databricks} as a dynamic provider. + * + * @param fn Lambda expression for optional parameters. + */ + public static DynamicProvider databricks( + Function> fn) { + return DatabricksGenerative.Provider.of(fn); + } + + /** + * Configure {@code generative-friendliai} as a dynamic provider. + * + * @param fn Lambda expression for optional parameters. + */ + public static DynamicProvider friendliai( + Function> fn) { + return FriendliaiGenerative.Provider.of(fn); + } + + /** + * Configure {@code generative-palm} as a dynamic provider. + * + * @param fn Lambda expression for optional parameters. + */ + public static DynamicProvider google( + Function> fn) { + return GoogleGenerative.Provider.of(fn); + } + + /** + * Configure {@code generative-mistral} as a dynamic provider. + * + * @param fn Lambda expression for optional parameters. + */ + public static DynamicProvider mistral( + Function> fn) { + return MistralGenerative.Provider.of(fn); + } + + /** + * Configure {@code generative-nvidia} as a dynamic provider. + * + * @param fn Lambda expression for optional parameters. + */ + public static DynamicProvider nvidia( + Function> fn) { + return NvidiaGenerative.Provider.of(fn); + } + + /** + * Configure {@code generative-ollama} as a dynamic provider. + * + * @param fn Lambda expression for optional parameters. + */ + public static DynamicProvider ollama( + Function> fn) { + return OllamaGenerative.Provider.of(fn); + } + + /** + * Configure {@code generative-openai} as a dynamic provider. + * + * @param fn Lambda expression for optional parameters. + */ + public static DynamicProvider openai( + Function> fn) { + return OpenAiGenerative.Provider.of(fn); + } + + /** + * Configure {@code generative-openai} as a dynamic provider. + * + * @param fn Lambda expression for optional parameters. + */ + public static DynamicProvider azure( + Function> fn) { + return AzureOpenAiGenerative.Provider.of(fn); + } + + /** + * Configure {@code generative-xai} as a dynamic provider. + * + * @param fn Lambda expression for optional parameters. + */ + public static DynamicProvider xai( + Function> fn) { + return XaiGenerative.Provider.of(fn); + } + +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeDebug.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeDebug.java new file mode 100644 index 000000000..4b08cdc12 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeDebug.java @@ -0,0 +1,4 @@ +package io.weaviate.client6.v1.api.collections.generate; + +public record GenerativeDebug(String fullPrompt) { +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeObject.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeObject.java new file mode 100644 index 000000000..1767865e9 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeObject.java @@ -0,0 +1,12 @@ +package io.weaviate.client6.v1.api.collections.generate; + +import io.weaviate.client6.v1.api.collections.query.QueryMetadata; + +public record GenerativeObject( + /** Object properties. */ + PropertiesT properties, + /** Object metadata. */ + QueryMetadata metadata, + /** Generative task output. */ + TaskOutput generated) { +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeRequest.java new file mode 100644 index 000000000..b09ce9f4e --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeRequest.java @@ -0,0 +1,43 @@ +package io.weaviate.client6.v1.api.collections.generate; + +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; +import io.weaviate.client6.v1.api.collections.query.GroupBy; +import io.weaviate.client6.v1.api.collections.query.QueryOperator; +import io.weaviate.client6.v1.api.collections.query.QueryRequest; +import io.weaviate.client6.v1.internal.grpc.Rpc; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateBlockingStub; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateFutureStub; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet; +import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; + +public record GenerativeRequest(QueryOperator operator, GenerativeTask generate, GroupBy groupBy) { + static Rpc, WeaviateProtoSearchGet.SearchReply> rpc( + CollectionDescriptor collection, + CollectionHandleDefaults defaults) { + return Rpc.of( + request -> { + var query = QueryRequest.marshal( + new QueryRequest(request.operator, request.groupBy), + collection, defaults); + var generative = WeaviateProtoGenerative.GenerativeSearch.newBuilder(); + request.generate.appendTo(generative); + var builder = query.toBuilder(); + builder.setGenerative(generative); + return builder.build(); + }, + reply -> GenerativeResponse.unmarshal(reply, collection), + () -> WeaviateBlockingStub::search, + () -> WeaviateFutureStub::search); + } + + static Rpc, WeaviateProtoSearchGet.SearchReply> grouped( + CollectionDescriptor collection, + CollectionHandleDefaults defaults) { + var rpc = rpc(collection, defaults); + return Rpc.of( + request -> rpc.marshal(request), + reply -> GenerativeResponseGrouped.unmarshal(reply, collection, defaults), + () -> rpc.method(), () -> rpc.methodAsync()); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeResponse.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeResponse.java new file mode 100644 index 000000000..82a75cefe --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeResponse.java @@ -0,0 +1,174 @@ +package io.weaviate.client6.v1.api.collections.generate; + +import java.util.ArrayList; +import java.util.List; + +import io.weaviate.client6.v1.api.collections.generative.AnthropicGenerative; +import io.weaviate.client6.v1.api.collections.generative.AnyscaleGenerative; +import io.weaviate.client6.v1.api.collections.generative.AwsGenerative; +import io.weaviate.client6.v1.api.collections.generative.CohereGenerative; +import io.weaviate.client6.v1.api.collections.generative.DatabricksGenerative; +import io.weaviate.client6.v1.api.collections.generative.DummyGenerative; +import io.weaviate.client6.v1.api.collections.generative.FriendliaiGenerative; +import io.weaviate.client6.v1.api.collections.generative.GoogleGenerative; +import io.weaviate.client6.v1.api.collections.generative.MistralGenerative; +import io.weaviate.client6.v1.api.collections.generative.NvidiaGenerative; +import io.weaviate.client6.v1.api.collections.generative.OllamaGenerative; +import io.weaviate.client6.v1.api.collections.generative.OpenAiGenerative; +import io.weaviate.client6.v1.api.collections.generative.ProviderMetadata; +import io.weaviate.client6.v1.api.collections.generative.XaiGenerative; +import io.weaviate.client6.v1.api.collections.query.QueryResponse; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet; +import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; + +public record GenerativeResponse( + float took, + List> objects, + TaskOutput generated) { + static GenerativeResponse unmarshal( + WeaviateProtoSearchGet.SearchReply reply, + CollectionDescriptor collection) { + var objects = reply + .getResultsList() + .stream() + .map(result -> { + var object = QueryResponse.unmarshalResultObject( + result.getProperties(), result.getMetadata(), collection); + TaskOutput generative = null; + if (result.hasGenerative()) { + generative = GenerativeResponse.unmarshalTaskOutput(result.getGenerative()); + } + return new GenerativeObject<>( + object.properties(), + object.metadata(), + generative); + }) + .toList(); + + TaskOutput summary = null; + if (reply.hasGenerativeGroupedResults()) { + summary = GenerativeResponse.unmarshalTaskOutput(reply.getGenerativeGroupedResults()); + } + return new GenerativeResponse<>(reply.getTook(), objects, summary); + } + + static TaskOutput unmarshalTaskOutput(List values) { + if (values.isEmpty()) { + return null; + } + var generated = values.get(0); + + var metadata = generated.getMetadata(); + ProviderMetadata providerMetadata = null; + if (metadata.hasDummy()) { + providerMetadata = new DummyGenerative.Metadata(); + } else if (metadata.hasAws()) { + providerMetadata = new AwsGenerative.Metadata(); + } else if (metadata.hasAnthropic()) { + var anthropic = metadata.getAnthropic(); + var usage = anthropic.getUsage(); + providerMetadata = new AnthropicGenerative.Metadata(new AnthropicGenerative.Metadata.Usage( + usage.getInputTokens(), + usage.getOutputTokens())); + } else if (metadata.hasAnyscale()) { + providerMetadata = new AnyscaleGenerative.Metadata(); + } else if (metadata.hasCohere()) { + var cohere = metadata.getCohere(); + var apiVersion = cohere.getApiVersion(); + var billedUnits = cohere.getBilledUnits(); + var tokens = cohere.getTokens(); + providerMetadata = new CohereGenerative.Metadata( + new CohereGenerative.Metadata.ApiVersion( + apiVersion.hasVersion() ? apiVersion.getVersion() : null, + apiVersion.hasIsDeprecated() ? apiVersion.getIsDeprecated() : null, + apiVersion.hasIsExperimental() ? apiVersion.getIsExperimental() : null), + new CohereGenerative.Metadata.BilledUnits(billedUnits.getInputTokens(), + billedUnits.hasOutputTokens() ? billedUnits.getOutputTokens() : null, + billedUnits.hasSearchUnits() ? billedUnits.getSearchUnits() : null, + billedUnits.hasClassifications() ? billedUnits.getClassifications() : null), + new CohereGenerative.Metadata.Tokens( + tokens.hasInputTokens() ? tokens.getInputTokens() : null, + tokens.hasOutputTokens() ? tokens.getOutputTokens() : null), + new ArrayList<>(cohere.getWarnings().getValuesList())); + } else if (metadata.hasDatabricks()) { + var databricks = metadata.getDatabricks(); + var usage = databricks.getUsage(); + providerMetadata = new DatabricksGenerative.Metadata(new ProviderMetadata.Usage( + usage.hasPromptTokens() ? usage.getPromptTokens() : null, + usage.hasCompletionTokens() ? usage.getCompletionTokens() : null, + usage.hasTotalTokens() ? usage.getTotalTokens() : null)); + } else if (metadata.hasFriendliai()) { + var friendliai = metadata.getFriendliai(); + var usage = friendliai.getUsage(); + providerMetadata = new FriendliaiGenerative.Metadata(new ProviderMetadata.Usage( + usage.hasPromptTokens() ? usage.getPromptTokens() : null, + usage.hasCompletionTokens() ? usage.getCompletionTokens() : null, + usage.hasTotalTokens() ? usage.getTotalTokens() : null)); + } else if (metadata.hasMistral()) { + var mistral = metadata.getMistral(); + var usage = mistral.getUsage(); + providerMetadata = new MistralGenerative.Metadata(new ProviderMetadata.Usage( + usage.hasPromptTokens() ? usage.getPromptTokens() : null, + usage.hasCompletionTokens() ? usage.getCompletionTokens() : null, + usage.hasTotalTokens() ? usage.getTotalTokens() : null)); + } else if (metadata.hasNvidia()) { + var nvidia = metadata.getNvidia(); + var usage = nvidia.getUsage(); + providerMetadata = new NvidiaGenerative.Metadata(new ProviderMetadata.Usage( + usage.hasPromptTokens() ? usage.getPromptTokens() : null, + usage.hasCompletionTokens() ? usage.getCompletionTokens() : null, + usage.hasTotalTokens() ? usage.getTotalTokens() : null)); + } else if (metadata.hasOllama()) { + providerMetadata = new OllamaGenerative.Metadata(); + } else if (metadata.hasOpenai()) { + var openai = metadata.getOpenai(); + var usage = openai.getUsage(); + providerMetadata = new OpenAiGenerative.Metadata(new ProviderMetadata.Usage( + usage.hasPromptTokens() ? usage.getPromptTokens() : null, + usage.hasCompletionTokens() ? usage.getCompletionTokens() : null, + usage.hasTotalTokens() ? usage.getTotalTokens() : null)); + } else if (metadata.hasGoogle()) { + var google = metadata.getGoogle(); + var tokens = google.getMetadata().getTokenMetadata(); + var usage = google.getUsageMetadata(); + providerMetadata = new GoogleGenerative.Metadata( + new GoogleGenerative.Metadata.TokenMetadata( + new GoogleGenerative.Metadata.TokenCount( + tokens.getInputTokenCount().hasTotalBillableCharacters() + ? tokens.getInputTokenCount().getTotalBillableCharacters() + : null, + tokens.getInputTokenCount().hasTotalTokens() + ? tokens.getInputTokenCount().getTotalTokens() + : null), + new GoogleGenerative.Metadata.TokenCount( + tokens.getOutputTokenCount().hasTotalBillableCharacters() + ? tokens.getOutputTokenCount().getTotalBillableCharacters() + : null, + tokens.getOutputTokenCount().hasTotalTokens() + ? tokens.getOutputTokenCount().getTotalTokens() + : null)), + new GoogleGenerative.Metadata.Usage( + usage.hasPromptTokenCount() ? usage.getPromptTokenCount() : null, + usage.hasCandidatesTokenCount() ? usage.getCandidatesTokenCount() : null, + usage.hasTotalTokenCount() ? usage.getTotalTokenCount() : null)); + } else if (metadata.hasXai()) { + var xai = metadata.getXai(); + var usage = xai.getUsage(); + providerMetadata = new XaiGenerative.Metadata(new ProviderMetadata.Usage( + usage.hasPromptTokens() ? usage.getPromptTokens() : null, + usage.hasCompletionTokens() ? usage.getCompletionTokens() : null, + usage.hasTotalTokens() ? usage.getTotalTokens() : null)); + } + + GenerativeDebug debug = null; + if (generated.getDebug() != null && generated.getDebug().getFullPrompt() != null) { + debug = new GenerativeDebug(generated.getDebug().getFullPrompt()); + } + return new TaskOutput(generated.getResult(), providerMetadata, debug); + } + + static TaskOutput unmarshalTaskOutput(WeaviateProtoGenerative.GenerativeResult result) { + return unmarshalTaskOutput(result.getValuesList()); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeResponseGroup.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeResponseGroup.java new file mode 100644 index 000000000..4a9620530 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeResponseGroup.java @@ -0,0 +1,26 @@ +package io.weaviate.client6.v1.api.collections.generate; + +import java.util.List; + +import io.weaviate.client6.v1.api.collections.query.QueryObjectGrouped; + +public record GenerativeResponseGroup( + /** Group name. */ + String name, + /** + * The smallest distance value among all objects in the group, indicating the + * most similar object in that group to the query + */ + Float minDistance, + /** + * The largest distance value among all objects in the group, indicating the + * least similar object in that group to the query. + */ + Float maxDistance, + /** The size of the group. */ + long numberOfObjects, + /** Objects retrieved in the query. */ + List> objects, + /** Output of the summary task for this group. */ + TaskOutput generated) { +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeResponseGrouped.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeResponseGrouped.java new file mode 100644 index 000000000..673cddbde --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeResponseGrouped.java @@ -0,0 +1,77 @@ +package io.weaviate.client6.v1.api.collections.generate; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; + +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; +import io.weaviate.client6.v1.api.collections.query.QueryObjectGrouped; +import io.weaviate.client6.v1.api.collections.query.QueryResponse; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet; +import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; + +public record GenerativeResponseGrouped( + /** Execution time of the request as measure by the server. */ + float took, + /** Objects returned by the associated query. */ + List> objects, + /** Grouped results with per-group generated output. */ + Map> groups, + /** Output of the summary group task. */ + TaskOutput generated) { + + static GenerativeResponseGrouped unmarshal( + WeaviateProtoSearchGet.SearchReply reply, + CollectionDescriptor collection, + CollectionHandleDefaults defaults) { + var allObjects = new ArrayList>(); + var groups = reply.getGroupByResultsList().stream() + .map(group -> { + var groupName = group.getName(); + List> objects = group.getObjectsList().stream() + .map(object -> QueryResponse.unmarshalResultObject( + object.getProperties(), + object.getMetadata(), + collection)) + .map(object -> new QueryObjectGrouped<>( + object.properties(), + object.metadata(), + groupName)) + .toList(); + + allObjects.addAll(objects); + + TaskOutput generative = null; + if (group.hasGenerativeResult()) { + generative = GenerativeResponse.unmarshalTaskOutput(group.getGenerativeResult()); + } else if (group.hasGenerative()) { + // As of today the server continues to use the deprecated field in response. + generative = GenerativeResponse.unmarshalTaskOutput(List.of(group.getGenerative())); + } + + return new GenerativeResponseGroup<>( + groupName, + group.getMinDistance(), + group.getMaxDistance(), + group.getNumberOfObjects(), + objects, + generative); + }) + // Collectors.toMap() throws an NPE if either key or value in the map are null. + // In this specific case it is safe to use it, as the function in the map above + // always returns a QueryResponseGroup. + // The name of the group should not be null either, that's something we assume + // about the server's response. + .collect(Collectors.toMap(GenerativeResponseGroup::name, Function.identity())); + + TaskOutput summary = null; + if (reply.hasGenerativeGroupedResults()) { + summary = GenerativeResponse.unmarshalTaskOutput(reply.getGenerativeGroupedResults()); + } else if (reply.hasGenerativeGroupedResult()) { + summary = new TaskOutput(reply.getGenerativeGroupedResult(), null, null); + } + return new GenerativeResponseGrouped<>(reply.getTook(), allObjects, groups, summary); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeTask.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeTask.java new file mode 100644 index 000000000..013697b66 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generate/GenerativeTask.java @@ -0,0 +1,190 @@ +package io.weaviate.client6.v1.api.collections.generate; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.function.Function; + +import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; + +public record GenerativeTask(Single single, Grouped grouped) { + public static GenerativeTask of(Function> fn) { + return fn.apply(new Builder()).build(); + } + + public GenerativeTask(Builder builder) { + this(builder.single, builder.grouped); + } + + public static class Builder implements ObjectBuilder { + private Single single; + private Grouped grouped; + + public Builder singlePrompt(String prompt) { + this.single = Single.of(prompt); + return this; + } + + public Builder singlePrompt(String prompt, Function> fn) { + this.single = Single.of(prompt, fn); + return this; + + } + + public Builder groupedTask(String prompt) { + this.grouped = Grouped.of(prompt); + return this; + } + + public Builder groupedTask(String prompt, Function> fn) { + this.grouped = Grouped.of(prompt, fn); + return this; + } + + @Override + public GenerativeTask build() { + return new GenerativeTask(this); + } + } + + void appendTo(WeaviateProtoGenerative.GenerativeSearch.Builder req) { + if (single != null) { + single.appendTo(req); + } + if (grouped != null) { + grouped.appendTo(req); + } + } + + public record Single(String prompt, boolean debug, List providers) { + public static Single of(String prompt) { + return of(prompt, ObjectBuilder.identity()); + } + + public static Single of(String prompt, Function> fn) { + return fn.apply(new Builder(prompt)).build(); + } + + public Single(Builder builder) { + this(builder.prompt, builder.debug, builder.providers); + } + + public static class Builder implements ObjectBuilder { + private final String prompt; + private final List providers = new ArrayList<>(); + private boolean debug = false; + + public Builder(String prompt) { + this.prompt = prompt; + } + + public Builder debug(boolean enable) { + this.debug = enable; + return this; + } + + public Builder dynamicProvider(DynamicProvider provider) { + providers.clear(); // Protobuf allows `repeated` but the server expects there to be 1. + providers.add(provider); + return this; + } + + @Override + public Single build() { + return new Single(this); + } + } + + public void appendTo(WeaviateProtoGenerative.GenerativeSearch.Builder req) { + var ragProviders = providers.stream() + .map(provider -> { + var proto = WeaviateProtoGenerative.GenerativeProvider.newBuilder(); + provider.appendTo(proto); + return proto.build(); + }) + .toList(); + + req.setSingle( + WeaviateProtoGenerative.GenerativeSearch.Single.newBuilder() + .setPrompt(prompt) + .setDebug(debug) + .addAllQueries(ragProviders)); + } + } + + public record Grouped(String prompt, boolean debug, List properties, List providers) { + public static Grouped of(String prompt) { + return of(prompt, ObjectBuilder.identity()); + } + + public static Grouped of(String prompt, Function> fn) { + return fn.apply(new Builder(prompt)).build(); + } + + public Grouped(Builder builder) { + this(builder.prompt, builder.debug, builder.properties, builder.providers); + } + + public static class Builder implements ObjectBuilder { + private final String prompt; + private final List providers = new ArrayList<>(); + private final List properties = new ArrayList<>(); + private boolean debug = false; + + public Builder(String prompt) { + this.prompt = prompt; + } + + public Builder properties(String... properties) { + return properties(Arrays.asList(properties)); + } + + public Builder properties(List properties) { + this.properties.addAll(properties); + return this; + } + + public Builder dynamicProvider(DynamicProvider provider) { + providers.clear(); // Protobuf allows `repeated` but the server expects there to be 1. + providers.add(provider); + return this; + } + + public Builder debug(boolean enable) { + this.debug = enable; + return this; + } + + @Override + public Grouped build() { + return new Grouped(this); + } + } + + public void appendTo(WeaviateProtoGenerative.GenerativeSearch.Builder req) { + var grouped = WeaviateProtoGenerative.GenerativeSearch.Grouped.newBuilder() + .setTask(prompt) + .setDebug(debug); + + if (properties != null && !properties.isEmpty()) { + grouped.setProperties( + WeaviateProtoBase.TextArray.newBuilder() + .addAllValues(properties)); + + } + + var ragProviders = providers.stream() + .map(provider -> { + var proto = WeaviateProtoGenerative.GenerativeProvider.newBuilder(); + provider.appendTo(proto); + return proto.build(); + }) + .toList(); + grouped.addAllQueries(ragProviders); + + req.setGrouped(grouped); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/TaskOutput.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/TaskOutput.java new file mode 100644 index 000000000..379e7ebc7 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generate/TaskOutput.java @@ -0,0 +1,10 @@ +package io.weaviate.client6.v1.api.collections.generate; + +import io.weaviate.client6.v1.api.collections.generative.ProviderMetadata; + +public record TaskOutput( + String text, + ProviderMetadata metadata, + GenerativeDebug debug) { + +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/WeaviateGenerateClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/WeaviateGenerateClient.java new file mode 100644 index 000000000..391007b5a --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generate/WeaviateGenerateClient.java @@ -0,0 +1,37 @@ +package io.weaviate.client6.v1.api.collections.generate; + +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; +import io.weaviate.client6.v1.api.collections.query.GroupBy; +import io.weaviate.client6.v1.api.collections.query.QueryOperator; +import io.weaviate.client6.v1.internal.grpc.GrpcTransport; +import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; + +public class WeaviateGenerateClient + extends + AbstractGenerateClient, GenerativeResponseGrouped> { + + public WeaviateGenerateClient( + CollectionDescriptor collection, + GrpcTransport grpcTransport, + CollectionHandleDefaults defaults) { + super(collection, grpcTransport, defaults); + } + + /** Copy constructor that sets new defaults. */ + public WeaviateGenerateClient(WeaviateGenerateClient c, CollectionHandleDefaults defaults) { + super(c, defaults); + } + + @Override + protected final GenerativeResponse performRequest(QueryOperator operator, GenerativeTask generate) { + var request = new GenerativeRequest(operator, generate, null); + return this.grpcTransport.performRequest(request, GenerativeRequest.rpc(collection, defaults)); + } + + @Override + protected final GenerativeResponseGrouped performRequest(QueryOperator operator, GenerativeTask generate, + GroupBy groupBy) { + var request = new GenerativeRequest(operator, generate, groupBy); + return this.grpcTransport.performRequest(request, GenerativeRequest.grouped(collection, defaults)); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generate/WeaviateGenerateClientAsync.java b/src/main/java/io/weaviate/client6/v1/api/collections/generate/WeaviateGenerateClientAsync.java new file mode 100644 index 000000000..eff3866a7 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generate/WeaviateGenerateClientAsync.java @@ -0,0 +1,41 @@ +package io.weaviate.client6.v1.api.collections.generate; + +import java.util.concurrent.CompletableFuture; + +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; +import io.weaviate.client6.v1.api.collections.query.GroupBy; +import io.weaviate.client6.v1.api.collections.query.QueryOperator; +import io.weaviate.client6.v1.internal.grpc.GrpcTransport; +import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; + +public class WeaviateGenerateClientAsync + extends + AbstractGenerateClient>, CompletableFuture>> { + + public WeaviateGenerateClientAsync( + CollectionDescriptor collection, + GrpcTransport grpcTransport, + CollectionHandleDefaults defaults) { + super(collection, grpcTransport, defaults); + } + + /** Copy constructor that sets new defaults. */ + public WeaviateGenerateClientAsync(WeaviateGenerateClientAsync c, CollectionHandleDefaults defaults) { + super(c, defaults); + } + + @Override + protected final CompletableFuture> performRequest(QueryOperator operator, + GenerativeTask generate) { + var request = new GenerativeRequest(operator, generate, null); + return this.grpcTransport.performRequestAsync(request, GenerativeRequest.rpc(collection, defaults)); + } + + @Override + protected final CompletableFuture> performRequest(QueryOperator operator, + GenerativeTask generate, + GroupBy groupBy) { + var request = new GenerativeRequest(operator, generate, groupBy); + return this.grpcTransport.performRequestAsync(request, GenerativeRequest.grouped(collection, defaults)); + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnthropicGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnthropicGenerative.java new file mode 100644 index 000000000..bb836d49f --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnthropicGenerative.java @@ -0,0 +1,275 @@ +package io.weaviate.client6.v1.api.collections.generative; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.function.Function; + +import com.google.gson.annotations.SerializedName; + +import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.api.collections.generate.DynamicProvider; +import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; + +public record AnthropicGenerative( + @SerializedName("model") String model, + @SerializedName("maxTokens") Integer maxTokens, + @SerializedName("temperature") Float temperature, + @SerializedName("topK") Integer topK, + @SerializedName("topP") Float topP, + @SerializedName("stopSequences") List stopSequences) implements Generative { + + @Override + public Kind _kind() { + return Generative.Kind.ANTHROPIC; + } + + @Override + public Object _self() { + return this; + } + + public static AnthropicGenerative of() { + return of(ObjectBuilder.identity()); + } + + public static AnthropicGenerative of(Function> fn) { + return fn.apply(new Builder()).build(); + } + + public AnthropicGenerative(Builder builder) { + this( + builder.model, + builder.maxTokens, + builder.temperature, + builder.topK, + builder.topP, + builder.stopSequences); + } + + public static class Builder implements ObjectBuilder { + private Integer topK; + private Float topP; + private String model; + private Integer maxTokens; + private Float temperature; + private final List stopSequences = new ArrayList<>(); + + /** Top K value for sampling. */ + public Builder topK(int topK) { + this.topK = topK; + return this; + } + + /** Top P value for nucleus sampling. */ + public Builder topP(float topP) { + this.topP = topP; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** Limit the number of tokens to generate in the response. */ + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + /** + * Set tokens which should signal the model to stop generating further output. + */ + public Builder stopSequences(String... stopSequences) { + return stopSequences(Arrays.asList(stopSequences)); + } + + /** + * Set tokens which should signal the model to stop generating further output. + */ + public Builder stopSequences(List stopSequences) { + this.stopSequences.addAll(stopSequences); + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public AnthropicGenerative build() { + return new AnthropicGenerative(this); + } + } + + public static record Metadata(Usage usage) implements ProviderMetadata { + public static record Usage(Long inputTokens, Long outputTokens) { + } + } + + public static record Provider( + String baseUrl, + Integer maxTokens, + String model, + Float temperature, + Integer topK, + Float topP, + List stopSequences, + List images, + List imageProperties) implements DynamicProvider { + + public static Provider of( + Function> fn) { + return fn.apply(new Builder()).build(); + } + + @Override + public void appendTo( + io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative.GenerativeProvider.Builder req) { + var provider = WeaviateProtoGenerative.GenerativeAnthropic.newBuilder(); + if (baseUrl != null) { + provider.setBaseUrl(baseUrl); + } + if (maxTokens != null) { + provider.setMaxTokens(maxTokens); + } + if (model != null) { + provider.setModel(model); + } + if (temperature != null) { + provider.setTemperature(temperature); + } + if (topK != null) { + provider.setTopK(topK); + } + if (topP != null) { + provider.setTopP(topP); + } + + if (stopSequences != null) { + provider.setStopSequences(WeaviateProtoBase.TextArray.newBuilder() + .addAllValues(stopSequences)); + } + if (images != null) { + provider.setImages(WeaviateProtoBase.TextArray.newBuilder() + .addAllValues(images)); + } + if (imageProperties != null) { + provider.setImageProperties(WeaviateProtoBase.TextArray.newBuilder() + .addAllValues(imageProperties)); + } + req.setAnthropic(provider); + } + + public Provider(Builder builder) { + this( + builder.baseUrl, + builder.maxTokens, + builder.model, + builder.temperature, + builder.topK, + builder.topP, + builder.stopSequences, + builder.images, + builder.imageProperties); + } + + public static class Builder implements ObjectBuilder { + private String baseUrl; + private Integer topK; + private Float topP; + private String model; + private Integer maxTokens; + private Float temperature; + private final List stopSequences = new ArrayList<>(); + private final List images = new ArrayList<>(); + private final List imageProperties = new ArrayList<>(); + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + /** Top K value for sampling. */ + public Builder topK(int topK) { + this.topK = topK; + return this; + } + + /** Top P value for nucleus sampling. */ + public Builder topP(float topP) { + this.topP = topP; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** Limit the number of tokens to generate in the response. */ + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + /** + * Set tokens which should signal the model to stop generating further output. + */ + public Builder stopSequences(String... stopSequences) { + return stopSequences(Arrays.asList(stopSequences)); + } + + /** + * Set tokens which should signal the model to stop generating further output. + */ + public Builder stopSequences(List stopSequences) { + this.stopSequences.addAll(stopSequences); + return this; + } + + public Builder images(String... images) { + return images(Arrays.asList(images)); + } + + public Builder images(List images) { + this.images.addAll(images); + return this; + } + + public Builder imageProperties(String... imageProperties) { + return imageProperties(Arrays.asList(imageProperties)); + } + + public Builder imageProperties(List imageProperties) { + this.imageProperties.addAll(imageProperties); + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public AnthropicGenerative.Provider build() { + return new AnthropicGenerative.Provider(this); + } + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnyscaleGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnyscaleGenerative.java new file mode 100644 index 000000000..a2279e0a2 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AnyscaleGenerative.java @@ -0,0 +1,140 @@ +package io.weaviate.client6.v1.api.collections.generative; + +import java.util.function.Function; + +import com.google.gson.annotations.SerializedName; + +import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.api.collections.generate.DynamicProvider; +import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; + +public record AnyscaleGenerative( + @SerializedName("baseURL") String baseUrl, + @SerializedName("model") String model, + @SerializedName("temperature") Float temperature) implements Generative { + + @Override + public Kind _kind() { + return Generative.Kind.ANYSCALE; + } + + @Override + public Object _self() { + return this; + } + + public static AnyscaleGenerative of() { + return of(ObjectBuilder.identity()); + } + + public static AnyscaleGenerative of(Function> fn) { + return fn.apply(new Builder()).build(); + } + + public AnyscaleGenerative(Builder builder) { + this( + builder.baseUrl, + builder.model, + builder.temperature); + } + + public static class Builder implements ObjectBuilder { + private String baseUrl; + private String model; + private Float temperature; + + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + public Builder model(String model) { + this.model = model; + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public AnyscaleGenerative build() { + return new AnyscaleGenerative(this); + } + } + + public static record Metadata() implements ProviderMetadata { + } + + public static record Provider( + String baseUrl, + String model, + Float temperature) implements DynamicProvider { + + public static Provider of( + Function> fn) { + return fn.apply(new Builder()).build(); + } + + @Override + public void appendTo( + io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative.GenerativeProvider.Builder req) { + var provider = WeaviateProtoGenerative.GenerativeAnyscale.newBuilder(); + if (baseUrl != null) { + provider.setBaseUrl(baseUrl); + } + if (model != null) { + provider.setModel(model); + } + if (temperature != null) { + provider.setTemperature(temperature); + } + req.setAnyscale(provider); + } + + public Provider(Builder builder) { + this( + builder.baseUrl, + builder.model, + builder.temperature); + } + + public static class Builder implements ObjectBuilder { + private String baseUrl; + private String model; + private Float temperature; + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public AnyscaleGenerative.Provider build() { + return new AnyscaleGenerative.Provider(this); + } + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AwsGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AwsGenerative.java new file mode 100644 index 000000000..1589b15db --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AwsGenerative.java @@ -0,0 +1,222 @@ +package io.weaviate.client6.v1.api.collections.generative; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.function.Function; + +import com.google.gson.annotations.SerializedName; + +import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.api.collections.generate.DynamicProvider; +import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; + +public record AwsGenerative( + @SerializedName("region") String region, + @SerializedName("service") String service, + @SerializedName("endpoint") String baseUrl, + @SerializedName("model") String model) implements Generative { + + @Override + public Generative.Kind _kind() { + return Generative.Kind.AWS; + } + + @Override + public Object _self() { + return this; + } + + public static AwsGenerative of(String region, String service) { + return of(region, service, ObjectBuilder.identity()); + } + + public static AwsGenerative of(String region, String service, Function> fn) { + return fn.apply(new Builder(region, service)).build(); + } + + public AwsGenerative(Builder builder) { + this( + builder.service, + builder.region, + builder.baseUrl, + builder.model); + } + + public static class Builder implements ObjectBuilder { + private final String region; + private final String service; + + public Builder(String service, String region) { + this.service = service; + this.region = region; + } + + private String baseUrl; + private String model; + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + @Override + public AwsGenerative build() { + return new AwsGenerative(this); + } + } + + public static record Metadata() implements ProviderMetadata { + } + + public static record Provider( + String region, + String service, + String baseUrl, + String model, + String targetModel, + String targetModelVariant, + Float temperature, + List images, + List imageProperties) implements DynamicProvider { + + public static Provider of( + Function> fn) { + return fn.apply(new Builder()).build(); + } + + @Override + public void appendTo( + io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative.GenerativeProvider.Builder req) { + var provider = WeaviateProtoGenerative.GenerativeAWS.newBuilder(); + if (region != null) { + provider.setRegion(region); + } + if (service != null) { + provider.setService(service); + } + if (baseUrl != null) { + provider.setEndpoint(baseUrl); + } + if (model != null) { + provider.setModel(model); + } + if (temperature != null) { + provider.setTemperature(temperature); + } + if (targetModel != null) { + provider.setTargetModel(targetModel); + } + if (targetModelVariant != null) { + provider.setTargetVariant(targetModelVariant); + } + if (images != null) { + provider.setImages(WeaviateProtoBase.TextArray.newBuilder() + .addAllValues(images)); + } + if (imageProperties != null) { + provider.setImageProperties(WeaviateProtoBase.TextArray.newBuilder() + .addAllValues(imageProperties)); + } + req.setAws(provider); + } + + public Provider(Builder builder) { + this( + builder.region, + builder.service, + builder.baseUrl, + builder.model, + builder.targetModel, + builder.targetModelVariant, + builder.temperature, + builder.images, + builder.imageProperties); + } + + public static class Builder implements ObjectBuilder { + private String region; + private String service; + private String baseUrl; + private String model; + private String targetModel; + private String targetModelVariant; + private Float temperature; + private final List images = new ArrayList<>(); + private final List imageProperties = new ArrayList<>(); + + public Builder region(String region) { + this.region = region; + return this; + } + + public Builder service(String service) { + this.service = service; + return this; + } + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + public Builder targetModel(String targetModel) { + this.targetModel = targetModel; + return this; + } + + public Builder targetModelVariant(String targetModelVariant) { + this.targetModelVariant = targetModelVariant; + return this; + } + + public Builder images(String... images) { + return images(Arrays.asList(images)); + } + + public Builder images(List images) { + this.images.addAll(images); + return this; + } + + public Builder imageProperties(String... imageProperties) { + return imageProperties(Arrays.asList(imageProperties)); + } + + public Builder imageProperties(List imageProperties) { + this.imageProperties.addAll(imageProperties); + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public AwsGenerative.Provider build() { + return new AwsGenerative.Provider(this); + } + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/AzureOpenAiGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AzureOpenAiGenerative.java new file mode 100644 index 000000000..94d0a3c0a --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/AzureOpenAiGenerative.java @@ -0,0 +1,320 @@ +package io.weaviate.client6.v1.api.collections.generative; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.function.Function; + +import com.google.gson.annotations.SerializedName; + +import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.api.collections.generate.DynamicProvider; +import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; + +public record AzureOpenAiGenerative( + @SerializedName("baseURL") String baseUrl, + @SerializedName("frequencyPenaltyProperty") Float frequencyPenalty, + @SerializedName("presencePenaltyProperty") Float presencePenalty, + @SerializedName("maxTokensProperty") Integer maxTokens, + @SerializedName("temperatureProperty") Float temperature, + @SerializedName("topPProperty") Float topP, + + @SerializedName("resourceName") String resourceName, + @SerializedName("deploymentId") String deploymentId) implements Generative { + + @Override + public Kind _kind() { + return Generative.Kind.AZURE_OPENAI; + } + + @Override + public Object _self() { + return this; + } + + public static AzureOpenAiGenerative of(String resourceName, String deploymentId) { + return of(resourceName, deploymentId, ObjectBuilder.identity()); + } + + public static AzureOpenAiGenerative of(String resourceName, String deploymentId, + Function> fn) { + return fn.apply(new Builder(resourceName, deploymentId)).build(); + } + + public AzureOpenAiGenerative(Builder builder) { + this( + builder.baseUrl, + builder.frequencyPenalty, + builder.presencePenalty, + builder.maxTokens, + builder.temperature, + builder.topP, + builder.resourceName, + builder.deploymentId); + } + + public static class Builder implements ObjectBuilder { + private final String resourceName; + private final String deploymentId; + + private String baseUrl; + private Float frequencyPenalty; + private Float presencePenalty; + private Integer maxTokens; + private Float temperature; + private Float topP; + + public Builder(String resourceName, String deploymentId) { + this.resourceName = resourceName; + this.deploymentId = deploymentId; + } + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + /** Limit the number of tokens to generate in the response. */ + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + public Builder frequencyPenalty(float frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + return this; + } + + public Builder presencePenalty(float presencePenalty) { + this.presencePenalty = presencePenalty; + return this; + } + + /** Top P value for nucleus sampling. */ + public Builder topP(float topP) { + this.topP = topP; + return this; + } + + @Override + public AzureOpenAiGenerative build() { + return new AzureOpenAiGenerative(this); + } + } + + public static record Provider( + String baseUrl, + Integer maxTokens, + String model, + Float temperature, + Integer n, + Float topP, + Float frequencyPenalty, + Float presencePenalty, + String apiVersion, + String resourceName, + String deploymentId, + List stopSequences, + List images, + List imageProperties) implements DynamicProvider { + + public static Provider of( + Function> fn) { + return fn.apply(new Builder()).build(); + } + + @Override + public void appendTo( + io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative.GenerativeProvider.Builder req) { + var provider = WeaviateProtoGenerative.GenerativeOpenAI.newBuilder(); + if (baseUrl != null) { + provider.setBaseUrl(baseUrl); + } + if (maxTokens != null) { + provider.setMaxTokens(maxTokens); + } + if (model != null) { + provider.setModel(model); + } + if (temperature != null) { + provider.setTemperature(temperature); + } + if (n != null) { + provider.setN(n); + } + if (topP != null) { + provider.setTopP(topP); + } + if (frequencyPenalty != null) { + provider.setFrequencyPenalty(frequencyPenalty); + } + if (presencePenalty != null) { + provider.setPresencePenalty(presencePenalty); + } + if (apiVersion != null) { + provider.setApiVersion(apiVersion); + } + if (resourceName != null) { + provider.setResourceName(resourceName); + } + if (deploymentId != null) { + provider.setDeploymentId(deploymentId); + } + if (stopSequences != null) { + provider.setStop(WeaviateProtoBase.TextArray.newBuilder() + .addAllValues(stopSequences)); + } + provider.setIsAzure(true); + req.setOpenai(provider); + } + + public Provider(Builder builder) { + this( + builder.baseUrl, + builder.maxTokens, + builder.model, + builder.temperature, + builder.n, + builder.topP, + builder.frequencyPenalty, + builder.presencePenalty, + builder.apiVersion, + builder.resourceName, + builder.deploymentId, + builder.stopSequences, + builder.images, + builder.imageProperties); + } + + public static class Builder implements ObjectBuilder { + private String baseUrl; + private Integer n; + private Float topP; + private String model; + private Integer maxTokens; + private Float temperature; + private Float frequencyPenalty; + private Float presencePenalty; + private String apiVersion; + private String resourceName; + private String deploymentId; + private final List stopSequences = new ArrayList<>(); + private final List images = new ArrayList<>(); + private final List imageProperties = new ArrayList<>(); + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + public Builder n(int n) { + this.n = n; + return this; + } + + /** Top P value for nucleus sampling. */ + public Builder topP(float topP) { + this.topP = topP; + return this; + } + + public Builder frequencyPenalty(float frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + return this; + } + + /** Top P value for nucleus sampling. */ + public Builder presencePenalty(float presencePenalty) { + this.presencePenalty = presencePenalty; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** Limit the number of tokens to generate in the response. */ + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + /** + * Set tokens which should signal the model to stop generating further output. + */ + public Builder stopSequences(String... stopSequences) { + return stopSequences(Arrays.asList(stopSequences)); + } + + /** + * Set tokens which should signal the model to stop generating further output. + */ + public Builder stopSequences(List stopSequences) { + this.stopSequences.addAll(stopSequences); + return this; + } + + public Builder apiVersion(String apiVersion) { + this.apiVersion = apiVersion; + return this; + } + + public Builder resourceName(String resourceName) { + this.resourceName = resourceName; + return this; + } + + public Builder deploymentId(String deploymentId) { + this.deploymentId = deploymentId; + return this; + } + + public Builder images(String... images) { + return images(Arrays.asList(images)); + } + + public Builder images(List images) { + this.images.addAll(images); + return this; + } + + public Builder imageProperties(String... imageProperties) { + return imageProperties(Arrays.asList(imageProperties)); + } + + public Builder imageProperties(List imageProperties) { + this.imageProperties.addAll(imageProperties); + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public AzureOpenAiGenerative.Provider build() { + return new AzureOpenAiGenerative.Provider(this); + } + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/CohereGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/CohereGenerative.java index b95ffc601..9a3d8860e 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/generative/CohereGenerative.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/CohereGenerative.java @@ -8,15 +8,19 @@ import com.google.gson.annotations.SerializedName; import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.api.collections.generate.DynamicProvider; import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; public record CohereGenerative( - @SerializedName("kProperty") String kProperty, + @SerializedName("baseURL") String baseUrl, + @SerializedName("kProperty") Integer topK, @SerializedName("model") String model, - @SerializedName("maxTokensProperty") Integer maxTokensProperty, + @SerializedName("maxTokensProperty") Integer maxTokens, + @SerializedName("temperatureProperty") Float temperature, @SerializedName("returnLikelihoodsProperty") String returnLikelihoodsProperty, - @SerializedName("stopSequencesProperty") List stopSequencesProperty, - @SerializedName("temperatureProperty") String temperatureProperty) implements Generative { + @SerializedName("stopSequencesProperty") List stopSequences) implements Generative { @Override public Kind _kind() { @@ -38,34 +42,45 @@ public static CohereGenerative of(Function { - private String kProperty; + private String baseUrl; + private Integer topK; private String model; - private Integer maxTokensProperty; + private Integer maxTokens; + private Float temperature; private String returnLikelihoodsProperty; - private List stopSequencesProperty = new ArrayList<>(); - private String temperatureProperty; + private List stopSequences = new ArrayList<>(); - public Builder kProperty(String kProperty) { - this.kProperty = kProperty; + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; return this; } + /** Top K value for sampling. */ + public Builder topK(int topK) { + this.topK = topK; + return this; + } + + /** Select generative model. */ public Builder model(String model) { this.model = model; return this; } - public Builder maxTokensProperty(int maxTokensProperty) { - this.maxTokensProperty = maxTokensProperty; + /** Limit the number of tokens to generate in the response. */ + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; return this; } @@ -74,17 +89,27 @@ public Builder returnLikelihoodsProperty(String returnLikelihoodsProperty) { return this; } - public Builder stopSequencesProperty(String... stopSequencesProperty) { - return stopSequencesProperty(Arrays.asList(stopSequencesProperty)); + /** + * Set tokens which should signal the model to stop generating further output. + */ + public Builder stopSequences(String... stopSequences) { + return stopSequences(Arrays.asList(stopSequences)); } - public Builder stopSequencesProperty(List stopSequencesProperty) { - this.stopSequencesProperty = stopSequencesProperty; + /** + * Set tokens which should signal the model to stop generating further output. + */ + public Builder stopSequences(List stopSequences) { + this.stopSequences = stopSequences; return this; } - public Builder temperatureProperty(String temperatureProperty) { - this.temperatureProperty = temperatureProperty; + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; return this; } @@ -93,4 +118,167 @@ public CohereGenerative build() { return new CohereGenerative(this); } } + + public static record Metadata(ApiVersion apiVersion, BilledUnits billedUnits, Tokens tokens, List warnings) + implements ProviderMetadata { + + public static record ApiVersion(String version, Boolean deprecated, Boolean experimental) { + } + + public static record BilledUnits(Double inputTokens, Double outputTokens, Double searchUnits, + Double classifications) { + } + + public static record Tokens(Double inputTokens, Double outputTokens) { + } + } + + public static record Provider( + String baseUrl, + Integer maxTokens, + String model, + Float temperature, + Integer topK, + Float topP, + Float frequencyPenalty, + Float presencePenalty, + List stopSequences) implements DynamicProvider { + + public static Provider of( + Function> fn) { + return fn.apply(new Builder()).build(); + } + + @Override + public void appendTo( + io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative.GenerativeProvider.Builder req) { + var provider = WeaviateProtoGenerative.GenerativeCohere.newBuilder(); + if (baseUrl != null) { + provider.setBaseUrl(baseUrl); + } + if (maxTokens != null) { + provider.setMaxTokens(maxTokens); + } + if (model != null) { + provider.setModel(model); + } + if (temperature != null) { + provider.setTemperature(temperature); + } + if (topK != null) { + provider.setK(topK); + } + if (topP != null) { + provider.setP(topP); + } + + if (frequencyPenalty != null) { + provider.setFrequencyPenalty(frequencyPenalty); + } + if (presencePenalty != null) { + provider.setPresencePenalty(presencePenalty); + } + + if (stopSequences != null) { + provider.setStopSequences(WeaviateProtoBase.TextArray.newBuilder() + .addAllValues(stopSequences)); + } + req.setCohere(provider); + } + + public Provider(Builder builder) { + this( + builder.baseUrl, + builder.maxTokens, + builder.model, + builder.temperature, + builder.topK, + builder.topP, + builder.frequencyPenalty, + builder.presencePenalty, + builder.stopSequences); + } + + public static class Builder implements ObjectBuilder { + private String baseUrl; + private Integer topK; + private Float topP; + private String model; + private Integer maxTokens; + private Float temperature; + private Float frequencyPenalty; + private Float presencePenalty; + private final List stopSequences = new ArrayList<>(); + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + /** Top K value for sampling. */ + public Builder topK(int topK) { + this.topK = topK; + return this; + } + + /** Top P value for nucleus sampling. */ + public Builder topP(float topP) { + this.topP = topP; + return this; + } + + public Builder frequencyPenalty(float frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + return this; + } + + /** Top P value for nucleus sampling. */ + public Builder presencePenalty(float presencePenalty) { + this.presencePenalty = presencePenalty; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** Limit the number of tokens to generate in the response. */ + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + /** + * Set tokens which should signal the model to stop generating further output. + */ + public Builder stopSequences(String... stopSequences) { + return stopSequences(Arrays.asList(stopSequences)); + } + + /** + * Set tokens which should signal the model to stop generating further output. + */ + public Builder stopSequences(List stopSequences) { + this.stopSequences.addAll(stopSequences); + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public CohereGenerative.Provider build() { + return new CohereGenerative.Provider(this); + } + } + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/DatabricksGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/DatabricksGenerative.java new file mode 100644 index 000000000..df2b44f14 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/DatabricksGenerative.java @@ -0,0 +1,265 @@ +package io.weaviate.client6.v1.api.collections.generative; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.function.Function; + +import com.google.gson.annotations.SerializedName; + +import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.api.collections.generate.DynamicProvider; +import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; + +public record DatabricksGenerative( + @SerializedName("endpoint") String baseUrl, + @SerializedName("maxTokens") Integer maxTokens, + @SerializedName("topK") Integer topK, + @SerializedName("topP") Float topP, + @SerializedName("temperature") Float temperature) implements Generative { + + @Override + public Kind _kind() { + return Generative.Kind.DATABRICKS; + } + + @Override + public Object _self() { + return this; + } + + public static DatabricksGenerative of(String baseURL) { + return of(baseURL, ObjectBuilder.identity()); + } + + public static DatabricksGenerative of(String baseURL, Function> fn) { + return fn.apply(new Builder(baseURL)).build(); + } + + public DatabricksGenerative(Builder builder) { + this( + builder.baseUrl, + builder.maxTokens, + builder.topK, + builder.topP, + builder.temperature); + } + + public static class Builder implements ObjectBuilder { + private final String baseUrl; + + private Integer maxTokens; + private Integer topK; + private Float topP; + private Float temperature; + + public Builder(String baseUrl) { + this.baseUrl = baseUrl; + } + + /** Limit the number of tokens to generate in the response. */ + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + /** Top K value for sampling. */ + public Builder topK(int topK) { + this.topK = topK; + return this; + } + + /** Top P value for nucleus sampling. */ + public Builder topP(float topP) { + this.topP = topP; + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public DatabricksGenerative build() { + return new DatabricksGenerative(this); + } + } + + public static record Metadata(ProviderMetadata.Usage usage) implements ProviderMetadata { + } + + public static record Provider( + String baseUrl, + Integer maxTokens, + String model, + Float temperature, + Integer n, + Float topP, + Float frequencyPenalty, + Float presencePenalty, + Boolean logProbs, + Integer topLogProbs, + List stopSequences) implements DynamicProvider { + + public static Provider of( + Function> fn) { + return fn.apply(new Builder()).build(); + } + + @Override + public void appendTo( + io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative.GenerativeProvider.Builder req) { + var provider = WeaviateProtoGenerative.GenerativeDatabricks.newBuilder(); + if (baseUrl != null) { + provider.setEndpoint(baseUrl); + } + if (maxTokens != null) { + provider.setMaxTokens(maxTokens); + } + if (model != null) { + provider.setModel(model); + } + if (temperature != null) { + provider.setTemperature(temperature); + } + if (n != null) { + provider.setN(n); + } + if (topP != null) { + provider.setTopP(topP); + } + if (frequencyPenalty != null) { + provider.setFrequencyPenalty(frequencyPenalty); + } + if (presencePenalty != null) { + provider.setPresencePenalty(presencePenalty); + } + if (logProbs != null) { + provider.setLogProbs(logProbs); + } + if (topLogProbs != null) { + provider.setTopLogProbs(topLogProbs); + } + if (stopSequences != null) { + provider.setStop(WeaviateProtoBase.TextArray.newBuilder() + .addAllValues(stopSequences)); + } + req.setDatabricks(provider); + } + + public Provider(Builder builder) { + this( + builder.baseUrl, + builder.maxTokens, + builder.model, + builder.temperature, + builder.n, + builder.topP, + builder.frequencyPenalty, + builder.presencePenalty, + builder.logProbs, + builder.topLogProbs, + builder.stopSequences); + } + + public static class Builder implements ObjectBuilder { + private String baseUrl; + private Integer n; + private Float topP; + private String model; + private Integer maxTokens; + private Float temperature; + private Float frequencyPenalty; + private Float presencePenalty; + private Boolean logProbs; + private Integer topLogProbs; + private final List stopSequences = new ArrayList<>(); + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + public Builder n(int n) { + this.n = n; + return this; + } + + /** Top P value for nucleus sampling. */ + public Builder topP(float topP) { + this.topP = topP; + return this; + } + + public Builder frequencyPenalty(float frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + return this; + } + + /** Top P value for nucleus sampling. */ + public Builder presencePenalty(float presencePenalty) { + this.presencePenalty = presencePenalty; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** Limit the number of tokens to generate in the response. */ + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + /** + * Set tokens which should signal the model to stop generating further output. + */ + public Builder stopSequences(String... stopSequences) { + return stopSequences(Arrays.asList(stopSequences)); + } + + /** + * Set tokens which should signal the model to stop generating further output. + */ + public Builder stopSequences(List stopSequences) { + this.stopSequences.addAll(stopSequences); + return this; + } + + public Builder logProbs(boolean logProbs) { + this.logProbs = logProbs; + return this; + } + + public Builder topLogProbs(int topLogProbs) { + this.topLogProbs = topLogProbs; + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public DatabricksGenerative.Provider build() { + return new DatabricksGenerative.Provider(this); + } + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/DummyGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/DummyGenerative.java new file mode 100644 index 000000000..8f45163fe --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/DummyGenerative.java @@ -0,0 +1,18 @@ +package io.weaviate.client6.v1.api.collections.generative; + +import io.weaviate.client6.v1.api.collections.Generative; + +public record DummyGenerative() implements Generative { + @Override + public Kind _kind() { + return Generative.Kind.DUMMY; + } + + @Override + public Object _self() { + return this; + } + + public static record Metadata() implements ProviderMetadata { + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/FriendliaiGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/FriendliaiGenerative.java new file mode 100644 index 000000000..d154dc8b3 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/FriendliaiGenerative.java @@ -0,0 +1,186 @@ +package io.weaviate.client6.v1.api.collections.generative; + +import java.util.function.Function; + +import com.google.gson.annotations.SerializedName; + +import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.api.collections.generate.DynamicProvider; +import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; + +public record FriendliaiGenerative( + @SerializedName("baseURL") String baseUrl, + @SerializedName("model") String model, + @SerializedName("maxTokens") Integer maxTokens, + @SerializedName("temperature") Float temperature) implements Generative { + + @Override + public Kind _kind() { + return Generative.Kind.FRIENDLIAI; + } + + @Override + public Object _self() { + return this; + } + + public static FriendliaiGenerative of() { + return of(ObjectBuilder.identity()); + } + + public static FriendliaiGenerative of(Function> fn) { + return fn.apply(new Builder()).build(); + } + + public FriendliaiGenerative(Builder builder) { + this( + builder.baseUrl, + builder.model, + builder.maxTokens, + builder.temperature); + } + + public static class Builder implements ObjectBuilder { + private String baseUrl; + private String model; + private Integer maxTokens; + private Float temperature; + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + /** Limit the number of tokens to generate in the response. */ + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public FriendliaiGenerative build() { + return new FriendliaiGenerative(this); + } + } + + public static record Metadata(ProviderMetadata.Usage usage) implements ProviderMetadata { + } + + public static record Provider( + String baseUrl, + Integer maxTokens, + String model, + Float temperature, + Integer n, + Float topP) implements DynamicProvider { + + public static Provider of( + Function> fn) { + return fn.apply(new Builder()).build(); + } + + @Override + public void appendTo( + io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative.GenerativeProvider.Builder req) { + var provider = WeaviateProtoGenerative.GenerativeFriendliAI.newBuilder(); + if (baseUrl != null) { + provider.setBaseUrl(baseUrl); + } + if (maxTokens != null) { + provider.setMaxTokens(maxTokens); + } + if (model != null) { + provider.setModel(model); + } + if (temperature != null) { + provider.setTemperature(temperature); + } + if (n != null) { + provider.setN(n); + } + if (topP != null) { + provider.setTopP(topP); + } + req.setFriendliai(provider); + } + + public Provider(Builder builder) { + this( + builder.baseUrl, + builder.maxTokens, + builder.model, + builder.temperature, + builder.n, + builder.topP); + } + + public static class Builder implements ObjectBuilder { + private String baseUrl; + private Integer n; + private Float topP; + private String model; + private Integer maxTokens; + private Float temperature; + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + public Builder n(int n) { + this.n = n; + return this; + } + + /** Top P value for nucleus sampling. */ + public Builder topP(float topP) { + this.topP = topP; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** Limit the number of tokens to generate in the response. */ + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public FriendliaiGenerative.Provider build() { + return new FriendliaiGenerative.Provider(this); + } + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/GoogleGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/GoogleGenerative.java new file mode 100644 index 000000000..084bab0b2 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/GoogleGenerative.java @@ -0,0 +1,327 @@ +package io.weaviate.client6.v1.api.collections.generative; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.function.Function; + +import com.google.gson.annotations.SerializedName; + +import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.api.collections.generate.DynamicProvider; +import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; + +public record GoogleGenerative( + @SerializedName("apiEndpoint") String baseUrl, + @SerializedName("modelId") String model, + @SerializedName("projectId") String projectId, + @SerializedName("maxOutputTokens") Integer maxTokens, + @SerializedName("topK") Integer topK, + @SerializedName("topP") Float topP, + @SerializedName("temperature") Float temperature) implements Generative { + + @Override + public Kind _kind() { + return Generative.Kind.GOOGLE; + } + + @Override + public Object _self() { + return this; + } + + public static GoogleGenerative of(String projectId) { + return of(projectId, ObjectBuilder.identity()); + } + + public static GoogleGenerative of(String projectId, Function> fn) { + return fn.apply(new Builder(projectId)).build(); + } + + public GoogleGenerative(Builder builder) { + this( + builder.baseUrl, + builder.model, + builder.projectId, + builder.maxTokens, + builder.topK, + builder.topP, + builder.temperature); + } + + public static class Builder implements ObjectBuilder { + private final String projectId; + + private String baseUrl; + private String model; + private Integer maxTokens; + private Integer topK; + private Float topP; + private Float temperature; + + public Builder(String projectId) { + this.projectId = projectId; + } + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** Limit the number of tokens to generate in the response. */ + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + /** Top K value for sampling. */ + public Builder topK(int topK) { + this.topK = topK; + return this; + } + + /** Top P value for nucleus sampling. */ + public Builder topP(float topP) { + this.topP = topP; + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public GoogleGenerative build() { + return new GoogleGenerative(this); + } + } + + public static record Metadata(TokenMetadata tokens, Usage usage) implements ProviderMetadata { + + public static record TokenCount(Long totalBillableCharacters, Long totalTokens) { + } + + public static record TokenMetadata(TokenCount inputTokens, TokenCount outputTokens) { + } + + public static record Usage(Long promptTokenCount, Long candidatesTokenCount, Long totalTokenCount) { + } + } + + public static record Provider( + String baseUrl, + Integer maxTokens, + String model, + Float temperature, + Integer topK, + Float topP, + Float frequencyPenalty, + Float presencePenalty, + String projectId, + String endpointId, + String region, + List stopSequences, + List images, + List imageProperties) implements DynamicProvider { + + public static Provider of( + Function> fn) { + return fn.apply(new Builder()).build(); + } + + @Override + public void appendTo( + io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative.GenerativeProvider.Builder req) { + var provider = WeaviateProtoGenerative.GenerativeGoogle.newBuilder(); + if (baseUrl != null) { + provider.setApiEndpoint(baseUrl); + } + if (maxTokens != null) { + provider.setMaxTokens(maxTokens); + } + if (model != null) { + provider.setModel(model); + } + if (temperature != null) { + provider.setTemperature(temperature); + } + if (topK != null) { + provider.setTopK(topK); + } + if (topP != null) { + provider.setTopP(topP); + } + if (projectId != null) { + provider.setProjectId(projectId); + } + if (endpointId != null) { + provider.setEndpointId(endpointId); + } + if (region != null) { + provider.setRegion(region); + } + if (frequencyPenalty != null) { + provider.setFrequencyPenalty(frequencyPenalty); + } + if (presencePenalty != null) { + provider.setPresencePenalty(presencePenalty); + } + if (stopSequences != null) { + provider.setStopSequences(WeaviateProtoBase.TextArray.newBuilder() + .addAllValues(stopSequences)); + } + req.setGoogle(provider); + } + + public Provider(Builder builder) { + this( + builder.baseUrl, + builder.maxTokens, + builder.model, + builder.temperature, + builder.topK, + builder.topP, + builder.frequencyPenalty, + builder.presencePenalty, + builder.projectId, + builder.endpointId, + builder.region, + builder.stopSequences, + builder.images, + builder.imageProperties); + } + + public static class Builder implements ObjectBuilder { + private String baseUrl; + private Integer topK; + private Float topP; + private String model; + private Integer maxTokens; + private Float temperature; + private Float frequencyPenalty; + private Float presencePenalty; + private String projectId; + private String endpointId; + private String region; + private final List stopSequences = new ArrayList<>(); + private final List images = new ArrayList<>(); + private final List imageProperties = new ArrayList<>(); + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + public Builder topK(int topK) { + this.topK = topK; + return this; + } + + /** Top P value for nucleus sampling. */ + public Builder topP(float topP) { + this.topP = topP; + return this; + } + + public Builder frequencyPenalty(float frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + return this; + } + + /** Top P value for nucleus sampling. */ + public Builder presencePenalty(float presencePenalty) { + this.presencePenalty = presencePenalty; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** Limit the number of tokens to generate in the response. */ + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + /** + * Set tokens which should signal the model to stop generating further output. + */ + public Builder stopSequences(String... stopSequences) { + return stopSequences(Arrays.asList(stopSequences)); + } + + /** + * Set tokens which should signal the model to stop generating further output. + */ + public Builder stopSequences(List stopSequences) { + this.stopSequences.addAll(stopSequences); + return this; + } + + public Builder projectId(String projectId) { + this.projectId = projectId; + return this; + } + + public Builder endpointId(String endpointId) { + this.endpointId = endpointId; + return this; + } + + public Builder region(String region) { + this.region = region; + return this; + } + + public Builder images(String... images) { + return images(Arrays.asList(images)); + } + + public Builder images(List images) { + this.images.addAll(images); + return this; + } + + public Builder imageProperties(String... imageProperties) { + return imageProperties(Arrays.asList(imageProperties)); + } + + public Builder imageProperties(List imageProperties) { + this.imageProperties.addAll(imageProperties); + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public GoogleGenerative.Provider build() { + return new GoogleGenerative.Provider(this); + } + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/MistralGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/MistralGenerative.java new file mode 100644 index 000000000..3f64cd06e --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/MistralGenerative.java @@ -0,0 +1,175 @@ +package io.weaviate.client6.v1.api.collections.generative; + +import java.util.function.Function; + +import com.google.gson.annotations.SerializedName; + +import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.api.collections.generate.DynamicProvider; +import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; + +public record MistralGenerative( + @SerializedName("baseURL") String baseUrl, + @SerializedName("model") String model, + @SerializedName("maxTokens") Integer maxTokens, + @SerializedName("temperature") Float temperature) implements Generative { + + @Override + public Kind _kind() { + return Generative.Kind.MISTRAL; + } + + @Override + public Object _self() { + return this; + } + + public static MistralGenerative of() { + return of(ObjectBuilder.identity()); + } + + public static MistralGenerative of(Function> fn) { + return fn.apply(new Builder()).build(); + } + + public MistralGenerative(Builder builder) { + this( + builder.baseUrl, + builder.model, + builder.maxTokens, + builder.temperature); + } + + public static class Builder implements ObjectBuilder { + private String baseUrl; + private String model; + private Integer maxTokens; + private Float temperature; + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + /** Limit the number of tokens to generate in the response. */ + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public MistralGenerative build() { + return new MistralGenerative(this); + } + } + + public static record Metadata(ProviderMetadata.Usage usage) implements ProviderMetadata { + } + + public static record Provider( + String baseUrl, + Integer maxTokens, + String model, + Float temperature, + Float topP) implements DynamicProvider { + + public static Provider of( + Function> fn) { + return fn.apply(new Builder()).build(); + } + + @Override + public void appendTo( + io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative.GenerativeProvider.Builder req) { + var provider = WeaviateProtoGenerative.GenerativeMistral.newBuilder(); + if (baseUrl != null) { + provider.setBaseUrl(baseUrl); + } + if (maxTokens != null) { + provider.setMaxTokens(maxTokens); + } + if (model != null) { + provider.setModel(model); + } + if (temperature != null) { + provider.setTemperature(temperature); + } + if (topP != null) { + provider.setTopP(topP); + } + req.setMistral(provider); + } + + public Provider(Builder builder) { + this( + builder.baseUrl, + builder.maxTokens, + builder.model, + builder.temperature, + builder.topP); + } + + public static class Builder implements ObjectBuilder { + private String baseUrl; + private Float topP; + private String model; + private Integer maxTokens; + private Float temperature; + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + /** Top P value for nucleus sampling. */ + public Builder topP(float topP) { + this.topP = topP; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** Limit the number of tokens to generate in the response. */ + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public MistralGenerative.Provider build() { + return new MistralGenerative.Provider(this); + } + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/NvidiaGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/NvidiaGenerative.java new file mode 100644 index 000000000..6bc156a16 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/NvidiaGenerative.java @@ -0,0 +1,175 @@ +package io.weaviate.client6.v1.api.collections.generative; + +import java.util.function.Function; + +import com.google.gson.annotations.SerializedName; + +import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.api.collections.generate.DynamicProvider; +import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; + +public record NvidiaGenerative( + @SerializedName("baseURL") String baseUrl, + @SerializedName("model") String model, + @SerializedName("maxTokens") Integer maxTokens, + @SerializedName("temperature") Float temperature) implements Generative { + + @Override + public Kind _kind() { + return Generative.Kind.NVIDIA; + } + + @Override + public Object _self() { + return this; + } + + public static NvidiaGenerative of() { + return of(ObjectBuilder.identity()); + } + + public static NvidiaGenerative of(Function> fn) { + return fn.apply(new Builder()).build(); + } + + public NvidiaGenerative(Builder builder) { + this( + builder.baseUrl, + builder.model, + builder.maxTokens, + builder.temperature); + } + + public static class Builder implements ObjectBuilder { + private String baseUrl; + private String model; + private Integer maxTokens; + private Float temperature; + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + /** Limit the number of tokens to generate in the response. */ + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public NvidiaGenerative build() { + return new NvidiaGenerative(this); + } + } + + public static record Metadata(ProviderMetadata.Usage usage) implements ProviderMetadata { + } + + public static record Provider( + String baseUrl, + Integer maxTokens, + String model, + Float temperature, + Float topP) implements DynamicProvider { + + public static Provider of( + Function> fn) { + return fn.apply(new Builder()).build(); + } + + @Override + public void appendTo( + io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative.GenerativeProvider.Builder req) { + var provider = WeaviateProtoGenerative.GenerativeNvidia.newBuilder(); + if (baseUrl != null) { + provider.setBaseUrl(baseUrl); + } + if (maxTokens != null) { + provider.setMaxTokens(maxTokens); + } + if (model != null) { + provider.setModel(model); + } + if (temperature != null) { + provider.setTemperature(temperature); + } + if (topP != null) { + provider.setTopP(topP); + } + req.setNvidia(provider); + } + + public Provider(Builder builder) { + this( + builder.baseUrl, + builder.maxTokens, + builder.model, + builder.temperature, + builder.topP); + } + + public static class Builder implements ObjectBuilder { + private String baseUrl; + private Float topP; + private String model; + private Integer maxTokens; + private Float temperature; + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + /** Top P value for nucleus sampling. */ + public Builder topP(float topP) { + this.topP = topP; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** Limit the number of tokens to generate in the response. */ + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public NvidiaGenerative.Provider build() { + return new NvidiaGenerative.Provider(this); + } + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/OllamaGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/OllamaGenerative.java new file mode 100644 index 000000000..89a356b8d --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/OllamaGenerative.java @@ -0,0 +1,166 @@ +package io.weaviate.client6.v1.api.collections.generative; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.function.Function; + +import com.google.gson.annotations.SerializedName; + +import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.api.collections.generate.DynamicProvider; +import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; + +public record OllamaGenerative( + @SerializedName("apiEndpoint") String baseUrl, + @SerializedName("model") String model) implements Generative { + + @Override + public Kind _kind() { + return Generative.Kind.OLLAMA; + } + + @Override + public Object _self() { + return this; + } + + public static OllamaGenerative of() { + return of(ObjectBuilder.identity()); + } + + public static OllamaGenerative of(Function> fn) { + return fn.apply(new Builder()).build(); + } + + public OllamaGenerative(Builder builder) { + this( + builder.baseUrl, + builder.model); + } + + public static class Builder implements ObjectBuilder { + private String baseUrl; + private String model; + + /** Base URL of the generative model. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + @Override + public OllamaGenerative build() { + return new OllamaGenerative(this); + } + } + + public static record Metadata() implements ProviderMetadata { + } + + public static record Provider( + String baseUrl, + String model, + Float temperature, + List images, + List imageProperties) implements DynamicProvider { + + public static Provider of( + Function> fn) { + return fn.apply(new Builder()).build(); + } + + @Override + public void appendTo( + io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative.GenerativeProvider.Builder req) { + var provider = WeaviateProtoGenerative.GenerativeOllama.newBuilder(); + if (baseUrl != null) { + provider.setApiEndpoint(baseUrl); + } + if (model != null) { + provider.setModel(model); + } + if (temperature != null) { + provider.setTemperature(temperature); + } + if (images != null) { + provider.setImages(WeaviateProtoBase.TextArray.newBuilder() + .addAllValues(images)); + } + if (imageProperties != null) { + provider.setImageProperties(WeaviateProtoBase.TextArray.newBuilder() + .addAllValues(imageProperties)); + } + req.setOllama(provider); + } + + public Provider(Builder builder) { + this( + builder.baseUrl, + builder.model, + builder.temperature, + builder.images, + builder.imageProperties); + } + + public static class Builder implements ObjectBuilder { + private String baseUrl; + private String model; + private Float temperature; + private final List images = new ArrayList<>(); + private final List imageProperties = new ArrayList<>(); + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + public Builder images(String... images) { + return images(Arrays.asList(images)); + } + + public Builder images(List images) { + this.images.addAll(images); + return this; + } + + public Builder imageProperties(String... imageProperties) { + return imageProperties(Arrays.asList(imageProperties)); + } + + public Builder imageProperties(List imageProperties) { + this.imageProperties.addAll(imageProperties); + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public OllamaGenerative.Provider build() { + return new OllamaGenerative.Provider(this); + } + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/OpenAiGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/OpenAiGenerative.java new file mode 100644 index 000000000..0417aacda --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/OpenAiGenerative.java @@ -0,0 +1,286 @@ +package io.weaviate.client6.v1.api.collections.generative; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.function.Function; + +import com.google.gson.annotations.SerializedName; + +import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.api.collections.generate.DynamicProvider; +import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; + +public record OpenAiGenerative( + @SerializedName("baseURL") String baseUrl, + @SerializedName("frequencyPenaltyProperty") Float frequencyPenalty, + @SerializedName("presencePenaltyProperty") Float presencePenalty, + @SerializedName("maxTokensProperty") Integer maxTokens, + @SerializedName("temperatureProperty") Float temperature, + @SerializedName("topPProperty") Float topP, + + @SerializedName("model") String model) implements Generative { + + @Override + public Kind _kind() { + return Generative.Kind.OPENAI; + } + + @Override + public Object _self() { + return this; + } + + public static OpenAiGenerative of() { + return of(ObjectBuilder.identity()); + } + + public static OpenAiGenerative of(Function> fn) { + return fn.apply(new Builder()).build(); + } + + public OpenAiGenerative(Builder builder) { + this( + builder.baseUrl, + builder.frequencyPenalty, + builder.presencePenalty, + builder.maxTokens, + builder.temperature, + builder.topP, + builder.model); + } + + public static class Builder implements ObjectBuilder { + private String baseUrl; + private Float frequencyPenalty; + private Float presencePenalty; + private Integer maxTokens; + private Float temperature; + private Float topP; + private String model; + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** Limit the number of tokens to generate in the response. */ + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + public Builder frequencyPenalty(float frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + return this; + } + + public Builder presencePenalty(float presencePenalty) { + this.presencePenalty = presencePenalty; + return this; + } + + /** Top P value for nucleus sampling. */ + public Builder topP(float topP) { + this.topP = topP; + return this; + } + + @Override + public OpenAiGenerative build() { + return new OpenAiGenerative(this); + } + } + + public static record Metadata(ProviderMetadata.Usage usage) implements ProviderMetadata { + } + + public static record Provider( + String baseUrl, + Integer maxTokens, + String model, + Float temperature, + Integer n, + Float topP, + Float frequencyPenalty, + Float presencePenalty, + List stopSequences, + List images, + List imageProperties) implements DynamicProvider { + + public static Provider of( + Function> fn) { + return fn.apply(new Builder()).build(); + } + + @Override + public void appendTo( + io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative.GenerativeProvider.Builder req) { + var provider = WeaviateProtoGenerative.GenerativeOpenAI.newBuilder(); + if (baseUrl != null) { + provider.setBaseUrl(baseUrl); + } + if (maxTokens != null) { + provider.setMaxTokens(maxTokens); + } + if (model != null) { + provider.setModel(model); + } + if (temperature != null) { + provider.setTemperature(temperature); + } + if (n != null) { + provider.setN(n); + } + if (topP != null) { + provider.setTopP(topP); + } + if (frequencyPenalty != null) { + provider.setFrequencyPenalty(frequencyPenalty); + } + if (presencePenalty != null) { + provider.setPresencePenalty(presencePenalty); + } + if (stopSequences != null) { + provider.setStop(WeaviateProtoBase.TextArray.newBuilder() + .addAllValues(stopSequences)); + } + provider.setIsAzure(false); + req.setOpenai(provider); + } + + public Provider(Builder builder) { + this( + builder.baseUrl, + builder.maxTokens, + builder.model, + builder.temperature, + builder.n, + builder.topP, + builder.frequencyPenalty, + builder.presencePenalty, + builder.stopSequences, + builder.images, + builder.imageProperties); + } + + public static class Builder implements ObjectBuilder { + private String baseUrl; + private Integer n; + private Float topP; + private String model; + private Integer maxTokens; + private Float temperature; + private Float frequencyPenalty; + private Float presencePenalty; + private final List stopSequences = new ArrayList<>(); + private final List images = new ArrayList<>(); + private final List imageProperties = new ArrayList<>(); + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + public Builder n(int n) { + this.n = n; + return this; + } + + /** Top P value for nucleus sampling. */ + public Builder topP(float topP) { + this.topP = topP; + return this; + } + + public Builder frequencyPenalty(float frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + return this; + } + + /** Top P value for nucleus sampling. */ + public Builder presencePenalty(float presencePenalty) { + this.presencePenalty = presencePenalty; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** Limit the number of tokens to generate in the response. */ + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + /** + * Set tokens which should signal the model to stop generating further output. + */ + public Builder stopSequences(String... stopSequences) { + return stopSequences(Arrays.asList(stopSequences)); + } + + /** + * Set tokens which should signal the model to stop generating further output. + */ + public Builder stopSequences(List stopSequences) { + this.stopSequences.addAll(stopSequences); + return this; + } + + public Builder images(String... images) { + return images(Arrays.asList(images)); + } + + public Builder images(List images) { + this.images.addAll(images); + return this; + } + + public Builder imageProperties(String... imageProperties) { + return imageProperties(Arrays.asList(imageProperties)); + } + + public Builder imageProperties(List imageProperties) { + this.imageProperties.addAll(imageProperties); + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public OpenAiGenerative.Provider build() { + return new OpenAiGenerative.Provider(this); + } + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/ProviderMetadata.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/ProviderMetadata.java new file mode 100644 index 000000000..884dd6dd3 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/ProviderMetadata.java @@ -0,0 +1,9 @@ +package io.weaviate.client6.v1.api.collections.generative; + +public interface ProviderMetadata { + record Usage( + Long promptTokens, + Long completionTokens, + Long totalTokens) { + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/generative/XaiGenerative.java b/src/main/java/io/weaviate/client6/v1/api/collections/generative/XaiGenerative.java new file mode 100644 index 000000000..d736c658c --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/generative/XaiGenerative.java @@ -0,0 +1,211 @@ +package io.weaviate.client6.v1.api.collections.generative; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.function.Function; + +import com.google.gson.annotations.SerializedName; + +import io.weaviate.client6.v1.api.collections.Generative; +import io.weaviate.client6.v1.api.collections.generate.DynamicProvider; +import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; + +public record XaiGenerative( + @SerializedName("baseURL") String baseUrl, + @SerializedName("model") String model, + @SerializedName("maxTokens") Integer maxTokens, + @SerializedName("temperature") Float temperature) implements Generative { + + @Override + public Kind _kind() { + return Generative.Kind.XAI; + } + + @Override + public Object _self() { + return this; + } + + public static XaiGenerative of() { + return of(ObjectBuilder.identity()); + } + + public static XaiGenerative of(Function> fn) { + return fn.apply(new Builder()).build(); + } + + public XaiGenerative(Builder builder) { + this( + builder.baseUrl, + builder.model, + builder.maxTokens, + builder.temperature); + } + + public static class Builder implements ObjectBuilder { + private String baseUrl; + private String model; + private Integer maxTokens; + private Float temperature; + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + /** Limit the number of tokens to generate in the response. */ + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public XaiGenerative build() { + return new XaiGenerative(this); + } + } + + public static record Metadata(ProviderMetadata.Usage usage) implements ProviderMetadata { + } + + public static record Provider( + String baseUrl, + Integer maxTokens, + String model, + Float temperature, + Float topP, + List images, + List imageProperties) implements DynamicProvider { + + public static Provider of( + Function> fn) { + return fn.apply(new Builder()).build(); + } + + @Override + public void appendTo( + io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative.GenerativeProvider.Builder req) { + var provider = WeaviateProtoGenerative.GenerativeXAI.newBuilder(); + if (baseUrl != null) { + provider.setBaseUrl(baseUrl); + } + if (maxTokens != null) { + provider.setMaxTokens(maxTokens); + } + if (model != null) { + provider.setModel(model); + } + if (temperature != null) { + provider.setTemperature(temperature); + } + if (topP != null) { + provider.setTopP(topP); + } + if (images != null) { + provider.setImages(WeaviateProtoBase.TextArray.newBuilder() + .addAllValues(images)); + } + if (imageProperties != null) { + provider.setImageProperties(WeaviateProtoBase.TextArray.newBuilder() + .addAllValues(imageProperties)); + } + req.setXai(provider); + } + + public Provider(Builder builder) { + this( + builder.baseUrl, + builder.maxTokens, + builder.model, + builder.temperature, + builder.topP, + builder.images, + builder.imageProperties); + } + + public static class Builder implements ObjectBuilder { + private String baseUrl; + private Float topP; + private String model; + private Integer maxTokens; + private Float temperature; + private final List images = new ArrayList<>(); + private final List imageProperties = new ArrayList<>(); + + /** Base URL of the generative provider. */ + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + /** Top P value for nucleus sampling. */ + public Builder topP(float topP) { + this.topP = topP; + return this; + } + + /** Select generative model. */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** Limit the number of tokens to generate in the response. */ + public Builder maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + public Builder images(String... images) { + return images(Arrays.asList(images)); + } + + public Builder images(List images) { + this.images.addAll(images); + return this; + } + + public Builder imageProperties(String... imageProperties) { + return imageProperties(Arrays.asList(imageProperties)); + } + + public Builder imageProperties(List imageProperties) { + this.imageProperties.addAll(imageProperties); + return this; + } + + /** + * Control the randomness of the model's output. + * Higher values make output more random. + */ + public Builder temperature(float temperature) { + this.temperature = temperature; + return this; + } + + @Override + public XaiGenerative.Provider build() { + return new XaiGenerative.Provider(this); + } + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/AbstractQueryClient.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/AbstractQueryClient.java index 8774ef508..887bc53fc 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/AbstractQueryClient.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/AbstractQueryClient.java @@ -398,7 +398,7 @@ public ResponseT nearVector(float[] vector, Function> fn) { - return performRequest(NearVector.of(searchTarget, fn)); + return nearVector(NearVector.of(searchTarget, fn)); } /** @@ -473,7 +473,7 @@ public GroupedResponseT nearVector(float[] vector, Function> fn, GroupBy groupBy) { - return performRequest(NearVector.of(searchTarget, fn), groupBy); + return nearVector(NearVector.of(searchTarget, fn), groupBy); } /** @@ -734,7 +734,7 @@ public GroupedResponseT nearText(String text, public GroupedResponseT nearText(List text, Function> fn, GroupBy groupBy) { - return nearText(Target.text(text), groupBy); + return nearText(Target.text(text), fn, groupBy); } /** diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/BaseQueryOptions.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/BaseQueryOptions.java index 036ce1165..e7f433dea 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/BaseQueryOptions.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/BaseQueryOptions.java @@ -3,12 +3,14 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.function.Function; import org.apache.commons.lang3.StringUtils; import io.weaviate.client6.v1.api.collections.query.Metadata.MetadataField; import io.weaviate.client6.v1.internal.ObjectBuilder; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet; public record BaseQueryOptions( @@ -18,6 +20,7 @@ public record BaseQueryOptions( String after, ConsistencyLevel consistencyLevel, Where where, + GenerativeSearch generativeSearch, List returnProperties, List returnReferences, List returnMetadata) { @@ -30,6 +33,7 @@ private BaseQueryOptions(Builder, T> builder.after, builder.consistencyLevel, builder.where, + builder.generativeSearch, builder.returnProperties, builder.returnReferences, builder.returnMetadata); @@ -44,6 +48,7 @@ public static abstract class Builder, T extends Ob private String after; private ConsistencyLevel consistencyLevel; private Where where; + private GenerativeSearch generativeSearch; private List returnProperties = new ArrayList<>(); private List returnReferences = new ArrayList<>(); private List returnMetadata = new ArrayList<>(); @@ -102,6 +107,17 @@ public final SELF consistencyLevel(ConsistencyLevel consistencyLevel) { return (SELF) this; } + /** + * Add arguments for generative query. + * Builders which support this parameter should make the method public. + * + * @param fn Lambda expression for optional parameters. + */ + protected SELF generate(Function> fn) { + this.generativeSearch = GenerativeSearch.of(fn); + return (SELF) this; + } + /** * Filter result set using traditional filtering operators: {@code eq}, * {@code gte}, {@code like}, etc. @@ -194,6 +210,12 @@ final void appendTo(WeaviateProtoSearchGet.SearchRequest.Builder req) { req.setFilters(filter); } + if (generativeSearch != null) { + var generative = WeaviateProtoGenerative.GenerativeSearch.newBuilder(); + generativeSearch.appendTo(generative); + req.setGenerative(generative); + } + var metadata = WeaviateProtoSearchGet.MetadataRequest.newBuilder(); returnMetadata.forEach(m -> m.appendTo(metadata)); req.setMetadata(metadata); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/GenerativeSearch.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/GenerativeSearch.java new file mode 100644 index 000000000..359dd282c --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/GenerativeSearch.java @@ -0,0 +1,157 @@ +package io.weaviate.client6.v1.api.collections.query; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.function.Function; + +import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBase; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoGenerative; + +public record GenerativeSearch(Single single, Grouped grouped) { + public static GenerativeSearch of(Function> fn) { + return fn.apply(new Builder()).build(); + } + + public GenerativeSearch(Builder builder) { + this(builder.single, builder.grouped); + } + + public static class Builder implements ObjectBuilder { + private Single single; + private Grouped grouped; + + public Builder singlePrompt(String prompt) { + this.single = Single.of(prompt); + return this; + } + + public Builder singlePrompt(String prompt, Function> fn) { + this.single = Single.of(prompt, fn); + return this; + + } + + public Builder groupedTask(String prompt) { + this.grouped = Grouped.of(prompt); + return this; + } + + public Builder groupedTask(String prompt, Function> fn) { + this.grouped = Grouped.of(prompt, fn); + return this; + } + + @Override + public GenerativeSearch build() { + return new GenerativeSearch(this); + } + } + + void appendTo(WeaviateProtoGenerative.GenerativeSearch.Builder req) { + if (single != null) { + single.appendTo(req); + } + if (grouped != null) { + grouped.appendTo(req); + } + } + + public record Single(String prompt, boolean debug) { + public static Single of(String prompt) { + return of(prompt, ObjectBuilder.identity()); + } + + public static Single of(String prompt, Function> fn) { + return fn.apply(new Builder(prompt)).build(); + } + + public Single(Builder builder) { + this(builder.prompt, builder.debug); + } + + public static class Builder implements ObjectBuilder { + private final String prompt; + private boolean debug = false; + + public Builder(String prompt) { + this.prompt = prompt; + } + + public Builder debug(boolean enable) { + this.debug = enable; + return this; + } + + @Override + public Single build() { + return new Single(this); + } + } + + public void appendTo(WeaviateProtoGenerative.GenerativeSearch.Builder req) { + req.setSingle( + WeaviateProtoGenerative.GenerativeSearch.Single.newBuilder() + .setPrompt(prompt) + .setDebug(debug)); + } + } + + public record Grouped(String prompt, boolean debug, List properties) { + public static Grouped of(String prompt) { + return of(prompt, ObjectBuilder.identity()); + } + + public static Grouped of(String prompt, Function> fn) { + return fn.apply(new Builder(prompt)).build(); + } + + public Grouped(Builder builder) { + this(builder.prompt, builder.debug, builder.properties); + } + + public static class Builder implements ObjectBuilder { + private final String prompt; + private final List properties = new ArrayList<>(); + private boolean debug = false; + + public Builder(String prompt) { + this.prompt = prompt; + } + + public Builder properties(String... properties) { + return properties(Arrays.asList(properties)); + } + + public Builder properties(List properties) { + this.properties.addAll(properties); + return this; + } + + public Builder debug(boolean enable) { + this.debug = enable; + return this; + } + + @Override + public Grouped build() { + return new Grouped(this); + } + } + + public void appendTo(WeaviateProtoGenerative.GenerativeSearch.Builder req) { + var grouped = WeaviateProtoGenerative.GenerativeSearch.Grouped.newBuilder() + .setTask(prompt) + .setDebug(debug); + + if (properties != null && !properties.isEmpty()) { + grouped.setProperties( + WeaviateProtoBase.TextArray.newBuilder() + .addAllValues(properties)); + + } + req.setGrouped(grouped); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryObjectGrouped.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryObjectGrouped.java index f35f3e824..f98014af1 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryObjectGrouped.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryObjectGrouped.java @@ -2,15 +2,15 @@ import io.weaviate.client6.v1.api.collections.WeaviateObject; -public record QueryObjectGrouped( +public record QueryObjectGrouped( /** Object properties. */ - T properties, + PropertiesT properties, /** Object metadata. */ QueryMetadata metadata, /** Name of the group that the object belongs to. */ String belongsToGroup) { - QueryObjectGrouped(WeaviateObject object, + QueryObjectGrouped(WeaviateObject object, String belongsToGroup) { this(object.properties(), object.metadata(), belongsToGroup); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryOperator.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryOperator.java index a3844b4eb..a2fe89d11 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryOperator.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryOperator.java @@ -2,7 +2,7 @@ import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet; -interface QueryOperator { +public interface QueryOperator { /** Append QueryOperator to the request message. */ void appendTo(WeaviateProtoSearchGet.SearchRequest.Builder req); } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java index 11d31b4e5..fcc553f9e 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java @@ -1,277 +1,57 @@ package io.weaviate.client6.v1.api.collections.query; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.UUID; -import java.util.function.Function; -import java.util.stream.Collectors; -import java.util.stream.Stream; - import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; -import io.weaviate.client6.v1.api.collections.ObjectMetadata; -import io.weaviate.client6.v1.api.collections.Vectors; -import io.weaviate.client6.v1.api.collections.WeaviateObject; -import io.weaviate.client6.v1.internal.DateUtil; -import io.weaviate.client6.v1.internal.grpc.ByteStringUtil; import io.weaviate.client6.v1.internal.grpc.Rpc; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateBlockingStub; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateFutureStub; -import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoProperties; import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet; import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; -import io.weaviate.client6.v1.internal.orm.PropertiesBuilder; public record QueryRequest(QueryOperator operator, GroupBy groupBy) { - static Rpc, WeaviateProtoSearchGet.SearchReply> rpc( - CollectionDescriptor collection, + static Rpc, WeaviateProtoSearchGet.SearchReply> rpc( + CollectionDescriptor collection, CollectionHandleDefaults defaults) { return Rpc.of( - request -> { - var message = WeaviateProtoSearchGet.SearchRequest.newBuilder(); - message.setUses127Api(true); - message.setUses125Api(true); - message.setUses123Api(true); - message.setCollection(collection.collectionName()); - request.operator.appendTo(message); - - if (defaults.tenant() != null) { - message.setTenant(defaults.tenant()); - } - if (defaults.consistencyLevel() != null) { - defaults.consistencyLevel().appendTo(message); - } - - if (request.groupBy != null) { - request.groupBy.appendTo(message); - } - return message.build(); - }, - reply -> { - var objects = reply - .getResultsList() - .stream() - .map(obj -> QueryRequest.unmarshalResultObject( - obj.getProperties(), obj.getMetadata(), collection)) - .toList(); - return new QueryResponse<>(objects); - }, + request -> QueryRequest.marshal(request, collection, defaults), + reply -> QueryResponse.unmarshal(reply, collection), () -> WeaviateBlockingStub::search, () -> WeaviateFutureStub::search); } - static Rpc, WeaviateProtoSearchGet.SearchReply> grouped( - CollectionDescriptor collection, + public static WeaviateProtoSearchGet.SearchRequest marshal( + QueryRequest request, + CollectionDescriptor collection, CollectionHandleDefaults defaults) { - var rpc = rpc(collection, defaults); - return Rpc.of( - request -> rpc.marshal(request), - reply -> { - var allObjects = new ArrayList>(); - var groups = reply.getGroupByResultsList() - .stream().map(group -> { - var name = group.getName(); - List> objects = group.getObjectsList().stream() - .map(obj -> QueryRequest.unmarshalResultObject( - obj.getProperties(), - obj.getMetadata(), - collection)) - .map(obj -> new QueryObjectGrouped<>(obj, name)) - .toList(); - - allObjects.addAll(objects); - return new QueryResponseGroup<>( - name, - group.getMinDistance(), - group.getMaxDistance(), - group.getNumberOfObjects(), - objects); - }) - // Collectors.toMap() throws an NPE if either key or value in the map are null. - // In this specific case it is safe to use it, as the function in the map above - // always returns a QueryResponseGroup. - // The name of the group should not be null either, that's something we assume - // about the server's response. - .collect(Collectors.toMap(QueryResponseGroup::name, Function.identity())); + var message = WeaviateProtoSearchGet.SearchRequest.newBuilder(); + message.setUses127Api(true); + message.setUses125Api(true); + message.setUses123Api(true); + message.setCollection(collection.collectionName()); + request.operator.appendTo(message); - return new QueryResponseGrouped(allObjects, groups); - }, () -> rpc.method(), () -> rpc.methodAsync()); - } - - private static WeaviateObject unmarshalResultObject( - WeaviateProtoSearchGet.PropertiesResult propertiesResult, - WeaviateProtoSearchGet.MetadataResult metadataResult, - CollectionDescriptor collection) { - var object = unmarshalWithReferences(propertiesResult, metadataResult, collection); - var metadata = new QueryMetadata.Builder() - .uuid(object.metadata().uuid()) - .vectors(object.metadata().vectors()); - - if (metadataResult.getCreationTimeUnixPresent()) { - metadata.creationTimeUnix(metadataResult.getCreationTimeUnix()); - } - if (metadataResult.getLastUpdateTimeUnixPresent()) { - metadata.lastUpdateTimeUnix(metadataResult.getLastUpdateTimeUnix()); - } - if (metadataResult.getDistancePresent()) { - metadata.distance(metadataResult.getDistance()); - } - if (metadataResult.getCertaintyPresent()) { - metadata.certainty(metadataResult.getCertainty()); + if (defaults.tenant() != null) { + message.setTenant(defaults.tenant()); } - if (metadataResult.getScorePresent()) { - metadata.score(metadataResult.getScore()); + if (defaults.consistencyLevel() != null) { + defaults.consistencyLevel().appendTo(message); } - if (metadataResult.getExplainScorePresent()) { - metadata.explainScore(metadataResult.getExplainScore()); - } - return new WeaviateObject<>(collection.collectionName(), object.properties(), object.references(), - metadata.build()); - } - - private static WeaviateObject unmarshalWithReferences( - WeaviateProtoSearchGet.PropertiesResult propertiesResult, - WeaviateProtoSearchGet.MetadataResult metadataResult, - CollectionDescriptor descriptor) { - var properties = descriptor.propertiesBuilder(); - propertiesResult.getNonRefProps().getFieldsMap().entrySet().stream() - .forEach(entry -> setProperty(entry.getKey(), entry.getValue(), properties, descriptor)); - - // In case a reference is multi-target, there will be a separate - // "reference property" for each of the targets, so instead of - // `collect` we need to `reduce` the map, merging related references - // as we go. - // I.e. { "ref": A-1 } , { "ref": B-1 } => { "ref": [A-1, B-1] } - var referenceProperties = propertiesResult.getRefPropsList() - .stream().reduce( - new HashMap>(), - (map, ref) -> { - var refObjects = ref.getPropertiesList().stream() - .map(property -> { - var reference = unmarshalWithReferences( - property, property.getMetadata(), - // TODO: this should be possible to configure for ODM? - CollectionDescriptor.ofMap(property.getTargetCollection())); - return (Object) new WeaviateObject<>( - reference.collection(), - (Object) reference.properties(), - reference.references(), - reference.metadata()); - }) - .toList(); - - // Merge ObjectReferences by joining the underlying WeaviateObjects. - map.merge( - ref.getPropName(), - refObjects, - (left, right) -> { - var joined = Stream.concat( - left.stream(), - right.stream()).toList(); - return joined; - }); - return map; - }, - (left, right) -> { - left.putAll(right); - return left; - }); - - ObjectMetadata metadata = null; - if (metadataResult != null) { - var metadataBuilder = new ObjectMetadata.Builder() - .uuid(metadataResult.getId()); - var vectors = new Vectors[metadataResult.getVectorsList().size()]; - var i = 0; - for (final var vector : metadataResult.getVectorsList()) { - var vectorName = vector.getName(); - var vbytes = vector.getVectorBytes(); - switch (vector.getType()) { - case VECTOR_TYPE_SINGLE_FP32: - vectors[i++] = Vectors.of(vectorName, ByteStringUtil.decodeVectorSingle(vbytes)); - break; - case VECTOR_TYPE_MULTI_FP32: - vectors[i++] = Vectors.of(vectorName, ByteStringUtil.decodeVectorMulti(vbytes)); - break; - default: - continue; - } - } - metadataBuilder.vectors(vectors); - metadata = metadataBuilder.build(); + if (request.groupBy != null) { + request.groupBy.appendTo(message); } - var obj = new WeaviateObject.Builder() - .collection(descriptor.collectionName()) - .properties(properties.build()) - .references(referenceProperties) - .metadata(metadata); - return obj.build(); + return message.build(); } - private static void setProperty(String property, WeaviateProtoProperties.Value value, - PropertiesBuilder builder, CollectionDescriptor descriptor) { - if (value.hasNullValue()) { - builder.setNull(property); - } else if (value.hasTextValue()) { - builder.setText(property, value.getTextValue()); - } else if (value.hasBoolValue()) { - builder.setBoolean(property, value.getBoolValue()); - } else if (value.hasIntValue()) { - builder.setLong(property, value.getIntValue()); - } else if (value.hasNumberValue()) { - builder.setDouble(property, value.getNumberValue()); - } else if (value.hasBlobValue()) { - builder.setBlob(property, value.getBlobValue()); - } else if (value.hasDateValue()) { - builder.setOffsetDateTime(property, DateUtil.fromISO8601(value.getDateValue())); - } else if (value.hasUuidValue()) { - builder.setUuid(property, UUID.fromString(value.getUuidValue())); - } else if (value.hasListValue()) { - var list = value.getListValue(); - if (list.hasTextValues()) { - builder.setTextArray(property, list.getTextValues().getValuesList()); - } else if (list.hasIntValues()) { - var ints = Arrays.stream( - ByteStringUtil.decodeIntValues(list.getIntValues().getValues())) - .boxed().toList(); - builder.setLongArray(property, ints); - } else if (list.hasNumberValues()) { - var numbers = Arrays.stream( - ByteStringUtil.decodeNumberValues(list.getNumberValues().getValues())) - .boxed().toList(); - builder.setDoubleArray(property, numbers); - } else if (list.hasUuidValues()) { - var uuids = list.getUuidValues().getValuesList().stream() - .map(UUID::fromString).toList(); - builder.setUuidArray(property, uuids); - } else if (list.hasBoolValues()) { - builder.setBooleanArray(property, list.getBoolValues().getValuesList()); - } else if (list.hasDateValues()) { - var dates = list.getDateValues().getValuesList().stream() - .map(DateUtil::fromISO8601).toList(); - builder.setOffsetDateTimeArray(property, dates); - } else if (list.hasObjectValues()) { - List objects = list.getObjectValues().getValuesList().stream() - .map(object -> { - var properties = descriptor.propertiesBuilder(); - object.getFieldsMap().entrySet().stream() - .forEach(entry -> setProperty(entry.getKey(), entry.getValue(), properties, descriptor)); - return properties.build(); - }).toList(); - builder.setNestedObjectArray(property, objects); - } - } else if (value.hasObjectValue()) { - var object = value.getObjectValue(); - var properties = descriptor.propertiesBuilder(); - object.getFieldsMap().entrySet().stream() - .forEach(entry -> setProperty(entry.getKey(), entry.getValue(), properties, descriptor)); - builder.setNestedObject(property, properties.build()); - } else { - throw new IllegalArgumentException(property + " data type is not supported"); - } + static Rpc, WeaviateProtoSearchGet.SearchReply> grouped( + CollectionDescriptor collection, + CollectionHandleDefaults defaults) { + var rpc = rpc(collection, defaults); + return Rpc.of( + request -> rpc.marshal(request), + reply -> QueryResponseGrouped.unmarshal(reply, collection, defaults), + () -> rpc.method(), () -> rpc.methodAsync()); } + } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryResponse.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryResponse.java index b5d465f8c..1160b0255 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryResponse.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryResponse.java @@ -1,9 +1,206 @@ package io.weaviate.client6.v1.api.collections.query; +import java.util.Arrays; +import java.util.HashMap; import java.util.List; +import java.util.UUID; +import java.util.stream.Stream; +import io.weaviate.client6.v1.api.collections.ObjectMetadata; +import io.weaviate.client6.v1.api.collections.Vectors; import io.weaviate.client6.v1.api.collections.WeaviateObject; +import io.weaviate.client6.v1.internal.DateUtil; +import io.weaviate.client6.v1.internal.grpc.ByteStringUtil; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoProperties; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet; +import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; +import io.weaviate.client6.v1.internal.orm.PropertiesBuilder; -public record QueryResponse( - List> objects) { +public record QueryResponse( + List> objects) { + + static QueryResponse unmarshal(WeaviateProtoSearchGet.SearchReply reply, + CollectionDescriptor collection) { + var objects = reply + .getResultsList() + .stream() + .map(obj -> QueryResponse.unmarshalResultObject( + obj.getProperties(), obj.getMetadata(), collection)) + .toList(); + return new QueryResponse<>(objects); + } + + public static WeaviateObject unmarshalResultObject( + WeaviateProtoSearchGet.PropertiesResult propertiesResult, + WeaviateProtoSearchGet.MetadataResult metadataResult, + CollectionDescriptor collection) { + var object = unmarshalWithReferences(propertiesResult, metadataResult, collection); + var metadata = new QueryMetadata.Builder() + .uuid(object.metadata().uuid()) + .vectors(object.metadata().vectors()); + + if (metadataResult.getCreationTimeUnixPresent()) { + metadata.creationTimeUnix(metadataResult.getCreationTimeUnix()); + } + if (metadataResult.getLastUpdateTimeUnixPresent()) { + metadata.lastUpdateTimeUnix(metadataResult.getLastUpdateTimeUnix()); + } + if (metadataResult.getDistancePresent()) { + metadata.distance(metadataResult.getDistance()); + } + if (metadataResult.getCertaintyPresent()) { + metadata.certainty(metadataResult.getCertainty()); + } + if (metadataResult.getScorePresent()) { + metadata.score(metadataResult.getScore()); + } + if (metadataResult.getExplainScorePresent()) { + metadata.explainScore(metadataResult.getExplainScore()); + } + return new WeaviateObject<>(collection.collectionName(), object.properties(), object.references(), + metadata.build()); + } + + static WeaviateObject unmarshalWithReferences( + WeaviateProtoSearchGet.PropertiesResult propertiesResult, + WeaviateProtoSearchGet.MetadataResult metadataResult, + CollectionDescriptor descriptor) { + var properties = descriptor.propertiesBuilder(); + propertiesResult.getNonRefProps().getFieldsMap().entrySet().stream() + .forEach(entry -> setProperty(entry.getKey(), entry.getValue(), properties, descriptor)); + + // In case a reference is multi-target, there will be a separate + // "reference property" for each of the targets, so instead of + // `collect` we need to `reduce` the map, merging related references + // as we go. + // I.e. { "ref": A-1 } , { "ref": B-1 } => { "ref": [A-1, B-1] } + var referenceProperties = propertiesResult.getRefPropsList() + .stream().reduce( + new HashMap>(), + (map, ref) -> { + var refObjects = ref.getPropertiesList().stream() + .map(property -> { + var reference = unmarshalWithReferences( + property, property.getMetadata(), + CollectionDescriptor.ofMap(property.getTargetCollection())); + return (Object) new WeaviateObject<>( + reference.collection(), + (Object) reference.properties(), + reference.references(), + reference.metadata()); + }) + .toList(); + + // Merge ObjectReferences by joining the underlying WeaviateObjects. + map.merge( + ref.getPropName(), + refObjects, + (left, right) -> { + var joined = Stream.concat( + left.stream(), + right.stream()).toList(); + return joined; + }); + return map; + }, + (left, right) -> { + left.putAll(right); + return left; + }); + + ObjectMetadata metadata = null; + if (metadataResult != null) { + var metadataBuilder = new ObjectMetadata.Builder() + .uuid(metadataResult.getId()); + + var vectors = new Vectors[metadataResult.getVectorsList().size()]; + var i = 0; + for (final var vector : metadataResult.getVectorsList()) { + var vectorName = vector.getName(); + var vbytes = vector.getVectorBytes(); + switch (vector.getType()) { + case VECTOR_TYPE_SINGLE_FP32: + vectors[i++] = Vectors.of(vectorName, ByteStringUtil.decodeVectorSingle(vbytes)); + break; + case VECTOR_TYPE_MULTI_FP32: + vectors[i++] = Vectors.of(vectorName, ByteStringUtil.decodeVectorMulti(vbytes)); + break; + default: + continue; + } + } + metadataBuilder.vectors(vectors); + metadata = metadataBuilder.build(); + } + + var obj = new WeaviateObject.Builder() + .collection(descriptor.collectionName()) + .properties(properties.build()) + .references(referenceProperties) + .metadata(metadata); + return obj.build(); + } + + static void setProperty(String property, WeaviateProtoProperties.Value value, + PropertiesBuilder builder, CollectionDescriptor descriptor) { + if (value.hasNullValue()) { + builder.setNull(property); + } else if (value.hasTextValue()) { + builder.setText(property, value.getTextValue()); + } else if (value.hasBoolValue()) { + builder.setBoolean(property, value.getBoolValue()); + } else if (value.hasIntValue()) { + builder.setLong(property, value.getIntValue()); + } else if (value.hasNumberValue()) { + builder.setDouble(property, value.getNumberValue()); + } else if (value.hasBlobValue()) { + builder.setBlob(property, value.getBlobValue()); + } else if (value.hasDateValue()) { + builder.setOffsetDateTime(property, DateUtil.fromISO8601(value.getDateValue())); + } else if (value.hasUuidValue()) { + builder.setUuid(property, UUID.fromString(value.getUuidValue())); + } else if (value.hasListValue()) { + var list = value.getListValue(); + if (list.hasTextValues()) { + builder.setTextArray(property, list.getTextValues().getValuesList()); + } else if (list.hasIntValues()) { + var ints = Arrays.stream( + ByteStringUtil.decodeIntValues(list.getIntValues().getValues())) + .boxed().toList(); + builder.setLongArray(property, ints); + } else if (list.hasNumberValues()) { + var numbers = Arrays.stream( + ByteStringUtil.decodeNumberValues(list.getNumberValues().getValues())) + .boxed().toList(); + builder.setDoubleArray(property, numbers); + } else if (list.hasUuidValues()) { + var uuids = list.getUuidValues().getValuesList().stream() + .map(UUID::fromString).toList(); + builder.setUuidArray(property, uuids); + } else if (list.hasBoolValues()) { + builder.setBooleanArray(property, list.getBoolValues().getValuesList()); + } else if (list.hasDateValues()) { + var dates = list.getDateValues().getValuesList().stream() + .map(DateUtil::fromISO8601).toList(); + builder.setOffsetDateTimeArray(property, dates); + } else if (list.hasObjectValues()) { + List objects = list.getObjectValues().getValuesList().stream() + .map(object -> { + var properties = descriptor.propertiesBuilder(); + object.getFieldsMap().entrySet().stream() + .forEach(entry -> setProperty(entry.getKey(), entry.getValue(), properties, descriptor)); + return properties.build(); + }).toList(); + builder.setNestedObjectArray(property, objects); + } + } else if (value.hasObjectValue()) { + var object = value.getObjectValue(); + var properties = descriptor.propertiesBuilder(); + object.getFieldsMap().entrySet().stream() + .forEach(entry -> setProperty(entry.getKey(), entry.getValue(), properties, descriptor)); + builder.setNestedObject(property, properties.build()); + } else { + throw new IllegalArgumentException(property + " data type is not supported"); + } + } } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryResponseGrouped.java b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryResponseGrouped.java index 9ee8442fa..587be10ac 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryResponseGrouped.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/query/QueryResponseGrouped.java @@ -1,11 +1,52 @@ package io.weaviate.client6.v1.api.collections.query; +import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; -public record QueryResponseGrouped( +import io.weaviate.client6.v1.api.collections.CollectionHandleDefaults; +import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet; +import io.weaviate.client6.v1.internal.orm.CollectionDescriptor; + +public record QueryResponseGrouped( /** All objects retrieved in the query. */ - List> objects, + List> objects, /** Grouped response objects. */ - Map> groups) { + Map> groups) { + + static QueryResponseGrouped unmarshal( + WeaviateProtoSearchGet.SearchReply reply, + CollectionDescriptor collection, + CollectionHandleDefaults defaults) { + var allObjects = new ArrayList>(); + var groups = reply.getGroupByResultsList() + .stream().map(group -> { + var name = group.getName(); + List> objects = group.getObjectsList().stream() + .map(obj -> QueryResponse.unmarshalResultObject( + obj.getProperties(), + obj.getMetadata(), + collection)) + .map(obj -> new QueryObjectGrouped<>(obj, name)) + .toList(); + + allObjects.addAll(objects); + return new QueryResponseGroup<>( + name, + group.getMinDistance(), + group.getMaxDistance(), + group.getNumberOfObjects(), + objects); + }) + // Collectors.toMap() throws an NPE if either key or value in the map are null. + // In this specific case it is safe to use it, as the function in the map above + // always returns a QueryResponseGroup. + // The name of the group should not be null either, that's something we assume + // about the server's response. + .collect(Collectors.toMap(QueryResponseGroup::name, Function.identity())); + + return new QueryResponseGrouped(allObjects, groups); + } } diff --git a/src/main/java/io/weaviate/client6/v1/internal/TaggedUnion.java b/src/main/java/io/weaviate/client6/v1/internal/TaggedUnion.java new file mode 100644 index 000000000..42b944294 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/internal/TaggedUnion.java @@ -0,0 +1,27 @@ +package io.weaviate.client6.v1.internal; + +public interface TaggedUnion, SelfT> { + KindT _kind(); + + SelfT _self(); + + /** Does the current instance have the kind? */ + default boolean _is(KindT kind) { + return _kind() == kind; + } + + /** Convert tagged union instance to one of its variants. */ + default > Value _as(KindT kind) { + return TaggedUnion.as(this, kind); + } + + /** Convert tagged union instance to one of its variants. */ + public static , Tag extends Enum, Value> Value as(Union union, Tag kind) { + if (union._is(kind)) { + @SuppressWarnings("unchecked") + Value value = (Value) union._self(); + return value; + } + throw new IllegalStateException("Cannot convert '%s' variant to '%s'".formatted(union._kind(), kind)); + } +} diff --git a/src/main/proto/v1/generative.proto b/src/main/proto/v1/generative.proto index 4e6e32525..fb3deda67 100644 --- a/src/main/proto/v1/generative.proto +++ b/src/main/proto/v1/generative.proto @@ -23,9 +23,9 @@ message GenerativeSearch { bool debug = 4; } - string single_response_prompt = 1 [deprecated = true]; - string grouped_response_task = 2 [deprecated = true]; - repeated string grouped_properties = 3 [deprecated = true]; + string single_response_prompt = 1 [ deprecated = true ]; + string grouped_response_task = 2 [ deprecated = true ]; + repeated string grouped_properties = 3 [ deprecated = true ]; Single single = 4; Grouped grouped = 5; } @@ -49,7 +49,7 @@ message GenerativeProvider { } } -message GenerativeAnthropic{ +message GenerativeAnthropic { optional string base_url = 1; optional int64 max_tokens = 2; optional string model = 3; @@ -61,13 +61,13 @@ message GenerativeAnthropic{ optional TextArray image_properties = 9; } -message GenerativeAnyscale{ +message GenerativeAnyscale { optional string base_url = 1; optional string model = 2; optional double temperature = 3; } -message GenerativeAWS{ +message GenerativeAWS { optional string model = 3; optional double temperature = 8; optional string service = 9; @@ -79,7 +79,7 @@ message GenerativeAWS{ optional TextArray image_properties = 15; } -message GenerativeCohere{ +message GenerativeCohere { optional string base_url = 1; optional double frequency_penalty = 2; optional int64 max_tokens = 3; @@ -91,10 +91,9 @@ message GenerativeCohere{ optional double temperature = 9; } -message GenerativeDummy{ -} +message GenerativeDummy {} -message GenerativeMistral{ +message GenerativeMistral { optional string base_url = 1; optional int64 max_tokens = 2; optional string model = 3; @@ -102,7 +101,7 @@ message GenerativeMistral{ optional double top_p = 5; } -message GenerativeOllama{ +message GenerativeOllama { optional string api_endpoint = 1; optional string model = 2; optional double temperature = 3; @@ -110,7 +109,7 @@ message GenerativeOllama{ optional TextArray image_properties = 5; } -message GenerativeOpenAI{ +message GenerativeOpenAI { optional double frequency_penalty = 1; optional int64 max_tokens = 2; optional string model = 3; @@ -128,7 +127,7 @@ message GenerativeOpenAI{ optional TextArray image_properties = 15; } -message GenerativeGoogle{ +message GenerativeGoogle { optional double frequency_penalty = 1; optional int64 max_tokens = 2; optional string model = 3; @@ -145,7 +144,7 @@ message GenerativeGoogle{ optional TextArray image_properties = 14; } -message GenerativeDatabricks{ +message GenerativeDatabricks { optional string endpoint = 1; optional string model = 2; optional double frequency_penalty = 3; @@ -159,7 +158,7 @@ message GenerativeDatabricks{ optional double top_p = 11; } -message GenerativeFriendliAI{ +message GenerativeFriendliAI { optional string base_url = 1; optional string model = 2; optional int64 max_tokens = 3; @@ -168,7 +167,7 @@ message GenerativeFriendliAI{ optional double top_p = 6; } -message GenerativeNvidia{ +message GenerativeNvidia { optional string base_url = 1; optional string model = 2; optional double temperature = 3; @@ -176,7 +175,7 @@ message GenerativeNvidia{ optional int64 max_tokens = 5; } -message GenerativeXAI{ +message GenerativeXAI { optional string base_url = 1; optional string model = 2; optional double temperature = 3; @@ -194,11 +193,9 @@ message GenerativeAnthropicMetadata { Usage usage = 1; } -message GenerativeAnyscaleMetadata { -} +message GenerativeAnyscaleMetadata {} -message GenerativeAWSMetadata { -} +message GenerativeAWSMetadata {} message GenerativeCohereMetadata { message ApiVersion { @@ -222,8 +219,7 @@ message GenerativeCohereMetadata { optional TextArray warnings = 4; } -message GenerativeDummyMetadata { -} +message GenerativeDummyMetadata {} message GenerativeMistralMetadata { message Usage { @@ -234,8 +230,7 @@ message GenerativeMistralMetadata { optional Usage usage = 1; } -message GenerativeOllamaMetadata { -} +message GenerativeOllamaMetadata {} message GenerativeOpenAIMetadata { message Usage { @@ -255,9 +250,7 @@ message GenerativeGoogleMetadata { optional TokenCount input_token_count = 1; optional TokenCount output_token_count = 2; } - message Metadata { - optional TokenMetadata token_metadata = 1; - } + message Metadata { optional TokenMetadata token_metadata = 1; } message UsageMetadata { optional int64 prompt_token_count = 1; optional int64 candidates_token_count = 2; @@ -327,10 +320,6 @@ message GenerativeReply { optional GenerativeMetadata metadata = 3; } -message GenerativeResult { - repeated GenerativeReply values = 1; -} +message GenerativeResult { repeated GenerativeReply values = 1; } -message GenerativeDebug { - optional string full_prompt = 1; -} +message GenerativeDebug { optional string full_prompt = 1; } diff --git a/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java b/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java index 24b79dc15..f6f15e0dd 100644 --- a/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java +++ b/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java @@ -426,30 +426,6 @@ public static Object[][] testCases() { """, }, - // Generative.CustomTypeAdapterFactory - { - Generative.class, - Generative.cohere(generate -> generate - .kProperty("k-property") - .maxTokensProperty(10) - .model("example-model") - .returnLikelihoodsProperty("likelihood") - .stopSequencesProperty("stop", "halt") - .temperatureProperty("celcius")), - """ - { - "generative-cohere": { - "kProperty": "k-property", - "maxTokensProperty": 10, - "model": "example-model", - "returnLikelihoodsProperty": "likelihood", - "stopSequencesProperty": ["stop", "halt"], - "temperatureProperty": "celcius" - } - } - """, - }, - // BatchReference.CustomTypeAdapterFactory { BatchReference.class, @@ -917,6 +893,268 @@ public static Object[][] testCases() { } """ }, + + // Generative.CustomTypeAdapterFactory + { + Generative.class, + Generative.anyscale(cfg -> cfg + .baseUrl("https://example.com") + .model("example-model") + .temperature(3f)), + """ + { + "generative-anyscale": { + "baseURL": "https://example.com", + "temperature": 3.0, + "model": "example-model" + } + } + """, + }, + { + Generative.class, + Generative.anthropic(cfg -> cfg + .topK(1) + .maxTokens(2) + .temperature(3f) + .model("example-model") + .stopSequences("stop", "halt")), + """ + { + "generative-anthropic": { + "topK": 1, + "maxTokens": 2, + "temperature": 3.0, + "model": "example-model", + "stopSequences": ["stop", "halt"] + } + } + """, + }, + { + Generative.class, + Generative.aws( + "aws-region", + "aws-service", + cfg -> cfg + .baseUrl("https://example.com") + .model("example-model")), + """ + { + "generative-aws": { + "endpoint": "https://example.com", + "model": "example-model", + "region": "aws-region", + "service": "aws-service" + } + } + """, + }, + { + Generative.class, + Generative.cohere(cfg -> cfg + .topK(1) + .maxTokens(2) + .temperature(3f) + .model("example-model") + .returnLikelihoodsProperty("likelihood") + .stopSequences("stop", "halt")), + """ + { + "generative-cohere": { + "kProperty": 1, + "maxTokensProperty": 2, + "temperatureProperty": 3.0, + "model": "example-model", + "returnLikelihoodsProperty": "likelihood", + "stopSequencesProperty": ["stop", "halt"] + } + } + """, + }, + { + Generative.class, + Generative.databricks( + "https://example.com", + cfg -> cfg + .topK(1) + .maxTokens(2) + .temperature(3f) + .topP(4f)), + """ + { + "generative-databricks": { + "endpoint": "https://example.com", + "topK": 1, + "maxTokens": 2, + "temperature": 3.0, + "topP": 4.0 + } + } + """, + }, + { + Generative.class, + Generative.friendliai(cfg -> cfg + .baseUrl("https://example.com") + .maxTokens(2) + .temperature(3f) + .model("example-model")), + """ + { + "generative-friendliai": { + "baseURL": "https://example.com", + "maxTokens": 2, + "temperature": 3.0, + "model": "example-model" + } + } + """, + }, + { + Generative.class, + Generative.mistral(cfg -> cfg + .baseUrl("https://example.com") + .maxTokens(2) + .temperature(3f) + .model("example-model")), + """ + { + "generative-mistral": { + "baseURL": "https://example.com", + "maxTokens": 2, + "temperature": 3.0, + "model": "example-model" + } + } + """, + }, + { + Generative.class, + Generative.nvidia(cfg -> cfg + .baseUrl("https://example.com") + .maxTokens(2) + .temperature(3f) + .model("example-model")), + """ + { + "generative-nvidia": { + "baseURL": "https://example.com", + "maxTokens": 2, + "temperature": 3.0, + "model": "example-model" + } + } + """, + }, + { + Generative.class, + Generative.google( + "google-project", + cfg -> cfg + .baseUrl("https://example.com") + .maxTokens(2) + .temperature(3f) + .topK(4) + .topP(5f) + .model("example-model")), + """ + { + "generative-palm": { + "apiEndpoint": "https://example.com", + "maxOutputTokens": 2, + "temperature": 3.0, + "topK": 4, + "topP": 5, + "projectId": "google-project", + "modelId": "example-model" + } + } + """, + }, + { + Generative.class, + Generative.ollama(cfg -> cfg + .baseUrl("https://example.com") + .model("example-model")), + """ + { + "generative-ollama": { + "apiEndpoint": "https://example.com", + "model": "example-model" + } + } + """, + }, + { + Generative.class, + Generative.xai(cfg -> cfg + .baseUrl("https://example.com") + .maxTokens(2) + .temperature(3f) + .model("example-model")), + """ + { + "generative-xai": { + "baseURL": "https://example.com", + "maxTokens": 2, + "temperature": 3.0, + "model": "example-model" + } + } + """, + }, + { + Generative.class, + Generative.openai(cfg -> cfg + .baseUrl("https://example.com") + .frequencyPenalty(1f) + .presencePenalty(2f) + .temperature(3f) + .topP(4f) + .maxTokens(5) + .model("o3-mini")), + """ + { + "generative-openai": { + "baseURL": "https://example.com", + "frequencyPenaltyProperty": 1.0, + "presencePenaltyProperty": 2.0, + "temperatureProperty": 3.0, + "topPProperty": 4.0, + "maxTokensProperty": 5, + "model": "o3-mini" + } + } + """ + }, + { + Generative.class, + Generative.azure( + "azure-resource", + "azure-deployment", + cfg -> cfg + .baseUrl("https://example.com") + .frequencyPenalty(1f) + .presencePenalty(2f) + .temperature(3f) + .topP(4f) + .maxTokens(5)), + """ + { + "generative-openai": { + "baseURL": "https://example.com", + "frequencyPenaltyProperty": 1.0, + "presencePenaltyProperty": 2.0, + "temperatureProperty": 3.0, + "topPProperty": 4.0, + "maxTokensProperty": 5, + "resourceName": "azure-resource", + "deploymentId": "azure-deployment" + } + } + """ + }, }; }