From 96c2f7b746a6aea7ddfabc439c013443a2286c96 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 30 Mar 2026 15:32:21 +0000 Subject: [PATCH 1/3] Initial plan From afd4a8451007e1915bf15eb99b86c299bd53c9b6 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 30 Mar 2026 15:44:07 +0000 Subject: [PATCH 2/3] fix: eliminate duplicate-code patterns across server and guard packages - Add SessionIDFromContext helper to session.go (canonical context extraction) - Fix latent panic in routed.go:91 (unsafe bare type assertion) - Fix partially-safe nil-then-cast in unified.go:245-250 - Simplify UnifiedServer.getSessionID to delegate to SessionIDFromContext - Add writeErrorResponse helper to http_helpers.go for consistent JSON errors - Replace http.Error calls in auth.go and handlers.go with writeErrorResponse - Update auth_test.go, handlers_test.go, http_helpers_test.go, transport_test.go to expect JSON error bodies instead of plain text - Extract callWasmGuardFunction helper in wasm.go (shared lock/backend/marshal/call) - Refactor LabelAgent, LabelResource, LabelResponse to use the new helper Closes #2826, #2827, #2828 (sub-issues of #2825) Agent-Logs-Url: https://github.com/github/gh-aw-mcpg/sessions/25cbb04b-116d-40aa-aea0-f59a319da6d5 Co-authored-by: lpcox <15877973+lpcox@users.noreply.github.com> --- internal/guard/wasm.go | 68 ++++++++++------------------ internal/server/auth.go | 4 +- internal/server/auth_test.go | 14 +++--- internal/server/handlers.go | 2 +- internal/server/handlers_test.go | 8 ++-- internal/server/http_helpers.go | 10 ++++ internal/server/http_helpers_test.go | 4 +- internal/server/routed.go | 2 +- internal/server/session.go | 25 +++++++--- internal/server/transport_test.go | 5 +- internal/server/unified.go | 8 +--- 11 files changed, 72 insertions(+), 78 deletions(-) diff --git a/internal/guard/wasm.go b/internal/guard/wasm.go index b91fe4c1..1f6089eb 100644 --- a/internal/guard/wasm.go +++ b/internal/guard/wasm.go @@ -600,6 +600,26 @@ func parseLabelAgentResponse(resultJSON []byte) (*LabelAgentResult, error) { return &result, nil } +// callWasmGuardFunction serialises WASM access, sets the backend reference, marshals +// inputData, logs the input, calls the named WASM export, and returns the raw result. +// All three public dispatch methods (LabelAgent, LabelResource, LabelResponse) share +// this preamble; keeping it in one place ensures locking and backend-update logic +// cannot drift between them. +func (g *WasmGuard) callWasmGuardFunction(ctx context.Context, funcName string, backend BackendCaller, inputData map[string]interface{}) ([]byte, error) { + g.mu.Lock() + defer g.mu.Unlock() + + g.backend = backend + + inputJSON, err := json.Marshal(inputData) + if err != nil { + return nil, fmt.Errorf("failed to marshal %s input: %w", funcName, err) + } + logWasm.Printf("%s input JSON (%d bytes): %s", funcName, len(inputJSON), string(inputJSON)) + + return g.callWasmFunction(ctx, funcName, inputJSON) +} + // LabelAgent calls the WASM module's label_agent function. func (g *WasmGuard) LabelAgent(ctx context.Context, policy interface{}, backend BackendCaller, caps *difc.Capabilities) (*LabelAgentResult, error) { logWasm.Printf("LabelAgent called: guard=%s", g.name) @@ -608,13 +628,6 @@ func (g *WasmGuard) LabelAgent(ctx context.Context, policy interface{}, backend return nil, fmt.Errorf("WASM guard does not export label_agent") } - // Serialize access to the WASM module - g.mu.Lock() - defer g.mu.Unlock() - - // Update backend caller for this request - g.backend = backend - normalizedPolicy, err := normalizePolicyPayload(policy) if err != nil { logWasm.Printf("LabelAgent normalizePolicyPayload failed: guard=%s, error=%v", g.name, err) @@ -634,14 +647,7 @@ func (g *WasmGuard) LabelAgent(ctx context.Context, policy interface{}, backend return nil, err } - inputJSON, err := json.Marshal(input) - if err != nil { - return nil, fmt.Errorf("failed to marshal label_agent input: %w", err) - } - - logWasm.Printf("LabelAgent input JSON (%d bytes): %s", len(inputJSON), string(inputJSON)) - - resultJSON, err := g.callWasmFunction(ctx, "label_agent", inputJSON) + resultJSON, err := g.callWasmGuardFunction(ctx, "label_agent", backend, input) if err != nil { logWasm.Printf("LabelAgent callWasmFunction failed: guard=%s, error=%v", g.name, err) return nil, err @@ -673,13 +679,6 @@ func (g *WasmGuard) LabelAgent(ctx context.Context, policy interface{}, backend func (g *WasmGuard) LabelResource(ctx context.Context, toolName string, args interface{}, backend BackendCaller, caps *difc.Capabilities) (*difc.LabeledResource, difc.OperationType, error) { logWasm.Printf("LabelResource called: toolName=%s, args=%+v", toolName, args) - // Serialize access to the WASM module - g.mu.Lock() - defer g.mu.Unlock() - - // Update backend caller for this request - g.backend = backend - // Prepare input input := map[string]interface{}{ "tool_name": toolName, @@ -689,15 +688,7 @@ func (g *WasmGuard) LabelResource(ctx context.Context, toolName string, args int input["capabilities"] = caps } - inputJSON, err := json.Marshal(input) - if err != nil { - return nil, difc.OperationWrite, fmt.Errorf("failed to marshal input: %w", err) - } - - logWasm.Printf("LabelResource input JSON (%d bytes): %s", len(inputJSON), string(inputJSON)) - - // Call WASM function - resultJSON, err := g.callWasmFunction(ctx, "label_resource", inputJSON) + resultJSON, err := g.callWasmGuardFunction(ctx, "label_resource", backend, input) if err != nil { return nil, difc.OperationWrite, err } @@ -715,13 +706,6 @@ func (g *WasmGuard) LabelResource(ctx context.Context, toolName string, args int func (g *WasmGuard) LabelResponse(ctx context.Context, toolName string, result interface{}, backend BackendCaller, caps *difc.Capabilities) (difc.LabeledData, error) { logWasm.Printf("LabelResponse called: toolName=%s", toolName) - // Serialize access to the WASM module - g.mu.Lock() - defer g.mu.Unlock() - - // Update backend caller for this request - g.backend = backend - // Prepare input input := map[string]interface{}{ "tool_name": toolName, @@ -738,13 +722,7 @@ func (g *WasmGuard) LabelResponse(ctx context.Context, toolName string, result i input["capabilities"] = caps } - inputJSON, err := json.Marshal(input) - if err != nil { - return nil, fmt.Errorf("failed to marshal input: %w", err) - } - - // Call WASM function - resultJSON, err := g.callWasmFunction(ctx, "label_response", inputJSON) + resultJSON, err := g.callWasmGuardFunction(ctx, "label_response", backend, input) if err != nil { return nil, err } diff --git a/internal/server/auth.go b/internal/server/auth.go index fd98c72c..49a723df 100644 --- a/internal/server/auth.go +++ b/internal/server/auth.go @@ -48,7 +48,7 @@ func authMiddleware(apiKey string, next http.HandlerFunc) http.HandlerFunc { // Spec 7.1: Missing token returns 401 logger.LogErrorMd("auth", "Authentication failed: missing Authorization header, remote=%s, path=%s", r.RemoteAddr, r.URL.Path) logRuntimeError("authentication_failed", "missing_auth_header", r, nil) - http.Error(w, "Unauthorized: missing Authorization header", http.StatusUnauthorized) + writeErrorResponse(w, http.StatusUnauthorized, "unauthorized", "missing Authorization header") return } @@ -56,7 +56,7 @@ func authMiddleware(apiKey string, next http.HandlerFunc) http.HandlerFunc { if authHeader != apiKey { logger.LogErrorMd("auth", "Authentication failed: invalid API key, remote=%s, path=%s", r.RemoteAddr, r.URL.Path) logRuntimeError("authentication_failed", "invalid_api_key", r, nil) - http.Error(w, "Unauthorized: invalid API key", http.StatusUnauthorized) + writeErrorResponse(w, http.StatusUnauthorized, "unauthorized", "invalid API key") return } diff --git a/internal/server/auth_test.go b/internal/server/auth_test.go index 7480698c..47a62b42 100644 --- a/internal/server/auth_test.go +++ b/internal/server/auth_test.go @@ -33,7 +33,7 @@ func TestAuthMiddleware(t *testing.T) { authHeader: "", expectStatusCode: http.StatusUnauthorized, expectNextCalled: false, - expectErrorMessage: "Unauthorized: missing Authorization header", + expectErrorMessage: "missing Authorization header", }, { name: "InvalidAPIKey", @@ -41,7 +41,7 @@ func TestAuthMiddleware(t *testing.T) { authHeader: "wrong-key", expectStatusCode: http.StatusUnauthorized, expectNextCalled: false, - expectErrorMessage: "Unauthorized: invalid API key", + expectErrorMessage: "invalid API key", }, { name: "EmptyAPIKeyWithEmptyHeader", @@ -49,7 +49,7 @@ func TestAuthMiddleware(t *testing.T) { authHeader: "", expectStatusCode: http.StatusUnauthorized, expectNextCalled: false, - expectErrorMessage: "Unauthorized: missing Authorization header", + expectErrorMessage: "missing Authorization header", }, { name: "EmptyConfiguredKeyWithValidHeader", @@ -57,7 +57,7 @@ func TestAuthMiddleware(t *testing.T) { authHeader: "some-key", expectStatusCode: http.StatusUnauthorized, expectNextCalled: false, - expectErrorMessage: "Unauthorized: invalid API key", + expectErrorMessage: "invalid API key", }, { name: "CaseSensitiveKey", @@ -65,7 +65,7 @@ func TestAuthMiddleware(t *testing.T) { authHeader: "myapikey", expectStatusCode: http.StatusUnauthorized, expectNextCalled: false, - expectErrorMessage: "Unauthorized: invalid API key", + expectErrorMessage: "invalid API key", }, { name: "WhitespaceNotTrimmed", @@ -73,7 +73,7 @@ func TestAuthMiddleware(t *testing.T) { authHeader: " test-key ", expectStatusCode: http.StatusUnauthorized, expectNextCalled: false, - expectErrorMessage: "Unauthorized: invalid API key", + expectErrorMessage: "invalid API key", }, { name: "BearerSchemeNotSupported", @@ -81,7 +81,7 @@ func TestAuthMiddleware(t *testing.T) { authHeader: "Bearer test-key", expectStatusCode: http.StatusUnauthorized, expectNextCalled: false, - expectErrorMessage: "Unauthorized: invalid API key", + expectErrorMessage: "invalid API key", }, { name: "LongAPIKey", diff --git a/internal/server/handlers.go b/internal/server/handlers.go index 77dd91d2..766ca0cf 100644 --- a/internal/server/handlers.go +++ b/internal/server/handlers.go @@ -41,7 +41,7 @@ func handleClose(unifiedServer *UnifiedServer) http.Handler { // Only accept POST requests if r.Method != http.MethodPost { logHandlers.Printf("Close request rejected: invalid method=%s", r.Method) - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + writeErrorResponse(w, http.StatusMethodNotAllowed, "method_not_allowed", "method not allowed") return } diff --git a/internal/server/handlers_test.go b/internal/server/handlers_test.go index 95886541..15423366 100644 --- a/internal/server/handlers_test.go +++ b/internal/server/handlers_test.go @@ -86,25 +86,25 @@ func TestHandleClose_MethodValidation(t *testing.T) { name: "GET request returns 405", method: http.MethodGet, expectedStatus: http.StatusMethodNotAllowed, - expectedBody: "Method not allowed", + expectedBody: "method not allowed", }, { name: "PUT request returns 405", method: http.MethodPut, expectedStatus: http.StatusMethodNotAllowed, - expectedBody: "Method not allowed", + expectedBody: "method not allowed", }, { name: "DELETE request returns 405", method: http.MethodDelete, expectedStatus: http.StatusMethodNotAllowed, - expectedBody: "Method not allowed", + expectedBody: "method not allowed", }, { name: "PATCH request returns 405", method: http.MethodPatch, expectedStatus: http.StatusMethodNotAllowed, - expectedBody: "Method not allowed", + expectedBody: "method not allowed", }, } diff --git a/internal/server/http_helpers.go b/internal/server/http_helpers.go index 2b8d2945..2ff2d637 100644 --- a/internal/server/http_helpers.go +++ b/internal/server/http_helpers.go @@ -23,6 +23,16 @@ func writeJSONResponse(w http.ResponseWriter, statusCode int, body interface{}) httputil.WriteJSONResponse(w, statusCode, body) } +// writeErrorResponse writes a JSON error response with a consistent shape. +// All HTTP error paths in the server package should use this helper to ensure +// clients always receive application/json rather than text/plain. +func writeErrorResponse(w http.ResponseWriter, statusCode int, code, message string) { + writeJSONResponse(w, statusCode, map[string]string{ + "error": code, + "message": message, + }) +} + // withResponseLogging wraps an http.Handler to log response bodies func withResponseLogging(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/internal/server/http_helpers_test.go b/internal/server/http_helpers_test.go index 5100d1cb..03711155 100644 --- a/internal/server/http_helpers_test.go +++ b/internal/server/http_helpers_test.go @@ -502,7 +502,7 @@ func TestWrapWithMiddleware(t *testing.T) { shutdown: false, expectStatusCode: http.StatusUnauthorized, expectNextCalled: false, - expectErrorMessage: "Unauthorized", + expectErrorMessage: "unauthorized", }, { name: "WithAuth_MissingKey_Unauthorized", @@ -511,7 +511,7 @@ func TestWrapWithMiddleware(t *testing.T) { shutdown: false, expectStatusCode: http.StatusUnauthorized, expectNextCalled: false, - expectErrorMessage: "Unauthorized", + expectErrorMessage: "unauthorized", }, { name: "Shutdown_RejectsRequest", diff --git a/internal/server/routed.go b/internal/server/routed.go index 6a30388c..86a4098c 100644 --- a/internal/server/routed.go +++ b/internal/server/routed.go @@ -88,7 +88,7 @@ func CreateHTTPServerForRoutedMode(addr string, unifiedServer *UnifiedServer, ap // Return a cached filtered proxy server for this backend and session // This ensures the same server instance is reused for all requests in a session - sessionID := r.Context().Value(SessionIDContextKey).(string) + sessionID := SessionIDFromContext(r.Context()) return serverCache.getOrCreate(backendID, sessionID, func() *sdk.Server { return createFilteredServer(unifiedServer, backendID) }) diff --git a/internal/server/session.go b/internal/server/session.go index 51bcb113..1e22f569 100644 --- a/internal/server/session.go +++ b/internal/server/session.go @@ -26,17 +26,28 @@ func NewSession(sessionID, token string) *Session { } } +// SessionIDFromContext returns the MCP session ID stored in ctx, or "default" if the +// context contains no session ID (or one of the wrong type). This is the canonical +// place in the server package that reads SessionIDContextKey directly. +func SessionIDFromContext(ctx context.Context) string { + if id, ok := ctx.Value(SessionIDContextKey).(string); ok && id != "" { + return id + } + return "default" +} + // getSessionID extracts the MCP session ID from the context func (us *UnifiedServer) getSessionID(ctx context.Context) string { - if sessionID, ok := ctx.Value(SessionIDContextKey).(string); ok && sessionID != "" { + sessionID := SessionIDFromContext(ctx) + if sessionID != "default" { logSession.Printf("Extracted session ID from context: %s", auth.TruncateSessionID(sessionID)) - return sessionID + } else { + // No session ID in context - this happens before the SDK assigns one + // For now, use "default" as a placeholder for single-client scenarios + // In production multi-agent scenarios, the SDK will provide session IDs after initialize + log.Printf("No session ID in context, using 'default' (this is normal before SDK session is established)") } - // No session ID in context - this happens before the SDK assigns one - // For now, use "default" as a placeholder for single-client scenarios - // In production multi-agent scenarios, the SDK will provide session IDs after initialize - log.Printf("No session ID in context, using 'default' (this is normal before SDK session is established)") - return "default" + return sessionID } // ensureSessionDirectory creates the session subdirectory in the payload directory if it doesn't exist diff --git a/internal/server/transport_test.go b/internal/server/transport_test.go index 2c4b29e6..0ca0b133 100644 --- a/internal/server/transport_test.go +++ b/internal/server/transport_test.go @@ -343,10 +343,9 @@ func TestCreateHTTPServerForMCP_Close(t *testing.T) { if tt.wantError { switch tt.wantStatusCode { case http.StatusMethodNotAllowed: - // http.Error writes plain text for 405 - assert.Contains(t, w.Body.String(), "Method not allowed") + assert.Contains(t, w.Body.String(), "method not allowed") case http.StatusUnauthorized: - assert.Contains(t, w.Body.String(), "Unauthorized") + assert.Contains(t, w.Body.String(), "unauthorized") } } else { // Success response should be JSON diff --git a/internal/server/unified.go b/internal/server/unified.go index 380f5b78..21c55ead 100644 --- a/internal/server/unified.go +++ b/internal/server/unified.go @@ -241,13 +241,9 @@ func (g *guardBackendCaller) CallTool(ctx context.Context, toolName string, args // This bypasses DIFC checks since it's internal to the guard log.Printf("[DIFC] Guard calling backend %s tool %s for metadata", g.serverID, toolName) - // Get or launch backend connection (use session-aware connection for stateful backends) - sessionID := g.ctx.Value(SessionIDContextKey) - if sessionID == nil { - sessionID = "default" - } + sessionID := SessionIDFromContext(g.ctx) - return executeBackendToolCall(g.ctx, g.server.launcher, g.serverID, sessionID.(string), toolName, args) + return executeBackendToolCall(g.ctx, g.server.launcher, g.serverID, sessionID, toolName, args) } // newErrorCallToolResult creates a standard error CallToolResult with the error message From 3a055e0c8cb84b77eb314cdbfee32952285e2f99 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 30 Mar 2026 15:46:53 +0000 Subject: [PATCH 3/3] refactor: address code review feedback - Remove SessionIDFromContext != default coupling in getSessionID - Drop now-unused log import from session.go - Add comment explaining normalization-before-lock safety in LabelAgent Agent-Logs-Url: https://github.com/github/gh-aw-mcpg/sessions/25cbb04b-116d-40aa-aea0-f59a319da6d5 Co-authored-by: lpcox <15877973+lpcox@users.noreply.github.com> --- internal/guard/wasm.go | 3 +++ internal/server/session.go | 10 +--------- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/internal/guard/wasm.go b/internal/guard/wasm.go index 1f6089eb..d36795ae 100644 --- a/internal/guard/wasm.go +++ b/internal/guard/wasm.go @@ -628,6 +628,9 @@ func (g *WasmGuard) LabelAgent(ctx context.Context, policy interface{}, backend return nil, fmt.Errorf("WASM guard does not export label_agent") } + // Normalisation and payload-build operate only on the caller-supplied `policy` + // argument and do not access any g.* fields, so they are safe to run outside + // the lock that callWasmGuardFunction acquires. normalizedPolicy, err := normalizePolicyPayload(policy) if err != nil { logWasm.Printf("LabelAgent normalizePolicyPayload failed: guard=%s, error=%v", g.name, err) diff --git a/internal/server/session.go b/internal/server/session.go index 1e22f569..15abf7f4 100644 --- a/internal/server/session.go +++ b/internal/server/session.go @@ -3,7 +3,6 @@ package server import ( "context" "fmt" - "log" "os" "path/filepath" "time" @@ -39,14 +38,7 @@ func SessionIDFromContext(ctx context.Context) string { // getSessionID extracts the MCP session ID from the context func (us *UnifiedServer) getSessionID(ctx context.Context) string { sessionID := SessionIDFromContext(ctx) - if sessionID != "default" { - logSession.Printf("Extracted session ID from context: %s", auth.TruncateSessionID(sessionID)) - } else { - // No session ID in context - this happens before the SDK assigns one - // For now, use "default" as a placeholder for single-client scenarios - // In production multi-agent scenarios, the SDK will provide session IDs after initialize - log.Printf("No session ID in context, using 'default' (this is normal before SDK session is established)") - } + logSession.Printf("Extracted session ID from context: %s", auth.TruncateSessionID(sessionID)) return sessionID }