From 500df4a102d46dc2feac2125a5bed69a497ce854 Mon Sep 17 00:00:00 2001 From: Jean-Laurent de Morlhon Date: Fri, 6 Feb 2026 10:30:50 +0100 Subject: [PATCH] /attach use file upload instead of embedding in the context Signed-off-by: Jean-Laurent de Morlhon --- pkg/app/app.go | 83 +++- pkg/chat/chat.go | 57 ++- pkg/cli/runner.go | 71 +-- pkg/model/provider/anthropic/beta_client.go | 20 +- .../provider/anthropic/beta_client_test.go | 9 +- .../provider/anthropic/beta_converter.go | 196 +++++--- .../provider/anthropic/beta_converter_test.go | 15 +- pkg/model/provider/anthropic/client.go | 173 +++++--- pkg/model/provider/anthropic/client_test.go | 45 +- pkg/model/provider/anthropic/files.go | 420 ++++++++++++++++++ pkg/model/provider/anthropic/files_test.go | 239 ++++++++++ 11 files changed, 1139 insertions(+), 189 deletions(-) create mode 100644 pkg/model/provider/anthropic/files.go create mode 100644 pkg/model/provider/anthropic/files_test.go diff --git a/pkg/app/app.go b/pkg/app/app.go index c59999ed7..817f0b577 100644 --- a/pkg/app/app.go +++ b/pkg/app/app.go @@ -4,7 +4,9 @@ import ( "context" "fmt" "log/slog" + "os" "os/exec" + "path/filepath" "slices" "strings" "sync/atomic" @@ -243,20 +245,91 @@ func (a *App) Run(ctx context.Context, cancel context.CancelFunc, message string go func() { if len(attachments) > 0 { + // Strip attachment placeholders from the message text + // Placeholders are in the format @/path/to/file + cleanMessage := message + for placeholder := range attachments { + cleanMessage = strings.ReplaceAll(cleanMessage, placeholder, "") + } + cleanMessage = strings.TrimSpace(cleanMessage) + if cleanMessage == "" { + cleanMessage = "Please analyze this attached file." + } + multiContent := []chat.MessagePart{ { Type: chat.MessagePartTypeText, - Text: message, + Text: cleanMessage, }, } - for key, dataURL := range attachments { + // Attachments are keyed by @filepath placeholder + // Extract the file path and add as file attachment for provider upload. + // Note: There is an inherent TOCTOU race between this validation and when + // the provider reads the file during upload. This validation catches common + // cases (deleted files, wrong paths) but files could still change before upload. + for placeholder := range attachments { + filePath := strings.TrimPrefix(placeholder, "@") + if filePath == "" { + slog.Debug("skipping attachment with empty file path", "placeholder", placeholder) + continue + } + + // Convert to absolute path to ensure consistency with provider upload code + // and prevent issues if working directory changes between validation and upload + absPath, err := filepath.Abs(filePath) + if err != nil { + slog.Warn("skipping attachment: invalid path", "path", filePath, "error", err) + a.events <- runtime.Warning(fmt.Sprintf("Skipped attachment %s: invalid path", filePath), "") + continue + } + + fi, err := os.Stat(absPath) + if err != nil { + var reason string + switch { + case os.IsNotExist(err): + reason = "file does not exist" + case os.IsPermission(err): + reason = "permission denied" + default: + reason = fmt.Sprintf("cannot access file: %v", err) + } + slog.Warn("skipping attachment", "path", absPath, "reason", reason) + a.events <- runtime.Warning(fmt.Sprintf("Skipped attachment %s: %s", filePath, reason), "") + continue + } + + if !fi.Mode().IsRegular() { + slog.Warn("skipping attachment: not a regular file", "path", absPath, "mode", fi.Mode().String()) + a.events <- runtime.Warning(fmt.Sprintf("Skipped attachment %s: not a regular file", filePath), "") + continue + } + + const maxAttachmentSize = 100 * 1024 * 1024 // 100MB + if fi.Size() > maxAttachmentSize { + slog.Warn("skipping attachment: file too large", "path", absPath, "size", fi.Size(), "max", maxAttachmentSize) + a.events <- runtime.Warning(fmt.Sprintf("Skipped attachment %s: file too large (max 100MB)", filePath), "") + continue + } + + mimeType := chat.DetectMimeType(absPath) + if !chat.IsSupportedMimeType(mimeType) { + slog.Warn("skipping attachment: unsupported file type", "path", absPath, "mime_type", mimeType) + a.events <- runtime.Warning(fmt.Sprintf("Skipped attachment %s: unsupported file type (supported: images, pdf, txt, md)", filePath), "") + continue + } + multiContent = append(multiContent, chat.MessagePart{ - Type: chat.MessagePartTypeText, - Text: fmt.Sprintf("Contents of %s: %s", key, dataURL), + Type: chat.MessagePartTypeFile, + File: &chat.MessageFile{ + Path: absPath, + MimeType: mimeType, + }, }) } - a.session.AddMessage(session.UserMessage(message, multiContent...)) + + a.session.AddMessage(session.UserMessage(cleanMessage, multiContent...)) } else { a.session.AddMessage(session.UserMessage(message)) } diff --git a/pkg/chat/chat.go b/pkg/chat/chat.go index 7e6088206..882295d55 100644 --- a/pkg/chat/chat.go +++ b/pkg/chat/chat.go @@ -1,6 +1,11 @@ package chat -import "github.com/docker/cagent/pkg/tools" +import ( + "path/filepath" + "strings" + + "github.com/docker/cagent/pkg/tools" +) type MessageRole string @@ -16,6 +21,7 @@ type MessagePartType string const ( MessagePartTypeText MessagePartType = "text" MessagePartTypeImageURL MessagePartType = "image_url" + MessagePartTypeFile MessagePartType = "file" ) type ImageURLDetail string @@ -74,10 +80,18 @@ type Message struct { CacheControl bool `json:"cache_control,omitempty"` } +// MessageFile represents a file attachment that can be uploaded to a provider's file storage. +type MessageFile struct { + Path string `json:"path,omitempty"` // Local file path (used for upload) + FileID string `json:"file_id,omitempty"` // Provider-specific file ID (after upload) + MimeType string `json:"mime_type,omitempty"` // MIME type of the file +} + type MessagePart struct { Type MessagePartType `json:"type,omitempty"` Text string `json:"text,omitempty"` ImageURL *MessageImageURL `json:"image_url,omitempty"` + File *MessageFile `json:"file,omitempty"` } // FinishReason represents the reason why the model finished generating a response @@ -145,3 +159,44 @@ type MessageStream interface { // Close closes the stream Close() } + +// DetectMimeType returns the MIME type for a file based on its extension. +// This is the canonical implementation used across all packages for consistency. +// Note: Only returns MIME types that are supported for file attachments. +// Unsupported extensions return "application/octet-stream". +func DetectMimeType(filePath string) string { + ext := strings.ToLower(filepath.Ext(filePath)) + switch ext { + // Images + case ".jpg", ".jpeg": + return "image/jpeg" + case ".png": + return "image/png" + case ".gif": + return "image/gif" + case ".webp": + return "image/webp" + // Documents + case ".pdf": + return "application/pdf" + case ".txt", ".json", ".csv": + return "text/plain" + case ".md", ".markdown": + return "text/markdown" + default: + return "application/octet-stream" + } +} + +// IsSupportedMimeType returns true if the MIME type is supported for file attachments. +// Supported types include images (jpeg, png, gif, webp) and documents (pdf, text, markdown). +func IsSupportedMimeType(mimeType string) bool { + switch mimeType { + case "image/jpeg", "image/png", "image/gif", "image/webp": + return true + case "application/pdf", "text/plain", "text/markdown": + return true + default: + return false + } +} diff --git a/pkg/cli/runner.go b/pkg/cli/runner.go index 76fb64da7..5a15782ac 100644 --- a/pkg/cli/runner.go +++ b/pkg/cli/runner.go @@ -3,7 +3,6 @@ package cli import ( "cmp" "context" - "encoding/base64" "encoding/json" "fmt" "io" @@ -316,78 +315,46 @@ func ParseAttachCommand(userInput string) (messageText, attachPath string) { return messageText, attachPath } -// CreateUserMessageWithAttachment creates a user message with optional image attachment +// CreateUserMessageWithAttachment creates a user message with optional file attachment. +// The attachment is stored as a file reference (path + MIME type) rather than base64-encoded +// content. The actual upload to the provider's file storage happens at request time. func CreateUserMessageWithAttachment(userContent, attachmentPath string) *session.Message { if attachmentPath == "" { return session.UserMessage(userContent) } - // Convert file to data URL - dataURL, err := fileToDataURL(attachmentPath) + // Validate file exists + absPath, err := filepath.Abs(attachmentPath) if err != nil { - slog.Warn("Failed to attach file", "path", attachmentPath, "error", err) + slog.Warn("Failed to get absolute path for attachment", "path", attachmentPath, "error", err) return session.UserMessage(userContent) } + if _, err := os.Stat(absPath); os.IsNotExist(err) { + slog.Warn("Attachment file does not exist", "path", absPath) + return session.UserMessage(userContent) + } + + // Determine MIME type + mimeType := chat.DetectMimeType(absPath) + // Ensure we have some text content when attaching a file textContent := cmp.Or(strings.TrimSpace(userContent), "Please analyze this attached file.") - // Create message with multi-content including text and image + // Create message with multi-content including text and file reference multiContent := []chat.MessagePart{ { Type: chat.MessagePartTypeText, Text: textContent, }, { - Type: chat.MessagePartTypeImageURL, - ImageURL: &chat.MessageImageURL{ - URL: dataURL, - Detail: chat.ImageURLDetailAuto, + Type: chat.MessagePartTypeFile, + File: &chat.MessageFile{ + Path: absPath, + MimeType: mimeType, }, }, } return session.UserMessage("", multiContent...) } - -// fileToDataURL converts a file to a data URL -func fileToDataURL(filePath string) (string, error) { - // Check if file exists - if _, err := os.Stat(filePath); os.IsNotExist(err) { - return "", fmt.Errorf("file does not exist: %s", filePath) - } - - // Read file content - fileBytes, err := os.ReadFile(filePath) - if err != nil { - return "", fmt.Errorf("failed to read file: %w", err) - } - - // Determine MIME type based on file extension - ext := strings.ToLower(filepath.Ext(filePath)) - var mimeType string - switch ext { - case ".jpg", ".jpeg": - mimeType = "image/jpeg" - case ".png": - mimeType = "image/png" - case ".gif": - mimeType = "image/gif" - case ".webp": - mimeType = "image/webp" - case ".bmp": - mimeType = "image/bmp" - case ".svg": - mimeType = "image/svg+xml" - default: - return "", fmt.Errorf("unsupported image format: %s", ext) - } - - // Encode to base64 - encoded := base64.StdEncoding.EncodeToString(fileBytes) - - // Create data URL - dataURL := fmt.Sprintf("data:%s;base64,%s", mimeType, encoded) - - return dataURL, nil -} diff --git a/pkg/model/provider/anthropic/beta_client.go b/pkg/model/provider/anthropic/beta_client.go index d3c61cd03..b99124341 100644 --- a/pkg/model/provider/anthropic/beta_client.go +++ b/pkg/model/provider/anthropic/beta_client.go @@ -38,7 +38,11 @@ func (c *Client) createBetaStream( return nil, err } - converted := convertBetaMessages(messages) + converted, err := c.convertBetaMessages(ctx, messages) + if err != nil { + slog.Error("Failed to convert messages for Anthropic Beta request", "error", err) + return nil, err + } if err := validateAnthropicSequencingBeta(converted); err != nil { slog.Warn("Invalid message sequencing for Anthropic Beta API detected, attempting self-repair", "error", err) converted = repairAnthropicSequencingBeta(converted) @@ -50,13 +54,25 @@ func (c *Client) createBetaStream( sys := extractBetaSystemBlocks(messages) + // Check if messages contain file attachments to include the files-api beta header + needsFilesAPI := hasFileAttachments(messages) + + betas := []anthropic.AnthropicBeta{ + anthropic.AnthropicBetaInterleavedThinking2025_05_14, + "fine-grained-tool-streaming-2025-05-14", + } + if needsFilesAPI { + betas = append(betas, filesAPIBeta) + slog.Debug("Anthropic Beta API: Including files-api beta header for file attachments") + } + params := anthropic.BetaMessageNewParams{ Model: anthropic.Model(c.ModelConfig.Model), MaxTokens: maxTokens, System: sys, Messages: converted, Tools: allTools, - Betas: []anthropic.AnthropicBeta{anthropic.AnthropicBetaInterleavedThinking2025_05_14, "fine-grained-tool-streaming-2025-05-14"}, + Betas: betas, } // Apply structured output configuration diff --git a/pkg/model/provider/anthropic/beta_client_test.go b/pkg/model/provider/anthropic/beta_client_test.go index 0db312b4c..ea04e1f1d 100644 --- a/pkg/model/provider/anthropic/beta_client_test.go +++ b/pkg/model/provider/anthropic/beta_client_test.go @@ -262,7 +262,8 @@ func TestConvertBetaMessages_UserMessage(t *testing.T) { }, } - converted := convertBetaMessages(msgs) + converted, err := testClient().convertBetaMessages(t.Context(), msgs) + require.NoError(t, err) require.Len(t, converted, 1) assert.Equal(t, anthropic.BetaMessageParamRoleUser, converted[0].Role) @@ -282,7 +283,8 @@ func TestConvertBetaMessages_SkipsSystemMessages(t *testing.T) { }, } - converted := convertBetaMessages(msgs) + converted, err := testClient().convertBetaMessages(t.Context(), msgs) + require.NoError(t, err) require.Len(t, converted, 1) assert.Equal(t, anthropic.BetaMessageParamRoleUser, converted[0].Role) @@ -297,7 +299,8 @@ func TestConvertBetaMessages_AssistantMessage(t *testing.T) { }, } - converted := convertBetaMessages(msgs) + converted, err := testClient().convertBetaMessages(t.Context(), msgs) + require.NoError(t, err) require.Len(t, converted, 1) assert.Equal(t, anthropic.BetaMessageParamRoleAssistant, converted[0].Role) diff --git a/pkg/model/provider/anthropic/beta_converter.go b/pkg/model/provider/anthropic/beta_converter.go index f70212a4c..1844dbf5c 100644 --- a/pkg/model/provider/anthropic/beta_converter.go +++ b/pkg/model/provider/anthropic/beta_converter.go @@ -1,7 +1,9 @@ package anthropic import ( + "context" "encoding/json" + "fmt" "strings" "github.com/anthropics/anthropic-sdk-go" @@ -18,7 +20,7 @@ import ( // // Important: Anthropic API requires that all tool_result blocks corresponding to tool_use // blocks from the same assistant message MUST be grouped into a single user message. -func convertBetaMessages(messages []chat.Message) []anthropic.BetaMessageParam { +func (c *Client) convertBetaMessages(ctx context.Context, messages []chat.Message) ([]anthropic.BetaMessageParam, error) { var betaMessages []anthropic.BetaMessageParam for i := 0; i < len(messages); i++ { @@ -28,60 +30,11 @@ func convertBetaMessages(messages []chat.Message) []anthropic.BetaMessageParam { continue } if msg.Role == chat.MessageRoleUser { - // Handle user messages (including images and tool results) + // Handle user messages (including images, files, and tool results) if len(msg.MultiContent) > 0 { - contentBlocks := make([]anthropic.BetaContentBlockParamUnion, 0, len(msg.MultiContent)) - for _, part := range msg.MultiContent { - if part.Type == chat.MessagePartTypeText { - if txt := strings.TrimSpace(part.Text); txt != "" { - contentBlocks = append(contentBlocks, anthropic.BetaContentBlockParamUnion{ - OfText: &anthropic.BetaTextBlockParam{Text: txt}, - }) - } - } else if part.Type == chat.MessagePartTypeImageURL && part.ImageURL != nil { - if strings.HasPrefix(part.ImageURL.URL, "data:") { - parts := strings.SplitN(part.ImageURL.URL, ",", 2) - if len(parts) == 2 { - mediaTypePart := parts[0] - base64Data := parts[1] - var mediaType string - switch { - case strings.Contains(mediaTypePart, "image/jpeg"): - mediaType = "image/jpeg" - case strings.Contains(mediaTypePart, "image/png"): - mediaType = "image/png" - case strings.Contains(mediaTypePart, "image/gif"): - mediaType = "image/gif" - case strings.Contains(mediaTypePart, "image/webp"): - mediaType = "image/webp" - default: - mediaType = "image/jpeg" - } - // Use SDK types directly for better performance (avoids JSON round trip) - contentBlocks = append(contentBlocks, anthropic.BetaContentBlockParamUnion{ - OfImage: &anthropic.BetaImageBlockParam{ - Source: anthropic.BetaImageBlockParamSourceUnion{ - OfBase64: &anthropic.BetaBase64ImageSourceParam{ - Data: base64Data, - MediaType: anthropic.BetaBase64ImageSourceMediaType(mediaType), - }, - }, - }, - }) - } - } else if strings.HasPrefix(part.ImageURL.URL, "http://") || strings.HasPrefix(part.ImageURL.URL, "https://") { - // Support URL-based images - Anthropic can fetch images directly from URLs - contentBlocks = append(contentBlocks, anthropic.BetaContentBlockParamUnion{ - OfImage: &anthropic.BetaImageBlockParam{ - Source: anthropic.BetaImageBlockParamSourceUnion{ - OfURL: &anthropic.BetaURLImageSourceParam{ - URL: part.ImageURL.URL, - }, - }, - }, - }) - } - } + contentBlocks, err := c.convertBetaUserMultiContent(ctx, msg.MultiContent) + if err != nil { + return nil, err } if len(contentBlocks) > 0 { betaMessages = append(betaMessages, anthropic.BetaMessageParam{ @@ -189,7 +142,140 @@ func convertBetaMessages(messages []chat.Message) []anthropic.BetaMessageParam { // Add ephemeral cache to last 2 messages' last content block applyBetaMessageCacheControl(betaMessages) - return betaMessages + return betaMessages, nil +} + +// convertBetaUserMultiContent converts user message multi-content parts to Beta API content blocks. +// It handles text, images (base64 and URL), and file uploads via the Files API. +func (c *Client) convertBetaUserMultiContent(ctx context.Context, parts []chat.MessagePart) ([]anthropic.BetaContentBlockParamUnion, error) { + contentBlocks := make([]anthropic.BetaContentBlockParamUnion, 0, len(parts)) + + for _, part := range parts { + switch part.Type { + case chat.MessagePartTypeText: + if txt := strings.TrimSpace(part.Text); txt != "" { + contentBlocks = append(contentBlocks, anthropic.BetaContentBlockParamUnion{ + OfText: &anthropic.BetaTextBlockParam{Text: txt}, + }) + } + + case chat.MessagePartTypeImageURL: + if part.ImageURL == nil { + continue + } + // Handle base64 data URLs (legacy format) + if strings.HasPrefix(part.ImageURL.URL, "data:") { + urlParts := strings.SplitN(part.ImageURL.URL, ",", 2) + if len(urlParts) == 2 { + mediaTypePart := urlParts[0] + base64Data := urlParts[1] + + var mediaType string + switch { + case strings.Contains(mediaTypePart, "image/jpeg"): + mediaType = "image/jpeg" + case strings.Contains(mediaTypePart, "image/png"): + mediaType = "image/png" + case strings.Contains(mediaTypePart, "image/gif"): + mediaType = "image/gif" + case strings.Contains(mediaTypePart, "image/webp"): + mediaType = "image/webp" + default: + mediaType = "image/jpeg" + } + + contentBlocks = append(contentBlocks, anthropic.BetaContentBlockParamUnion{ + OfImage: &anthropic.BetaImageBlockParam{ + Source: anthropic.BetaImageBlockParamSourceUnion{ + OfBase64: &anthropic.BetaBase64ImageSourceParam{ + Data: base64Data, + MediaType: anthropic.BetaBase64ImageSourceMediaType(mediaType), + }, + }, + }, + }) + } + } else if strings.HasPrefix(part.ImageURL.URL, "http://") || strings.HasPrefix(part.ImageURL.URL, "https://") { + // URL-based images + contentBlocks = append(contentBlocks, anthropic.BetaContentBlockParamUnion{ + OfImage: &anthropic.BetaImageBlockParam{ + Source: anthropic.BetaImageBlockParamSourceUnion{ + OfURL: &anthropic.BetaURLImageSourceParam{ + URL: part.ImageURL.URL, + }, + }, + }, + }) + } + + case chat.MessagePartTypeFile: + if part.File == nil { + continue + } + + switch { + case part.File.Path != "" && part.File.FileID == "": + // Upload the file if we have a path but no file ID + if c.fileManager == nil { + return nil, fmt.Errorf("%w: cannot upload file %s", ErrFileManagerNotInitialized, part.File.Path) + } + + uploaded, err := c.fileManager.GetOrUpload(ctx, part.File.Path) + if err != nil { + return nil, fmt.Errorf("failed to upload file %s: %w", part.File.Path, err) + } + + block, err := createBetaFileContentBlock(uploaded.FileID, uploaded.MimeType) + if err != nil { + return nil, err + } + contentBlocks = append(contentBlocks, block) + + case part.File.FileID != "": + // File already uploaded, use the ID directly + block, err := createBetaFileContentBlock(part.File.FileID, part.File.MimeType) + if err != nil { + return nil, err + } + contentBlocks = append(contentBlocks, block) + + default: + // File part has neither path nor file ID - this is invalid + return nil, fmt.Errorf("invalid file attachment: neither path nor file_id provided") + } + } + } + + return contentBlocks, nil +} + +// createBetaFileContentBlock creates the appropriate Beta API content block for a file. +func createBetaFileContentBlock(fileID, mimeType string) (anthropic.BetaContentBlockParamUnion, error) { + if IsImageMime(mimeType) { + return anthropic.BetaContentBlockParamUnion{ + OfImage: &anthropic.BetaImageBlockParam{ + Source: anthropic.BetaImageBlockParamSourceUnion{ + OfFile: &anthropic.BetaFileImageSourceParam{ + FileID: fileID, + }, + }, + }, + }, nil + } + + if IsDocumentMime(mimeType) { + return anthropic.BetaContentBlockParamUnion{ + OfDocument: &anthropic.BetaRequestDocumentBlockParam{ + Source: anthropic.BetaRequestDocumentBlockSourceUnionParam{ + OfFile: &anthropic.BetaFileDocumentSourceParam{ + FileID: fileID, + }, + }, + }, + }, nil + } + + return anthropic.BetaContentBlockParamUnion{}, fmt.Errorf("%w: %s", ErrUnsupportedFileType, mimeType) } // extractBetaSystemBlocks extracts system messages for Beta API format diff --git a/pkg/model/provider/anthropic/beta_converter_test.go b/pkg/model/provider/anthropic/beta_converter_test.go index 687b37eea..c5aaadd4d 100644 --- a/pkg/model/provider/anthropic/beta_converter_test.go +++ b/pkg/model/provider/anthropic/beta_converter_test.go @@ -60,7 +60,8 @@ func TestConvertBetaMessages_MergesConsecutiveToolMessages(t *testing.T) { } // Convert to Beta format - betaMessages := convertBetaMessages(messages) + betaMessages, err := testClient().convertBetaMessages(t.Context(), messages) + require.NoError(t, err) require.Len(t, betaMessages, 4, "Should have 4 messages after conversion") @@ -83,7 +84,7 @@ func TestConvertBetaMessages_MergesConsecutiveToolMessages(t *testing.T) { assert.Contains(t, toolResultIDs, "tool_call_2") // Most importantly: validate that the sequence is valid for Anthropic API - err := validateAnthropicSequencingBeta(betaMessages) + err = validateAnthropicSequencingBeta(betaMessages) require.NoError(t, err, "Converted messages should pass Anthropic sequencing validation") } @@ -119,11 +120,12 @@ func TestConvertBetaMessages_SingleToolMessage(t *testing.T) { }, } - betaMessages := convertBetaMessages(messages) + betaMessages, err := testClient().convertBetaMessages(t.Context(), messages) + require.NoError(t, err) require.Len(t, betaMessages, 4) // Validate sequence - err := validateAnthropicSequencingBeta(betaMessages) + err = validateAnthropicSequencingBeta(betaMessages) require.NoError(t, err) } @@ -179,9 +181,10 @@ func TestConvertBetaMessages_NonConsecutiveToolMessages(t *testing.T) { }, } - betaMessages := convertBetaMessages(messages) + betaMessages, err := testClient().convertBetaMessages(t.Context(), messages) + require.NoError(t, err) // Validate the entire sequence - err := validateAnthropicSequencingBeta(betaMessages) + err = validateAnthropicSequencingBeta(betaMessages) require.NoError(t, err, "Messages with non-consecutive tool calls should still validate") } diff --git a/pkg/model/provider/anthropic/client.go b/pkg/model/provider/anthropic/client.go index 084338136..b4b1f5e5a 100644 --- a/pkg/model/provider/anthropic/client.go +++ b/pkg/model/provider/anthropic/client.go @@ -31,6 +31,7 @@ type Client struct { base.Config clientFn func(context.Context) (anthropic.Client, error) lastHTTPResponse *http.Response + fileManager *FileManager } func (c *Client) getResponseTrailer() http.Header { @@ -203,9 +204,25 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro slog.Debug("Anthropic client created successfully", "model", cfg.Model) + // Initialize FileManager for file uploads + anthropicClient.fileManager = NewFileManager(anthropicClient.clientFn) + return anthropicClient, nil } +// hasFileAttachments checks if any messages contain file attachments. +// This is used to determine if we need to use the Beta API (Files API is Beta-only). +func hasFileAttachments(messages []chat.Message) bool { + for i := range messages { + for _, part := range messages[i].MultiContent { + if part.Type == chat.MessagePartTypeFile && part.File != nil { + return true + } + } + } + return false +} + // CreateChatCompletionStream creates a streaming chat completion request func (c *Client) CreateChatCompletionStream( ctx context.Context, @@ -236,9 +253,10 @@ func (c *Client) CreateChatCompletionStream( // Use Beta API when: // 1. Interleaved thinking is enabled, or - // 2. Structured output is configured + // 2. Structured output is configured, or + // 3. Messages contain file attachments (Files API is Beta-only) // Note: Structured outputs require beta header support (only available on BetaMessageNewParams) - if c.interleavedThinkingEnabled() || c.ModelOptions.StructuredOutput() != nil { + if c.interleavedThinkingEnabled() || c.ModelOptions.StructuredOutput() != nil || hasFileAttachments(messages) { return c.createBetaStream(ctx, client, messages, requestTools, maxTokens) } @@ -248,7 +266,11 @@ func (c *Client) CreateChatCompletionStream( return nil, err } - converted := convertMessages(messages) + converted, err := c.convertMessages(ctx, messages) + if err != nil { + slog.Error("Failed to convert messages for Anthropic request", "error", err) + return nil, err + } // Preflight validation to ensure tool_use/tool_result sequencing is valid if err := validateAnthropicSequencing(converted); err != nil { slog.Warn("Invalid message sequencing for Anthropic detected, attempting self-repair", "error", err) @@ -344,7 +366,7 @@ func (c *Client) CreateChatCompletionStream( return ad, nil } -func convertMessages(messages []chat.Message) []anthropic.MessageParam { +func (c *Client) convertMessages(ctx context.Context, messages []chat.Message) ([]anthropic.MessageParam, error) { var anthropicMessages []anthropic.MessageParam // Track whether the last appended assistant message included tool_use blocks // so we can ensure the immediate next message is the grouped tool_result user message. @@ -357,53 +379,11 @@ func convertMessages(messages []chat.Message) []anthropic.MessageParam { continue } if msg.Role == chat.MessageRoleUser { - // Handle MultiContent for user messages (including images) + // Handle MultiContent for user messages (including images and files) if len(msg.MultiContent) > 0 { - contentBlocks := make([]anthropic.ContentBlockParamUnion, 0, len(msg.MultiContent)) - for _, part := range msg.MultiContent { - if part.Type == chat.MessagePartTypeText { - if txt := strings.TrimSpace(part.Text); txt != "" { - contentBlocks = append(contentBlocks, anthropic.NewTextBlock(txt)) - } - } else if part.Type == chat.MessagePartTypeImageURL && part.ImageURL != nil { - // Anthropic expects base64 image data - // Extract base64 data from data URL - if strings.HasPrefix(part.ImageURL.URL, "data:") { - parts := strings.SplitN(part.ImageURL.URL, ",", 2) - if len(parts) == 2 { - // Extract media type from data URL - mediaTypePart := parts[0] - base64Data := parts[1] - - var mediaType string - switch { - case strings.Contains(mediaTypePart, "image/jpeg"): - mediaType = "image/jpeg" - case strings.Contains(mediaTypePart, "image/png"): - mediaType = "image/png" - case strings.Contains(mediaTypePart, "image/gif"): - mediaType = "image/gif" - case strings.Contains(mediaTypePart, "image/webp"): - mediaType = "image/webp" - default: - // Default to jpeg if not recognized - mediaType = "image/jpeg" - } - - // Use SDK helper with proper typed source for better performance - // (avoids JSON marshal/unmarshal round trip) - contentBlocks = append(contentBlocks, anthropic.NewImageBlock(anthropic.Base64ImageSourceParam{ - Data: base64Data, - MediaType: anthropic.Base64ImageSourceMediaType(mediaType), - })) - } - } else if strings.HasPrefix(part.ImageURL.URL, "http://") || strings.HasPrefix(part.ImageURL.URL, "https://") { - // Support URL-based images - Anthropic can fetch images directly from URLs - contentBlocks = append(contentBlocks, anthropic.NewImageBlock(anthropic.URLImageSourceParam{ - URL: part.ImageURL.URL, - })) - } - } + contentBlocks, err := c.convertUserMultiContent(ctx, msg.MultiContent) + if err != nil { + return nil, err } if len(contentBlocks) > 0 { anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(contentBlocks...)) @@ -499,7 +479,84 @@ func convertMessages(messages []chat.Message) []anthropic.MessageParam { // Add ephemeral cache to last 2 messages' last content block applyMessageCacheControl(anthropicMessages) - return anthropicMessages + return anthropicMessages, nil +} + +// convertUserMultiContent converts user message multi-content parts to Anthropic content blocks. +// It handles text and images (base64 and URL). File uploads are NOT supported in the non-Beta API +// and will return an error - callers should use hasFileAttachments() to route to the Beta API. +func (c *Client) convertUserMultiContent(_ context.Context, parts []chat.MessagePart) ([]anthropic.ContentBlockParamUnion, error) { + contentBlocks := make([]anthropic.ContentBlockParamUnion, 0, len(parts)) + + for _, part := range parts { + switch part.Type { + case chat.MessagePartTypeText: + if txt := strings.TrimSpace(part.Text); txt != "" { + contentBlocks = append(contentBlocks, anthropic.NewTextBlock(txt)) + } + + case chat.MessagePartTypeImageURL: + if part.ImageURL == nil { + continue + } + // Handle base64 data URLs (legacy format) + if strings.HasPrefix(part.ImageURL.URL, "data:") { + urlParts := strings.SplitN(part.ImageURL.URL, ",", 2) + if len(urlParts) == 2 { + mediaTypePart := urlParts[0] + base64Data := urlParts[1] + + var mediaType string + switch { + case strings.Contains(mediaTypePart, "image/jpeg"): + mediaType = "image/jpeg" + case strings.Contains(mediaTypePart, "image/png"): + mediaType = "image/png" + case strings.Contains(mediaTypePart, "image/gif"): + mediaType = "image/gif" + case strings.Contains(mediaTypePart, "image/webp"): + mediaType = "image/webp" + default: + mediaType = "image/jpeg" + } + + contentBlocks = append(contentBlocks, anthropic.NewImageBlock(anthropic.Base64ImageSourceParam{ + Data: base64Data, + MediaType: anthropic.Base64ImageSourceMediaType(mediaType), + })) + } + } else if strings.HasPrefix(part.ImageURL.URL, "http://") || strings.HasPrefix(part.ImageURL.URL, "https://") { + // URL-based images + contentBlocks = append(contentBlocks, anthropic.NewImageBlock(anthropic.URLImageSourceParam{ + URL: part.ImageURL.URL, + })) + } + + case chat.MessagePartTypeFile: + if part.File == nil { + continue + } + + // File uploads require the Beta API - this code path should not be reached + // if hasFileAttachments() correctly routes to createBetaStream(). + // Return a clear error if we somehow get here. + return nil, fmt.Errorf("file attachments require the Beta API; use hasFileAttachments() to route correctly (path=%q, file_id=%q)", + part.File.Path, part.File.FileID) + } + } + + return contentBlocks, nil +} + +// createFileContentBlock creates the appropriate content block for a file based on its MIME type. +// Note: File uploads via the Files API require the Beta API. This function supports images +// (which have OfFile in the Beta API only) and documents. For non-Beta API usage with files, +// the caller should handle the conversion differently or use base64 encoding. +func createFileContentBlock(fileID, mimeType string) (anthropic.ContentBlockParamUnion, error) { + // The standard (non-Beta) API doesn't support file references in ImageBlockParamSourceUnion + // or DocumentBlockParamSourceUnion. Files API is Beta-only. + // For now, we return an error directing users to use the Beta API path. + return anthropic.ContentBlockParamUnion{}, fmt.Errorf("file uploads require the Beta API; file_id=%s, mime_type=%s", fileID, mimeType) } // applyMessageCacheControl adds ephemeral cache control to the last content block @@ -597,6 +654,20 @@ func (c *Client) ID() string { return c.ModelConfig.Provider + "/" + c.ModelConfig.Model } +// CleanupFiles removes all files uploaded during this session from Anthropic's storage. +func (c *Client) CleanupFiles(ctx context.Context) error { + if c.fileManager == nil { + return nil + } + return c.fileManager.CleanupAll(ctx) +} + +// FileManager returns the file manager for this client, allowing external cleanup. +// Returns nil if file uploads are not supported or not initialized. +func (c *Client) FileManager() *FileManager { + return c.fileManager +} + // validateAnthropicSequencing verifies that for every assistant message that includes // one or more tool_use blocks, the immediately following message is a user message // that includes tool_result blocks for all those tool_use IDs (grouped into that single message). diff --git a/pkg/model/provider/anthropic/client_test.go b/pkg/model/provider/anthropic/client_test.go index 03dcb8098..8e2f56b78 100644 --- a/pkg/model/provider/anthropic/client_test.go +++ b/pkg/model/provider/anthropic/client_test.go @@ -16,13 +16,19 @@ import ( "github.com/docker/cagent/pkg/tools" ) +// testClient creates a minimal Client for testing convertMessages. +func testClient() *Client { + return &Client{} +} + func TestConvertMessages_SkipEmptySystemText(t *testing.T) { msgs := []chat.Message{{ Role: chat.MessageRoleSystem, Content: " \n\t ", }} - out := convertMessages(msgs) + out, err := testClient().convertMessages(t.Context(), msgs) + require.NoError(t, err) assert.Empty(t, out) } @@ -32,7 +38,8 @@ func TestConvertMessages_SkipEmptyUserText_NoMultiContent(t *testing.T) { Content: " \n\t ", }} - out := convertMessages(msgs) + out, err := testClient().convertMessages(t.Context(), msgs) + require.NoError(t, err) assert.Empty(t, out) } @@ -45,7 +52,8 @@ func TestConvertMessages_UserMultiContent_SkipEmptyText_KeepImage(t *testing.T) }, }} - out := convertMessages(msgs) + out, err := testClient().convertMessages(t.Context(), msgs) + require.NoError(t, err) require.Len(t, out, 1) b, err := json.Marshal(out[0]) @@ -71,7 +79,8 @@ func TestConvertMessages_SkipEmptyAssistantText_NoToolCalls(t *testing.T) { Content: " \t\n ", }} - out := convertMessages(msgs) + out, err := testClient().convertMessages(t.Context(), msgs) + require.NoError(t, err) assert.Empty(t, out) } @@ -84,7 +93,8 @@ func TestConvertMessages_AssistantToolCalls_NoText_IncludesToolUse(t *testing.T) }, }} - out := convertMessages(msgs) + out, err := testClient().convertMessages(t.Context(), msgs) + require.NoError(t, err) require.Len(t, out, 1) b, err := json.Marshal(out[0]) @@ -112,7 +122,8 @@ func TestSystemMessages_AreExtractedAndNotInMessageList(t *testing.T) { assert.Equal(t, "system rules here", strings.TrimSpace(sys[0].Text)) // System role messages must not appear in the anthropic messages list - out := convertMessages(msgs) + out, err := testClient().convertMessages(t.Context(), msgs) + require.NoError(t, err) assert.Len(t, out, 1) } @@ -128,7 +139,8 @@ func TestSystemMessages_MultipleExtractedAndExcludedFromMessageList(t *testing.T assert.Equal(t, "sys A", strings.TrimSpace(sys[0].Text)) assert.Equal(t, "sys B", strings.TrimSpace(sys[1].Text)) - out := convertMessages(msgs) + out, err := testClient().convertMessages(t.Context(), msgs) + require.NoError(t, err) assert.Len(t, out, 1) } @@ -148,7 +160,8 @@ func TestSystemMessages_InterspersedExtractedAndExcluded(t *testing.T) { assert.Equal(t, "S2", strings.TrimSpace(sys[1].Text)) // Converted messages must exclude system roles and preserve order of others - out := convertMessages(msgs) + out, err := testClient().convertMessages(t.Context(), msgs) + require.NoError(t, err) require.Len(t, out, 3) expectedRoles := []string{"user", "assistant", "user"} for i, expected := range expectedRoles { @@ -173,8 +186,9 @@ func TestSequencingRepair_Standard(t *testing.T) { {Role: chat.MessageRoleUser, Content: "continue"}, } - converted := convertMessages(msgs) - err := validateAnthropicSequencing(converted) + converted, err := testClient().convertMessages(t.Context(), msgs) + require.NoError(t, err) + err = validateAnthropicSequencing(converted) require.Error(t, err) repaired := repairAnthropicSequencing(converted) @@ -195,8 +209,9 @@ func TestSequencingRepair_Beta(t *testing.T) { {Role: chat.MessageRoleUser, Content: "continue"}, } - converted := convertBetaMessages(msgs) - err := validateAnthropicSequencingBeta(converted) + converted, err := testClient().convertBetaMessages(t.Context(), msgs) + require.NoError(t, err) + err = validateAnthropicSequencingBeta(converted) require.Error(t, err) repaired := repairAnthropicSequencingBeta(converted) @@ -212,7 +227,8 @@ func TestConvertMessages_DropOrphanToolResults_NoPrecedingToolUse(t *testing.T) {Role: chat.MessageRoleUser, Content: "continue"}, } - converted := convertMessages(msgs) + converted, err := testClient().convertMessages(t.Context(), msgs) + require.NoError(t, err) // Expect only the two user text messages to appear require.Len(t, converted, 2) @@ -246,7 +262,8 @@ func TestConvertMessages_GroupToolResults_AfterAssistantToolUse(t *testing.T) { {Role: chat.MessageRoleUser, Content: "ok"}, } - converted := convertMessages(msgs) + converted, err := testClient().convertMessages(t.Context(), msgs) + require.NoError(t, err) // Expect: user(start), assistant(tool_use), user(grouped tool_result), user(ok) require.Len(t, converted, 4) diff --git a/pkg/model/provider/anthropic/files.go b/pkg/model/provider/anthropic/files.go new file mode 100644 index 000000000..609585a46 --- /dev/null +++ b/pkg/model/provider/anthropic/files.go @@ -0,0 +1,420 @@ +package anthropic + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "io" + "log/slog" + "os" + "path/filepath" + "sync" + "time" + + "github.com/anthropics/anthropic-sdk-go" + + "github.com/docker/cagent/pkg/chat" +) + +const ( + // filesAPIBeta is the beta header value required for the Files API. + filesAPIBeta = "files-api-2025-04-14" + + // defaultFileTTL is the default time-to-live for uploaded files. + defaultFileTTL = 24 * time.Hour +) + +// UploadedFile represents a file that has been uploaded to Anthropic. +type UploadedFile struct { + FileID string + Filename string + MimeType string + SizeBytes int64 + UploadedAt time.Time + LocalPath string + ContentHash string +} + +// inFlightUpload tracks an upload in progress to prevent duplicate concurrent uploads. +type inFlightUpload struct { + done chan struct{} + result *UploadedFile + err error +} + +// cacheKey creates a composite key for deduplication that includes both content hash and MIME type. +// This prevents issues where identical content with different extensions would share cached uploads. +// Uses a null byte as delimiter since it cannot appear in either SHA256 hex strings or MIME types. +func cacheKey(contentHash, mimeType string) string { + return contentHash + "\x00" + mimeType +} + +// FileManager manages file uploads to Anthropic's Files API. +// It provides deduplication, caching, and TTL-based cleanup. +// Thread-safe for concurrent use. +type FileManager struct { + clientFn func(context.Context) (anthropic.Client, error) + + mu sync.RWMutex + uploads map[string]*UploadedFile // cache key (hash:mime) → uploaded file + paths map[string]string // local path → cache key + inFlight map[string]*inFlightUpload // cache key → in-progress upload +} + +// NewFileManager creates a new FileManager with the given client factory. +func NewFileManager(clientFn func(context.Context) (anthropic.Client, error)) *FileManager { + return &FileManager{ + clientFn: clientFn, + uploads: make(map[string]*UploadedFile), + paths: make(map[string]string), + inFlight: make(map[string]*inFlightUpload), + } +} + +// GetOrUpload returns an existing upload for the file or uploads it if not cached. +// Files are deduplicated by content hash AND MIME type, so identical files with +// different extensions will be uploaded separately. +// Concurrent calls for the same file will wait for a single upload to complete. +func (fm *FileManager) GetOrUpload(ctx context.Context, filePath string) (*UploadedFile, error) { + absPath, err := filepath.Abs(filePath) + if err != nil { + return nil, fmt.Errorf("failed to get absolute path: %w", err) + } + + // Determine MIME type early - needed for cache key + mimeType := chat.DetectMimeType(absPath) + + // Check if we already have this path cached + fm.mu.RLock() + if key, ok := fm.paths[absPath]; ok { + if upload, ok := fm.uploads[key]; ok { + fm.mu.RUnlock() + return upload, nil + } + } + fm.mu.RUnlock() + + // Open file once and compute hash while reading for upload preparation + // This validates the file exists and is readable + file, err := os.Open(absPath) + if err != nil { + return nil, fmt.Errorf("failed to open file: %w", err) + } + defer file.Close() + + stat, err := file.Stat() + if err != nil { + return nil, fmt.Errorf("failed to stat file: %w", err) + } + + // Compute hash by reading file content + h := sha256.New() + if _, err := io.Copy(h, file); err != nil { + return nil, fmt.Errorf("failed to hash file: %w", err) + } + hash := hex.EncodeToString(h.Sum(nil)) + + // Create cache key from hash + MIME type + key := cacheKey(hash, mimeType) + + // Try to get from cache or join an in-flight upload + fm.mu.Lock() + + // Double-check cache after acquiring write lock + if upload, ok := fm.uploads[key]; ok { + fm.paths[absPath] = key + fm.mu.Unlock() + return upload, nil + } + + // Check if there's an in-flight upload for this key + if flight, ok := fm.inFlight[key]; ok { + fm.mu.Unlock() + // Wait for the in-flight upload to complete + select { + case <-flight.done: + // Check context after waking up + if ctx.Err() != nil { + return nil, ctx.Err() + } + if flight.err != nil { + return nil, flight.err + } + // Cache the path mapping + fm.mu.Lock() + fm.paths[absPath] = key + fm.mu.Unlock() + return flight.result, nil + case <-ctx.Done(): + return nil, ctx.Err() + } + } + + // Start a new upload - register it as in-flight + flight := &inFlightUpload{ + done: make(chan struct{}), + } + fm.inFlight[key] = flight + fm.mu.Unlock() + + // Perform the upload (outside the lock) + // File needs to be re-opened since we consumed it for hashing + var upload *UploadedFile + func() { + defer func() { + fm.mu.Lock() + flight.result = upload + flight.err = err + close(flight.done) + delete(fm.inFlight, key) + + // Cache successful uploads regardless of context cancellation. + // The file is already on Anthropic's servers and should be reusable. + if err == nil && upload != nil { + fm.uploads[key] = upload + fm.paths[absPath] = key + } + fm.mu.Unlock() + }() + + upload, err = fm.upload(ctx, absPath, hash, mimeType, stat.Size()) + }() + + // If context was cancelled but upload succeeded, still return the upload. + // The file is already on Anthropic's servers and cached for reuse. + if err == nil && upload != nil { + return upload, nil + } + + if ctx.Err() != nil { + return nil, ctx.Err() + } + + return upload, err +} + +// upload performs the actual file upload to Anthropic. +func (fm *FileManager) upload(ctx context.Context, filePath, contentHash, mimeType string, fileSize int64) (*UploadedFile, error) { + client, err := fm.clientFn(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get client: %w", err) + } + + file, err := os.Open(filePath) + if err != nil { + return nil, fmt.Errorf("failed to open file: %w", err) + } + defer file.Close() + + filename := filepath.Base(filePath) + + slog.Debug("Uploading file to Anthropic Files API", + "filename", filename, + "mime_type", mimeType, + "size", fileSize) + + // Use the SDK's File helper to create the upload + params := anthropic.BetaFileUploadParams{ + File: anthropic.File(file, filename, mimeType), + Betas: []anthropic.AnthropicBeta{filesAPIBeta}, + } + + result, err := client.Beta.Files.Upload(ctx, params) + if err != nil { + return nil, fmt.Errorf("failed to upload file: %w", err) + } + + upload := &UploadedFile{ + FileID: result.ID, + Filename: result.Filename, + MimeType: result.MimeType, + SizeBytes: result.SizeBytes, + UploadedAt: time.Now(), + LocalPath: filePath, + ContentHash: contentHash, + } + + slog.Info("File uploaded to Anthropic", + "file_id", upload.FileID, + "filename", upload.Filename, + "size", upload.SizeBytes) + + return upload, nil +} + +// Delete removes a file from Anthropic's storage. +func (fm *FileManager) Delete(ctx context.Context, fileID string) error { + client, err := fm.clientFn(ctx) + if err != nil { + return fmt.Errorf("failed to get client: %w", err) + } + + params := anthropic.BetaFileDeleteParams{ + Betas: []anthropic.AnthropicBeta{filesAPIBeta}, + } + + _, err = client.Beta.Files.Delete(ctx, fileID, params) + if err != nil { + return fmt.Errorf("failed to delete file: %w", err) + } + + slog.Debug("Deleted file from Anthropic", "file_id", fileID) + return nil +} + +// Cleanup removes files older than the specified TTL from both Anthropic and the cache. +func (fm *FileManager) Cleanup(ctx context.Context, ttl time.Duration) error { + if ttl == 0 { + ttl = defaultFileTTL + } + + cutoff := time.Now().Add(-ttl) + + fm.mu.Lock() + defer fm.mu.Unlock() + + // Collect keys to delete first to avoid modifying map during iteration + var keysToDelete []string + var errs []error + + for key, upload := range fm.uploads { + if upload.UploadedAt.Before(cutoff) { + if err := fm.deleteUnlocked(ctx, upload.FileID); err != nil { + slog.Warn("Failed to delete expired file", "file_id", upload.FileID, "error", err) + errs = append(errs, err) + continue + } + keysToDelete = append(keysToDelete, key) + } + } + + // Now delete from maps + for _, key := range keysToDelete { + delete(fm.uploads, key) + } + + // Collect paths to delete + var pathsToDelete []string + for path, k := range fm.paths { + for _, deletedKey := range keysToDelete { + if k == deletedKey { + pathsToDelete = append(pathsToDelete, path) + break + } + } + } + + for _, path := range pathsToDelete { + delete(fm.paths, path) + } + + if len(errs) > 0 { + return fmt.Errorf("failed to delete %d files during cleanup", len(errs)) + } + return nil +} + +// CleanupAll removes all cached files from Anthropic. +func (fm *FileManager) CleanupAll(ctx context.Context) error { + fm.mu.Lock() + defer fm.mu.Unlock() + + // Collect keys to delete first to avoid modifying map during iteration + var keysToDelete []string + var errs []error + + for key, upload := range fm.uploads { + if err := fm.deleteUnlocked(ctx, upload.FileID); err != nil { + slog.Warn("Failed to delete file during cleanup", "file_id", upload.FileID, "error", err) + errs = append(errs, err) + continue + } + keysToDelete = append(keysToDelete, key) + } + + // Now delete from map + for _, key := range keysToDelete { + delete(fm.uploads, key) + } + + // Clear path mappings + fm.paths = make(map[string]string) + + if len(errs) > 0 { + return fmt.Errorf("failed to delete %d files during cleanup", len(errs)) + } + return nil +} + +// deleteUnlocked deletes a file without acquiring the lock (caller must hold lock). +func (fm *FileManager) deleteUnlocked(ctx context.Context, fileID string) error { + client, err := fm.clientFn(ctx) + if err != nil { + return err + } + + params := anthropic.BetaFileDeleteParams{ + Betas: []anthropic.AnthropicBeta{filesAPIBeta}, + } + + _, err = client.Beta.Files.Delete(ctx, fileID, params) + return err +} + +// CachedCount returns the number of files currently cached. +func (fm *FileManager) CachedCount() int { + fm.mu.RLock() + defer fm.mu.RUnlock() + return len(fm.uploads) +} + +// hashFile computes the SHA256 hash of a file's contents. +// Note: This function is only used for testing and legacy code paths. +// The main GetOrUpload path computes the hash inline to avoid opening the file twice. +func hashFile(filePath string) (string, error) { + file, err := os.Open(filePath) + if err != nil { + return "", err + } + defer file.Close() + + h := sha256.New() + if _, err := io.Copy(h, file); err != nil { + return "", err + } + + return hex.EncodeToString(h.Sum(nil)), nil +} + +// IsImageMime returns true if the MIME type is an image type supported by Anthropic. +func IsImageMime(mimeType string) bool { + switch mimeType { + case "image/jpeg", "image/png", "image/gif", "image/webp": + return true + default: + return false + } +} + +// IsDocumentMime returns true if the MIME type is a document type supported by Anthropic. +func IsDocumentMime(mimeType string) bool { + switch mimeType { + case "application/pdf", "text/plain", "text/markdown": + return true + default: + return false + } +} + +// IsSupportedMime returns true if the MIME type is supported by Anthropic's Files API. +func IsSupportedMime(mimeType string) bool { + return chat.IsSupportedMimeType(mimeType) +} + +// ErrUnsupportedFileType is returned when a file type is not supported by the Files API. +var ErrUnsupportedFileType = errors.New("unsupported file type for Anthropic Files API") + +// ErrFileManagerNotInitialized is returned when file operations are attempted without a FileManager. +var ErrFileManagerNotInitialized = errors.New("file manager not initialized") diff --git a/pkg/model/provider/anthropic/files_test.go b/pkg/model/provider/anthropic/files_test.go new file mode 100644 index 000000000..8c0ba84b5 --- /dev/null +++ b/pkg/model/provider/anthropic/files_test.go @@ -0,0 +1,239 @@ +package anthropic + +import ( + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/cagent/pkg/chat" +) + +func TestDetectMimeType(t *testing.T) { + tests := []struct { + path string + expected string + }{ + {"image.jpg", "image/jpeg"}, + {"image.jpeg", "image/jpeg"}, + {"image.png", "image/png"}, + {"image.gif", "image/gif"}, + {"image.webp", "image/webp"}, + {"document.pdf", "application/pdf"}, + {"readme.txt", "text/plain"}, + {"readme.md", "text/markdown"}, + {"readme.markdown", "text/markdown"}, + // json and csv are treated as text/plain for provider compatibility + {"data.json", "text/plain"}, + {"data.csv", "text/plain"}, + {"unknown.xyz", "application/octet-stream"}, + } + + for _, tt := range tests { + t.Run(tt.path, func(t *testing.T) { + result := chat.DetectMimeType(tt.path) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestIsImageMime(t *testing.T) { + tests := []struct { + mimeType string + expected bool + }{ + {"image/jpeg", true}, + {"image/png", true}, + {"image/gif", true}, + {"image/webp", true}, + {"application/pdf", false}, + {"text/plain", false}, + {"application/octet-stream", false}, + } + + for _, tt := range tests { + t.Run(tt.mimeType, func(t *testing.T) { + result := IsImageMime(tt.mimeType) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestIsDocumentMime(t *testing.T) { + tests := []struct { + mimeType string + expected bool + }{ + {"application/pdf", true}, + {"text/plain", true}, + {"text/markdown", true}, + {"image/jpeg", false}, + {"image/png", false}, + {"application/octet-stream", false}, + } + + for _, tt := range tests { + t.Run(tt.mimeType, func(t *testing.T) { + result := IsDocumentMime(tt.mimeType) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestIsSupportedMime(t *testing.T) { + tests := []struct { + mimeType string + expected bool + }{ + {"image/jpeg", true}, + {"image/png", true}, + {"image/gif", true}, + {"image/webp", true}, + {"application/pdf", true}, + {"text/plain", true}, + {"text/markdown", true}, + {"application/json", false}, + {"application/octet-stream", false}, + } + + for _, tt := range tests { + t.Run(tt.mimeType, func(t *testing.T) { + result := IsSupportedMime(tt.mimeType) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestHashFile(t *testing.T) { + // Create a temporary file + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.txt") + content := []byte("test content for hashing") + err := os.WriteFile(testFile, content, 0o644) + require.NoError(t, err) + + // Hash should be consistent for same content + hash1, err := hashFile(testFile) + require.NoError(t, err) + assert.NotEmpty(t, hash1) + + hash2, err := hashFile(testFile) + require.NoError(t, err) + assert.Equal(t, hash1, hash2) + + // Different content should produce different hash + testFile2 := filepath.Join(tmpDir, "test2.txt") + err = os.WriteFile(testFile2, []byte("different content"), 0o644) + require.NoError(t, err) + + hash3, err := hashFile(testFile2) + require.NoError(t, err) + assert.NotEqual(t, hash1, hash3) +} + +func TestHashFile_NotFound(t *testing.T) { + _, err := hashFile("/nonexistent/path/to/file.txt") + assert.Error(t, err) +} + +func TestNewFileManager(t *testing.T) { + fm := NewFileManager(nil) + require.NotNil(t, fm) + assert.Equal(t, 0, fm.CachedCount()) +} + +func TestUploadedFile_TTL(t *testing.T) { + old := &UploadedFile{ + FileID: "file_old", + UploadedAt: time.Now().Add(-25 * time.Hour), + } + recent := &UploadedFile{ + FileID: "file_recent", + UploadedAt: time.Now().Add(-1 * time.Hour), + } + + cutoff := time.Now().Add(-24 * time.Hour) + + assert.True(t, old.UploadedAt.Before(cutoff), "old file should be before cutoff") + assert.False(t, recent.UploadedAt.Before(cutoff), "recent file should not be before cutoff") +} + +func TestFileManager_Deduplication(t *testing.T) { + // This test verifies the deduplication logic structure + // Actual upload testing would require mocking the Anthropic client + + fm := NewFileManager(nil) + require.NotNil(t, fm) + + // Manually populate the cache to test deduplication logic + testHash := "abc123" + testFile := &UploadedFile{ + FileID: "file_test", + Filename: "test.png", + MimeType: "image/png", + ContentHash: testHash, + UploadedAt: time.Now(), + LocalPath: "/path/to/test.png", + } + + fm.mu.Lock() + fm.uploads[testHash] = testFile + fm.paths["/path/to/test.png"] = testHash + fm.mu.Unlock() + + // Check that the file is cached + assert.Equal(t, 1, fm.CachedCount()) + + // Verify the path mapping exists + fm.mu.RLock() + hash, ok := fm.paths["/path/to/test.png"] + fm.mu.RUnlock() + assert.True(t, ok) + assert.Equal(t, testHash, hash) + + // Verify the upload exists + fm.mu.RLock() + upload, ok := fm.uploads[testHash] + fm.mu.RUnlock() + assert.True(t, ok) + assert.Equal(t, "file_test", upload.FileID) +} + +func TestCreateFileContentBlock_NotSupported(t *testing.T) { + // Standard API doesn't support file references - Files API is Beta-only + _, err := createFileContentBlock("file_123", "image/png") + require.Error(t, err) + assert.Contains(t, err.Error(), "Beta API") +} + +func TestCreateBetaFileContentBlock_Image(t *testing.T) { + block, err := createBetaFileContentBlock("file_beta_123", "image/jpeg") + require.NoError(t, err) + assert.NotNil(t, block.OfImage) + assert.Nil(t, block.OfDocument) + assert.Equal(t, "file_beta_123", block.OfImage.Source.OfFile.FileID) +} + +func TestCreateBetaFileContentBlock_Document(t *testing.T) { + block, err := createBetaFileContentBlock("file_beta_456", "application/pdf") + require.NoError(t, err) + assert.NotNil(t, block.OfDocument) + assert.Nil(t, block.OfImage) + assert.Equal(t, "file_beta_456", block.OfDocument.Source.OfFile.FileID) +} + +func TestCreateBetaFileContentBlock_Unsupported(t *testing.T) { + _, err := createBetaFileContentBlock("file_beta_000", "video/mp4") + require.Error(t, err) + assert.ErrorIs(t, err, ErrUnsupportedFileType) +} + +func TestFileManager_CleanupAll_Empty(t *testing.T) { + fm := NewFileManager(nil) + err := fm.CleanupAll(t.Context()) + require.NoError(t, err) + assert.Equal(t, 0, fm.CachedCount()) +}