Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion agent-schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@
},
"provider_opts": {
"type": "object",
"description": "Provider-specific options. dmr: runtime_flags. anthropic/amazon-bedrock (Claude): interleaved_thinking (boolean, default true). openai: transport ('sse' or 'websocket') to choose between SSE and WebSocket streaming for the Responses API. openai/anthropic/google: rerank_prompt (string) to fully override the system prompt used for RAG reranking (advanced - prefer using results.reranking.criteria for domain-specific guidance).",
"description": "Provider-specific options. Sampling parameters: top_k (integer, supported by anthropic, google, amazon-bedrock, and custom OpenAI-compatible providers like vLLM/Ollama), repetition_penalty (float, forwarded to custom OpenAI-compatible providers), min_p (float, forwarded to custom providers), seed (integer, forwarded to OpenAI). Infrastructure options: dmr: runtime_flags. anthropic/amazon-bedrock (Claude): interleaved_thinking (boolean, default true). openai: transport ('sse' or 'websocket') to choose between SSE and WebSocket streaming for the Responses API. openai/anthropic/google: rerank_prompt (string) to fully override the system prompt used for RAG reranking (advanced - prefer using results.reranking.criteria for domain-specific guidance).",
"additionalProperties": true
},
"track_usage": {
Expand Down
22 changes: 22 additions & 0 deletions examples/sampling-opts.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#!/usr/bin/env docker agent run

# This example shows how to use provider_opts to pass sampling parameters
# like top_k and repetition_penalty to different providers.

agents:
root:
model: gpt
description: "Assistant with custom sampling parameters"
instruction: |
You are a helpful assistant running on a local model with tuned sampling parameters.
models:
gpt:
provider: openai
model: gpt-4o
temperature: 0.7
top_p: 0.9
provider_opts:
top_k: 40
repetition_penalty: 1.15
min_p: 0.05
13 changes: 13 additions & 0 deletions pkg/model/provider/anthropic/beta_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/anthropics/anthropic-sdk-go/packages/ssestream"

"github.com/docker/docker-agent/pkg/chat"
"github.com/docker/docker-agent/pkg/model/provider/providerutil"
"github.com/docker/docker-agent/pkg/rag/prompts"
"github.com/docker/docker-agent/pkg/rag/types"
"github.com/docker/docker-agent/pkg/tools"
Expand Down Expand Up @@ -115,6 +116,12 @@ func (c *Client) createBetaStream(
"max_tokens", maxTokens,
"message_count", len(params.Messages))

// Forward top_k from provider_opts (Anthropic natively supports it)
if topK, ok := providerutil.GetProviderOptInt64(c.ModelConfig.ProviderOpts, "top_k"); ok {
params.TopK = param.NewOpt(topK)
slog.Debug("Anthropic Beta provider_opts: set top_k", "value", topK)
}

stream := client.Beta.Messages.NewStreaming(ctx, params)
trackUsage := c.ModelConfig.TrackUsage == nil || *c.ModelConfig.TrackUsage
ad := c.newBetaStreamAdapter(stream, trackUsage)
Expand Down Expand Up @@ -293,6 +300,12 @@ func (c *Client) Rerank(ctx context.Context, query string, documents []types.Doc
params.TopP = param.NewOpt(*c.ModelConfig.TopP)
}

// Forward top_k from provider_opts (Anthropic natively supports it)
if topK, ok := providerutil.GetProviderOptInt64(c.ModelConfig.ProviderOpts, "top_k"); ok {
params.TopK = param.NewOpt(topK)
slog.Debug("Anthropic Beta provider_opts: set top_k", "value", topK)
}

// Use streaming API to avoid timeout errors for operations that may take longer than 10 minutes
stream := client.Beta.Messages.NewStreaming(ctx, params)

Expand Down
7 changes: 7 additions & 0 deletions pkg/model/provider/anthropic/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/docker/docker-agent/pkg/httpclient"
"github.com/docker/docker-agent/pkg/model/provider/base"
"github.com/docker/docker-agent/pkg/model/provider/options"
"github.com/docker/docker-agent/pkg/model/provider/providerutil"
"github.com/docker/docker-agent/pkg/tools"
)

Expand Down Expand Up @@ -337,6 +338,12 @@ func (c *Client) CreateChatCompletionStream(
slog.Debug("Anthropic extended thinking enabled, ignoring temperature/top_p settings")
}

// Forward top_k from provider_opts (Anthropic natively supports it)
if topK, ok := providerutil.GetProviderOptInt64(c.ModelConfig.ProviderOpts, "top_k"); ok {
params.TopK = param.NewOpt(topK)
slog.Debug("Anthropic provider_opts: set top_k", "value", topK)
}

if len(requestTools) > 0 {
slog.Debug("Adding tools to Anthropic request", "tool_count", len(requestTools))
}
Expand Down
112 changes: 77 additions & 35 deletions pkg/model/provider/bedrock/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/docker/docker-agent/pkg/environment"
"github.com/docker/docker-agent/pkg/model/provider/base"
"github.com/docker/docker-agent/pkg/model/provider/options"
"github.com/docker/docker-agent/pkg/model/provider/providerutil"
"github.com/docker/docker-agent/pkg/modelsdev"
"github.com/docker/docker-agent/pkg/tools"
)
Expand Down Expand Up @@ -244,7 +245,7 @@ func (c *Client) buildConverseStreamInput(messages []chat.Message, requestTools
}

// Set inference configuration (temp/topP are suppressed when thinking is on).
input.InferenceConfig = c.buildInferenceConfig(additionalFields != nil)
input.InferenceConfig = c.buildInferenceConfig(c.isThinkingEnabled())

// Convert and set tools
if len(requestTools) > 0 {
Expand Down Expand Up @@ -278,59 +279,100 @@ func (c *Client) buildInferenceConfig(thinkingEnabled bool) *types.InferenceConf
}

func (c *Client) interleavedThinkingEnabled() bool {
return getProviderOpt[bool](c.ModelConfig.ProviderOpts, "interleaved_thinking")
}

func (c *Client) promptCachingEnabled() bool {
if getProviderOpt[bool](c.ModelConfig.ProviderOpts, "disable_prompt_caching") {
return false
// Default to true, matching the documented schema behavior.
v, ok := c.ModelConfig.ProviderOpts["interleaved_thinking"]
if !ok {
return true
}
return c.cachingSupported
b, ok := v.(bool)
if !ok {
slog.Warn("Bedrock provider_opts type mismatch",
"key", "interleaved_thinking",
"expected_type", "bool",
"actual_type", fmt.Sprintf("%T", v),
"value", v)
return true
}
return b
}

// buildAdditionalModelRequestFields configures Claude's extended thinking (reasoning) mode.
func (c *Client) buildAdditionalModelRequestFields() document.Interface {
// isThinkingEnabled returns true if a valid thinking budget is configured.
// It mirrors the validation in buildAdditionalModelRequestFields but without
// side effects (no logging), so it can safely be used to gate inference config.
func (c *Client) isThinkingEnabled() bool {
if c.ModelConfig.ThinkingBudget == nil {
return nil
return false
}
tokens := c.ModelConfig.ThinkingBudget.Tokens
if t, ok := c.ModelConfig.ThinkingBudget.EffortTokens(); ok {
tokens = t
}
if tokens <= 0 {
return nil
}

// Validate minimum (Claude requires at least 1024 tokens for thinking)
if tokens < 1024 {
slog.Warn("Bedrock thinking_budget below minimum (1024), ignoring",
"tokens", tokens)
return nil
return false
}

// Validate against max_tokens
if c.ModelConfig.MaxTokens != nil && tokens >= int(*c.ModelConfig.MaxTokens) {
slog.Warn("Bedrock thinking_budget must be less than max_tokens, ignoring",
"thinking_budget", tokens,
"max_tokens", *c.ModelConfig.MaxTokens)
return nil
return false
}
return true
}

slog.Debug("Bedrock request using thinking_budget", "budget_tokens", tokens)
func (c *Client) promptCachingEnabled() bool {
if getProviderOpt[bool](c.ModelConfig.ProviderOpts, "disable_prompt_caching") {
return false
}
return c.cachingSupported
}

fields := map[string]any{
"thinking": map[string]any{
"type": "enabled",
"budget_tokens": tokens,
},
// buildAdditionalModelRequestFields configures Claude's extended thinking (reasoning) mode
// and forwards supported sampling parameters from provider_opts (e.g. top_k).
func (c *Client) buildAdditionalModelRequestFields() document.Interface {
fields := map[string]any{}

// Forward top_k from provider_opts (Anthropic on Bedrock supports it)
if topK, ok := providerutil.GetProviderOptInt64(c.ModelConfig.ProviderOpts, "top_k"); ok {
fields["top_k"] = topK
slog.Debug("Bedrock provider_opts: set top_k", "value", topK)
}

// Add anthropic_beta field for interleaved thinking
if c.interleavedThinkingEnabled() {
fields["anthropic_beta"] = []string{"interleaved-thinking-2025-05-14"}
slog.Debug("Bedrock request using interleaved thinking beta")
// Configure thinking budget if present and valid
if budget := c.ModelConfig.ThinkingBudget; budget != nil {
tokens := budget.Tokens
if t, ok := budget.EffortTokens(); ok {
tokens = t
}

valid := tokens > 0
if valid && tokens < 1024 {
slog.Warn("Bedrock thinking_budget below minimum (1024), ignoring", "tokens", tokens)
valid = false
}
if valid && c.ModelConfig.MaxTokens != nil && tokens >= int(*c.ModelConfig.MaxTokens) {
slog.Warn("Bedrock thinking_budget must be less than max_tokens, ignoring",
"thinking_budget", tokens,
"max_tokens", *c.ModelConfig.MaxTokens)
valid = false
}

if valid {
slog.Debug("Bedrock request using thinking_budget", "budget_tokens", tokens)
fields["thinking"] = map[string]any{
"type": "enabled",
"budget_tokens": tokens,
}

if c.interleavedThinkingEnabled() {
fields["anthropic_beta"] = []string{"interleaved-thinking-2025-05-14"}
slog.Debug("Bedrock request using interleaved thinking beta")
} else {
slog.Warn("Bedrock thinking_budget is set but interleaved_thinking is explicitly disabled; " +
"the anthropic_beta header will not be sent, which may cause the thinking budget to be ignored")
}
}
}

if len(fields) == 0 {
return nil
}
return document.NewLazyDocument(fields)
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/model/provider/bedrock/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -854,7 +854,7 @@ func TestInterleavedThinkingEnabled_NotSet(t *testing.T) {
},
}

assert.False(t, client.interleavedThinkingEnabled())
assert.True(t, client.interleavedThinkingEnabled())
}

func TestInterleavedThinkingEnabled_NilProviderOpts(t *testing.T) {
Expand All @@ -870,7 +870,7 @@ func TestInterleavedThinkingEnabled_NilProviderOpts(t *testing.T) {
},
}

assert.False(t, client.interleavedThinkingEnabled())
assert.True(t, client.interleavedThinkingEnabled())
}

func TestBuildAdditionalModelRequestFields_WithInterleavedThinking(t *testing.T) {
Expand Down
7 changes: 7 additions & 0 deletions pkg/model/provider/gemini/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/docker/docker-agent/pkg/httpclient"
"github.com/docker/docker-agent/pkg/model/provider/base"
"github.com/docker/docker-agent/pkg/model/provider/options"
"github.com/docker/docker-agent/pkg/model/provider/providerutil"
"github.com/docker/docker-agent/pkg/rag/prompts"
"github.com/docker/docker-agent/pkg/rag/types"
"github.com/docker/docker-agent/pkg/tools"
Expand Down Expand Up @@ -352,6 +353,12 @@ func (c *Client) buildConfig() *genai.GenerateContentConfig {
config.PresencePenalty = new(float32(*c.ModelConfig.PresencePenalty))
}

// Forward top_k from provider_opts (Gemini natively supports it)
if topK, ok := providerutil.GetProviderOptFloat64(c.ModelConfig.ProviderOpts, "top_k"); ok {
config.TopK = new(float32(topK))
slog.Debug("Gemini provider_opts: set top_k", "value", topK)
}

// Apply thinking configuration for Gemini models.
// See https://ai.google.dev/gemini-api/docs/thinking
if c.ModelOptions.NoThinking() {
Expand Down
7 changes: 7 additions & 0 deletions pkg/model/provider/openai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,11 @@ func (c *Client) CreateChatCompletionStream(
return nil, err
}

// Forward sampling-related provider_opts as extra body fields.
// This allows custom/OpenAI-compatible providers (vLLM, Ollama, etc.)
// to receive parameters like top_k, repetition_penalty, etc.
applySamplingProviderOpts(&params, c.ModelConfig.ProviderOpts)

stream := client.Chat.Completions.NewStreaming(ctx, params)

slog.Debug("OpenAI chat completion stream created successfully", "model", c.ModelConfig.Model)
Expand Down Expand Up @@ -842,6 +847,8 @@ func (c *Client) Rerank(ctx context.Context, query string, documents []types.Doc
},
}

applySamplingProviderOpts(&params, c.ModelConfig.ProviderOpts)

resp, err := client.Chat.Completions.New(ctx, params)
if err != nil {
slog.Error("OpenAI rerank request failed", "error", err)
Expand Down
46 changes: 46 additions & 0 deletions pkg/model/provider/openai/sampling_opts.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package openai

import (
"log/slog"

oai "github.com/openai/openai-go/v3"

"github.com/docker/docker-agent/pkg/model/provider/providerutil"
)

// applySamplingProviderOpts forwards sampling-related provider_opts as extra
// body fields on the OpenAI ChatCompletionNewParams. This enables custom
// OpenAI-compatible providers (vLLM, Ollama, llama.cpp, etc.) to receive
// parameters like top_k, repetition_penalty, min_p, etc. that the native
// OpenAI API does not support but these backends do.
func applySamplingProviderOpts(params *oai.ChatCompletionNewParams, opts map[string]any) {
if len(opts) == 0 {
return
}

extras := make(map[string]any)

for _, key := range providerutil.SamplingProviderOptsKeys() {
if key == "seed" {
// seed is a native ChatCompletionNewParams field (int64),
// so set it directly rather than as an extra field.
if v, ok := providerutil.GetProviderOptInt64(opts, key); ok {
params.Seed = oai.Int(v)
slog.Debug("OpenAI provider_opts: set seed", "value", v)
}
continue
}

if v, ok := providerutil.GetProviderOptFloat64(opts, key); ok {
extras[key] = v
slog.Debug("OpenAI provider_opts: forwarding sampling param", "key", key, "value", v)
} else if vi, ok := providerutil.GetProviderOptInt64(opts, key); ok {
extras[key] = vi
slog.Debug("OpenAI provider_opts: forwarding sampling param", "key", key, "value", vi)
}
}

if len(extras) > 0 {
params.SetExtraFields(extras)
}
}
Loading
Loading