From 7e70592780882b466a4c752a2ae41b10a4536142 Mon Sep 17 00:00:00 2001 From: Landon Cox Date: Tue, 31 Mar 2026 17:32:25 -0700 Subject: [PATCH] refactor: go-sdk usage improvements from module review Address findings from Go Fan report on modelcontextprotocol/go-sdk: 1. Extract generic paginateAll() helper in connection.go to deduplicate identical cursor-loop pagination across listTools, listResources, and listPrompts (~45 lines of boilerplate removed). 2. Eliminate resourceContents intermediate type in tool_result.go by using sdk.ResourceContents directly for JSON unmarshaling, removing field-by-field copy in the resource content conversion. 3. Pass explicit &sdk.ServerOptions{} instead of nil in mcptest/server.go to guard against future SDK changes that might not accept nil options. 4. Add TTL-based eviction to filteredServerCache in routed.go to prevent unbounded memory growth. Cache entries now expire after the session timeout (30min), evicted lazily on each getOrCreate call. 5. Add transport ownership documentation to transportConnector type clarifying that the SDK session owns the transport after Connect(). Closes #2911 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- internal/mcp/connection.go | 103 +++++++++++++++------------- internal/mcp/http_transport.go | 4 +- internal/mcp/tool_result.go | 26 ++----- internal/server/routed.go | 55 +++++++++++---- internal/testutil/mcptest/server.go | 2 +- 5 files changed, 106 insertions(+), 84 deletions(-) diff --git a/internal/mcp/connection.go b/internal/mcp/connection.go index f0acab5d..0c398ede 100644 --- a/internal/mcp/connection.go +++ b/internal/mcp/connection.go @@ -536,31 +536,56 @@ func callParamMethod[P any](c *Connection, rawParams interface{}, fn func(P) (in return marshalToResponse(result) } -func (c *Connection) listTools() (*Response, error) { - if err := c.requireSession(); err != nil { - return nil, err - } - logConn.Printf("listTools: requesting tool list from backend serverID=%s", c.serverID) - // Fetch first page to determine initial capacity - first, err := c.getSDKSession().ListTools(c.ctx, &sdk.ListToolsParams{}) +// paginatedPage holds a single page of results from a paginated SDK list call. +type paginatedPage[T any] struct { + Items []T + NextCursor string +} + +// paginateAll collects all items across paginated SDK list calls. +func paginateAll[T any]( + serverID string, + itemKind string, + fetch func(cursor string) (paginatedPage[T], error), +) ([]T, error) { + first, err := fetch("") if err != nil { return nil, err } - allTools := make([]*sdk.Tool, len(first.Tools), max(len(first.Tools), 1)) - copy(allTools, first.Tools) - logConn.Printf("listTools: received page of %d tools from serverID=%s", len(first.Tools), c.serverID) + all := make([]T, len(first.Items), max(len(first.Items), 1)) + copy(all, first.Items) + logConn.Printf("list%s: received page of %d %s from serverID=%s", itemKind, len(first.Items), itemKind, serverID) + cursor := first.NextCursor for cursor != "" { - result, err := c.getSDKSession().ListTools(c.ctx, &sdk.ListToolsParams{Cursor: cursor}) + page, err := fetch(cursor) if err != nil { return nil, err } - allTools = append(allTools, result.Tools...) - logConn.Printf("listTools: received page of %d tools (total so far: %d) from serverID=%s", len(result.Tools), len(allTools), c.serverID) - cursor = result.NextCursor + all = append(all, page.Items...) + logConn.Printf("list%s: received page of %d %s (total so far: %d) from serverID=%s", itemKind, len(page.Items), itemKind, len(all), serverID) + cursor = page.NextCursor } - logConn.Printf("listTools: received %d tools total from serverID=%s", len(allTools), c.serverID) - return marshalToResponse(&sdk.ListToolsResult{Tools: allTools}) + logConn.Printf("list%s: received %d %s total from serverID=%s", itemKind, len(all), itemKind, serverID) + return all, nil +} + +func (c *Connection) listTools() (*Response, error) { + if err := c.requireSession(); err != nil { + return nil, err + } + logConn.Printf("listTools: requesting tool list from backend serverID=%s", c.serverID) + tools, err := paginateAll(c.serverID, "tools", func(cursor string) (paginatedPage[*sdk.Tool], error) { + result, err := c.getSDKSession().ListTools(c.ctx, &sdk.ListToolsParams{Cursor: cursor}) + if err != nil { + return paginatedPage[*sdk.Tool]{}, err + } + return paginatedPage[*sdk.Tool]{Items: result.Tools, NextCursor: result.NextCursor}, nil + }) + if err != nil { + return nil, err + } + return marshalToResponse(&sdk.ListToolsResult{Tools: tools}) } func (c *Connection) callTool(params interface{}) (*Response, error) { @@ -583,26 +608,17 @@ func (c *Connection) listResources() (*Response, error) { return nil, err } logConn.Printf("listResources: requesting resource list from backend serverID=%s", c.serverID) - // Fetch first page to determine initial capacity - first, err := c.getSDKSession().ListResources(c.ctx, &sdk.ListResourcesParams{}) - if err != nil { - return nil, err - } - allResources := make([]*sdk.Resource, len(first.Resources), max(len(first.Resources), 1)) - copy(allResources, first.Resources) - logConn.Printf("listResources: received page of %d resources from serverID=%s", len(first.Resources), c.serverID) - cursor := first.NextCursor - for cursor != "" { + resources, err := paginateAll(c.serverID, "resources", func(cursor string) (paginatedPage[*sdk.Resource], error) { result, err := c.getSDKSession().ListResources(c.ctx, &sdk.ListResourcesParams{Cursor: cursor}) if err != nil { - return nil, err + return paginatedPage[*sdk.Resource]{}, err } - allResources = append(allResources, result.Resources...) - logConn.Printf("listResources: received page of %d resources (total so far: %d) from serverID=%s", len(result.Resources), len(allResources), c.serverID) - cursor = result.NextCursor + return paginatedPage[*sdk.Resource]{Items: result.Resources, NextCursor: result.NextCursor}, nil + }) + if err != nil { + return nil, err } - logConn.Printf("listResources: received %d resources total from serverID=%s", len(allResources), c.serverID) - return marshalToResponse(&sdk.ListResourcesResult{Resources: allResources}) + return marshalToResponse(&sdk.ListResourcesResult{Resources: resources}) } func (c *Connection) readResource(params interface{}) (*Response, error) { @@ -622,26 +638,17 @@ func (c *Connection) listPrompts() (*Response, error) { return nil, err } logConn.Printf("listPrompts: requesting prompt list from backend serverID=%s", c.serverID) - // Fetch first page to determine initial capacity - first, err := c.getSDKSession().ListPrompts(c.ctx, &sdk.ListPromptsParams{}) - if err != nil { - return nil, err - } - allPrompts := make([]*sdk.Prompt, len(first.Prompts), max(len(first.Prompts), 1)) - copy(allPrompts, first.Prompts) - logConn.Printf("listPrompts: received page of %d prompts from serverID=%s", len(first.Prompts), c.serverID) - cursor := first.NextCursor - for cursor != "" { + prompts, err := paginateAll(c.serverID, "prompts", func(cursor string) (paginatedPage[*sdk.Prompt], error) { result, err := c.getSDKSession().ListPrompts(c.ctx, &sdk.ListPromptsParams{Cursor: cursor}) if err != nil { - return nil, err + return paginatedPage[*sdk.Prompt]{}, err } - allPrompts = append(allPrompts, result.Prompts...) - logConn.Printf("listPrompts: received page of %d prompts (total so far: %d) from serverID=%s", len(result.Prompts), len(allPrompts), c.serverID) - cursor = result.NextCursor + return paginatedPage[*sdk.Prompt]{Items: result.Prompts, NextCursor: result.NextCursor}, nil + }) + if err != nil { + return nil, err } - logConn.Printf("listPrompts: received %d prompts total from serverID=%s", len(allPrompts), c.serverID) - return marshalToResponse(&sdk.ListPromptsResult{Prompts: allPrompts}) + return marshalToResponse(&sdk.ListPromptsResult{Prompts: prompts}) } func (c *Connection) getPrompt(params interface{}) (*Response, error) { diff --git a/internal/mcp/http_transport.go b/internal/mcp/http_transport.go index 81803027..38a2638f 100644 --- a/internal/mcp/http_transport.go +++ b/internal/mcp/http_transport.go @@ -48,7 +48,9 @@ type httpRequestResult struct { Header http.Header } -// transportConnector is a function that creates an SDK transport for a given URL and HTTP client +// transportConnector is a function that creates an SDK transport for a given URL and HTTP client. +// The returned transport is owned by the SDK client session after Connect() succeeds; +// callers must not close it directly — it is cleaned up when the session is closed. type transportConnector func(url string, httpClient *http.Client) sdk.Transport // isHTTPConnectionError checks if an error is a network connection error. diff --git a/internal/mcp/tool_result.go b/internal/mcp/tool_result.go index 962a8025..9c612aae 100644 --- a/internal/mcp/tool_result.go +++ b/internal/mcp/tool_result.go @@ -8,15 +8,6 @@ import ( sdk "github.com/modelcontextprotocol/go-sdk/mcp" ) -// resourceContents mirrors sdk.ResourceContents for JSON unmarshaling of -// embedded resource content items returned by backend MCP servers. -type resourceContents struct { - URI string `json:"uri"` - MIMEType string `json:"mimeType,omitempty"` - Text string `json:"text,omitempty"` - Blob []byte `json:"blob,omitempty"` -} - var logToolResult = logger.New("mcp:tool_result") // ConvertToCallToolResult converts backend result data to SDK CallToolResult format. @@ -70,11 +61,11 @@ func ConvertToCallToolResult(data interface{}) (*sdk.CallToolResult, error) { // Parse the backend result structure (standard MCP CallToolResult format) var backendResult struct { Content []struct { - Type string `json:"type"` - Text string `json:"text,omitempty"` - Data []byte `json:"data,omitempty"` // image/audio binary data (automatically decoded from base64 JSON) - MIMEType string `json:"mimeType,omitempty"` // image/audio MIME type - Resource *resourceContents `json:"resource,omitempty"` // embedded resource + Type string `json:"type"` + Text string `json:"text,omitempty"` + Data []byte `json:"data,omitempty"` // image/audio binary data (automatically decoded from base64 JSON) + MIMEType string `json:"mimeType,omitempty"` // image/audio MIME type + Resource *sdk.ResourceContents `json:"resource,omitempty"` // embedded resource } `json:"content"` IsError bool `json:"isError,omitempty"` } @@ -114,12 +105,7 @@ func ConvertToCallToolResult(data interface{}) (*sdk.CallToolResult, error) { case "resource": if item.Resource != nil { content = append(content, &sdk.EmbeddedResource{ - Resource: &sdk.ResourceContents{ - URI: item.Resource.URI, - MIMEType: item.Resource.MIMEType, - Text: item.Resource.Text, - Blob: item.Resource.Blob, - }, + Resource: item.Resource, }) } else { logToolResult.Printf("Resource content item missing 'resource' field, skipping") diff --git a/internal/server/routed.go b/internal/server/routed.go index 86a4098c..d6008ee6 100644 --- a/internal/server/routed.go +++ b/internal/server/routed.go @@ -10,7 +10,6 @@ import ( "time" "github.com/github/gh-aw-mcpg/internal/logger" - "github.com/github/gh-aw-mcpg/internal/syncutil" "github.com/github/gh-aw-mcpg/internal/version" sdk "github.com/modelcontextprotocol/go-sdk/mcp" ) @@ -31,27 +30,53 @@ func rejectIfShutdown(unifiedServer *UnifiedServer, next http.Handler, logNamesp }) } -// filteredServerCache caches filtered server instances per (backend, session) key +// filteredServerCache caches filtered server instances per (backend, session) key. +// Entries are evicted after the configured TTL to prevent unbounded memory growth +// in long-running deployments with many sessions. type filteredServerCache struct { - servers map[string]*sdk.Server + servers map[string]*filteredServerEntry + ttl time.Duration mu sync.RWMutex } -// newFilteredServerCache creates a new server cache -func newFilteredServerCache() *filteredServerCache { +type filteredServerEntry struct { + server *sdk.Server + lastUsed time.Time +} + +// newFilteredServerCache creates a new server cache with the given entry TTL. +func newFilteredServerCache(ttl time.Duration) *filteredServerCache { return &filteredServerCache{ - servers: make(map[string]*sdk.Server), + servers: make(map[string]*filteredServerEntry), + ttl: ttl, } } -// getOrCreate returns a cached server or creates a new one +// getOrCreate returns a cached server or creates a new one. +// Expired entries are lazily evicted on each call. func (c *filteredServerCache) getOrCreate(backendID, sessionID string, creator func() *sdk.Server) *sdk.Server { key := fmt.Sprintf("%s/%s", backendID, sessionID) + now := time.Now() - server, _ := syncutil.GetOrCreate(&c.mu, c.servers, key, func() (*sdk.Server, error) { - logRouted.Printf("[CACHE] Creating new filtered server: backend=%s, session=%s", backendID, sessionID) - return creator(), nil - }) + c.mu.Lock() + defer c.mu.Unlock() + + // Lazy eviction of expired entries + for k, entry := range c.servers { + if now.Sub(entry.lastUsed) > c.ttl { + logRouted.Printf("[CACHE] Evicting expired server: key=%s (idle %s)", k, now.Sub(entry.lastUsed).Round(time.Second)) + delete(c.servers, k) + } + } + + if entry, ok := c.servers[key]; ok { + entry.lastUsed = now + return entry.server + } + + logRouted.Printf("[CACHE] Creating new filtered server: backend=%s, session=%s", backendID, sessionID) + server := creator() + c.servers[key] = &filteredServerEntry{server: server, lastUsed: now} return server } @@ -71,8 +96,10 @@ func CreateHTTPServerForRoutedMode(addr string, unifiedServer *UnifiedServer, ap allBackends := unifiedServer.GetServerIDs() logRouted.Printf("Registering routes for %d backends: %v", len(allBackends), allBackends) - // Create server cache for session-aware server instances - serverCache := newFilteredServerCache() + // Create server cache for session-aware server instances. + // TTL matches the SDK SessionTimeout so cache entries expire with sessions. + routedSessionTimeout := 30 * time.Minute + serverCache := newFilteredServerCache(routedSessionTimeout) // Create a proxy for each backend server for _, serverID := range allBackends { @@ -95,7 +122,7 @@ func CreateHTTPServerForRoutedMode(addr string, unifiedServer *UnifiedServer, ap }, &sdk.StreamableHTTPOptions{ Stateless: false, Logger: logger.NewSlogLoggerWithHandler(logRouted), - SessionTimeout: 30 * time.Minute, + SessionTimeout: routedSessionTimeout, }) // Apply standard middleware stack (SDK logging → shutdown check → auth) diff --git a/internal/testutil/mcptest/server.go b/internal/testutil/mcptest/server.go index e7e17ad8..a8491918 100644 --- a/internal/testutil/mcptest/server.go +++ b/internal/testutil/mcptest/server.go @@ -36,7 +36,7 @@ func (s *Server) Start() error { Version: s.config.Version, } - s.server = sdk.NewServer(impl, nil) + s.server = sdk.NewServer(impl, &sdk.ServerOptions{}) // Register tools for i, toolCfg := range s.config.Tools {