diff --git a/internal/config/config_stdin.go b/internal/config/config_stdin.go index 3b15db83..cbc90b84 100644 --- a/internal/config/config_stdin.go +++ b/internal/config/config_stdin.go @@ -32,16 +32,16 @@ type StdinConfig struct { // StdinGatewayConfig represents gateway configuration in stdin JSON format. // Uses pointers for optional fields to distinguish between unset and zero values. type StdinGatewayConfig struct { - Port *int `json:"port,omitempty"` - APIKey string `json:"apiKey,omitempty"` - Domain string `json:"domain,omitempty"` - StartupTimeout *int `json:"startupTimeout,omitempty"` - ToolTimeout *int `json:"toolTimeout,omitempty"` + Port *int `json:"port,omitempty"` + APIKey string `json:"apiKey,omitempty"` + Domain string `json:"domain,omitempty"` + StartupTimeout *int `json:"startupTimeout,omitempty"` + ToolTimeout *int `json:"toolTimeout,omitempty"` KeepaliveInterval *int `json:"keepaliveInterval,omitempty"` PayloadDir string `json:"payloadDir,omitempty"` PayloadSizeThreshold *int `json:"payloadSizeThreshold,omitempty"` TrustedBots []string `json:"trustedBots,omitempty"` - OpenTelemetry *StdinOpenTelemetryConfig `json:"opentelemetry,omitempty"` + OpenTelemetry *StdinOpenTelemetryConfig `json:"opentelemetry,omitempty"` } // StdinOpenTelemetryConfig represents the OpenTelemetry configuration in stdin JSON format (spec ยง4.1.3.6). diff --git a/internal/mcp/connection.go b/internal/mcp/connection.go index d2b99407..c7bac2d6 100644 --- a/internal/mcp/connection.go +++ b/internal/mcp/connection.go @@ -544,7 +544,14 @@ type paginatedPage[T any] struct { NextCursor string } +// paginateAllMaxPages is the maximum number of pages that paginateAll will fetch. +// This guards against misbehaving or adversarial backends that return an unbounded +// sequence of pages, which would otherwise consume unbounded memory and time. +const paginateAllMaxPages = 100 + // paginateAll collects all items across paginated SDK list calls. +// It returns an error if the backend returns more than paginateAllMaxPages pages, +// protecting against runaway backends. func paginateAll[T any]( serverID string, itemKind string, @@ -559,7 +566,10 @@ func paginateAll[T any]( logConn.Printf("list%s: received page of %d %s from serverID=%s", itemKind, len(first.Items), itemKind, serverID) cursor := first.NextCursor - for cursor != "" { + for pageCount := 1; cursor != ""; pageCount++ { + if pageCount >= paginateAllMaxPages { + return nil, fmt.Errorf("list%s: backend serverID=%s returned more than %d pages; aborting to prevent unbounded memory growth", itemKind, serverID, paginateAllMaxPages) + } page, err := fetch(cursor) if err != nil { return nil, err diff --git a/internal/mcp/connection_test.go b/internal/mcp/connection_test.go index aa537bca..31067699 100644 --- a/internal/mcp/connection_test.go +++ b/internal/mcp/connection_test.go @@ -977,3 +977,44 @@ data: {"jsonrpc":"2.0","id":2,"result":{}} assert.JSONEq(t, originalBody, string(resp.Error.Data)) }) } + +// TestPaginateAll tests the paginateAll generic helper. +func TestPaginateAll(t *testing.T) { + t.Run("single page with no cursor returns all items", func(t *testing.T) { + items, err := paginateAll("server1", "tools", func(cursor string) (paginatedPage[string], error) { + return paginatedPage[string]{Items: []string{"a", "b", "c"}, NextCursor: ""}, nil + }) + require.NoError(t, err) + assert.Equal(t, []string{"a", "b", "c"}, items) + }) + + t.Run("multiple pages are collected", func(t *testing.T) { + pages := []paginatedPage[string]{ + {Items: []string{"a"}, NextCursor: "page2"}, + {Items: []string{"b"}, NextCursor: "page3"}, + {Items: []string{"c"}, NextCursor: ""}, + } + call := 0 + items, err := paginateAll("server1", "tools", func(cursor string) (paginatedPage[string], error) { + page := pages[call] + call++ + return page, nil + }) + require.NoError(t, err) + assert.Equal(t, []string{"a", "b", "c"}, items) + }) + + t.Run("exceeding max pages returns error", func(t *testing.T) { + // Each call returns a cursor so the loop never ends naturally. + callCount := 0 + _, err := paginateAll("server1", "tools", func(cursor string) (paginatedPage[string], error) { + callCount++ + return paginatedPage[string]{Items: []string{"x"}, NextCursor: "next"}, nil + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "more than") + assert.Contains(t, err.Error(), "pages") + // Must stop at the page limit, not run forever. + assert.Equal(t, paginateAllMaxPages, callCount) + }) +} diff --git a/internal/server/tool_registry.go b/internal/server/tool_registry.go index 423d7eee..76e45ab2 100644 --- a/internal/server/tool_registry.go +++ b/internal/server/tool_registry.go @@ -29,8 +29,25 @@ type launchResult struct { // InputSchema from backends that use different JSON Schema versions (e.g., draft-07) without // validation errors, which is critical for clients to understand tool parameters. // -// The handler's third parameter (pre-validated input) is passed as nil since argument -// unmarshaling is handled inside the handler itself. +// # Three-argument handler convention +// +// Throughout this package, tool handlers use a three-argument form: +// +// func(ctx context.Context, req *sdk.CallToolRequest, state interface{}) (*sdk.CallToolResult, interface{}, error) +// +// This differs from the SDK's native two-argument form. The extra parameters serve +// two internal purposes: +// - state interface{}: reserved for the jq middleware pipeline (currently always nil at +// the call site; middleware may propagate state between pre- and post-processing steps). +// - second return value interface{}: carries intermediate data for the DIFC write-sink +// logger so it can record the raw backend result alongside the final tool result. +// +// The wrapper in this function adapts the three-argument form back to the SDK's two-argument +// form when registering with the SDK server. +// +// NOTE: The Server.AddTool method (used here) skips JSON Schema validation whereas the +// sdk.AddTool function validates the schema. This distinction relies on internal SDK +// behaviour and must be re-verified on every SDK upgrade. func registerToolWithoutValidation(server *sdk.Server, tool *sdk.Tool, handler func(context.Context, *sdk.CallToolRequest, interface{}) (*sdk.CallToolResult, interface{}, error)) { server.AddTool(tool, func(ctx context.Context, req *sdk.CallToolRequest) (*sdk.CallToolResult, error) { result, _, err := handler(ctx, req, nil) diff --git a/internal/testutil/mcptest/config.go b/internal/testutil/mcptest/config.go index 82c902e6..049feaf9 100644 --- a/internal/testutil/mcptest/config.go +++ b/internal/testutil/mcptest/config.go @@ -1,6 +1,10 @@ package mcptest -import sdk "github.com/modelcontextprotocol/go-sdk/mcp" +import ( + "log/slog" + + sdk "github.com/modelcontextprotocol/go-sdk/mcp" +) // ServerConfig defines the configuration for a test MCP server type ServerConfig struct { @@ -12,6 +16,10 @@ type ServerConfig struct { Tools []ToolConfig // Resources is the list of resources to expose Resources []ResourceConfig + // Logger is an optional slog.Logger for SDK diagnostics. + // When nil, defaults to the project debug logger (namespace "testutil:mcptest"), + // which surfaces SDK protocol errors when DEBUG=testutil:* is set. + Logger *slog.Logger } // ToolConfig defines a tool for the test server diff --git a/internal/testutil/mcptest/server.go b/internal/testutil/mcptest/server.go index b3e5c51c..268183a9 100644 --- a/internal/testutil/mcptest/server.go +++ b/internal/testutil/mcptest/server.go @@ -5,6 +5,7 @@ import ( "fmt" "log" + "github.com/github/gh-aw-mcpg/internal/logger" "github.com/github/gh-aw-mcpg/internal/mcp" sdk "github.com/modelcontextprotocol/go-sdk/mcp" ) @@ -36,7 +37,14 @@ func (s *Server) Start() error { Version: s.config.Version, } - s.server = sdk.NewServer(impl, &sdk.ServerOptions{}) + sdkLogger := s.config.Logger + if sdkLogger == nil { + sdkLogger = logger.NewSlogLoggerWithHandler(logger.New("testutil:mcptest")) + } + + s.server = sdk.NewServer(impl, &sdk.ServerOptions{ + Logger: sdkLogger, + }) // Register tools for i, toolCfg := range s.config.Tools {