From 753dffeafa299dfa6690907aff55294a3d8ab361 Mon Sep 17 00:00:00 2001 From: Djordje Lukic Date: Mon, 16 Feb 2026 10:44:15 +0100 Subject: [PATCH] Thread context.Context through modelsdev store API Replace the lazy-loaded closure-based Store with a simple struct that holds a cache file path and loads the database on first GetDatabase call. Pass caller-provided context to the HTTP request in fetchFromAPI instead of using context.Background(). NewStore() remains context-free (no I/O). All Store methods (GetDatabase, GetProvider, GetModel, ResolveModelAlias) and ModelSupportsReasoning accept context for the network call. Updated all callsites, using context.Background() only in TUI callbacks where no caller context is available. Assisted-By: cagent --- pkg/config/examples_test.go | 2 +- pkg/config/model_alias.go | 9 +-- pkg/config/model_alias_test.go | 2 +- pkg/model/provider/bedrock/client.go | 6 +- pkg/model/provider/bedrock/client_test.go | 6 +- pkg/modelsdev/store.go | 72 +++++++++++------------ pkg/modelsdev/store_test.go | 2 +- pkg/rag/strategy/semantic_embeddings.go | 2 +- pkg/rag/strategy/vector_store.go | 4 +- pkg/runtime/model_switcher.go | 4 +- pkg/runtime/model_switcher_test.go | 2 +- pkg/runtime/runtime.go | 6 +- pkg/runtime/runtime_test.go | 4 +- pkg/teamloader/teamloader.go | 6 +- pkg/tui/commands/commands.go | 2 +- pkg/tui/components/sidebar/sidebar.go | 5 +- pkg/tui/handlers.go | 4 +- 17 files changed, 71 insertions(+), 67 deletions(-) diff --git a/pkg/config/examples_test.go b/pkg/config/examples_test.go index 070b9f515..f20b583f2 100644 --- a/pkg/config/examples_test.go +++ b/pkg/config/examples_test.go @@ -70,7 +70,7 @@ func TestParseExamples(t *testing.T) { continue } - model, err := modelsStore.GetModel(model.Provider + "/" + model.Model) + model, err := modelsStore.GetModel(t.Context(), model.Provider+"/"+model.Model) require.NoError(t, err) require.NotNil(t, model) } diff --git a/pkg/config/model_alias.go b/pkg/config/model_alias.go index e322bd86c..e1bdfd64e 100644 --- a/pkg/config/model_alias.go +++ b/pkg/config/model_alias.go @@ -1,6 +1,7 @@ package config import ( + "context" "log/slog" "strings" @@ -16,7 +17,7 @@ import ( // either set directly on the model or inherited from a custom provider definition. // This is necessary because external providers (like Azure Foundry) may use the alias // names directly as deployment names rather than the pinned version names. -func ResolveModelAliases(cfg *latest.Config, store *modelsdev.Store) { +func ResolveModelAliases(ctx context.Context, cfg *latest.Config, store *modelsdev.Store) { // Resolve model aliases in the models section for name, modelCfg := range cfg.Models { // Skip alias resolution for models with custom base_url (direct or via provider) @@ -27,7 +28,7 @@ func ResolveModelAliases(cfg *latest.Config, store *modelsdev.Store) { continue } - if resolved := store.ResolveModelAlias(modelCfg.Provider, modelCfg.Model); resolved != modelCfg.Model { + if resolved := store.ResolveModelAlias(ctx, modelCfg.Provider, modelCfg.Model); resolved != modelCfg.Model { modelCfg.Model = resolved cfg.Models[name] = modelCfg } @@ -35,7 +36,7 @@ func ResolveModelAliases(cfg *latest.Config, store *modelsdev.Store) { // Resolve model aliases in routing rules for i, rule := range modelCfg.Routing { if provider, model, ok := strings.Cut(rule.Model, "/"); ok { - if resolved := store.ResolveModelAlias(provider, model); resolved != model { + if resolved := store.ResolveModelAlias(ctx, provider, model); resolved != model { modelCfg.Routing[i].Model = provider + "/" + resolved } } @@ -52,7 +53,7 @@ func ResolveModelAliases(cfg *latest.Config, store *modelsdev.Store) { var resolvedModels []string for modelRef := range strings.SplitSeq(agent.Model, ",") { if provider, model, ok := strings.Cut(modelRef, "/"); ok { - if resolved := store.ResolveModelAlias(provider, model); resolved != model { + if resolved := store.ResolveModelAlias(ctx, provider, model); resolved != model { resolvedModels = append(resolvedModels, provider+"/"+resolved) continue } diff --git a/pkg/config/model_alias_test.go b/pkg/config/model_alias_test.go index e5c7fe3ee..d2b1ee102 100644 --- a/pkg/config/model_alias_test.go +++ b/pkg/config/model_alias_test.go @@ -237,7 +237,7 @@ func TestResolveModelAliases(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ResolveModelAliases(tt.cfg, store) + ResolveModelAliases(t.Context(), tt.cfg, store) assert.Equal(t, tt.expected, tt.cfg) }) } diff --git a/pkg/model/provider/bedrock/client.go b/pkg/model/provider/bedrock/client.go index a02cf549d..e8b11287a 100644 --- a/pkg/model/provider/bedrock/client.go +++ b/pkg/model/provider/bedrock/client.go @@ -112,7 +112,7 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro // Detect prompt caching capability at init time for efficiency. // Uses models.dev cache pricing as proxy for capability detection. - cachingSupported := detectCachingSupport(cfg.Model) + cachingSupported := detectCachingSupport(ctx, cfg.Model) slog.Debug("Bedrock client created successfully", "model", cfg.Model, @@ -133,7 +133,7 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro // detectCachingSupport checks if a model supports prompt caching using models.dev data. // Models with non-zero CacheRead/CacheWrite costs support prompt caching. // Returns false on lookup failure (safe default for unsupported models). -func detectCachingSupport(model string) bool { +func detectCachingSupport(ctx context.Context, model string) bool { store, err := modelsdev.NewStore() if err != nil { slog.Debug("Bedrock models store unavailable, prompt caching disabled", "error", err) @@ -141,7 +141,7 @@ func detectCachingSupport(model string) bool { } modelID := "amazon-bedrock/" + model - m, err := store.GetModel(modelID) + m, err := store.GetModel(ctx, modelID) if err != nil { slog.Debug("Bedrock prompt caching disabled: model not found in models.dev", "model_id", modelID, "error", err) diff --git a/pkg/model/provider/bedrock/client_test.go b/pkg/model/provider/bedrock/client_test.go index 4e63ee13b..6d02e2ee4 100644 --- a/pkg/model/provider/bedrock/client_test.go +++ b/pkg/model/provider/bedrock/client_test.go @@ -1249,7 +1249,7 @@ func TestDetectCachingSupport_SupportedModel(t *testing.T) { t.Parallel() // Uses real models.dev lookup to verify Claude models support caching - supported := detectCachingSupport("anthropic.claude-3-5-sonnet-20241022-v2:0") + supported := detectCachingSupport(t.Context(), "anthropic.claude-3-5-sonnet-20241022-v2:0") assert.True(t, supported) } @@ -1257,7 +1257,7 @@ func TestDetectCachingSupport_UnsupportedModel(t *testing.T) { t.Parallel() // Llama doesn't have cache pricing in models.dev - supported := detectCachingSupport("meta.llama3-8b-instruct-v1:0") + supported := detectCachingSupport(t.Context(), "meta.llama3-8b-instruct-v1:0") assert.False(t, supported) } @@ -1265,7 +1265,7 @@ func TestDetectCachingSupport_UnknownModel(t *testing.T) { t.Parallel() // Unknown model should gracefully return false, not panic - supported := detectCachingSupport("nonexistent.model.that.does.not.exist:v1") + supported := detectCachingSupport(t.Context(), "nonexistent.model.that.does.not.exist:v1") assert.False(t, supported) } diff --git a/pkg/modelsdev/store.go b/pkg/modelsdev/store.go index ff996df58..b92534575 100644 --- a/pkg/modelsdev/store.go +++ b/pkg/modelsdev/store.go @@ -21,25 +21,16 @@ const ( ) // Store manages access to the models.dev data. -// The database is loaded lazily on first access and cached for the -// lifetime of the Store. All methods are safe for concurrent use. +// All methods are safe for concurrent use. type Store struct { - db func() (*Database, error) + cacheFile string + mu sync.Mutex + db *Database } -// defaultStore is a cached singleton store instance for repeated access. -var defaultStore = sync.OnceValues(newStoreInternal) - -// NewStore returns the cached default store instance. -// The underlying database is fetched lazily on first access -// from a local cache file or the models.dev API. +// NewStore creates a new models.dev store. +// The database is loaded on first access via GetDatabase. func NewStore() (*Store, error) { - return defaultStore() -} - -// newStoreInternal creates a new models.dev store that loads data -// from the filesystem cache or the network on first access. -func newStoreInternal() (*Store, error) { homeDir, err := os.UserHomeDir() if err != nil { return nil, fmt.Errorf("failed to get user home directory: %w", err) @@ -50,12 +41,8 @@ func newStoreInternal() (*Store, error) { return nil, fmt.Errorf("failed to create cache directory: %w", err) } - cacheFile := filepath.Join(cacheDir, CacheFileName) - return &Store{ - db: sync.OnceValues(func() (*Database, error) { - return loadDatabase(cacheFile) - }), + cacheFile: filepath.Join(cacheDir, CacheFileName), }, nil } @@ -64,19 +51,30 @@ func newStoreInternal() (*Store, error) { // from the network or touches the filesystem, making it suitable for // tests and any scenario where the provider data is already known. func NewDatabaseStore(db *Database) *Store { - return &Store{ - db: func() (*Database, error) { return db, nil }, - } + return &Store{db: db} } // GetDatabase returns the models.dev database, fetching from cache or API as needed. -func (s *Store) GetDatabase() (*Database, error) { - return s.db() +func (s *Store) GetDatabase(ctx context.Context) (*Database, error) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.db != nil { + return s.db, nil + } + + db, err := loadDatabase(ctx, s.cacheFile) + if err != nil { + return nil, err + } + + s.db = db + return db, nil } // GetProvider returns a specific provider by ID. -func (s *Store) GetProvider(providerID string) (*Provider, error) { - db, err := s.db() +func (s *Store) GetProvider(ctx context.Context, providerID string) (*Provider, error) { + db, err := s.GetDatabase(ctx) if err != nil { return nil, err } @@ -90,7 +88,7 @@ func (s *Store) GetProvider(providerID string) (*Provider, error) { } // GetModel returns a specific model by provider ID and model ID. -func (s *Store) GetModel(id string) (*Model, error) { +func (s *Store) GetModel(ctx context.Context, id string) (*Model, error) { parts := strings.SplitN(id, "/", 2) if len(parts) != 2 { return nil, fmt.Errorf("invalid model ID: %q", id) @@ -98,7 +96,7 @@ func (s *Store) GetModel(id string) (*Model, error) { providerID := parts[0] modelID := parts[1] - provider, err := s.GetProvider(providerID) + provider, err := s.GetProvider(ctx, providerID) if err != nil { return nil, err } @@ -130,7 +128,7 @@ func (s *Store) GetModel(id string) (*Model, error) { // loadDatabase loads the database from the local cache file or // falls back to fetching from the models.dev API. -func loadDatabase(cacheFile string) (*Database, error) { +func loadDatabase(ctx context.Context, cacheFile string) (*Database, error) { // Try to load from cache first cached, err := loadFromCache(cacheFile) if err == nil && time.Since(cached.LastRefresh) < refreshInterval { @@ -138,7 +136,7 @@ func loadDatabase(cacheFile string) (*Database, error) { } // Cache is invalid or doesn't exist, fetch from API - database, fetchErr := fetchFromAPI() + database, fetchErr := fetchFromAPI(ctx) if fetchErr != nil { // If API fetch fails, but we have cached data, use it if cached != nil { @@ -156,8 +154,8 @@ func loadDatabase(cacheFile string) (*Database, error) { return database, nil } -func fetchFromAPI() (*Database, error) { - req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, ModelsDevAPIURL, http.NoBody) +func fetchFromAPI(ctx context.Context) (*Database, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, ModelsDevAPIURL, http.NoBody) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } @@ -225,7 +223,7 @@ var datePattern = regexp.MustCompile(`-\d{4}-?\d{2}-?\d{2}$`) // For example, ("anthropic", "claude-sonnet-4-5") might resolve to "claude-sonnet-4-5-20250929". // If the model is not an alias (already pinned or unknown), the original model name is returned. // This method uses the models.dev database to find the corresponding pinned version. -func (s *Store) ResolveModelAlias(providerID, modelName string) string { +func (s *Store) ResolveModelAlias(ctx context.Context, providerID, modelName string) string { if providerID == "" || modelName == "" { return modelName } @@ -236,7 +234,7 @@ func (s *Store) ResolveModelAlias(providerID, modelName string) string { } // Get the provider from the database - provider, err := s.GetProvider(providerID) + provider, err := s.GetProvider(ctx, providerID) if err != nil { return modelName } @@ -285,7 +283,7 @@ func isBedrockRegionPrefix(prefix string) bool { // - If modelID is empty or not in "provider/model" format, returns true (fail-open) // - If models.dev lookup fails for any reason, returns true (fail-open) // - If lookup succeeds, returns the model's Reasoning field value -func ModelSupportsReasoning(modelID string) bool { +func ModelSupportsReasoning(ctx context.Context, modelID string) bool { // Fail-open for empty model ID if modelID == "" { return true @@ -303,7 +301,7 @@ func ModelSupportsReasoning(modelID string) bool { return true } - model, err := store.GetModel(modelID) + model, err := store.GetModel(ctx, modelID) if err != nil { slog.Debug("Failed to lookup model in models.dev, assuming reasoning supported to allow user choice", "model_id", modelID, "error", err) return true diff --git a/pkg/modelsdev/store_test.go b/pkg/modelsdev/store_test.go index bfe43465a..a6742db04 100644 --- a/pkg/modelsdev/store_test.go +++ b/pkg/modelsdev/store_test.go @@ -57,7 +57,7 @@ func TestResolveModelAlias(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := store.ResolveModelAlias(tt.provider, tt.model) + result := store.ResolveModelAlias(t.Context(), tt.provider, tt.model) assert.Equal(t, tt.expected, result) }) } diff --git a/pkg/rag/strategy/semantic_embeddings.go b/pkg/rag/strategy/semantic_embeddings.go index 1d69913e2..b22e46bf4 100644 --- a/pkg/rag/strategy/semantic_embeddings.go +++ b/pkg/rag/strategy/semantic_embeddings.go @@ -506,7 +506,7 @@ func calculateSemanticUsageCost(modelsStore modelStore, modelID string, usage *c return 0 } - model, err := modelsStore.GetModel(modelID) + model, err := modelsStore.GetModel(context.Background(), modelID) if err != nil { slog.Debug("Failed to get semantic model pricing from models.dev, cost will be 0", "model_id", modelID, diff --git a/pkg/rag/strategy/vector_store.go b/pkg/rag/strategy/vector_store.go index 80afff926..19d6a4c94 100644 --- a/pkg/rag/strategy/vector_store.go +++ b/pkg/rag/strategy/vector_store.go @@ -87,7 +87,7 @@ type VectorStore struct { } type modelStore interface { - GetModel(modelID string) (*modelsdev.Model, error) + GetModel(ctx context.Context, modelID string) (*modelsdev.Model, error) } // EmbeddingInputBuilder builds the string that will be sent to the embedding model @@ -174,7 +174,7 @@ func (s *VectorStore) calculateCost(tokens int64) float64 { return 0 } - model, err := s.modelsStore.GetModel(s.modelID) + model, err := s.modelsStore.GetModel(context.Background(), s.modelID) if err != nil { slog.Debug("Failed to get model pricing from models.dev, cost will be 0", "model_id", s.modelID, diff --git a/pkg/runtime/model_switcher.go b/pkg/runtime/model_switcher.go index 23cbe3aaf..9b49c5a0d 100644 --- a/pkg/runtime/model_switcher.go +++ b/pkg/runtime/model_switcher.go @@ -286,7 +286,7 @@ func (r *LocalRuntime) AvailableModels(ctx context.Context) []ModelChoice { // buildCatalogChoices builds ModelChoice entries from the models.dev catalog, // filtered by supported providers and available credentials. func (r *LocalRuntime) buildCatalogChoices(ctx context.Context) []ModelChoice { - db, err := r.modelsStore.GetDatabase() + db, err := r.modelsStore.GetDatabase(ctx) if err != nil { slog.Debug("Failed to get models.dev database for catalog", "error", err) return nil @@ -446,7 +446,7 @@ func (r *LocalRuntime) createProviderFromConfig(ctx context.Context, cfg *latest if cfg.MaxTokens != nil { opts = append(opts, options.WithMaxTokens(*cfg.MaxTokens)) } else if r.modelsStore != nil { - m, err := r.modelsStore.GetModel(cfg.Provider + "/" + cfg.Model) + m, err := r.modelsStore.GetModel(ctx, cfg.Provider+"/"+cfg.Model) if err == nil && m != nil { opts = append(opts, options.WithMaxTokens(m.Limit.Output)) } diff --git a/pkg/runtime/model_switcher_test.go b/pkg/runtime/model_switcher_test.go index a4c973007..ddded0ea4 100644 --- a/pkg/runtime/model_switcher_test.go +++ b/pkg/runtime/model_switcher_test.go @@ -27,7 +27,7 @@ type mockCatalogStore struct { db *modelsdev.Database } -func (m *mockCatalogStore) GetDatabase() (*modelsdev.Database, error) { +func (m *mockCatalogStore) GetDatabase(_ context.Context) (*modelsdev.Database, error) { return m.db, nil } diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index dbfa9128d..dc792b297 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -171,8 +171,8 @@ type CurrentAgentInfo struct { } type ModelStore interface { - GetModel(modelID string) (*modelsdev.Model, error) - GetDatabase() (*modelsdev.Database, error) + GetModel(ctx context.Context, modelID string) (*modelsdev.Model, error) + GetDatabase(ctx context.Context) (*modelsdev.Database, error) } // RAGInitializer is implemented by runtimes that support background RAG initialization. @@ -988,7 +988,7 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c modelID := model.ID() slog.Debug("Using agent", "agent", a.Name(), "model", modelID) slog.Debug("Getting model definition", "model_id", modelID) - m, err := r.modelsStore.GetModel(modelID) + m, err := r.modelsStore.GetModel(ctx, modelID) if err != nil { slog.Debug("Failed to get model definition", "error", err) } diff --git a/pkg/runtime/runtime_test.go b/pkg/runtime/runtime_test.go index a6b3f7ddb..7779045b4 100644 --- a/pkg/runtime/runtime_test.go +++ b/pkg/runtime/runtime_test.go @@ -175,7 +175,7 @@ type mockModelStore struct { ModelStore } -func (m mockModelStore) GetModel(string) (*modelsdev.Model, error) { +func (m mockModelStore) GetModel(_ context.Context, _ string) (*modelsdev.Model, error) { return nil, nil } @@ -681,7 +681,7 @@ type mockModelStoreWithLimit struct { limit int } -func (m mockModelStoreWithLimit) GetModel(string) (*modelsdev.Model, error) { +func (m mockModelStoreWithLimit) GetModel(_ context.Context, _ string) (*modelsdev.Model, error) { return &modelsdev.Model{Limit: modelsdev.Limit{Context: m.limit}, Cost: &modelsdev.Cost{}}, nil } diff --git a/pkg/teamloader/teamloader.go b/pkg/teamloader/teamloader.go index bcc97dcb9..8efb2846c 100644 --- a/pkg/teamloader/teamloader.go +++ b/pkg/teamloader/teamloader.go @@ -116,7 +116,7 @@ func LoadWithConfig(ctx context.Context, agentSource config.Source, runConfig *c if err != nil { slog.Debug("Failed to create modelsdev store for alias resolution", "error", err) } else { - config.ResolveModelAliases(cfg, modelsStore) + config.ResolveModelAliases(ctx, cfg, modelsStore) } // Apply model overrides from CLI flags before checking required env vars @@ -318,7 +318,7 @@ func getModelsForAgent(ctx context.Context, cfg *latest.Config, a *latest.AgentC if err != nil { return nil, false, err } - m, err := modelsStore.GetModel(modelCfg.Provider + "/" + modelCfg.Model) + m, err := modelsStore.GetModel(ctx, modelCfg.Provider+"/"+modelCfg.Model) if err == nil { maxTokens = &m.Limit.Output } @@ -381,7 +381,7 @@ func getFallbackModelsForAgent(ctx context.Context, cfg *latest.Config, a *lates if err != nil { return nil, err } - m, err := modelsStore.GetModel(modelCfg.Provider + "/" + modelCfg.Model) + m, err := modelsStore.GetModel(ctx, modelCfg.Provider+"/"+modelCfg.Model) if err == nil { maxTokens = &m.Limit.Output } diff --git a/pkg/tui/commands/commands.go b/pkg/tui/commands/commands.go index 9834acccd..34a978aed 100644 --- a/pkg/tui/commands/commands.go +++ b/pkg/tui/commands/commands.go @@ -262,7 +262,7 @@ func BuildCommandCategories(ctx context.Context, application *app.App) []Categor // Check if the current model supports reasoning; hide /think if not currentModel := application.CurrentAgentModel() - if !modelsdev.ModelSupportsReasoning(currentModel) { + if !modelsdev.ModelSupportsReasoning(ctx, currentModel) { filtered := make([]Item, 0, len(sessionCommands)) for _, cmd := range sessionCommands { if cmd.ID != "session.think" { diff --git a/pkg/tui/components/sidebar/sidebar.go b/pkg/tui/components/sidebar/sidebar.go index 77cd0dbce..72f1b0617 100644 --- a/pkg/tui/components/sidebar/sidebar.go +++ b/pkg/tui/components/sidebar/sidebar.go @@ -1,6 +1,7 @@ package sidebar import ( + "context" "fmt" "log/slog" "maps" @@ -252,7 +253,9 @@ func (m *model) SetAgentInfo(agentName, modelID, description string) { m.currentAgent = agentName m.agentModel = modelID m.agentDescription = description - m.reasoningSupported = modelsdev.ModelSupportsReasoning(modelID) + // TODO: this can block for up to 30s on the first call if the cache is cold, + // which freezes the TUI. Move to an async command. + m.reasoningSupported = modelsdev.ModelSupportsReasoning(context.TODO(), modelID) // Update the provider and model in availableAgents for the current agent. // This is important when fallback models from different providers are used. diff --git a/pkg/tui/handlers.go b/pkg/tui/handlers.go index 0d050b43c..515688921 100644 --- a/pkg/tui/handlers.go +++ b/pkg/tui/handlers.go @@ -330,7 +330,9 @@ func (a *appModel) handleToggleYolo() (tea.Model, tea.Cmd) { func (a *appModel) handleToggleThinking() (tea.Model, tea.Cmd) { // Check if the current model supports reasoning currentModel := a.application.CurrentAgentModel() - if !modelsdev.ModelSupportsReasoning(currentModel) { + // TODO: this can block for up to 30s on the first call if the cache is cold, + // which freezes the TUI. Move to an async command. + if !modelsdev.ModelSupportsReasoning(context.TODO(), currentModel) { return a, notification.InfoCmd("Thinking/reasoning is not supported for the current model") }