From 630cfdebb7cb5fdcd95f52156950c4fea210a302 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Wed, 3 Dec 2025 20:47:03 +0800 Subject: [PATCH] fix: concat agentic messages --- components/agentic/callback_extra.go | 4 +- components/agentic/option.go | 7 +- components/agentic/option_test.go | 2 +- components/model/callback_extra.go | 6 +- compose/tools_node_agentic.go | 7 +- compose/tools_node_agentic_test.go | 37 +- schema/agentic_message.go | 1018 ++++++++++------- schema/agentic_message_test.go | 552 ++++++--- .../claude/{content_block.go => extension.go} | 68 +- schema/claude/extension_test.go | 190 +++ schema/claude/response_meta.go | 22 - .../gemini/{response_meta.go => extension.go} | 37 +- schema/gemini/extension_test.go | 79 ++ schema/message.go | 10 +- schema/openai/consts.go | 68 ++ schema/openai/content_block.go | 75 -- schema/openai/extension.go | 204 ++++ schema/openai/extension_test.go | 193 ++++ schema/openai/response_meta.go | 40 - schema/tool.go | 20 + 20 files changed, 1896 insertions(+), 743 deletions(-) rename schema/claude/{content_block.go => extension.go} (54%) create mode 100644 schema/claude/extension_test.go delete mode 100644 schema/claude/response_meta.go rename schema/gemini/{response_meta.go => extension.go} (81%) create mode 100644 schema/gemini/extension_test.go delete mode 100644 schema/openai/content_block.go create mode 100644 schema/openai/extension.go create mode 100644 schema/openai/extension_test.go delete mode 100644 schema/openai/response_meta.go diff --git a/components/agentic/callback_extra.go b/components/agentic/callback_extra.go index 389408d33..e2f0f51df 100644 --- a/components/agentic/callback_extra.go +++ b/components/agentic/callback_extra.go @@ -26,9 +26,9 @@ type Config struct { // Model is the model name. Model string // Temperature is the temperature, which controls the randomness of the model. - Temperature float32 + Temperature float64 // TopP is the top p, which controls the diversity of the model. - TopP float32 + TopP float64 } // CallbackInput is the input for the model callback. diff --git a/components/agentic/option.go b/components/agentic/option.go index ac117ddb4..d8873442a 100644 --- a/components/agentic/option.go +++ b/components/agentic/option.go @@ -30,8 +30,10 @@ type Options struct { TopP *float64 // Tools is a list of tools the model may call. Tools []*schema.ToolInfo - // ToolChoice controls which tool is called by the model. + // ToolChoice controls how the model call the tools. ToolChoice *schema.ToolChoice + // AllowedTools is a list of allowed tools the model may call. + AllowedTools []*schema.AllowedTool } // Option is the call option for ChatModel component. @@ -81,10 +83,11 @@ func WithTools(tools []*schema.ToolInfo) Option { } // WithToolChoice is the option to set tool choice for the model. -func WithToolChoice(toolChoice schema.ToolChoice) Option { +func WithToolChoice(toolChoice schema.ToolChoice, allowedTools ...*schema.AllowedTool) Option { return Option{ apply: func(opts *Options) { opts.ToolChoice = &toolChoice + opts.AllowedTools = allowedTools }, } } diff --git a/components/agentic/option_test.go b/components/agentic/option_test.go index d349f35ac..2c5bac652 100644 --- a/components/agentic/option_test.go +++ b/components/agentic/option_test.go @@ -29,7 +29,7 @@ func TestCommon(t *testing.T) { WithTools([]*schema.ToolInfo{{Name: "test"}}), WithModel("test"), WithTemperature(0.1), - WithToolChoice(schema.ToolChoiceAllowed), + WithToolChoice(schema.ToolChoiceAllowed, []*schema.AllowedTool{{FunctionToolName: "test"}}...), WithTopP(0.1), ) assert.Len(t, o.Tools, 1) diff --git a/components/model/callback_extra.go b/components/model/callback_extra.go index 8334c0569..f00fa31cd 100644 --- a/components/model/callback_extra.go +++ b/components/model/callback_extra.go @@ -29,17 +29,17 @@ type TokenUsage struct { PromptTokenDetails PromptTokenDetails // CompletionTokens is the number of completion tokens. CompletionTokens int + // CompletionTokensDetails is a breakdown of the completion tokens. + CompletionTokensDetails CompletionTokensDetails // TotalTokens is the total number of tokens. TotalTokens int - // CompletionTokensDetails is breakdown of completion tokens. - CompletionTokensDetails CompletionTokensDetails `json:"completion_token_details"` } type CompletionTokensDetails struct { // ReasoningTokens tokens generated by the model for reasoning. // This is currently supported by OpenAI, Gemini, ARK and Qwen chat models. // For other models, this field will be 0. - ReasoningTokens int `json:"reasoning_tokens,omitempty"` + ReasoningTokens int } type PromptTokenDetails struct { diff --git a/compose/tools_node_agentic.go b/compose/tools_node_agentic.go index 38c5c89de..96aef7b72 100644 --- a/compose/tools_node_agentic.go +++ b/compose/tools_node_agentic.go @@ -70,6 +70,7 @@ func agenticMessageToToolCallMessage(input *schema.AgenticMessage) *schema.Messa Name: block.FunctionToolCall.Name, Arguments: block.FunctionToolCall.Arguments, }, + Extra: block.Extra, }) } return &schema.Message{ @@ -87,8 +88,8 @@ func toolMessageToAgenticMessage(input []*schema.Message) []*schema.AgenticMessa CallID: m.ToolCallID, Name: m.ToolName, Result: m.Content, - Extra: m.Extra, }, + Extra: m.Extra, }) } return []*schema.AgenticMessage{{ @@ -110,9 +111,9 @@ func streamToolMessageToAgenticMessage(input *schema.StreamReader[[]*schema.Mess CallID: m.ToolCallID, Name: m.ToolName, Result: m.Content, - Extra: m.Extra, }, - StreamMeta: &schema.StreamMeta{Index: int64(i)}, + StreamingMeta: &schema.StreamingMeta{Index: i}, + Extra: m.Extra, }) } return []*schema.AgenticMessage{{ diff --git a/compose/tools_node_agentic_test.go b/compose/tools_node_agentic_test.go index dcd3177a9..4641dd8ae 100644 --- a/compose/tools_node_agentic_test.go +++ b/compose/tools_node_agentic_test.go @@ -20,6 +20,7 @@ import ( "io" "testing" + "github.com/bytedance/sonic" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/schema" @@ -155,13 +156,14 @@ func TestStreamToolMessageToAgenticMessage(t *testing.T) { nil, }, { + nil, { Role: schema.Tool, - Content: "content1-2", + Content: "content2-2", ToolName: "name2", ToolCallID: "2", }, - nil, nil, + nil, }, { nil, nil, @@ -172,16 +174,6 @@ func TestStreamToolMessageToAgenticMessage(t *testing.T) { ToolCallID: "3", }, }, - { - nil, - { - Role: schema.Tool, - Content: "content2-2", - ToolName: "name2", - ToolCallID: "2", - }, - nil, - }, { nil, nil, { @@ -204,7 +196,11 @@ func TestStreamToolMessageToAgenticMessage(t *testing.T) { } result, err := schema.ConcatAgenticMessagesArray(chunks) assert.NoError(t, err) - assert.Equal(t, []*schema.AgenticMessage{ + + actualStr, err := sonic.MarshalString(result) + assert.NoError(t, err) + + expected := []*schema.AgenticMessage{ { Role: schema.AgenticRoleTypeUser, ContentBlocks: []*schema.ContentBlock{ @@ -213,10 +209,8 @@ func TestStreamToolMessageToAgenticMessage(t *testing.T) { FunctionToolResult: &schema.FunctionToolResult{ CallID: "1", Name: "name1", - Result: "content1-1content1-2", - Extra: map[string]interface{}{}, + Result: "content1-1", }, - StreamMeta: &schema.StreamMeta{Index: 0}, }, { Type: schema.ContentBlockTypeFunctionToolResult, @@ -224,9 +218,7 @@ func TestStreamToolMessageToAgenticMessage(t *testing.T) { CallID: "2", Name: "name2", Result: "content2-1content2-2", - Extra: map[string]interface{}{}, }, - StreamMeta: &schema.StreamMeta{Index: 1}, }, { Type: schema.ContentBlockTypeFunctionToolResult, @@ -234,11 +226,14 @@ func TestStreamToolMessageToAgenticMessage(t *testing.T) { CallID: "3", Name: "name3", Result: "content3-1content3-2", - Extra: map[string]interface{}{}, }, - StreamMeta: &schema.StreamMeta{Index: 2}, }, }, }, - }, result) + } + + expectedStr, err := sonic.MarshalString(expected) + assert.NoError(t, err) + + assert.Equal(t, expectedStr, actualStr) } diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 2139201ec..504799c89 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "reflect" + "sort" "strings" "github.com/cloudwego/eino/schema/claude" @@ -82,9 +83,9 @@ type AgenticResponseMeta struct { Extension any } -type StreamMeta struct { +type StreamingMeta struct { // Index specifies the index position of this block in the final response. - Index int64 + Index int } type ContentBlock struct { @@ -123,14 +124,12 @@ type ContentBlock struct { // MCPToolApprovalResponse records the user's approval decision for an MCP tool call. MCPToolApprovalResponse *MCPToolApprovalResponse - StreamMeta *StreamMeta + StreamingMeta *StreamingMeta + Extra map[string]any } type UserInputText struct { Text string - - // Extra stores additional information. - Extra map[string]any } type UserInputImage struct { @@ -138,27 +137,18 @@ type UserInputImage struct { Base64Data string MIMEType string Detail ImageURLDetail - - // Extra stores additional information. - Extra map[string]any } type UserInputAudio struct { URL string Base64Data string MIMEType string - - // Extra stores additional information. - Extra map[string]any } type UserInputVideo struct { URL string Base64Data string MIMEType string - - // Extra stores additional information. - Extra map[string]any } type UserInputFile struct { @@ -166,9 +156,6 @@ type UserInputFile struct { Name string Base64Data string MIMEType string - - // Extra stores additional information. - Extra map[string]any } type AssistantGenText struct { @@ -177,51 +164,37 @@ type AssistantGenText struct { OpenAIExtension *openai.AssistantGenTextExtension ClaudeExtension *claude.AssistantGenTextExtension Extension any - - // Extra stores additional information. - Extra map[string]any } type AssistantGenImage struct { URL string Base64Data string MIMEType string - - // Extra stores additional information. - Extra map[string]any } type AssistantGenAudio struct { URL string Base64Data string MIMEType string - - // Extra stores additional information. - Extra map[string]any } type AssistantGenVideo struct { URL string Base64Data string MIMEType string - - // Extra stores additional information. - Extra map[string]any } type Reasoning struct { // Summary is the reasoning content summary. Summary []*ReasoningSummary + // EncryptedContent is the encrypted reasoning content. EncryptedContent string - - // Extra stores additional information. - Extra map[string]any } type ReasoningSummary struct { // Index specifies the index position of this summary in the final Reasoning. - Index int64 + Index int Text string } @@ -229,39 +202,37 @@ type ReasoningSummary struct { type FunctionToolCall struct { // CallID is the unique identifier for the tool call. CallID string + // Name specifies the function tool invoked. Name string + // Arguments is the JSON string arguments for the function tool call. Arguments string - - // Extra stores additional information - Extra map[string]any } type FunctionToolResult struct { // CallID is the unique identifier for the tool call. CallID string + // Name specifies the function tool invoked. Name string + // Result is the function tool result returned by the user Result string - - // Extra stores additional information. - Extra map[string]any } type ServerToolCall struct { // Name specifies the server-side tool invoked. // Supplied by the model server (e.g., `web_search` for OpenAI, `googleSearch` for Gemini). Name string + // CallID is the unique identifier for the tool call. // Empty if not provided by the model server. CallID string + // Arguments are the raw inputs to the server-side tool, // supplied by the component implementer. Arguments any - // Extra stores additional information. - Extra map[string]any } type ServerToolResult struct { @@ -276,41 +247,40 @@ type ServerToolResult struct { // Result refers to the raw output generated by the server-side tool, // supplied by the component implementer. Result any - - // Extra stores additional information. - Extra map[string]any } type MCPToolCall struct { // ServerLabel is the MCP server label used to identify it in tool calls ServerLabel string - // ApprovalRequestID is the unique ID of the approval request. + + // ApprovalRequestID is the approval request ID. ApprovalRequestID string + // CallID is the unique ID of the tool call. CallID string + // Name is the name of the tool to run. Name string + // Arguments is the JSON string arguments for the tool call. Arguments string - - // Extra stores additional information. - Extra map[string]any } type MCPToolResult struct { // ServerLabel is the MCP server label used to identify it in tool calls ServerLabel string + // CallID is the unique ID of the tool call. CallID string + // Name is the name of the tool to run. Name string + // Result is the JSON string with the tool result. Result string + // Error returned when the server fails to run the tool. Error *MCPToolCallError - - // Extra stores additional information. - Extra map[string]any } type MCPToolCallError struct { @@ -321,49 +291,49 @@ type MCPToolCallError struct { type MCPListToolsResult struct { // ServerLabel is the MCP server label used to identify it in tool calls. ServerLabel string + // Tools is the list of tools available on the server. Tools []*MCPListToolsItem + // Error returned when the server fails to list tools. Error string - - // Extra stores additional information. - Extra map[string]any } type MCPListToolsItem struct { // Name is the name of the tool. Name string + // Description is the description of the tool. Description string - // InputSchema is the JSON schema that describes the tool input. + + // InputSchema is the JSON schema that describes the tool input parameters. InputSchema *jsonschema.Schema } type MCPToolApprovalRequest struct { // ID is the approval request ID. ID string + // Name is the name of the tool to run. Name string + // Arguments is the JSON string arguments for the tool call. Arguments string + // ServerLabel is the MCP server label used to identify it in tool calls. ServerLabel string - - // Extra stores additional information. - Extra map[string]any } type MCPToolApprovalResponse struct { // ApprovalRequestID is the approval request ID being responded to. ApprovalRequestID string + // Approve indicates whether the request is approved. Approve bool + // Reason is the rationale for the decision. // Optional. Reason string - - // Extra stores additional information. - Extra map[string]any } // DeveloperAgenticMessage represents a message with AgenticRoleType "developer". @@ -404,8 +374,32 @@ func FunctionToolResultAgenticMessage(callID, name, result string) *AgenticMessa } } -func NewContentBlock(block any) *ContentBlock { - switch b := block.(type) { +type contentBlockVariant interface { + Reasoning | userInputVariant | assistantGenVariant | functionToolCallVariant | serverToolCallVariant | mcpToolCallVariant +} + +type userInputVariant interface { + UserInputText | UserInputImage | UserInputAudio | UserInputVideo | UserInputFile +} + +type assistantGenVariant interface { + AssistantGenText | AssistantGenImage | AssistantGenAudio | AssistantGenVideo +} + +type functionToolCallVariant interface { + FunctionToolCall | FunctionToolResult +} + +type serverToolCallVariant interface { + ServerToolCall | ServerToolResult +} + +type mcpToolCallVariant interface { + MCPToolCall | MCPToolResult | MCPListToolsResult | MCPToolApprovalRequest | MCPToolApprovalResponse +} + +func NewContentBlock[T contentBlockVariant](content *T) *ContentBlock { + switch b := any(content).(type) { case *Reasoning: return &ContentBlock{Type: ContentBlockTypeReasoning, Reasoning: b} case *UserInputText: @@ -449,6 +443,12 @@ func NewContentBlock(block any) *ContentBlock { } } +func NewContentBlockChunk[T contentBlockVariant](content *T, meta *StreamingMeta) *ContentBlock { + block := NewContentBlock(content) + block.StreamingMeta = meta + return block +} + // AgenticMessagesTemplate is the interface for messages template. // It's used to render a template to a list of agentic messages. // e.g. @@ -689,10 +689,11 @@ func ConcatAgenticMessagesArray(mas [][]*AgenticMessage) ([]*AgenticMessage, err func ConcatAgenticMessages(msgs []*AgenticMessage) (*AgenticMessage, error) { var ( - role AgenticRoleType - blocksList [][]*ContentBlock - blocks []*ContentBlock - metas []*AgenticResponseMeta + role AgenticRoleType + blocks []*ContentBlock + metas []*AgenticResponseMeta + blockIndices []int + indexToBlocks = map[int][]*ContentBlock{} ) if len(msgs) == 1 { @@ -713,9 +714,12 @@ func ConcatAgenticMessages(msgs []*AgenticMessage) (*AgenticMessage, error) { } for _, block := range msg.ContentBlocks { - if block.StreamMeta == nil { + if block == nil { + continue + } + if block.StreamingMeta == nil { // Non-streaming block - if len(blocksList) > 0 { + if len(blockIndices) > 0 { // Cannot mix streaming and non-streaming blocks return nil, fmt.Errorf("found non-streaming block after streaming blocks") } @@ -728,8 +732,12 @@ func ConcatAgenticMessages(msgs []*AgenticMessage) (*AgenticMessage, error) { return nil, fmt.Errorf("found streaming block after non-streaming blocks") } // Collect streaming block by index - blocksList = expandSlice(int(block.StreamMeta.Index), blocksList) - blocksList[block.StreamMeta.Index] = append(blocksList[block.StreamMeta.Index], block) + if blocks_, ok := indexToBlocks[block.StreamingMeta.Index]; ok { + indexToBlocks[block.StreamingMeta.Index] = append(blocks_, block) + } else { + blockIndices = append(blockIndices, block.StreamingMeta.Index) + indexToBlocks[block.StreamingMeta.Index] = []*ContentBlock{block} + } } } @@ -743,219 +751,253 @@ func ConcatAgenticMessages(msgs []*AgenticMessage) (*AgenticMessage, error) { return nil, fmt.Errorf("failed to concat agentic response meta: %w", err) } - if len(blocksList) > 0 { + if len(blockIndices) > 0 { // All blocks are streaming, concat each group by index - blocks = make([]*ContentBlock, len(blocksList)) - for i, bs := range blocksList { - if len(bs) == 0 { - continue - } - b, err := concatAgenticContentBlocks(bs) + indexToBlock := map[int]*ContentBlock{} + for idx, bs := range indexToBlocks { + b, err := concatChunksOfSameContentBlock(bs) if err != nil { - return nil, fmt.Errorf("failed to concat content blocks at index %d: %w", i, err) + return nil, err } - blocks[i] = b + indexToBlock[idx] = b } - } - - for i := 0; i < len(blocks); i++ { - if blocks[i] == nil { - blocks = append(blocks[:i], blocks[i+1:]...) + blocks = make([]*ContentBlock, 0, len(blockIndices)) + sort.Slice(blockIndices, func(i, j int) bool { + return blockIndices[i] < blockIndices[j] + }) + for _, idx := range blockIndices { + blocks = append(blocks, indexToBlock[idx]) } } return &AgenticMessage{ - ResponseMeta: meta, Role: role, + ResponseMeta: meta, ContentBlocks: blocks, }, nil } -func concatAgenticResponseMeta(metas []*AgenticResponseMeta) (*AgenticResponseMeta, error) { +func concatAgenticResponseMeta(metas []*AgenticResponseMeta) (ret *AgenticResponseMeta, err error) { if len(metas) == 0 { return nil, nil } - ret := &AgenticResponseMeta{ - TokenUsage: &TokenUsage{}, - OpenAIExtension: nil, - ClaudeExtension: nil, - GeminiExtension: nil, - Extension: nil, - } + + openaiExtensions := make([]*openai.ResponseMetaExtension, 0, len(metas)) + claudeExtensions := make([]*claude.ResponseMetaExtension, 0, len(metas)) + geminiExtensions := make([]*gemini.ResponseMetaExtension, 0, len(metas)) + tokenUsages := make([]*TokenUsage, 0, len(metas)) + + var ( + extType reflect.Type + extensions reflect.Value + ) + for _, meta := range metas { - ret.Extension = meta.Extension - ret.OpenAIExtension = meta.OpenAIExtension - ret.ClaudeExtension = meta.ClaudeExtension - ret.GeminiExtension = meta.GeminiExtension if meta.TokenUsage != nil { - ret.TokenUsage.CompletionTokens += meta.TokenUsage.CompletionTokens - ret.TokenUsage.CompletionTokenDetails.ReasoningTokens += meta.TokenUsage.CompletionTokenDetails.ReasoningTokens - ret.TokenUsage.PromptTokens += meta.TokenUsage.PromptTokens - ret.TokenUsage.PromptTokenDetails.CachedTokens += meta.TokenUsage.PromptTokenDetails.CachedTokens - ret.TokenUsage.TotalTokens += meta.TokenUsage.TotalTokens + tokenUsages = append(tokenUsages, meta.TokenUsage) + } + + var isConsistent bool + + if meta.Extension != nil { + extType, isConsistent = validateExtensionType(extType, meta.Extension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in response meta chunks: '%s' vs '%s'", + extType, reflect.TypeOf(meta.Extension)) + } + if !extensions.IsValid() { + extensions = reflect.MakeSlice(reflect.SliceOf(extType), 0, len(metas)) + } + extensions = reflect.Append(extensions, reflect.ValueOf(meta.Extension)) + } + + if meta.OpenAIExtension != nil { + extType, isConsistent = validateExtensionType(extType, meta.OpenAIExtension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in response meta chunks: '%s' vs '%s'", + extType, reflect.TypeOf(meta.OpenAIExtension)) + } + openaiExtensions = append(openaiExtensions, meta.OpenAIExtension) + } + + if meta.ClaudeExtension != nil { + extType, isConsistent = validateExtensionType(extType, meta.ClaudeExtension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in response meta chunks: '%s' vs '%s'", + extType, reflect.TypeOf(meta.ClaudeExtension)) + } + claudeExtensions = append(claudeExtensions, meta.ClaudeExtension) + } + + if meta.GeminiExtension != nil { + extType, isConsistent = validateExtensionType(extType, meta.GeminiExtension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in response meta chunks: '%s' vs '%s'", + extType, reflect.TypeOf(meta.GeminiExtension)) + } + geminiExtensions = append(geminiExtensions, meta.GeminiExtension) + } + } + + ret = &AgenticResponseMeta{ + TokenUsage: concatTokenUsage(tokenUsages), + } + + if extensions.IsValid() && !extensions.IsZero() { + extension, err := internal.ConcatSliceValue(extensions) + if err != nil { + return nil, fmt.Errorf("failed to concat extensions: %w", err) } + ret.Extension = extension.Interface() } + + if len(openaiExtensions) > 0 { + ret.OpenAIExtension, err = openai.ConcatResponseMetaExtensions(openaiExtensions) + if err != nil { + return nil, fmt.Errorf("failed to concat openai extensions: %w", err) + } + } + + if len(claudeExtensions) > 0 { + ret.ClaudeExtension, err = claude.ConcatResponseMetaExtensions(claudeExtensions) + if err != nil { + return nil, fmt.Errorf("failed to concat claude extensions: %w", err) + } + } + + if len(geminiExtensions) > 0 { + ret.GeminiExtension, err = gemini.ConcatResponseMetaExtensions(geminiExtensions) + if err != nil { + return nil, fmt.Errorf("failed to concat gemini extensions: %w", err) + } + } + return ret, nil } -func concatAgenticContentBlocks(blocks []*ContentBlock) (*ContentBlock, error) { +func concatTokenUsage(usages []*TokenUsage) *TokenUsage { + if len(usages) == 0 { + return nil + } + + ret := &TokenUsage{} + + for _, usage := range usages { + if usage == nil { + continue + } + ret.CompletionTokens += usage.CompletionTokens + ret.CompletionTokensDetails.ReasoningTokens += usage.CompletionTokensDetails.ReasoningTokens + ret.PromptTokens += usage.PromptTokens + ret.PromptTokenDetails.CachedTokens += usage.PromptTokenDetails.CachedTokens + ret.TotalTokens += usage.TotalTokens + } + + return ret +} + +func concatChunksOfSameContentBlock(blocks []*ContentBlock) (*ContentBlock, error) { if len(blocks) == 0 { return nil, fmt.Errorf("no content blocks to concat") } + blockType := blocks[0].Type - index := blocks[0].StreamMeta.Index + switch blockType { case ContentBlockTypeReasoning: - return concatContentBlockHelper(blocks, blockType, "reasoning", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *Reasoning { return b.Reasoning }, - concatReasoning, - func(r *Reasoning) *ContentBlock { - return &ContentBlock{Type: blockType, Reasoning: r, StreamMeta: &StreamMeta{Index: index}} - }) + concatReasoning) case ContentBlockTypeUserInputText: - return concatContentBlockHelper(blocks, blockType, "user input text", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *UserInputText { return b.UserInputText }, - concatUserInputText, - func(t *UserInputText) *ContentBlock { - return &ContentBlock{Type: blockType, UserInputText: t, StreamMeta: &StreamMeta{Index: index}} - }) + concatUserInputTexts) case ContentBlockTypeUserInputImage: - return concatContentBlockHelper(blocks, blockType, "user input image", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *UserInputImage { return b.UserInputImage }, - concatUserInputImage, - func(i *UserInputImage) *ContentBlock { - return &ContentBlock{Type: blockType, UserInputImage: i, StreamMeta: &StreamMeta{Index: index}} - }) + concatUserInputImages) case ContentBlockTypeUserInputAudio: - return concatContentBlockHelper(blocks, blockType, "user input audio", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *UserInputAudio { return b.UserInputAudio }, - concatUserInputAudio, - func(a *UserInputAudio) *ContentBlock { - return &ContentBlock{Type: blockType, UserInputAudio: a, StreamMeta: &StreamMeta{Index: index}} - }) + concatUserInputAudios) case ContentBlockTypeUserInputVideo: - return concatContentBlockHelper(blocks, blockType, "user input video", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *UserInputVideo { return b.UserInputVideo }, - concatUserInputVideo, - func(v *UserInputVideo) *ContentBlock { - return &ContentBlock{Type: blockType, UserInputVideo: v, StreamMeta: &StreamMeta{Index: index}} - }) + concatUserInputVideos) case ContentBlockTypeUserInputFile: - return concatContentBlockHelper(blocks, blockType, "user input file", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *UserInputFile { return b.UserInputFile }, - concatUserInputFile, - func(f *UserInputFile) *ContentBlock { - return &ContentBlock{Type: blockType, UserInputFile: f, StreamMeta: &StreamMeta{Index: index}} - }) + concatUserInputFiles) case ContentBlockTypeAssistantGenText: - return concatContentBlockHelper(blocks, blockType, "assistant gen text", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *AssistantGenText { return b.AssistantGenText }, - concatAssistantGenText, - func(t *AssistantGenText) *ContentBlock { - return &ContentBlock{Type: blockType, AssistantGenText: t, StreamMeta: &StreamMeta{Index: index}} - }) + concatAssistantGenTexts) case ContentBlockTypeAssistantGenImage: - return concatContentBlockHelper(blocks, blockType, "assistant gen image", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *AssistantGenImage { return b.AssistantGenImage }, - concatAssistantGenImage, - func(i *AssistantGenImage) *ContentBlock { - return &ContentBlock{Type: blockType, AssistantGenImage: i, StreamMeta: &StreamMeta{Index: index}} - }) + concatAssistantGenImages) case ContentBlockTypeAssistantGenAudio: - return concatContentBlockHelper(blocks, blockType, "assistant gen audio", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *AssistantGenAudio { return b.AssistantGenAudio }, - concatAssistantGenAudio, - func(a *AssistantGenAudio) *ContentBlock { - return &ContentBlock{Type: blockType, AssistantGenAudio: a, StreamMeta: &StreamMeta{Index: index}} - }) + concatAssistantGenAudios) case ContentBlockTypeAssistantGenVideo: - return concatContentBlockHelper(blocks, blockType, "assistant gen video", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *AssistantGenVideo { return b.AssistantGenVideo }, - concatAssistantGenVideo, - func(v *AssistantGenVideo) *ContentBlock { - return &ContentBlock{Type: blockType, AssistantGenVideo: v, StreamMeta: &StreamMeta{Index: index}} - }) + concatAssistantGenVideos) case ContentBlockTypeFunctionToolCall: - return concatContentBlockHelper(blocks, blockType, "function tool call", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *FunctionToolCall { return b.FunctionToolCall }, - concatFunctionToolCall, - func(c *FunctionToolCall) *ContentBlock { - return &ContentBlock{Type: blockType, FunctionToolCall: c, StreamMeta: &StreamMeta{Index: index}} - }) + concatFunctionToolCalls) case ContentBlockTypeFunctionToolResult: - return concatContentBlockHelper(blocks, blockType, "function tool result", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *FunctionToolResult { return b.FunctionToolResult }, - concatFunctionToolResult, - func(r *FunctionToolResult) *ContentBlock { - return &ContentBlock{Type: blockType, FunctionToolResult: r, StreamMeta: &StreamMeta{Index: index}} - }) + concatFunctionToolResults) case ContentBlockTypeServerToolCall: - return concatContentBlockHelper(blocks, blockType, "server tool call", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *ServerToolCall { return b.ServerToolCall }, - concatServerToolCall, - func(c *ServerToolCall) *ContentBlock { - return &ContentBlock{Type: blockType, ServerToolCall: c, StreamMeta: &StreamMeta{Index: index}} - }) + concatServerToolCalls) case ContentBlockTypeServerToolResult: - return concatContentBlockHelper(blocks, blockType, "server tool result", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *ServerToolResult { return b.ServerToolResult }, - concatServerToolResult, - func(r *ServerToolResult) *ContentBlock { - return &ContentBlock{Type: blockType, ServerToolResult: r, StreamMeta: &StreamMeta{Index: index}} - }) + concatServerToolResults) case ContentBlockTypeMCPToolCall: - return concatContentBlockHelper(blocks, blockType, "MCP tool call", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *MCPToolCall { return b.MCPToolCall }, - concatMCPToolCall, - func(c *MCPToolCall) *ContentBlock { - return &ContentBlock{Type: blockType, MCPToolCall: c, StreamMeta: &StreamMeta{Index: index}} - }) + concatMCPToolCalls) case ContentBlockTypeMCPToolResult: - return concatContentBlockHelper(blocks, blockType, "MCP tool result", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *MCPToolResult { return b.MCPToolResult }, - concatMCPToolResult, - func(r *MCPToolResult) *ContentBlock { - return &ContentBlock{Type: blockType, MCPToolResult: r, StreamMeta: &StreamMeta{Index: index}} - }) + concatMCPToolResults) case ContentBlockTypeMCPListToolsResult: - return concatContentBlockHelper(blocks, blockType, "MCP list tools", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *MCPListToolsResult { return b.MCPListToolsResult }, - concatMCPListToolsResult, - func(r *MCPListToolsResult) *ContentBlock { - return &ContentBlock{Type: blockType, MCPListToolsResult: r, StreamMeta: &StreamMeta{Index: index}} - }) + concatMCPListToolsResults) case ContentBlockTypeMCPToolApprovalRequest: - return concatContentBlockHelper(blocks, blockType, "MCP tool approval request", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *MCPToolApprovalRequest { return b.MCPToolApprovalRequest }, - concatMCPToolApprovalRequest, - func(r *MCPToolApprovalRequest) *ContentBlock { - return &ContentBlock{Type: blockType, MCPToolApprovalRequest: r, StreamMeta: &StreamMeta{Index: index}} - }) + concatMCPToolApprovalRequests) case ContentBlockTypeMCPToolApprovalResponse: - return concatContentBlockHelper(blocks, blockType, "MCP tool approval response", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *MCPToolApprovalResponse { return b.MCPToolApprovalResponse }, - concatMCPToolApprovalResponse, - func(r *MCPToolApprovalResponse) *ContentBlock { - return &ContentBlock{Type: blockType, MCPToolApprovalResponse: r, StreamMeta: &StreamMeta{Index: index}} - }) + concatMCPToolApprovalResponses) default: return nil, fmt.Errorf("unknown content block type: %s", blockType) @@ -964,21 +1006,19 @@ func concatAgenticContentBlocks(blocks []*ContentBlock) (*ContentBlock, error) { // concatContentBlockHelper is a generic helper function that reduces code duplication // for concatenating content blocks of a specific type. -func concatContentBlockHelper[T any]( +func concatContentBlockHelper[T contentBlockVariant]( blocks []*ContentBlock, expectedType ContentBlockType, - typeName string, getter func(*ContentBlock) *T, concatFunc func([]*T) (*T, error), - constructor func(*T) *ContentBlock, ) (*ContentBlock, error) { items, err := genericGetTFromContentBlocks(blocks, func(block *ContentBlock) (*T, error) { if block.Type != expectedType { - return nil, fmt.Errorf("expected %s block, got %s", typeName, block.Type) + return nil, fmt.Errorf("content block type mismatch: expected '%s', but got '%s'", expectedType, block.Type) } item := getter(block) if item == nil { - return nil, fmt.Errorf("%s content is nil", typeName) + return nil, fmt.Errorf("'%s' content is nil", expectedType) } return item, nil }) @@ -988,10 +1028,28 @@ func concatContentBlockHelper[T any]( concatenated, err := concatFunc(items) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to concat '%s' content blocks: %w", expectedType, err) } - return constructor(concatenated), nil + extras := make([]map[string]any, 0, len(blocks)) + for _, block := range blocks { + if len(block.Extra) > 0 { + extras = append(extras, block.Extra) + } + } + + var extra map[string]any + if len(extras) > 0 { + extra, err = internal.ConcatItems(extras) + if err != nil { + return nil, fmt.Errorf("failed to concat content block extras: %w", err) + } + } + + block := NewContentBlock(concatenated) + block.Extra = extra + + return block, nil } func genericGetTFromContentBlocks[T any](blocks []*ContentBlock, checkAndGetter func(block *ContentBlock) (T, error)) ([]T, error) { @@ -1006,43 +1064,14 @@ func genericGetTFromContentBlocks[T any](blocks []*ContentBlock, checkAndGetter return ret, nil } -// Concatenation strategies for different content block types: -// -// String concatenation (incremental streaming): -// - Reasoning: Summary texts are concatenated, grouped by Index if present -// - UserInputText: Text fields are concatenated -// - AssistantGenText: Text fields are concatenated, annotations/citations are merged -// - FunctionToolCall: Arguments (JSON strings) are concatenated incrementally -// - FunctionToolResult: Result strings are concatenated -// - ServerToolCall: Arguments are merged (last non-nil value for any type) -// - ServerToolResult: Results are merged using internal.ConcatItems -// - MCPToolCall: Arguments (JSON strings) are concatenated incrementally -// - MCPToolResult: Result strings are concatenated -// - MCPListToolsResult: Tools arrays are merged -// - MCPToolApprovalRequest: Arguments are concatenated -// -// Take last block (non-streaming content): -// - UserInputImage, UserInputAudio, UserInputVideo, UserInputFile: Return last block -// - AssistantGenImage, AssistantGenAudio, AssistantGenVideo: Return last block -// - MCPToolApprovalResponse: Return last block -// - func concatReasoning(reasons []*Reasoning) (*Reasoning, error) { if len(reasons) == 0 { return nil, fmt.Errorf("no reasoning found") } - if len(reasons) == 1 { - return reasons[0], nil - } - ret := &Reasoning{ - Summary: make([]*ReasoningSummary, 0), - EncryptedContent: "", - Extra: make(map[string]any), - } + ret := &Reasoning{} - // Collect all summaries from all reasons - allSummaries := make([]*ReasoningSummary, 0) + var allSummaries []*ReasoningSummary for _, r := range reasons { if r == nil { continue @@ -1051,157 +1080,269 @@ func concatReasoning(reasons []*Reasoning) (*Reasoning, error) { if r.EncryptedContent != "" { ret.EncryptedContent += r.EncryptedContent } - for k, v := range r.Extra { - ret.Extra[k] = v - } } - // Group by Index and concatenate Text for same Index - // Use dynamic array that expands as needed - var summaryArray []*ReasoningSummary + var ( + indices []int + indexToSummary = map[int]*ReasoningSummary{} + ) + for _, s := range allSummaries { - idx := s.Index - // Expand array if needed - summaryArray = expandSlice(int(idx), summaryArray) - if summaryArray[idx] == nil { - // Create new entry with a copy of Index - summaryArray[idx] = &ReasoningSummary{ - Index: idx, - Text: s.Text, - } - } else { - // Concatenate text for same index - summaryArray[idx].Text += s.Text + if s == nil { + continue + } + if indexToSummary[s.Index] == nil { + indexToSummary[s.Index] = &ReasoningSummary{} + indices = append(indices, s.Index) } + indexToSummary[s.Index].Text += s.Text } - // Convert array to slice, filtering out nil entries - ret.Summary = make([]*ReasoningSummary, 0, len(summaryArray)) - for _, summary := range summaryArray { - if summary != nil { - ret.Summary = append(ret.Summary, summary) - } + sort.Slice(indices, func(i, j int) bool { + return indices[i] < indices[j] + }) + + ret.Summary = make([]*ReasoningSummary, 0, len(indices)) + for _, idx := range indices { + ret.Summary = append(ret.Summary, indexToSummary[idx]) } return ret, nil } -func concatUserInputText(texts []*UserInputText) (*UserInputText, error) { +func concatUserInputTexts(texts []*UserInputText) (*UserInputText, error) { if len(texts) == 0 { return nil, fmt.Errorf("no user input text found") } if len(texts) == 1 { return texts[0], nil } - - ret := &UserInputText{ - Text: "", - Extra: make(map[string]any), - } - - for _, t := range texts { - if t == nil { - continue - } - ret.Text += t.Text - for k, v := range t.Extra { - ret.Extra[k] = v - } - } - - return ret, nil + return nil, fmt.Errorf("cannot concat multiple user input texts") } -func concatUserInputImage(images []*UserInputImage) (*UserInputImage, error) { +func concatUserInputImages(images []*UserInputImage) (*UserInputImage, error) { if len(images) == 0 { return nil, fmt.Errorf("no user input image found") } - return images[len(images)-1], nil + if len(images) == 1 { + return images[0], nil + } + return nil, fmt.Errorf("cannot concat multiple user input images") } -func concatUserInputAudio(audios []*UserInputAudio) (*UserInputAudio, error) { +func concatUserInputAudios(audios []*UserInputAudio) (*UserInputAudio, error) { if len(audios) == 0 { return nil, fmt.Errorf("no user input audio found") } - return audios[len(audios)-1], nil + if len(audios) == 1 { + return audios[0], nil + } + return nil, fmt.Errorf("cannot concat multiple user input audios") } -func concatUserInputVideo(videos []*UserInputVideo) (*UserInputVideo, error) { +func concatUserInputVideos(videos []*UserInputVideo) (*UserInputVideo, error) { if len(videos) == 0 { return nil, fmt.Errorf("no user input video found") } - return videos[len(videos)-1], nil + if len(videos) == 1 { + return videos[0], nil + } + return nil, fmt.Errorf("cannot concat multiple user input videos") } -func concatUserInputFile(files []*UserInputFile) (*UserInputFile, error) { +func concatUserInputFiles(files []*UserInputFile) (*UserInputFile, error) { if len(files) == 0 { return nil, fmt.Errorf("no user input file found") } - return files[len(files)-1], nil + if len(files) == 1 { + return files[0], nil + } + return nil, fmt.Errorf("cannot concat multiple user input files") } -func concatAssistantGenText(texts []*AssistantGenText) (*AssistantGenText, error) { +func concatAssistantGenTexts(texts []*AssistantGenText) (ret *AssistantGenText, err error) { if len(texts) == 0 { - return nil, fmt.Errorf("no assistant gen text found") + return nil, fmt.Errorf("no assistant generated text found") } if len(texts) == 1 { return texts[0], nil } - ret := &AssistantGenText{ - Text: "", - OpenAIExtension: nil, - ClaudeExtension: nil, - Extra: make(map[string]any), - } + ret = &AssistantGenText{} + + openaiExtensions := make([]*openai.AssistantGenTextExtension, 0, len(texts)) + claudeExtensions := make([]*claude.AssistantGenTextExtension, 0, len(texts)) + + var ( + extType reflect.Type + extensions reflect.Value + ) for _, t := range texts { if t == nil { continue } + ret.Text += t.Text + + var isConsistent bool + + if t.Extension != nil { + extType, isConsistent = validateExtensionType(extType, t.Extension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in assistant generated text chunks: '%s' vs '%s'", + extType, reflect.TypeOf(t.Extension)) + } + if !extensions.IsValid() { + extensions = reflect.MakeSlice(reflect.SliceOf(extType), 0, len(texts)) + } + extensions = reflect.Append(extensions, reflect.ValueOf(t.Extension)) + } + if t.OpenAIExtension != nil { - if ret.OpenAIExtension == nil { - ret.OpenAIExtension = &openai.AssistantGenTextExtension{} + extType, isConsistent = validateExtensionType(extType, t.OpenAIExtension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in assistant generated text chunks: '%s' vs '%s'", + extType, reflect.TypeOf(t.OpenAIExtension)) } - ret.OpenAIExtension.Annotations = append(ret.OpenAIExtension.Annotations, t.OpenAIExtension.Annotations...) + openaiExtensions = append(openaiExtensions, t.OpenAIExtension) } + if t.ClaudeExtension != nil { - if ret.ClaudeExtension == nil { - ret.ClaudeExtension = &claude.AssistantGenTextExtension{} + extType, isConsistent = validateExtensionType(extType, t.ClaudeExtension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in assistant generated text chunks: '%s' vs '%s'", + extType, reflect.TypeOf(t.ClaudeExtension)) } - ret.ClaudeExtension.Citations = append(ret.ClaudeExtension.Citations, t.ClaudeExtension.Citations...) + claudeExtensions = append(claudeExtensions, t.ClaudeExtension) } - for k, v := range t.Extra { - ret.Extra[k] = v + } + + if extensions.IsValid() && !extensions.IsZero() { + ret.Extension, err = internal.ConcatSliceValue(extensions) + if err != nil { + return nil, err + } + ret.Extension = extensions.Interface() + } + + if len(openaiExtensions) > 0 { + ret.OpenAIExtension, err = openai.ConcatAssistantGenTextExtensions(openaiExtensions) + if err != nil { + return nil, err + } + } + + if len(claudeExtensions) > 0 { + ret.ClaudeExtension, err = claude.ConcatAssistantGenTextExtensions(claudeExtensions) + if err != nil { + return nil, err } } return ret, nil } -func concatAssistantGenImage(images []*AssistantGenImage) (*AssistantGenImage, error) { +func concatAssistantGenImages(images []*AssistantGenImage) (*AssistantGenImage, error) { if len(images) == 0 { return nil, fmt.Errorf("no assistant gen image found") } - return images[len(images)-1], nil + if len(images) == 1 { + return images[0], nil + } + + ret := &AssistantGenImage{} + + for _, img := range images { + if img == nil { + continue + } + + ret.Base64Data += img.Base64Data + + if ret.URL == "" { + ret.URL = img.URL + } else if img.URL != "" && ret.URL != img.URL { + return nil, fmt.Errorf("inconsistent URLs in assistant generated image chunks: '%s' vs '%s'", ret.URL, img.URL) + } + + if ret.MIMEType == "" { + ret.MIMEType = img.MIMEType + } else if img.MIMEType != "" && ret.MIMEType != img.MIMEType { + return nil, fmt.Errorf("inconsistent MIME types in assistant generated image chunks: '%s' vs '%s'", ret.MIMEType, img.MIMEType) + } + } + + return ret, nil } -func concatAssistantGenAudio(audios []*AssistantGenAudio) (*AssistantGenAudio, error) { +func concatAssistantGenAudios(audios []*AssistantGenAudio) (*AssistantGenAudio, error) { if len(audios) == 0 { return nil, fmt.Errorf("no assistant gen audio found") } - return audios[len(audios)-1], nil + if len(audios) == 1 { + return audios[0], nil + } + + ret := &AssistantGenAudio{} + + for _, audio := range audios { + if audio == nil { + continue + } + + ret.Base64Data += audio.Base64Data + + if ret.URL == "" { + ret.URL = audio.URL + } else if audio.URL != "" && ret.URL != audio.URL { + return nil, fmt.Errorf("inconsistent URLs in assistant generated audio chunks: '%s' vs '%s'", ret.URL, audio.URL) + } + + if ret.MIMEType == "" { + ret.MIMEType = audio.MIMEType + } else if audio.MIMEType != "" && ret.MIMEType != audio.MIMEType { + return nil, fmt.Errorf("inconsistent MIME types in assistant generated audio chunks: '%s' vs '%s'", ret.MIMEType, audio.MIMEType) + } + } + + return ret, nil } -func concatAssistantGenVideo(videos []*AssistantGenVideo) (*AssistantGenVideo, error) { +func concatAssistantGenVideos(videos []*AssistantGenVideo) (*AssistantGenVideo, error) { if len(videos) == 0 { return nil, fmt.Errorf("no assistant gen video found") } - return videos[len(videos)-1], nil + if len(videos) == 1 { + return videos[0], nil + } + + ret := &AssistantGenVideo{} + + for _, video := range videos { + if video == nil { + continue + } + + ret.Base64Data += video.Base64Data + + if ret.URL == "" { + ret.URL = video.URL + } else if video.URL != "" && ret.URL != video.URL { + return nil, fmt.Errorf("inconsistent URLs in assistant generated video chunks: '%s' vs '%s'", ret.URL, video.URL) + } + + if ret.MIMEType == "" { + ret.MIMEType = video.MIMEType + } else if video.MIMEType != "" && ret.MIMEType != video.MIMEType { + return nil, fmt.Errorf("inconsistent MIME types in assistant generated video chunks: '%s' vs '%s'", ret.MIMEType, video.MIMEType) + } + } + + return ret, nil } -func concatFunctionToolCall(calls []*FunctionToolCall) (*FunctionToolCall, error) { +func concatFunctionToolCalls(calls []*FunctionToolCall) (*FunctionToolCall, error) { if len(calls) == 0 { return nil, fmt.Errorf("no function tool call found") } @@ -1209,31 +1350,32 @@ func concatFunctionToolCall(calls []*FunctionToolCall) (*FunctionToolCall, error return calls[0], nil } - // For tool calls, arguments are typically built incrementally during streaming - ret := &FunctionToolCall{ - Extra: make(map[string]any), - } + ret := &FunctionToolCall{} for _, c := range calls { if c == nil { continue } + if ret.CallID == "" { ret.CallID = c.CallID + } else if c.CallID != "" && c.CallID != ret.CallID { + return nil, fmt.Errorf("expected call ID '%s' for function tool call, but got '%s'", ret.CallID, c.CallID) } + if ret.Name == "" { ret.Name = c.Name + } else if c.Name != "" && c.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for function tool call, but got '%s'", ret.Name, c.Name) } + ret.Arguments += c.Arguments - for k, v := range c.Extra { - ret.Extra[k] = v - } } return ret, nil } -func concatFunctionToolResult(results []*FunctionToolResult) (*FunctionToolResult, error) { +func concatFunctionToolResults(results []*FunctionToolResult) (*FunctionToolResult, error) { if len(results) == 0 { return nil, fmt.Errorf("no function tool result found") } @@ -1241,30 +1383,32 @@ func concatFunctionToolResult(results []*FunctionToolResult) (*FunctionToolResul return results[0], nil } - ret := &FunctionToolResult{ - Extra: make(map[string]any), - } + ret := &FunctionToolResult{} for _, r := range results { if r == nil { continue } + if ret.CallID == "" { ret.CallID = r.CallID + } else if r.CallID != "" && r.CallID != ret.CallID { + return nil, fmt.Errorf("expected call ID '%s' for function tool result, but got '%s'", ret.CallID, r.CallID) } + if ret.Name == "" { ret.Name = r.Name + } else if r.Name != "" && r.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for function tool result, but got '%s'", ret.Name, r.Name) } + ret.Result += r.Result - for k, v := range r.Extra { - ret.Extra[k] = v - } } return ret, nil } -func concatServerToolCall(calls []*ServerToolCall) (*ServerToolCall, error) { +func concatServerToolCalls(calls []*ServerToolCall) (ret *ServerToolCall, err error) { if len(calls) == 0 { return nil, fmt.Errorf("no server tool call found") } @@ -1272,33 +1416,54 @@ func concatServerToolCall(calls []*ServerToolCall) (*ServerToolCall, error) { return calls[0], nil } - // ServerToolCall Arguments is of type any; merge strategy uses the last non-nil value - ret := &ServerToolCall{ - Extra: make(map[string]any), - } + ret = &ServerToolCall{} + + var ( + argsType reflect.Type + argsChunks reflect.Value + ) for _, c := range calls { if c == nil { continue } - if ret.Name == "" { - ret.Name = c.Name - } + if ret.CallID == "" { ret.CallID = c.CallID + } else if c.CallID != "" && c.CallID != ret.CallID { + return nil, fmt.Errorf("expected call ID '%s' for server tool call, but got '%s'", ret.CallID, c.CallID) } + + if ret.Name == "" { + ret.Name = c.Name + } else if c.Name != "" && c.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for server tool call, but got '%s'", ret.Name, c.Name) + } + if c.Arguments != nil { - ret.Arguments = c.Arguments + argsType_ := reflect.TypeOf(c.Arguments) + if argsType == nil { + argsType = argsType_ + argsChunks = reflect.MakeSlice(reflect.SliceOf(argsType), 0, len(calls)) + } else if argsType != argsType_ { + return nil, fmt.Errorf("expected type '%s' for server tool call arguments, but got '%s'", argsType, argsType_) + } + argsChunks = reflect.Append(argsChunks, reflect.ValueOf(c.Arguments)) } - for k, v := range c.Extra { - ret.Extra[k] = v + } + + if argsChunks.IsValid() && !argsChunks.IsZero() { + arguments, err := internal.ConcatSliceValue(argsChunks) + if err != nil { + return nil, err } + ret.Arguments = arguments.Interface() } return ret, nil } -func concatServerToolResult(results []*ServerToolResult) (*ServerToolResult, error) { +func concatServerToolResults(results []*ServerToolResult) (ret *ServerToolResult, err error) { if len(results) == 0 { return nil, fmt.Errorf("no server tool result found") } @@ -1306,45 +1471,54 @@ func concatServerToolResult(results []*ServerToolResult) (*ServerToolResult, err return results[0], nil } - // ServerToolResult Result is of type any; merge strategy uses the last non-nil value - ret := &ServerToolResult{ - Extra: make(map[string]any), - } + ret = &ServerToolResult{} + + var ( + resType reflect.Type + resChunks reflect.Value + ) - tZeroResult := reflect.TypeOf(results[0].Result) - data := reflect.MakeSlice(reflect.SliceOf(tZeroResult), 0, 0) for _, r := range results { if r == nil { continue } - if ret.Name == "" { - ret.Name = r.Name - } + if ret.CallID == "" { ret.CallID = r.CallID + } else if r.CallID != "" && r.CallID != ret.CallID { + return nil, fmt.Errorf("expected call ID '%s' for server tool result, but got '%s'", ret.CallID, r.CallID) } + + if ret.Name == "" { + ret.Name = r.Name + } else if r.Name != "" && r.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for server tool result, but got '%s'", ret.Name, r.Name) + } + if r.Result != nil { - vResult := reflect.ValueOf(r.Result) - if tZeroResult != vResult.Type() { - return nil, fmt.Errorf("tool result types are different: %v %v", tZeroResult, vResult.Type()) + resType_ := reflect.TypeOf(r.Result) + if resType == nil { + resType = resType_ + resChunks = reflect.MakeSlice(reflect.SliceOf(resType), 0, len(results)) + } else if resType != resType_ { + return nil, fmt.Errorf("expected type '%s' for server tool result, but got '%s'", resType, resType_) } - data = reflect.Append(data, vResult) - } - for k, v := range r.Extra { - ret.Extra[k] = v + resChunks = reflect.Append(resChunks, reflect.ValueOf(r.Result)) } } - d, err := internal.ConcatSliceValue(data) - if err != nil { - return nil, fmt.Errorf("failed to concat server tool result: %v", err) + if resChunks.IsValid() && !resChunks.IsZero() { + result, err := internal.ConcatSliceValue(resChunks) + if err != nil { + return nil, fmt.Errorf("failed to concat server tool result: %v", err) + } + ret.Result = result.Interface() } - ret.Result = d return ret, nil } -func concatMCPToolCall(calls []*MCPToolCall) (*MCPToolCall, error) { +func concatMCPToolCalls(calls []*MCPToolCall) (*MCPToolCall, error) { if len(calls) == 0 { return nil, fmt.Errorf("no mcp tool call found") } @@ -1352,36 +1526,38 @@ func concatMCPToolCall(calls []*MCPToolCall) (*MCPToolCall, error) { return calls[0], nil } - ret := &MCPToolCall{ - Extra: make(map[string]any), - } + ret := &MCPToolCall{} for _, c := range calls { if c == nil { continue } + + ret.Arguments += c.Arguments + if ret.ServerLabel == "" { ret.ServerLabel = c.ServerLabel + } else if c.ServerLabel != "" && c.ServerLabel != ret.ServerLabel { + return nil, fmt.Errorf("expected server label '%s' for mcp tool call, but got '%s'", ret.ServerLabel, c.ServerLabel) } - if ret.ApprovalRequestID == "" { - ret.ApprovalRequestID = c.ApprovalRequestID - } + if ret.CallID == "" { ret.CallID = c.CallID + } else if c.CallID != "" && c.CallID != ret.CallID { + return nil, fmt.Errorf("expected call ID '%s' for mcp tool call, but got '%s'", ret.CallID, c.CallID) } + if ret.Name == "" { ret.Name = c.Name - } - ret.Arguments += c.Arguments - for k, v := range c.Extra { - ret.Extra[k] = v + } else if c.Name != "" && c.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for mcp tool call, but got '%s'", ret.Name, c.Name) } } return ret, nil } -func concatMCPToolResult(results []*MCPToolResult) (*MCPToolResult, error) { +func concatMCPToolResults(results []*MCPToolResult) (*MCPToolResult, error) { if len(results) == 0 { return nil, fmt.Errorf("no mcp tool result found") } @@ -1389,33 +1565,44 @@ func concatMCPToolResult(results []*MCPToolResult) (*MCPToolResult, error) { return results[0], nil } - ret := &MCPToolResult{ - Extra: make(map[string]any), - } + ret := &MCPToolResult{} for _, r := range results { if r == nil { continue } + + if r.Result != "" { + ret.Result = r.Result + } + + if ret.ServerLabel == "" { + ret.ServerLabel = r.ServerLabel + } else if r.ServerLabel != "" && r.ServerLabel != ret.ServerLabel { + return nil, fmt.Errorf("expected server label '%s' for mcp tool result, but got '%s'", ret.ServerLabel, r.ServerLabel) + } + if ret.CallID == "" { ret.CallID = r.CallID + } else if r.CallID != "" && r.CallID != ret.CallID { + return nil, fmt.Errorf("expected call ID '%s' for mcp tool result, but got '%s'", ret.CallID, r.CallID) } + if ret.Name == "" { ret.Name = r.Name + } else if r.Name != "" && r.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for mcp tool result, but got '%s'", ret.Name, r.Name) } - ret.Result += r.Result + if r.Error != nil { - ret.Error = r.Error // Use the last error - } - for k, v := range r.Extra { - ret.Extra[k] = v + ret.Error = r.Error } } return ret, nil } -func concatMCPListToolsResult(results []*MCPListToolsResult) (*MCPListToolsResult, error) { +func concatMCPListToolsResults(results []*MCPListToolsResult) (*MCPListToolsResult, error) { if len(results) == 0 { return nil, fmt.Errorf("no mcp list tools result found") } @@ -1423,31 +1610,30 @@ func concatMCPListToolsResult(results []*MCPListToolsResult) (*MCPListToolsResul return results[0], nil } - ret := &MCPListToolsResult{ - Tools: make([]*MCPListToolsItem, 0), - Extra: make(map[string]any), - } + ret := &MCPListToolsResult{} for _, r := range results { if r == nil { continue } - if ret.ServerLabel == "" { - ret.ServerLabel = r.ServerLabel - } + ret.Tools = append(ret.Tools, r.Tools...) + if r.Error != "" { - ret.Error = r.Error // Use the last error + ret.Error = r.Error } - for k, v := range r.Extra { - ret.Extra[k] = v + + if ret.ServerLabel == "" { + ret.ServerLabel = r.ServerLabel + } else if r.ServerLabel != "" && r.ServerLabel != ret.ServerLabel { + return nil, fmt.Errorf("expected server label '%s' for mcp list tools result, but got '%s'", ret.ServerLabel, r.ServerLabel) } } return ret, nil } -func concatMCPToolApprovalRequest(requests []*MCPToolApprovalRequest) (*MCPToolApprovalRequest, error) { +func concatMCPToolApprovalRequests(requests []*MCPToolApprovalRequest) (*MCPToolApprovalRequest, error) { if len(requests) == 0 { return nil, fmt.Errorf("no mcp tool approval request found") } @@ -1455,48 +1641,45 @@ func concatMCPToolApprovalRequest(requests []*MCPToolApprovalRequest) (*MCPToolA return requests[0], nil } - ret := &MCPToolApprovalRequest{ - Extra: make(map[string]any), - } + ret := &MCPToolApprovalRequest{} for _, r := range requests { if r == nil { continue } + + ret.Arguments += r.Arguments + if ret.ID == "" { ret.ID = r.ID + } else if r.ID != "" && r.ID != ret.ID { + return nil, fmt.Errorf("expected request ID '%s' for mcp tool approval request, but got '%s'", ret.ID, r.ID) } + if ret.Name == "" { ret.Name = r.Name + } else if r.Name != "" && r.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for mcp tool approval request, but got '%s'", ret.Name, r.Name) } - ret.Arguments += r.Arguments + if ret.ServerLabel == "" { ret.ServerLabel = r.ServerLabel - } - for k, v := range r.Extra { - ret.Extra[k] = v + } else if r.ServerLabel != "" && r.ServerLabel != ret.ServerLabel { + return nil, fmt.Errorf("expected server label '%s' for mcp tool approval request, but got '%s'", ret.ServerLabel, r.ServerLabel) } } return ret, nil } -func concatMCPToolApprovalResponse(responses []*MCPToolApprovalResponse) (*MCPToolApprovalResponse, error) { +func concatMCPToolApprovalResponses(responses []*MCPToolApprovalResponse) (*MCPToolApprovalResponse, error) { if len(responses) == 0 { return nil, fmt.Errorf("no mcp tool approval response found") } if len(responses) == 1 { return responses[0], nil } - - return responses[len(responses)-1], nil -} - -func expandSlice[T any](idx int, s []T) []T { - if len(s) > idx { - return s - } - return append(s, make([]T, idx-len(s)+1)...) + return nil, fmt.Errorf("cannot concat multiple mcp tool approval responses") } func (m *AgenticMessage) String() string { @@ -1603,8 +1786,8 @@ func (b *ContentBlock) String() string { } } - if b.StreamMeta != nil { - sb.WriteString(fmt.Sprintf(" stream_index: %d\n", b.StreamMeta.Index)) + if b.StreamingMeta != nil { + sb.WriteString(fmt.Sprintf(" stream_index: %d\n", b.StreamingMeta.Index)) } return sb.String() @@ -1705,9 +1888,6 @@ func (m *MCPToolCall) String() string { sb.WriteString(fmt.Sprintf(" call_id: %s\n", m.CallID)) sb.WriteString(fmt.Sprintf(" name: %s\n", m.Name)) sb.WriteString(fmt.Sprintf(" arguments: %s\n", m.Arguments)) - if m.ApprovalRequestID != "" { - sb.WriteString(fmt.Sprintf(" approval_request_id: %s\n", m.ApprovalRequestID)) - } return sb.String() } @@ -1792,3 +1972,17 @@ func formatMediaString(url, base64Data string, mimeType string, detail any) stri } return sb.String() } + +func validateExtensionType(expected reflect.Type, actual any) (reflect.Type, bool) { + if actual == nil { + return expected, true + } + actualType := reflect.TypeOf(actual) + if expected == nil { + return actualType, true + } + if expected != actualType { + return expected, false + } + return expected, true +} diff --git a/schema/agentic_message_test.go b/schema/agentic_message_test.go index 0cafcd9ff..016aa5c4e 100644 --- a/schema/agentic_message_test.go +++ b/schema/agentic_message_test.go @@ -18,6 +18,7 @@ package schema import ( "context" + "reflect" "testing" "github.com/stretchr/testify/assert" @@ -75,7 +76,7 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "Hello ", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -87,7 +88,7 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "World!", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -112,7 +113,7 @@ func TestConcatAgenticMessages(t *testing.T) { {Index: 0, Text: "First "}, }, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -126,7 +127,7 @@ func TestConcatAgenticMessages(t *testing.T) { {Index: 0, Text: "Second"}, }, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -137,7 +138,7 @@ func TestConcatAgenticMessages(t *testing.T) { assert.Len(t, result.ContentBlocks, 1) assert.Len(t, result.ContentBlocks[0].Reasoning.Summary, 1) assert.Equal(t, "First Second", result.ContentBlocks[0].Reasoning.Summary[0].Text) - assert.Equal(t, int64(0), result.ContentBlocks[0].Reasoning.Summary[0].Index) + assert.Equal(t, 0, result.ContentBlocks[0].Reasoning.Summary[0].Index) }) t.Run("concat reasoning with index", func(t *testing.T) { @@ -153,7 +154,7 @@ func TestConcatAgenticMessages(t *testing.T) { {Index: 1, Text: "Part2-"}, }, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -168,7 +169,7 @@ func TestConcatAgenticMessages(t *testing.T) { {Index: 1, Text: "Part4"}, }, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -185,26 +186,26 @@ func TestConcatAgenticMessages(t *testing.T) { t.Run("concat user input text", func(t *testing.T) { msgs := []*AgenticMessage{ { - Role: AgenticRoleTypeUser, + Role: AgenticRoleTypeAssistant, ContentBlocks: []*ContentBlock{ { - Type: ContentBlockTypeUserInputText, - UserInputText: &UserInputText{ + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ Text: "Hello ", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, { - Role: AgenticRoleTypeUser, + Role: AgenticRoleTypeAssistant, ContentBlocks: []*ContentBlock{ { - Type: ContentBlockTypeUserInputText, - UserInputText: &UserInputText{ + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ Text: "World!", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -213,35 +214,35 @@ func TestConcatAgenticMessages(t *testing.T) { result, err := ConcatAgenticMessages(msgs) assert.NoError(t, err) assert.Len(t, result.ContentBlocks, 1) - assert.Equal(t, "Hello World!", result.ContentBlocks[0].UserInputText.Text) + assert.Equal(t, "Hello World!", result.ContentBlocks[0].AssistantGenText.Text) }) - t.Run("concat user input image", func(t *testing.T) { - url1 := "https://example.com/image1.jpg" - url2 := "https://example.com/image2.jpg" + t.Run("concat assistant gen image", func(t *testing.T) { + base1 := "1" + base2 := "2" msgs := []*AgenticMessage{ { - Role: AgenticRoleTypeUser, + Role: AgenticRoleTypeAssistant, ContentBlocks: []*ContentBlock{ { - Type: ContentBlockTypeUserInputImage, - UserInputImage: &UserInputImage{ - URL: url1, + Type: ContentBlockTypeAssistantGenImage, + AssistantGenImage: &AssistantGenImage{ + Base64Data: base1, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, { - Role: AgenticRoleTypeUser, + Role: AgenticRoleTypeAssistant, ContentBlocks: []*ContentBlock{ { - Type: ContentBlockTypeUserInputImage, - UserInputImage: &UserInputImage{ - URL: url2, + Type: ContentBlockTypeAssistantGenImage, + AssistantGenImage: &AssistantGenImage{ + Base64Data: base2, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -250,11 +251,10 @@ func TestConcatAgenticMessages(t *testing.T) { result, err := ConcatAgenticMessages(msgs) assert.NoError(t, err) assert.Len(t, result.ContentBlocks, 1) - // Should take the last image - assert.Equal(t, url2, result.ContentBlocks[0].UserInputImage.URL) + assert.Equal(t, "12", result.ContentBlocks[0].AssistantGenImage.Base64Data) }) - t.Run("concat user input audio", func(t *testing.T) { + t.Run("concat user input audio - should error", func(t *testing.T) { url1 := "https://example.com/audio1.mp3" url2 := "https://example.com/audio2.mp3" @@ -267,7 +267,7 @@ func TestConcatAgenticMessages(t *testing.T) { UserInputAudio: &UserInputAudio{ URL: url1, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -279,20 +279,18 @@ func TestConcatAgenticMessages(t *testing.T) { UserInputAudio: &UserInputAudio{ URL: url2, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, } - result, err := ConcatAgenticMessages(msgs) - assert.NoError(t, err) - assert.Len(t, result.ContentBlocks, 1) - // Should take the last audio - assert.Equal(t, url2, result.ContentBlocks[0].UserInputAudio.URL) + _, err := ConcatAgenticMessages(msgs) + assert.Error(t, err) + assert.ErrorContains(t, err, "cannot concat multiple user input audios") }) - t.Run("concat user input video", func(t *testing.T) { + t.Run("concat user input video - should error", func(t *testing.T) { url1 := "https://example.com/video1.mp4" url2 := "https://example.com/video2.mp4" @@ -305,7 +303,7 @@ func TestConcatAgenticMessages(t *testing.T) { UserInputVideo: &UserInputVideo{ URL: url1, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -317,17 +315,15 @@ func TestConcatAgenticMessages(t *testing.T) { UserInputVideo: &UserInputVideo{ URL: url2, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, } - result, err := ConcatAgenticMessages(msgs) - assert.NoError(t, err) - assert.Len(t, result.ContentBlocks, 1) - // Should take the last video - assert.Equal(t, url2, result.ContentBlocks[0].UserInputVideo.URL) + _, err := ConcatAgenticMessages(msgs) + assert.Error(t, err) + assert.ErrorContains(t, err, "cannot concat multiple user input videos") }) t.Run("concat assistant gen text", func(t *testing.T) { @@ -340,7 +336,7 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "Generated ", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -352,7 +348,7 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "Text", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -365,9 +361,6 @@ func TestConcatAgenticMessages(t *testing.T) { }) t.Run("concat assistant gen image", func(t *testing.T) { - url1 := "https://example.com/gen_image1.jpg" - url2 := "https://example.com/gen_image2.jpg" - msgs := []*AgenticMessage{ { Role: AgenticRoleTypeAssistant, @@ -375,9 +368,9 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeAssistantGenImage, AssistantGenImage: &AssistantGenImage{ - URL: url1, + Base64Data: "part1", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -387,9 +380,9 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeAssistantGenImage, AssistantGenImage: &AssistantGenImage{ - URL: url2, + Base64Data: "part2", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -398,14 +391,10 @@ func TestConcatAgenticMessages(t *testing.T) { result, err := ConcatAgenticMessages(msgs) assert.NoError(t, err) assert.Len(t, result.ContentBlocks, 1) - // Should take the last image - assert.Equal(t, url2, result.ContentBlocks[0].AssistantGenImage.URL) + assert.Equal(t, "part1part2", result.ContentBlocks[0].AssistantGenImage.Base64Data) }) t.Run("concat assistant gen audio", func(t *testing.T) { - url1 := "https://example.com/gen_audio1.mp3" - url2 := "https://example.com/gen_audio2.mp3" - msgs := []*AgenticMessage{ { Role: AgenticRoleTypeAssistant, @@ -413,9 +402,9 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeAssistantGenAudio, AssistantGenAudio: &AssistantGenAudio{ - URL: url1, + Base64Data: "audio1", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -425,9 +414,9 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeAssistantGenAudio, AssistantGenAudio: &AssistantGenAudio{ - URL: url2, + Base64Data: "audio2", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -436,14 +425,10 @@ func TestConcatAgenticMessages(t *testing.T) { result, err := ConcatAgenticMessages(msgs) assert.NoError(t, err) assert.Len(t, result.ContentBlocks, 1) - // Should take the last audio - assert.Equal(t, url2, result.ContentBlocks[0].AssistantGenAudio.URL) + assert.Equal(t, "audio1audio2", result.ContentBlocks[0].AssistantGenAudio.Base64Data) }) t.Run("concat assistant gen video", func(t *testing.T) { - url1 := "https://example.com/gen_video1.mp4" - url2 := "https://example.com/gen_video2.mp4" - msgs := []*AgenticMessage{ { Role: AgenticRoleTypeAssistant, @@ -451,9 +436,9 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeAssistantGenVideo, AssistantGenVideo: &AssistantGenVideo{ - URL: url1, + Base64Data: "video1", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -463,9 +448,9 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeAssistantGenVideo, AssistantGenVideo: &AssistantGenVideo{ - URL: url2, + Base64Data: "video2", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -474,8 +459,7 @@ func TestConcatAgenticMessages(t *testing.T) { result, err := ConcatAgenticMessages(msgs) assert.NoError(t, err) assert.Len(t, result.ContentBlocks, 1) - // Should take the last video - assert.Equal(t, url2, result.ContentBlocks[0].AssistantGenVideo.URL) + assert.Equal(t, "video1video2", result.ContentBlocks[0].AssistantGenVideo.Base64Data) }) t.Run("concat function tool call", func(t *testing.T) { @@ -490,7 +474,7 @@ func TestConcatAgenticMessages(t *testing.T) { Name: "get_weather", Arguments: `{"location`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -502,7 +486,7 @@ func TestConcatAgenticMessages(t *testing.T) { FunctionToolCall: &FunctionToolCall{ Arguments: `":"NYC"}`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -528,7 +512,7 @@ func TestConcatAgenticMessages(t *testing.T) { Name: "get_weather", Result: `{"temp`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -540,7 +524,7 @@ func TestConcatAgenticMessages(t *testing.T) { FunctionToolResult: &FunctionToolResult{ Result: `":72}`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -565,7 +549,7 @@ func TestConcatAgenticMessages(t *testing.T) { CallID: "server_call_1", Name: "server_func", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -577,7 +561,7 @@ func TestConcatAgenticMessages(t *testing.T) { ServerToolCall: &ServerToolCall{ Arguments: map[string]any{"key": "value"}, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -603,7 +587,7 @@ func TestConcatAgenticMessages(t *testing.T) { Name: "server_func", Result: "result1", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -611,11 +595,9 @@ func TestConcatAgenticMessages(t *testing.T) { Role: AgenticRoleTypeAssistant, ContentBlocks: []*ContentBlock{ { - Type: ContentBlockTypeServerToolResult, - ServerToolResult: &ServerToolResult{ - Result: "result2", - }, - StreamMeta: &StreamMeta{Index: 0}, + Type: ContentBlockTypeServerToolResult, + ServerToolResult: &ServerToolResult{}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -626,6 +608,7 @@ func TestConcatAgenticMessages(t *testing.T) { assert.Len(t, result.ContentBlocks, 1) assert.Equal(t, "server_call_1", result.ContentBlocks[0].ServerToolResult.CallID) assert.Equal(t, "server_func", result.ContentBlocks[0].ServerToolResult.Name) + assert.Equal(t, "result1", result.ContentBlocks[0].ServerToolResult.Result) }) t.Run("concat mcp tool call", func(t *testing.T) { @@ -641,7 +624,7 @@ func TestConcatAgenticMessages(t *testing.T) { Name: "mcp_func", Arguments: `{"arg`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -653,7 +636,7 @@ func TestConcatAgenticMessages(t *testing.T) { MCPToolCall: &MCPToolCall{ Arguments: `":123}`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -676,11 +659,12 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeMCPToolResult, MCPToolResult: &MCPToolResult{ - CallID: "mcp_call_1", - Name: "mcp_func", - Result: `{"res`, + ServerLabel: "mcp-server", + CallID: "mcp_call_1", + Name: "mcp_func", + Result: `First`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -690,9 +674,9 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeMCPToolResult, MCPToolResult: &MCPToolResult{ - Result: `ult":true}`, + Result: `Second`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -701,9 +685,10 @@ func TestConcatAgenticMessages(t *testing.T) { result, err := ConcatAgenticMessages(msgs) assert.NoError(t, err) assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "mcp-server", result.ContentBlocks[0].MCPToolResult.ServerLabel) assert.Equal(t, "mcp_call_1", result.ContentBlocks[0].MCPToolResult.CallID) assert.Equal(t, "mcp_func", result.ContentBlocks[0].MCPToolResult.Name) - assert.Equal(t, `{"result":true}`, result.ContentBlocks[0].MCPToolResult.Result) + assert.Equal(t, `Second`, result.ContentBlocks[0].MCPToolResult.Result) }) t.Run("concat mcp list tools", func(t *testing.T) { @@ -719,7 +704,7 @@ func TestConcatAgenticMessages(t *testing.T) { {Name: "tool1"}, }, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -733,7 +718,7 @@ func TestConcatAgenticMessages(t *testing.T) { {Name: "tool2"}, }, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -759,7 +744,7 @@ func TestConcatAgenticMessages(t *testing.T) { ServerLabel: "mcp-server", Arguments: `{"request`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -771,7 +756,7 @@ func TestConcatAgenticMessages(t *testing.T) { MCPToolApprovalRequest: &MCPToolApprovalRequest{ Arguments: `":1}`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -786,7 +771,7 @@ func TestConcatAgenticMessages(t *testing.T) { assert.Equal(t, `{"request":1}`, result.ContentBlocks[0].MCPToolApprovalRequest.Arguments) }) - t.Run("concat mcp tool approval response", func(t *testing.T) { + t.Run("concat mcp tool approval response - should error", func(t *testing.T) { response1 := &MCPToolApprovalResponse{ ApprovalRequestID: "approval_1", Approve: false, @@ -803,7 +788,7 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeMCPToolApprovalResponse, MCPToolApprovalResponse: response1, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -813,17 +798,15 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeMCPToolApprovalResponse, MCPToolApprovalResponse: response2, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, } - result, err := ConcatAgenticMessages(msgs) - assert.NoError(t, err) - assert.Len(t, result.ContentBlocks, 1) - // Should take the last response - assert.Equal(t, response2, result.ContentBlocks[0].MCPToolApprovalResponse) + _, err := ConcatAgenticMessages(msgs) + assert.Error(t, err) + assert.ErrorContains(t, err, "cannot concat multiple mcp tool approval responses") }) t.Run("concat response meta", func(t *testing.T) { @@ -865,7 +848,7 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "Hello", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -877,7 +860,7 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "World", }, - // No StreamMeta - non-streaming + // No StreamingMeta - non-streaming }, }, }, @@ -901,7 +884,7 @@ func TestConcatAgenticMessages(t *testing.T) { Name: "list_files", Arguments: `{"path`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -913,7 +896,7 @@ func TestConcatAgenticMessages(t *testing.T) { MCPToolCall: &MCPToolCall{ Arguments: `":"/tmp"}`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -927,7 +910,7 @@ func TestConcatAgenticMessages(t *testing.T) { assert.Equal(t, `{"path":"/tmp"}`, result.ContentBlocks[0].MCPToolCall.Arguments) }) - t.Run("concat user input text", func(t *testing.T) { + t.Run("concat user input text - should error", func(t *testing.T) { msgs := []*AgenticMessage{ { Role: AgenticRoleTypeUser, @@ -937,7 +920,7 @@ func TestConcatAgenticMessages(t *testing.T) { UserInputText: &UserInputText{ Text: "What is ", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -949,16 +932,15 @@ func TestConcatAgenticMessages(t *testing.T) { UserInputText: &UserInputText{ Text: "the weather?", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, } - result, err := ConcatAgenticMessages(msgs) - assert.NoError(t, err) - assert.Len(t, result.ContentBlocks, 1) - assert.Equal(t, "What is the weather?", result.ContentBlocks[0].UserInputText.Text) + _, err := ConcatAgenticMessages(msgs) + assert.Error(t, err) + assert.ErrorContains(t, err, "cannot concat multiple user input texts") }) t.Run("multiple stream indexes - sparse indexes", func(t *testing.T) { @@ -971,14 +953,14 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "Index0-", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, { Type: ContentBlockTypeAssistantGenText, AssistantGenText: &AssistantGenText{ Text: "Index2-", }, - StreamMeta: &StreamMeta{Index: 2}, + StreamingMeta: &StreamingMeta{Index: 2}, }, }, }, @@ -990,14 +972,14 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "Part2", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, { Type: ContentBlockTypeAssistantGenText, AssistantGenText: &AssistantGenText{ Text: "Part2", }, - StreamMeta: &StreamMeta{Index: 2}, + StreamingMeta: &StreamingMeta{Index: 2}, }, }, }, @@ -1020,7 +1002,7 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "Text ", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, { Type: ContentBlockTypeFunctionToolCall, @@ -1029,7 +1011,7 @@ func TestConcatAgenticMessages(t *testing.T) { Name: "func1", Arguments: `{"a`, }, - StreamMeta: &StreamMeta{Index: 1}, + StreamingMeta: &StreamingMeta{Index: 1}, }, }, }, @@ -1041,14 +1023,14 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "Content", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, { Type: ContentBlockTypeFunctionToolCall, FunctionToolCall: &FunctionToolCall{ Arguments: `":1}`, }, - StreamMeta: &StreamMeta{Index: 1}, + StreamingMeta: &StreamingMeta{Index: 1}, }, }, }, @@ -1073,21 +1055,21 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "A", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, { Type: ContentBlockTypeAssistantGenText, AssistantGenText: &AssistantGenText{ Text: "B", }, - StreamMeta: &StreamMeta{Index: 1}, + StreamingMeta: &StreamingMeta{Index: 1}, }, { Type: ContentBlockTypeAssistantGenText, AssistantGenText: &AssistantGenText{ Text: "C", }, - StreamMeta: &StreamMeta{Index: 2}, + StreamingMeta: &StreamingMeta{Index: 2}, }, }, }, @@ -1099,21 +1081,21 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "1", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, { Type: ContentBlockTypeAssistantGenText, AssistantGenText: &AssistantGenText{ Text: "2", }, - StreamMeta: &StreamMeta{Index: 1}, + StreamingMeta: &StreamingMeta{Index: 1}, }, { Type: ContentBlockTypeAssistantGenText, AssistantGenText: &AssistantGenText{ Text: "3", }, - StreamMeta: &StreamMeta{Index: 2}, + StreamingMeta: &StreamingMeta{Index: 2}, }, }, }, @@ -1276,7 +1258,7 @@ func TestAgenticMessageString(t *testing.T) { Name: "get_current_weather", Arguments: `{"location":"New York City","unit":"fahrenheit"}`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, { Type: ContentBlockTypeFunctionToolResult, @@ -1289,11 +1271,10 @@ func TestAgenticMessageString(t *testing.T) { { Type: ContentBlockTypeMCPToolCall, MCPToolCall: &MCPToolCall{ - ServerLabel: "weather-mcp-server", - CallID: "mcp_forecast_456", - Name: "get_7day_forecast", - Arguments: `{"city":"New York","days":7}`, - ApprovalRequestID: "approval_req_789", + ServerLabel: "weather-mcp-server", + CallID: "mcp_forecast_456", + Name: "get_7day_forecast", + Arguments: `{"city":"New York","days":7}`, }, }, { @@ -1363,7 +1344,6 @@ content_blocks: call_id: mcp_forecast_456 name: get_7day_forecast arguments: {"city":"New York","days":7} - approval_request_id: approval_req_789 [7] type: mcp_tool_result call_id: mcp_forecast_456 name: get_7day_forecast @@ -1378,4 +1358,294 @@ content_blocks: response_meta: token_usage: prompt=250, completion=180, total=430 `, output) + + t.Run("full fields", func(t *testing.T) { + msg := &AgenticMessage{ + Role: AgenticRoleTypeSystem, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputAudio, + UserInputAudio: &UserInputAudio{ + URL: "http://audio.com", + Base64Data: "audio_data", + MIMEType: "audio/mp3", + }, + }, + { + Type: ContentBlockTypeUserInputVideo, + UserInputVideo: &UserInputVideo{ + URL: "http://video.com", + Base64Data: "video_data", + MIMEType: "video/mp4", + }, + }, + { + Type: ContentBlockTypeUserInputFile, + UserInputFile: &UserInputFile{ + URL: "http://file.com", + Name: "file.txt", + Base64Data: "file_data", + MIMEType: "text/plain", + }, + }, + { + Type: ContentBlockTypeAssistantGenImage, + AssistantGenImage: &AssistantGenImage{ + URL: "http://gen_image.com", + Base64Data: "gen_image_data", + MIMEType: "image/png", + }, + }, + { + Type: ContentBlockTypeAssistantGenAudio, + AssistantGenAudio: &AssistantGenAudio{ + URL: "http://gen_audio.com", + Base64Data: "gen_audio_data", + MIMEType: "audio/wav", + }, + }, + { + Type: ContentBlockTypeAssistantGenVideo, + AssistantGenVideo: &AssistantGenVideo{ + URL: "http://gen_video.com", + Base64Data: "gen_video_data", + MIMEType: "video/mp4", + }, + }, + { + Type: ContentBlockTypeServerToolCall, + ServerToolCall: &ServerToolCall{ + Name: "server_tool", + CallID: "call_1", + Arguments: map[string]any{"a": 1}, + }, + }, + { + Type: ContentBlockTypeServerToolResult, + ServerToolResult: &ServerToolResult{ + Name: "server_tool", + CallID: "call_1", + Result: map[string]any{"success": true}, + }, + }, + { + Type: ContentBlockTypeMCPToolApprovalRequest, + MCPToolApprovalRequest: &MCPToolApprovalRequest{ + ID: "req_1", + Name: "mcp_tool", + ServerLabel: "mcp_server", + Arguments: "{}", + }, + }, + { + Type: ContentBlockTypeMCPToolApprovalResponse, + MCPToolApprovalResponse: &MCPToolApprovalResponse{ + ApprovalRequestID: "req_1", + Approve: true, + Reason: "looks good", + }, + }, + }, + } + + s := msg.String() + assert.Contains(t, s, "role: system") + assert.Contains(t, s, "type: user_input_audio") + assert.Contains(t, s, "http://audio.com") + assert.Contains(t, s, "type: user_input_video") + assert.Contains(t, s, "http://video.com") + assert.Contains(t, s, "type: user_input_file") + assert.Contains(t, s, "file.txt") + assert.Contains(t, s, "type: assistant_gen_image") + assert.Contains(t, s, "http://gen_image.com") + assert.Contains(t, s, "type: assistant_gen_audio") + assert.Contains(t, s, "http://gen_audio.com") + assert.Contains(t, s, "type: assistant_gen_video") + assert.Contains(t, s, "http://gen_video.com") + assert.Contains(t, s, "type: server_tool_call") + assert.Contains(t, s, "server_tool") + assert.Contains(t, s, "map[a:1]") + assert.Contains(t, s, "type: server_tool_result") + assert.Contains(t, s, "map[success:true]") + assert.Contains(t, s, "type: mcp_tool_approval_request") + assert.Contains(t, s, "req_1") + assert.Contains(t, s, "type: mcp_tool_approval_response") + assert.Contains(t, s, "looks good") + }) + + t.Run("nil/empty fields", func(t *testing.T) { + msg := &AgenticMessage{ + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + {Type: ContentBlockTypeUserInputAudio, UserInputAudio: &UserInputAudio{}}, // empty + {Type: ContentBlockTypeUserInputVideo, UserInputVideo: &UserInputVideo{}}, + {Type: ContentBlockTypeUserInputFile, UserInputFile: &UserInputFile{}}, + {Type: ContentBlockTypeAssistantGenImage, AssistantGenImage: &AssistantGenImage{}}, + {Type: ContentBlockTypeAssistantGenAudio, AssistantGenAudio: &AssistantGenAudio{}}, + {Type: ContentBlockTypeAssistantGenVideo, AssistantGenVideo: &AssistantGenVideo{}}, + {Type: ContentBlockTypeServerToolCall, ServerToolCall: &ServerToolCall{Name: "t"}}, // No CallID + {Type: ContentBlockTypeServerToolResult, ServerToolResult: &ServerToolResult{Name: "t"}}, // No CallID + {Type: ContentBlockTypeMCPToolResult, MCPToolResult: &MCPToolResult{Name: "t"}}, // No Error + {Type: ContentBlockTypeMCPListToolsResult, MCPListToolsResult: &MCPListToolsResult{}}, // No Error + {Type: ContentBlockTypeMCPToolApprovalResponse, MCPToolApprovalResponse: &MCPToolApprovalResponse{Approve: false}}, // No Reason + nil, // Nil block in slice + }, + } + + s := msg.String() + assert.Contains(t, s, "type: user_input_audio") + assert.NotContains(t, s, "mime_type:") + assert.Contains(t, s, "type: server_tool_call") + }) + + t.Run("nil content struct in block", func(t *testing.T) { + // Test cases where the specific content struct is nil but type is set + // This shouldn't crash and should just print type + msg := &AgenticMessage{ + ContentBlocks: []*ContentBlock{ + {Type: ContentBlockTypeReasoning, Reasoning: nil}, + {Type: ContentBlockTypeUserInputText, UserInputText: nil}, + {Type: ContentBlockTypeUserInputImage, UserInputImage: nil}, + {Type: ContentBlockTypeUserInputAudio, UserInputAudio: nil}, + {Type: ContentBlockTypeUserInputVideo, UserInputVideo: nil}, + {Type: ContentBlockTypeUserInputFile, UserInputFile: nil}, + {Type: ContentBlockTypeAssistantGenText, AssistantGenText: nil}, + {Type: ContentBlockTypeAssistantGenImage, AssistantGenImage: nil}, + {Type: ContentBlockTypeAssistantGenAudio, AssistantGenAudio: nil}, + {Type: ContentBlockTypeAssistantGenVideo, AssistantGenVideo: nil}, + {Type: ContentBlockTypeFunctionToolCall, FunctionToolCall: nil}, + {Type: ContentBlockTypeFunctionToolResult, FunctionToolResult: nil}, + {Type: ContentBlockTypeServerToolCall, ServerToolCall: nil}, + {Type: ContentBlockTypeServerToolResult, ServerToolResult: nil}, + {Type: ContentBlockTypeMCPToolCall, MCPToolCall: nil}, + {Type: ContentBlockTypeMCPToolResult, MCPToolResult: nil}, + {Type: ContentBlockTypeMCPListToolsResult, MCPListToolsResult: nil}, + {Type: ContentBlockTypeMCPToolApprovalRequest, MCPToolApprovalRequest: nil}, + {Type: ContentBlockTypeMCPToolApprovalResponse, MCPToolApprovalResponse: nil}, + }, + } + s := msg.String() + assert.Contains(t, s, "type: reasoning") + // ensure no panic and basic output present + }) +} + +func TestDeveloperAgenticMessage(t *testing.T) { + t.Run("basic", func(t *testing.T) { + msg := DeveloperAgenticMessage("developer") + assert.Equal(t, AgenticRoleTypeDeveloper, msg.Role) + assert.Len(t, msg.ContentBlocks, 1) + assert.Equal(t, "developer", msg.ContentBlocks[0].UserInputText.Text) + }) +} + +func TestSystemAgenticMessage(t *testing.T) { + t.Run("basic", func(t *testing.T) { + msg := SystemAgenticMessage("system") + assert.Equal(t, AgenticRoleTypeSystem, msg.Role) + assert.Len(t, msg.ContentBlocks, 1) + assert.Equal(t, "system", msg.ContentBlocks[0].UserInputText.Text) + }) +} + +func TestUserAgenticMessage(t *testing.T) { + t.Run("basic", func(t *testing.T) { + msg := UserAgenticMessage("user") + assert.Equal(t, AgenticRoleTypeUser, msg.Role) + assert.Len(t, msg.ContentBlocks, 1) + assert.Equal(t, "user", msg.ContentBlocks[0].UserInputText.Text) + }) +} + +func TestFunctionToolResultAgenticMessage(t *testing.T) { + t.Run("basic", func(t *testing.T) { + msg := FunctionToolResultAgenticMessage("call_1", "tool_name", "result_str") + assert.Equal(t, AgenticRoleTypeUser, msg.Role) + assert.Len(t, msg.ContentBlocks, 1) + assert.Equal(t, ContentBlockTypeFunctionToolResult, msg.ContentBlocks[0].Type) + assert.Equal(t, "call_1", msg.ContentBlocks[0].FunctionToolResult.CallID) + assert.Equal(t, "tool_name", msg.ContentBlocks[0].FunctionToolResult.Name) + assert.Equal(t, "result_str", msg.ContentBlocks[0].FunctionToolResult.Result) + }) +} + +func TestNewContentBlock(t *testing.T) { + cbType := reflect.TypeOf(ContentBlock{}) + for i := 0; i < cbType.NumField(); i++ { + field := cbType.Field(i) + + // Skip non-content fields + if field.Name == "Type" || field.Name == "Extra" || field.Name == "StreamingMeta" { + continue + } + + t.Run(field.Name, func(t *testing.T) { + // Ensure field is a pointer + assert.Equal(t, reflect.Ptr, field.Type.Kind(), "Field %s should be a pointer", field.Name) + + // Create a new instance of the field's type + // field.Type is *T, so Elem() is T. reflect.New(T) returns *T. + elemType := field.Type.Elem() + inputVal := reflect.New(elemType) + input := inputVal.Interface() + + // Call NewContentBlock (generic) via type switch + var block *ContentBlock + switch v := input.(type) { + case *Reasoning: + block = NewContentBlock(v) + case *UserInputText: + block = NewContentBlock(v) + case *UserInputImage: + block = NewContentBlock(v) + case *UserInputAudio: + block = NewContentBlock(v) + case *UserInputVideo: + block = NewContentBlock(v) + case *UserInputFile: + block = NewContentBlock(v) + case *AssistantGenText: + block = NewContentBlock(v) + case *AssistantGenImage: + block = NewContentBlock(v) + case *AssistantGenAudio: + block = NewContentBlock(v) + case *AssistantGenVideo: + block = NewContentBlock(v) + case *FunctionToolCall: + block = NewContentBlock(v) + case *FunctionToolResult: + block = NewContentBlock(v) + case *ServerToolCall: + block = NewContentBlock(v) + case *ServerToolResult: + block = NewContentBlock(v) + case *MCPToolCall: + block = NewContentBlock(v) + case *MCPToolResult: + block = NewContentBlock(v) + case *MCPListToolsResult: + block = NewContentBlock(v) + case *MCPToolApprovalRequest: + block = NewContentBlock(v) + case *MCPToolApprovalResponse: + block = NewContentBlock(v) + default: + t.Fatalf("unsupported ContentBlock field type: %T", input) + } + + // Assertions + assert.NotNil(t, block, "NewContentBlock should return non-nil for type %T", input) + + // Check if the corresponding field in block is set equals to input + blockVal := reflect.ValueOf(block).Elem() + fieldVal := blockVal.FieldByName(field.Name) + assert.True(t, fieldVal.IsValid(), "Field %s not found in result", field.Name) + assert.Equal(t, input, fieldVal.Interface(), "Field %s should match input", field.Name) + + // Check Type is set + typeVal := blockVal.FieldByName("Type") + assert.NotEmpty(t, typeVal.String(), "Type should be set for %s", field.Name) + }) + } } diff --git a/schema/claude/content_block.go b/schema/claude/extension.go similarity index 54% rename from schema/claude/content_block.go rename to schema/claude/extension.go index 0c43d1045..2bd7422ad 100644 --- a/schema/claude/content_block.go +++ b/schema/claude/extension.go @@ -16,6 +16,15 @@ package claude +import ( + "fmt" +) + +type ResponseMetaExtension struct { + ID string `json:"id,omitempty"` + StopReason string `json:"stop_reason,omitempty"` +} + type AssistantGenTextExtension struct { Citations []*TextCitation `json:"citations,omitempty"` } @@ -33,30 +42,30 @@ type CitationCharLocation struct { CitedText string `json:"cited_text,omitempty"` DocumentTitle string `json:"document_title,omitempty"` - DocumentIndex int64 `json:"document_index,omitempty"` + DocumentIndex int `json:"document_index,omitempty"` - StartCharIndex int64 `json:"start_char_index,omitempty"` - EndCharIndex int64 `json:"end_char_index,omitempty"` + StartCharIndex int `json:"start_char_index,omitempty"` + EndCharIndex int `json:"end_char_index,omitempty"` } type CitationPageLocation struct { CitedText string `json:"cited_text,omitempty"` DocumentTitle string `json:"document_title,omitempty"` - DocumentIndex int64 `json:"document_index,omitempty"` + DocumentIndex int `json:"document_index,omitempty"` - StartPageNumber int64 `json:"start_page_number,omitempty"` - EndPageNumber int64 `json:"end_page_number,omitempty"` + StartPageNumber int `json:"start_page_number,omitempty"` + EndPageNumber int `json:"end_page_number,omitempty"` } type CitationContentBlockLocation struct { CitedText string `json:"cited_text,omitempty"` DocumentTitle string `json:"document_title,omitempty"` - DocumentIndex int64 `json:"document_index,omitempty"` + DocumentIndex int `json:"document_index,omitempty"` - StartBlockIndex int64 `json:"start_block_index,omitempty"` - EndBlockIndex int64 `json:"end_block_index,omitempty"` + StartBlockIndex int `json:"start_block_index,omitempty"` + EndBlockIndex int `json:"end_block_index,omitempty"` } type CitationWebSearchResultLocation struct { @@ -67,3 +76,44 @@ type CitationWebSearchResultLocation struct { EncryptedIndex string `json:"encrypted_index,omitempty"` } + +func ConcatAssistantGenTextExtensions(chunks []*AssistantGenTextExtension) (*AssistantGenTextExtension, error) { + if len(chunks) == 0 { + return nil, fmt.Errorf("no assistant generated text extension found") + } + if len(chunks) == 1 { + return chunks[0], nil + } + + ret := &AssistantGenTextExtension{ + Citations: make([]*TextCitation, 0, len(chunks)), + } + + for _, ext := range chunks { + ret.Citations = append(ret.Citations, ext.Citations...) + } + + return ret, nil +} + +func ConcatResponseMetaExtensions(chunks []*ResponseMetaExtension) (*ResponseMetaExtension, error) { + if len(chunks) == 0 { + return nil, fmt.Errorf("no response meta extension found") + } + if len(chunks) == 1 { + return chunks[0], nil + } + + ret := &ResponseMetaExtension{} + + for _, ext := range chunks { + if ext.ID != "" { + ret.ID = ext.ID + } + if ext.StopReason != "" { + ret.StopReason = ext.StopReason + } + } + + return ret, nil +} diff --git a/schema/claude/extension_test.go b/schema/claude/extension_test.go new file mode 100644 index 000000000..474fe740b --- /dev/null +++ b/schema/claude/extension_test.go @@ -0,0 +1,190 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package claude + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConcatAssistantGenTextExtensions(t *testing.T) { + t.Run("multiple extensions - concatenates all citations", func(t *testing.T) { + exts := []*AssistantGenTextExtension{ + { + Citations: []*TextCitation{ + { + Type: "char_location", + CharLocation: &CitationCharLocation{ + CitedText: "citation 1", + DocumentIndex: 0, + }, + }, + }, + }, + { + Citations: []*TextCitation{ + { + Type: "page_location", + PageLocation: &CitationPageLocation{ + CitedText: "citation 2", + StartPageNumber: 1, + EndPageNumber: 2, + }, + }, + { + Type: "web_search_result_location", + WebSearchResultLocation: &CitationWebSearchResultLocation{ + CitedText: "citation 3", + URL: "https://example.com", + }, + }, + }, + }, + { + Citations: []*TextCitation{ + { + Type: "content_block_location", + ContentBlockLocation: &CitationContentBlockLocation{ + CitedText: "citation 4", + StartBlockIndex: 0, + EndBlockIndex: 5, + }, + }, + }, + }, + } + + result, err := ConcatAssistantGenTextExtensions(exts) + assert.NoError(t, err) + assert.Len(t, result.Citations, 4) + assert.Equal(t, "citation 1", result.Citations[0].CharLocation.CitedText) + assert.Equal(t, "citation 2", result.Citations[1].PageLocation.CitedText) + assert.Equal(t, "citation 3", result.Citations[2].WebSearchResultLocation.CitedText) + assert.Equal(t, "citation 4", result.Citations[3].ContentBlockLocation.CitedText) + }) + + t.Run("mixed empty and non-empty citations", func(t *testing.T) { + exts := []*AssistantGenTextExtension{ + {Citations: nil}, + { + Citations: []*TextCitation{ + { + Type: "char_location", + CharLocation: &CitationCharLocation{ + CitedText: "text1", + }, + }, + }, + }, + {Citations: []*TextCitation{}}, + { + Citations: []*TextCitation{ + { + Type: "page_location", + PageLocation: &CitationPageLocation{ + CitedText: "text2", + }, + }, + }, + }, + } + + result, err := ConcatAssistantGenTextExtensions(exts) + assert.NoError(t, err) + assert.Len(t, result.Citations, 2) + assert.Equal(t, "text1", result.Citations[0].CharLocation.CitedText) + assert.Equal(t, "text2", result.Citations[1].PageLocation.CitedText) + }) + + t.Run("streaming scenario - citations arrive in chunks", func(t *testing.T) { + // Simulates streaming where citations arrive progressively + exts := []*AssistantGenTextExtension{ + { + Citations: []*TextCitation{ + {Type: "char_location", CharLocation: &CitationCharLocation{CitedText: "chunk1"}}, + }, + }, + { + Citations: []*TextCitation{ + {Type: "char_location", CharLocation: &CitationCharLocation{CitedText: "chunk2"}}, + }, + }, + { + Citations: []*TextCitation{ + {Type: "char_location", CharLocation: &CitationCharLocation{CitedText: "chunk3"}}, + }, + }, + } + + result, err := ConcatAssistantGenTextExtensions(exts) + assert.NoError(t, err) + assert.Len(t, result.Citations, 3) + assert.Equal(t, "chunk1", result.Citations[0].CharLocation.CitedText) + assert.Equal(t, "chunk2", result.Citations[1].CharLocation.CitedText) + assert.Equal(t, "chunk3", result.Citations[2].CharLocation.CitedText) + }) +} + +func TestConcatResponseMetaExtensions(t *testing.T) { + t.Run("multiple extensions - takes last non-empty values", func(t *testing.T) { + exts := []*ResponseMetaExtension{ + { + ID: "msg_1", + StopReason: "stop_1", + }, + { + ID: "msg_2", + StopReason: "", + }, + { + ID: "", + StopReason: "stop_3", + }, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "msg_2", result.ID) // Last non-empty ID + assert.Equal(t, "stop_3", result.StopReason) // Last non-empty StopReason + }) + + t.Run("all empty fields", func(t *testing.T) { + exts := []*ResponseMetaExtension{ + {ID: "", StopReason: ""}, + {ID: "", StopReason: ""}, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "", result.ID) + assert.Equal(t, "", result.StopReason) + }) + + t.Run("streaming scenario - ID in first chunk, StopReason in last", func(t *testing.T) { + exts := []*ResponseMetaExtension{ + {ID: "msg_stream_123", StopReason: ""}, + {ID: "", StopReason: ""}, + {ID: "", StopReason: "end_turn"}, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "msg_stream_123", result.ID) + assert.Equal(t, "end_turn", result.StopReason) + }) +} diff --git a/schema/claude/response_meta.go b/schema/claude/response_meta.go deleted file mode 100644 index 9f60dd713..000000000 --- a/schema/claude/response_meta.go +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Copyright 2025 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package claude - -type ResponseMetaExtension struct { - ID string `json:"id,omitempty"` - StopReason string `json:"stop_reason,omitempty"` -} diff --git a/schema/gemini/response_meta.go b/schema/gemini/extension.go similarity index 81% rename from schema/gemini/response_meta.go rename to schema/gemini/extension.go index a5b3f626c..dc7e8a24a 100644 --- a/schema/gemini/response_meta.go +++ b/schema/gemini/extension.go @@ -16,6 +16,10 @@ package gemini +import ( + "fmt" +) + type ResponseMetaExtension struct { ID string `json:"id,omitempty"` FinishReason string `json:"finish_reason,omitempty"` @@ -56,7 +60,7 @@ type GroundingSupport struct { // A list of indices (into 'grounding_chunk') specifying the citations associated with // the claim. For instance [1,3,4] means that grounding_chunk[1], grounding_chunk[3], // grounding_chunk[4] are the retrieved content attributed to the claim. - GroundingChunkIndices []int32 `json:"grounding_chunk_indices,omitempty"` + GroundingChunkIndices []int `json:"grounding_chunk_indices,omitempty"` // Segment of the content this support belongs to. Segment *Segment `json:"segment,omitempty"` } @@ -65,12 +69,12 @@ type GroundingSupport struct { type Segment struct { // Output only. End index in the given Part, measured in bytes. Offset from the start // of the Part, exclusive, starting at zero. - EndIndex int32 `json:"end_index,omitempty"` + EndIndex int `json:"end_index,omitempty"` // Output only. The index of a Part object within its parent Content object. - PartIndex int32 `json:"part_index,omitempty"` + PartIndex int `json:"part_index,omitempty"` // Output only. Start index in the given Part, measured in bytes. Offset from the start // of the Part, inclusive, starting at zero. - StartIndex int32 `json:"start_index,omitempty"` + StartIndex int `json:"start_index,omitempty"` // Output only. The text corresponding to the segment from the response. Text string `json:"text,omitempty"` } @@ -82,3 +86,28 @@ type SearchEntryPoint struct { // Optional. Base64 encoded JSON representing array of tuple. SDKBlob []byte `json:"sdk_blob,omitempty"` } + +func ConcatResponseMetaExtensions(chunks []*ResponseMetaExtension) (*ResponseMetaExtension, error) { + if len(chunks) == 0 { + return nil, fmt.Errorf("no response meta extension found") + } + if len(chunks) == 1 { + return chunks[0], nil + } + + ret := &ResponseMetaExtension{} + + for _, ext := range chunks { + if ext.ID != "" { + ret.ID = ext.ID + } + if ext.FinishReason != "" { + ret.FinishReason = ext.FinishReason + } + if ext.GroundingMeta != nil { + ret.GroundingMeta = ext.GroundingMeta + } + } + + return ret, nil +} diff --git a/schema/gemini/extension_test.go b/schema/gemini/extension_test.go new file mode 100644 index 000000000..56f390aa8 --- /dev/null +++ b/schema/gemini/extension_test.go @@ -0,0 +1,79 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package gemini + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConcatResponseMetaExtensions(t *testing.T) { + t.Run("multiple extensions - takes last non-empty values", func(t *testing.T) { + meta1 := &GroundingMetadata{WebSearchQueries: []string{"query1"}} + meta2 := &GroundingMetadata{WebSearchQueries: []string{"query2"}} + + exts := []*ResponseMetaExtension{ + { + ID: "resp_1", + FinishReason: "STOP", + GroundingMeta: meta1, + }, + { + ID: "resp_2", + FinishReason: "", + GroundingMeta: nil, + }, + { + ID: "", + FinishReason: "MAX_TOKENS", + GroundingMeta: meta2, + }, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "resp_2", result.ID) + assert.Equal(t, "MAX_TOKENS", result.FinishReason) + assert.Equal(t, meta2, result.GroundingMeta) + }) + + t.Run("streaming scenario", func(t *testing.T) { + meta := &GroundingMetadata{ + GroundingChunks: []*GroundingChunk{ + { + Web: &GroundingChunkWeb{ + Title: "Example", + URI: "https://example.com", + }, + }, + }, + } + + exts := []*ResponseMetaExtension{ + {ID: "stream_123", FinishReason: "", GroundingMeta: nil}, + {ID: "", FinishReason: "", GroundingMeta: nil}, + {ID: "", FinishReason: "STOP", GroundingMeta: meta}, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "stream_123", result.ID) + assert.Equal(t, "STOP", result.FinishReason) + assert.Equal(t, meta, result.GroundingMeta) + }) +} diff --git a/schema/message.go b/schema/message.go index 3c890b0b5..8884984fc 100644 --- a/schema/message.go +++ b/schema/message.go @@ -499,12 +499,10 @@ type TokenUsage struct { PromptTokenDetails PromptTokenDetails `json:"prompt_token_details"` // CompletionTokens is the number of completion tokens. CompletionTokens int `json:"completion_tokens"` - // CompletionTokenDetails is a breakdown of the completion tokens. - CompletionTokenDetails CompletionTokenDetails `json:"completion_token_details"` + // CompletionTokensDetails is a breakdown of the completion tokens. + CompletionTokensDetails CompletionTokensDetails `json:"completion_tokens_details"` // TotalTokens is the total number of tokens. TotalTokens int `json:"total_tokens"` - // CompletionTokensDetails is breakdown of completion tokens. - CompletionTokensDetails CompletionTokensDetails `json:"completion_token_details"` } type CompletionTokensDetails struct { @@ -514,10 +512,6 @@ type CompletionTokensDetails struct { ReasoningTokens int `json:"reasoning_tokens,omitempty"` } -type CompletionTokenDetails struct { - ReasoningTokens int `json:"reasoning_tokens"` -} - type PromptTokenDetails struct { // Cached tokens present in the prompt. CachedTokens int `json:"cached_tokens"` diff --git a/schema/openai/consts.go b/schema/openai/consts.go index 321ee2a9e..9cea69efd 100644 --- a/schema/openai/consts.go +++ b/schema/openai/consts.go @@ -24,3 +24,71 @@ const ( TextAnnotationTypeContainerFileCitation TextAnnotationType = "container_file_citation" TextAnnotationTypeFilePath TextAnnotationType = "file_path" ) + +type ReasoningEffort string + +const ( + ReasoningEffortMinimal ReasoningEffort = "minimal" + ReasoningEffortLow ReasoningEffort = "low" + ReasoningEffortMedium ReasoningEffort = "medium" + ReasoningEffortHigh ReasoningEffort = "high" +) + +type ReasoningSummary string + +const ( + ReasoningSummaryAuto ReasoningSummary = "auto" + ReasoningSummaryConcise ReasoningSummary = "concise" + ReasoningSummaryDetailed ReasoningSummary = "detailed" +) + +type ServiceTier string + +const ( + ServiceTierAuto ServiceTier = "auto" + ServiceTierDefault ServiceTier = "default" + ServiceTierFlex ServiceTier = "flex" + ServiceTierScale ServiceTier = "scale" + ServiceTierPriority ServiceTier = "priority" +) + +type PromptCacheRetention string + +const ( + PromptCacheRetentionInMemory PromptCacheRetention = "in-memory" + PromptCacheRetention24h PromptCacheRetention = "24h" +) + +type ResponseStatus string + +const ( + ResponseStatusCompleted ResponseStatus = "completed" + ResponseStatusFailed ResponseStatus = "failed" + ResponseStatusInProgress ResponseStatus = "in_progress" + ResponseStatusCancelled ResponseStatus = "cancelled" + ResponseStatusQueued ResponseStatus = "queued" + ResponseStatusIncomplete ResponseStatus = "incomplete" +) + +type ResponseErrorCode string + +const ( + ResponseErrorCodeServerError ResponseErrorCode = "server_error" + ResponseErrorCodeRateLimitExceeded ResponseErrorCode = "rate_limit_exceeded" + ResponseErrorCodeInvalidPrompt ResponseErrorCode = "invalid_prompt" + ResponseErrorCodeVectorStoreTimeout ResponseErrorCode = "vector_store_timeout" + ResponseErrorCodeInvalidImage ResponseErrorCode = "invalid_image" + ResponseErrorCodeInvalidImageFormat ResponseErrorCode = "invalid_image_format" + ResponseErrorCodeInvalidBase64Image ResponseErrorCode = "invalid_base64_image" + ResponseErrorCodeInvalidImageURL ResponseErrorCode = "invalid_image_url" + ResponseErrorCodeImageTooLarge ResponseErrorCode = "image_too_large" + ResponseErrorCodeImageTooSmall ResponseErrorCode = "image_too_small" + ResponseErrorCodeImageParseError ResponseErrorCode = "image_parse_error" + ResponseErrorCodeImageContentPolicyViolation ResponseErrorCode = "image_content_policy_violation" + ResponseErrorCodeInvalidImageMode ResponseErrorCode = "invalid_image_mode" + ResponseErrorCodeImageFileTooLarge ResponseErrorCode = "image_file_too_large" + ResponseErrorCodeUnsupportedImageMediaType ResponseErrorCode = "unsupported_image_media_type" + ResponseErrorCodeEmptyImageFile ResponseErrorCode = "empty_image_file" + ResponseErrorCodeFailedToDownloadImage ResponseErrorCode = "failed_to_download_image" + ResponseErrorCodeImageFileNotFound ResponseErrorCode = "image_file_not_found" +) diff --git a/schema/openai/content_block.go b/schema/openai/content_block.go deleted file mode 100644 index 5d92be8f7..000000000 --- a/schema/openai/content_block.go +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Copyright 2025 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package openai - -type AssistantGenTextExtension struct { - Annotations []*TextAnnotation `json:"annotations,omitempty"` -} - -type TextAnnotation struct { - Type TextAnnotationType `json:"type,omitempty"` - - FileCitation *TextAnnotationFileCitation `json:"file_citation,omitempty"` - URLCitation *TextAnnotationURLCitation `json:"url_citation,omitempty"` - ContainerFileCitation *TextAnnotationContainerFileCitation `json:"container_file_citation,omitempty"` - FilePath *TextAnnotationFilePath `json:"file_path,omitempty"` -} - -type TextAnnotationFileCitation struct { - // The ID of the file. - FileID string `json:"file_id,omitempty"` - // The filename of the file cited. - Filename string `json:"filename,omitempty"` - - // The index of the file in the list of files. - Index int64 `json:"index,omitempty"` -} - -type TextAnnotationURLCitation struct { - // The title of the web resource. - Title string `json:"title,omitempty"` - // The URL of the web resource. - URL string `json:"url,omitempty"` - - // The index of the first character of the URL citation in the message. - StartIndex int64 `json:"start_index,omitempty"` - // The index of the last character of the URL citation in the message. - EndIndex int64 `json:"end_index,omitempty"` -} - -type TextAnnotationContainerFileCitation struct { - // The ID of the container file. - ContainerID string `json:"container_id,omitempty"` - - // The ID of the file. - FileID string `json:"file_id,omitempty"` - // The filename of the container file cited. - Filename string `json:"filename,omitempty"` - - // The index of the first character of the container file citation in the message. - StartIndex int64 `json:"start_index,omitempty"` - // The index of the last character of the container file citation in the message. - EndIndex int64 `json:"end_index,omitempty"` -} - -type TextAnnotationFilePath struct { - // The ID of the file. - FileID string `json:"file_id,omitempty"` - - // The index of the file in the list of files. - Index int64 `json:"index,omitempty"` -} diff --git a/schema/openai/extension.go b/schema/openai/extension.go new file mode 100644 index 000000000..01d64b238 --- /dev/null +++ b/schema/openai/extension.go @@ -0,0 +1,204 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openai + +import ( + "fmt" + "sort" +) + +type ResponseMetaExtension struct { + ID string `json:"id,omitempty"` + Status ResponseStatus `json:"status,omitempty"` + Error *ResponseError `json:"error,omitempty"` + IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"` + PreviousResponseID string `json:"previous_response_id,omitempty"` + Reasoning *Reasoning `json:"reasoning,omitempty"` + ServiceTier ServiceTier `json:"service_tier,omitempty"` + CreatedAt int64 `json:"created_at,omitempty"` + PromptCacheRetention PromptCacheRetention `json:"prompt_cache_retention,omitempty"` +} + +type AssistantGenTextExtension struct { + Refusal *OutputRefusal `json:"refusal,omitempty"` + Annotations []*TextAnnotation `json:"annotations,omitempty"` +} + +type ResponseError struct { + Code ResponseErrorCode `json:"code,omitempty"` + Message string `json:"message,omitempty"` +} + +type IncompleteDetails struct { + Reason string `json:"reason,omitempty"` +} + +type Reasoning struct { + Effort ReasoningEffort `json:"effort,omitempty"` + Summary ReasoningSummary `json:"summary,omitempty"` +} + +type OutputRefusal struct { + Reason string `json:"reason,omitempty"` +} + +type TextAnnotation struct { + Index int `json:"index,omitempty"` + + Type TextAnnotationType `json:"type,omitempty"` + + FileCitation *TextAnnotationFileCitation `json:"file_citation,omitempty"` + URLCitation *TextAnnotationURLCitation `json:"url_citation,omitempty"` + ContainerFileCitation *TextAnnotationContainerFileCitation `json:"container_file_citation,omitempty"` + FilePath *TextAnnotationFilePath `json:"file_path,omitempty"` +} + +type TextAnnotationFileCitation struct { + // The ID of the file. + FileID string `json:"file_id,omitempty"` + // The filename of the file cited. + Filename string `json:"filename,omitempty"` + + // The index of the file in the list of files. + Index int `json:"index,omitempty"` +} + +type TextAnnotationURLCitation struct { + // The title of the web resource. + Title string `json:"title,omitempty"` + // The URL of the web resource. + URL string `json:"url,omitempty"` + + // The index of the first character of the URL citation in the message. + StartIndex int `json:"start_index,omitempty"` + // The index of the last character of the URL citation in the message. + EndIndex int `json:"end_index,omitempty"` +} + +type TextAnnotationContainerFileCitation struct { + // The ID of the container file. + ContainerID string `json:"container_id,omitempty"` + + // The ID of the file. + FileID string `json:"file_id,omitempty"` + // The filename of the container file cited. + Filename string `json:"filename,omitempty"` + + // The index of the first character of the container file citation in the message. + StartIndex int `json:"start_index,omitempty"` + // The index of the last character of the container file citation in the message. + EndIndex int `json:"end_index,omitempty"` +} + +type TextAnnotationFilePath struct { + // The ID of the file. + FileID string `json:"file_id,omitempty"` + + // The index of the file in the list of files. + Index int `json:"index,omitempty"` +} + +func ConcatAssistantGenTextExtensions(chunks []*AssistantGenTextExtension) (*AssistantGenTextExtension, error) { + if len(chunks) == 0 { + return nil, fmt.Errorf("no assistant generated text extension found") + } + + ret := &AssistantGenTextExtension{} + + var allAnnotations []*TextAnnotation + for _, ext := range chunks { + allAnnotations = append(allAnnotations, ext.Annotations...) + } + + var ( + indices []int + indexToAnnotation = map[int]*TextAnnotation{} + ) + + for _, an := range allAnnotations { + if an == nil { + continue + } + if indexToAnnotation[an.Index] == nil { + indexToAnnotation[an.Index] = an + indices = append(indices, an.Index) + } else { + return nil, fmt.Errorf("duplicate annotation index %d", an.Index) + } + } + + sort.Slice(indices, func(i, j int) bool { + return indices[i] < indices[j] + }) + + ret.Annotations = make([]*TextAnnotation, 0, len(indices)) + for _, idx := range indices { + an := *indexToAnnotation[idx] + an.Index = 0 // clear index + ret.Annotations = append(ret.Annotations, &an) + } + + for _, ext := range chunks { + if ext.Refusal == nil { + continue + } + if ret.Refusal == nil { + ret.Refusal = ext.Refusal + } else { + ret.Refusal.Reason += ext.Refusal.Reason + } + } + + return ret, nil +} + +func ConcatResponseMetaExtensions(chunks []*ResponseMetaExtension) (*ResponseMetaExtension, error) { + if len(chunks) == 0 { + return nil, fmt.Errorf("no response meta extension found") + } + if len(chunks) == 1 { + return chunks[0], nil + } + + ret := &ResponseMetaExtension{} + + for _, ext := range chunks { + if ext.ID != "" { + ret.ID = ext.ID + } + if ext.Status != "" { + ret.Status = ext.Status + } + if ext.Error != nil { + ret.Error = ext.Error + } + if ext.IncompleteDetails != nil { + ret.IncompleteDetails = ext.IncompleteDetails + } + if ext.PreviousResponseID != "" { + ret.PreviousResponseID = ext.PreviousResponseID + } + if ext.Reasoning != nil { + ret.Reasoning = ext.Reasoning + } + if ext.ServiceTier != "" { + ret.ServiceTier = ext.ServiceTier + } + } + + return ret, nil +} diff --git a/schema/openai/extension_test.go b/schema/openai/extension_test.go new file mode 100644 index 000000000..640982fdf --- /dev/null +++ b/schema/openai/extension_test.go @@ -0,0 +1,193 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openai + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConcatResponseMetaExtensions(t *testing.T) { + t.Run("multiple extensions - takes last non-empty values", func(t *testing.T) { + err1 := &ResponseError{Code: "err1", Message: "msg1"} + incomplete := &IncompleteDetails{Reason: "max_tokens"} + + exts := []*ResponseMetaExtension{ + { + ID: "id_1", + Status: "in_progress", + Error: err1, + IncompleteDetails: nil, + }, + { + ID: "id_2", + Status: "", + Error: nil, + IncompleteDetails: nil, + }, + { + ID: "", + Status: "completed", + Error: nil, + IncompleteDetails: incomplete, + }, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "id_2", result.ID) + assert.Equal(t, ResponseStatus("completed"), result.Status) + assert.Equal(t, err1, result.Error) + assert.Equal(t, incomplete, result.IncompleteDetails) + }) + + t.Run("streaming scenario", func(t *testing.T) { + exts := []*ResponseMetaExtension{ + {ID: "chatcmpl_stream", Status: "", Error: nil, IncompleteDetails: nil}, + {ID: "", Status: ResponseStatus("in_progress"), Error: nil, IncompleteDetails: nil}, + {ID: "", Status: ResponseStatus("completed"), Error: nil, IncompleteDetails: nil}, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "chatcmpl_stream", result.ID) + assert.Equal(t, ResponseStatus("completed"), result.Status) + }) +} + +func TestConcatAssistantGenTextExtensions(t *testing.T) { + t.Run("single extension with annotations", func(t *testing.T) { + ext := &AssistantGenTextExtension{ + Annotations: []*TextAnnotation{ + { + Index: 0, + Type: "file_citation", + FileCitation: &TextAnnotationFileCitation{ + FileID: "file_123", + Filename: "doc.pdf", + }, + }, + }, + } + + result, err := ConcatAssistantGenTextExtensions([]*AssistantGenTextExtension{ext}) + assert.NoError(t, err) + assert.Len(t, result.Annotations, 1) + assert.Equal(t, "file_123", result.Annotations[0].FileCitation.FileID) + }) + + t.Run("multiple extensions - merges annotations by index", func(t *testing.T) { + exts := []*AssistantGenTextExtension{ + { + Annotations: []*TextAnnotation{ + { + Index: 0, + Type: "file_citation", + FileCitation: &TextAnnotationFileCitation{ + FileID: "file_1", + }, + }, + }, + }, + { + Annotations: []*TextAnnotation{ + { + Index: 2, + Type: "url_citation", + URLCitation: &TextAnnotationURLCitation{ + URL: "https://example.com", + }, + }, + }, + }, + { + Annotations: []*TextAnnotation{ + { + Index: 1, + Type: "file_path", + FilePath: &TextAnnotationFilePath{ + FileID: "file_2", + }, + }, + }, + }, + } + + result, err := ConcatAssistantGenTextExtensions(exts) + assert.NoError(t, err) + assert.Len(t, result.Annotations, 3) + assert.Equal(t, "file_1", result.Annotations[0].FileCitation.FileID) + assert.Equal(t, "file_2", result.Annotations[1].FilePath.FileID) + assert.Equal(t, "https://example.com", result.Annotations[2].URLCitation.URL) + }) + + t.Run("streaming scenario - annotations arrive in chunks", func(t *testing.T) { + exts := []*AssistantGenTextExtension{ + { + Annotations: []*TextAnnotation{ + {Index: 0, Type: "file_citation", FileCitation: &TextAnnotationFileCitation{FileID: "f1"}}, + }, + }, + { + Annotations: []*TextAnnotation{ + {Index: 1, Type: "url_citation", URLCitation: &TextAnnotationURLCitation{URL: "url1"}}, + }, + }, + { + Annotations: []*TextAnnotation{ + {Index: 2, Type: "file_path", FilePath: &TextAnnotationFilePath{FileID: "f2"}}, + }, + }, + } + + result, err := ConcatAssistantGenTextExtensions(exts) + assert.NoError(t, err) + assert.Len(t, result.Annotations, 3) + assert.Equal(t, "f1", result.Annotations[0].FileCitation.FileID) + assert.Equal(t, "url1", result.Annotations[1].URLCitation.URL) + assert.Equal(t, "f2", result.Annotations[2].FilePath.FileID) + }) + + t.Run("multiple extensions - concatenates refusal reason", func(t *testing.T) { + ext1 := &AssistantGenTextExtension{Refusal: &OutputRefusal{Reason: "A"}} + ext2 := &AssistantGenTextExtension{Refusal: &OutputRefusal{Reason: "B"}} + + result, err := ConcatAssistantGenTextExtensions([]*AssistantGenTextExtension{ext1, ext2}) + assert.NoError(t, err) + assert.NotNil(t, result.Refusal) + assert.Equal(t, "AB", result.Refusal.Reason) + }) + + t.Run("duplicate index - error occurrence", func(t *testing.T) { + exts := []*AssistantGenTextExtension{ + { + Annotations: []*TextAnnotation{ + {Index: 0, Type: "file_citation", FileCitation: &TextAnnotationFileCitation{FileID: "first"}}, + }, + }, + { + Annotations: []*TextAnnotation{ + {Index: 0, Type: "url_citation", URLCitation: &TextAnnotationURLCitation{URL: "second"}}, + }, + }, + } + + _, err := ConcatAssistantGenTextExtensions(exts) + assert.Error(t, err) + }) +} diff --git a/schema/openai/response_meta.go b/schema/openai/response_meta.go deleted file mode 100644 index e1933065b..000000000 --- a/schema/openai/response_meta.go +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright 2025 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package openai - -type ResponseMetaExtension struct { - ID string `json:"id,omitempty"` - Status string `json:"status,omitempty"` - Error *ResponseError `json:"error,omitempty"` - StreamError *StreamResponseError `json:"stream_error,omitempty"` - IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"` -} - -type ResponseError struct { - Code string `json:"code,omitempty"` - Message string `json:"message,omitempty"` -} - -type StreamResponseError struct { - Code string - Message string - Param string -} - -type IncompleteDetails struct { - Reason string `json:"reason,omitempty"` -} diff --git a/schema/tool.go b/schema/tool.go index 44dc0e994..bb3d98c13 100644 --- a/schema/tool.go +++ b/schema/tool.go @@ -54,6 +54,26 @@ const ( ToolChoiceForced ToolChoice = "forced" ) +type AllowedTool struct { + // FunctionToolName is the name of the function tool. + FunctionToolName string + + MCPTool *AllowedMCPTool + + ServerTool *AllowedServerTool +} +type AllowedMCPTool struct { + // ServerLabel is the label of the MCP server. + ServerLabel string + // The name of the MCP tool. + Name string +} + +type AllowedServerTool struct { + // The name of the server tool. + Name string +} + // ToolInfo is the information of a tool. type ToolInfo struct { // The unique name of the tool that clearly communicates its purpose.