diff --git a/internal/mcp/connection.go b/internal/mcp/connection.go index f0acab5d..0c398ede 100644 --- a/internal/mcp/connection.go +++ b/internal/mcp/connection.go @@ -536,31 +536,56 @@ func callParamMethod[P any](c *Connection, rawParams interface{}, fn func(P) (in return marshalToResponse(result) } -func (c *Connection) listTools() (*Response, error) { - if err := c.requireSession(); err != nil { - return nil, err - } - logConn.Printf("listTools: requesting tool list from backend serverID=%s", c.serverID) - // Fetch first page to determine initial capacity - first, err := c.getSDKSession().ListTools(c.ctx, &sdk.ListToolsParams{}) +// paginatedPage holds a single page of results from a paginated SDK list call. +type paginatedPage[T any] struct { + Items []T + NextCursor string +} + +// paginateAll collects all items across paginated SDK list calls. +func paginateAll[T any]( + serverID string, + itemKind string, + fetch func(cursor string) (paginatedPage[T], error), +) ([]T, error) { + first, err := fetch("") if err != nil { return nil, err } - allTools := make([]*sdk.Tool, len(first.Tools), max(len(first.Tools), 1)) - copy(allTools, first.Tools) - logConn.Printf("listTools: received page of %d tools from serverID=%s", len(first.Tools), c.serverID) + all := make([]T, len(first.Items), max(len(first.Items), 1)) + copy(all, first.Items) + logConn.Printf("list%s: received page of %d %s from serverID=%s", itemKind, len(first.Items), itemKind, serverID) + cursor := first.NextCursor for cursor != "" { - result, err := c.getSDKSession().ListTools(c.ctx, &sdk.ListToolsParams{Cursor: cursor}) + page, err := fetch(cursor) if err != nil { return nil, err } - allTools = append(allTools, result.Tools...) - logConn.Printf("listTools: received page of %d tools (total so far: %d) from serverID=%s", len(result.Tools), len(allTools), c.serverID) - cursor = result.NextCursor + all = append(all, page.Items...) + logConn.Printf("list%s: received page of %d %s (total so far: %d) from serverID=%s", itemKind, len(page.Items), itemKind, len(all), serverID) + cursor = page.NextCursor } - logConn.Printf("listTools: received %d tools total from serverID=%s", len(allTools), c.serverID) - return marshalToResponse(&sdk.ListToolsResult{Tools: allTools}) + logConn.Printf("list%s: received %d %s total from serverID=%s", itemKind, len(all), itemKind, serverID) + return all, nil +} + +func (c *Connection) listTools() (*Response, error) { + if err := c.requireSession(); err != nil { + return nil, err + } + logConn.Printf("listTools: requesting tool list from backend serverID=%s", c.serverID) + tools, err := paginateAll(c.serverID, "tools", func(cursor string) (paginatedPage[*sdk.Tool], error) { + result, err := c.getSDKSession().ListTools(c.ctx, &sdk.ListToolsParams{Cursor: cursor}) + if err != nil { + return paginatedPage[*sdk.Tool]{}, err + } + return paginatedPage[*sdk.Tool]{Items: result.Tools, NextCursor: result.NextCursor}, nil + }) + if err != nil { + return nil, err + } + return marshalToResponse(&sdk.ListToolsResult{Tools: tools}) } func (c *Connection) callTool(params interface{}) (*Response, error) { @@ -583,26 +608,17 @@ func (c *Connection) listResources() (*Response, error) { return nil, err } logConn.Printf("listResources: requesting resource list from backend serverID=%s", c.serverID) - // Fetch first page to determine initial capacity - first, err := c.getSDKSession().ListResources(c.ctx, &sdk.ListResourcesParams{}) - if err != nil { - return nil, err - } - allResources := make([]*sdk.Resource, len(first.Resources), max(len(first.Resources), 1)) - copy(allResources, first.Resources) - logConn.Printf("listResources: received page of %d resources from serverID=%s", len(first.Resources), c.serverID) - cursor := first.NextCursor - for cursor != "" { + resources, err := paginateAll(c.serverID, "resources", func(cursor string) (paginatedPage[*sdk.Resource], error) { result, err := c.getSDKSession().ListResources(c.ctx, &sdk.ListResourcesParams{Cursor: cursor}) if err != nil { - return nil, err + return paginatedPage[*sdk.Resource]{}, err } - allResources = append(allResources, result.Resources...) - logConn.Printf("listResources: received page of %d resources (total so far: %d) from serverID=%s", len(result.Resources), len(allResources), c.serverID) - cursor = result.NextCursor + return paginatedPage[*sdk.Resource]{Items: result.Resources, NextCursor: result.NextCursor}, nil + }) + if err != nil { + return nil, err } - logConn.Printf("listResources: received %d resources total from serverID=%s", len(allResources), c.serverID) - return marshalToResponse(&sdk.ListResourcesResult{Resources: allResources}) + return marshalToResponse(&sdk.ListResourcesResult{Resources: resources}) } func (c *Connection) readResource(params interface{}) (*Response, error) { @@ -622,26 +638,17 @@ func (c *Connection) listPrompts() (*Response, error) { return nil, err } logConn.Printf("listPrompts: requesting prompt list from backend serverID=%s", c.serverID) - // Fetch first page to determine initial capacity - first, err := c.getSDKSession().ListPrompts(c.ctx, &sdk.ListPromptsParams{}) - if err != nil { - return nil, err - } - allPrompts := make([]*sdk.Prompt, len(first.Prompts), max(len(first.Prompts), 1)) - copy(allPrompts, first.Prompts) - logConn.Printf("listPrompts: received page of %d prompts from serverID=%s", len(first.Prompts), c.serverID) - cursor := first.NextCursor - for cursor != "" { + prompts, err := paginateAll(c.serverID, "prompts", func(cursor string) (paginatedPage[*sdk.Prompt], error) { result, err := c.getSDKSession().ListPrompts(c.ctx, &sdk.ListPromptsParams{Cursor: cursor}) if err != nil { - return nil, err + return paginatedPage[*sdk.Prompt]{}, err } - allPrompts = append(allPrompts, result.Prompts...) - logConn.Printf("listPrompts: received page of %d prompts (total so far: %d) from serverID=%s", len(result.Prompts), len(allPrompts), c.serverID) - cursor = result.NextCursor + return paginatedPage[*sdk.Prompt]{Items: result.Prompts, NextCursor: result.NextCursor}, nil + }) + if err != nil { + return nil, err } - logConn.Printf("listPrompts: received %d prompts total from serverID=%s", len(allPrompts), c.serverID) - return marshalToResponse(&sdk.ListPromptsResult{Prompts: allPrompts}) + return marshalToResponse(&sdk.ListPromptsResult{Prompts: prompts}) } func (c *Connection) getPrompt(params interface{}) (*Response, error) { diff --git a/internal/mcp/http_transport.go b/internal/mcp/http_transport.go index 81803027..38a2638f 100644 --- a/internal/mcp/http_transport.go +++ b/internal/mcp/http_transport.go @@ -48,7 +48,9 @@ type httpRequestResult struct { Header http.Header } -// transportConnector is a function that creates an SDK transport for a given URL and HTTP client +// transportConnector is a function that creates an SDK transport for a given URL and HTTP client. +// The returned transport is owned by the SDK client session after Connect() succeeds; +// callers must not close it directly — it is cleaned up when the session is closed. type transportConnector func(url string, httpClient *http.Client) sdk.Transport // isHTTPConnectionError checks if an error is a network connection error. diff --git a/internal/mcp/tool_result.go b/internal/mcp/tool_result.go index 962a8025..9c612aae 100644 --- a/internal/mcp/tool_result.go +++ b/internal/mcp/tool_result.go @@ -8,15 +8,6 @@ import ( sdk "github.com/modelcontextprotocol/go-sdk/mcp" ) -// resourceContents mirrors sdk.ResourceContents for JSON unmarshaling of -// embedded resource content items returned by backend MCP servers. -type resourceContents struct { - URI string `json:"uri"` - MIMEType string `json:"mimeType,omitempty"` - Text string `json:"text,omitempty"` - Blob []byte `json:"blob,omitempty"` -} - var logToolResult = logger.New("mcp:tool_result") // ConvertToCallToolResult converts backend result data to SDK CallToolResult format. @@ -70,11 +61,11 @@ func ConvertToCallToolResult(data interface{}) (*sdk.CallToolResult, error) { // Parse the backend result structure (standard MCP CallToolResult format) var backendResult struct { Content []struct { - Type string `json:"type"` - Text string `json:"text,omitempty"` - Data []byte `json:"data,omitempty"` // image/audio binary data (automatically decoded from base64 JSON) - MIMEType string `json:"mimeType,omitempty"` // image/audio MIME type - Resource *resourceContents `json:"resource,omitempty"` // embedded resource + Type string `json:"type"` + Text string `json:"text,omitempty"` + Data []byte `json:"data,omitempty"` // image/audio binary data (automatically decoded from base64 JSON) + MIMEType string `json:"mimeType,omitempty"` // image/audio MIME type + Resource *sdk.ResourceContents `json:"resource,omitempty"` // embedded resource } `json:"content"` IsError bool `json:"isError,omitempty"` } @@ -114,12 +105,7 @@ func ConvertToCallToolResult(data interface{}) (*sdk.CallToolResult, error) { case "resource": if item.Resource != nil { content = append(content, &sdk.EmbeddedResource{ - Resource: &sdk.ResourceContents{ - URI: item.Resource.URI, - MIMEType: item.Resource.MIMEType, - Text: item.Resource.Text, - Blob: item.Resource.Blob, - }, + Resource: item.Resource, }) } else { logToolResult.Printf("Resource content item missing 'resource' field, skipping") diff --git a/internal/server/routed.go b/internal/server/routed.go index 86a4098c..d6008ee6 100644 --- a/internal/server/routed.go +++ b/internal/server/routed.go @@ -10,7 +10,6 @@ import ( "time" "github.com/github/gh-aw-mcpg/internal/logger" - "github.com/github/gh-aw-mcpg/internal/syncutil" "github.com/github/gh-aw-mcpg/internal/version" sdk "github.com/modelcontextprotocol/go-sdk/mcp" ) @@ -31,27 +30,53 @@ func rejectIfShutdown(unifiedServer *UnifiedServer, next http.Handler, logNamesp }) } -// filteredServerCache caches filtered server instances per (backend, session) key +// filteredServerCache caches filtered server instances per (backend, session) key. +// Entries are evicted after the configured TTL to prevent unbounded memory growth +// in long-running deployments with many sessions. type filteredServerCache struct { - servers map[string]*sdk.Server + servers map[string]*filteredServerEntry + ttl time.Duration mu sync.RWMutex } -// newFilteredServerCache creates a new server cache -func newFilteredServerCache() *filteredServerCache { +type filteredServerEntry struct { + server *sdk.Server + lastUsed time.Time +} + +// newFilteredServerCache creates a new server cache with the given entry TTL. +func newFilteredServerCache(ttl time.Duration) *filteredServerCache { return &filteredServerCache{ - servers: make(map[string]*sdk.Server), + servers: make(map[string]*filteredServerEntry), + ttl: ttl, } } -// getOrCreate returns a cached server or creates a new one +// getOrCreate returns a cached server or creates a new one. +// Expired entries are lazily evicted on each call. func (c *filteredServerCache) getOrCreate(backendID, sessionID string, creator func() *sdk.Server) *sdk.Server { key := fmt.Sprintf("%s/%s", backendID, sessionID) + now := time.Now() - server, _ := syncutil.GetOrCreate(&c.mu, c.servers, key, func() (*sdk.Server, error) { - logRouted.Printf("[CACHE] Creating new filtered server: backend=%s, session=%s", backendID, sessionID) - return creator(), nil - }) + c.mu.Lock() + defer c.mu.Unlock() + + // Lazy eviction of expired entries + for k, entry := range c.servers { + if now.Sub(entry.lastUsed) > c.ttl { + logRouted.Printf("[CACHE] Evicting expired server: key=%s (idle %s)", k, now.Sub(entry.lastUsed).Round(time.Second)) + delete(c.servers, k) + } + } + + if entry, ok := c.servers[key]; ok { + entry.lastUsed = now + return entry.server + } + + logRouted.Printf("[CACHE] Creating new filtered server: backend=%s, session=%s", backendID, sessionID) + server := creator() + c.servers[key] = &filteredServerEntry{server: server, lastUsed: now} return server } @@ -71,8 +96,10 @@ func CreateHTTPServerForRoutedMode(addr string, unifiedServer *UnifiedServer, ap allBackends := unifiedServer.GetServerIDs() logRouted.Printf("Registering routes for %d backends: %v", len(allBackends), allBackends) - // Create server cache for session-aware server instances - serverCache := newFilteredServerCache() + // Create server cache for session-aware server instances. + // TTL matches the SDK SessionTimeout so cache entries expire with sessions. + routedSessionTimeout := 30 * time.Minute + serverCache := newFilteredServerCache(routedSessionTimeout) // Create a proxy for each backend server for _, serverID := range allBackends { @@ -95,7 +122,7 @@ func CreateHTTPServerForRoutedMode(addr string, unifiedServer *UnifiedServer, ap }, &sdk.StreamableHTTPOptions{ Stateless: false, Logger: logger.NewSlogLoggerWithHandler(logRouted), - SessionTimeout: 30 * time.Minute, + SessionTimeout: routedSessionTimeout, }) // Apply standard middleware stack (SDK logging → shutdown check → auth) diff --git a/internal/testutil/mcptest/server.go b/internal/testutil/mcptest/server.go index e7e17ad8..a8491918 100644 --- a/internal/testutil/mcptest/server.go +++ b/internal/testutil/mcptest/server.go @@ -36,7 +36,7 @@ func (s *Server) Start() error { Version: s.config.Version, } - s.server = sdk.NewServer(impl, nil) + s.server = sdk.NewServer(impl, &sdk.ServerOptions{}) // Register tools for i, toolCfg := range s.config.Tools {