diff --git a/internal/middleware/jqschema.go b/internal/middleware/jqschema.go index fa8d5f46..7c22410f 100644 --- a/internal/middleware/jqschema.go +++ b/internal/middleware/jqschema.go @@ -113,16 +113,41 @@ func applyJqSchema(ctx context.Context, jsonData interface{}) (string, error) { func savePayload(baseDir, sessionID, queryID string, payload []byte) (string, error) { // Create directory structure: {baseDir}/{sessionID}/{queryID} dir := filepath.Join(baseDir, sessionID, queryID) + + logger.LogDebug("payload", "Creating payload directory: baseDir=%s, session=%s, query=%s, fullPath=%s", + baseDir, sessionID, queryID, dir) + if err := os.MkdirAll(dir, 0700); err != nil { + logger.LogError("payload", "Failed to create payload directory: path=%s, error=%v", dir, err) return "", fmt.Errorf("failed to create payload directory: %w", err) } + logger.LogDebug("payload", "Successfully created payload directory: path=%s, permissions=0700", dir) + // Save payload to file with restrictive permissions (owner read/write only) filePath := filepath.Join(dir, "payload.json") + payloadSize := len(payload) + + logger.LogInfo("payload", "Writing large payload to filesystem: path=%s, size=%d bytes (%.2f KB, %.2f MB)", + filePath, payloadSize, float64(payloadSize)/1024, float64(payloadSize)/(1024*1024)) + if err := os.WriteFile(filePath, payload, 0600); err != nil { + logger.LogError("payload", "Failed to write payload file: path=%s, size=%d bytes, error=%v", + filePath, payloadSize, err) return "", fmt.Errorf("failed to write payload file: %w", err) } + logger.LogInfo("payload", "Successfully saved large payload to filesystem: path=%s, size=%d bytes, permissions=0600", + filePath, payloadSize) + + // Verify file was written correctly + if stat, err := os.Stat(filePath); err != nil { + logger.LogWarn("payload", "Could not verify payload file after write: path=%s, error=%v", filePath, err) + } else { + logger.LogDebug("payload", "Payload file verified: path=%s, size=%d bytes, mode=%s", + filePath, stat.Size(), stat.Mode()) + } + return filePath, nil } @@ -149,16 +174,31 @@ func WrapToolHandler( } logMiddleware.Printf("Processing tool call: tool=%s, queryID=%s, sessionID=%s", toolName, queryID, sessionID) + logger.LogDebug("payload", "Middleware processing tool call: tool=%s, queryID=%s, session=%s, baseDir=%s", + toolName, queryID, sessionID, baseDir) // Call the original handler result, data, err := handler(ctx, req, args) if err != nil { logMiddleware.Printf("Tool call failed: tool=%s, queryID=%s, sessionID=%s, error=%v", toolName, queryID, sessionID, err) + logger.LogDebug("payload", "Tool call failed, skipping payload storage: tool=%s, queryID=%s, error=%v", + toolName, queryID, err) return result, data, err } // Only process successful results with data if result == nil || result.IsError || data == nil { + logger.LogDebug("payload", "Skipping payload storage: tool=%s, queryID=%s, reason=%s", + toolName, queryID, + func() string { + if result == nil { + return "result is nil" + } else if result.IsError { + return "result indicates error" + } else { + return "no data returned" + } + }()) return result, data, err } @@ -166,20 +206,34 @@ func WrapToolHandler( payloadJSON, marshalErr := json.Marshal(data) if marshalErr != nil { logMiddleware.Printf("Failed to marshal response: tool=%s, queryID=%s, error=%v", toolName, queryID, marshalErr) + logger.LogError("payload", "Failed to marshal response data to JSON: tool=%s, queryID=%s, error=%v", + toolName, queryID, marshalErr) return result, data, err } + payloadSize := len(payloadJSON) + logger.LogInfo("payload", "Response data marshaled to JSON: tool=%s, queryID=%s, size=%d bytes (%.2f KB, %.2f MB)", + toolName, queryID, payloadSize, float64(payloadSize)/1024, float64(payloadSize)/(1024*1024)) + // Save the payload + logger.LogInfo("payload", "Starting payload storage to filesystem: tool=%s, queryID=%s, session=%s, baseDir=%s", + toolName, queryID, sessionID, baseDir) + filePath, saveErr := savePayload(baseDir, sessionID, queryID, payloadJSON) if saveErr != nil { logMiddleware.Printf("Failed to save payload: tool=%s, queryID=%s, sessionID=%s, error=%v", toolName, queryID, sessionID, saveErr) + logger.LogError("payload", "Failed to save payload to filesystem: tool=%s, queryID=%s, session=%s, error=%v", + toolName, queryID, sessionID, saveErr) // Continue even if save fails - don't break the tool call } else { logMiddleware.Printf("Saved payload: tool=%s, queryID=%s, sessionID=%s, path=%s, size=%d bytes", toolName, queryID, sessionID, filePath, len(payloadJSON)) + logger.LogInfo("payload", "Payload storage completed successfully: tool=%s, queryID=%s, session=%s, path=%s, size=%d bytes", + toolName, queryID, sessionID, filePath, len(payloadJSON)) } // Apply jq schema transformation + logger.LogDebug("payload", "Applying jq schema transformation: tool=%s, queryID=%s", toolName, queryID) var schemaJSON string if schemaErr := func() error { // Unmarshal to interface{} for jq processing @@ -196,17 +250,27 @@ func WrapToolHandler( return nil }(); schemaErr != nil { logMiddleware.Printf("Failed to apply jq schema: tool=%s, queryID=%s, sessionID=%s, error=%v", toolName, queryID, sessionID, schemaErr) + logger.LogWarn("payload", "Failed to generate schema for payload: tool=%s, queryID=%s, error=%v", + toolName, queryID, schemaErr) // Continue with original response if schema extraction fails return result, data, err } + logger.LogDebug("payload", "Schema transformation completed: tool=%s, queryID=%s, schemaSize=%d bytes", + toolName, queryID, len(schemaJSON)) + // Build the transformed response: first 500 chars + schema payloadStr := string(payloadJSON) var preview string - if len(payloadStr) > 500 { + truncated := len(payloadStr) > 500 + if truncated { preview = payloadStr[:500] + "..." + logger.LogInfo("payload", "Payload truncated for preview: tool=%s, queryID=%s, originalSize=%d bytes, previewSize=500 bytes", + toolName, queryID, len(payloadStr)) } else { preview = payloadStr + logger.LogDebug("payload", "Payload small enough for full preview: tool=%s, queryID=%s, size=%d bytes", + toolName, queryID, len(payloadStr)) } // Create rewritten response @@ -216,11 +280,13 @@ func WrapToolHandler( "preview": preview, "schema": schemaJSON, "originalSize": len(payloadJSON), - "truncated": len(payloadStr) > 500, + "truncated": truncated, } logMiddleware.Printf("Rewritten response: tool=%s, queryID=%s, sessionID=%s, originalSize=%d, truncated=%v", - toolName, queryID, sessionID, len(payloadJSON), len(payloadStr) > 500) + toolName, queryID, sessionID, len(payloadJSON), truncated) + logger.LogInfo("payload", "Created metadata response for client: tool=%s, queryID=%s, session=%s, payloadPath=%s, originalSize=%d bytes, truncated=%v", + toolName, queryID, sessionID, filePath, len(payloadJSON), truncated) // Parse the schema JSON string back to an object for cleaner display var schemaObj interface{} @@ -228,7 +294,38 @@ func WrapToolHandler( rewrittenResponse["schema"] = schemaObj } - return result, rewrittenResponse, nil + // Marshal the rewritten response to JSON for the Content field + rewrittenJSON, marshalErr := json.Marshal(rewrittenResponse) + if marshalErr != nil { + logMiddleware.Printf("Failed to marshal rewritten response: tool=%s, queryID=%s, error=%v", toolName, queryID, marshalErr) + logger.LogError("payload", "Failed to marshal metadata response: tool=%s, queryID=%s, error=%v", + toolName, queryID, marshalErr) + // Fall back to original result if we can't marshal + return result, rewrittenResponse, nil + } + + logger.LogDebug("payload", "Metadata response marshaled: tool=%s, queryID=%s, metadataSize=%d bytes", + toolName, queryID, len(rewrittenJSON)) + + // Create a new CallToolResult with the transformed content + // Replace the original content with our rewritten response + transformedResult := &sdk.CallToolResult{ + Content: []sdk.Content{ + &sdk.TextContent{ + Text: string(rewrittenJSON), + }, + }, + IsError: result.IsError, + Meta: result.Meta, + } + + logMiddleware.Printf("Transformed result with metadata: tool=%s, queryID=%s, sessionID=%s", toolName, queryID, sessionID) + logger.LogInfo("payload", "Returning transformed response to client: tool=%s, queryID=%s, session=%s, payloadPath=%s, clientReceivesMetadata=true", + toolName, queryID, sessionID, filePath) + logger.LogInfo("payload", "Client can access full payload at: %s (inside container: /workspace/mcp-payloads/%s/%s/payload.json)", + filePath, sessionID, queryID) + + return transformedResult, rewrittenResponse, nil } } diff --git a/internal/middleware/jqschema_integration_test.go b/internal/middleware/jqschema_integration_test.go index 34c28532..689b203d 100644 --- a/internal/middleware/jqschema_integration_test.go +++ b/internal/middleware/jqschema_integration_test.go @@ -64,7 +64,30 @@ func TestMiddlewareIntegration(t *testing.T) { require.NotNil(t, result, "Result should not be nil") assert.False(t, result.IsError, "Result should not indicate error") - // Verify response structure + // Verify the result Content field contains the transformed response + require.NotEmpty(t, result.Content, "Result should have Content") + textContent, ok := result.Content[0].(*sdk.TextContent) + require.True(t, ok, "Content should be TextContent") + require.NotEmpty(t, textContent.Text, "TextContent should have text") + + // Parse the JSON from Content + var contentMap map[string]interface{} + err = json.Unmarshal([]byte(textContent.Text), &contentMap) + require.NoError(t, err, "Content should be valid JSON") + + // Verify all required fields exist in Content + assert.Contains(t, contentMap, "queryID", "Content should contain queryID") + assert.Contains(t, contentMap, "payloadPath", "Content should contain payloadPath") + assert.Contains(t, contentMap, "preview", "Content should contain preview") + assert.Contains(t, contentMap, "schema", "Content should contain schema") + assert.Contains(t, contentMap, "originalSize", "Content should contain originalSize") + assert.Contains(t, contentMap, "truncated", "Content should contain truncated") + + // Verify queryID format in Content + queryIDFromContent := contentMap["queryID"].(string) + assert.Len(t, queryIDFromContent, 32, "QueryID should be 32 hex characters") + + // Verify response structure in data return value (for internal use) dataMap, ok := data.(map[string]interface{}) require.True(t, ok, "Response should be a map") @@ -180,6 +203,25 @@ func TestMiddlewareWithLargePayload(t *testing.T) { require.NoError(t, err) require.NotNil(t, result) + // Verify Content field has transformed response + require.NotEmpty(t, result.Content, "Result should have Content") + textContent, ok := result.Content[0].(*sdk.TextContent) + require.True(t, ok, "Content should be TextContent") + + var contentMap map[string]interface{} + err = json.Unmarshal([]byte(textContent.Text), &contentMap) + require.NoError(t, err, "Content should be valid JSON") + + // Verify truncation in Content field + truncatedInContent := contentMap["truncated"].(bool) + previewInContent := contentMap["preview"].(string) + + if truncatedInContent { + assert.True(t, len(previewInContent) <= 503, "Preview in Content should be truncated") + assert.Contains(t, previewInContent, "...", "Truncated preview in Content should end with ...") + } + + // Also check data return value dataMap := data.(map[string]interface{}) // Verify truncation occurred @@ -223,9 +265,24 @@ func TestMiddlewareDirectoryCreation(t *testing.T) { require.NoError(t, err) require.NotNil(t, result) + // Verify Content field + require.NotEmpty(t, result.Content, "Result should have Content") + textContent, ok := result.Content[0].(*sdk.TextContent) + require.True(t, ok, "Content should be TextContent") + + var contentMap map[string]interface{} + err = json.Unmarshal([]byte(textContent.Text), &contentMap) + require.NoError(t, err, "Content should be valid JSON") + + queryIDFromContent := contentMap["queryID"].(string) + + // Also check data return value dataMap := data.(map[string]interface{}) queryID := dataMap["queryID"].(string) + // Both should match + assert.Equal(t, queryID, queryIDFromContent, "QueryID should match in both data and Content") + // Verify directory structure with session ID expectedDir := filepath.Join(baseDir, sessionID, queryID) assert.DirExists(t, expectedDir, "Query directory should exist") diff --git a/internal/middleware/jqschema_test.go b/internal/middleware/jqschema_test.go index c4143157..c531443a 100644 --- a/internal/middleware/jqschema_test.go +++ b/internal/middleware/jqschema_test.go @@ -140,26 +140,40 @@ func TestWrapToolHandler(t *testing.T) { require.NotNil(t, result, "Result should not be nil") assert.False(t, result.IsError, "Result should not be an error") - // Verify rewritten response structure - dataMap, ok := data.(map[string]interface{}) - require.True(t, ok, "Data should be a map") - - assert.Contains(t, dataMap, "queryID", "Response should contain queryID") - assert.Contains(t, dataMap, "payloadPath", "Response should contain payloadPath") - assert.Contains(t, dataMap, "preview", "Response should contain preview") - assert.Contains(t, dataMap, "schema", "Response should contain schema") - assert.Contains(t, dataMap, "originalSize", "Response should contain originalSize") - assert.Contains(t, dataMap, "truncated", "Response should contain truncated") + // Verify the result Content field contains the transformed response + require.NotEmpty(t, result.Content, "Result should have Content") + textContent, ok := result.Content[0].(*sdk.TextContent) + require.True(t, ok, "Content should be TextContent") + require.NotEmpty(t, textContent.Text, "TextContent should have text") + + // Parse the JSON from Content + var contentMap map[string]interface{} + err = json.Unmarshal([]byte(textContent.Text), &contentMap) + require.NoError(t, err, "Content should be valid JSON") + + // Verify transformed response in Content field + assert.Contains(t, contentMap, "queryID", "Content should contain queryID") + assert.Contains(t, contentMap, "payloadPath", "Content should contain payloadPath") + assert.Contains(t, contentMap, "preview", "Content should contain preview") + assert.Contains(t, contentMap, "schema", "Content should contain schema") + assert.Contains(t, contentMap, "originalSize", "Content should contain originalSize") + assert.Contains(t, contentMap, "truncated", "Content should contain truncated") // Verify queryID is a valid hex string - queryID, ok := dataMap["queryID"].(string) + queryID, ok := contentMap["queryID"].(string) require.True(t, ok, "queryID should be a string") assert.NotEmpty(t, queryID, "queryID should not be empty") // Verify schema is present - schema := dataMap["schema"] + schema := contentMap["schema"] assert.NotNil(t, schema, "Schema should not be nil") + // Also verify rewritten response in data return value (for internal use) + dataMap, ok := data.(map[string]interface{}) + require.True(t, ok, "Data should be a map") + assert.Contains(t, dataMap, "queryID", "Data should contain queryID") + assert.Contains(t, dataMap, "payloadPath", "Data should contain payloadPath") + // Clean up test directory defer os.RemoveAll(filepath.Join("/tmp", "gh-awmg")) } @@ -216,14 +230,26 @@ func TestWrapToolHandler_LongPayload(t *testing.T) { require.NoError(t, err, "Should not return error") require.NotNil(t, result, "Result should not be nil") - dataMap, ok := data.(map[string]interface{}) - require.True(t, ok, "Data should be a map") + // Verify Content field contains the transformed response + require.NotEmpty(t, result.Content, "Result should have Content") + textContent, ok := result.Content[0].(*sdk.TextContent) + require.True(t, ok, "Content should be TextContent") + + // Parse the JSON from Content + var contentMap map[string]interface{} + err = json.Unmarshal([]byte(textContent.Text), &contentMap) + require.NoError(t, err, "Content should be valid JSON") - // Verify truncation - assert.True(t, dataMap["truncated"].(bool), "Should indicate truncation") - preview := dataMap["preview"].(string) + // Verify truncation in Content field + assert.True(t, contentMap["truncated"].(bool), "Should indicate truncation in Content") + preview := contentMap["preview"].(string) assert.LessOrEqual(t, len(preview), 503, "Preview should be truncated to ~500 chars + '...'") assert.True(t, strings.HasSuffix(preview, "..."), "Preview should end with '...'") + + // Also verify in data return value + dataMap, ok := data.(map[string]interface{}) + require.True(t, ok, "Data should be a map") + assert.True(t, dataMap["truncated"].(bool), "Should indicate truncation in data") } // TestPayloadStorage_SessionIsolation verifies that payloads are stored in session-specific directories