diff --git a/agent-schema.json b/agent-schema.json index 1f2c0210b..fd2bbc76b 100644 --- a/agent-schema.json +++ b/agent-schema.json @@ -552,7 +552,7 @@ }, "provider_opts": { "type": "object", - "description": "Provider-specific options. dmr: runtime_flags. anthropic/amazon-bedrock (Claude): interleaved_thinking (boolean, default true). openai: transport ('sse' or 'websocket') to choose between SSE and WebSocket streaming for the Responses API. openai/anthropic/google: rerank_prompt (string) to fully override the system prompt used for RAG reranking (advanced - prefer using results.reranking.criteria for domain-specific guidance).", + "description": "Provider-specific options. Sampling parameters: top_k (integer, supported by anthropic, google, amazon-bedrock, and custom OpenAI-compatible providers like vLLM/Ollama), repetition_penalty (float, forwarded to custom OpenAI-compatible providers), min_p (float, forwarded to custom providers), seed (integer, forwarded to OpenAI). Infrastructure options: dmr: runtime_flags. anthropic/amazon-bedrock (Claude): interleaved_thinking (boolean, default true). openai: transport ('sse' or 'websocket') to choose between SSE and WebSocket streaming for the Responses API. openai/anthropic/google: rerank_prompt (string) to fully override the system prompt used for RAG reranking (advanced - prefer using results.reranking.criteria for domain-specific guidance).", "additionalProperties": true }, "track_usage": { diff --git a/examples/sampling-opts.yaml b/examples/sampling-opts.yaml new file mode 100644 index 000000000..da7e1a04d --- /dev/null +++ b/examples/sampling-opts.yaml @@ -0,0 +1,22 @@ +#!/usr/bin/env docker agent run + +# This example shows how to use provider_opts to pass sampling parameters +# like top_k and repetition_penalty to different providers. + +agents: + root: + model: gpt + description: "Assistant with custom sampling parameters" + instruction: | + You are a helpful assistant running on a local model with tuned sampling parameters. + +models: + gpt: + provider: openai + model: gpt-4o + temperature: 0.7 + top_p: 0.9 + provider_opts: + top_k: 40 + repetition_penalty: 1.15 + min_p: 0.05 diff --git a/pkg/model/provider/anthropic/beta_client.go b/pkg/model/provider/anthropic/beta_client.go index f1b754d8b..249b7ab8a 100644 --- a/pkg/model/provider/anthropic/beta_client.go +++ b/pkg/model/provider/anthropic/beta_client.go @@ -13,6 +13,7 @@ import ( "github.com/anthropics/anthropic-sdk-go/packages/ssestream" "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/model/provider/providerutil" "github.com/docker/docker-agent/pkg/rag/prompts" "github.com/docker/docker-agent/pkg/rag/types" "github.com/docker/docker-agent/pkg/tools" @@ -115,6 +116,12 @@ func (c *Client) createBetaStream( "max_tokens", maxTokens, "message_count", len(params.Messages)) + // Forward top_k from provider_opts (Anthropic natively supports it) + if topK, ok := providerutil.GetProviderOptInt64(c.ModelConfig.ProviderOpts, "top_k"); ok { + params.TopK = param.NewOpt(topK) + slog.Debug("Anthropic Beta provider_opts: set top_k", "value", topK) + } + stream := client.Beta.Messages.NewStreaming(ctx, params) trackUsage := c.ModelConfig.TrackUsage == nil || *c.ModelConfig.TrackUsage ad := c.newBetaStreamAdapter(stream, trackUsage) @@ -293,6 +300,12 @@ func (c *Client) Rerank(ctx context.Context, query string, documents []types.Doc params.TopP = param.NewOpt(*c.ModelConfig.TopP) } + // Forward top_k from provider_opts (Anthropic natively supports it) + if topK, ok := providerutil.GetProviderOptInt64(c.ModelConfig.ProviderOpts, "top_k"); ok { + params.TopK = param.NewOpt(topK) + slog.Debug("Anthropic Beta provider_opts: set top_k", "value", topK) + } + // Use streaming API to avoid timeout errors for operations that may take longer than 10 minutes stream := client.Beta.Messages.NewStreaming(ctx, params) diff --git a/pkg/model/provider/anthropic/client.go b/pkg/model/provider/anthropic/client.go index b171e9563..971e4ab2c 100644 --- a/pkg/model/provider/anthropic/client.go +++ b/pkg/model/provider/anthropic/client.go @@ -24,6 +24,7 @@ import ( "github.com/docker/docker-agent/pkg/httpclient" "github.com/docker/docker-agent/pkg/model/provider/base" "github.com/docker/docker-agent/pkg/model/provider/options" + "github.com/docker/docker-agent/pkg/model/provider/providerutil" "github.com/docker/docker-agent/pkg/tools" ) @@ -337,6 +338,12 @@ func (c *Client) CreateChatCompletionStream( slog.Debug("Anthropic extended thinking enabled, ignoring temperature/top_p settings") } + // Forward top_k from provider_opts (Anthropic natively supports it) + if topK, ok := providerutil.GetProviderOptInt64(c.ModelConfig.ProviderOpts, "top_k"); ok { + params.TopK = param.NewOpt(topK) + slog.Debug("Anthropic provider_opts: set top_k", "value", topK) + } + if len(requestTools) > 0 { slog.Debug("Adding tools to Anthropic request", "tool_count", len(requestTools)) } diff --git a/pkg/model/provider/bedrock/client.go b/pkg/model/provider/bedrock/client.go index a29fd7363..ffed161b6 100644 --- a/pkg/model/provider/bedrock/client.go +++ b/pkg/model/provider/bedrock/client.go @@ -20,6 +20,7 @@ import ( "github.com/docker/docker-agent/pkg/environment" "github.com/docker/docker-agent/pkg/model/provider/base" "github.com/docker/docker-agent/pkg/model/provider/options" + "github.com/docker/docker-agent/pkg/model/provider/providerutil" "github.com/docker/docker-agent/pkg/modelsdev" "github.com/docker/docker-agent/pkg/tools" ) @@ -244,7 +245,7 @@ func (c *Client) buildConverseStreamInput(messages []chat.Message, requestTools } // Set inference configuration (temp/topP are suppressed when thinking is on). - input.InferenceConfig = c.buildInferenceConfig(additionalFields != nil) + input.InferenceConfig = c.buildInferenceConfig(c.isThinkingEnabled()) // Convert and set tools if len(requestTools) > 0 { @@ -278,59 +279,100 @@ func (c *Client) buildInferenceConfig(thinkingEnabled bool) *types.InferenceConf } func (c *Client) interleavedThinkingEnabled() bool { - return getProviderOpt[bool](c.ModelConfig.ProviderOpts, "interleaved_thinking") -} - -func (c *Client) promptCachingEnabled() bool { - if getProviderOpt[bool](c.ModelConfig.ProviderOpts, "disable_prompt_caching") { - return false + // Default to true, matching the documented schema behavior. + v, ok := c.ModelConfig.ProviderOpts["interleaved_thinking"] + if !ok { + return true } - return c.cachingSupported + b, ok := v.(bool) + if !ok { + slog.Warn("Bedrock provider_opts type mismatch", + "key", "interleaved_thinking", + "expected_type", "bool", + "actual_type", fmt.Sprintf("%T", v), + "value", v) + return true + } + return b } -// buildAdditionalModelRequestFields configures Claude's extended thinking (reasoning) mode. -func (c *Client) buildAdditionalModelRequestFields() document.Interface { +// isThinkingEnabled returns true if a valid thinking budget is configured. +// It mirrors the validation in buildAdditionalModelRequestFields but without +// side effects (no logging), so it can safely be used to gate inference config. +func (c *Client) isThinkingEnabled() bool { if c.ModelConfig.ThinkingBudget == nil { - return nil + return false } tokens := c.ModelConfig.ThinkingBudget.Tokens if t, ok := c.ModelConfig.ThinkingBudget.EffortTokens(); ok { tokens = t } - if tokens <= 0 { - return nil - } - - // Validate minimum (Claude requires at least 1024 tokens for thinking) if tokens < 1024 { - slog.Warn("Bedrock thinking_budget below minimum (1024), ignoring", - "tokens", tokens) - return nil + return false } - - // Validate against max_tokens if c.ModelConfig.MaxTokens != nil && tokens >= int(*c.ModelConfig.MaxTokens) { - slog.Warn("Bedrock thinking_budget must be less than max_tokens, ignoring", - "thinking_budget", tokens, - "max_tokens", *c.ModelConfig.MaxTokens) - return nil + return false } + return true +} - slog.Debug("Bedrock request using thinking_budget", "budget_tokens", tokens) +func (c *Client) promptCachingEnabled() bool { + if getProviderOpt[bool](c.ModelConfig.ProviderOpts, "disable_prompt_caching") { + return false + } + return c.cachingSupported +} - fields := map[string]any{ - "thinking": map[string]any{ - "type": "enabled", - "budget_tokens": tokens, - }, +// buildAdditionalModelRequestFields configures Claude's extended thinking (reasoning) mode +// and forwards supported sampling parameters from provider_opts (e.g. top_k). +func (c *Client) buildAdditionalModelRequestFields() document.Interface { + fields := map[string]any{} + + // Forward top_k from provider_opts (Anthropic on Bedrock supports it) + if topK, ok := providerutil.GetProviderOptInt64(c.ModelConfig.ProviderOpts, "top_k"); ok { + fields["top_k"] = topK + slog.Debug("Bedrock provider_opts: set top_k", "value", topK) } - // Add anthropic_beta field for interleaved thinking - if c.interleavedThinkingEnabled() { - fields["anthropic_beta"] = []string{"interleaved-thinking-2025-05-14"} - slog.Debug("Bedrock request using interleaved thinking beta") + // Configure thinking budget if present and valid + if budget := c.ModelConfig.ThinkingBudget; budget != nil { + tokens := budget.Tokens + if t, ok := budget.EffortTokens(); ok { + tokens = t + } + + valid := tokens > 0 + if valid && tokens < 1024 { + slog.Warn("Bedrock thinking_budget below minimum (1024), ignoring", "tokens", tokens) + valid = false + } + if valid && c.ModelConfig.MaxTokens != nil && tokens >= int(*c.ModelConfig.MaxTokens) { + slog.Warn("Bedrock thinking_budget must be less than max_tokens, ignoring", + "thinking_budget", tokens, + "max_tokens", *c.ModelConfig.MaxTokens) + valid = false + } + + if valid { + slog.Debug("Bedrock request using thinking_budget", "budget_tokens", tokens) + fields["thinking"] = map[string]any{ + "type": "enabled", + "budget_tokens": tokens, + } + + if c.interleavedThinkingEnabled() { + fields["anthropic_beta"] = []string{"interleaved-thinking-2025-05-14"} + slog.Debug("Bedrock request using interleaved thinking beta") + } else { + slog.Warn("Bedrock thinking_budget is set but interleaved_thinking is explicitly disabled; " + + "the anthropic_beta header will not be sent, which may cause the thinking budget to be ignored") + } + } } + if len(fields) == 0 { + return nil + } return document.NewLazyDocument(fields) } diff --git a/pkg/model/provider/bedrock/client_test.go b/pkg/model/provider/bedrock/client_test.go index 9c7e0e2bf..7f91a19b0 100644 --- a/pkg/model/provider/bedrock/client_test.go +++ b/pkg/model/provider/bedrock/client_test.go @@ -854,7 +854,7 @@ func TestInterleavedThinkingEnabled_NotSet(t *testing.T) { }, } - assert.False(t, client.interleavedThinkingEnabled()) + assert.True(t, client.interleavedThinkingEnabled()) } func TestInterleavedThinkingEnabled_NilProviderOpts(t *testing.T) { @@ -870,7 +870,7 @@ func TestInterleavedThinkingEnabled_NilProviderOpts(t *testing.T) { }, } - assert.False(t, client.interleavedThinkingEnabled()) + assert.True(t, client.interleavedThinkingEnabled()) } func TestBuildAdditionalModelRequestFields_WithInterleavedThinking(t *testing.T) { diff --git a/pkg/model/provider/gemini/client.go b/pkg/model/provider/gemini/client.go index 0178e5d74..3cfb62922 100644 --- a/pkg/model/provider/gemini/client.go +++ b/pkg/model/provider/gemini/client.go @@ -21,6 +21,7 @@ import ( "github.com/docker/docker-agent/pkg/httpclient" "github.com/docker/docker-agent/pkg/model/provider/base" "github.com/docker/docker-agent/pkg/model/provider/options" + "github.com/docker/docker-agent/pkg/model/provider/providerutil" "github.com/docker/docker-agent/pkg/rag/prompts" "github.com/docker/docker-agent/pkg/rag/types" "github.com/docker/docker-agent/pkg/tools" @@ -352,6 +353,12 @@ func (c *Client) buildConfig() *genai.GenerateContentConfig { config.PresencePenalty = new(float32(*c.ModelConfig.PresencePenalty)) } + // Forward top_k from provider_opts (Gemini natively supports it) + if topK, ok := providerutil.GetProviderOptFloat64(c.ModelConfig.ProviderOpts, "top_k"); ok { + config.TopK = new(float32(topK)) + slog.Debug("Gemini provider_opts: set top_k", "value", topK) + } + // Apply thinking configuration for Gemini models. // See https://ai.google.dev/gemini-api/docs/thinking if c.ModelOptions.NoThinking() { diff --git a/pkg/model/provider/openai/client.go b/pkg/model/provider/openai/client.go index 8042a295b..19a99b1d4 100644 --- a/pkg/model/provider/openai/client.go +++ b/pkg/model/provider/openai/client.go @@ -312,6 +312,11 @@ func (c *Client) CreateChatCompletionStream( return nil, err } + // Forward sampling-related provider_opts as extra body fields. + // This allows custom/OpenAI-compatible providers (vLLM, Ollama, etc.) + // to receive parameters like top_k, repetition_penalty, etc. + applySamplingProviderOpts(¶ms, c.ModelConfig.ProviderOpts) + stream := client.Chat.Completions.NewStreaming(ctx, params) slog.Debug("OpenAI chat completion stream created successfully", "model", c.ModelConfig.Model) @@ -842,6 +847,8 @@ func (c *Client) Rerank(ctx context.Context, query string, documents []types.Doc }, } + applySamplingProviderOpts(¶ms, c.ModelConfig.ProviderOpts) + resp, err := client.Chat.Completions.New(ctx, params) if err != nil { slog.Error("OpenAI rerank request failed", "error", err) diff --git a/pkg/model/provider/openai/sampling_opts.go b/pkg/model/provider/openai/sampling_opts.go new file mode 100644 index 000000000..8734234fc --- /dev/null +++ b/pkg/model/provider/openai/sampling_opts.go @@ -0,0 +1,46 @@ +package openai + +import ( + "log/slog" + + oai "github.com/openai/openai-go/v3" + + "github.com/docker/docker-agent/pkg/model/provider/providerutil" +) + +// applySamplingProviderOpts forwards sampling-related provider_opts as extra +// body fields on the OpenAI ChatCompletionNewParams. This enables custom +// OpenAI-compatible providers (vLLM, Ollama, llama.cpp, etc.) to receive +// parameters like top_k, repetition_penalty, min_p, etc. that the native +// OpenAI API does not support but these backends do. +func applySamplingProviderOpts(params *oai.ChatCompletionNewParams, opts map[string]any) { + if len(opts) == 0 { + return + } + + extras := make(map[string]any) + + for _, key := range providerutil.SamplingProviderOptsKeys() { + if key == "seed" { + // seed is a native ChatCompletionNewParams field (int64), + // so set it directly rather than as an extra field. + if v, ok := providerutil.GetProviderOptInt64(opts, key); ok { + params.Seed = oai.Int(v) + slog.Debug("OpenAI provider_opts: set seed", "value", v) + } + continue + } + + if v, ok := providerutil.GetProviderOptFloat64(opts, key); ok { + extras[key] = v + slog.Debug("OpenAI provider_opts: forwarding sampling param", "key", key, "value", v) + } else if vi, ok := providerutil.GetProviderOptInt64(opts, key); ok { + extras[key] = vi + slog.Debug("OpenAI provider_opts: forwarding sampling param", "key", key, "value", vi) + } + } + + if len(extras) > 0 { + params.SetExtraFields(extras) + } +} diff --git a/pkg/model/provider/openai/sampling_opts_test.go b/pkg/model/provider/openai/sampling_opts_test.go new file mode 100644 index 000000000..87d55afb8 --- /dev/null +++ b/pkg/model/provider/openai/sampling_opts_test.go @@ -0,0 +1,75 @@ +package openai + +import ( + "encoding/json" + "testing" + + oai "github.com/openai/openai-go/v3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestApplySamplingProviderOpts(t *testing.T) { + tests := []struct { + name string + opts map[string]any + wantKeys []string // keys expected in JSON output + }{ + { + name: "nil opts", + opts: nil, + }, + { + name: "empty opts", + opts: map[string]any{}, + }, + { + name: "top_k forwarded", + opts: map[string]any{"top_k": 40}, + wantKeys: []string{"top_k"}, + }, + { + name: "repetition_penalty forwarded", + opts: map[string]any{"repetition_penalty": 1.15}, + wantKeys: []string{"repetition_penalty"}, + }, + { + name: "multiple sampling opts", + opts: map[string]any{"top_k": 50, "repetition_penalty": 1.1, "min_p": 0.05}, + wantKeys: []string{"top_k", "repetition_penalty", "min_p"}, + }, + { + name: "non-sampling opts ignored", + opts: map[string]any{"api_type": "openai_chatcompletions", "transport": "websocket"}, + }, + { + name: "seed set natively", + opts: map[string]any{"seed": 42}, + wantKeys: []string{"seed"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + params := oai.ChatCompletionNewParams{ + Model: "test-model", + } + applySamplingProviderOpts(¶ms, tt.opts) + + // Marshal to JSON and check for expected keys + data, err := json.Marshal(params) + require.NoError(t, err) + + var m map[string]any + require.NoError(t, json.Unmarshal(data, &m)) + + for _, key := range tt.wantKeys { + assert.Contains(t, m, key, "expected key %q in JSON output", key) + } + + // Non-sampling keys should never appear + assert.NotContains(t, m, "api_type") + assert.NotContains(t, m, "transport") + }) + } +} diff --git a/pkg/model/provider/providerutil/provider_opts.go b/pkg/model/provider/providerutil/provider_opts.go new file mode 100644 index 000000000..8235fdd71 --- /dev/null +++ b/pkg/model/provider/providerutil/provider_opts.go @@ -0,0 +1,86 @@ +package providerutil + +import ( + "fmt" + "log/slog" + "math" +) + +// GetProviderOptFloat64 extracts a float64 value from provider opts. +// YAML may parse numbers as float64 or int, so this handles both. +func GetProviderOptFloat64(opts map[string]any, key string) (float64, bool) { + if opts == nil { + return 0, false + } + v, ok := opts[key] + if !ok { + return 0, false + } + switch n := v.(type) { + case float64: + return n, true + case float32: + return float64(n), true + case int: + return float64(n), true + case int64: + return float64(n), true + default: + slog.Debug("provider_opts type mismatch, ignoring", + "key", key, + "expected_type", "numeric", + "actual_type", fmt.Sprintf("%T", v), + "value", v) + return 0, false + } +} + +// GetProviderOptInt64 extracts an int64 value from provider opts. +// YAML may parse numbers as float64 or int, so this handles both. +func GetProviderOptInt64(opts map[string]any, key string) (int64, bool) { + if opts == nil { + return 0, false + } + v, ok := opts[key] + if !ok { + return 0, false + } + switch n := v.(type) { + case int: + return int64(n), true + case int64: + return n, true + case float64: + if n == math.Trunc(n) && n >= math.MinInt64 && n <= math.MaxInt64 { + return int64(n), true + } + slog.Debug("provider_opts: float64 value is not a valid integer", + "key", key, "value", v) + return 0, false + default: + slog.Debug("provider_opts type mismatch, ignoring", + "key", key, + "expected_type", "integer", + "actual_type", fmt.Sprintf("%T", v), + "value", v) + return 0, false + } +} + +// samplingProviderOptsKeys lists the provider_opts keys that are +// treated as sampling parameters and forwarded to provider APIs. +// Provider-specific infrastructure keys (api_type, transport, region, etc.) +// are NOT included here. +var samplingProviderOptsKeys = []string{ + "top_k", + "repetition_penalty", + "seed", + "min_p", + "typical_p", +} + +// SamplingProviderOptsKeys returns the list of provider_opts keys that are +// treated as sampling parameters and forwarded to provider APIs. +func SamplingProviderOptsKeys() []string { + return append([]string(nil), samplingProviderOptsKeys...) +} diff --git a/pkg/model/provider/providerutil/provider_opts_test.go b/pkg/model/provider/providerutil/provider_opts_test.go new file mode 100644 index 000000000..bbd27d9c8 --- /dev/null +++ b/pkg/model/provider/providerutil/provider_opts_test.go @@ -0,0 +1,62 @@ +package providerutil + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGetProviderOptFloat64(t *testing.T) { + tests := []struct { + name string + opts map[string]any + key string + want float64 + wantOK bool + }{ + {"nil opts", nil, "top_k", 0, false}, + {"missing key", map[string]any{}, "top_k", 0, false}, + {"float64 value", map[string]any{"top_k": 40.0}, "top_k", 40.0, true}, + {"int value", map[string]any{"top_k": 40}, "top_k", 40.0, true}, + {"int64 value", map[string]any{"top_k": int64(40)}, "top_k", 40.0, true}, + {"float32 value", map[string]any{"top_k": float32(40.5)}, "top_k", float64(float32(40.5)), true}, + {"string value", map[string]any{"top_k": "40"}, "top_k", 0, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := GetProviderOptFloat64(tt.opts, tt.key) + assert.Equal(t, tt.wantOK, ok) + if ok { + assert.InDelta(t, tt.want, got, 0.001) + } + }) + } +} + +func TestGetProviderOptInt64(t *testing.T) { + tests := []struct { + name string + opts map[string]any + key string + want int64 + wantOK bool + }{ + {"nil opts", nil, "seed", 0, false}, + {"missing key", map[string]any{}, "seed", 0, false}, + {"int value", map[string]any{"seed": 42}, "seed", 42, true}, + {"int64 value", map[string]any{"seed": int64(42)}, "seed", 42, true}, + {"float64 whole number", map[string]any{"seed": 42.0}, "seed", 42, true}, + {"float64 fractional", map[string]any{"seed": 42.5}, "seed", 0, false}, + {"string value", map[string]any{"seed": "42"}, "seed", 0, false}, + {"float64 overflow", map[string]any{"seed": 1e19}, "seed", 0, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := GetProviderOptInt64(tt.opts, tt.key) + assert.Equal(t, tt.wantOK, ok) + if ok { + assert.Equal(t, tt.want, got) + } + }) + } +}