diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 581e1619..eea9a2fe 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -324,7 +324,7 @@ func (r *restBackendCaller) CallTool(ctx context.Context, toolName string, args if pp, ok := argsMap["perPage"].(float64); ok { perPage = fmt.Sprintf("%d", int(pp)) } - apiPath = fmt.Sprintf("/search/repositories?q=%s&per_page=%s", query, perPage) + apiPath = fmt.Sprintf("/search/repositories?q=%s&per_page=%s", url.QueryEscape(query), perPage) case "get_collaborator_permission": owner, _ := argsMap["owner"].(string) diff --git a/internal/server/routed.go b/internal/server/routed.go index e95e4a04..fd24a144 100644 --- a/internal/server/routed.go +++ b/internal/server/routed.go @@ -9,6 +9,7 @@ import ( "sync" "time" + "github.com/github/gh-aw-mcpg/internal/auth" "github.com/github/gh-aw-mcpg/internal/httputil" "github.com/github/gh-aw-mcpg/internal/logger" "github.com/github/gh-aw-mcpg/internal/version" @@ -31,12 +32,18 @@ func rejectIfShutdown(unifiedServer *UnifiedServer, next http.Handler, logNamesp }) } +// filteredServerCacheMaxSize is the maximum number of entries the filteredServerCache +// will hold. When the cache is full, the least-recently-used entry is evicted to make room. +const filteredServerCacheMaxSize = 1000 + // filteredServerCache caches filtered server instances per (backend, session) key. // Entries are evicted after the configured TTL to prevent unbounded memory growth -// in long-running deployments with many sessions. +// in long-running deployments with many sessions. A max-size cap provides an additional +// safety guard against an unbounded number of unique sessions. type filteredServerCache struct { servers map[string]*filteredServerEntry ttl time.Duration + maxSize int mu sync.RWMutex } @@ -50,11 +57,13 @@ func newFilteredServerCache(ttl time.Duration) *filteredServerCache { return &filteredServerCache{ servers: make(map[string]*filteredServerEntry), ttl: ttl, + maxSize: filteredServerCacheMaxSize, } } // getOrCreate returns a cached server or creates a new one. -// Expired entries are lazily evicted on each call. +// Expired entries are lazily evicted on each call. When the cache has reached its +// maximum size, the least-recently-used entry is evicted to make room. func (c *filteredServerCache) getOrCreate(backendID, sessionID string, creator func() *sdk.Server) *sdk.Server { key := fmt.Sprintf("%s/%s", backendID, sessionID) now := time.Now() @@ -65,7 +74,7 @@ func (c *filteredServerCache) getOrCreate(backendID, sessionID string, creator f // Lazy eviction of expired entries for k, entry := range c.servers { if now.Sub(entry.lastUsed) > c.ttl { - logRouted.Printf("[CACHE] Evicting expired server: key=%s (idle %s)", k, now.Sub(entry.lastUsed).Round(time.Second)) + logRouted.Printf("[CACHE] Evicting expired server: key=%s (idle %s)", auth.TruncateSessionID(k), now.Sub(entry.lastUsed).Round(time.Second)) delete(c.servers, k) } } @@ -75,7 +84,15 @@ func (c *filteredServerCache) getOrCreate(backendID, sessionID string, creator f return entry.server } - logRouted.Printf("[CACHE] Creating new filtered server: backend=%s, session=%s", backendID, sessionID) + // Safety bound: if at capacity after TTL eviction, log a warning but do not + // evict non-expired entries. Routed mode relies on reusing the same filtered + // server instance for a given (backend, session), and evicting an active entry + // would recreate that server mid-session, breaking StreamableHTTP semantics. + if len(c.servers) >= c.maxSize { + logRouted.Printf("[CACHE] Max size reached (%d), retaining active entries until TTL eviction", c.maxSize) + } + + logRouted.Printf("[CACHE] Creating new filtered server: backend=%s, session=%s", backendID, auth.TruncateSessionID(sessionID)) server := creator() c.servers[key] = &filteredServerEntry{server: server, lastUsed: now} return server @@ -172,22 +189,16 @@ func createFilteredServer(unifiedServer *UnifiedServer, backendID string) *sdk.S continue } - // Use Server.AddTool method (not sdk.AddTool function) to avoid schema validation - // This allows including InputSchema from backends using different JSON Schema versions - // Wrap the typed handler to match the simple ToolHandler signature - wrappedHandler := func(ctx context.Context, req *sdk.CallToolRequest) (*sdk.CallToolResult, error) { - // Call the unified server's handler directly - // This ensures we go through the same session and connection pool - log.Printf("[ROUTED] Calling unified handler for: %s", toolNameCopy) - result, _, err := handler(ctx, req, nil) - return result, err - } - - server.AddTool(&sdk.Tool{ + // Use registerToolWithoutValidation to bypass JSON Schema validation, allowing + // InputSchema from backends using different JSON Schema versions (e.g., draft-07). + registerToolWithoutValidation(server, &sdk.Tool{ Name: toolInfo.Name, // Without prefix for the client Description: toolInfo.Description, InputSchema: toolInfo.InputSchema, // Include schema for clients - }, wrappedHandler) + }, func(ctx context.Context, req *sdk.CallToolRequest, _ interface{}) (*sdk.CallToolResult, interface{}, error) { + log.Printf("[ROUTED] Calling unified handler for: %s", toolNameCopy) + return handler(ctx, req, nil) + }) } return server diff --git a/internal/server/routed_test.go b/internal/server/routed_test.go index 8f61dac7..131cd979 100644 --- a/internal/server/routed_test.go +++ b/internal/server/routed_test.go @@ -6,6 +6,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -556,6 +557,121 @@ func TestCreateFilteredServer_EdgeCases(t *testing.T) { }) } +// TestFilteredServerCache_MaxSize verifies that the cache allows growth beyond maxSize +// when all entries are still active (non-expired), to avoid disrupting sessions. +func TestFilteredServerCache_MaxSize(t *testing.T) { + assert := assert.New(t) + + ttl := time.Hour + cache := newFilteredServerCache(ttl) + cache.maxSize = 3 // Use a small max for the test + + callCount := 0 + creator := func() *sdk.Server { + callCount++ + return sdk.NewServer(&sdk.Implementation{Name: "test", Version: "1.0"}, &sdk.ServerOptions{}) + } + + // Fill the cache to max capacity + s1 := cache.getOrCreate("backend", "session1", creator) + s2 := cache.getOrCreate("backend", "session2", creator) + s3 := cache.getOrCreate("backend", "session3", creator) + assert.Equal(3, callCount, "Should have created 3 servers") + assert.NotNil(s1) + assert.NotNil(s2) + assert.NotNil(s3) + assert.Equal(3, len(cache.servers), "Cache should have 3 entries") + + // Adding a fourth entry should be allowed (no LRU eviction of active sessions) + s4 := cache.getOrCreate("backend", "session4", creator) + assert.Equal(4, callCount, "Should have created a 4th server") + assert.NotNil(s4) + assert.Equal(4, len(cache.servers), "Cache should grow beyond maxSize for active sessions") + + // All sessions should still be present + _, session1Exists := cache.servers["backend/session1"] + assert.True(session1Exists, "session1 should still be cached") + _, session2Exists := cache.servers["backend/session2"] + assert.True(session2Exists, "session2 should still be cached") + _, session3Exists := cache.servers["backend/session3"] + assert.True(session3Exists, "session3 should still be cached") + _, session4Exists := cache.servers["backend/session4"] + assert.True(session4Exists, "session4 should still be cached") +} + +// TestFilteredServerCache_TTLEviction verifies that expired entries are evicted. +func TestFilteredServerCache_TTLEviction(t *testing.T) { + assert := assert.New(t) + + ttl := 100 * time.Millisecond + cache := newFilteredServerCache(ttl) + + callCount := 0 + creator := func() *sdk.Server { + callCount++ + return sdk.NewServer(&sdk.Implementation{Name: "test", Version: "1.0"}, &sdk.ServerOptions{}) + } + + // Add an entry + cache.getOrCreate("backend", "session1", creator) + assert.Equal(1, callCount) + assert.Equal(1, len(cache.servers)) + + // Wait for TTL to expire (use generous margin to avoid CI flakiness) + time.Sleep(200 * time.Millisecond) + + // Next call should evict the expired entry and create a new one + cache.getOrCreate("backend", "session2", creator) + assert.Equal(2, callCount, "Should have created a new server after TTL eviction") + + // session1 should have been evicted during the lazy eviction scan + _, session1Exists := cache.servers["backend/session1"] + assert.False(session1Exists, "Expired session1 should have been evicted") +} + +// TestRegisterToolWithoutValidation verifies that tools are registered on the server +// and that the wrapped handler forwards calls correctly via in-memory transport. +func TestRegisterToolWithoutValidation(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + server := sdk.NewServer(&sdk.Implementation{Name: "test", Version: "1.0"}, &sdk.ServerOptions{}) + + var handlerCalled bool + handler := func(ctx context.Context, req *sdk.CallToolRequest, state interface{}) (*sdk.CallToolResult, interface{}, error) { + handlerCalled = true + return &sdk.CallToolResult{IsError: false}, nil, nil + } + + registerToolWithoutValidation(server, &sdk.Tool{ + Name: "test_tool", + Description: "A test tool", + InputSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{}, + }, + }, handler) + + // Use in-memory transports to connect a client to the server and invoke the tool + serverTransport, clientTransport := sdk.NewInMemoryTransports() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + go func() { + _ = server.Run(ctx, serverTransport) + }() + + client := sdk.NewClient(&sdk.Implementation{Name: "test-client", Version: "1.0"}, &sdk.ClientOptions{}) + clientSession, err := client.Connect(ctx, clientTransport, nil) + require.NoError(err) + defer clientSession.Close() + + result, err := clientSession.CallTool(ctx, &sdk.CallToolParams{Name: "test_tool"}) + require.NoError(err) + assert.False(result.IsError) + assert.True(handlerCalled, "Handler should have been called") +} + // TestCreateHTTPServerForRoutedMode_OAuth tests OAuth discovery endpoint in routed mode func TestCreateHTTPServerForRoutedMode_OAuth(t *testing.T) { tests := []struct { diff --git a/internal/server/tool_registry.go b/internal/server/tool_registry.go index 3ba14eda..423d7eee 100644 --- a/internal/server/tool_registry.go +++ b/internal/server/tool_registry.go @@ -24,6 +24,20 @@ type launchResult struct { duration time.Duration } +// registerToolWithoutValidation registers a tool with the SDK server using the Server.AddTool +// method (not the sdk.AddTool function) to bypass JSON Schema validation. This allows including +// InputSchema from backends that use different JSON Schema versions (e.g., draft-07) without +// validation errors, which is critical for clients to understand tool parameters. +// +// The handler's third parameter (pre-validated input) is passed as nil since argument +// unmarshaling is handled inside the handler itself. +func registerToolWithoutValidation(server *sdk.Server, tool *sdk.Tool, handler func(context.Context, *sdk.CallToolRequest, interface{}) (*sdk.CallToolResult, interface{}, error)) { + server.AddTool(tool, func(ctx context.Context, req *sdk.CallToolRequest) (*sdk.CallToolResult, error) { + result, _, err := handler(ctx, req, nil) + return result, err + }) +} + // registerAllTools fetches and registers tools from all backend servers func (us *UnifiedServer) registerAllTools() error { log.Println("Registering tools from all backends...") @@ -235,28 +249,12 @@ func (us *UnifiedServer) registerToolsFromBackend(serverID string) error { us.tools[prefixedName].Handler = finalHandler us.toolsMu.Unlock() - // Register the tool with the SDK using the Server.AddTool method (not sdk.AddTool function) - // The method version does NOT perform schema validation, allowing us to include - // InputSchema from backends that use different JSON Schema versions (e.g., draft-07) - // without validation errors. This is critical for clients to understand tool parameters. - // - // We need to wrap our typed handler to match the simpler ToolHandler signature. - // The typed handler signature: func(context.Context, *CallToolRequest, interface{}) (*CallToolResult, interface{}, error) - // The simple handler signature: func(context.Context, *CallToolRequest) (*CallToolResult, error) - wrappedHandler := func(ctx context.Context, req *sdk.CallToolRequest) (*sdk.CallToolResult, error) { - // Call the final handler (which may include middleware wrapping) - // The third parameter would be the pre-unmarshaled/validated input if using sdk.AddTool, - // but we handle unmarshaling ourselves in the handler, so we pass nil - result, _, err := finalHandler(ctx, req, nil) - return result, err - } - - us.server.AddTool(&sdk.Tool{ + registerToolWithoutValidation(us.server, &sdk.Tool{ Name: prefixedName, Description: toolDesc, InputSchema: normalizedSchema, // Include the schema for clients to understand parameters Annotations: tool.Annotations, - }, wrappedHandler) + }, finalHandler) log.Printf("Registered tool: %s", logName) } diff --git a/internal/testutil/mcptest/server.go b/internal/testutil/mcptest/server.go index a8491918..b3e5c51c 100644 --- a/internal/testutil/mcptest/server.go +++ b/internal/testutil/mcptest/server.go @@ -2,10 +2,10 @@ package mcptest import ( "context" - "encoding/json" "fmt" "log" + "github.com/github/gh-aw-mcpg/internal/mcp" sdk "github.com/modelcontextprotocol/go-sdk/mcp" ) @@ -48,18 +48,16 @@ func (s *Server) Start() error { Description: tool.Description, InputSchema: tool.InputSchema, }, func(ctx context.Context, req *sdk.CallToolRequest) (*sdk.CallToolResult, error) { - var args map[string]interface{} - if len(req.Params.Arguments) > 0 { - if err := json.Unmarshal(req.Params.Arguments, &args); err != nil { - return &sdk.CallToolResult{ - IsError: true, - Content: []sdk.Content{ - &sdk.TextContent{ - Text: fmt.Sprintf("Failed to parse arguments: %v", err), - }, + args, err := mcp.ParseToolArguments(req) + if err != nil { + return &sdk.CallToolResult{ + IsError: true, + Content: []sdk.Content{ + &sdk.TextContent{ + Text: fmt.Sprintf("Failed to parse tool arguments: %v", err), }, - }, nil - } + }, + }, nil } content, err := tool.Handler(args)