diff --git a/.changeset/pink-deers-switch.md b/.changeset/pink-deers-switch.md new file mode 100644 index 000000000000..84d524aeb623 --- /dev/null +++ b/.changeset/pink-deers-switch.md @@ -0,0 +1,6 @@ +--- +'@ai-sdk/google-vertex': patch +'@ai-sdk/google': patch +--- + +feat: add provider option schemas for vertex imagegen and google genai diff --git a/content/providers/01-ai-sdk-providers/15-google-generative-ai.mdx b/content/providers/01-ai-sdk-providers/15-google-generative-ai.mdx index 980fb3ccdc4d..2fba02438350 100644 --- a/content/providers/01-ai-sdk-providers/15-google-generative-ai.mdx +++ b/content/providers/01-ai-sdk-providers/15-google-generative-ai.mdx @@ -80,7 +80,7 @@ const model = google('gemini-1.5-pro-latest'); e.g. `tunedModels/my-model`. -Google Generative AI models support also some model specific settings that are not part of the [standard call settings](/docs/ai-sdk-core/settings). +Google Generative AI also supports some model specific settings that are not part of the [standard call settings](/docs/ai-sdk-core/settings). You can pass them as an options argument: ```ts @@ -132,6 +132,29 @@ The following optional settings are available for Google Generative AI models: - `BLOCK_ONLY_HIGH` - `BLOCK_NONE` +Further configuration can be done using Google Generative AI provider options. You can validate the provider options using the `GoogleGenerativeAIProviderOptions` type. + +```ts +import { google } from '@ai-sdk/google'; +import { GoogleGenerativeAIProviderOptions } from '@ai-sdk/google'; +import { generateText } from 'ai'; + +const { text } = await generateText({ + model: google('gemini-1.5-pro-latest'), + providerOptions: { + google: { + responseModalities: ['TEXT', 'IMAGE'], + } satisfies GoogleGenerativeAIProviderOptions, + }, + // ... +}); +``` + +The following provider options are available: + +- **responseModalities** _string[]_ + The modalities to use for the response. The following modalities are supported: `TEXT`, `IMAGE`. When not defined or empty, the model defaults to returning only text. + You can use Google Generative AI language models to generate text with the `generateText` function: ```ts diff --git a/content/providers/01-ai-sdk-providers/16-google-vertex.mdx b/content/providers/01-ai-sdk-providers/16-google-vertex.mdx index b326449fd6fe..b65b0f6b0819 100644 --- a/content/providers/01-ai-sdk-providers/16-google-vertex.mdx +++ b/content/providers/01-ai-sdk-providers/16-google-vertex.mdx @@ -599,6 +599,41 @@ const { image } = await generateImage({ }); ``` +Further configuration can be done using Google Vertex provider options. You can validate the provider options using the `GoogleVertexImageProviderOptions` type. + +```ts +import { vertex } from '@ai-sdk/google-vertex'; +import { GoogleVertexImageProviderOptions } from '@ai-sdk/google-vertex'; +import { generateImage } from 'ai'; + +const { image } = await generateImage({ + model: vertex.image('imagen-3.0-generate-001'), + providerOptions: { + vertex: { + negativePrompt: 'pixelated, blurry, low-quality', + } satisfies GoogleVertexImageProviderOptions, + }, + // ... +}); +``` + +The following provider options are available: + +- **negativePrompt** _string_ + A description of what to discourage in the generated images. + +- **personGeneration** `allow_adult` | `allow_all` | `dont_allow` + Whether to allow person generation. Defaults to `allow_adult`. + +- **safetySetting** `block_low_and_above` | `block_medium_and_above` | `block_only_high` | `block_none` + Whether to block unsafe content. Defaults to `block_medium_and_above`. + +- **addWatermark** _boolean_ + Whether to add an invisible watermark to the generated images. Defaults to `true`. + +- **storageUri** _string_ + Cloud Storage URI to store the generated images. + Imagen models do not support the `size` parameter. Use the `aspectRatio` parameter instead. diff --git a/examples/ai-core/src/generate-image/google-vertex.ts b/examples/ai-core/src/generate-image/google-vertex.ts index 4635fcda81dd..d21043e1005b 100644 --- a/examples/ai-core/src/generate-image/google-vertex.ts +++ b/examples/ai-core/src/generate-image/google-vertex.ts @@ -1,7 +1,10 @@ -import { vertex } from '@ai-sdk/google-vertex'; +import { + GoogleVertexImageProviderOptions, + vertex, +} from '@ai-sdk/google-vertex'; import { experimental_generateImage as generateImage } from 'ai'; -import { presentImages } from '../lib/present-image'; import 'dotenv/config'; +import { presentImages } from '../lib/present-image'; async function main() { const { image } = await generateImage({ @@ -10,9 +13,8 @@ async function main() { aspectRatio: '1:1', providerOptions: { vertex: { - // https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/imagen-api#parameter_list addWatermark: false, - }, + } satisfies GoogleVertexImageProviderOptions, }, }); diff --git a/examples/ai-core/src/generate-text/google-image.ts b/examples/ai-core/src/generate-text/google-image.ts index d7c66219ecdf..f5f32d72134a 100644 --- a/examples/ai-core/src/generate-text/google-image.ts +++ b/examples/ai-core/src/generate-text/google-image.ts @@ -1,4 +1,4 @@ -import { google } from '@ai-sdk/google'; +import { google, GoogleGenerativeAIProviderOptions } from '@ai-sdk/google'; import { generateText } from 'ai'; import 'dotenv/config'; import fs from 'node:fs'; @@ -15,6 +15,11 @@ async function main() { ], }, ], + providerOptions: { + google: { + responseModalities: ['TEXT', 'IMAGE'], + } satisfies GoogleGenerativeAIProviderOptions, + }, }); console.log(result.text); diff --git a/packages/google-vertex/src/google-vertex-image-model.test.ts b/packages/google-vertex/src/google-vertex-image-model.test.ts index 0a7dab1c6fe9..c9be8cb4839a 100644 --- a/packages/google-vertex/src/google-vertex-image-model.test.ts +++ b/packages/google-vertex/src/google-vertex-image-model.test.ts @@ -38,27 +38,6 @@ describe('GoogleVertexImageModel', () => { }; } - it('should pass the correct parameters', async () => { - prepareJsonResponse(); - - await model.doGenerate({ - prompt, - n: 2, - size: undefined, - aspectRatio: undefined, - seed: undefined, - providerOptions: { vertex: { aspectRatio: '1:1' } }, - }); - - expect(await server.calls[0].requestBody).toStrictEqual({ - instances: [{ prompt }], - parameters: { - sampleCount: 2, - aspectRatio: '1:1', - }, - }); - }); - it('should pass headers', async () => { prepareJsonResponse(); @@ -143,13 +122,9 @@ describe('GoogleVertexImageModel', () => { prompt: 'test prompt', n: 1, size: undefined, - aspectRatio: undefined, + aspectRatio: '16:9', seed: undefined, - providerOptions: { - vertex: { - aspectRatio: '16:9', - }, - }, + providerOptions: {}, }); expect(await server.calls[0].requestBody).toStrictEqual({ @@ -214,7 +189,7 @@ describe('GoogleVertexImageModel', () => { seed: 42, providerOptions: { vertex: { - temperature: 0.8, + addWatermark: false, }, }, }); @@ -225,7 +200,7 @@ describe('GoogleVertexImageModel', () => { sampleCount: 1, aspectRatio: '1:1', seed: 42, - temperature: 0.8, + addWatermark: false, }, }); }); @@ -302,7 +277,7 @@ describe('GoogleVertexImageModel', () => { const result = await model.doGenerate({ prompt, - n: 1, + n: 2, size: undefined, aspectRatio: undefined, seed: undefined, @@ -319,5 +294,36 @@ describe('GoogleVertexImageModel', () => { ); expect(result.response.modelId).toBe('imagen-3.0-generate-001'); }); + + it('should only pass valid provider options', async () => { + prepareJsonResponse(); + + await model.doGenerate({ + prompt, + n: 2, + size: undefined, + aspectRatio: '16:9', + seed: undefined, + providerOptions: { + vertex: { + addWatermark: false, + negativePrompt: 'negative prompt', + personGeneration: 'allow_all', + foo: 'bar', + }, + }, + }); + + expect(await server.calls[0].requestBody).toStrictEqual({ + instances: [{ prompt }], + parameters: { + sampleCount: 2, + addWatermark: false, + negativePrompt: 'negative prompt', + personGeneration: 'allow_all', + aspectRatio: '16:9', + }, + }); + }); }); }); diff --git a/packages/google-vertex/src/google-vertex-image-model.ts b/packages/google-vertex/src/google-vertex-image-model.ts index c94d5ddc32d5..635b50df5445 100644 --- a/packages/google-vertex/src/google-vertex-image-model.ts +++ b/packages/google-vertex/src/google-vertex-image-model.ts @@ -3,6 +3,7 @@ import { Resolvable, combineHeaders, createJsonResponseHandler, + parseProviderOptions, postJsonToApi, resolve, } from '@ai-sdk/provider-utils'; @@ -65,13 +66,19 @@ export class GoogleVertexImageModel implements ImageModelV1 { }); } + const vertexImageOptions = parseProviderOptions({ + provider: 'vertex', + providerOptions, + schema: vertexImageProviderOptionsSchema, + }); + const body = { instances: [{ prompt }], parameters: { sampleCount: n, ...(aspectRatio != null ? { aspectRatio } : {}), ...(seed != null ? { seed } : {}), - ...(providerOptions.vertex ?? {}), + ...(vertexImageOptions ?? {}), }, }; @@ -108,3 +115,23 @@ export class GoogleVertexImageModel implements ImageModelV1 { const vertexImageResponseSchema = z.object({ predictions: z.array(z.object({ bytesBase64Encoded: z.string() })).nullish(), }); + +const vertexImageProviderOptionsSchema = z.object({ + negativePrompt: z.string().nullish(), + personGeneration: z + .enum(['dont_allow', 'allow_adult', 'allow_all']) + .nullish(), + safetySetting: z + .enum([ + 'block_low_and_above', + 'block_medium_and_above', + 'block_only_high', + 'block_none', + ]) + .nullish(), + addWatermark: z.boolean().nullish(), + storageUri: z.string().nullish(), +}); +export type GoogleVertexImageProviderOptions = z.infer< + typeof vertexImageProviderOptionsSchema +>; diff --git a/packages/google-vertex/src/index.ts b/packages/google-vertex/src/index.ts index bf1e0fad1033..e84f8e06d72f 100644 --- a/packages/google-vertex/src/index.ts +++ b/packages/google-vertex/src/index.ts @@ -1,3 +1,4 @@ +export type { GoogleVertexImageProviderOptions } from './google-vertex-image-model'; export { createVertex, vertex } from './google-vertex-provider-node'; export type { GoogleVertexProvider, diff --git a/packages/google/src/google-generative-ai-language-model.test.ts b/packages/google/src/google-generative-ai-language-model.test.ts index 4a10794a8c2a..a860afb73ad9 100644 --- a/packages/google/src/google-generative-ai-language-model.test.ts +++ b/packages/google/src/google-generative-ai-language-model.test.ts @@ -414,6 +414,39 @@ describe('doGenerate', () => { }); }); + it('should only pass valid provider options', async () => { + prepareJsonResponse({}); + + await model.doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: [ + { role: 'system', content: 'test system instruction' }, + { role: 'user', content: [{ type: 'text', text: 'Hello' }] }, + ], + seed: 123, + temperature: 0.5, + providerMetadata: { + google: { foo: 'bar', responseModalities: ['TEXT', 'IMAGE'] }, + }, + }); + + expect(await server.calls[0].requestBody).toStrictEqual({ + contents: [ + { + role: 'user', + parts: [{ text: 'Hello' }], + }, + ], + systemInstruction: { parts: [{ text: 'test system instruction' }] }, + generationConfig: { + seed: 123, + temperature: 0.5, + responseModalities: ['TEXT', 'IMAGE'], + }, + }); + }); + it('should pass tools and toolChoice', async () => { prepareJsonResponse({}); @@ -1869,4 +1902,29 @@ describe('doStream', () => { 'tool-calls', ); }); + + it('should only pass valid provider options', async () => { + prepareStreamResponse({ content: [''] }); + + await model.doStream({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + providerMetadata: { + google: { foo: 'bar', responseModalities: ['TEXT', 'IMAGE'] }, + }, + }); + + expect(await server.calls[0].requestBody).toMatchObject({ + contents: [ + { + role: 'user', + parts: [{ text: 'Hello' }], + }, + ], + generationConfig: { + responseModalities: ['TEXT', 'IMAGE'], + }, + }); + }); }); diff --git a/packages/google/src/google-generative-ai-language-model.ts b/packages/google/src/google-generative-ai-language-model.ts index 8f2a9bf1c356..6482e97d64c9 100644 --- a/packages/google/src/google-generative-ai-language-model.ts +++ b/packages/google/src/google-generative-ai-language-model.ts @@ -88,9 +88,7 @@ export class GoogleGenerativeAILanguageModel implements LanguageModelV1 { const googleOptions = parseProviderOptions({ provider: 'google', providerOptions: providerMetadata, - schema: z.object({ - responseModalities: z.array(z.enum(['TEXT', 'IMAGE'])).nullish(), - }), + schema: googleGenerativeAIProviderOptionsSchema, }); const generationConfig = { @@ -623,3 +621,10 @@ const chunkSchema = z.object({ }) .nullish(), }); + +const googleGenerativeAIProviderOptionsSchema = z.object({ + responseModalities: z.array(z.enum(['TEXT', 'IMAGE'])).nullish(), +}); +export type GoogleGenerativeAIProviderOptions = z.infer< + typeof googleGenerativeAIProviderOptionsSchema +>; diff --git a/packages/google/src/index.ts b/packages/google/src/index.ts index bf9c08c46346..8220c6c80ba0 100644 --- a/packages/google/src/index.ts +++ b/packages/google/src/index.ts @@ -1,6 +1,7 @@ -export { createGoogleGenerativeAI, google } from './google-provider'; export type { GoogleErrorData } from './google-error'; +export type { GoogleGenerativeAIProviderOptions } from './google-generative-ai-language-model'; export type { GoogleGenerativeAIProviderMetadata } from './google-generative-ai-prompt'; +export { createGoogleGenerativeAI, google } from './google-provider'; export type { GoogleGenerativeAIProvider, GoogleGenerativeAIProviderSettings,