diff --git a/README.md b/README.md index 1697b98bb..2bc551076 100644 --- a/README.md +++ b/README.md @@ -161,7 +161,7 @@ agents: models: claude: provider: anthropic - model: claude-sonnet-4-0 + model: claude-sonnet-4-5 max_tokens: 64000 ``` @@ -425,7 +425,7 @@ these three providers in order based on the first api key it finds in your environment. ```sh -export ANTHROPIC_API_KEY=your_api_key_here # first choice. default model claude-sonnet-4-0 +export ANTHROPIC_API_KEY=your_api_key_here # first choice. default model claude-sonnet-4-5 export OPENAI_API_KEY=your_api_key_here # if anthropic key not set. default model gpt-5-mini export GOOGLE_API_KEY=your_api_key_here # if anthropic and openai keys are not set. default model gemini-2.5-flash ``` diff --git a/cmd/root/flags.go b/cmd/root/flags.go index 696f67532..8a88e1ec2 100644 --- a/cmd/root/flags.go +++ b/cmd/root/flags.go @@ -10,12 +10,14 @@ import ( "github.com/spf13/cobra" "github.com/docker/cagent/pkg/config" + "github.com/docker/cagent/pkg/config/latest" "github.com/docker/cagent/pkg/userconfig" ) const ( flagModelsGateway = "models-gateway" envModelsGateway = "CAGENT_MODELS_GATEWAY" + envDefaultModel = "CAGENT_DEFAULT_MODEL" ) func addRuntimeConfigFlags(cmd *cobra.Command, runConfig *config.RuntimeConfig) { @@ -63,17 +65,29 @@ func addGatewayFlags(cmd *cobra.Command, runConfig *config.RuntimeConfig) { persistentPreRunE := cmd.PersistentPreRunE cmd.PersistentPreRunE = func(_ *cobra.Command, args []string) error { + userCfg, err := loadUserConfig() + if err != nil { + slog.Warn("Failed to load user config", "error", err) + userCfg = &userconfig.Config{} + } + // Precedence: CLI flag > environment variable > user config if runConfig.ModelsGateway == "" { if gateway := os.Getenv(envModelsGateway); gateway != "" { runConfig.ModelsGateway = gateway - } else if userCfg, err := loadUserConfig(); err == nil && userCfg.ModelsGateway != "" { + } else if userCfg.ModelsGateway != "" { runConfig.ModelsGateway = userCfg.ModelsGateway } } - runConfig.ModelsGateway = canonize(runConfig.ModelsGateway) + // Precedence for default model: environment variable > user config + if model := os.Getenv(envDefaultModel); model != "" { + runConfig.DefaultModel = parseModelShorthand(model) + } else if userCfg.DefaultModel != nil { + runConfig.DefaultModel = &userCfg.DefaultModel.ModelConfig + } + if err := setupWorkingDirectory(runConfig.WorkingDir); err != nil { return err } @@ -88,3 +102,14 @@ func addGatewayFlags(cmd *cobra.Command, runConfig *config.RuntimeConfig) { return nil } } + +// parseModelShorthand parses "provider/model" into a ModelConfig +func parseModelShorthand(s string) *latest.ModelConfig { + if idx := strings.Index(s, "/"); idx > 0 && idx < len(s)-1 { + return &latest.ModelConfig{ + Provider: s[:idx], + Model: s[idx+1:], + } + } + return nil +} diff --git a/cmd/root/flags_test.go b/cmd/root/flags_test.go index b2a9c330e..e405079c2 100644 --- a/cmd/root/flags_test.go +++ b/cmd/root/flags_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/require" "github.com/docker/cagent/pkg/config" + "github.com/docker/cagent/pkg/config/latest" "github.com/docker/cagent/pkg/userconfig" ) @@ -152,3 +153,82 @@ func TestCanonize(t *testing.T) { }) } } + +func TestDefaultModelLogic(t *testing.T) { + tests := []struct { + name string + env string + userConfig *userconfig.Config + expectedProvider string + expectedModel string + }{ + { + name: "env", + env: "openai/gpt-4o", + expectedProvider: "openai", + expectedModel: "gpt-4o", + }, + { + name: "user_config", + userConfig: &userconfig.Config{ + DefaultModel: &latest.FlexibleModelConfig{ + ModelConfig: latest.ModelConfig{Provider: "google", Model: "gemini-2.5-flash"}, + }, + }, + expectedProvider: "google", + expectedModel: "gemini-2.5-flash", + }, + { + name: "env_overrides_user_config", + env: "openai/gpt-4o", + userConfig: &userconfig.Config{ + DefaultModel: &latest.FlexibleModelConfig{ + ModelConfig: latest.ModelConfig{Provider: "google", Model: "gemini-2.5-flash"}, + }, + }, + expectedProvider: "openai", + expectedModel: "gpt-4o", + }, + { + name: "empty_when_not_set", + expectedProvider: "", + expectedModel: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Setenv("CAGENT_DEFAULT_MODEL", tt.env) + + // Mock user config loader + original := loadUserConfig + loadUserConfig = func() (*userconfig.Config, error) { + if tt.userConfig != nil { + return tt.userConfig, nil + } + return &userconfig.Config{}, nil + } + t.Cleanup(func() { loadUserConfig = original }) + + cmd := &cobra.Command{ + RunE: func(*cobra.Command, []string) error { + return nil + }, + } + runConfig := config.RuntimeConfig{} + addGatewayFlags(cmd, &runConfig) + + cmd.SetArgs(nil) + err := cmd.Execute() + + require.NoError(t, err) + if tt.expectedProvider == "" && tt.expectedModel == "" { + assert.Nil(t, runConfig.DefaultModel) + } else { + require.NotNil(t, runConfig.DefaultModel) + assert.Equal(t, tt.expectedProvider, runConfig.DefaultModel.Provider) + assert.Equal(t, tt.expectedModel, runConfig.DefaultModel.Model) + } + }) + } +} diff --git a/pkg/config/auto.go b/pkg/config/auto.go index 99d21a31c..5b206aa95 100644 --- a/pkg/config/auto.go +++ b/pkg/config/auto.go @@ -52,7 +52,7 @@ To fix this, you can: var DefaultModels = map[string]string{ "openai": "gpt-5-mini", - "anthropic": "claude-sonnet-4-0", + "anthropic": "claude-sonnet-4-5", "google": "gemini-2.5-flash", "dmr": "ai/qwen3:latest", "mistral": "mistral-small-latest", @@ -82,7 +82,16 @@ func AvailableProviders(ctx context.Context, modelsGateway string, env environme return providers } -func AutoModelConfig(ctx context.Context, modelsGateway string, env environment.Provider) latest.ModelConfig { +func AutoModelConfig(ctx context.Context, modelsGateway string, env environment.Provider, defaultModel *latest.ModelConfig) latest.ModelConfig { + // If user specified a default model config, use it (with defaults for unset fields) + if defaultModel != nil && defaultModel.Provider != "" && defaultModel.Model != "" { + result := *defaultModel + if result.MaxTokens == nil { + result.MaxTokens = PreferredMaxTokens(result.Provider) + } + return result + } + availableProviders := AvailableProviders(ctx, modelsGateway, env) firstAvailable := availableProviders[0] diff --git a/pkg/config/auto_test.go b/pkg/config/auto_test.go index 4f3fba94e..3c98c07c0 100644 --- a/pkg/config/auto_test.go +++ b/pkg/config/auto_test.go @@ -5,6 +5,8 @@ import ( "testing" "github.com/stretchr/testify/assert" + + "github.com/docker/cagent/pkg/config/latest" ) type mockEnvProvider struct { @@ -175,7 +177,7 @@ func TestAutoModelConfig(t *testing.T) { "ANTHROPIC_API_KEY": "test-key", }, expectedProvider: "anthropic", - expectedModel: "claude-sonnet-4-0", + expectedModel: "claude-sonnet-4-5", expectedMaxTokens: 32000, }, { @@ -217,7 +219,7 @@ func TestAutoModelConfig(t *testing.T) { envVars: map[string]string{}, gateway: "gateway:8080", expectedProvider: "anthropic", - expectedModel: "claude-sonnet-4-0", + expectedModel: "claude-sonnet-4-5", expectedMaxTokens: 32000, }, } @@ -226,7 +228,7 @@ func TestAutoModelConfig(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - modelConfig := AutoModelConfig(t.Context(), tt.gateway, &mockEnvProvider{envVars: tt.envVars}) + modelConfig := AutoModelConfig(t.Context(), tt.gateway, &mockEnvProvider{envVars: tt.envVars}, nil) assert.Equal(t, tt.expectedProvider, modelConfig.Provider) assert.Equal(t, tt.expectedModel, modelConfig.Model) @@ -295,7 +297,7 @@ func TestDefaultModels(t *testing.T) { // Test specific model values assert.Equal(t, "gpt-5-mini", DefaultModels["openai"]) - assert.Equal(t, "claude-sonnet-4-0", DefaultModels["anthropic"]) + assert.Equal(t, "claude-sonnet-4-5", DefaultModels["anthropic"]) assert.Equal(t, "gemini-2.5-flash", DefaultModels["google"]) assert.Equal(t, "ai/qwen3:latest", DefaultModels["dmr"]) assert.Equal(t, "mistral-small-latest", DefaultModels["mistral"]) @@ -326,7 +328,7 @@ func TestAutoModelConfig_IntegrationWithDefaultModels(t *testing.T) { envVars["MISTRAL_API_KEY"] = "test-key" } - modelConfig := AutoModelConfig(t.Context(), "", &mockEnvProvider{envVars: envVars}) + modelConfig := AutoModelConfig(t.Context(), "", &mockEnvProvider{envVars: envVars}, nil) // Verify the returned model matches the DefaultModels entry expectedModel := DefaultModels[provider] @@ -339,7 +341,7 @@ func TestAutoModelConfig_IntegrationWithDefaultModels(t *testing.T) { t.Run("dmr", func(t *testing.T) { t.Parallel() - modelConfig := AutoModelConfig(t.Context(), "", &mockEnvProvider{envVars: map[string]string{}}) + modelConfig := AutoModelConfig(t.Context(), "", &mockEnvProvider{envVars: map[string]string{}}, nil) assert.Equal(t, "dmr", modelConfig.Provider) assert.Equal(t, DefaultModels["dmr"], modelConfig.Model) @@ -399,3 +401,100 @@ func TestAvailableProviders_PrecedenceOrder(t *testing.T) { providers = AvailableProviders(t.Context(), "", env) assert.Equal(t, "dmr", providers[0]) } + +func TestAutoModelConfig_UserDefaultModel(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + defaultModel *latest.ModelConfig + envVars map[string]string + expectedProvider string + expectedModel string + expectedMaxTokens int64 + }{ + { + name: "user default model overrides auto detection", + defaultModel: &latest.ModelConfig{Provider: "openai", Model: "gpt-4o"}, + envVars: map[string]string{"ANTHROPIC_API_KEY": "test-key"}, + expectedProvider: "openai", + expectedModel: "gpt-4o", + expectedMaxTokens: 32000, + }, + { + name: "user default model with dmr provider", + defaultModel: &latest.ModelConfig{Provider: "dmr", Model: "ai/llama3.2"}, + envVars: map[string]string{"OPENAI_API_KEY": "test-key"}, + expectedProvider: "dmr", + expectedModel: "ai/llama3.2", + expectedMaxTokens: 16000, + }, + { + name: "user default model with anthropic provider", + defaultModel: &latest.ModelConfig{Provider: "anthropic", Model: "claude-sonnet-4-0"}, + envVars: map[string]string{}, + expectedProvider: "anthropic", + expectedModel: "claude-sonnet-4-0", + expectedMaxTokens: 32000, + }, + { + name: "nil default model falls back to auto detection", + defaultModel: nil, + envVars: map[string]string{"GOOGLE_API_KEY": "test-key"}, + expectedProvider: "google", + expectedModel: "gemini-2.5-flash", + expectedMaxTokens: 32000, + }, + { + name: "empty provider falls back to auto detection", + defaultModel: &latest.ModelConfig{Provider: "", Model: "model-only"}, + envVars: map[string]string{"MISTRAL_API_KEY": "test-key"}, + expectedProvider: "mistral", + expectedModel: "mistral-small-latest", + expectedMaxTokens: 32000, + }, + { + name: "empty model falls back to auto detection", + defaultModel: &latest.ModelConfig{Provider: "openai", Model: ""}, + envVars: map[string]string{"ANTHROPIC_API_KEY": "test-key"}, + expectedProvider: "anthropic", + expectedModel: "claude-sonnet-4-5", + expectedMaxTokens: 32000, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + modelConfig := AutoModelConfig(t.Context(), "", &mockEnvProvider{envVars: tt.envVars}, tt.defaultModel) + + assert.Equal(t, tt.expectedProvider, modelConfig.Provider) + assert.Equal(t, tt.expectedModel, modelConfig.Model) + assert.Equal(t, tt.expectedMaxTokens, *modelConfig.MaxTokens) + }) + } +} + +func TestAutoModelConfig_UserDefaultModelWithOptions(t *testing.T) { + t.Parallel() + + // Test that user-provided options like max_tokens, thinking_budget are preserved + customMaxTokens := int64(64000) + thinkingBudget := &latest.ThinkingBudget{Tokens: 10000} + + defaultModel := &latest.ModelConfig{ + Provider: "anthropic", + Model: "claude-sonnet-4-5", + MaxTokens: &customMaxTokens, + ThinkingBudget: thinkingBudget, + } + + modelConfig := AutoModelConfig(t.Context(), "", &mockEnvProvider{envVars: map[string]string{}}, defaultModel) + + assert.Equal(t, "anthropic", modelConfig.Provider) + assert.Equal(t, "claude-sonnet-4-5", modelConfig.Model) + assert.Equal(t, int64(64000), *modelConfig.MaxTokens) + assert.NotNil(t, modelConfig.ThinkingBudget) + assert.Equal(t, 10000, modelConfig.ThinkingBudget.Tokens) +} diff --git a/pkg/config/latest/types.go b/pkg/config/latest/types.go index aab024847..0ec762d33 100644 --- a/pkg/config/latest/types.go +++ b/pkg/config/latest/types.go @@ -3,6 +3,7 @@ package latest import ( "encoding/json" "fmt" + "strings" "github.com/goccy/go-yaml" @@ -159,6 +160,61 @@ type ModelConfig struct { Routing []RoutingRule `json:"routing,omitempty"` } +// FlexibleModelConfig wraps ModelConfig to support both shorthand and full syntax. +// It can be unmarshaled from either: +// - A shorthand string: "provider/model" (e.g., "anthropic/claude-sonnet-4-5") +// - A full model definition with all options +type FlexibleModelConfig struct { + ModelConfig +} + +// UnmarshalYAML implements custom unmarshaling for flexible model config +func (f *FlexibleModelConfig) UnmarshalYAML(unmarshal func(any) error) error { + // Try string shorthand first + var shorthand string + if err := unmarshal(&shorthand); err == nil && shorthand != "" { + provider, model, ok := strings.Cut(shorthand, "/") + if !ok || provider == "" || model == "" { + return fmt.Errorf("invalid model shorthand %q: expected format 'provider/model'", shorthand) + } + f.Provider = provider + f.Model = model + return nil + } + + // Try full model config + var cfg ModelConfig + if err := unmarshal(&cfg); err != nil { + return err + } + f.ModelConfig = cfg + return nil +} + +// MarshalYAML outputs shorthand format if only provider/model are set +func (f FlexibleModelConfig) MarshalYAML() ([]byte, error) { + if f.isShorthandOnly() { + return yaml.Marshal(f.Provider + "/" + f.Model) + } + return yaml.Marshal(f.ModelConfig) +} + +// isShorthandOnly returns true if only provider and model are set +func (f *FlexibleModelConfig) isShorthandOnly() bool { + return f.Temperature == nil && + f.MaxTokens == nil && + f.TopP == nil && + f.FrequencyPenalty == nil && + f.PresencePenalty == nil && + f.BaseURL == "" && + f.ParallelToolCalls == nil && + f.TokenKey == "" && + len(f.ProviderOpts) == 0 && + f.TrackUsage == nil && + f.ThinkingBudget == nil && + len(f.Routing) == 0 +} + // RoutingRule defines a single routing rule for model selection. // Each rule maps example phrases to a target model. type RoutingRule struct { diff --git a/pkg/config/runtime.go b/pkg/config/runtime.go index dd64bcbb4..3d0b2c29e 100644 --- a/pkg/config/runtime.go +++ b/pkg/config/runtime.go @@ -4,6 +4,7 @@ import ( "log/slog" "sync" + "github.com/docker/cagent/pkg/config/latest" "github.com/docker/cagent/pkg/environment" ) @@ -18,6 +19,7 @@ type RuntimeConfig struct { type Config struct { EnvFiles []string ModelsGateway string + DefaultModel *latest.ModelConfig GlobalCodeMode bool WorkingDir string } diff --git a/pkg/teamloader/teamloader.go b/pkg/teamloader/teamloader.go index 06a6bd982..7e5461603 100644 --- a/pkg/teamloader/teamloader.go +++ b/pkg/teamloader/teamloader.go @@ -132,7 +132,7 @@ func LoadWithConfig(ctx context.Context, agentSource config.Source, runConfig *c agentsByName := make(map[string]*agent.Agent) autoModel := sync.OnceValue(func() latest.ModelConfig { - return config.AutoModelConfig(ctx, runConfig.ModelsGateway, env) + return config.AutoModelConfig(ctx, runConfig.ModelsGateway, env, runConfig.DefaultModel) }) expander := js.NewJsExpander(env) diff --git a/pkg/userconfig/userconfig.go b/pkg/userconfig/userconfig.go index 7c320f508..cfcd11a71 100644 --- a/pkg/userconfig/userconfig.go +++ b/pkg/userconfig/userconfig.go @@ -16,6 +16,7 @@ import ( "github.com/goccy/go-yaml" "github.com/natefinch/atomic" + "github.com/docker/cagent/pkg/config/latest" "github.com/docker/cagent/pkg/paths" ) @@ -67,6 +68,9 @@ type Config struct { Version string `yaml:"version,omitempty"` // ModelsGateway is the default models gateway URL ModelsGateway string `yaml:"models_gateway,omitempty"` + // DefaultModel is the default model to use when model is set to "auto". + // Supports both shorthand ("provider/model") and full model definition. + DefaultModel *latest.FlexibleModelConfig `yaml:"default_model,omitempty"` // Aliases maps alias names to alias configurations Aliases map[string]*Alias `yaml:"aliases,omitempty"` // Settings contains global user settings diff --git a/pkg/userconfig/userconfig_test.go b/pkg/userconfig/userconfig_test.go index 8efddeaf9..5ea2d777f 100644 --- a/pkg/userconfig/userconfig_test.go +++ b/pkg/userconfig/userconfig_test.go @@ -5,8 +5,11 @@ import ( "path/filepath" "testing" + "github.com/goccy/go-yaml" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/docker/cagent/pkg/config/latest" ) func TestConfig_Empty(t *testing.T) { @@ -625,3 +628,157 @@ func TestConfig_CredentialHelper_Empty(t *testing.T) { assert.Nil(t, config.CredentialHelper) } + +func TestDefaultModelConfig_Shorthand(t *testing.T) { + t.Parallel() + + yamlContent := `default_model: anthropic/claude-sonnet-4-5` + + var config Config + err := yaml.Unmarshal([]byte(yamlContent), &config) + require.NoError(t, err) + + require.NotNil(t, config.DefaultModel) + assert.Equal(t, "anthropic", config.DefaultModel.Provider) + assert.Equal(t, "claude-sonnet-4-5", config.DefaultModel.Model) + assert.Nil(t, config.DefaultModel.MaxTokens) +} + +func TestDefaultModelConfig_FullDefinition(t *testing.T) { + t.Parallel() + + yamlContent := `default_model: + provider: anthropic + model: claude-sonnet-4-5 + max_tokens: 64000 + thinking_budget: 10000` + + var config Config + err := yaml.Unmarshal([]byte(yamlContent), &config) + require.NoError(t, err) + + require.NotNil(t, config.DefaultModel) + assert.Equal(t, "anthropic", config.DefaultModel.Provider) + assert.Equal(t, "claude-sonnet-4-5", config.DefaultModel.Model) + require.NotNil(t, config.DefaultModel.MaxTokens) + assert.Equal(t, int64(64000), *config.DefaultModel.MaxTokens) + require.NotNil(t, config.DefaultModel.ThinkingBudget) + assert.Equal(t, 10000, config.DefaultModel.ThinkingBudget.Tokens) +} + +func TestDefaultModelConfig_FullDefinitionWithEffort(t *testing.T) { + t.Parallel() + + yamlContent := `default_model: + provider: openai + model: o1 + thinking_budget: high` + + var config Config + err := yaml.Unmarshal([]byte(yamlContent), &config) + require.NoError(t, err) + + require.NotNil(t, config.DefaultModel) + assert.Equal(t, "openai", config.DefaultModel.Provider) + assert.Equal(t, "o1", config.DefaultModel.Model) + require.NotNil(t, config.DefaultModel.ThinkingBudget) + assert.Equal(t, "high", config.DefaultModel.ThinkingBudget.Effort) +} + +func TestDefaultModelConfig_Marshal_ShorthandOutput(t *testing.T) { + t.Parallel() + + config := &latest.FlexibleModelConfig{ + ModelConfig: latest.ModelConfig{ + Provider: "anthropic", + Model: "claude-sonnet-4-5", + }, + } + + data, err := yaml.Marshal(config) + require.NoError(t, err) + + // Should output shorthand format when only provider/model are set + assert.Equal(t, "anthropic/claude-sonnet-4-5\n", string(data)) +} + +func TestDefaultModelConfig_Marshal_FullOutput(t *testing.T) { + t.Parallel() + + maxTokens := int64(64000) + config := &latest.FlexibleModelConfig{ + ModelConfig: latest.ModelConfig{ + Provider: "anthropic", + Model: "claude-sonnet-4-5", + MaxTokens: &maxTokens, + }, + } + + data, err := yaml.Marshal(config) + require.NoError(t, err) + + // Should output full format when extra options are set + assert.Contains(t, string(data), "provider:") + assert.Contains(t, string(data), "model:") + assert.Contains(t, string(data), "max_tokens:") +} + +func TestDefaultModelConfig_InvalidShorthand(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yaml string + wantErr bool + }{ + {"no slash", "default_model: anthropic", true}, + {"empty provider", "default_model: /model", true}, + {"empty model", "default_model: provider/", true}, + {"valid shorthand", "default_model: anthropic/claude-sonnet-4-5", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + var config Config + err := yaml.Unmarshal([]byte(tt.yaml), &config) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestConfig_DefaultModel_SaveAndLoad(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + configFile := filepath.Join(tmpDir, "config.yaml") + + maxTokens := int64(64000) + config := &Config{ + DefaultModel: &latest.FlexibleModelConfig{ + ModelConfig: latest.ModelConfig{ + Provider: "anthropic", + Model: "claude-sonnet-4-5", + MaxTokens: &maxTokens, + ThinkingBudget: &latest.ThinkingBudget{Tokens: 10000}, + }, + }, + } + + require.NoError(t, config.saveTo(configFile)) + + loaded, err := loadFrom(configFile, "") + require.NoError(t, err) + + require.NotNil(t, loaded.DefaultModel) + assert.Equal(t, "anthropic", loaded.DefaultModel.Provider) + assert.Equal(t, "claude-sonnet-4-5", loaded.DefaultModel.Model) + require.NotNil(t, loaded.DefaultModel.MaxTokens) + assert.Equal(t, int64(64000), *loaded.DefaultModel.MaxTokens) + require.NotNil(t, loaded.DefaultModel.ThinkingBudget) + assert.Equal(t, 10000, loaded.DefaultModel.ThinkingBudget.Tokens) +}