Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions internal/server/routed.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"log"
"net/http"
"strings"
"sync"
"time"

Expand All @@ -18,6 +19,15 @@ import (

var logRouted = logger.New("server:routed")

func truncateCacheKeyForLog(key string) string {
backendID, sessionID, found := strings.Cut(key, "/")
if !found {
return key
}

return fmt.Sprintf("%s/%s", backendID, auth.TruncateSessionID(sessionID))
}

// rejectIfShutdown is a middleware that rejects requests with HTTP 503 when gateway is shutting down
// Per spec 5.1.3: "Immediately reject any new RPC requests to /mcp/{server-name} endpoints with HTTP 503"
// The logNamespace parameter is used to create a logger for debug output specific to the call site.
Expand Down Expand Up @@ -74,7 +84,7 @@ func (c *filteredServerCache) getOrCreate(backendID, sessionID string, creator f
// Lazy eviction of expired entries
for k, entry := range c.servers {
if now.Sub(entry.lastUsed) > c.ttl {
logRouted.Printf("[CACHE] Evicting expired server: key=%s (idle %s)", auth.TruncateSessionID(k), now.Sub(entry.lastUsed).Round(time.Second))
logRouted.Printf("[CACHE] Evicting expired server: key=%s (idle %s)", truncateCacheKeyForLog(k), now.Sub(entry.lastUsed).Round(time.Second))
delete(c.servers, k)
}
}
Expand All @@ -84,12 +94,24 @@ func (c *filteredServerCache) getOrCreate(backendID, sessionID string, creator f
return entry.server
}

// Safety bound: if at capacity after TTL eviction, log a warning but do not
// evict non-expired entries. Routed mode relies on reusing the same filtered
// server instance for a given (backend, session), and evicting an active entry
// would recreate that server mid-session, breaking StreamableHTTP semantics.
// When at capacity after TTL eviction, evict the least-recently-used entry
// to bound memory growth reliably. This may interrupt an active session for
// the evicted (backend, session) pair, but is preferable to unbounded growth.
if len(c.servers) >= c.maxSize {
logRouted.Printf("[CACHE] Max size reached (%d), retaining active entries until TTL eviction", c.maxSize)
lruKey := ""
var lruTime time.Time
first := true
for k, entry := range c.servers {
if first || entry.lastUsed.Before(lruTime) {
lruKey = k
lruTime = entry.lastUsed
first = false
}
}
if lruKey != "" {
Comment on lines +101 to +111
logRouted.Printf("[CACHE] Max size reached (%d), evicting LRU entry: key=%s (idle %s)", c.maxSize, truncateCacheKeyForLog(lruKey), now.Sub(lruTime).Round(time.Second))
delete(c.servers, lruKey)
}
}

logRouted.Printf("[CACHE] Creating new filtered server: backend=%s, session=%s", backendID, auth.TruncateSessionID(sessionID))
Expand Down
65 changes: 58 additions & 7 deletions internal/server/routed_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -557,8 +557,8 @@ func TestCreateFilteredServer_EdgeCases(t *testing.T) {
})
}

// TestFilteredServerCache_MaxSize verifies that the cache allows growth beyond maxSize
// when all entries are still active (non-expired), to avoid disrupting sessions.
// TestFilteredServerCache_MaxSize verifies that when the cache is at capacity, the
// least-recently-used entry is evicted to make room for a new entry.
func TestFilteredServerCache_MaxSize(t *testing.T) {
assert := assert.New(t)

Expand All @@ -582,21 +582,72 @@ func TestFilteredServerCache_MaxSize(t *testing.T) {
assert.NotNil(s3)
assert.Equal(3, len(cache.servers), "Cache should have 3 entries")

// Adding a fourth entry should be allowed (no LRU eviction of active sessions)
// Manually set lastUsed to ensure deterministic LRU ordering:
// session1 is least recently used, session3 is most recently used.
now := time.Now()
cache.servers["backend/session1"].lastUsed = now.Add(-3 * time.Millisecond)
cache.servers["backend/session2"].lastUsed = now.Add(-2 * time.Millisecond)
cache.servers["backend/session3"].lastUsed = now.Add(-1 * time.Millisecond)

// Adding a fourth entry should evict the LRU entry (session1) to stay within maxSize
s4 := cache.getOrCreate("backend", "session4", creator)
assert.Equal(4, callCount, "Should have created a 4th server")
assert.NotNil(s4)
assert.Equal(4, len(cache.servers), "Cache should grow beyond maxSize for active sessions")
assert.Equal(3, len(cache.servers), "Cache should maintain maxSize by evicting the LRU entry")

// All sessions should still be present
// session1 (LRU) should have been evicted
_, session1Exists := cache.servers["backend/session1"]
assert.True(session1Exists, "session1 should still be cached")
assert.False(session1Exists, "session1 (LRU) should have been evicted to make room")

// session2, session3, session4 should still be present
_, session2Exists := cache.servers["backend/session2"]
assert.True(session2Exists, "session2 should still be cached")
_, session3Exists := cache.servers["backend/session3"]
assert.True(session3Exists, "session3 should still be cached")
_, session4Exists := cache.servers["backend/session4"]
assert.True(session4Exists, "session4 should still be cached")
assert.True(session4Exists, "session4 should be cached")
}

// TestTruncateCacheKeyForLog verifies that cache keys are properly truncated for logging.
func TestTruncateCacheKeyForLog(t *testing.T) {
tests := []struct {
name string
key string
expected string
}{
{
name: "standard key with backendID/sessionID",
key: "github/abc123def456ghi789",
expected: "github/abc123de...",
},
{
name: "key without slash returns as-is",
key: "nodelimiter",
expected: "nodelimiter",
},
{
name: "empty key",
key: "",
expected: "",
},
{
name: "key with short session",
key: "backend/ab",
expected: "backend/ab",
},
{
name: "key with multiple slashes truncates after first",
key: "backend/session/extra",
expected: "backend/session/...",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := truncateCacheKeyForLog(tt.key)
assert.Equal(t, tt.expected, result)
})
}
}

// TestFilteredServerCache_TTLEviction verifies that expired entries are evicted.
Expand Down
64 changes: 57 additions & 7 deletions internal/testutil/mcptest/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@ import (
"context"
"fmt"

"github.com/github/gh-aw-mcpg/internal/logger"
sdk "github.com/modelcontextprotocol/go-sdk/mcp"
)

var logValidator = logger.New("testutil:validator")

const validatorPaginationMaxPages = 1000

// ValidatorClient is a client for validating MCP servers
type ValidatorClient struct {
client *sdk.Client
Expand All @@ -19,7 +24,9 @@ func NewValidatorClient(ctx context.Context, transport sdk.Transport) (*Validato
client := sdk.NewClient(&sdk.Implementation{
Name: "mcp-validator",
Version: "1.0.0",
}, &sdk.ClientOptions{})
}, &sdk.ClientOptions{
Logger: logger.NewSlogLoggerWithHandler(logValidator),
})

session, err := client.Connect(ctx, transport, nil)
if err != nil {
Expand All @@ -33,22 +40,65 @@ func NewValidatorClient(ctx context.Context, transport sdk.Transport) (*Validato
}, nil
}

// ListTools retrieves the list of tools from the connected MCP server
// paginate collects all pages from a paginated MCP list call.
// fetch is called with a cursor (empty string for the first page) and returns the items,
// the next cursor (empty when done), and any error.
func paginate[T any](ctx context.Context, fetch func(ctx context.Context, cursor string) ([]T, string, error)) ([]T, error) {
var all []T
var cursor string
seenCursors := make(map[string]struct{})
pages := 0
for {
pages++
if pages > validatorPaginationMaxPages {
return nil, fmt.Errorf("exceeded maximum pagination limit of %d pages", validatorPaginationMaxPages)
}

items, nextCursor, err := fetch(ctx, cursor)
if err != nil {
return nil, err
}
all = append(all, items...)
if nextCursor == "" {
break
}
if _, ok := seenCursors[nextCursor]; ok {
return nil, fmt.Errorf("detected repeated pagination cursor %q", nextCursor)
}
seenCursors[nextCursor] = struct{}{}
cursor = nextCursor
}
return all, nil
}

// ListTools retrieves the list of tools from the connected MCP server, including all paginated results.
func (v *ValidatorClient) ListTools() ([]*sdk.Tool, error) {
result, err := v.session.ListTools(v.ctx, &sdk.ListToolsParams{})
tools, err := paginate(v.ctx, func(ctx context.Context, cursor string) ([]*sdk.Tool, string, error) {
result, err := v.session.ListTools(ctx, &sdk.ListToolsParams{Cursor: cursor})
if err != nil {
return nil, "", err
}
return result.Tools, result.NextCursor, nil
})
if err != nil {
return nil, fmt.Errorf("list tools: %w", err)
}
return result.Tools, nil
return tools, nil
}

// ListResources retrieves the list of resources from the connected MCP server
// ListResources retrieves the list of resources from the connected MCP server, including all paginated results.
func (v *ValidatorClient) ListResources() ([]*sdk.Resource, error) {
result, err := v.session.ListResources(v.ctx, &sdk.ListResourcesParams{})
resources, err := paginate(v.ctx, func(ctx context.Context, cursor string) ([]*sdk.Resource, string, error) {
result, err := v.session.ListResources(ctx, &sdk.ListResourcesParams{Cursor: cursor})
if err != nil {
return nil, "", err
}
return result.Resources, result.NextCursor, nil
})
if err != nil {
return nil, fmt.Errorf("list resources: %w", err)
}
return result.Resources, nil
return resources, nil
}

// CallTool calls a tool on the MCP server
Expand Down
Loading