Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 101 additions & 4 deletions internal/middleware/jqschema.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,16 +113,41 @@ func applyJqSchema(ctx context.Context, jsonData interface{}) (string, error) {
func savePayload(baseDir, sessionID, queryID string, payload []byte) (string, error) {
// Create directory structure: {baseDir}/{sessionID}/{queryID}
dir := filepath.Join(baseDir, sessionID, queryID)

logger.LogDebug("payload", "Creating payload directory: baseDir=%s, session=%s, query=%s, fullPath=%s",
baseDir, sessionID, queryID, dir)

if err := os.MkdirAll(dir, 0700); err != nil {
logger.LogError("payload", "Failed to create payload directory: path=%s, error=%v", dir, err)
return "", fmt.Errorf("failed to create payload directory: %w", err)
}

logger.LogDebug("payload", "Successfully created payload directory: path=%s, permissions=0700", dir)

// Save payload to file with restrictive permissions (owner read/write only)
filePath := filepath.Join(dir, "payload.json")
payloadSize := len(payload)

logger.LogInfo("payload", "Writing large payload to filesystem: path=%s, size=%d bytes (%.2f KB, %.2f MB)",
filePath, payloadSize, float64(payloadSize)/1024, float64(payloadSize)/(1024*1024))

if err := os.WriteFile(filePath, payload, 0600); err != nil {
logger.LogError("payload", "Failed to write payload file: path=%s, size=%d bytes, error=%v",
Comment on lines +131 to +135
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These LogInfo messages are emitted for every wrapped tool call (the middleware saves all payloads), and the underlying file logger fsyncs on each log write. This combination can significantly increase I/O and latency in production. Consider downgrading most of these messages to debug (or sampling), and avoid wording like "large payload" unless there is an actual size threshold.

Copilot uses AI. Check for mistakes.
filePath, payloadSize, err)
return "", fmt.Errorf("failed to write payload file: %w", err)
}

logger.LogInfo("payload", "Successfully saved large payload to filesystem: path=%s, size=%d bytes, permissions=0600",
filePath, payloadSize)

// Verify file was written correctly
if stat, err := os.Stat(filePath); err != nil {
logger.LogWarn("payload", "Could not verify payload file after write: path=%s, error=%v", filePath, err)
} else {
logger.LogDebug("payload", "Payload file verified: path=%s, size=%d bytes, mode=%s",
filePath, stat.Size(), stat.Mode())
}

Comment on lines +143 to +150
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The post-write verification does an extra os.Stat() on the hot path. Combined with always-on payload writes, this adds filesystem overhead without affecting correctness. Consider guarding this behind debug logging (or removing it) to reduce unnecessary I/O.

Suggested change
// Verify file was written correctly
if stat, err := os.Stat(filePath); err != nil {
logger.LogWarn("payload", "Could not verify payload file after write: path=%s, error=%v", filePath, err)
} else {
logger.LogDebug("payload", "Payload file verified: path=%s, size=%d bytes, mode=%s",
filePath, stat.Size(), stat.Mode())
}

Copilot uses AI. Check for mistakes.
return filePath, nil
}

Expand All @@ -149,37 +174,66 @@ func WrapToolHandler(
}

logMiddleware.Printf("Processing tool call: tool=%s, queryID=%s, sessionID=%s", toolName, queryID, sessionID)
logger.LogDebug("payload", "Middleware processing tool call: tool=%s, queryID=%s, session=%s, baseDir=%s",
toolName, queryID, sessionID, baseDir)

// Call the original handler
result, data, err := handler(ctx, req, args)
if err != nil {
logMiddleware.Printf("Tool call failed: tool=%s, queryID=%s, sessionID=%s, error=%v", toolName, queryID, sessionID, err)
logger.LogDebug("payload", "Tool call failed, skipping payload storage: tool=%s, queryID=%s, error=%v",
toolName, queryID, err)
return result, data, err
}

// Only process successful results with data
if result == nil || result.IsError || data == nil {
logger.LogDebug("payload", "Skipping payload storage: tool=%s, queryID=%s, reason=%s",
toolName, queryID,
func() string {
if result == nil {
return "result is nil"
} else if result.IsError {
return "result indicates error"
} else {
return "no data returned"
}
}())
return result, data, err
}

// Marshal the response data to JSON
payloadJSON, marshalErr := json.Marshal(data)
if marshalErr != nil {
logMiddleware.Printf("Failed to marshal response: tool=%s, queryID=%s, error=%v", toolName, queryID, marshalErr)
logger.LogError("payload", "Failed to marshal response data to JSON: tool=%s, queryID=%s, error=%v",
toolName, queryID, marshalErr)
return result, data, err
}

payloadSize := len(payloadJSON)
logger.LogInfo("payload", "Response data marshaled to JSON: tool=%s, queryID=%s, size=%d bytes (%.2f KB, %.2f MB)",
toolName, queryID, payloadSize, float64(payloadSize)/1024, float64(payloadSize)/(1024*1024))

// Save the payload
logger.LogInfo("payload", "Starting payload storage to filesystem: tool=%s, queryID=%s, session=%s, baseDir=%s",
toolName, queryID, sessionID, baseDir)

filePath, saveErr := savePayload(baseDir, sessionID, queryID, payloadJSON)
if saveErr != nil {
logMiddleware.Printf("Failed to save payload: tool=%s, queryID=%s, sessionID=%s, error=%v", toolName, queryID, sessionID, saveErr)
logger.LogError("payload", "Failed to save payload to filesystem: tool=%s, queryID=%s, session=%s, error=%v",
toolName, queryID, sessionID, saveErr)
// Continue even if save fails - don't break the tool call
} else {
logMiddleware.Printf("Saved payload: tool=%s, queryID=%s, sessionID=%s, path=%s, size=%d bytes",
toolName, queryID, sessionID, filePath, len(payloadJSON))
logger.LogInfo("payload", "Payload storage completed successfully: tool=%s, queryID=%s, session=%s, path=%s, size=%d bytes",
toolName, queryID, sessionID, filePath, len(payloadJSON))
}

// Apply jq schema transformation
logger.LogDebug("payload", "Applying jq schema transformation: tool=%s, queryID=%s", toolName, queryID)
var schemaJSON string
if schemaErr := func() error {
// Unmarshal to interface{} for jq processing
Expand All @@ -196,17 +250,27 @@ func WrapToolHandler(
return nil
}(); schemaErr != nil {
logMiddleware.Printf("Failed to apply jq schema: tool=%s, queryID=%s, sessionID=%s, error=%v", toolName, queryID, sessionID, schemaErr)
logger.LogWarn("payload", "Failed to generate schema for payload: tool=%s, queryID=%s, error=%v",
toolName, queryID, schemaErr)
// Continue with original response if schema extraction fails
return result, data, err
}

logger.LogDebug("payload", "Schema transformation completed: tool=%s, queryID=%s, schemaSize=%d bytes",
toolName, queryID, len(schemaJSON))

// Build the transformed response: first 500 chars + schema
payloadStr := string(payloadJSON)
var preview string
if len(payloadStr) > 500 {
truncated := len(payloadStr) > 500
if truncated {
preview = payloadStr[:500] + "..."
logger.LogInfo("payload", "Payload truncated for preview: tool=%s, queryID=%s, originalSize=%d bytes, previewSize=500 bytes",
toolName, queryID, len(payloadStr))
} else {
preview = payloadStr
logger.LogDebug("payload", "Payload small enough for full preview: tool=%s, queryID=%s, size=%d bytes",
toolName, queryID, len(payloadStr))
}

// Create rewritten response
Expand All @@ -216,19 +280,52 @@ func WrapToolHandler(
"preview": preview,
"schema": schemaJSON,
"originalSize": len(payloadJSON),
"truncated": len(payloadStr) > 500,
"truncated": truncated,
}

logMiddleware.Printf("Rewritten response: tool=%s, queryID=%s, sessionID=%s, originalSize=%d, truncated=%v",
toolName, queryID, sessionID, len(payloadJSON), len(payloadStr) > 500)
toolName, queryID, sessionID, len(payloadJSON), truncated)
logger.LogInfo("payload", "Created metadata response for client: tool=%s, queryID=%s, session=%s, payloadPath=%s, originalSize=%d bytes, truncated=%v",
toolName, queryID, sessionID, filePath, len(payloadJSON), truncated)
Comment on lines 286 to +289
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If savePayload fails, filePath will be empty but the middleware still builds/returns client metadata and logs a payloadPath. That can mislead clients into trying to read a non-existent file. Consider only returning/advertising payloadPath when saveErr == nil (or explicitly include storage failure info in the metadata).

See below for a potential fix:

			"preview":      preview,
			"schema":       schemaJSON,
			"originalSize": len(payloadJSON),
			"truncated":    truncated,
		}

		// Only advertise payloadPath when we actually have a stored payload
		if filePath != "" {
			rewrittenResponse["payloadPath"] = filePath
		}

		logMiddleware.Printf("Rewritten response: tool=%s, queryID=%s, sessionID=%s, originalSize=%d, truncated=%v",
			toolName, queryID, sessionID, len(payloadJSON), truncated)
		if filePath != "" {
			logger.LogInfo("payload", "Created metadata response for client: tool=%s, queryID=%s, session=%s, payloadPath=%s, originalSize=%d bytes, truncated=%v",
				toolName, queryID, sessionID, filePath, len(payloadJSON), truncated)
		} else {
			logger.LogInfo("payload", "Created metadata response for client without payloadPath: tool=%s, queryID=%s, session=%s, originalSize=%d bytes, truncated=%v",
				toolName, queryID, sessionID, len(payloadJSON), truncated)
		}

Copilot uses AI. Check for mistakes.

// Parse the schema JSON string back to an object for cleaner display
var schemaObj interface{}
if err := json.Unmarshal([]byte(schemaJSON), &schemaObj); err == nil {
rewrittenResponse["schema"] = schemaObj
}

return result, rewrittenResponse, nil
// Marshal the rewritten response to JSON for the Content field
rewrittenJSON, marshalErr := json.Marshal(rewrittenResponse)
if marshalErr != nil {
logMiddleware.Printf("Failed to marshal rewritten response: tool=%s, queryID=%s, error=%v", toolName, queryID, marshalErr)
logger.LogError("payload", "Failed to marshal metadata response: tool=%s, queryID=%s, error=%v",
toolName, queryID, marshalErr)
// Fall back to original result if we can't marshal
return result, rewrittenResponse, nil
}

logger.LogDebug("payload", "Metadata response marshaled: tool=%s, queryID=%s, metadataSize=%d bytes",
toolName, queryID, len(rewrittenJSON))

// Create a new CallToolResult with the transformed content
// Replace the original content with our rewritten response
transformedResult := &sdk.CallToolResult{
Content: []sdk.Content{
&sdk.TextContent{
Text: string(rewrittenJSON),
},
},
IsError: result.IsError,
Meta: result.Meta,
}

logMiddleware.Printf("Transformed result with metadata: tool=%s, queryID=%s, sessionID=%s", toolName, queryID, sessionID)
logger.LogInfo("payload", "Returning transformed response to client: tool=%s, queryID=%s, session=%s, payloadPath=%s, clientReceivesMetadata=true",
toolName, queryID, sessionID, filePath)
logger.LogInfo("payload", "Client can access full payload at: %s (inside container: /workspace/mcp-payloads/%s/%s/payload.json)",
filePath, sessionID, queryID)
Comment on lines +325 to +326
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This log line hardcodes an agent/container path (/workspace/mcp-payloads/...) that depends on workflow-specific mounts and may be incorrect in most deployments. Consider removing the hardcoded path, or deriving/logging it from configuration so operational logs don’t mislead users about where the payload is accessible.

Copilot uses AI. Check for mistakes.

return transformedResult, rewrittenResponse, nil
}
}

Expand Down
59 changes: 58 additions & 1 deletion internal/middleware/jqschema_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,30 @@ func TestMiddlewareIntegration(t *testing.T) {
require.NotNil(t, result, "Result should not be nil")
assert.False(t, result.IsError, "Result should not indicate error")

// Verify response structure
// Verify the result Content field contains the transformed response
require.NotEmpty(t, result.Content, "Result should have Content")
textContent, ok := result.Content[0].(*sdk.TextContent)
require.True(t, ok, "Content should be TextContent")
require.NotEmpty(t, textContent.Text, "TextContent should have text")

// Parse the JSON from Content
var contentMap map[string]interface{}
err = json.Unmarshal([]byte(textContent.Text), &contentMap)
require.NoError(t, err, "Content should be valid JSON")

// Verify all required fields exist in Content
assert.Contains(t, contentMap, "queryID", "Content should contain queryID")
assert.Contains(t, contentMap, "payloadPath", "Content should contain payloadPath")
assert.Contains(t, contentMap, "preview", "Content should contain preview")
assert.Contains(t, contentMap, "schema", "Content should contain schema")
assert.Contains(t, contentMap, "originalSize", "Content should contain originalSize")
assert.Contains(t, contentMap, "truncated", "Content should contain truncated")

// Verify queryID format in Content
queryIDFromContent := contentMap["queryID"].(string)
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests use direct type assertions on values pulled from a map resulting from json.Unmarshal (e.g., contentMap["queryID"].(string)). If the key is missing or the type changes, the test will panic and produce a less actionable failure. Prefer require.Contains/require.True checks (or require.IsType) before asserting types/values.

Suggested change
queryIDFromContent := contentMap["queryID"].(string)
queryIDValue, ok := contentMap["queryID"]
require.True(t, ok, "Content queryID should be present")
queryIDFromContent, ok := queryIDValue.(string)
require.True(t, ok, "Content queryID should be a string")

Copilot uses AI. Check for mistakes.
assert.Len(t, queryIDFromContent, 32, "QueryID should be 32 hex characters")

// Verify response structure in data return value (for internal use)
dataMap, ok := data.(map[string]interface{})
require.True(t, ok, "Response should be a map")

Expand Down Expand Up @@ -180,6 +203,25 @@ func TestMiddlewareWithLargePayload(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, result)

// Verify Content field has transformed response
require.NotEmpty(t, result.Content, "Result should have Content")
textContent, ok := result.Content[0].(*sdk.TextContent)
require.True(t, ok, "Content should be TextContent")

var contentMap map[string]interface{}
err = json.Unmarshal([]byte(textContent.Text), &contentMap)
require.NoError(t, err, "Content should be valid JSON")

// Verify truncation in Content field
truncatedInContent := contentMap["truncated"].(bool)
previewInContent := contentMap["preview"].(string)

if truncatedInContent {
assert.True(t, len(previewInContent) <= 503, "Preview in Content should be truncated")
assert.Contains(t, previewInContent, "...", "Truncated preview in Content should end with ...")
}

// Also check data return value
dataMap := data.(map[string]interface{})

// Verify truncation occurred
Expand Down Expand Up @@ -223,9 +265,24 @@ func TestMiddlewareDirectoryCreation(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, result)

// Verify Content field
require.NotEmpty(t, result.Content, "Result should have Content")
textContent, ok := result.Content[0].(*sdk.TextContent)
require.True(t, ok, "Content should be TextContent")

var contentMap map[string]interface{}
err = json.Unmarshal([]byte(textContent.Text), &contentMap)
require.NoError(t, err, "Content should be valid JSON")

queryIDFromContent := contentMap["queryID"].(string)

// Also check data return value
dataMap := data.(map[string]interface{})
queryID := dataMap["queryID"].(string)

// Both should match
assert.Equal(t, queryID, queryIDFromContent, "QueryID should match in both data and Content")

// Verify directory structure with session ID
expectedDir := filepath.Join(baseDir, sessionID, queryID)
assert.DirExists(t, expectedDir, "Query directory should exist")
Expand Down
60 changes: 43 additions & 17 deletions internal/middleware/jqschema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,26 +140,40 @@ func TestWrapToolHandler(t *testing.T) {
require.NotNil(t, result, "Result should not be nil")
assert.False(t, result.IsError, "Result should not be an error")

// Verify rewritten response structure
dataMap, ok := data.(map[string]interface{})
require.True(t, ok, "Data should be a map")

assert.Contains(t, dataMap, "queryID", "Response should contain queryID")
assert.Contains(t, dataMap, "payloadPath", "Response should contain payloadPath")
assert.Contains(t, dataMap, "preview", "Response should contain preview")
assert.Contains(t, dataMap, "schema", "Response should contain schema")
assert.Contains(t, dataMap, "originalSize", "Response should contain originalSize")
assert.Contains(t, dataMap, "truncated", "Response should contain truncated")
// Verify the result Content field contains the transformed response
require.NotEmpty(t, result.Content, "Result should have Content")
textContent, ok := result.Content[0].(*sdk.TextContent)
require.True(t, ok, "Content should be TextContent")
require.NotEmpty(t, textContent.Text, "TextContent should have text")

// Parse the JSON from Content
var contentMap map[string]interface{}
err = json.Unmarshal([]byte(textContent.Text), &contentMap)
require.NoError(t, err, "Content should be valid JSON")

// Verify transformed response in Content field
assert.Contains(t, contentMap, "queryID", "Content should contain queryID")
assert.Contains(t, contentMap, "payloadPath", "Content should contain payloadPath")
assert.Contains(t, contentMap, "preview", "Content should contain preview")
assert.Contains(t, contentMap, "schema", "Content should contain schema")
assert.Contains(t, contentMap, "originalSize", "Content should contain originalSize")
assert.Contains(t, contentMap, "truncated", "Content should contain truncated")

// Verify queryID is a valid hex string
queryID, ok := dataMap["queryID"].(string)
queryID, ok := contentMap["queryID"].(string)
require.True(t, ok, "queryID should be a string")
assert.NotEmpty(t, queryID, "queryID should not be empty")

// Verify schema is present
schema := dataMap["schema"]
schema := contentMap["schema"]
assert.NotNil(t, schema, "Schema should not be nil")

// Also verify rewritten response in data return value (for internal use)
dataMap, ok := data.(map[string]interface{})
require.True(t, ok, "Data should be a map")
assert.Contains(t, dataMap, "queryID", "Data should contain queryID")
assert.Contains(t, dataMap, "payloadPath", "Data should contain payloadPath")

// Clean up test directory
defer os.RemoveAll(filepath.Join("/tmp", "gh-awmg"))
}
Expand Down Expand Up @@ -216,14 +230,26 @@ func TestWrapToolHandler_LongPayload(t *testing.T) {
require.NoError(t, err, "Should not return error")
require.NotNil(t, result, "Result should not be nil")

dataMap, ok := data.(map[string]interface{})
require.True(t, ok, "Data should be a map")
// Verify Content field contains the transformed response
require.NotEmpty(t, result.Content, "Result should have Content")
textContent, ok := result.Content[0].(*sdk.TextContent)
require.True(t, ok, "Content should be TextContent")

// Parse the JSON from Content
var contentMap map[string]interface{}
err = json.Unmarshal([]byte(textContent.Text), &contentMap)
require.NoError(t, err, "Content should be valid JSON")

// Verify truncation
assert.True(t, dataMap["truncated"].(bool), "Should indicate truncation")
preview := dataMap["preview"].(string)
// Verify truncation in Content field
assert.True(t, contentMap["truncated"].(bool), "Should indicate truncation in Content")
preview := contentMap["preview"].(string)
assert.LessOrEqual(t, len(preview), 503, "Preview should be truncated to ~500 chars + '...'")
assert.True(t, strings.HasSuffix(preview, "..."), "Preview should end with '...'")

// Also verify in data return value
dataMap, ok := data.(map[string]interface{})
require.True(t, ok, "Data should be a map")
assert.True(t, dataMap["truncated"].(bool), "Should indicate truncation in data")
Comment on lines +244 to +252
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test directly type-asserts fields from contentMap (e.g., contentMap["truncated"].(bool), contentMap["preview"].(string)). If the JSON shape changes, the test will panic instead of failing with a clear assertion message. Consider asserting the keys exist and types are expected before using them.

Suggested change
assert.True(t, contentMap["truncated"].(bool), "Should indicate truncation in Content")
preview := contentMap["preview"].(string)
assert.LessOrEqual(t, len(preview), 503, "Preview should be truncated to ~500 chars + '...'")
assert.True(t, strings.HasSuffix(preview, "..."), "Preview should end with '...'")
// Also verify in data return value
dataMap, ok := data.(map[string]interface{})
require.True(t, ok, "Data should be a map")
assert.True(t, dataMap["truncated"].(bool), "Should indicate truncation in data")
truncatedVal, ok := contentMap["truncated"]
require.True(t, ok, "Content JSON should contain 'truncated' field")
truncated, ok := truncatedVal.(bool)
require.True(t, ok, "'truncated' field in Content JSON should be a bool")
assert.True(t, truncated, "Should indicate truncation in Content")
previewVal, ok := contentMap["preview"]
require.True(t, ok, "Content JSON should contain 'preview' field")
preview, ok := previewVal.(string)
require.True(t, ok, "'preview' field in Content JSON should be a string")
assert.LessOrEqual(t, len(preview), 503, "Preview should be truncated to ~500 chars + '...'")
assert.True(t, strings.HasSuffix(preview, "..."), "Preview should end with '...'")
// Also verify in data return value
dataMap, ok := data.(map[string]interface{})
require.True(t, ok, "Data should be a map")
dataTruncatedVal, ok := dataMap["truncated"]
require.True(t, ok, "Data JSON should contain 'truncated' field")
dataTruncated, ok := dataTruncatedVal.(bool)
require.True(t, ok, "'truncated' field in data JSON should be a bool")
assert.True(t, dataTruncated, "Should indicate truncation in data")

Copilot uses AI. Check for mistakes.
}

// TestPayloadStorage_SessionIsolation verifies that payloads are stored in session-specific directories
Expand Down
Loading