Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/config/examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func TestParseExamples(t *testing.T) {
continue
}

model, err := modelsStore.GetModel(model.Provider + "/" + model.Model)
model, err := modelsStore.GetModel(t.Context(), model.Provider+"/"+model.Model)
require.NoError(t, err)
require.NotNil(t, model)
}
Expand Down
9 changes: 5 additions & 4 deletions pkg/config/model_alias.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package config

import (
"context"
"log/slog"
"strings"

Expand All @@ -16,7 +17,7 @@ import (
// either set directly on the model or inherited from a custom provider definition.
// This is necessary because external providers (like Azure Foundry) may use the alias
// names directly as deployment names rather than the pinned version names.
func ResolveModelAliases(cfg *latest.Config, store *modelsdev.Store) {
func ResolveModelAliases(ctx context.Context, cfg *latest.Config, store *modelsdev.Store) {
// Resolve model aliases in the models section
for name, modelCfg := range cfg.Models {
// Skip alias resolution for models with custom base_url (direct or via provider)
Expand All @@ -27,15 +28,15 @@ func ResolveModelAliases(cfg *latest.Config, store *modelsdev.Store) {
continue
}

if resolved := store.ResolveModelAlias(modelCfg.Provider, modelCfg.Model); resolved != modelCfg.Model {
if resolved := store.ResolveModelAlias(ctx, modelCfg.Provider, modelCfg.Model); resolved != modelCfg.Model {
modelCfg.Model = resolved
cfg.Models[name] = modelCfg
}

// Resolve model aliases in routing rules
for i, rule := range modelCfg.Routing {
if provider, model, ok := strings.Cut(rule.Model, "/"); ok {
if resolved := store.ResolveModelAlias(provider, model); resolved != model {
if resolved := store.ResolveModelAlias(ctx, provider, model); resolved != model {
modelCfg.Routing[i].Model = provider + "/" + resolved
}
}
Expand All @@ -52,7 +53,7 @@ func ResolveModelAliases(cfg *latest.Config, store *modelsdev.Store) {
var resolvedModels []string
for modelRef := range strings.SplitSeq(agent.Model, ",") {
if provider, model, ok := strings.Cut(modelRef, "/"); ok {
if resolved := store.ResolveModelAlias(provider, model); resolved != model {
if resolved := store.ResolveModelAlias(ctx, provider, model); resolved != model {
resolvedModels = append(resolvedModels, provider+"/"+resolved)
continue
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/config/model_alias_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ func TestResolveModelAliases(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ResolveModelAliases(tt.cfg, store)
ResolveModelAliases(t.Context(), tt.cfg, store)
assert.Equal(t, tt.expected, tt.cfg)
})
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/model/provider/bedrock/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro

// Detect prompt caching capability at init time for efficiency.
// Uses models.dev cache pricing as proxy for capability detection.
cachingSupported := detectCachingSupport(cfg.Model)
cachingSupported := detectCachingSupport(ctx, cfg.Model)

slog.Debug("Bedrock client created successfully",
"model", cfg.Model,
Expand All @@ -133,15 +133,15 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
// detectCachingSupport checks if a model supports prompt caching using models.dev data.
// Models with non-zero CacheRead/CacheWrite costs support prompt caching.
// Returns false on lookup failure (safe default for unsupported models).
func detectCachingSupport(model string) bool {
func detectCachingSupport(ctx context.Context, model string) bool {
store, err := modelsdev.NewStore()
if err != nil {
slog.Debug("Bedrock models store unavailable, prompt caching disabled", "error", err)
return false
}

modelID := "amazon-bedrock/" + model
m, err := store.GetModel(modelID)
m, err := store.GetModel(ctx, modelID)
if err != nil {
slog.Debug("Bedrock prompt caching disabled: model not found in models.dev",
"model_id", modelID, "error", err)
Expand Down
6 changes: 3 additions & 3 deletions pkg/model/provider/bedrock/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1249,23 +1249,23 @@ func TestDetectCachingSupport_SupportedModel(t *testing.T) {
t.Parallel()

// Uses real models.dev lookup to verify Claude models support caching
supported := detectCachingSupport("anthropic.claude-3-5-sonnet-20241022-v2:0")
supported := detectCachingSupport(t.Context(), "anthropic.claude-3-5-sonnet-20241022-v2:0")
assert.True(t, supported)
}

func TestDetectCachingSupport_UnsupportedModel(t *testing.T) {
t.Parallel()

// Llama doesn't have cache pricing in models.dev
supported := detectCachingSupport("meta.llama3-8b-instruct-v1:0")
supported := detectCachingSupport(t.Context(), "meta.llama3-8b-instruct-v1:0")
assert.False(t, supported)
}

func TestDetectCachingSupport_UnknownModel(t *testing.T) {
t.Parallel()

// Unknown model should gracefully return false, not panic
supported := detectCachingSupport("nonexistent.model.that.does.not.exist:v1")
supported := detectCachingSupport(t.Context(), "nonexistent.model.that.does.not.exist:v1")
assert.False(t, supported)
}

Expand Down
72 changes: 35 additions & 37 deletions pkg/modelsdev/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,16 @@ const (
)

// Store manages access to the models.dev data.
// The database is loaded lazily on first access and cached for the
// lifetime of the Store. All methods are safe for concurrent use.
// All methods are safe for concurrent use.
type Store struct {
db func() (*Database, error)
cacheFile string
mu sync.Mutex
db *Database
}

// defaultStore is a cached singleton store instance for repeated access.
var defaultStore = sync.OnceValues(newStoreInternal)

// NewStore returns the cached default store instance.
// The underlying database is fetched lazily on first access
// from a local cache file or the models.dev API.
// NewStore creates a new models.dev store.
// The database is loaded on first access via GetDatabase.
func NewStore() (*Store, error) {
return defaultStore()
}

// newStoreInternal creates a new models.dev store that loads data
// from the filesystem cache or the network on first access.
func newStoreInternal() (*Store, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return nil, fmt.Errorf("failed to get user home directory: %w", err)
Expand All @@ -50,12 +41,8 @@ func newStoreInternal() (*Store, error) {
return nil, fmt.Errorf("failed to create cache directory: %w", err)
}

cacheFile := filepath.Join(cacheDir, CacheFileName)

return &Store{
db: sync.OnceValues(func() (*Database, error) {
return loadDatabase(cacheFile)
}),
cacheFile: filepath.Join(cacheDir, CacheFileName),
}, nil
}

Expand All @@ -64,19 +51,30 @@ func newStoreInternal() (*Store, error) {
// from the network or touches the filesystem, making it suitable for
// tests and any scenario where the provider data is already known.
func NewDatabaseStore(db *Database) *Store {
return &Store{
db: func() (*Database, error) { return db, nil },
}
return &Store{db: db}
}

// GetDatabase returns the models.dev database, fetching from cache or API as needed.
func (s *Store) GetDatabase() (*Database, error) {
return s.db()
func (s *Store) GetDatabase(ctx context.Context) (*Database, error) {
s.mu.Lock()
defer s.mu.Unlock()

if s.db != nil {
return s.db, nil
}

db, err := loadDatabase(ctx, s.cacheFile)
if err != nil {
return nil, err
}

s.db = db
return db, nil
}

// GetProvider returns a specific provider by ID.
func (s *Store) GetProvider(providerID string) (*Provider, error) {
db, err := s.db()
func (s *Store) GetProvider(ctx context.Context, providerID string) (*Provider, error) {
db, err := s.GetDatabase(ctx)
if err != nil {
return nil, err
}
Expand All @@ -90,15 +88,15 @@ func (s *Store) GetProvider(providerID string) (*Provider, error) {
}

// GetModel returns a specific model by provider ID and model ID.
func (s *Store) GetModel(id string) (*Model, error) {
func (s *Store) GetModel(ctx context.Context, id string) (*Model, error) {
parts := strings.SplitN(id, "/", 2)
if len(parts) != 2 {
return nil, fmt.Errorf("invalid model ID: %q", id)
}
providerID := parts[0]
modelID := parts[1]

provider, err := s.GetProvider(providerID)
provider, err := s.GetProvider(ctx, providerID)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -130,15 +128,15 @@ func (s *Store) GetModel(id string) (*Model, error) {

// loadDatabase loads the database from the local cache file or
// falls back to fetching from the models.dev API.
func loadDatabase(cacheFile string) (*Database, error) {
func loadDatabase(ctx context.Context, cacheFile string) (*Database, error) {
// Try to load from cache first
cached, err := loadFromCache(cacheFile)
if err == nil && time.Since(cached.LastRefresh) < refreshInterval {
return &cached.Database, nil
}

// Cache is invalid or doesn't exist, fetch from API
database, fetchErr := fetchFromAPI()
database, fetchErr := fetchFromAPI(ctx)
if fetchErr != nil {
// If API fetch fails, but we have cached data, use it
if cached != nil {
Expand All @@ -156,8 +154,8 @@ func loadDatabase(cacheFile string) (*Database, error) {
return database, nil
}

func fetchFromAPI() (*Database, error) {
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, ModelsDevAPIURL, http.NoBody)
func fetchFromAPI(ctx context.Context) (*Database, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, ModelsDevAPIURL, http.NoBody)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
Expand Down Expand Up @@ -225,7 +223,7 @@ var datePattern = regexp.MustCompile(`-\d{4}-?\d{2}-?\d{2}$`)
// For example, ("anthropic", "claude-sonnet-4-5") might resolve to "claude-sonnet-4-5-20250929".
// If the model is not an alias (already pinned or unknown), the original model name is returned.
// This method uses the models.dev database to find the corresponding pinned version.
func (s *Store) ResolveModelAlias(providerID, modelName string) string {
func (s *Store) ResolveModelAlias(ctx context.Context, providerID, modelName string) string {
if providerID == "" || modelName == "" {
return modelName
}
Expand All @@ -236,7 +234,7 @@ func (s *Store) ResolveModelAlias(providerID, modelName string) string {
}

// Get the provider from the database
provider, err := s.GetProvider(providerID)
provider, err := s.GetProvider(ctx, providerID)
if err != nil {
return modelName
}
Expand Down Expand Up @@ -285,7 +283,7 @@ func isBedrockRegionPrefix(prefix string) bool {
// - If modelID is empty or not in "provider/model" format, returns true (fail-open)
// - If models.dev lookup fails for any reason, returns true (fail-open)
// - If lookup succeeds, returns the model's Reasoning field value
func ModelSupportsReasoning(modelID string) bool {
func ModelSupportsReasoning(ctx context.Context, modelID string) bool {
// Fail-open for empty model ID
if modelID == "" {
return true
Expand All @@ -303,7 +301,7 @@ func ModelSupportsReasoning(modelID string) bool {
return true
}

model, err := store.GetModel(modelID)
model, err := store.GetModel(ctx, modelID)
if err != nil {
slog.Debug("Failed to lookup model in models.dev, assuming reasoning supported to allow user choice", "model_id", modelID, "error", err)
return true
Expand Down
2 changes: 1 addition & 1 deletion pkg/modelsdev/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func TestResolveModelAlias(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := store.ResolveModelAlias(tt.provider, tt.model)
result := store.ResolveModelAlias(t.Context(), tt.provider, tt.model)
assert.Equal(t, tt.expected, result)
})
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/rag/strategy/semantic_embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ func calculateSemanticUsageCost(modelsStore modelStore, modelID string, usage *c
return 0
}

model, err := modelsStore.GetModel(modelID)
model, err := modelsStore.GetModel(context.Background(), modelID)
if err != nil {
slog.Debug("Failed to get semantic model pricing from models.dev, cost will be 0",
"model_id", modelID,
Expand Down
4 changes: 2 additions & 2 deletions pkg/rag/strategy/vector_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ type VectorStore struct {
}

type modelStore interface {
GetModel(modelID string) (*modelsdev.Model, error)
GetModel(ctx context.Context, modelID string) (*modelsdev.Model, error)
}

// EmbeddingInputBuilder builds the string that will be sent to the embedding model
Expand Down Expand Up @@ -174,7 +174,7 @@ func (s *VectorStore) calculateCost(tokens int64) float64 {
return 0
}

model, err := s.modelsStore.GetModel(s.modelID)
model, err := s.modelsStore.GetModel(context.Background(), s.modelID)
if err != nil {
slog.Debug("Failed to get model pricing from models.dev, cost will be 0",
"model_id", s.modelID,
Expand Down
4 changes: 2 additions & 2 deletions pkg/runtime/model_switcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ func (r *LocalRuntime) AvailableModels(ctx context.Context) []ModelChoice {
// buildCatalogChoices builds ModelChoice entries from the models.dev catalog,
// filtered by supported providers and available credentials.
func (r *LocalRuntime) buildCatalogChoices(ctx context.Context) []ModelChoice {
db, err := r.modelsStore.GetDatabase()
db, err := r.modelsStore.GetDatabase(ctx)
if err != nil {
slog.Debug("Failed to get models.dev database for catalog", "error", err)
return nil
Expand Down Expand Up @@ -446,7 +446,7 @@ func (r *LocalRuntime) createProviderFromConfig(ctx context.Context, cfg *latest
if cfg.MaxTokens != nil {
opts = append(opts, options.WithMaxTokens(*cfg.MaxTokens))
} else if r.modelsStore != nil {
m, err := r.modelsStore.GetModel(cfg.Provider + "/" + cfg.Model)
m, err := r.modelsStore.GetModel(ctx, cfg.Provider+"/"+cfg.Model)
if err == nil && m != nil {
opts = append(opts, options.WithMaxTokens(m.Limit.Output))
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/runtime/model_switcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ type mockCatalogStore struct {
db *modelsdev.Database
}

func (m *mockCatalogStore) GetDatabase() (*modelsdev.Database, error) {
func (m *mockCatalogStore) GetDatabase(_ context.Context) (*modelsdev.Database, error) {
return m.db, nil
}

Expand Down
6 changes: 3 additions & 3 deletions pkg/runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,8 @@ type CurrentAgentInfo struct {
}

type ModelStore interface {
GetModel(modelID string) (*modelsdev.Model, error)
GetDatabase() (*modelsdev.Database, error)
GetModel(ctx context.Context, modelID string) (*modelsdev.Model, error)
GetDatabase(ctx context.Context) (*modelsdev.Database, error)
}

// RAGInitializer is implemented by runtimes that support background RAG initialization.
Expand Down Expand Up @@ -988,7 +988,7 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
modelID := model.ID()
slog.Debug("Using agent", "agent", a.Name(), "model", modelID)
slog.Debug("Getting model definition", "model_id", modelID)
m, err := r.modelsStore.GetModel(modelID)
m, err := r.modelsStore.GetModel(ctx, modelID)
if err != nil {
slog.Debug("Failed to get model definition", "error", err)
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/runtime/runtime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ type mockModelStore struct {
ModelStore
}

func (m mockModelStore) GetModel(string) (*modelsdev.Model, error) {
func (m mockModelStore) GetModel(_ context.Context, _ string) (*modelsdev.Model, error) {
return nil, nil
}

Expand Down Expand Up @@ -681,7 +681,7 @@ type mockModelStoreWithLimit struct {
limit int
}

func (m mockModelStoreWithLimit) GetModel(string) (*modelsdev.Model, error) {
func (m mockModelStoreWithLimit) GetModel(_ context.Context, _ string) (*modelsdev.Model, error) {
return &modelsdev.Model{Limit: modelsdev.Limit{Context: m.limit}, Cost: &modelsdev.Cost{}}, nil
}

Expand Down
Loading