Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions internal/config/config_stdin.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,16 @@ type StdinConfig struct {
// StdinGatewayConfig represents gateway configuration in stdin JSON format.
// Uses pointers for optional fields to distinguish between unset and zero values.
type StdinGatewayConfig struct {
Port *int `json:"port,omitempty"`
APIKey string `json:"apiKey,omitempty"`
Domain string `json:"domain,omitempty"`
StartupTimeout *int `json:"startupTimeout,omitempty"`
ToolTimeout *int `json:"toolTimeout,omitempty"`
Port *int `json:"port,omitempty"`
APIKey string `json:"apiKey,omitempty"`
Domain string `json:"domain,omitempty"`
StartupTimeout *int `json:"startupTimeout,omitempty"`
ToolTimeout *int `json:"toolTimeout,omitempty"`
KeepaliveInterval *int `json:"keepaliveInterval,omitempty"`
PayloadDir string `json:"payloadDir,omitempty"`
PayloadSizeThreshold *int `json:"payloadSizeThreshold,omitempty"`
TrustedBots []string `json:"trustedBots,omitempty"`
OpenTelemetry *StdinOpenTelemetryConfig `json:"opentelemetry,omitempty"`
OpenTelemetry *StdinOpenTelemetryConfig `json:"opentelemetry,omitempty"`
}

// StdinOpenTelemetryConfig represents the OpenTelemetry configuration in stdin JSON format (spec §4.1.3.6).
Expand Down
12 changes: 11 additions & 1 deletion internal/mcp/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,14 @@ type paginatedPage[T any] struct {
NextCursor string
}

// paginateAllMaxPages is the maximum number of pages that paginateAll will fetch.
// This guards against misbehaving or adversarial backends that return an unbounded
// sequence of pages, which would otherwise consume unbounded memory and time.
const paginateAllMaxPages = 100

// paginateAll collects all items across paginated SDK list calls.
// It returns an error if the backend returns more than paginateAllMaxPages pages,
// protecting against runaway backends.
func paginateAll[T any](
serverID string,
itemKind string,
Expand All @@ -559,7 +566,10 @@ func paginateAll[T any](
logConn.Printf("list%s: received page of %d %s from serverID=%s", itemKind, len(first.Items), itemKind, serverID)

cursor := first.NextCursor
for cursor != "" {
for pageCount := 1; cursor != ""; pageCount++ {
if pageCount >= paginateAllMaxPages {
return nil, fmt.Errorf("list%s: backend serverID=%s returned more than %d pages; aborting to prevent unbounded memory growth", itemKind, serverID, paginateAllMaxPages)
}
page, err := fetch(cursor)
if err != nil {
return nil, err
Expand Down
41 changes: 41 additions & 0 deletions internal/mcp/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -977,3 +977,44 @@ data: {"jsonrpc":"2.0","id":2,"result":{}}
assert.JSONEq(t, originalBody, string(resp.Error.Data))
})
}

// TestPaginateAll tests the paginateAll generic helper.
func TestPaginateAll(t *testing.T) {
t.Run("single page with no cursor returns all items", func(t *testing.T) {
items, err := paginateAll("server1", "tools", func(cursor string) (paginatedPage[string], error) {
return paginatedPage[string]{Items: []string{"a", "b", "c"}, NextCursor: ""}, nil
})
require.NoError(t, err)
assert.Equal(t, []string{"a", "b", "c"}, items)
})

t.Run("multiple pages are collected", func(t *testing.T) {
pages := []paginatedPage[string]{
{Items: []string{"a"}, NextCursor: "page2"},
{Items: []string{"b"}, NextCursor: "page3"},
{Items: []string{"c"}, NextCursor: ""},
}
call := 0
items, err := paginateAll("server1", "tools", func(cursor string) (paginatedPage[string], error) {
page := pages[call]
call++
return page, nil
})
require.NoError(t, err)
assert.Equal(t, []string{"a", "b", "c"}, items)
})

t.Run("exceeding max pages returns error", func(t *testing.T) {
// Each call returns a cursor so the loop never ends naturally.
callCount := 0
_, err := paginateAll("server1", "tools", func(cursor string) (paginatedPage[string], error) {
callCount++
return paginatedPage[string]{Items: []string{"x"}, NextCursor: "next"}, nil
})
require.Error(t, err)
assert.Contains(t, err.Error(), "more than")
assert.Contains(t, err.Error(), "pages")
// Must stop at the page limit, not run forever.
assert.Equal(t, paginateAllMaxPages, callCount)
})
}
21 changes: 19 additions & 2 deletions internal/server/tool_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,25 @@ type launchResult struct {
// InputSchema from backends that use different JSON Schema versions (e.g., draft-07) without
// validation errors, which is critical for clients to understand tool parameters.
//
// The handler's third parameter (pre-validated input) is passed as nil since argument
// unmarshaling is handled inside the handler itself.
// # Three-argument handler convention
//
// Throughout this package, tool handlers use a three-argument form:
//
// func(ctx context.Context, req *sdk.CallToolRequest, state interface{}) (*sdk.CallToolResult, interface{}, error)
//
// This differs from the SDK's native two-argument form. The extra parameters serve
// two internal purposes:
// - state interface{}: reserved for the jq middleware pipeline (currently always nil at
// the call site; middleware may propagate state between pre- and post-processing steps).
// - second return value interface{}: carries intermediate data for the DIFC write-sink
// logger so it can record the raw backend result alongside the final tool result.
//
// The wrapper in this function adapts the three-argument form back to the SDK's two-argument
// form when registering with the SDK server.
//
// NOTE: The Server.AddTool method (used here) skips JSON Schema validation whereas the
// sdk.AddTool function validates the schema. This distinction relies on internal SDK
// behaviour and must be re-verified on every SDK upgrade.
func registerToolWithoutValidation(server *sdk.Server, tool *sdk.Tool, handler func(context.Context, *sdk.CallToolRequest, interface{}) (*sdk.CallToolResult, interface{}, error)) {
server.AddTool(tool, func(ctx context.Context, req *sdk.CallToolRequest) (*sdk.CallToolResult, error) {
result, _, err := handler(ctx, req, nil)
Expand Down
10 changes: 9 additions & 1 deletion internal/testutil/mcptest/config.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package mcptest

import sdk "github.com/modelcontextprotocol/go-sdk/mcp"
import (
"log/slog"

sdk "github.com/modelcontextprotocol/go-sdk/mcp"
)

// ServerConfig defines the configuration for a test MCP server
type ServerConfig struct {
Expand All @@ -12,6 +16,10 @@ type ServerConfig struct {
Tools []ToolConfig
// Resources is the list of resources to expose
Resources []ResourceConfig
// Logger is an optional slog.Logger for SDK diagnostics.
// When nil, defaults to the project debug logger (namespace "testutil:mcptest"),
// which surfaces SDK protocol errors when DEBUG=testutil:* is set.
Logger *slog.Logger
}

// ToolConfig defines a tool for the test server
Expand Down
10 changes: 9 additions & 1 deletion internal/testutil/mcptest/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"log"

"github.com/github/gh-aw-mcpg/internal/logger"
"github.com/github/gh-aw-mcpg/internal/mcp"
sdk "github.com/modelcontextprotocol/go-sdk/mcp"
)
Expand Down Expand Up @@ -36,7 +37,14 @@ func (s *Server) Start() error {
Version: s.config.Version,
}

s.server = sdk.NewServer(impl, &sdk.ServerOptions{})
sdkLogger := s.config.Logger
if sdkLogger == nil {
sdkLogger = logger.NewSlogLoggerWithHandler(logger.New("testutil:mcptest"))
}

s.server = sdk.NewServer(impl, &sdk.ServerOptions{
Logger: sdkLogger,
})
Comment on lines +45 to +47
Copy link

Copilot AI Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logger.Discard() returns a slog.Logger backed by io.Discard, so SDK logs are still suppressed. If the intent here is to make protocol/SDK errors visible during failing tests, use a logger that writes to os.Stderr (or reuse logger.NewSlogLoggerWithHandler(logger.New("testutil:mcptest"))), or make the logger configurable so tests can opt into output when needed.

Copilot uses AI. Check for mistakes.

// Register tools
for i, toolCfg := range s.config.Tools {
Expand Down
Loading