Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/config/examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func TestParseExamples(t *testing.T) {
continue
}

model, err := modelsStore.GetModel(t.Context(), model.Provider+"/"+model.Model)
model, err := modelsStore.GetModel(model.Provider + "/" + model.Model)
require.NoError(t, err)
require.NotNil(t, model)
}
Expand Down
15 changes: 4 additions & 11 deletions pkg/config/model_alias.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package config

import (
"context"
"log/slog"
"strings"

Expand All @@ -17,13 +16,7 @@ import (
// either set directly on the model or inherited from a custom provider definition.
// This is necessary because external providers (like Azure Foundry) may use the alias
// names directly as deployment names rather than the pinned version names.
func ResolveModelAliases(ctx context.Context, cfg *latest.Config) {
store, err := modelsdev.NewStore()
if err != nil {
slog.Debug("Failed to create modelsdev store for alias resolution", "error", err)
return
}

func ResolveModelAliases(cfg *latest.Config, store *modelsdev.Store) {
// Resolve model aliases in the models section
for name, modelCfg := range cfg.Models {
// Skip alias resolution for models with custom base_url (direct or via provider)
Expand All @@ -34,15 +27,15 @@ func ResolveModelAliases(ctx context.Context, cfg *latest.Config) {
continue
}

if resolved := store.ResolveModelAlias(ctx, modelCfg.Provider, modelCfg.Model); resolved != modelCfg.Model {
if resolved := store.ResolveModelAlias(modelCfg.Provider, modelCfg.Model); resolved != modelCfg.Model {
modelCfg.Model = resolved
cfg.Models[name] = modelCfg
}

// Resolve model aliases in routing rules
for i, rule := range modelCfg.Routing {
if provider, model, ok := strings.Cut(rule.Model, "/"); ok {
if resolved := store.ResolveModelAlias(ctx, provider, model); resolved != model {
if resolved := store.ResolveModelAlias(provider, model); resolved != model {
modelCfg.Routing[i].Model = provider + "/" + resolved
}
}
Expand All @@ -59,7 +52,7 @@ func ResolveModelAliases(ctx context.Context, cfg *latest.Config) {
var resolvedModels []string
for modelRef := range strings.SplitSeq(agent.Model, ",") {
if provider, model, ok := strings.Cut(modelRef, "/"); ok {
if resolved := store.ResolveModelAlias(ctx, provider, model); resolved != model {
if resolved := store.ResolveModelAlias(provider, model); resolved != model {
resolvedModels = append(resolvedModels, provider+"/"+resolved)
continue
}
Expand Down
12 changes: 2 additions & 10 deletions pkg/config/model_alias_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/docker/cagent/pkg/config/latest"
"github.com/docker/cagent/pkg/modelsdev"
Expand All @@ -27,14 +26,7 @@ func TestResolveModelAliases(t *testing.T) {
},
}

store, err := modelsdev.NewStore(modelsdev.WithCacheDir(t.TempDir()))
require.NoError(t, err)
store.SetDatabaseForTesting(mockData)
t.Cleanup(func() {
store.SetDatabaseForTesting(nil)
})

ctx := t.Context()
store := modelsdev.NewDatabaseStore(mockData)

tests := []struct {
name string
Expand Down Expand Up @@ -245,7 +237,7 @@ func TestResolveModelAliases(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ResolveModelAliases(ctx, tt.cfg)
ResolveModelAliases(tt.cfg, store)
assert.Equal(t, tt.expected, tt.cfg)
})
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/model/provider/bedrock/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro

// Detect prompt caching capability at init time for efficiency.
// Uses models.dev cache pricing as proxy for capability detection.
cachingSupported := detectCachingSupport(ctx, cfg.Model)
cachingSupported := detectCachingSupport(cfg.Model)

slog.Debug("Bedrock client created successfully",
"model", cfg.Model,
Expand All @@ -133,15 +133,15 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
// detectCachingSupport checks if a model supports prompt caching using models.dev data.
// Models with non-zero CacheRead/CacheWrite costs support prompt caching.
// Returns false on lookup failure (safe default for unsupported models).
func detectCachingSupport(ctx context.Context, model string) bool {
func detectCachingSupport(model string) bool {
store, err := modelsdev.NewStore()
if err != nil {
slog.Debug("Bedrock models store unavailable, prompt caching disabled", "error", err)
return false
}

modelID := "amazon-bedrock/" + model
m, err := store.GetModel(ctx, modelID)
m, err := store.GetModel(modelID)
if err != nil {
slog.Debug("Bedrock prompt caching disabled: model not found in models.dev",
"model_id", modelID, "error", err)
Expand Down
6 changes: 3 additions & 3 deletions pkg/model/provider/bedrock/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1249,23 +1249,23 @@ func TestDetectCachingSupport_SupportedModel(t *testing.T) {
t.Parallel()

// Uses real models.dev lookup to verify Claude models support caching
supported := detectCachingSupport(t.Context(), "anthropic.claude-3-5-sonnet-20241022-v2:0")
supported := detectCachingSupport("anthropic.claude-3-5-sonnet-20241022-v2:0")
assert.True(t, supported)
}

func TestDetectCachingSupport_UnsupportedModel(t *testing.T) {
t.Parallel()

// Llama doesn't have cache pricing in models.dev
supported := detectCachingSupport(t.Context(), "meta.llama3-8b-instruct-v1:0")
supported := detectCachingSupport("meta.llama3-8b-instruct-v1:0")
assert.False(t, supported)
}

func TestDetectCachingSupport_UnknownModel(t *testing.T) {
t.Parallel()

// Unknown model should gracefully return false, not panic
supported := detectCachingSupport(t.Context(), "nonexistent.model.that.does.not.exist:v1")
supported := detectCachingSupport("nonexistent.model.that.does.not.exist:v1")
assert.False(t, supported)
}

Expand Down
Loading