Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 26 additions & 45 deletions internal/guard/wasm.go
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,26 @@ func parseLabelAgentResponse(resultJSON []byte) (*LabelAgentResult, error) {
return &result, nil
}

// callWasmGuardFunction serialises WASM access, sets the backend reference, marshals
// inputData, logs the input, calls the named WASM export, and returns the raw result.
// All three public dispatch methods (LabelAgent, LabelResource, LabelResponse) share
// this preamble; keeping it in one place ensures locking and backend-update logic
// cannot drift between them.
Comment on lines +603 to +607
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This refactor moves some work in LabelAgent/LabelResource/LabelResponse outside the WASM mutex, but the surrounding documentation elsewhere in the file still states that public methods hold g.mu for their entire duration. Please update the related comments to reflect the new locking boundary so future changes don’t rely on outdated assumptions.

Suggested change
// callWasmGuardFunction serialises WASM access, sets the backend reference, marshals
// inputData, logs the input, calls the named WASM export, and returns the raw result.
// All three public dispatch methods (LabelAgent, LabelResource, LabelResponse) share
// this preamble; keeping it in one place ensures locking and backend-update logic
// cannot drift between them.
// callWasmGuardFunction acquires g.mu to serialise access to the WASM instance and
// to set the backend reference, then marshals inputData, logs the input, calls the
// named WASM export, and returns the raw result. The mutex is held only for the
// duration of this helper; callers may perform additional work before or after
// without holding g.mu. All three public dispatch methods (LabelAgent, LabelResource,
// LabelResponse) use this helper so that their locking and backend-update behaviour
// stays consistent.

Copilot uses AI. Check for mistakes.
func (g *WasmGuard) callWasmGuardFunction(ctx context.Context, funcName string, backend BackendCaller, inputData map[string]interface{}) ([]byte, error) {
g.mu.Lock()
defer g.mu.Unlock()

g.backend = backend

inputJSON, err := json.Marshal(inputData)
if err != nil {
return nil, fmt.Errorf("failed to marshal %s input: %w", funcName, err)
}
logWasm.Printf("%s input JSON (%d bytes): %s", funcName, len(inputJSON), string(inputJSON))
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

callWasmGuardFunction now logs the full marshaled input for every guard export, which means LabelResponse will start logging tool results (potentially large and sensitive) even though it previously didn’t. Consider making logging conditional per funcName (or logging only sizes / redacted summaries) to avoid leaking response data and inflating logs.

Suggested change
logWasm.Printf("%s input JSON (%d bytes): %s", funcName, len(inputJSON), string(inputJSON))
// Log only the function name and payload size to avoid leaking potentially large or sensitive data.
logWasm.Printf("%s input JSON (%d bytes)", funcName, len(inputJSON))

Copilot uses AI. Check for mistakes.

return g.callWasmFunction(ctx, funcName, inputJSON)
}

// LabelAgent calls the WASM module's label_agent function.
func (g *WasmGuard) LabelAgent(ctx context.Context, policy interface{}, backend BackendCaller, caps *difc.Capabilities) (*LabelAgentResult, error) {
logWasm.Printf("LabelAgent called: guard=%s", g.name)
Expand All @@ -608,13 +628,9 @@ func (g *WasmGuard) LabelAgent(ctx context.Context, policy interface{}, backend
return nil, fmt.Errorf("WASM guard does not export label_agent")
}

// Serialize access to the WASM module
g.mu.Lock()
defer g.mu.Unlock()

// Update backend caller for this request
g.backend = backend

// Normalisation and payload-build operate only on the caller-supplied `policy`
// argument and do not access any g.* fields, so they are safe to run outside
// the lock that callWasmGuardFunction acquires.
normalizedPolicy, err := normalizePolicyPayload(policy)
if err != nil {
logWasm.Printf("LabelAgent normalizePolicyPayload failed: guard=%s, error=%v", g.name, err)
Expand All @@ -634,14 +650,7 @@ func (g *WasmGuard) LabelAgent(ctx context.Context, policy interface{}, backend
return nil, err
}

inputJSON, err := json.Marshal(input)
if err != nil {
return nil, fmt.Errorf("failed to marshal label_agent input: %w", err)
}

logWasm.Printf("LabelAgent input JSON (%d bytes): %s", len(inputJSON), string(inputJSON))

resultJSON, err := g.callWasmFunction(ctx, "label_agent", inputJSON)
resultJSON, err := g.callWasmGuardFunction(ctx, "label_agent", backend, input)
if err != nil {
logWasm.Printf("LabelAgent callWasmFunction failed: guard=%s, error=%v", g.name, err)
return nil, err
Expand Down Expand Up @@ -673,13 +682,6 @@ func (g *WasmGuard) LabelAgent(ctx context.Context, policy interface{}, backend
func (g *WasmGuard) LabelResource(ctx context.Context, toolName string, args interface{}, backend BackendCaller, caps *difc.Capabilities) (*difc.LabeledResource, difc.OperationType, error) {
logWasm.Printf("LabelResource called: toolName=%s, args=%+v", toolName, args)

// Serialize access to the WASM module
g.mu.Lock()
defer g.mu.Unlock()

// Update backend caller for this request
g.backend = backend

// Prepare input
input := map[string]interface{}{
"tool_name": toolName,
Expand All @@ -689,15 +691,7 @@ func (g *WasmGuard) LabelResource(ctx context.Context, toolName string, args int
input["capabilities"] = caps
}

inputJSON, err := json.Marshal(input)
if err != nil {
return nil, difc.OperationWrite, fmt.Errorf("failed to marshal input: %w", err)
}

logWasm.Printf("LabelResource input JSON (%d bytes): %s", len(inputJSON), string(inputJSON))

// Call WASM function
resultJSON, err := g.callWasmFunction(ctx, "label_resource", inputJSON)
resultJSON, err := g.callWasmGuardFunction(ctx, "label_resource", backend, input)
if err != nil {
return nil, difc.OperationWrite, err
}
Expand All @@ -715,13 +709,6 @@ func (g *WasmGuard) LabelResource(ctx context.Context, toolName string, args int
func (g *WasmGuard) LabelResponse(ctx context.Context, toolName string, result interface{}, backend BackendCaller, caps *difc.Capabilities) (difc.LabeledData, error) {
logWasm.Printf("LabelResponse called: toolName=%s", toolName)

// Serialize access to the WASM module
g.mu.Lock()
defer g.mu.Unlock()

// Update backend caller for this request
g.backend = backend

// Prepare input
input := map[string]interface{}{
"tool_name": toolName,
Expand All @@ -738,13 +725,7 @@ func (g *WasmGuard) LabelResponse(ctx context.Context, toolName string, result i
input["capabilities"] = caps
}

inputJSON, err := json.Marshal(input)
if err != nil {
return nil, fmt.Errorf("failed to marshal input: %w", err)
}

// Call WASM function
resultJSON, err := g.callWasmFunction(ctx, "label_response", inputJSON)
resultJSON, err := g.callWasmGuardFunction(ctx, "label_response", backend, input)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions internal/server/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ func authMiddleware(apiKey string, next http.HandlerFunc) http.HandlerFunc {
// Spec 7.1: Missing token returns 401
logger.LogErrorMd("auth", "Authentication failed: missing Authorization header, remote=%s, path=%s", r.RemoteAddr, r.URL.Path)
logRuntimeError("authentication_failed", "missing_auth_header", r, nil)
http.Error(w, "Unauthorized: missing Authorization header", http.StatusUnauthorized)
writeErrorResponse(w, http.StatusUnauthorized, "unauthorized", "missing Authorization header")
return
}

// Spec 7.1: Authorization header must contain API key directly (not Bearer scheme)
if authHeader != apiKey {
logger.LogErrorMd("auth", "Authentication failed: invalid API key, remote=%s, path=%s", r.RemoteAddr, r.URL.Path)
logRuntimeError("authentication_failed", "invalid_api_key", r, nil)
http.Error(w, "Unauthorized: invalid API key", http.StatusUnauthorized)
writeErrorResponse(w, http.StatusUnauthorized, "unauthorized", "invalid API key")
return
}

Expand Down
14 changes: 7 additions & 7 deletions internal/server/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,55 +33,55 @@ func TestAuthMiddleware(t *testing.T) {
authHeader: "",
expectStatusCode: http.StatusUnauthorized,
expectNextCalled: false,
expectErrorMessage: "Unauthorized: missing Authorization header",
expectErrorMessage: "missing Authorization header",
},
{
name: "InvalidAPIKey",
configuredAPIKey: "correct-key",
authHeader: "wrong-key",
expectStatusCode: http.StatusUnauthorized,
expectNextCalled: false,
expectErrorMessage: "Unauthorized: invalid API key",
expectErrorMessage: "invalid API key",
},
{
name: "EmptyAPIKeyWithEmptyHeader",
configuredAPIKey: "",
authHeader: "",
expectStatusCode: http.StatusUnauthorized,
expectNextCalled: false,
expectErrorMessage: "Unauthorized: missing Authorization header",
expectErrorMessage: "missing Authorization header",
},
{
name: "EmptyConfiguredKeyWithValidHeader",
configuredAPIKey: "",
authHeader: "some-key",
expectStatusCode: http.StatusUnauthorized,
expectNextCalled: false,
expectErrorMessage: "Unauthorized: invalid API key",
expectErrorMessage: "invalid API key",
},
{
name: "CaseSensitiveKey",
configuredAPIKey: "MyAPIKey",
authHeader: "myapikey",
expectStatusCode: http.StatusUnauthorized,
expectNextCalled: false,
expectErrorMessage: "Unauthorized: invalid API key",
expectErrorMessage: "invalid API key",
},
{
name: "WhitespaceNotTrimmed",
configuredAPIKey: "test-key",
authHeader: " test-key ",
expectStatusCode: http.StatusUnauthorized,
expectNextCalled: false,
expectErrorMessage: "Unauthorized: invalid API key",
expectErrorMessage: "invalid API key",
},
{
name: "BearerSchemeNotSupported",
configuredAPIKey: "test-key",
authHeader: "Bearer test-key",
expectStatusCode: http.StatusUnauthorized,
expectNextCalled: false,
expectErrorMessage: "Unauthorized: invalid API key",
expectErrorMessage: "invalid API key",
},
{
name: "LongAPIKey",
Expand Down
2 changes: 1 addition & 1 deletion internal/server/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func handleClose(unifiedServer *UnifiedServer) http.Handler {
// Only accept POST requests
if r.Method != http.MethodPost {
logHandlers.Printf("Close request rejected: invalid method=%s", r.Method)
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
writeErrorResponse(w, http.StatusMethodNotAllowed, "method_not_allowed", "method not allowed")
return
Comment on lines 41 to 45
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

handleClose now uses writeErrorResponse for the invalid-method branch, but the "already closed" branch still returns a different JSON shape (only an "error" field via writeJSONResponse). If the intent is a consistent {error, message} format for HTTP errors, switch that 410 response to writeErrorResponse (or align writeErrorResponse’s comment/expectations).

Copilot uses AI. Check for mistakes.
}

Expand Down
8 changes: 4 additions & 4 deletions internal/server/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,25 +86,25 @@ func TestHandleClose_MethodValidation(t *testing.T) {
name: "GET request returns 405",
method: http.MethodGet,
expectedStatus: http.StatusMethodNotAllowed,
expectedBody: "Method not allowed",
expectedBody: "method not allowed",
},
{
name: "PUT request returns 405",
method: http.MethodPut,
expectedStatus: http.StatusMethodNotAllowed,
expectedBody: "Method not allowed",
expectedBody: "method not allowed",
},
{
name: "DELETE request returns 405",
method: http.MethodDelete,
expectedStatus: http.StatusMethodNotAllowed,
expectedBody: "Method not allowed",
expectedBody: "method not allowed",
},
{
name: "PATCH request returns 405",
method: http.MethodPatch,
expectedStatus: http.StatusMethodNotAllowed,
expectedBody: "Method not allowed",
expectedBody: "method not allowed",
},
}

Expand Down
10 changes: 10 additions & 0 deletions internal/server/http_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@ func writeJSONResponse(w http.ResponseWriter, statusCode int, body interface{})
httputil.WriteJSONResponse(w, statusCode, body)
}

// writeErrorResponse writes a JSON error response with a consistent shape.
// All HTTP error paths in the server package should use this helper to ensure
// clients always receive application/json rather than text/plain.
func writeErrorResponse(w http.ResponseWriter, statusCode int, code, message string) {
writeJSONResponse(w, statusCode, map[string]string{
"error": code,
"message": message,
})
}
Comment on lines +26 to +34
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new writeErrorResponse comment says all HTTP error paths in the server package should use this helper, but there are still error responses written via other mechanisms (e.g., JSON without a "message" field, or preformatted JSON constants). Either update remaining call sites to use writeErrorResponse, or soften the comment so it doesn’t over-promise a guarantee the package doesn’t yet meet.

Copilot uses AI. Check for mistakes.

// withResponseLogging wraps an http.Handler to log response bodies
func withResponseLogging(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down
4 changes: 2 additions & 2 deletions internal/server/http_helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ func TestWrapWithMiddleware(t *testing.T) {
shutdown: false,
expectStatusCode: http.StatusUnauthorized,
expectNextCalled: false,
expectErrorMessage: "Unauthorized",
expectErrorMessage: "unauthorized",
},
{
name: "WithAuth_MissingKey_Unauthorized",
Expand All @@ -511,7 +511,7 @@ func TestWrapWithMiddleware(t *testing.T) {
shutdown: false,
expectStatusCode: http.StatusUnauthorized,
expectNextCalled: false,
expectErrorMessage: "Unauthorized",
expectErrorMessage: "unauthorized",
},
{
name: "Shutdown_RejectsRequest",
Expand Down
2 changes: 1 addition & 1 deletion internal/server/routed.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func CreateHTTPServerForRoutedMode(addr string, unifiedServer *UnifiedServer, ap

// Return a cached filtered proxy server for this backend and session
// This ensures the same server instance is reused for all requests in a session
sessionID := r.Context().Value(SessionIDContextKey).(string)
sessionID := SessionIDFromContext(r.Context())
return serverCache.getOrCreate(backendID, sessionID, func() *sdk.Server {
return createFilteredServer(unifiedServer, backendID)
})
Expand Down
23 changes: 13 additions & 10 deletions internal/server/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package server
import (
"context"
"fmt"
"log"
"os"
"path/filepath"
"time"
Expand All @@ -26,19 +25,23 @@ func NewSession(sessionID, token string) *Session {
}
}

// getSessionID extracts the MCP session ID from the context
func (us *UnifiedServer) getSessionID(ctx context.Context) string {
if sessionID, ok := ctx.Value(SessionIDContextKey).(string); ok && sessionID != "" {
logSession.Printf("Extracted session ID from context: %s", auth.TruncateSessionID(sessionID))
return sessionID
// SessionIDFromContext returns the MCP session ID stored in ctx, or "default" if the
// context contains no session ID (or one of the wrong type). This is the canonical
// place in the server package that reads SessionIDContextKey directly.
func SessionIDFromContext(ctx context.Context) string {
if id, ok := ctx.Value(SessionIDContextKey).(string); ok && id != "" {
return id
}
// No session ID in context - this happens before the SDK assigns one
// For now, use "default" as a placeholder for single-client scenarios
// In production multi-agent scenarios, the SDK will provide session IDs after initialize
log.Printf("No session ID in context, using 'default' (this is normal before SDK session is established)")
return "default"
}

// getSessionID extracts the MCP session ID from the context
func (us *UnifiedServer) getSessionID(ctx context.Context) string {
sessionID := SessionIDFromContext(ctx)
logSession.Printf("Extracted session ID from context: %s", auth.TruncateSessionID(sessionID))
return sessionID
}

// ensureSessionDirectory creates the session subdirectory in the payload directory if it doesn't exist
func (us *UnifiedServer) ensureSessionDirectory(sessionID string) error {
sessionDir := filepath.Join(us.payloadDir, sessionID)
Expand Down
5 changes: 2 additions & 3 deletions internal/server/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -343,10 +343,9 @@ func TestCreateHTTPServerForMCP_Close(t *testing.T) {
if tt.wantError {
switch tt.wantStatusCode {
case http.StatusMethodNotAllowed:
// http.Error writes plain text for 405
assert.Contains(t, w.Body.String(), "Method not allowed")
assert.Contains(t, w.Body.String(), "method not allowed")
case http.StatusUnauthorized:
assert.Contains(t, w.Body.String(), "Unauthorized")
assert.Contains(t, w.Body.String(), "unauthorized")
}
} else {
// Success response should be JSON
Expand Down
8 changes: 2 additions & 6 deletions internal/server/unified.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,13 +241,9 @@ func (g *guardBackendCaller) CallTool(ctx context.Context, toolName string, args
// This bypasses DIFC checks since it's internal to the guard
log.Printf("[DIFC] Guard calling backend %s tool %s for metadata", g.serverID, toolName)

// Get or launch backend connection (use session-aware connection for stateful backends)
sessionID := g.ctx.Value(SessionIDContextKey)
if sessionID == nil {
sessionID = "default"
}
sessionID := SessionIDFromContext(g.ctx)

return executeBackendToolCall(g.ctx, g.server.launcher, g.serverID, sessionID.(string), toolName, args)
return executeBackendToolCall(g.ctx, g.server.launcher, g.serverID, sessionID, toolName, args)
Comment on lines 242 to +246
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

guardBackendCaller.CallTool takes a ctx parameter but ignores it and instead uses g.ctx for both session ID extraction and the backend call. This prevents guards from applying per-call cancellation/deadlines and is inconsistent with the BackendCaller interface. Prefer using the passed-in ctx for executeBackendToolCall (and derive sessionID from that ctx, or store the sessionID on the struct at construction).

See below for a potential fix:

	sessionID := SessionIDFromContext(ctx)

	return executeBackendToolCall(ctx, g.server.launcher, g.serverID, sessionID, toolName, args)

Copilot uses AI. Check for mistakes.
}

// newErrorCallToolResult creates a standard error CallToolResult with the error message
Expand Down
Loading