From 793b12037328e9d79a1a61540ac4aa46f915def5 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Fri, 13 Feb 2026 17:32:22 +0100 Subject: [PATCH 1/5] Make it easier to mock model providers Signed-off-by: David Gageot --- pkg/config/model_alias.go | 8 +------- pkg/config/model_alias_test.go | 10 ++-------- pkg/modelsdev/store.go | 14 ++++++++------ pkg/modelsdev/store_test.go | 8 +------- pkg/teamloader/teamloader.go | 7 ++++++- 5 files changed, 18 insertions(+), 29 deletions(-) diff --git a/pkg/config/model_alias.go b/pkg/config/model_alias.go index 93baa9865..e1bdfd64e 100644 --- a/pkg/config/model_alias.go +++ b/pkg/config/model_alias.go @@ -17,13 +17,7 @@ import ( // either set directly on the model or inherited from a custom provider definition. // This is necessary because external providers (like Azure Foundry) may use the alias // names directly as deployment names rather than the pinned version names. -func ResolveModelAliases(ctx context.Context, cfg *latest.Config) { - store, err := modelsdev.NewStore() - if err != nil { - slog.Debug("Failed to create modelsdev store for alias resolution", "error", err) - return - } - +func ResolveModelAliases(ctx context.Context, cfg *latest.Config, store *modelsdev.Store) { // Resolve model aliases in the models section for name, modelCfg := range cfg.Models { // Skip alias resolution for models with custom base_url (direct or via provider) diff --git a/pkg/config/model_alias_test.go b/pkg/config/model_alias_test.go index 944ff3810..8215f27f3 100644 --- a/pkg/config/model_alias_test.go +++ b/pkg/config/model_alias_test.go @@ -4,7 +4,6 @@ import ( "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "github.com/docker/cagent/pkg/config/latest" "github.com/docker/cagent/pkg/modelsdev" @@ -27,12 +26,7 @@ func TestResolveModelAliases(t *testing.T) { }, } - store, err := modelsdev.NewStore(modelsdev.WithCacheDir(t.TempDir())) - require.NoError(t, err) - store.SetDatabaseForTesting(mockData) - t.Cleanup(func() { - store.SetDatabaseForTesting(nil) - }) + store := modelsdev.NewDatabaseStore(mockData) ctx := t.Context() @@ -245,7 +239,7 @@ func TestResolveModelAliases(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ResolveModelAliases(ctx, tt.cfg) + ResolveModelAliases(ctx, tt.cfg, store) assert.Equal(t, tt.expected, tt.cfg) }) } diff --git a/pkg/modelsdev/store.go b/pkg/modelsdev/store.go index b8475d12f..b3f617262 100644 --- a/pkg/modelsdev/store.go +++ b/pkg/modelsdev/store.go @@ -272,12 +272,14 @@ func (s *Store) isCacheValid(cached *CachedData) bool { return time.Since(cached.LastRefresh) < s.refreshInterval } -// SetDatabaseForTesting sets the in-memory database cache for testing purposes. -// This method should only be used in tests. -func (s *Store) SetDatabaseForTesting(db *Database) { - s.dbCacheMu.Lock() - defer s.dbCacheMu.Unlock() - s.dbCache = db +// NewDatabaseStore creates a Store pre-populated with the given database. +// The returned store serves data entirely from memory and never fetches +// from the network or touches the filesystem, making it suitable for +// tests and any scenario where the provider data is already known. +func NewDatabaseStore(db *Database) *Store { + return &Store{ + dbCache: db, + } } // datePattern matches date suffixes like -20251101, -2024-11-20, etc. diff --git a/pkg/modelsdev/store_test.go b/pkg/modelsdev/store_test.go index 9fbb46355..c091b987c 100644 --- a/pkg/modelsdev/store_test.go +++ b/pkg/modelsdev/store_test.go @@ -4,7 +4,6 @@ import ( "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestResolveModelAlias(t *testing.T) { @@ -36,12 +35,7 @@ func TestResolveModelAlias(t *testing.T) { }, } - store, err := NewStore(WithCacheDir(t.TempDir())) - require.NoError(t, err) - store.SetDatabaseForTesting(mockData) - t.Cleanup(func() { - store.SetDatabaseForTesting(nil) - }) + store := NewDatabaseStore(mockData) ctx := t.Context() diff --git a/pkg/teamloader/teamloader.go b/pkg/teamloader/teamloader.go index 7828209ef..8efb2846c 100644 --- a/pkg/teamloader/teamloader.go +++ b/pkg/teamloader/teamloader.go @@ -112,7 +112,12 @@ func LoadWithConfig(ctx context.Context, agentSource config.Source, runConfig *c // Resolve model aliases (e.g., "claude-sonnet-4-5" -> "claude-sonnet-4-5-20250929") // This ensures the sidebar and other UI elements show the actual model being used. - config.ResolveModelAliases(ctx, cfg) + modelsStore, err := modelsdev.NewStore() + if err != nil { + slog.Debug("Failed to create modelsdev store for alias resolution", "error", err) + } else { + config.ResolveModelAliases(ctx, cfg, modelsStore) + } // Apply model overrides from CLI flags before checking required env vars if err := config.ApplyModelOverrides(cfg, loadOpts.modelOverrides); err != nil { From eadb6c26091ae88b6b0911d1f661ba034d317bc8 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Fri, 13 Feb 2026 17:43:37 +0100 Subject: [PATCH 2/5] Simpler Models.dev Store Signed-off-by: David Gageot --- pkg/modelsdev/store.go | 182 ++++++++++++++++------------------------- 1 file changed, 69 insertions(+), 113 deletions(-) diff --git a/pkg/modelsdev/store.go b/pkg/modelsdev/store.go index b3f617262..74595c7c9 100644 --- a/pkg/modelsdev/store.go +++ b/pkg/modelsdev/store.go @@ -18,57 +18,33 @@ import ( const ( ModelsDevAPIURL = "https://models.dev/api.json" CacheFileName = "models_dev.json" + refreshInterval = 24 * time.Hour ) // ModelAliases maps alias model IDs to their actual model IDs // TODO(krissetto): Add aliases here if needed, removed if unused var ModelAliases = map[string]string{} -// Store manages the models.dev data with local caching +// Store manages access to the models.dev data. +// The database is loaded lazily on first access and cached for the +// lifetime of the Store. All methods are safe for concurrent use. type Store struct { - cacheDir string - client *http.Client - refreshInterval time.Duration - - // In-memory cache for database to avoid repeated disk reads - dbCache *Database - dbCacheMu sync.RWMutex -} - -type Opt func(*Store) - -func WithRefreshInterval(refreshInterval time.Duration) Opt { - return func(s *Store) { - s.refreshInterval = refreshInterval - } + db func() (*Database, error) } -func WithCacheDir(cacheDir string) Opt { - return func(s *Store) { - s.cacheDir = cacheDir - } -} - -// defaultStore is a cached singleton store instance for repeated access -var defaultStore = sync.OnceValues(func() (*Store, error) { - return newStoreInternal() -}) +// defaultStore is a cached singleton store instance for repeated access. +var defaultStore = sync.OnceValues(newStoreInternal) // NewStore returns the cached default store instance. -// This is efficient for repeated calls as it reuses the same store. -// For custom configuration, use NewStoreWithOptions. -func NewStore(opts ...Opt) (*Store, error) { - if len(opts) > 0 { - return newStoreInternal(opts...) - } +// The underlying database is fetched lazily on first access +// from a local cache file or the models.dev API. +func NewStore() (*Store, error) { return defaultStore() } -// newStoreInternal creates a new models.dev store instance -func newStoreInternal(opts ...Opt) (*Store, error) { - s := &Store{ - refreshInterval: 24 * time.Hour, - } +// newStoreInternal creates a new models.dev store that loads data +// from the filesystem cache or the network on first access. +func newStoreInternal() (*Store, error) { homeDir, err := os.UserHomeDir() if err != nil { return nil, fmt.Errorf("failed to get user home directory: %w", err) @@ -78,70 +54,32 @@ func newStoreInternal(opts ...Opt) (*Store, error) { if err := os.MkdirAll(cacheDir, 0o755); err != nil { return nil, fmt.Errorf("failed to create cache directory: %w", err) } - s.cacheDir = cacheDir - for _, opt := range opts { - opt(s) - } - s.client = &http.Client{ - Timeout: 30 * time.Second, - } + cacheFile := filepath.Join(cacheDir, CacheFileName) - return s, nil + return &Store{ + db: sync.OnceValues(func() (*Database, error) { + return loadDatabase(cacheFile) + }), + }, nil } -// GetDatabase returns the models.dev database, fetching from cache or API as needed. -// Results are cached in memory to avoid repeated disk reads within the same process. -func (s *Store) GetDatabase(ctx context.Context) (*Database, error) { - // Check in-memory cache first - s.dbCacheMu.RLock() - if s.dbCache != nil { - db := s.dbCache - s.dbCacheMu.RUnlock() - return db, nil - } - s.dbCacheMu.RUnlock() - - // Need to load from disk or network - s.dbCacheMu.Lock() - defer s.dbCacheMu.Unlock() - - // Double-check after acquiring write lock - if s.dbCache != nil { - return s.dbCache, nil - } - - cacheFile := filepath.Join(s.cacheDir, CacheFileName) - - // Try to load from cache first - cached, err := s.loadFromCache(cacheFile) - if err == nil && s.isCacheValid(cached) { - s.dbCache = &cached.Database - return s.dbCache, nil - } - - // Cache is invalid or doesn't exist, fetch from API - database, err := s.fetchFromAPI(ctx) - if err != nil { - // If API fetch fails, but we have cached data, use it - if cached != nil { - s.dbCache = &cached.Database - return s.dbCache, nil - } - return nil, fmt.Errorf("failed to fetch from API and no cached data available: %w", err) - } - - // Save to cache - if err := s.saveToCache(cacheFile, database); err != nil { - // Log the error but don't fail the request - slog.Warn("Warning: failed to save to cache", "error", err) +// NewDatabaseStore creates a Store pre-populated with the given database. +// The returned store serves data entirely from memory and never fetches +// from the network or touches the filesystem, making it suitable for +// tests and any scenario where the provider data is already known. +func NewDatabaseStore(db *Database) *Store { + return &Store{ + db: func() (*Database, error) { return db, nil }, } +} - s.dbCache = database - return s.dbCache, nil +// GetDatabase returns the models.dev database, fetching from cache or API as needed. +func (s *Store) GetDatabase(ctx context.Context) (*Database, error) { + return s.db() } -// GetProvider returns a specific provider by ID +// GetProvider returns a specific provider by ID. func (s *Store) GetProvider(ctx context.Context, providerID string) (*Provider, error) { db, err := s.GetDatabase(ctx) if err != nil { @@ -156,7 +94,7 @@ func (s *Store) GetProvider(ctx context.Context, providerID string) (*Provider, return &provider, nil } -// GetModel returns a specific model by provider ID and model ID +// GetModel returns a specific model by provider ID and model ID. func (s *Store) GetModel(ctx context.Context, id string) (*Model, error) { // Check if the ID is an alias and resolve it if actualID, isAlias := ModelAliases[id]; isAlias { @@ -200,13 +138,45 @@ func (s *Store) GetModel(ctx context.Context, id string) (*Model, error) { return &model, nil } -func (s *Store) fetchFromAPI(ctx context.Context) (*Database, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, ModelsDevAPIURL, http.NoBody) +// loadDatabase loads the database from the local cache file or +// falls back to fetching from the models.dev API. +func loadDatabase(cacheFile string) (*Database, error) { + // Try to load from cache first + cached, err := loadFromCache(cacheFile) + if err == nil && time.Since(cached.LastRefresh) < refreshInterval { + return &cached.Database, nil + } + + // Cache is invalid or doesn't exist, fetch from API + database, fetchErr := fetchFromAPI() + if fetchErr != nil { + // If API fetch fails, but we have cached data, use it + if cached != nil { + return &cached.Database, nil + } + return nil, fmt.Errorf("failed to fetch from API and no cached data available: %w", fetchErr) + } + + // Save to cache + if err := saveToCache(cacheFile, database); err != nil { + // Log the error but don't fail the request + slog.Warn("Warning: failed to save to cache", "error", err) + } + + return database, nil +} + +func fetchFromAPI() (*Database, error) { + client := &http.Client{ + Timeout: 30 * time.Second, + } + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, ModelsDevAPIURL, http.NoBody) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } - resp, err := s.client.Do(req) + resp, err := client.Do(req) if err != nil { return nil, fmt.Errorf("failed to fetch from API: %w", err) } @@ -234,7 +204,7 @@ func (s *Store) fetchFromAPI(ctx context.Context) (*Database, error) { return database, nil } -func (s *Store) loadFromCache(cacheFile string) (*CachedData, error) { +func loadFromCache(cacheFile string) (*CachedData, error) { data, err := os.ReadFile(cacheFile) if err != nil { return nil, fmt.Errorf("failed to read cache file: %w", err) @@ -248,7 +218,7 @@ func (s *Store) loadFromCache(cacheFile string) (*CachedData, error) { return &cached, nil } -func (s *Store) saveToCache(cacheFile string, database *Database) error { +func saveToCache(cacheFile string, database *Database) error { now := time.Now() cached := CachedData{ Database: *database, @@ -268,20 +238,6 @@ func (s *Store) saveToCache(cacheFile string, database *Database) error { return nil } -func (s *Store) isCacheValid(cached *CachedData) bool { - return time.Since(cached.LastRefresh) < s.refreshInterval -} - -// NewDatabaseStore creates a Store pre-populated with the given database. -// The returned store serves data entirely from memory and never fetches -// from the network or touches the filesystem, making it suitable for -// tests and any scenario where the provider data is already known. -func NewDatabaseStore(db *Database) *Store { - return &Store{ - dbCache: db, - } -} - // datePattern matches date suffixes like -20251101, -2024-11-20, etc. var datePattern = regexp.MustCompile(`-\d{4}-?\d{2}-?\d{2}$`) From f9ec9114860d51eb3fee434ca67e8398cb9adb75 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Fri, 13 Feb 2026 17:47:24 +0100 Subject: [PATCH 3/5] Remove unused code Signed-off-by: David Gageot --- pkg/modelsdev/store.go | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/pkg/modelsdev/store.go b/pkg/modelsdev/store.go index 74595c7c9..366589849 100644 --- a/pkg/modelsdev/store.go +++ b/pkg/modelsdev/store.go @@ -21,10 +21,6 @@ const ( refreshInterval = 24 * time.Hour ) -// ModelAliases maps alias model IDs to their actual model IDs -// TODO(krissetto): Add aliases here if needed, removed if unused -var ModelAliases = map[string]string{} - // Store manages access to the models.dev data. // The database is loaded lazily on first access and cached for the // lifetime of the Store. All methods are safe for concurrent use. @@ -96,11 +92,6 @@ func (s *Store) GetProvider(ctx context.Context, providerID string) (*Provider, // GetModel returns a specific model by provider ID and model ID. func (s *Store) GetModel(ctx context.Context, id string) (*Model, error) { - // Check if the ID is an alias and resolve it - if actualID, isAlias := ModelAliases[id]; isAlias { - id = actualID - } - parts := strings.SplitN(id, "/", 2) if len(parts) != 2 { return nil, fmt.Errorf("invalid model ID: %q", id) @@ -250,15 +241,6 @@ func (s *Store) ResolveModelAlias(ctx context.Context, providerID, modelName str return modelName } - // Check if there's a manual alias mapping first - fullID := providerID + "/" + modelName - if resolved, ok := ModelAliases[fullID]; ok { - if _, m, ok := strings.Cut(resolved, "/"); ok { - return m - } - return resolved - } - // If the model already has a date suffix, it's already pinned if datePattern.MatchString(modelName) { return modelName From e3eb7133c93c7d9c1952afb9357ba492b93dda11 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Fri, 13 Feb 2026 17:56:13 +0100 Subject: [PATCH 4/5] Don't pass unused context Signed-off-by: David Gageot --- pkg/config/examples_test.go | 2 +- pkg/config/model_alias.go | 9 ++++----- pkg/config/model_alias_test.go | 4 +--- pkg/model/provider/bedrock/client.go | 6 +++--- pkg/model/provider/bedrock/client_test.go | 6 +++--- pkg/modelsdev/store.go | 18 +++++++++--------- pkg/modelsdev/store_test.go | 4 +--- pkg/rag/strategy/semantic_embeddings.go | 6 +++--- pkg/rag/strategy/vector_store.go | 8 ++++---- pkg/runtime/model_switcher.go | 2 +- pkg/runtime/model_switcher_test.go | 2 +- pkg/runtime/runtime.go | 4 ++-- pkg/runtime/runtime_test.go | 4 ++-- pkg/teamloader/teamloader.go | 6 +++--- pkg/tui/commands/commands.go | 2 +- pkg/tui/components/sidebar/sidebar.go | 3 +-- pkg/tui/handlers.go | 2 +- 17 files changed, 41 insertions(+), 47 deletions(-) diff --git a/pkg/config/examples_test.go b/pkg/config/examples_test.go index 7f1aeef3b..eebc4a9fb 100644 --- a/pkg/config/examples_test.go +++ b/pkg/config/examples_test.go @@ -70,7 +70,7 @@ func TestParseExamples(t *testing.T) { continue } - model, err := modelsStore.GetModel(t.Context(), model.Provider+"/"+model.Model) + model, err := modelsStore.GetModel(model.Provider + "/" + model.Model) require.NoError(t, err) require.NotNil(t, model) } diff --git a/pkg/config/model_alias.go b/pkg/config/model_alias.go index e1bdfd64e..e322bd86c 100644 --- a/pkg/config/model_alias.go +++ b/pkg/config/model_alias.go @@ -1,7 +1,6 @@ package config import ( - "context" "log/slog" "strings" @@ -17,7 +16,7 @@ import ( // either set directly on the model or inherited from a custom provider definition. // This is necessary because external providers (like Azure Foundry) may use the alias // names directly as deployment names rather than the pinned version names. -func ResolveModelAliases(ctx context.Context, cfg *latest.Config, store *modelsdev.Store) { +func ResolveModelAliases(cfg *latest.Config, store *modelsdev.Store) { // Resolve model aliases in the models section for name, modelCfg := range cfg.Models { // Skip alias resolution for models with custom base_url (direct or via provider) @@ -28,7 +27,7 @@ func ResolveModelAliases(ctx context.Context, cfg *latest.Config, store *modelsd continue } - if resolved := store.ResolveModelAlias(ctx, modelCfg.Provider, modelCfg.Model); resolved != modelCfg.Model { + if resolved := store.ResolveModelAlias(modelCfg.Provider, modelCfg.Model); resolved != modelCfg.Model { modelCfg.Model = resolved cfg.Models[name] = modelCfg } @@ -36,7 +35,7 @@ func ResolveModelAliases(ctx context.Context, cfg *latest.Config, store *modelsd // Resolve model aliases in routing rules for i, rule := range modelCfg.Routing { if provider, model, ok := strings.Cut(rule.Model, "/"); ok { - if resolved := store.ResolveModelAlias(ctx, provider, model); resolved != model { + if resolved := store.ResolveModelAlias(provider, model); resolved != model { modelCfg.Routing[i].Model = provider + "/" + resolved } } @@ -53,7 +52,7 @@ func ResolveModelAliases(ctx context.Context, cfg *latest.Config, store *modelsd var resolvedModels []string for modelRef := range strings.SplitSeq(agent.Model, ",") { if provider, model, ok := strings.Cut(modelRef, "/"); ok { - if resolved := store.ResolveModelAlias(ctx, provider, model); resolved != model { + if resolved := store.ResolveModelAlias(provider, model); resolved != model { resolvedModels = append(resolvedModels, provider+"/"+resolved) continue } diff --git a/pkg/config/model_alias_test.go b/pkg/config/model_alias_test.go index 8215f27f3..e5c7fe3ee 100644 --- a/pkg/config/model_alias_test.go +++ b/pkg/config/model_alias_test.go @@ -28,8 +28,6 @@ func TestResolveModelAliases(t *testing.T) { store := modelsdev.NewDatabaseStore(mockData) - ctx := t.Context() - tests := []struct { name string cfg *latest.Config @@ -239,7 +237,7 @@ func TestResolveModelAliases(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ResolveModelAliases(ctx, tt.cfg, store) + ResolveModelAliases(tt.cfg, store) assert.Equal(t, tt.expected, tt.cfg) }) } diff --git a/pkg/model/provider/bedrock/client.go b/pkg/model/provider/bedrock/client.go index e8b11287a..a02cf549d 100644 --- a/pkg/model/provider/bedrock/client.go +++ b/pkg/model/provider/bedrock/client.go @@ -112,7 +112,7 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro // Detect prompt caching capability at init time for efficiency. // Uses models.dev cache pricing as proxy for capability detection. - cachingSupported := detectCachingSupport(ctx, cfg.Model) + cachingSupported := detectCachingSupport(cfg.Model) slog.Debug("Bedrock client created successfully", "model", cfg.Model, @@ -133,7 +133,7 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro // detectCachingSupport checks if a model supports prompt caching using models.dev data. // Models with non-zero CacheRead/CacheWrite costs support prompt caching. // Returns false on lookup failure (safe default for unsupported models). -func detectCachingSupport(ctx context.Context, model string) bool { +func detectCachingSupport(model string) bool { store, err := modelsdev.NewStore() if err != nil { slog.Debug("Bedrock models store unavailable, prompt caching disabled", "error", err) @@ -141,7 +141,7 @@ func detectCachingSupport(ctx context.Context, model string) bool { } modelID := "amazon-bedrock/" + model - m, err := store.GetModel(ctx, modelID) + m, err := store.GetModel(modelID) if err != nil { slog.Debug("Bedrock prompt caching disabled: model not found in models.dev", "model_id", modelID, "error", err) diff --git a/pkg/model/provider/bedrock/client_test.go b/pkg/model/provider/bedrock/client_test.go index 6d02e2ee4..4e63ee13b 100644 --- a/pkg/model/provider/bedrock/client_test.go +++ b/pkg/model/provider/bedrock/client_test.go @@ -1249,7 +1249,7 @@ func TestDetectCachingSupport_SupportedModel(t *testing.T) { t.Parallel() // Uses real models.dev lookup to verify Claude models support caching - supported := detectCachingSupport(t.Context(), "anthropic.claude-3-5-sonnet-20241022-v2:0") + supported := detectCachingSupport("anthropic.claude-3-5-sonnet-20241022-v2:0") assert.True(t, supported) } @@ -1257,7 +1257,7 @@ func TestDetectCachingSupport_UnsupportedModel(t *testing.T) { t.Parallel() // Llama doesn't have cache pricing in models.dev - supported := detectCachingSupport(t.Context(), "meta.llama3-8b-instruct-v1:0") + supported := detectCachingSupport("meta.llama3-8b-instruct-v1:0") assert.False(t, supported) } @@ -1265,7 +1265,7 @@ func TestDetectCachingSupport_UnknownModel(t *testing.T) { t.Parallel() // Unknown model should gracefully return false, not panic - supported := detectCachingSupport(t.Context(), "nonexistent.model.that.does.not.exist:v1") + supported := detectCachingSupport("nonexistent.model.that.does.not.exist:v1") assert.False(t, supported) } diff --git a/pkg/modelsdev/store.go b/pkg/modelsdev/store.go index 366589849..8bf1104d2 100644 --- a/pkg/modelsdev/store.go +++ b/pkg/modelsdev/store.go @@ -71,13 +71,13 @@ func NewDatabaseStore(db *Database) *Store { } // GetDatabase returns the models.dev database, fetching from cache or API as needed. -func (s *Store) GetDatabase(ctx context.Context) (*Database, error) { +func (s *Store) GetDatabase() (*Database, error) { return s.db() } // GetProvider returns a specific provider by ID. -func (s *Store) GetProvider(ctx context.Context, providerID string) (*Provider, error) { - db, err := s.GetDatabase(ctx) +func (s *Store) GetProvider(providerID string) (*Provider, error) { + db, err := s.GetDatabase() if err != nil { return nil, err } @@ -91,7 +91,7 @@ func (s *Store) GetProvider(ctx context.Context, providerID string) (*Provider, } // GetModel returns a specific model by provider ID and model ID. -func (s *Store) GetModel(ctx context.Context, id string) (*Model, error) { +func (s *Store) GetModel(id string) (*Model, error) { parts := strings.SplitN(id, "/", 2) if len(parts) != 2 { return nil, fmt.Errorf("invalid model ID: %q", id) @@ -99,7 +99,7 @@ func (s *Store) GetModel(ctx context.Context, id string) (*Model, error) { providerID := parts[0] modelID := parts[1] - provider, err := s.GetProvider(ctx, providerID) + provider, err := s.GetProvider(providerID) if err != nil { return nil, err } @@ -236,7 +236,7 @@ var datePattern = regexp.MustCompile(`-\d{4}-?\d{2}-?\d{2}$`) // For example, ("anthropic", "claude-sonnet-4-5") might resolve to "claude-sonnet-4-5-20250929". // If the model is not an alias (already pinned or unknown), the original model name is returned. // This method uses the models.dev database to find the corresponding pinned version. -func (s *Store) ResolveModelAlias(ctx context.Context, providerID, modelName string) string { +func (s *Store) ResolveModelAlias(providerID, modelName string) string { if providerID == "" || modelName == "" { return modelName } @@ -247,7 +247,7 @@ func (s *Store) ResolveModelAlias(ctx context.Context, providerID, modelName str } // Get the provider from the database - provider, err := s.GetProvider(ctx, providerID) + provider, err := s.GetProvider(providerID) if err != nil { return modelName } @@ -296,7 +296,7 @@ func isBedrockRegionPrefix(prefix string) bool { // - If modelID is empty or not in "provider/model" format, returns true (fail-open) // - If models.dev lookup fails for any reason, returns true (fail-open) // - If lookup succeeds, returns the model's Reasoning field value -func ModelSupportsReasoning(ctx context.Context, modelID string) bool { +func ModelSupportsReasoning(modelID string) bool { // Fail-open for empty model ID if modelID == "" { return true @@ -314,7 +314,7 @@ func ModelSupportsReasoning(ctx context.Context, modelID string) bool { return true } - model, err := store.GetModel(ctx, modelID) + model, err := store.GetModel(modelID) if err != nil { slog.Debug("Failed to lookup model in models.dev, assuming reasoning supported to allow user choice", "model_id", modelID, "error", err) return true diff --git a/pkg/modelsdev/store_test.go b/pkg/modelsdev/store_test.go index c091b987c..bfe43465a 100644 --- a/pkg/modelsdev/store_test.go +++ b/pkg/modelsdev/store_test.go @@ -37,8 +37,6 @@ func TestResolveModelAlias(t *testing.T) { store := NewDatabaseStore(mockData) - ctx := t.Context() - tests := []struct { name string provider string @@ -59,7 +57,7 @@ func TestResolveModelAlias(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := store.ResolveModelAlias(ctx, tt.provider, tt.model) + result := store.ResolveModelAlias(tt.provider, tt.model) assert.Equal(t, tt.expected, result) }) } diff --git a/pkg/rag/strategy/semantic_embeddings.go b/pkg/rag/strategy/semantic_embeddings.go index 03bd9287b..1d69913e2 100644 --- a/pkg/rag/strategy/semantic_embeddings.go +++ b/pkg/rag/strategy/semantic_embeddings.go @@ -173,7 +173,7 @@ func NewSemanticEmbeddingsFromConfig(ctx context.Context, cfg latest.RAGStrategy return } - cost := calculateSemanticUsageCost(ctx, embeddingCfg.ModelsStore, chatModelID, usage) + cost := calculateSemanticUsageCost(embeddingCfg.ModelsStore, chatModelID, usage) store.RecordUsage(totalTokens, cost) } @@ -501,12 +501,12 @@ func humanizeMetadataKey(key string) string { } // calculateSemanticUsageCost calculates cost for semantic LLM usage. -func calculateSemanticUsageCost(ctx context.Context, modelsStore modelStore, modelID string, usage *chat.Usage) float64 { +func calculateSemanticUsageCost(modelsStore modelStore, modelID string, usage *chat.Usage) float64 { if usage == nil || modelsStore == nil || modelID == "" || strings.HasPrefix(modelID, "dmr/") { return 0 } - model, err := modelsStore.GetModel(ctx, modelID) + model, err := modelsStore.GetModel(modelID) if err != nil { slog.Debug("Failed to get semantic model pricing from models.dev, cost will be 0", "model_id", modelID, diff --git a/pkg/rag/strategy/vector_store.go b/pkg/rag/strategy/vector_store.go index d4cbc777e..80afff926 100644 --- a/pkg/rag/strategy/vector_store.go +++ b/pkg/rag/strategy/vector_store.go @@ -87,7 +87,7 @@ type VectorStore struct { } type modelStore interface { - GetModel(ctx context.Context, modelID string) (*modelsdev.Model, error) + GetModel(modelID string) (*modelsdev.Model, error) } // EmbeddingInputBuilder builds the string that will be sent to the embedding model @@ -150,7 +150,7 @@ func NewVectorStore(cfg VectorStoreConfig) *VectorStore { // Set usage handler to calculate cost from models.dev and emit events with CUMULATIVE totals // This matches how chat completions calculate cost in runtime.go cfg.Embedder.SetUsageHandler(func(tokens int64, _ float64) { - cost := s.calculateCost(context.Background(), tokens) + cost := s.calculateCost(tokens) s.recordUsage(tokens, cost) }) @@ -169,12 +169,12 @@ func (s *VectorStore) SetEmbeddingInputBuilder(builder EmbeddingInputBuilder) { } // calculateCost calculates embedding cost using models.dev pricing -func (s *VectorStore) calculateCost(ctx context.Context, tokens int64) float64 { +func (s *VectorStore) calculateCost(tokens int64) float64 { if s.modelsStore == nil || strings.HasPrefix(s.modelID, "dmr/") { return 0 } - model, err := s.modelsStore.GetModel(ctx, s.modelID) + model, err := s.modelsStore.GetModel(s.modelID) if err != nil { slog.Debug("Failed to get model pricing from models.dev, cost will be 0", "model_id", s.modelID, diff --git a/pkg/runtime/model_switcher.go b/pkg/runtime/model_switcher.go index cb44969db..d31c1429f 100644 --- a/pkg/runtime/model_switcher.go +++ b/pkg/runtime/model_switcher.go @@ -460,7 +460,7 @@ func (r *LocalRuntime) createProviderFromConfig(ctx context.Context, cfg *latest if cfg.MaxTokens != nil { opts = append(opts, options.WithMaxTokens(*cfg.MaxTokens)) } else if r.modelsStore != nil { - m, err := r.modelsStore.GetModel(ctx, cfg.Provider+"/"+cfg.Model) + m, err := r.modelsStore.GetModel(cfg.Provider + "/" + cfg.Model) if err == nil && m != nil { opts = append(opts, options.WithMaxTokens(m.Limit.Output)) } diff --git a/pkg/runtime/model_switcher_test.go b/pkg/runtime/model_switcher_test.go index d376d338c..56e1a93fc 100644 --- a/pkg/runtime/model_switcher_test.go +++ b/pkg/runtime/model_switcher_test.go @@ -26,7 +26,7 @@ type mockCatalogStore struct { db *modelsdev.Database } -func (m *mockCatalogStore) GetModel(_ context.Context, _ string) (*modelsdev.Model, error) { +func (m *mockCatalogStore) GetModel(_ string) (*modelsdev.Model, error) { return nil, nil } diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index 299793d64..52d27f977 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -171,7 +171,7 @@ type CurrentAgentInfo struct { } type ModelStore interface { - GetModel(ctx context.Context, modelID string) (*modelsdev.Model, error) + GetModel(modelID string) (*modelsdev.Model, error) } // RAGInitializer is implemented by runtimes that support background RAG initialization. @@ -987,7 +987,7 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c modelID := model.ID() slog.Debug("Using agent", "agent", a.Name(), "model", modelID) slog.Debug("Getting model definition", "model_id", modelID) - m, err := r.modelsStore.GetModel(ctx, modelID) + m, err := r.modelsStore.GetModel(modelID) if err != nil { slog.Debug("Failed to get model definition", "error", err) } diff --git a/pkg/runtime/runtime_test.go b/pkg/runtime/runtime_test.go index 344b987d3..947d8e47d 100644 --- a/pkg/runtime/runtime_test.go +++ b/pkg/runtime/runtime_test.go @@ -173,7 +173,7 @@ func (m *mockProviderWithError) MaxTokens() int { return 0 } type mockModelStore struct{} -func (m mockModelStore) GetModel(context.Context, string) (*modelsdev.Model, error) { +func (m mockModelStore) GetModel(string) (*modelsdev.Model, error) { return nil, nil } @@ -675,7 +675,7 @@ func (p *queueProvider) MaxTokens() int { return 0 } type mockModelStoreWithLimit struct{ limit int } -func (m mockModelStoreWithLimit) GetModel(context.Context, string) (*modelsdev.Model, error) { +func (m mockModelStoreWithLimit) GetModel(string) (*modelsdev.Model, error) { return &modelsdev.Model{Limit: modelsdev.Limit{Context: m.limit}, Cost: &modelsdev.Cost{}}, nil } diff --git a/pkg/teamloader/teamloader.go b/pkg/teamloader/teamloader.go index 8efb2846c..bcc97dcb9 100644 --- a/pkg/teamloader/teamloader.go +++ b/pkg/teamloader/teamloader.go @@ -116,7 +116,7 @@ func LoadWithConfig(ctx context.Context, agentSource config.Source, runConfig *c if err != nil { slog.Debug("Failed to create modelsdev store for alias resolution", "error", err) } else { - config.ResolveModelAliases(ctx, cfg, modelsStore) + config.ResolveModelAliases(cfg, modelsStore) } // Apply model overrides from CLI flags before checking required env vars @@ -318,7 +318,7 @@ func getModelsForAgent(ctx context.Context, cfg *latest.Config, a *latest.AgentC if err != nil { return nil, false, err } - m, err := modelsStore.GetModel(ctx, modelCfg.Provider+"/"+modelCfg.Model) + m, err := modelsStore.GetModel(modelCfg.Provider + "/" + modelCfg.Model) if err == nil { maxTokens = &m.Limit.Output } @@ -381,7 +381,7 @@ func getFallbackModelsForAgent(ctx context.Context, cfg *latest.Config, a *lates if err != nil { return nil, err } - m, err := modelsStore.GetModel(ctx, modelCfg.Provider+"/"+modelCfg.Model) + m, err := modelsStore.GetModel(modelCfg.Provider + "/" + modelCfg.Model) if err == nil { maxTokens = &m.Limit.Output } diff --git a/pkg/tui/commands/commands.go b/pkg/tui/commands/commands.go index 34a978aed..9834acccd 100644 --- a/pkg/tui/commands/commands.go +++ b/pkg/tui/commands/commands.go @@ -262,7 +262,7 @@ func BuildCommandCategories(ctx context.Context, application *app.App) []Categor // Check if the current model supports reasoning; hide /think if not currentModel := application.CurrentAgentModel() - if !modelsdev.ModelSupportsReasoning(ctx, currentModel) { + if !modelsdev.ModelSupportsReasoning(currentModel) { filtered := make([]Item, 0, len(sessionCommands)) for _, cmd := range sessionCommands { if cmd.ID != "session.think" { diff --git a/pkg/tui/components/sidebar/sidebar.go b/pkg/tui/components/sidebar/sidebar.go index 338beb5aa..77cd0dbce 100644 --- a/pkg/tui/components/sidebar/sidebar.go +++ b/pkg/tui/components/sidebar/sidebar.go @@ -1,7 +1,6 @@ package sidebar import ( - "context" "fmt" "log/slog" "maps" @@ -253,7 +252,7 @@ func (m *model) SetAgentInfo(agentName, modelID, description string) { m.currentAgent = agentName m.agentModel = modelID m.agentDescription = description - m.reasoningSupported = modelsdev.ModelSupportsReasoning(context.Background(), modelID) + m.reasoningSupported = modelsdev.ModelSupportsReasoning(modelID) // Update the provider and model in availableAgents for the current agent. // This is important when fallback models from different providers are used. diff --git a/pkg/tui/handlers.go b/pkg/tui/handlers.go index 941f289ea..0d050b43c 100644 --- a/pkg/tui/handlers.go +++ b/pkg/tui/handlers.go @@ -330,7 +330,7 @@ func (a *appModel) handleToggleYolo() (tea.Model, tea.Cmd) { func (a *appModel) handleToggleThinking() (tea.Model, tea.Cmd) { // Check if the current model supports reasoning currentModel := a.application.CurrentAgentModel() - if !modelsdev.ModelSupportsReasoning(context.Background(), currentModel) { + if !modelsdev.ModelSupportsReasoning(currentModel) { return a, notification.InfoCmd("Thinking/reasoning is not supported for the current model") } From 6efaee319b6c322e7791c5a81e72e69fb69bd4cf Mon Sep 17 00:00:00 2001 From: David Gageot Date: Fri, 13 Feb 2026 17:58:50 +0100 Subject: [PATCH 5/5] Faster fetch Signed-off-by: David Gageot --- pkg/modelsdev/store.go | 31 ++++++++++--------------------- 1 file changed, 10 insertions(+), 21 deletions(-) diff --git a/pkg/modelsdev/store.go b/pkg/modelsdev/store.go index 8bf1104d2..ccbf68619 100644 --- a/pkg/modelsdev/store.go +++ b/pkg/modelsdev/store.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "fmt" - "io" "log/slog" "net/http" "os" @@ -158,16 +157,12 @@ func loadDatabase(cacheFile string) (*Database, error) { } func fetchFromAPI() (*Database, error) { - client := &http.Client{ - Timeout: 30 * time.Second, - } - req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, ModelsDevAPIURL, http.NoBody) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } - resp, err := client.Do(req) + resp, err := (&http.Client{Timeout: 30 * time.Second}).Do(req) if err != nil { return nil, fmt.Errorf("failed to fetch from API: %w", err) } @@ -177,33 +172,27 @@ func fetchFromAPI() (*Database, error) { return nil, fmt.Errorf("API returned status %d", resp.StatusCode) } - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) - } - var providers map[string]Provider - if err := json.Unmarshal(body, &providers); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) + if err := json.NewDecoder(resp.Body).Decode(&providers); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) } - database := &Database{ + return &Database{ Providers: providers, UpdatedAt: time.Now(), - } - - return database, nil + }, nil } func loadFromCache(cacheFile string) (*CachedData, error) { - data, err := os.ReadFile(cacheFile) + f, err := os.Open(cacheFile) if err != nil { - return nil, fmt.Errorf("failed to read cache file: %w", err) + return nil, fmt.Errorf("failed to open cache file: %w", err) } + defer f.Close() var cached CachedData - if err := json.Unmarshal(data, &cached); err != nil { - return nil, fmt.Errorf("failed to unmarshal cached data: %w", err) + if err := json.NewDecoder(f).Decode(&cached); err != nil { + return nil, fmt.Errorf("failed to decode cached data: %w", err) } return &cached, nil