Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 55 additions & 48 deletions internal/mcp/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -536,31 +536,56 @@ func callParamMethod[P any](c *Connection, rawParams interface{}, fn func(P) (in
return marshalToResponse(result)
}

func (c *Connection) listTools() (*Response, error) {
if err := c.requireSession(); err != nil {
return nil, err
}
logConn.Printf("listTools: requesting tool list from backend serverID=%s", c.serverID)
// Fetch first page to determine initial capacity
first, err := c.getSDKSession().ListTools(c.ctx, &sdk.ListToolsParams{})
// paginatedPage holds a single page of results from a paginated SDK list call.
type paginatedPage[T any] struct {
Items []T
NextCursor string
}

// paginateAll collects all items across paginated SDK list calls.
func paginateAll[T any](
serverID string,
itemKind string,
fetch func(cursor string) (paginatedPage[T], error),
) ([]T, error) {
first, err := fetch("")
if err != nil {
return nil, err
}
allTools := make([]*sdk.Tool, len(first.Tools), max(len(first.Tools), 1))
copy(allTools, first.Tools)
logConn.Printf("listTools: received page of %d tools from serverID=%s", len(first.Tools), c.serverID)
all := make([]T, len(first.Items), max(len(first.Items), 1))
copy(all, first.Items)
logConn.Printf("list%s: received page of %d %s from serverID=%s", itemKind, len(first.Items), itemKind, serverID)

cursor := first.NextCursor
for cursor != "" {
result, err := c.getSDKSession().ListTools(c.ctx, &sdk.ListToolsParams{Cursor: cursor})
page, err := fetch(cursor)
if err != nil {
return nil, err
}
allTools = append(allTools, result.Tools...)
logConn.Printf("listTools: received page of %d tools (total so far: %d) from serverID=%s", len(result.Tools), len(allTools), c.serverID)
cursor = result.NextCursor
all = append(all, page.Items...)
logConn.Printf("list%s: received page of %d %s (total so far: %d) from serverID=%s", itemKind, len(page.Items), itemKind, len(all), serverID)
cursor = page.NextCursor
}
logConn.Printf("listTools: received %d tools total from serverID=%s", len(allTools), c.serverID)
return marshalToResponse(&sdk.ListToolsResult{Tools: allTools})
logConn.Printf("list%s: received %d %s total from serverID=%s", itemKind, len(all), itemKind, serverID)
return all, nil
Comment on lines +555 to +570
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

paginateAll() changes the log prefix from the prior per-method prefixes (e.g., "listTools:"/"listResources:") to "list%s:" with a lowercase plural kind (e.g., "listtools:"). This contradicts the PR description of preserving identical logging behavior and can break log filtering/alerts keyed on the existing prefixes. Consider passing an explicit log prefix ("listTools"/"listResources"/"listPrompts") into paginateAll, or logging through a callback supplied by the caller so the emitted messages remain unchanged.

Copilot uses AI. Check for mistakes.
}

func (c *Connection) listTools() (*Response, error) {
if err := c.requireSession(); err != nil {
return nil, err
}
logConn.Printf("listTools: requesting tool list from backend serverID=%s", c.serverID)
tools, err := paginateAll(c.serverID, "tools", func(cursor string) (paginatedPage[*sdk.Tool], error) {
result, err := c.getSDKSession().ListTools(c.ctx, &sdk.ListToolsParams{Cursor: cursor})
if err != nil {
return paginatedPage[*sdk.Tool]{}, err
}
return paginatedPage[*sdk.Tool]{Items: result.Tools, NextCursor: result.NextCursor}, nil
})
if err != nil {
return nil, err
}
return marshalToResponse(&sdk.ListToolsResult{Tools: tools})
}

func (c *Connection) callTool(params interface{}) (*Response, error) {
Expand All @@ -583,26 +608,17 @@ func (c *Connection) listResources() (*Response, error) {
return nil, err
}
logConn.Printf("listResources: requesting resource list from backend serverID=%s", c.serverID)
// Fetch first page to determine initial capacity
first, err := c.getSDKSession().ListResources(c.ctx, &sdk.ListResourcesParams{})
if err != nil {
return nil, err
}
allResources := make([]*sdk.Resource, len(first.Resources), max(len(first.Resources), 1))
copy(allResources, first.Resources)
logConn.Printf("listResources: received page of %d resources from serverID=%s", len(first.Resources), c.serverID)
cursor := first.NextCursor
for cursor != "" {
resources, err := paginateAll(c.serverID, "resources", func(cursor string) (paginatedPage[*sdk.Resource], error) {
result, err := c.getSDKSession().ListResources(c.ctx, &sdk.ListResourcesParams{Cursor: cursor})
if err != nil {
return nil, err
return paginatedPage[*sdk.Resource]{}, err
}
allResources = append(allResources, result.Resources...)
logConn.Printf("listResources: received page of %d resources (total so far: %d) from serverID=%s", len(result.Resources), len(allResources), c.serverID)
cursor = result.NextCursor
return paginatedPage[*sdk.Resource]{Items: result.Resources, NextCursor: result.NextCursor}, nil
})
if err != nil {
return nil, err
}
logConn.Printf("listResources: received %d resources total from serverID=%s", len(allResources), c.serverID)
return marshalToResponse(&sdk.ListResourcesResult{Resources: allResources})
return marshalToResponse(&sdk.ListResourcesResult{Resources: resources})
}

func (c *Connection) readResource(params interface{}) (*Response, error) {
Expand All @@ -622,26 +638,17 @@ func (c *Connection) listPrompts() (*Response, error) {
return nil, err
}
logConn.Printf("listPrompts: requesting prompt list from backend serverID=%s", c.serverID)
// Fetch first page to determine initial capacity
first, err := c.getSDKSession().ListPrompts(c.ctx, &sdk.ListPromptsParams{})
if err != nil {
return nil, err
}
allPrompts := make([]*sdk.Prompt, len(first.Prompts), max(len(first.Prompts), 1))
copy(allPrompts, first.Prompts)
logConn.Printf("listPrompts: received page of %d prompts from serverID=%s", len(first.Prompts), c.serverID)
cursor := first.NextCursor
for cursor != "" {
prompts, err := paginateAll(c.serverID, "prompts", func(cursor string) (paginatedPage[*sdk.Prompt], error) {
result, err := c.getSDKSession().ListPrompts(c.ctx, &sdk.ListPromptsParams{Cursor: cursor})
if err != nil {
return nil, err
return paginatedPage[*sdk.Prompt]{}, err
}
allPrompts = append(allPrompts, result.Prompts...)
logConn.Printf("listPrompts: received page of %d prompts (total so far: %d) from serverID=%s", len(result.Prompts), len(allPrompts), c.serverID)
cursor = result.NextCursor
return paginatedPage[*sdk.Prompt]{Items: result.Prompts, NextCursor: result.NextCursor}, nil
})
if err != nil {
return nil, err
}
logConn.Printf("listPrompts: received %d prompts total from serverID=%s", len(allPrompts), c.serverID)
return marshalToResponse(&sdk.ListPromptsResult{Prompts: allPrompts})
return marshalToResponse(&sdk.ListPromptsResult{Prompts: prompts})
}

func (c *Connection) getPrompt(params interface{}) (*Response, error) {
Expand Down
4 changes: 3 additions & 1 deletion internal/mcp/http_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ type httpRequestResult struct {
Header http.Header
}

// transportConnector is a function that creates an SDK transport for a given URL and HTTP client
// transportConnector is a function that creates an SDK transport for a given URL and HTTP client.
// The returned transport is owned by the SDK client session after Connect() succeeds;
// callers must not close it directly — it is cleaned up when the session is closed.
type transportConnector func(url string, httpClient *http.Client) sdk.Transport

// isHTTPConnectionError checks if an error is a network connection error.
Expand Down
26 changes: 6 additions & 20 deletions internal/mcp/tool_result.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,6 @@ import (
sdk "github.com/modelcontextprotocol/go-sdk/mcp"
)

// resourceContents mirrors sdk.ResourceContents for JSON unmarshaling of
// embedded resource content items returned by backend MCP servers.
type resourceContents struct {
URI string `json:"uri"`
MIMEType string `json:"mimeType,omitempty"`
Text string `json:"text,omitempty"`
Blob []byte `json:"blob,omitempty"`
}

var logToolResult = logger.New("mcp:tool_result")

// ConvertToCallToolResult converts backend result data to SDK CallToolResult format.
Expand Down Expand Up @@ -70,11 +61,11 @@ func ConvertToCallToolResult(data interface{}) (*sdk.CallToolResult, error) {
// Parse the backend result structure (standard MCP CallToolResult format)
var backendResult struct {
Content []struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
Data []byte `json:"data,omitempty"` // image/audio binary data (automatically decoded from base64 JSON)
MIMEType string `json:"mimeType,omitempty"` // image/audio MIME type
Resource *resourceContents `json:"resource,omitempty"` // embedded resource
Type string `json:"type"`
Text string `json:"text,omitempty"`
Data []byte `json:"data,omitempty"` // image/audio binary data (automatically decoded from base64 JSON)
MIMEType string `json:"mimeType,omitempty"` // image/audio MIME type
Resource *sdk.ResourceContents `json:"resource,omitempty"` // embedded resource
} `json:"content"`
IsError bool `json:"isError,omitempty"`
}
Expand Down Expand Up @@ -114,12 +105,7 @@ func ConvertToCallToolResult(data interface{}) (*sdk.CallToolResult, error) {
case "resource":
if item.Resource != nil {
content = append(content, &sdk.EmbeddedResource{
Resource: &sdk.ResourceContents{
URI: item.Resource.URI,
MIMEType: item.Resource.MIMEType,
Text: item.Resource.Text,
Blob: item.Resource.Blob,
},
Resource: item.Resource,
})
} else {
logToolResult.Printf("Resource content item missing 'resource' field, skipping")
Expand Down
55 changes: 41 additions & 14 deletions internal/server/routed.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"time"

"github.com/github/gh-aw-mcpg/internal/logger"
"github.com/github/gh-aw-mcpg/internal/syncutil"
"github.com/github/gh-aw-mcpg/internal/version"
sdk "github.com/modelcontextprotocol/go-sdk/mcp"
)
Expand All @@ -31,27 +30,53 @@ func rejectIfShutdown(unifiedServer *UnifiedServer, next http.Handler, logNamesp
})
}

// filteredServerCache caches filtered server instances per (backend, session) key
// filteredServerCache caches filtered server instances per (backend, session) key.
// Entries are evicted after the configured TTL to prevent unbounded memory growth
// in long-running deployments with many sessions.
type filteredServerCache struct {
servers map[string]*sdk.Server
servers map[string]*filteredServerEntry
ttl time.Duration
mu sync.RWMutex
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

filteredServerCache declares mu sync.RWMutex but getOrCreate() always takes the write lock (Lock) and never uses RLock/RUnlock. If there’s no read-only fast path, consider switching this to sync.Mutex to better reflect usage and avoid confusion about intended concurrency behavior.

Suggested change
mu sync.RWMutex
mu sync.Mutex

Copilot uses AI. Check for mistakes.
}

// newFilteredServerCache creates a new server cache
func newFilteredServerCache() *filteredServerCache {
type filteredServerEntry struct {
server *sdk.Server
lastUsed time.Time
}

// newFilteredServerCache creates a new server cache with the given entry TTL.
func newFilteredServerCache(ttl time.Duration) *filteredServerCache {
return &filteredServerCache{
servers: make(map[string]*sdk.Server),
servers: make(map[string]*filteredServerEntry),
ttl: ttl,
}
}

// getOrCreate returns a cached server or creates a new one
// getOrCreate returns a cached server or creates a new one.
// Expired entries are lazily evicted on each call.
func (c *filteredServerCache) getOrCreate(backendID, sessionID string, creator func() *sdk.Server) *sdk.Server {
key := fmt.Sprintf("%s/%s", backendID, sessionID)
now := time.Now()

server, _ := syncutil.GetOrCreate(&c.mu, c.servers, key, func() (*sdk.Server, error) {
logRouted.Printf("[CACHE] Creating new filtered server: backend=%s, session=%s", backendID, sessionID)
return creator(), nil
})
c.mu.Lock()
defer c.mu.Unlock()

// Lazy eviction of expired entries
for k, entry := range c.servers {
if now.Sub(entry.lastUsed) > c.ttl {
logRouted.Printf("[CACHE] Evicting expired server: key=%s (idle %s)", k, now.Sub(entry.lastUsed).Round(time.Second))
delete(c.servers, k)
}
}
Comment on lines +61 to +70
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getOrCreate() now scans and potentially logs/evicts every cache entry on every call (the for k, entry := range c.servers loop). Since this path runs per request, the O(n) eviction can become a CPU and lock-contention hotspot as the cache grows. Consider reducing eviction frequency (e.g., only evict when now passes a nextEvictAt, evict a bounded number of entries per call, or run periodic cleanup in a background goroutine).

Copilot uses AI. Check for mistakes.

if entry, ok := c.servers[key]; ok {
entry.lastUsed = now
return entry.server
}

logRouted.Printf("[CACHE] Creating new filtered server: backend=%s, session=%s", backendID, sessionID)
server := creator()
c.servers[key] = &filteredServerEntry{server: server, lastUsed: now}
Comment on lines +65 to +79
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

filteredServerCache logs raw session identifiers ("session=%s" and the full "key=%s" which includes sessionID). In this codebase session IDs are API keys and should be truncated before logging to avoid secret leakage (see internal/auth/header.go TruncateSessionID and existing usage in internal/server/session.go). Update these log lines to use auth.TruncateSessionID(sessionID) (and avoid embedding the full sessionID inside the logged key).

Copilot uses AI. Check for mistakes.
return server
Comment on lines +55 to 80
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new TTL-based eviction behavior in filteredServerCache (lastUsed tracking + lazy eviction) isn’t covered by tests. Since internal/server has routed_test.go, it would be good to add a focused unit test for getOrCreate() verifying that entries older than ttl are evicted and recreated, while recently-used entries are retained (using a controllable clock or very short ttl).

Copilot uses AI. Check for mistakes.
}

Expand All @@ -71,8 +96,10 @@ func CreateHTTPServerForRoutedMode(addr string, unifiedServer *UnifiedServer, ap
allBackends := unifiedServer.GetServerIDs()
logRouted.Printf("Registering routes for %d backends: %v", len(allBackends), allBackends)

// Create server cache for session-aware server instances
serverCache := newFilteredServerCache()
// Create server cache for session-aware server instances.
// TTL matches the SDK SessionTimeout so cache entries expire with sessions.
routedSessionTimeout := 30 * time.Minute
serverCache := newFilteredServerCache(routedSessionTimeout)

// Create a proxy for each backend server
for _, serverID := range allBackends {
Expand All @@ -95,7 +122,7 @@ func CreateHTTPServerForRoutedMode(addr string, unifiedServer *UnifiedServer, ap
}, &sdk.StreamableHTTPOptions{
Stateless: false,
Logger: logger.NewSlogLoggerWithHandler(logRouted),
SessionTimeout: 30 * time.Minute,
SessionTimeout: routedSessionTimeout,
})

// Apply standard middleware stack (SDK logging → shutdown check → auth)
Expand Down
2 changes: 1 addition & 1 deletion internal/testutil/mcptest/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func (s *Server) Start() error {
Version: s.config.Version,
}

s.server = sdk.NewServer(impl, nil)
s.server = sdk.NewServer(impl, &sdk.ServerOptions{})

// Register tools
for i, toolCfg := range s.config.Tools {
Expand Down
Loading