diff --git a/internal/server/routed.go b/internal/server/routed.go index fd24a144..7549b5d2 100644 --- a/internal/server/routed.go +++ b/internal/server/routed.go @@ -6,6 +6,7 @@ import ( "fmt" "log" "net/http" + "strings" "sync" "time" @@ -18,6 +19,15 @@ import ( var logRouted = logger.New("server:routed") +func truncateCacheKeyForLog(key string) string { + backendID, sessionID, found := strings.Cut(key, "/") + if !found { + return key + } + + return fmt.Sprintf("%s/%s", backendID, auth.TruncateSessionID(sessionID)) +} + // rejectIfShutdown is a middleware that rejects requests with HTTP 503 when gateway is shutting down // Per spec 5.1.3: "Immediately reject any new RPC requests to /mcp/{server-name} endpoints with HTTP 503" // The logNamespace parameter is used to create a logger for debug output specific to the call site. @@ -74,7 +84,7 @@ func (c *filteredServerCache) getOrCreate(backendID, sessionID string, creator f // Lazy eviction of expired entries for k, entry := range c.servers { if now.Sub(entry.lastUsed) > c.ttl { - logRouted.Printf("[CACHE] Evicting expired server: key=%s (idle %s)", auth.TruncateSessionID(k), now.Sub(entry.lastUsed).Round(time.Second)) + logRouted.Printf("[CACHE] Evicting expired server: key=%s (idle %s)", truncateCacheKeyForLog(k), now.Sub(entry.lastUsed).Round(time.Second)) delete(c.servers, k) } } @@ -84,12 +94,24 @@ func (c *filteredServerCache) getOrCreate(backendID, sessionID string, creator f return entry.server } - // Safety bound: if at capacity after TTL eviction, log a warning but do not - // evict non-expired entries. Routed mode relies on reusing the same filtered - // server instance for a given (backend, session), and evicting an active entry - // would recreate that server mid-session, breaking StreamableHTTP semantics. + // When at capacity after TTL eviction, evict the least-recently-used entry + // to bound memory growth reliably. This may interrupt an active session for + // the evicted (backend, session) pair, but is preferable to unbounded growth. if len(c.servers) >= c.maxSize { - logRouted.Printf("[CACHE] Max size reached (%d), retaining active entries until TTL eviction", c.maxSize) + lruKey := "" + var lruTime time.Time + first := true + for k, entry := range c.servers { + if first || entry.lastUsed.Before(lruTime) { + lruKey = k + lruTime = entry.lastUsed + first = false + } + } + if lruKey != "" { + logRouted.Printf("[CACHE] Max size reached (%d), evicting LRU entry: key=%s (idle %s)", c.maxSize, truncateCacheKeyForLog(lruKey), now.Sub(lruTime).Round(time.Second)) + delete(c.servers, lruKey) + } } logRouted.Printf("[CACHE] Creating new filtered server: backend=%s, session=%s", backendID, auth.TruncateSessionID(sessionID)) diff --git a/internal/server/routed_test.go b/internal/server/routed_test.go index 131cd979..c7a7ad4f 100644 --- a/internal/server/routed_test.go +++ b/internal/server/routed_test.go @@ -557,8 +557,8 @@ func TestCreateFilteredServer_EdgeCases(t *testing.T) { }) } -// TestFilteredServerCache_MaxSize verifies that the cache allows growth beyond maxSize -// when all entries are still active (non-expired), to avoid disrupting sessions. +// TestFilteredServerCache_MaxSize verifies that when the cache is at capacity, the +// least-recently-used entry is evicted to make room for a new entry. func TestFilteredServerCache_MaxSize(t *testing.T) { assert := assert.New(t) @@ -582,21 +582,72 @@ func TestFilteredServerCache_MaxSize(t *testing.T) { assert.NotNil(s3) assert.Equal(3, len(cache.servers), "Cache should have 3 entries") - // Adding a fourth entry should be allowed (no LRU eviction of active sessions) + // Manually set lastUsed to ensure deterministic LRU ordering: + // session1 is least recently used, session3 is most recently used. + now := time.Now() + cache.servers["backend/session1"].lastUsed = now.Add(-3 * time.Millisecond) + cache.servers["backend/session2"].lastUsed = now.Add(-2 * time.Millisecond) + cache.servers["backend/session3"].lastUsed = now.Add(-1 * time.Millisecond) + + // Adding a fourth entry should evict the LRU entry (session1) to stay within maxSize s4 := cache.getOrCreate("backend", "session4", creator) assert.Equal(4, callCount, "Should have created a 4th server") assert.NotNil(s4) - assert.Equal(4, len(cache.servers), "Cache should grow beyond maxSize for active sessions") + assert.Equal(3, len(cache.servers), "Cache should maintain maxSize by evicting the LRU entry") - // All sessions should still be present + // session1 (LRU) should have been evicted _, session1Exists := cache.servers["backend/session1"] - assert.True(session1Exists, "session1 should still be cached") + assert.False(session1Exists, "session1 (LRU) should have been evicted to make room") + + // session2, session3, session4 should still be present _, session2Exists := cache.servers["backend/session2"] assert.True(session2Exists, "session2 should still be cached") _, session3Exists := cache.servers["backend/session3"] assert.True(session3Exists, "session3 should still be cached") _, session4Exists := cache.servers["backend/session4"] - assert.True(session4Exists, "session4 should still be cached") + assert.True(session4Exists, "session4 should be cached") +} + +// TestTruncateCacheKeyForLog verifies that cache keys are properly truncated for logging. +func TestTruncateCacheKeyForLog(t *testing.T) { + tests := []struct { + name string + key string + expected string + }{ + { + name: "standard key with backendID/sessionID", + key: "github/abc123def456ghi789", + expected: "github/abc123de...", + }, + { + name: "key without slash returns as-is", + key: "nodelimiter", + expected: "nodelimiter", + }, + { + name: "empty key", + key: "", + expected: "", + }, + { + name: "key with short session", + key: "backend/ab", + expected: "backend/ab", + }, + { + name: "key with multiple slashes truncates after first", + key: "backend/session/extra", + expected: "backend/session/...", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := truncateCacheKeyForLog(tt.key) + assert.Equal(t, tt.expected, result) + }) + } } // TestFilteredServerCache_TTLEviction verifies that expired entries are evicted. diff --git a/internal/testutil/mcptest/validator.go b/internal/testutil/mcptest/validator.go index b37f30f9..ed32b968 100644 --- a/internal/testutil/mcptest/validator.go +++ b/internal/testutil/mcptest/validator.go @@ -4,9 +4,14 @@ import ( "context" "fmt" + "github.com/github/gh-aw-mcpg/internal/logger" sdk "github.com/modelcontextprotocol/go-sdk/mcp" ) +var logValidator = logger.New("testutil:validator") + +const validatorPaginationMaxPages = 1000 + // ValidatorClient is a client for validating MCP servers type ValidatorClient struct { client *sdk.Client @@ -19,7 +24,9 @@ func NewValidatorClient(ctx context.Context, transport sdk.Transport) (*Validato client := sdk.NewClient(&sdk.Implementation{ Name: "mcp-validator", Version: "1.0.0", - }, &sdk.ClientOptions{}) + }, &sdk.ClientOptions{ + Logger: logger.NewSlogLoggerWithHandler(logValidator), + }) session, err := client.Connect(ctx, transport, nil) if err != nil { @@ -33,22 +40,65 @@ func NewValidatorClient(ctx context.Context, transport sdk.Transport) (*Validato }, nil } -// ListTools retrieves the list of tools from the connected MCP server +// paginate collects all pages from a paginated MCP list call. +// fetch is called with a cursor (empty string for the first page) and returns the items, +// the next cursor (empty when done), and any error. +func paginate[T any](ctx context.Context, fetch func(ctx context.Context, cursor string) ([]T, string, error)) ([]T, error) { + var all []T + var cursor string + seenCursors := make(map[string]struct{}) + pages := 0 + for { + pages++ + if pages > validatorPaginationMaxPages { + return nil, fmt.Errorf("exceeded maximum pagination limit of %d pages", validatorPaginationMaxPages) + } + + items, nextCursor, err := fetch(ctx, cursor) + if err != nil { + return nil, err + } + all = append(all, items...) + if nextCursor == "" { + break + } + if _, ok := seenCursors[nextCursor]; ok { + return nil, fmt.Errorf("detected repeated pagination cursor %q", nextCursor) + } + seenCursors[nextCursor] = struct{}{} + cursor = nextCursor + } + return all, nil +} + +// ListTools retrieves the list of tools from the connected MCP server, including all paginated results. func (v *ValidatorClient) ListTools() ([]*sdk.Tool, error) { - result, err := v.session.ListTools(v.ctx, &sdk.ListToolsParams{}) + tools, err := paginate(v.ctx, func(ctx context.Context, cursor string) ([]*sdk.Tool, string, error) { + result, err := v.session.ListTools(ctx, &sdk.ListToolsParams{Cursor: cursor}) + if err != nil { + return nil, "", err + } + return result.Tools, result.NextCursor, nil + }) if err != nil { return nil, fmt.Errorf("list tools: %w", err) } - return result.Tools, nil + return tools, nil } -// ListResources retrieves the list of resources from the connected MCP server +// ListResources retrieves the list of resources from the connected MCP server, including all paginated results. func (v *ValidatorClient) ListResources() ([]*sdk.Resource, error) { - result, err := v.session.ListResources(v.ctx, &sdk.ListResourcesParams{}) + resources, err := paginate(v.ctx, func(ctx context.Context, cursor string) ([]*sdk.Resource, string, error) { + result, err := v.session.ListResources(ctx, &sdk.ListResourcesParams{Cursor: cursor}) + if err != nil { + return nil, "", err + } + return result.Resources, result.NextCursor, nil + }) if err != nil { return nil, fmt.Errorf("list resources: %w", err) } - return result.Resources, nil + return resources, nil } // CallTool calls a tool on the MCP server