diff --git a/pkg/config/examples_test.go b/pkg/config/examples_test.go index 070b9f515..f20b583f2 100644 --- a/pkg/config/examples_test.go +++ b/pkg/config/examples_test.go @@ -70,7 +70,7 @@ func TestParseExamples(t *testing.T) { continue } - model, err := modelsStore.GetModel(model.Provider + "/" + model.Model) + model, err := modelsStore.GetModel(t.Context(), model.Provider+"/"+model.Model) require.NoError(t, err) require.NotNil(t, model) } diff --git a/pkg/config/model_alias.go b/pkg/config/model_alias.go index e322bd86c..e1bdfd64e 100644 --- a/pkg/config/model_alias.go +++ b/pkg/config/model_alias.go @@ -1,6 +1,7 @@ package config import ( + "context" "log/slog" "strings" @@ -16,7 +17,7 @@ import ( // either set directly on the model or inherited from a custom provider definition. // This is necessary because external providers (like Azure Foundry) may use the alias // names directly as deployment names rather than the pinned version names. -func ResolveModelAliases(cfg *latest.Config, store *modelsdev.Store) { +func ResolveModelAliases(ctx context.Context, cfg *latest.Config, store *modelsdev.Store) { // Resolve model aliases in the models section for name, modelCfg := range cfg.Models { // Skip alias resolution for models with custom base_url (direct or via provider) @@ -27,7 +28,7 @@ func ResolveModelAliases(cfg *latest.Config, store *modelsdev.Store) { continue } - if resolved := store.ResolveModelAlias(modelCfg.Provider, modelCfg.Model); resolved != modelCfg.Model { + if resolved := store.ResolveModelAlias(ctx, modelCfg.Provider, modelCfg.Model); resolved != modelCfg.Model { modelCfg.Model = resolved cfg.Models[name] = modelCfg } @@ -35,7 +36,7 @@ func ResolveModelAliases(cfg *latest.Config, store *modelsdev.Store) { // Resolve model aliases in routing rules for i, rule := range modelCfg.Routing { if provider, model, ok := strings.Cut(rule.Model, "/"); ok { - if resolved := store.ResolveModelAlias(provider, model); resolved != model { + if resolved := store.ResolveModelAlias(ctx, provider, model); resolved != model { modelCfg.Routing[i].Model = provider + "/" + resolved } } @@ -52,7 +53,7 @@ func ResolveModelAliases(cfg *latest.Config, store *modelsdev.Store) { var resolvedModels []string for modelRef := range strings.SplitSeq(agent.Model, ",") { if provider, model, ok := strings.Cut(modelRef, "/"); ok { - if resolved := store.ResolveModelAlias(provider, model); resolved != model { + if resolved := store.ResolveModelAlias(ctx, provider, model); resolved != model { resolvedModels = append(resolvedModels, provider+"/"+resolved) continue } diff --git a/pkg/config/model_alias_test.go b/pkg/config/model_alias_test.go index e5c7fe3ee..d2b1ee102 100644 --- a/pkg/config/model_alias_test.go +++ b/pkg/config/model_alias_test.go @@ -237,7 +237,7 @@ func TestResolveModelAliases(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ResolveModelAliases(tt.cfg, store) + ResolveModelAliases(t.Context(), tt.cfg, store) assert.Equal(t, tt.expected, tt.cfg) }) } diff --git a/pkg/model/provider/bedrock/client.go b/pkg/model/provider/bedrock/client.go index a02cf549d..e8b11287a 100644 --- a/pkg/model/provider/bedrock/client.go +++ b/pkg/model/provider/bedrock/client.go @@ -112,7 +112,7 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro // Detect prompt caching capability at init time for efficiency. // Uses models.dev cache pricing as proxy for capability detection. - cachingSupported := detectCachingSupport(cfg.Model) + cachingSupported := detectCachingSupport(ctx, cfg.Model) slog.Debug("Bedrock client created successfully", "model", cfg.Model, @@ -133,7 +133,7 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro // detectCachingSupport checks if a model supports prompt caching using models.dev data. // Models with non-zero CacheRead/CacheWrite costs support prompt caching. // Returns false on lookup failure (safe default for unsupported models). -func detectCachingSupport(model string) bool { +func detectCachingSupport(ctx context.Context, model string) bool { store, err := modelsdev.NewStore() if err != nil { slog.Debug("Bedrock models store unavailable, prompt caching disabled", "error", err) @@ -141,7 +141,7 @@ func detectCachingSupport(model string) bool { } modelID := "amazon-bedrock/" + model - m, err := store.GetModel(modelID) + m, err := store.GetModel(ctx, modelID) if err != nil { slog.Debug("Bedrock prompt caching disabled: model not found in models.dev", "model_id", modelID, "error", err) diff --git a/pkg/model/provider/bedrock/client_test.go b/pkg/model/provider/bedrock/client_test.go index 4e63ee13b..6d02e2ee4 100644 --- a/pkg/model/provider/bedrock/client_test.go +++ b/pkg/model/provider/bedrock/client_test.go @@ -1249,7 +1249,7 @@ func TestDetectCachingSupport_SupportedModel(t *testing.T) { t.Parallel() // Uses real models.dev lookup to verify Claude models support caching - supported := detectCachingSupport("anthropic.claude-3-5-sonnet-20241022-v2:0") + supported := detectCachingSupport(t.Context(), "anthropic.claude-3-5-sonnet-20241022-v2:0") assert.True(t, supported) } @@ -1257,7 +1257,7 @@ func TestDetectCachingSupport_UnsupportedModel(t *testing.T) { t.Parallel() // Llama doesn't have cache pricing in models.dev - supported := detectCachingSupport("meta.llama3-8b-instruct-v1:0") + supported := detectCachingSupport(t.Context(), "meta.llama3-8b-instruct-v1:0") assert.False(t, supported) } @@ -1265,7 +1265,7 @@ func TestDetectCachingSupport_UnknownModel(t *testing.T) { t.Parallel() // Unknown model should gracefully return false, not panic - supported := detectCachingSupport("nonexistent.model.that.does.not.exist:v1") + supported := detectCachingSupport(t.Context(), "nonexistent.model.that.does.not.exist:v1") assert.False(t, supported) } diff --git a/pkg/modelsdev/store.go b/pkg/modelsdev/store.go index ff996df58..b92534575 100644 --- a/pkg/modelsdev/store.go +++ b/pkg/modelsdev/store.go @@ -21,25 +21,16 @@ const ( ) // Store manages access to the models.dev data. -// The database is loaded lazily on first access and cached for the -// lifetime of the Store. All methods are safe for concurrent use. +// All methods are safe for concurrent use. type Store struct { - db func() (*Database, error) + cacheFile string + mu sync.Mutex + db *Database } -// defaultStore is a cached singleton store instance for repeated access. -var defaultStore = sync.OnceValues(newStoreInternal) - -// NewStore returns the cached default store instance. -// The underlying database is fetched lazily on first access -// from a local cache file or the models.dev API. +// NewStore creates a new models.dev store. +// The database is loaded on first access via GetDatabase. func NewStore() (*Store, error) { - return defaultStore() -} - -// newStoreInternal creates a new models.dev store that loads data -// from the filesystem cache or the network on first access. -func newStoreInternal() (*Store, error) { homeDir, err := os.UserHomeDir() if err != nil { return nil, fmt.Errorf("failed to get user home directory: %w", err) @@ -50,12 +41,8 @@ func newStoreInternal() (*Store, error) { return nil, fmt.Errorf("failed to create cache directory: %w", err) } - cacheFile := filepath.Join(cacheDir, CacheFileName) - return &Store{ - db: sync.OnceValues(func() (*Database, error) { - return loadDatabase(cacheFile) - }), + cacheFile: filepath.Join(cacheDir, CacheFileName), }, nil } @@ -64,19 +51,30 @@ func newStoreInternal() (*Store, error) { // from the network or touches the filesystem, making it suitable for // tests and any scenario where the provider data is already known. func NewDatabaseStore(db *Database) *Store { - return &Store{ - db: func() (*Database, error) { return db, nil }, - } + return &Store{db: db} } // GetDatabase returns the models.dev database, fetching from cache or API as needed. -func (s *Store) GetDatabase() (*Database, error) { - return s.db() +func (s *Store) GetDatabase(ctx context.Context) (*Database, error) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.db != nil { + return s.db, nil + } + + db, err := loadDatabase(ctx, s.cacheFile) + if err != nil { + return nil, err + } + + s.db = db + return db, nil } // GetProvider returns a specific provider by ID. -func (s *Store) GetProvider(providerID string) (*Provider, error) { - db, err := s.db() +func (s *Store) GetProvider(ctx context.Context, providerID string) (*Provider, error) { + db, err := s.GetDatabase(ctx) if err != nil { return nil, err } @@ -90,7 +88,7 @@ func (s *Store) GetProvider(providerID string) (*Provider, error) { } // GetModel returns a specific model by provider ID and model ID. -func (s *Store) GetModel(id string) (*Model, error) { +func (s *Store) GetModel(ctx context.Context, id string) (*Model, error) { parts := strings.SplitN(id, "/", 2) if len(parts) != 2 { return nil, fmt.Errorf("invalid model ID: %q", id) @@ -98,7 +96,7 @@ func (s *Store) GetModel(id string) (*Model, error) { providerID := parts[0] modelID := parts[1] - provider, err := s.GetProvider(providerID) + provider, err := s.GetProvider(ctx, providerID) if err != nil { return nil, err } @@ -130,7 +128,7 @@ func (s *Store) GetModel(id string) (*Model, error) { // loadDatabase loads the database from the local cache file or // falls back to fetching from the models.dev API. -func loadDatabase(cacheFile string) (*Database, error) { +func loadDatabase(ctx context.Context, cacheFile string) (*Database, error) { // Try to load from cache first cached, err := loadFromCache(cacheFile) if err == nil && time.Since(cached.LastRefresh) < refreshInterval { @@ -138,7 +136,7 @@ func loadDatabase(cacheFile string) (*Database, error) { } // Cache is invalid or doesn't exist, fetch from API - database, fetchErr := fetchFromAPI() + database, fetchErr := fetchFromAPI(ctx) if fetchErr != nil { // If API fetch fails, but we have cached data, use it if cached != nil { @@ -156,8 +154,8 @@ func loadDatabase(cacheFile string) (*Database, error) { return database, nil } -func fetchFromAPI() (*Database, error) { - req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, ModelsDevAPIURL, http.NoBody) +func fetchFromAPI(ctx context.Context) (*Database, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, ModelsDevAPIURL, http.NoBody) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } @@ -225,7 +223,7 @@ var datePattern = regexp.MustCompile(`-\d{4}-?\d{2}-?\d{2}$`) // For example, ("anthropic", "claude-sonnet-4-5") might resolve to "claude-sonnet-4-5-20250929". // If the model is not an alias (already pinned or unknown), the original model name is returned. // This method uses the models.dev database to find the corresponding pinned version. -func (s *Store) ResolveModelAlias(providerID, modelName string) string { +func (s *Store) ResolveModelAlias(ctx context.Context, providerID, modelName string) string { if providerID == "" || modelName == "" { return modelName } @@ -236,7 +234,7 @@ func (s *Store) ResolveModelAlias(providerID, modelName string) string { } // Get the provider from the database - provider, err := s.GetProvider(providerID) + provider, err := s.GetProvider(ctx, providerID) if err != nil { return modelName } @@ -285,7 +283,7 @@ func isBedrockRegionPrefix(prefix string) bool { // - If modelID is empty or not in "provider/model" format, returns true (fail-open) // - If models.dev lookup fails for any reason, returns true (fail-open) // - If lookup succeeds, returns the model's Reasoning field value -func ModelSupportsReasoning(modelID string) bool { +func ModelSupportsReasoning(ctx context.Context, modelID string) bool { // Fail-open for empty model ID if modelID == "" { return true @@ -303,7 +301,7 @@ func ModelSupportsReasoning(modelID string) bool { return true } - model, err := store.GetModel(modelID) + model, err := store.GetModel(ctx, modelID) if err != nil { slog.Debug("Failed to lookup model in models.dev, assuming reasoning supported to allow user choice", "model_id", modelID, "error", err) return true diff --git a/pkg/modelsdev/store_test.go b/pkg/modelsdev/store_test.go index bfe43465a..a6742db04 100644 --- a/pkg/modelsdev/store_test.go +++ b/pkg/modelsdev/store_test.go @@ -57,7 +57,7 @@ func TestResolveModelAlias(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := store.ResolveModelAlias(tt.provider, tt.model) + result := store.ResolveModelAlias(t.Context(), tt.provider, tt.model) assert.Equal(t, tt.expected, result) }) } diff --git a/pkg/rag/strategy/semantic_embeddings.go b/pkg/rag/strategy/semantic_embeddings.go index 1d69913e2..b22e46bf4 100644 --- a/pkg/rag/strategy/semantic_embeddings.go +++ b/pkg/rag/strategy/semantic_embeddings.go @@ -506,7 +506,7 @@ func calculateSemanticUsageCost(modelsStore modelStore, modelID string, usage *c return 0 } - model, err := modelsStore.GetModel(modelID) + model, err := modelsStore.GetModel(context.Background(), modelID) if err != nil { slog.Debug("Failed to get semantic model pricing from models.dev, cost will be 0", "model_id", modelID, diff --git a/pkg/rag/strategy/vector_store.go b/pkg/rag/strategy/vector_store.go index 80afff926..19d6a4c94 100644 --- a/pkg/rag/strategy/vector_store.go +++ b/pkg/rag/strategy/vector_store.go @@ -87,7 +87,7 @@ type VectorStore struct { } type modelStore interface { - GetModel(modelID string) (*modelsdev.Model, error) + GetModel(ctx context.Context, modelID string) (*modelsdev.Model, error) } // EmbeddingInputBuilder builds the string that will be sent to the embedding model @@ -174,7 +174,7 @@ func (s *VectorStore) calculateCost(tokens int64) float64 { return 0 } - model, err := s.modelsStore.GetModel(s.modelID) + model, err := s.modelsStore.GetModel(context.Background(), s.modelID) if err != nil { slog.Debug("Failed to get model pricing from models.dev, cost will be 0", "model_id", s.modelID, diff --git a/pkg/runtime/model_switcher.go b/pkg/runtime/model_switcher.go index 23cbe3aaf..9b49c5a0d 100644 --- a/pkg/runtime/model_switcher.go +++ b/pkg/runtime/model_switcher.go @@ -286,7 +286,7 @@ func (r *LocalRuntime) AvailableModels(ctx context.Context) []ModelChoice { // buildCatalogChoices builds ModelChoice entries from the models.dev catalog, // filtered by supported providers and available credentials. func (r *LocalRuntime) buildCatalogChoices(ctx context.Context) []ModelChoice { - db, err := r.modelsStore.GetDatabase() + db, err := r.modelsStore.GetDatabase(ctx) if err != nil { slog.Debug("Failed to get models.dev database for catalog", "error", err) return nil @@ -446,7 +446,7 @@ func (r *LocalRuntime) createProviderFromConfig(ctx context.Context, cfg *latest if cfg.MaxTokens != nil { opts = append(opts, options.WithMaxTokens(*cfg.MaxTokens)) } else if r.modelsStore != nil { - m, err := r.modelsStore.GetModel(cfg.Provider + "/" + cfg.Model) + m, err := r.modelsStore.GetModel(ctx, cfg.Provider+"/"+cfg.Model) if err == nil && m != nil { opts = append(opts, options.WithMaxTokens(m.Limit.Output)) } diff --git a/pkg/runtime/model_switcher_test.go b/pkg/runtime/model_switcher_test.go index a4c973007..ddded0ea4 100644 --- a/pkg/runtime/model_switcher_test.go +++ b/pkg/runtime/model_switcher_test.go @@ -27,7 +27,7 @@ type mockCatalogStore struct { db *modelsdev.Database } -func (m *mockCatalogStore) GetDatabase() (*modelsdev.Database, error) { +func (m *mockCatalogStore) GetDatabase(_ context.Context) (*modelsdev.Database, error) { return m.db, nil } diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index dbfa9128d..dc792b297 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -171,8 +171,8 @@ type CurrentAgentInfo struct { } type ModelStore interface { - GetModel(modelID string) (*modelsdev.Model, error) - GetDatabase() (*modelsdev.Database, error) + GetModel(ctx context.Context, modelID string) (*modelsdev.Model, error) + GetDatabase(ctx context.Context) (*modelsdev.Database, error) } // RAGInitializer is implemented by runtimes that support background RAG initialization. @@ -988,7 +988,7 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c modelID := model.ID() slog.Debug("Using agent", "agent", a.Name(), "model", modelID) slog.Debug("Getting model definition", "model_id", modelID) - m, err := r.modelsStore.GetModel(modelID) + m, err := r.modelsStore.GetModel(ctx, modelID) if err != nil { slog.Debug("Failed to get model definition", "error", err) } diff --git a/pkg/runtime/runtime_test.go b/pkg/runtime/runtime_test.go index a6b3f7ddb..7779045b4 100644 --- a/pkg/runtime/runtime_test.go +++ b/pkg/runtime/runtime_test.go @@ -175,7 +175,7 @@ type mockModelStore struct { ModelStore } -func (m mockModelStore) GetModel(string) (*modelsdev.Model, error) { +func (m mockModelStore) GetModel(_ context.Context, _ string) (*modelsdev.Model, error) { return nil, nil } @@ -681,7 +681,7 @@ type mockModelStoreWithLimit struct { limit int } -func (m mockModelStoreWithLimit) GetModel(string) (*modelsdev.Model, error) { +func (m mockModelStoreWithLimit) GetModel(_ context.Context, _ string) (*modelsdev.Model, error) { return &modelsdev.Model{Limit: modelsdev.Limit{Context: m.limit}, Cost: &modelsdev.Cost{}}, nil } diff --git a/pkg/teamloader/teamloader.go b/pkg/teamloader/teamloader.go index bcc97dcb9..8efb2846c 100644 --- a/pkg/teamloader/teamloader.go +++ b/pkg/teamloader/teamloader.go @@ -116,7 +116,7 @@ func LoadWithConfig(ctx context.Context, agentSource config.Source, runConfig *c if err != nil { slog.Debug("Failed to create modelsdev store for alias resolution", "error", err) } else { - config.ResolveModelAliases(cfg, modelsStore) + config.ResolveModelAliases(ctx, cfg, modelsStore) } // Apply model overrides from CLI flags before checking required env vars @@ -318,7 +318,7 @@ func getModelsForAgent(ctx context.Context, cfg *latest.Config, a *latest.AgentC if err != nil { return nil, false, err } - m, err := modelsStore.GetModel(modelCfg.Provider + "/" + modelCfg.Model) + m, err := modelsStore.GetModel(ctx, modelCfg.Provider+"/"+modelCfg.Model) if err == nil { maxTokens = &m.Limit.Output } @@ -381,7 +381,7 @@ func getFallbackModelsForAgent(ctx context.Context, cfg *latest.Config, a *lates if err != nil { return nil, err } - m, err := modelsStore.GetModel(modelCfg.Provider + "/" + modelCfg.Model) + m, err := modelsStore.GetModel(ctx, modelCfg.Provider+"/"+modelCfg.Model) if err == nil { maxTokens = &m.Limit.Output } diff --git a/pkg/tui/commands/commands.go b/pkg/tui/commands/commands.go index 9834acccd..34a978aed 100644 --- a/pkg/tui/commands/commands.go +++ b/pkg/tui/commands/commands.go @@ -262,7 +262,7 @@ func BuildCommandCategories(ctx context.Context, application *app.App) []Categor // Check if the current model supports reasoning; hide /think if not currentModel := application.CurrentAgentModel() - if !modelsdev.ModelSupportsReasoning(currentModel) { + if !modelsdev.ModelSupportsReasoning(ctx, currentModel) { filtered := make([]Item, 0, len(sessionCommands)) for _, cmd := range sessionCommands { if cmd.ID != "session.think" { diff --git a/pkg/tui/components/sidebar/sidebar.go b/pkg/tui/components/sidebar/sidebar.go index 77cd0dbce..72f1b0617 100644 --- a/pkg/tui/components/sidebar/sidebar.go +++ b/pkg/tui/components/sidebar/sidebar.go @@ -1,6 +1,7 @@ package sidebar import ( + "context" "fmt" "log/slog" "maps" @@ -252,7 +253,9 @@ func (m *model) SetAgentInfo(agentName, modelID, description string) { m.currentAgent = agentName m.agentModel = modelID m.agentDescription = description - m.reasoningSupported = modelsdev.ModelSupportsReasoning(modelID) + // TODO: this can block for up to 30s on the first call if the cache is cold, + // which freezes the TUI. Move to an async command. + m.reasoningSupported = modelsdev.ModelSupportsReasoning(context.TODO(), modelID) // Update the provider and model in availableAgents for the current agent. // This is important when fallback models from different providers are used. diff --git a/pkg/tui/handlers.go b/pkg/tui/handlers.go index 0d050b43c..515688921 100644 --- a/pkg/tui/handlers.go +++ b/pkg/tui/handlers.go @@ -330,7 +330,9 @@ func (a *appModel) handleToggleYolo() (tea.Model, tea.Cmd) { func (a *appModel) handleToggleThinking() (tea.Model, tea.Cmd) { // Check if the current model supports reasoning currentModel := a.application.CurrentAgentModel() - if !modelsdev.ModelSupportsReasoning(currentModel) { + // TODO: this can block for up to 30s on the first call if the cache is cold, + // which freezes the TUI. Move to an async command. + if !modelsdev.ModelSupportsReasoning(context.TODO(), currentModel) { return a, notification.InfoCmd("Thinking/reasoning is not supported for the current model") }