Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ func (r *restBackendCaller) CallTool(ctx context.Context, toolName string, args
if pp, ok := argsMap["perPage"].(float64); ok {
perPage = fmt.Sprintf("%d", int(pp))
}
apiPath = fmt.Sprintf("/search/repositories?q=%s&per_page=%s", query, perPage)
apiPath = fmt.Sprintf("/search/repositories?q=%s&per_page=%s", url.QueryEscape(query), perPage)

case "get_collaborator_permission":
owner, _ := argsMap["owner"].(string)
Expand Down
45 changes: 28 additions & 17 deletions internal/server/routed.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"sync"
"time"

"github.com/github/gh-aw-mcpg/internal/auth"
"github.com/github/gh-aw-mcpg/internal/httputil"
"github.com/github/gh-aw-mcpg/internal/logger"
"github.com/github/gh-aw-mcpg/internal/version"
Expand All @@ -31,12 +32,18 @@ func rejectIfShutdown(unifiedServer *UnifiedServer, next http.Handler, logNamesp
})
}

// filteredServerCacheMaxSize is the maximum number of entries the filteredServerCache
// will hold. When the cache is full, the least-recently-used entry is evicted to make room.
const filteredServerCacheMaxSize = 1000

// filteredServerCache caches filtered server instances per (backend, session) key.
// Entries are evicted after the configured TTL to prevent unbounded memory growth
// in long-running deployments with many sessions.
// in long-running deployments with many sessions. A max-size cap provides an additional
// safety guard against an unbounded number of unique sessions.
type filteredServerCache struct {
servers map[string]*filteredServerEntry
ttl time.Duration
maxSize int
mu sync.RWMutex
}

Expand All @@ -50,11 +57,13 @@ func newFilteredServerCache(ttl time.Duration) *filteredServerCache {
return &filteredServerCache{
servers: make(map[string]*filteredServerEntry),
ttl: ttl,
maxSize: filteredServerCacheMaxSize,
}
}

// getOrCreate returns a cached server or creates a new one.
// Expired entries are lazily evicted on each call.
// Expired entries are lazily evicted on each call. When the cache has reached its
// maximum size, the least-recently-used entry is evicted to make room.
func (c *filteredServerCache) getOrCreate(backendID, sessionID string, creator func() *sdk.Server) *sdk.Server {
key := fmt.Sprintf("%s/%s", backendID, sessionID)
now := time.Now()
Expand All @@ -65,7 +74,7 @@ func (c *filteredServerCache) getOrCreate(backendID, sessionID string, creator f
// Lazy eviction of expired entries
for k, entry := range c.servers {
if now.Sub(entry.lastUsed) > c.ttl {
logRouted.Printf("[CACHE] Evicting expired server: key=%s (idle %s)", k, now.Sub(entry.lastUsed).Round(time.Second))
logRouted.Printf("[CACHE] Evicting expired server: key=%s (idle %s)", auth.TruncateSessionID(k), now.Sub(entry.lastUsed).Round(time.Second))
delete(c.servers, k)
}
}
Expand All @@ -75,7 +84,15 @@ func (c *filteredServerCache) getOrCreate(backendID, sessionID string, creator f
return entry.server
}

logRouted.Printf("[CACHE] Creating new filtered server: backend=%s, session=%s", backendID, sessionID)
// Safety bound: if at capacity after TTL eviction, log a warning but do not
// evict non-expired entries. Routed mode relies on reusing the same filtered
// server instance for a given (backend, session), and evicting an active entry
// would recreate that server mid-session, breaking StreamableHTTP semantics.
if len(c.servers) >= c.maxSize {
logRouted.Printf("[CACHE] Max size reached (%d), retaining active entries until TTL eviction", c.maxSize)
}

logRouted.Printf("[CACHE] Creating new filtered server: backend=%s, session=%s", backendID, auth.TruncateSessionID(sessionID))
server := creator()
c.servers[key] = &filteredServerEntry{server: server, lastUsed: now}
return server
Expand Down Expand Up @@ -172,22 +189,16 @@ func createFilteredServer(unifiedServer *UnifiedServer, backendID string) *sdk.S
continue
}

// Use Server.AddTool method (not sdk.AddTool function) to avoid schema validation
// This allows including InputSchema from backends using different JSON Schema versions
// Wrap the typed handler to match the simple ToolHandler signature
wrappedHandler := func(ctx context.Context, req *sdk.CallToolRequest) (*sdk.CallToolResult, error) {
// Call the unified server's handler directly
// This ensures we go through the same session and connection pool
log.Printf("[ROUTED] Calling unified handler for: %s", toolNameCopy)
result, _, err := handler(ctx, req, nil)
return result, err
}

server.AddTool(&sdk.Tool{
// Use registerToolWithoutValidation to bypass JSON Schema validation, allowing
// InputSchema from backends using different JSON Schema versions (e.g., draft-07).
registerToolWithoutValidation(server, &sdk.Tool{
Name: toolInfo.Name, // Without prefix for the client
Description: toolInfo.Description,
InputSchema: toolInfo.InputSchema, // Include schema for clients
}, wrappedHandler)
}, func(ctx context.Context, req *sdk.CallToolRequest, _ interface{}) (*sdk.CallToolResult, interface{}, error) {
log.Printf("[ROUTED] Calling unified handler for: %s", toolNameCopy)
return handler(ctx, req, nil)
})
}

return server
Expand Down
116 changes: 116 additions & 0 deletions internal/server/routed_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -556,6 +557,121 @@ func TestCreateFilteredServer_EdgeCases(t *testing.T) {
})
}

// TestFilteredServerCache_MaxSize verifies that the cache allows growth beyond maxSize
// when all entries are still active (non-expired), to avoid disrupting sessions.
func TestFilteredServerCache_MaxSize(t *testing.T) {
assert := assert.New(t)

ttl := time.Hour
cache := newFilteredServerCache(ttl)
cache.maxSize = 3 // Use a small max for the test

callCount := 0
creator := func() *sdk.Server {
callCount++
return sdk.NewServer(&sdk.Implementation{Name: "test", Version: "1.0"}, &sdk.ServerOptions{})
}

// Fill the cache to max capacity
s1 := cache.getOrCreate("backend", "session1", creator)
s2 := cache.getOrCreate("backend", "session2", creator)
s3 := cache.getOrCreate("backend", "session3", creator)
assert.Equal(3, callCount, "Should have created 3 servers")
assert.NotNil(s1)
assert.NotNil(s2)
assert.NotNil(s3)
assert.Equal(3, len(cache.servers), "Cache should have 3 entries")

// Adding a fourth entry should be allowed (no LRU eviction of active sessions)
s4 := cache.getOrCreate("backend", "session4", creator)
assert.Equal(4, callCount, "Should have created a 4th server")
assert.NotNil(s4)
assert.Equal(4, len(cache.servers), "Cache should grow beyond maxSize for active sessions")

// All sessions should still be present
_, session1Exists := cache.servers["backend/session1"]
assert.True(session1Exists, "session1 should still be cached")
_, session2Exists := cache.servers["backend/session2"]
assert.True(session2Exists, "session2 should still be cached")
_, session3Exists := cache.servers["backend/session3"]
assert.True(session3Exists, "session3 should still be cached")
_, session4Exists := cache.servers["backend/session4"]
assert.True(session4Exists, "session4 should still be cached")
}

// TestFilteredServerCache_TTLEviction verifies that expired entries are evicted.
func TestFilteredServerCache_TTLEviction(t *testing.T) {
assert := assert.New(t)

ttl := 100 * time.Millisecond
cache := newFilteredServerCache(ttl)

callCount := 0
creator := func() *sdk.Server {
callCount++
return sdk.NewServer(&sdk.Implementation{Name: "test", Version: "1.0"}, &sdk.ServerOptions{})
}

// Add an entry
cache.getOrCreate("backend", "session1", creator)
assert.Equal(1, callCount)
assert.Equal(1, len(cache.servers))

// Wait for TTL to expire (use generous margin to avoid CI flakiness)
time.Sleep(200 * time.Millisecond)

// Next call should evict the expired entry and create a new one
cache.getOrCreate("backend", "session2", creator)
assert.Equal(2, callCount, "Should have created a new server after TTL eviction")

// session1 should have been evicted during the lazy eviction scan
_, session1Exists := cache.servers["backend/session1"]
assert.False(session1Exists, "Expired session1 should have been evicted")
}

// TestRegisterToolWithoutValidation verifies that tools are registered on the server
// and that the wrapped handler forwards calls correctly via in-memory transport.
func TestRegisterToolWithoutValidation(t *testing.T) {
assert := assert.New(t)
require := require.New(t)

server := sdk.NewServer(&sdk.Implementation{Name: "test", Version: "1.0"}, &sdk.ServerOptions{})

var handlerCalled bool
handler := func(ctx context.Context, req *sdk.CallToolRequest, state interface{}) (*sdk.CallToolResult, interface{}, error) {
handlerCalled = true
return &sdk.CallToolResult{IsError: false}, nil, nil
}

registerToolWithoutValidation(server, &sdk.Tool{
Name: "test_tool",
Description: "A test tool",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{},
},
}, handler)

// Use in-memory transports to connect a client to the server and invoke the tool
serverTransport, clientTransport := sdk.NewInMemoryTransports()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

go func() {
_ = server.Run(ctx, serverTransport)
}()

client := sdk.NewClient(&sdk.Implementation{Name: "test-client", Version: "1.0"}, &sdk.ClientOptions{})
clientSession, err := client.Connect(ctx, clientTransport, nil)
require.NoError(err)
defer clientSession.Close()

result, err := clientSession.CallTool(ctx, &sdk.CallToolParams{Name: "test_tool"})
require.NoError(err)
assert.False(result.IsError)
assert.True(handlerCalled, "Handler should have been called")
}

// TestCreateHTTPServerForRoutedMode_OAuth tests OAuth discovery endpoint in routed mode
func TestCreateHTTPServerForRoutedMode_OAuth(t *testing.T) {
tests := []struct {
Expand Down
34 changes: 16 additions & 18 deletions internal/server/tool_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,20 @@ type launchResult struct {
duration time.Duration
}

// registerToolWithoutValidation registers a tool with the SDK server using the Server.AddTool
// method (not the sdk.AddTool function) to bypass JSON Schema validation. This allows including
// InputSchema from backends that use different JSON Schema versions (e.g., draft-07) without
// validation errors, which is critical for clients to understand tool parameters.
//
// The handler's third parameter (pre-validated input) is passed as nil since argument
// unmarshaling is handled inside the handler itself.
func registerToolWithoutValidation(server *sdk.Server, tool *sdk.Tool, handler func(context.Context, *sdk.CallToolRequest, interface{}) (*sdk.CallToolResult, interface{}, error)) {
server.AddTool(tool, func(ctx context.Context, req *sdk.CallToolRequest) (*sdk.CallToolResult, error) {
result, _, err := handler(ctx, req, nil)
return result, err
})
}

// registerAllTools fetches and registers tools from all backend servers
func (us *UnifiedServer) registerAllTools() error {
log.Println("Registering tools from all backends...")
Expand Down Expand Up @@ -235,28 +249,12 @@ func (us *UnifiedServer) registerToolsFromBackend(serverID string) error {
us.tools[prefixedName].Handler = finalHandler
us.toolsMu.Unlock()

// Register the tool with the SDK using the Server.AddTool method (not sdk.AddTool function)
// The method version does NOT perform schema validation, allowing us to include
// InputSchema from backends that use different JSON Schema versions (e.g., draft-07)
// without validation errors. This is critical for clients to understand tool parameters.
//
// We need to wrap our typed handler to match the simpler ToolHandler signature.
// The typed handler signature: func(context.Context, *CallToolRequest, interface{}) (*CallToolResult, interface{}, error)
// The simple handler signature: func(context.Context, *CallToolRequest) (*CallToolResult, error)
wrappedHandler := func(ctx context.Context, req *sdk.CallToolRequest) (*sdk.CallToolResult, error) {
// Call the final handler (which may include middleware wrapping)
// The third parameter would be the pre-unmarshaled/validated input if using sdk.AddTool,
// but we handle unmarshaling ourselves in the handler, so we pass nil
result, _, err := finalHandler(ctx, req, nil)
return result, err
}

us.server.AddTool(&sdk.Tool{
registerToolWithoutValidation(us.server, &sdk.Tool{
Name: prefixedName,
Description: toolDesc,
InputSchema: normalizedSchema, // Include the schema for clients to understand parameters
Annotations: tool.Annotations,
}, wrappedHandler)
}, finalHandler)

log.Printf("Registered tool: %s", logName)
}
Expand Down
22 changes: 10 additions & 12 deletions internal/testutil/mcptest/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ package mcptest

import (
"context"
"encoding/json"
"fmt"
"log"

"github.com/github/gh-aw-mcpg/internal/mcp"
sdk "github.com/modelcontextprotocol/go-sdk/mcp"
)

Expand Down Expand Up @@ -48,18 +48,16 @@ func (s *Server) Start() error {
Description: tool.Description,
InputSchema: tool.InputSchema,
}, func(ctx context.Context, req *sdk.CallToolRequest) (*sdk.CallToolResult, error) {
var args map[string]interface{}
if len(req.Params.Arguments) > 0 {
if err := json.Unmarshal(req.Params.Arguments, &args); err != nil {
return &sdk.CallToolResult{
IsError: true,
Content: []sdk.Content{
&sdk.TextContent{
Text: fmt.Sprintf("Failed to parse arguments: %v", err),
},
args, err := mcp.ParseToolArguments(req)
if err != nil {
return &sdk.CallToolResult{
IsError: true,
Content: []sdk.Content{
&sdk.TextContent{
Text: fmt.Sprintf("Failed to parse tool arguments: %v", err),
},
}, nil
}
},
}, nil
}

content, err := tool.Handler(args)
Expand Down
Loading