Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ agents:
models:
claude:
provider: anthropic
model: claude-sonnet-4-0
model: claude-sonnet-4-5
max_tokens: 64000
```

Expand Down Expand Up @@ -425,7 +425,7 @@ these three providers in order based on the first api key it finds in your
environment.

```sh
export ANTHROPIC_API_KEY=your_api_key_here # first choice. default model claude-sonnet-4-0
export ANTHROPIC_API_KEY=your_api_key_here # first choice. default model claude-sonnet-4-5
export OPENAI_API_KEY=your_api_key_here # if anthropic key not set. default model gpt-5-mini
export GOOGLE_API_KEY=your_api_key_here # if anthropic and openai keys are not set. default model gemini-2.5-flash
```
Expand Down
29 changes: 27 additions & 2 deletions cmd/root/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ import (
"github.com/spf13/cobra"

"github.com/docker/cagent/pkg/config"
"github.com/docker/cagent/pkg/config/latest"
"github.com/docker/cagent/pkg/userconfig"
)

const (
flagModelsGateway = "models-gateway"
envModelsGateway = "CAGENT_MODELS_GATEWAY"
envDefaultModel = "CAGENT_DEFAULT_MODEL"
)

func addRuntimeConfigFlags(cmd *cobra.Command, runConfig *config.RuntimeConfig) {
Expand Down Expand Up @@ -63,17 +65,29 @@ func addGatewayFlags(cmd *cobra.Command, runConfig *config.RuntimeConfig) {

persistentPreRunE := cmd.PersistentPreRunE
cmd.PersistentPreRunE = func(_ *cobra.Command, args []string) error {
userCfg, err := loadUserConfig()
if err != nil {
slog.Warn("Failed to load user config", "error", err)
userCfg = &userconfig.Config{}
}

// Precedence: CLI flag > environment variable > user config
if runConfig.ModelsGateway == "" {
if gateway := os.Getenv(envModelsGateway); gateway != "" {
runConfig.ModelsGateway = gateway
} else if userCfg, err := loadUserConfig(); err == nil && userCfg.ModelsGateway != "" {
} else if userCfg.ModelsGateway != "" {
runConfig.ModelsGateway = userCfg.ModelsGateway
}
}

runConfig.ModelsGateway = canonize(runConfig.ModelsGateway)

// Precedence for default model: environment variable > user config
if model := os.Getenv(envDefaultModel); model != "" {
runConfig.DefaultModel = parseModelShorthand(model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Silent failure when parsing invalid environment variable

When CAGENT_DEFAULT_MODEL is set to an invalid format (e.g., "anthropic" without /model, or "/model" without provider), parseModelShorthand returns nil silently with no warning or error. This means:

  1. The user's explicit configuration is ignored without any feedback
  2. The system falls back to auto-detection or user config without the user knowing
  3. Users may waste time debugging why their environment variable isn't working

Recommendation: Add a warning log when the environment variable is set but parsing fails:

if model := os.Getenv(envDefaultModel); model != "" {
    runConfig.DefaultModel = parseModelShorthand(model)
    if runConfig.DefaultModel == nil {
        slog.Warn("Invalid CAGENT_DEFAULT_MODEL format, expected 'provider/model'", "value", model)
    }
} else if userCfg.DefaultModel != nil {
    runConfig.DefaultModel = &userCfg.DefaultModel.ModelConfig
}

This gives users immediate feedback when they've misconfigured the environment variable.

} else if userCfg.DefaultModel != nil {
runConfig.DefaultModel = &userCfg.DefaultModel.ModelConfig
}

if err := setupWorkingDirectory(runConfig.WorkingDir); err != nil {
return err
}
Expand All @@ -88,3 +102,14 @@ func addGatewayFlags(cmd *cobra.Command, runConfig *config.RuntimeConfig) {
return nil
}
}

// parseModelShorthand parses "provider/model" into a ModelConfig
func parseModelShorthand(s string) *latest.ModelConfig {
if idx := strings.Index(s, "/"); idx > 0 && idx < len(s)-1 {
return &latest.ModelConfig{
Provider: s[:idx],
Model: s[idx+1:],
}
}
return nil
}
80 changes: 80 additions & 0 deletions cmd/root/flags_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/stretchr/testify/require"

"github.com/docker/cagent/pkg/config"
"github.com/docker/cagent/pkg/config/latest"
"github.com/docker/cagent/pkg/userconfig"
)

Expand Down Expand Up @@ -152,3 +153,82 @@ func TestCanonize(t *testing.T) {
})
}
}

func TestDefaultModelLogic(t *testing.T) {
tests := []struct {
name string
env string
userConfig *userconfig.Config
expectedProvider string
expectedModel string
}{
{
name: "env",
env: "openai/gpt-4o",
expectedProvider: "openai",
expectedModel: "gpt-4o",
},
{
name: "user_config",
userConfig: &userconfig.Config{
DefaultModel: &latest.FlexibleModelConfig{
ModelConfig: latest.ModelConfig{Provider: "google", Model: "gemini-2.5-flash"},
},
},
expectedProvider: "google",
expectedModel: "gemini-2.5-flash",
},
{
name: "env_overrides_user_config",
env: "openai/gpt-4o",
userConfig: &userconfig.Config{
DefaultModel: &latest.FlexibleModelConfig{
ModelConfig: latest.ModelConfig{Provider: "google", Model: "gemini-2.5-flash"},
},
},
expectedProvider: "openai",
expectedModel: "gpt-4o",
},
{
name: "empty_when_not_set",
expectedProvider: "",
expectedModel: "",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Setenv("CAGENT_DEFAULT_MODEL", tt.env)

// Mock user config loader
original := loadUserConfig
loadUserConfig = func() (*userconfig.Config, error) {
if tt.userConfig != nil {
return tt.userConfig, nil
}
return &userconfig.Config{}, nil
}
t.Cleanup(func() { loadUserConfig = original })

cmd := &cobra.Command{
RunE: func(*cobra.Command, []string) error {
return nil
},
}
runConfig := config.RuntimeConfig{}
addGatewayFlags(cmd, &runConfig)

cmd.SetArgs(nil)
err := cmd.Execute()

require.NoError(t, err)
if tt.expectedProvider == "" && tt.expectedModel == "" {
assert.Nil(t, runConfig.DefaultModel)
} else {
require.NotNil(t, runConfig.DefaultModel)
assert.Equal(t, tt.expectedProvider, runConfig.DefaultModel.Provider)
assert.Equal(t, tt.expectedModel, runConfig.DefaultModel.Model)
}
})
}
}
13 changes: 11 additions & 2 deletions pkg/config/auto.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ To fix this, you can:

var DefaultModels = map[string]string{
"openai": "gpt-5-mini",
"anthropic": "claude-sonnet-4-0",
"anthropic": "claude-sonnet-4-5",
"google": "gemini-2.5-flash",
"dmr": "ai/qwen3:latest",
"mistral": "mistral-small-latest",
Expand Down Expand Up @@ -82,7 +82,16 @@ func AvailableProviders(ctx context.Context, modelsGateway string, env environme
return providers
}

func AutoModelConfig(ctx context.Context, modelsGateway string, env environment.Provider) latest.ModelConfig {
func AutoModelConfig(ctx context.Context, modelsGateway string, env environment.Provider, defaultModel *latest.ModelConfig) latest.ModelConfig {
// If user specified a default model config, use it (with defaults for unset fields)
if defaultModel != nil && defaultModel.Provider != "" && defaultModel.Model != "" {
result := *defaultModel
if result.MaxTokens == nil {
result.MaxTokens = PreferredMaxTokens(result.Provider)
}
return result
}

availableProviders := AvailableProviders(ctx, modelsGateway, env)
firstAvailable := availableProviders[0]

Expand Down
111 changes: 105 additions & 6 deletions pkg/config/auto_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"testing"

"github.com/stretchr/testify/assert"

"github.com/docker/cagent/pkg/config/latest"
)

type mockEnvProvider struct {
Expand Down Expand Up @@ -175,7 +177,7 @@ func TestAutoModelConfig(t *testing.T) {
"ANTHROPIC_API_KEY": "test-key",
},
expectedProvider: "anthropic",
expectedModel: "claude-sonnet-4-0",
expectedModel: "claude-sonnet-4-5",
expectedMaxTokens: 32000,
},
{
Expand Down Expand Up @@ -217,7 +219,7 @@ func TestAutoModelConfig(t *testing.T) {
envVars: map[string]string{},
gateway: "gateway:8080",
expectedProvider: "anthropic",
expectedModel: "claude-sonnet-4-0",
expectedModel: "claude-sonnet-4-5",
expectedMaxTokens: 32000,
},
}
Expand All @@ -226,7 +228,7 @@ func TestAutoModelConfig(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

modelConfig := AutoModelConfig(t.Context(), tt.gateway, &mockEnvProvider{envVars: tt.envVars})
modelConfig := AutoModelConfig(t.Context(), tt.gateway, &mockEnvProvider{envVars: tt.envVars}, nil)

assert.Equal(t, tt.expectedProvider, modelConfig.Provider)
assert.Equal(t, tt.expectedModel, modelConfig.Model)
Expand Down Expand Up @@ -295,7 +297,7 @@ func TestDefaultModels(t *testing.T) {

// Test specific model values
assert.Equal(t, "gpt-5-mini", DefaultModels["openai"])
assert.Equal(t, "claude-sonnet-4-0", DefaultModels["anthropic"])
assert.Equal(t, "claude-sonnet-4-5", DefaultModels["anthropic"])
assert.Equal(t, "gemini-2.5-flash", DefaultModels["google"])
assert.Equal(t, "ai/qwen3:latest", DefaultModels["dmr"])
assert.Equal(t, "mistral-small-latest", DefaultModels["mistral"])
Expand Down Expand Up @@ -326,7 +328,7 @@ func TestAutoModelConfig_IntegrationWithDefaultModels(t *testing.T) {
envVars["MISTRAL_API_KEY"] = "test-key"
}

modelConfig := AutoModelConfig(t.Context(), "", &mockEnvProvider{envVars: envVars})
modelConfig := AutoModelConfig(t.Context(), "", &mockEnvProvider{envVars: envVars}, nil)

// Verify the returned model matches the DefaultModels entry
expectedModel := DefaultModels[provider]
Expand All @@ -339,7 +341,7 @@ func TestAutoModelConfig_IntegrationWithDefaultModels(t *testing.T) {
t.Run("dmr", func(t *testing.T) {
t.Parallel()

modelConfig := AutoModelConfig(t.Context(), "", &mockEnvProvider{envVars: map[string]string{}})
modelConfig := AutoModelConfig(t.Context(), "", &mockEnvProvider{envVars: map[string]string{}}, nil)

assert.Equal(t, "dmr", modelConfig.Provider)
assert.Equal(t, DefaultModels["dmr"], modelConfig.Model)
Expand Down Expand Up @@ -399,3 +401,100 @@ func TestAvailableProviders_PrecedenceOrder(t *testing.T) {
providers = AvailableProviders(t.Context(), "", env)
assert.Equal(t, "dmr", providers[0])
}

func TestAutoModelConfig_UserDefaultModel(t *testing.T) {
t.Parallel()

tests := []struct {
name string
defaultModel *latest.ModelConfig
envVars map[string]string
expectedProvider string
expectedModel string
expectedMaxTokens int64
}{
{
name: "user default model overrides auto detection",
defaultModel: &latest.ModelConfig{Provider: "openai", Model: "gpt-4o"},
envVars: map[string]string{"ANTHROPIC_API_KEY": "test-key"},
expectedProvider: "openai",
expectedModel: "gpt-4o",
expectedMaxTokens: 32000,
},
{
name: "user default model with dmr provider",
defaultModel: &latest.ModelConfig{Provider: "dmr", Model: "ai/llama3.2"},
envVars: map[string]string{"OPENAI_API_KEY": "test-key"},
expectedProvider: "dmr",
expectedModel: "ai/llama3.2",
expectedMaxTokens: 16000,
},
{
name: "user default model with anthropic provider",
defaultModel: &latest.ModelConfig{Provider: "anthropic", Model: "claude-sonnet-4-0"},
envVars: map[string]string{},
expectedProvider: "anthropic",
expectedModel: "claude-sonnet-4-0",
expectedMaxTokens: 32000,
},
{
name: "nil default model falls back to auto detection",
defaultModel: nil,
envVars: map[string]string{"GOOGLE_API_KEY": "test-key"},
expectedProvider: "google",
expectedModel: "gemini-2.5-flash",
expectedMaxTokens: 32000,
},
{
name: "empty provider falls back to auto detection",
defaultModel: &latest.ModelConfig{Provider: "", Model: "model-only"},
envVars: map[string]string{"MISTRAL_API_KEY": "test-key"},
expectedProvider: "mistral",
expectedModel: "mistral-small-latest",
expectedMaxTokens: 32000,
},
{
name: "empty model falls back to auto detection",
defaultModel: &latest.ModelConfig{Provider: "openai", Model: ""},
envVars: map[string]string{"ANTHROPIC_API_KEY": "test-key"},
expectedProvider: "anthropic",
expectedModel: "claude-sonnet-4-5",
expectedMaxTokens: 32000,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

modelConfig := AutoModelConfig(t.Context(), "", &mockEnvProvider{envVars: tt.envVars}, tt.defaultModel)

assert.Equal(t, tt.expectedProvider, modelConfig.Provider)
assert.Equal(t, tt.expectedModel, modelConfig.Model)
assert.Equal(t, tt.expectedMaxTokens, *modelConfig.MaxTokens)
})
}
}

func TestAutoModelConfig_UserDefaultModelWithOptions(t *testing.T) {
t.Parallel()

// Test that user-provided options like max_tokens, thinking_budget are preserved
customMaxTokens := int64(64000)
thinkingBudget := &latest.ThinkingBudget{Tokens: 10000}

defaultModel := &latest.ModelConfig{
Provider: "anthropic",
Model: "claude-sonnet-4-5",
MaxTokens: &customMaxTokens,
ThinkingBudget: thinkingBudget,
}

modelConfig := AutoModelConfig(t.Context(), "", &mockEnvProvider{envVars: map[string]string{}}, defaultModel)

assert.Equal(t, "anthropic", modelConfig.Provider)
assert.Equal(t, "claude-sonnet-4-5", modelConfig.Model)
assert.Equal(t, int64(64000), *modelConfig.MaxTokens)
assert.NotNil(t, modelConfig.ThinkingBudget)
assert.Equal(t, 10000, modelConfig.ThinkingBudget.Tokens)
}
Loading