From 884793466b011bfd997be3caf48aa8eb6d547ca0 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Thu, 16 Oct 2025 16:36:12 +0800 Subject: [PATCH 01/28] feat: define AgenticModel component interface --- components/agency/callback_extra.go | 101 ++++++++ components/agency/interface.go | 29 +++ components/agency/option.go | 77 ++++++ components/types.go | 1 + go.mod | 3 +- go.sum | 7 +- schema/agentic_message.go | 374 ++++++++++++++++++++++++++++ schema/anthropic/citation.go | 49 ++++ schema/anthropic/types.go | 10 + schema/google/candidate_meta.go | 66 +++++ schema/message.go | 4 +- schema/openai/annotation.go | 55 ++++ schema/openai/types.go | 10 + 13 files changed, 780 insertions(+), 6 deletions(-) create mode 100644 components/agency/callback_extra.go create mode 100644 components/agency/interface.go create mode 100644 components/agency/option.go create mode 100644 schema/agentic_message.go create mode 100644 schema/anthropic/citation.go create mode 100644 schema/anthropic/types.go create mode 100644 schema/google/candidate_meta.go create mode 100644 schema/openai/annotation.go create mode 100644 schema/openai/types.go diff --git a/components/agency/callback_extra.go b/components/agency/callback_extra.go new file mode 100644 index 000000000..984756d1a --- /dev/null +++ b/components/agency/callback_extra.go @@ -0,0 +1,101 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package agency + +import ( + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/schema" +) + +// TokenUsageMeta is the token usage for the model. +type TokenUsageMeta struct { + InputTokens int64 `json:"input_tokens"` + InputTokensDetails InputTokensUsageDetails `json:"input_tokens_details"` + OutputTokens int64 `json:"output_tokens"` + OutputTokensDetails OutputTokensUsageDetails `json:"output_tokens_details"` + TotalTokens int64 `json:"total_tokens"` +} + +type InputTokensUsageDetails struct { + CachedTokens int64 `json:"cached_tokens"` +} + +type OutputTokensUsageDetails struct { + ReasoningTokens int64 `json:"reasoning_tokens"` +} + +// Config is the config for the model. +type Config struct { + // Model is the model name. + Model string + // Temperature is the temperature, which controls the randomness of the model. + Temperature float32 + // TopP is the top p, which controls the diversity of the model. + TopP float32 +} + +// CallbackInput is the input for the model callback. +type CallbackInput struct { + // Responses is the responses to be sent to the model. + Responses []*schema.AgenticMessage + // Tools is the tools to be used in the model. + Tools []*schema.ToolInfo + // Config is the config for the model. + Config *Config + // Extra is the extra information for the callback. + Extra map[string]any +} + +// CallbackOutput is the output for the model callback. +type CallbackOutput struct { + // Response is the response generated by the model. + Response *schema.AgenticMessage + // Config is the config for the model. + Config *Config + // Usage is the token usage of this request. + Usage *TokenUsageMeta + // Extra is the extra information for the callback. + Extra map[string]any +} + +// ConvCallbackInput converts the callback input to the model callback input. +func ConvCallbackInput(src callbacks.CallbackInput) *CallbackInput { + switch t := src.(type) { + case *CallbackInput: // when callback is triggered within component implementation, the input is usually already a typed *model.CallbackInput + return t + case []*schema.AgenticMessage: // when callback is injected by graph node, not the component implementation itself, the input is the input of Chat Model interface, which is []*schema.AgenticMessage + return &CallbackInput{ + Responses: t, + } + default: + return nil + } +} + +// ConvCallbackOutput converts the callback output to the model callback output. +func ConvCallbackOutput(src callbacks.CallbackOutput) *CallbackOutput { + switch t := src.(type) { + case *CallbackOutput: // when callback is triggered within component implementation, the output is usually already a typed *model.CallbackOutput + return t + case *schema.AgenticMessage: // when callback is injected by graph node, not the component implementation itself, the output is the output of Chat Model interface, which is *schema.AgenticMessage + return &CallbackOutput{ + Response: t, + } + default: + return nil + } +} diff --git a/components/agency/interface.go b/components/agency/interface.go new file mode 100644 index 000000000..e33d6a933 --- /dev/null +++ b/components/agency/interface.go @@ -0,0 +1,29 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package agency + +import ( + "context" + + "github.com/cloudwego/eino/schema" +) + +type AgenticModel interface { + Generate(ctx context.Context, input []*schema.AgenticMessage, opts ...Option) (*schema.AgenticMessage, error) + Stream(ctx context.Context, input []*schema.AgenticMessage, opts ...Option) (*schema.StreamReader[*schema.AgenticMessage], error) + WithTools(tools []*schema.ToolInfo) (AgenticModel, error) +} diff --git a/components/agency/option.go b/components/agency/option.go new file mode 100644 index 000000000..17028f6e9 --- /dev/null +++ b/components/agency/option.go @@ -0,0 +1,77 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package agency + +// Options is the common options for the model. +type Options struct { +} + +// Option is the call option for ChatModel component. +type Option struct { + apply func(opts *Options) + + implSpecificOptFn any +} + +// WrapImplSpecificOptFn is the option to wrap the implementation specific option function. +func WrapImplSpecificOptFn[T any](optFn func(*T)) Option { + return Option{ + implSpecificOptFn: optFn, + } +} + +// GetCommonOptions extract model Options from Option list, optionally providing a base Options with default values. +func GetCommonOptions(base *Options, opts ...Option) *Options { + if base == nil { + base = &Options{} + } + + for i := range opts { + opt := opts[i] + if opt.apply != nil { + opt.apply(base) + } + } + + return base +} + +// GetImplSpecificOptions extract the implementation specific options from Option list, optionally providing a base options with default values. +// e.g. +// +// myOption := &MyOption{ +// Field1: "default_value", +// } +// +// myOption := model.GetImplSpecificOptions(myOption, opts...) +func GetImplSpecificOptions[T any](base *T, opts ...Option) *T { + if base == nil { + base = new(T) + } + + for i := range opts { + opt := opts[i] + if opt.implSpecificOptFn != nil { + optFn, ok := opt.implSpecificOptFn.(func(*T)) + if ok { + optFn(base) + } + } + } + + return base +} diff --git a/components/types.go b/components/types.go index a546ae59f..a23d82a68 100644 --- a/components/types.go +++ b/components/types.go @@ -68,6 +68,7 @@ const ( ComponentOfPrompt Component = "ChatTemplate" // ComponentOfChatModel identifies chat model components. ComponentOfChatModel Component = "ChatModel" + ComponentOfAgenticModel Component = "AgenticModel" // ComponentOfEmbedding identifies embedding components. ComponentOfEmbedding Component = "Embedding" // ComponentOfIndexer identifies indexer components. diff --git a/go.mod b/go.mod index cfa6957cc..0b87a6cab 100644 --- a/go.mod +++ b/go.mod @@ -41,6 +41,7 @@ require ( github.com/yargevad/filepathx v1.0.0 // indirect golang.org/x/arch v0.11.0 // indirect golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 // indirect - golang.org/x/sys v0.26.0 // indirect + golang.org/x/sys v0.29.0 // indirect + golang.org/x/term v0.28.0 // indirect gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect ) diff --git a/go.sum b/go.sum index a80d6399b..5813766b2 100644 --- a/go.sum +++ b/go.sum @@ -117,9 +117,10 @@ golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= -golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.10.0 h1:3R7pNqamzBraeqj/Tj8qt1aQ2HpmlC+Cx/qL/7hn4/c= +golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= +golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg= +golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= diff --git a/schema/agentic_message.go b/schema/agentic_message.go new file mode 100644 index 000000000..e386ed044 --- /dev/null +++ b/schema/agentic_message.go @@ -0,0 +1,374 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package schema + +import ( + "github.com/cloudwego/eino/schema/anthropic" + "github.com/cloudwego/eino/schema/google" + "github.com/cloudwego/eino/schema/openai" + "github.com/eino-contrib/jsonschema" +) + +type ContentBlockType string + +const ( + ContentBlockTypeReasoning ContentBlockType = "reasoning" + ContentBlockTypeUserInputText ContentBlockType = "user_input_text" + ContentBlockTypeUserInputImage ContentBlockType = "user_input_image" + ContentBlockTypeUserInputAudio ContentBlockType = "user_input_audio" + ContentBlockTypeUserInputVideo ContentBlockType = "user_input_video" + ContentBlockTypeUserInputFile ContentBlockType = "user_input_file" + ContentBlockTypeAssistantGenText ContentBlockType = "assistant_gen_text" + ContentBlockTypeAssistantGenImage ContentBlockType = "assistant_gen_image" + ContentBlockTypeAssistantGenAudio ContentBlockType = "assistant_gen_audio" + ContentBlockTypeAssistantGenVideo ContentBlockType = "assistant_gen_video" + ContentBlockTypeFunctionToolCall ContentBlockType = "function_tool_call" + ContentBlockTypeFunctionToolResult ContentBlockType = "function_tool_result" + ContentBlockTypeServerToolCall ContentBlockType = "server_tool_call" + ContentBlockTypeServerToolResult ContentBlockType = "server_tool_result" + ContentBlockTypeMCPToolCall ContentBlockType = "mcp_tool_call" + ContentBlockTypeMCPToolResult ContentBlockType = "mcp_tool_result" + ContentBlockTypeMCPListTools ContentBlockType = "mcp_list_tools" + ContentBlockTypeMCPToolApprovalRequest ContentBlockType = "mcp_tool_approval_request" + ContentBlockTypeMCPToolApprovalResponse ContentBlockType = "mcp_tool_approval_response" +) + +type AgenticRoleType string + +const ( + AgenticRoleTypeDeveloper AgenticRoleType = "developer" + AgenticRoleTypeSystem AgenticRoleType = "system" + AgenticRoleTypeUser AgenticRoleType = "user" + AgenticRoleTypeAssistant AgenticRoleType = "assistant" +) + +type AgenticMessage struct { + ResponseMeta *AgenticResponseMeta + + Role AgenticRoleType + ContentBlocks []*ContentBlock + + Extra map[string]any +} + +type AgenticResponseMeta struct { + Status *string + FinishReason string + + TokenUsage *TokenUsage + + GoogleAdditionalMeta *google.CandidateMeta +} + +type StreamMeta struct { + // Index is used for streaming to identify the chunk of the block for concatenation. + Index *int + // Streaming phase of the content block. + Phase StreamPhase +} + +type ContentBlock struct { + Type ContentBlockType + + Reasoning *Reasoning + + UserInputText *UserInputText + UserInputImage *UserInputImage + UserInputAudio *UserInputAudio + UserInputVideo *UserInputVideo + UserInputFile *UserInputFile + + AssistantGenText *AssistantGenText + AssistantGenImage *AssistantGenImage + AssistantGenAudio *AssistantGenAudio + AssistantGenVideo *AssistantGenVideo + + // FunctionToolCall holds invocation details for a user-defined tool. + FunctionToolCall *FunctionToolCall + // FunctionToolResult is the result from a user-defined tool call. + FunctionToolResult *FunctionToolResult + // ServerToolCall holds invocation details for a provider built-in tool run on the model server. + ServerToolCall *ServerToolCall + // ServerToolResult is the result from a provider built-in tool run on the model server. + ServerToolResult *ServerToolResult + + // MCPToolCall holds invocation details for an MCP tool managed by the model server. + MCPToolCall *MCPToolCall + // MCPToolResult is the result from an MCP tool managed by the model server. + MCPToolResult *MCPToolResult + // MCPListToolsResult lists available MCP tools reported by the model server. + MCPListToolsResult *MCPListToolsResult + // MCPToolApprovalRequest requests user approval for an MCP tool call when required. + MCPToolApprovalRequest *MCPToolApprovalRequest + // MCPToolApprovalResponse records the user's approval decision for an MCP tool call. + MCPToolApprovalResponse *MCPToolApprovalResponse + + StreamMeta *StreamMeta +} + +type StreamPhase string + +const ( + StreamPhaseStart StreamPhase = "start" + StreamPhaseDelta StreamPhase = "delta" + StreamPhaseStop StreamPhase = "stop" +) + +type UserInputText struct { + Text string + + // Extra stores additional information. + Extra map[string]any +} + +type UserInputImage struct { + URL *string + Base64Data *string + MIMEType string + Detail ImageURLDetail + + // Extra stores additional information. + Extra map[string]any +} + +type UserInputAudio struct { + URL *string + Base64Data *string + MIMEType string + + // Extra stores additional information. + Extra map[string]any +} + +type UserInputVideo struct { + URL *string + Base64Data *string + MIMEType string + + // Extra stores additional information. + Extra map[string]any +} + +type UserInputFile struct { + URL *string + Name *string + Base64Data *string + MIMEType string + + // Extra stores additional information. + Extra map[string]any +} + +type AssistantGenText struct { + Text string + + OpenAIAnnotations []*openai.TextAnnotation + AnthropicCitations []*anthropic.TextCitation + + // Extra stores additional information. + Extra map[string]any +} + +type AssistantGenImage struct { + URL *string + Base64Data *string + MIMEType string + + // Extra stores additional information. + Extra map[string]any +} + +type AssistantGenAudio struct { + URL *string + Base64Data *string + MIMEType string + + // Extra stores additional information. + Extra map[string]any +} + +type AssistantGenVideo struct { + URL *string + Base64Data *string + MIMEType string + + // Extra stores additional information. + Extra map[string]any +} + +type Reasoning struct { + // Summary is the reasoning content summary. + Summary []*ReasoningSummary + + // EncryptedContent is the encrypted reasoning content. + EncryptedContent string + + // Extra stores additional information. + Extra map[string]any +} + +type ReasoningSummary struct { + // Index specifies the ReasoningSummary chunk to be concatenated during streaming. + Index *int + + Text string +} + +type FunctionToolCall struct { + // CallID is the unique identifier for the tool call. + CallID string + + // Name specifies the function tool invoked. + Name string + + // Arguments is the JSON string arguments for the function tool call. + Arguments string + + // Extra stores additional information + Extra map[string]any +} + +type FunctionToolResult struct { + // CallID is the unique identifier for the tool call. + CallID string + + // Name specifies the function tool invoked. + Name string + + // Result is the function tool result returned by the user + Result string + + // Extra stores additional information. + Extra map[string]any +} + +type ServerToolCall struct { + // Name specifies the server-side tool invoked. + // Supplied by the model server (e.g., `web_search` for OpenAI, `googleSearch` for Gemini). + Name string + + // CallID is the unique identifier for the tool call. + // Empty if not provided by the model server. + CallID string + + // Arguments are the raw inputs to the server-side tool, + // supplied by the component implementer. + Arguments any + + // Extra stores additional information. + Extra map[string]any +} + +type ServerToolResult struct { + // Name specifies the server-side tool invoked. + // Supplied by the model server (e.g., `web_search` for OpenAI, `googleSearch` for Gemini). + Name string + + // CallID is the unique identifier for the tool call. + // Empty if not provided by the model server. + CallID string + + // Result refers to the raw output generated by the server-side tool, + // supplied by the component implementer. + Result any + + // Extra stores additional information. + Extra map[string]any +} + +type MCPToolCall struct { + // ServerLabel is the MCP server label used to identify it in tool calls + ServerLabel string + // ApprovalRequestID is the unique ID of the approval request. + ApprovalRequestID string + // CallID is the unique ID of the tool call. + CallID string + // Name is the name of the tool to run. + Name string + // Arguments is the JSON string arguments for the tool call. + Arguments string + + // Extra stores additional information. + Extra map[string]any +} + +type MCPToolResult struct { + // CallID is the unique ID of the tool call. + CallID string + // Name is the name of the tool to run. + Name string + // Result is the JSON string with the tool result. + Result string + // Error returned when the server fails to run the tool. + Error *MCPToolCallError + + // Extra stores additional information. + Extra map[string]any +} + +type MCPToolCallError struct { + Code int64 + Error string +} + +type MCPListToolsResult struct { + // ServerLabel is the MCP server label used to identify it in tool calls. + ServerLabel string + // Tools is the list of tools available on the server. + Tools []MCPListToolsItem + // Error returned when the server fails to list tools. + Error string + + // Extra stores additional information. + Extra map[string]any +} + +type MCPListToolsItem struct { + // Name is the name of the tool. + Name string + // Description is the description of the tool. + Description string + // InputSchema is the JSON schema that describes the tool input. + InputSchema *jsonschema.Schema +} + +type MCPToolApprovalRequest struct { + // CallID is the unique ID of the tool call. + CallID string + // Name is the name of the tool to run. + Name string + // Arguments is the JSON string arguments for the tool call. + Arguments string + // ServerLabel is the MCP server label used to identify it in tool calls. + ServerLabel string + + // Extra stores additional information. + Extra map[string]any +} + +type MCPToolApprovalResponse struct { + // ApprovalRequestID is the approval request ID being responded to. + ApprovalRequestID string + // Approve indicates whether the request is approved. + Approve bool + // Reason is the rationale for the decision. + // Optional. + Reason string + + // Extra stores additional information. + Extra map[string]any +} diff --git a/schema/anthropic/citation.go b/schema/anthropic/citation.go new file mode 100644 index 000000000..24a4c5aa6 --- /dev/null +++ b/schema/anthropic/citation.go @@ -0,0 +1,49 @@ +package anthropic + +type TextCitation struct { + Type TextCitationType `json:"type"` + + CharLocation *CitationCharLocation `json:"char_location,omitempty"` + PageLocation *CitationPageLocation `json:"page_location,omitempty"` + ContentBlockLocation *CitationContentBlockLocation `json:"content_block_location,omitempty"` + WebSearchResultLocation *CitationWebSearchResultLocation `json:"web_search_result_location,omitempty"` +} + +type CitationCharLocation struct { + CitedText string `json:"cited_text"` + + DocumentTitle string `json:"document_title"` + DocumentIndex int64 `json:"document_index"` + + StartCharIndex int64 `json:"start_char_index"` + EndCharIndex int64 `json:"end_char_index"` +} + +type CitationPageLocation struct { + CitedText string `json:"cited_text"` + + DocumentTitle string `json:"document_title"` + DocumentIndex int64 `json:"document_index"` + + StartPageNumber int64 `json:"start_page_number"` + EndPageNumber int64 `json:"end_page_number"` +} + +type CitationContentBlockLocation struct { + CitedText string `json:"cited_text"` + + DocumentTitle string `json:"document_title"` + DocumentIndex int64 `json:"document_index"` + + StartBlockIndex int64 `json:"start_block_index"` + EndBlockIndex int64 `json:"end_block_index"` +} + +type CitationWebSearchResultLocation struct { + CitedText string `json:"cited_text"` + + Title string `json:"title"` + URL string `json:"url"` + + EncryptedIndex string `json:"encrypted_index"` +} diff --git a/schema/anthropic/types.go b/schema/anthropic/types.go new file mode 100644 index 000000000..fbc85475d --- /dev/null +++ b/schema/anthropic/types.go @@ -0,0 +1,10 @@ +package anthropic + +type TextCitationType string + +const ( + TextCitationTypeCharLocation TextCitationType = "char_location" + TextCitationTypePageLocation TextCitationType = "page_location" + TextCitationTypeContentBlockLocation TextCitationType = "content_block_location" + TextCitationTypeWebSearchResultLocation TextCitationType = "web_search_result_location" +) diff --git a/schema/google/candidate_meta.go b/schema/google/candidate_meta.go new file mode 100644 index 000000000..aead31c2e --- /dev/null +++ b/schema/google/candidate_meta.go @@ -0,0 +1,66 @@ +package google + +type CandidateMeta struct { + GroundingMeta *GroundingMetadata `json:"grounding_meta,omitempty"` +} + +type GroundingMetadata struct { + // List of supporting references retrieved from specified grounding source. + GroundingChunks []*GroundingChunk `json:"grounding_chunks,omitempty"` + // Optional. List of grounding support. + GroundingSupports []*GroundingSupport `json:"grounding_supports,omitempty"` + // Optional. Google search entry for the following-up web searches. + SearchEntryPoint *SearchEntryPoint `json:"search_entry_point,omitempty"` + // Optional. Web search queries for the following-up web search. + WebSearchQueries []string `json:"web_search_queries,omitempty"` +} + +type GroundingChunk struct { + // Grounding chunk from the web. + Web *GroundingChunkWeb `json:"web,omitempty"` +} + +// Chunk from the web. +type GroundingChunkWeb struct { + // Domain of the (original) URI. This field is not supported in Gemini API. + Domain string `json:"domain,omitempty"` + // Title of the chunk. + Title string `json:"title,omitempty"` + // URI reference of the chunk. + URI string `json:"uri,omitempty"` +} + +type GroundingSupport struct { + // Confidence score of the support references. Ranges from 0 to 1. 1 is the most confident. + // For Gemini 2.0 and before, this list must have the same size as the grounding_chunk_indices. + // For Gemini 2.5 and after, this list will be empty and should be ignored. + ConfidenceScores []float32 `json:"confidence_scores,omitempty"` + // A list of indices (into 'grounding_chunk') specifying the citations associated with + // the claim. For instance [1,3,4] means that grounding_chunk[1], grounding_chunk[3], + // grounding_chunk[4] are the retrieved content attributed to the claim. + GroundingChunkIndices []int32 `json:"grounding_chunk_indices,omitempty"` + // Segment of the content this support belongs to. + Segment *Segment `json:"segment,omitempty"` +} + +// Segment of the content. +type Segment struct { + // Output only. End index in the given Part, measured in bytes. Offset from the start + // of the Part, exclusive, starting at zero. + EndIndex int32 `json:"end_index,omitempty"` + // Output only. The index of a Part object within its parent Content object. + PartIndex int32 `json:"part_index,omitempty"` + // Output only. Start index in the given Part, measured in bytes. Offset from the start + // of the Part, inclusive, starting at zero. + StartIndex int32 `json:"start_index,omitempty"` + // Output only. The text corresponding to the segment from the response. + Text string `json:"text,omitempty"` +} + +// Google search entry point. +type SearchEntryPoint struct { + // Optional. Web content snippet that can be embedded in a web page or an app webview. + RenderedContent string `json:"rendered_content,omitempty"` + // Optional. Base64 encoded JSON representing array of tuple. + SDKBlob []byte `json:"sdk_blob,omitempty"` +} diff --git a/schema/message.go b/schema/message.go index 127a520a1..5eb864870 100644 --- a/schema/message.go +++ b/schema/message.go @@ -689,10 +689,10 @@ type TokenUsage struct { PromptTokenDetails PromptTokenDetails `json:"prompt_token_details"` // CompletionTokens is the number of completion tokens. CompletionTokens int `json:"completion_tokens"` + // CompletionTokenDetails is a breakdown of the completion tokens. + CompletionTokenDetails CompletionTokensDetails `json:"completion_token_details"` // TotalTokens is the total number of tokens. TotalTokens int `json:"total_tokens"` - // CompletionTokensDetails is breakdown of completion tokens. - CompletionTokensDetails CompletionTokensDetails `json:"completion_token_details"` } type CompletionTokensDetails struct { diff --git a/schema/openai/annotation.go b/schema/openai/annotation.go new file mode 100644 index 000000000..a834e8072 --- /dev/null +++ b/schema/openai/annotation.go @@ -0,0 +1,55 @@ +package openai + +type TextAnnotation struct { + Type TextAnnotationType `json:"type"` + + FileCitation *TextAnnotationFileCitation `json:"file_citation,omitempty"` + URLCitation *TextAnnotationURLCitation `json:"url_citation,omitempty"` + ContainerFileCitation *TextAnnotationContainerFileCitation `json:"container_file_citation,omitempty"` + FilePath *TextAnnotationFilePath `json:"file_path,omitempty"` +} + +type TextAnnotationFileCitation struct { + // The ID of the file. + FileID string `json:"file_id"` + // The filename of the file cited. + Filename string `json:"filename"` + + // The index of the file in the list of files. + Index int64 `json:"index"` +} + +type TextAnnotationURLCitation struct { + // The title of the web resource. + Title string `json:"title"` + // The URL of the web resource. + URL string `json:"url"` + + // The index of the first character of the URL citation in the message. + StartIndex int64 `json:"start_index"` + // The index of the last character of the URL citation in the message. + EndIndex int64 `json:"end_index"` +} + +type TextAnnotationContainerFileCitation struct { + // The ID of the container file. + ContainerID string `json:"container_id"` + + // The ID of the file. + FileID string `json:"file_id"` + // The filename of the container file cited. + Filename string `json:"filename"` + + // The index of the first character of the container file citation in the message. + StartIndex int64 `json:"start_index"` + // The index of the last character of the container file citation in the message. + EndIndex int64 `json:"end_index"` +} + +type TextAnnotationFilePath struct { + // The ID of the file. + FileID string `json:"file_id"` + + // The index of the file in the list of files. + Index int64 `json:"index"` +} diff --git a/schema/openai/types.go b/schema/openai/types.go new file mode 100644 index 000000000..60cee4361 --- /dev/null +++ b/schema/openai/types.go @@ -0,0 +1,10 @@ +package openai + +type TextAnnotationType string + +const ( + TextAnnotationTypeFileCitation TextAnnotationType = "file_citation" + TextAnnotationTypeURLCitation TextAnnotationType = "url_citation" + TextAnnotationTypeContainerFileCitation TextAnnotationType = "container_file_citation" + TextAnnotationTypeFilePath TextAnnotationType = "file_path" +) From 0aef88268590eaeef27373717d213502cafb0bee Mon Sep 17 00:00:00 2001 From: mrh997 Date: Tue, 25 Nov 2025 11:35:01 +0800 Subject: [PATCH 02/28] feat: change the index in StreamMeta to a non-pointer (#573) --- schema/agentic_message.go | 14 ++------------ schema/anthropic/citation.go | 16 ++++++++++++++++ schema/anthropic/types.go | 16 ++++++++++++++++ schema/google/candidate_meta.go | 16 ++++++++++++++++ schema/openai/annotation.go | 16 ++++++++++++++++ schema/openai/types.go | 16 ++++++++++++++++ 6 files changed, 82 insertions(+), 12 deletions(-) diff --git a/schema/agentic_message.go b/schema/agentic_message.go index e386ed044..84a933c9e 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -75,10 +75,8 @@ type AgenticResponseMeta struct { } type StreamMeta struct { - // Index is used for streaming to identify the chunk of the block for concatenation. - Index *int - // Streaming phase of the content block. - Phase StreamPhase + // Index is the index position of this block in the final response. + Index int } type ContentBlock struct { @@ -120,14 +118,6 @@ type ContentBlock struct { StreamMeta *StreamMeta } -type StreamPhase string - -const ( - StreamPhaseStart StreamPhase = "start" - StreamPhaseDelta StreamPhase = "delta" - StreamPhaseStop StreamPhase = "stop" -) - type UserInputText struct { Text string diff --git a/schema/anthropic/citation.go b/schema/anthropic/citation.go index 24a4c5aa6..064477688 100644 --- a/schema/anthropic/citation.go +++ b/schema/anthropic/citation.go @@ -1,3 +1,19 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package anthropic type TextCitation struct { diff --git a/schema/anthropic/types.go b/schema/anthropic/types.go index fbc85475d..cc8b1f877 100644 --- a/schema/anthropic/types.go +++ b/schema/anthropic/types.go @@ -1,3 +1,19 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package anthropic type TextCitationType string diff --git a/schema/google/candidate_meta.go b/schema/google/candidate_meta.go index aead31c2e..8cd324254 100644 --- a/schema/google/candidate_meta.go +++ b/schema/google/candidate_meta.go @@ -1,3 +1,19 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package google type CandidateMeta struct { diff --git a/schema/openai/annotation.go b/schema/openai/annotation.go index a834e8072..ad4b6b91f 100644 --- a/schema/openai/annotation.go +++ b/schema/openai/annotation.go @@ -1,3 +1,19 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package openai type TextAnnotation struct { diff --git a/schema/openai/types.go b/schema/openai/types.go index 60cee4361..321ee2a9e 100644 --- a/schema/openai/types.go +++ b/schema/openai/types.go @@ -1,3 +1,19 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package openai type TextAnnotationType string From 389683828ce940351ee548c4d146b0d3b1d1e5ef Mon Sep 17 00:00:00 2001 From: mrh997 Date: Tue, 25 Nov 2025 21:34:26 +0800 Subject: [PATCH 03/28] feat: improve AgenticResponseMeta definition (#575) --- schema/agentic_message.go | 30 +++++++---------- schema/{anthropic => claude}/citation.go | 2 +- schema/claude/messages_meta.go | 22 +++++++++++++ schema/{anthropic => claude}/types.go | 2 +- .../response_meta.go} | 6 ++-- schema/openai/response_meta.go | 33 +++++++++++++++++++ 6 files changed, 72 insertions(+), 23 deletions(-) rename schema/{anthropic => claude}/citation.go (99%) create mode 100644 schema/claude/messages_meta.go rename schema/{anthropic => claude}/types.go (98%) rename schema/{google/candidate_meta.go => gemini/response_meta.go} (95%) create mode 100644 schema/openai/response_meta.go diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 84a933c9e..0d2b01047 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -17,8 +17,8 @@ package schema import ( - "github.com/cloudwego/eino/schema/anthropic" - "github.com/cloudwego/eino/schema/google" + "github.com/cloudwego/eino/schema/claude" + "github.com/cloudwego/eino/schema/gemini" "github.com/cloudwego/eino/schema/openai" "github.com/eino-contrib/jsonschema" ) @@ -66,16 +66,16 @@ type AgenticMessage struct { } type AgenticResponseMeta struct { - Status *string - FinishReason string - TokenUsage *TokenUsage - GoogleAdditionalMeta *google.CandidateMeta + OpenAIExtensions *openai.ResponseMeta + GeminiExtensions *gemini.ResponseMeta + ClaudeExtensions *claude.MessageMeta + Extensions any } type StreamMeta struct { - // Index is the index position of this block in the final response. + // Index specifies the index position of this block in the final response. Index int } @@ -166,8 +166,8 @@ type UserInputFile struct { type AssistantGenText struct { Text string - OpenAIAnnotations []*openai.TextAnnotation - AnthropicCitations []*anthropic.TextCitation + OpenAIAnnotations []*openai.TextAnnotation + ClaudeCitations []*claude.TextCitation // Extra stores additional information. Extra map[string]any @@ -203,7 +203,6 @@ type AssistantGenVideo struct { type Reasoning struct { // Summary is the reasoning content summary. Summary []*ReasoningSummary - // EncryptedContent is the encrypted reasoning content. EncryptedContent string @@ -212,8 +211,8 @@ type Reasoning struct { } type ReasoningSummary struct { - // Index specifies the ReasoningSummary chunk to be concatenated during streaming. - Index *int + // Index specifies the index position of this summary in the final Reasoning. + Index int Text string } @@ -221,10 +220,8 @@ type ReasoningSummary struct { type FunctionToolCall struct { // CallID is the unique identifier for the tool call. CallID string - // Name specifies the function tool invoked. Name string - // Arguments is the JSON string arguments for the function tool call. Arguments string @@ -235,10 +232,8 @@ type FunctionToolCall struct { type FunctionToolResult struct { // CallID is the unique identifier for the tool call. CallID string - // Name specifies the function tool invoked. Name string - // Result is the function tool result returned by the user Result string @@ -250,15 +245,12 @@ type ServerToolCall struct { // Name specifies the server-side tool invoked. // Supplied by the model server (e.g., `web_search` for OpenAI, `googleSearch` for Gemini). Name string - // CallID is the unique identifier for the tool call. // Empty if not provided by the model server. CallID string - // Arguments are the raw inputs to the server-side tool, // supplied by the component implementer. Arguments any - // Extra stores additional information. Extra map[string]any } diff --git a/schema/anthropic/citation.go b/schema/claude/citation.go similarity index 99% rename from schema/anthropic/citation.go rename to schema/claude/citation.go index 064477688..b5092d3bc 100644 --- a/schema/anthropic/citation.go +++ b/schema/claude/citation.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package anthropic +package claude type TextCitation struct { Type TextCitationType `json:"type"` diff --git a/schema/claude/messages_meta.go b/schema/claude/messages_meta.go new file mode 100644 index 000000000..a72dded2a --- /dev/null +++ b/schema/claude/messages_meta.go @@ -0,0 +1,22 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package claude + +type MessageMeta struct { + ID string `json:"id"` + StopReason string `json:"stop_reason"` +} diff --git a/schema/anthropic/types.go b/schema/claude/types.go similarity index 98% rename from schema/anthropic/types.go rename to schema/claude/types.go index cc8b1f877..cbf8784f6 100644 --- a/schema/anthropic/types.go +++ b/schema/claude/types.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package anthropic +package claude type TextCitationType string diff --git a/schema/google/candidate_meta.go b/schema/gemini/response_meta.go similarity index 95% rename from schema/google/candidate_meta.go rename to schema/gemini/response_meta.go index 8cd324254..3bc590a72 100644 --- a/schema/google/candidate_meta.go +++ b/schema/gemini/response_meta.go @@ -14,9 +14,11 @@ * limitations under the License. */ -package google +package gemini -type CandidateMeta struct { +type ResponseMeta struct { + ID string `json:"id"` + FinishReason string `json:"finish_reason"` GroundingMeta *GroundingMetadata `json:"grounding_meta,omitempty"` } diff --git a/schema/openai/response_meta.go b/schema/openai/response_meta.go new file mode 100644 index 000000000..1b184073e --- /dev/null +++ b/schema/openai/response_meta.go @@ -0,0 +1,33 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openai + +type ResponseMeta struct { + ID string `json:"id"` + Status string `json:"status"` + Error *ResponseError `json:"error,omitempty"` + IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"` +} + +type ResponseError struct { + Code string `json:"code"` + Message string `json:"message"` +} + +type IncompleteDetails struct { + Reason string `json:"reason"` +} From 320dcc1fabda8099814dfdd3a598afaa37b774c4 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Wed, 26 Nov 2025 16:08:32 +0800 Subject: [PATCH 04/28] feat: improve AssistantGenText definition (#577) --- schema/agentic_message.go | 5 +++-- schema/claude/{types.go => consts.go} | 0 schema/claude/{citation.go => content_block.go} | 4 ++++ schema/claude/{messages_meta.go => message_meta.go} | 0 schema/openai/{types.go => consts.go} | 0 schema/openai/{annotation.go => content_block.go} | 5 +++++ 6 files changed, 12 insertions(+), 2 deletions(-) rename schema/claude/{types.go => consts.go} (100%) rename schema/claude/{citation.go => content_block.go} (96%) rename schema/claude/{messages_meta.go => message_meta.go} (100%) rename schema/openai/{types.go => consts.go} (100%) rename schema/openai/{annotation.go => content_block.go} (94%) diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 0d2b01047..09ab28e40 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -166,8 +166,9 @@ type UserInputFile struct { type AssistantGenText struct { Text string - OpenAIAnnotations []*openai.TextAnnotation - ClaudeCitations []*claude.TextCitation + OpenAIExtensions *openai.OutputText + ClaudeExtensions *claude.TextBlock + Extensions any // Extra stores additional information. Extra map[string]any diff --git a/schema/claude/types.go b/schema/claude/consts.go similarity index 100% rename from schema/claude/types.go rename to schema/claude/consts.go diff --git a/schema/claude/citation.go b/schema/claude/content_block.go similarity index 96% rename from schema/claude/citation.go rename to schema/claude/content_block.go index b5092d3bc..ba297126e 100644 --- a/schema/claude/citation.go +++ b/schema/claude/content_block.go @@ -16,6 +16,10 @@ package claude +type TextBlock struct { + Citations []*TextCitation `json:"citations"` +} + type TextCitation struct { Type TextCitationType `json:"type"` diff --git a/schema/claude/messages_meta.go b/schema/claude/message_meta.go similarity index 100% rename from schema/claude/messages_meta.go rename to schema/claude/message_meta.go diff --git a/schema/openai/types.go b/schema/openai/consts.go similarity index 100% rename from schema/openai/types.go rename to schema/openai/consts.go diff --git a/schema/openai/annotation.go b/schema/openai/content_block.go similarity index 94% rename from schema/openai/annotation.go rename to schema/openai/content_block.go index ad4b6b91f..5135964b7 100644 --- a/schema/openai/annotation.go +++ b/schema/openai/content_block.go @@ -16,6 +16,11 @@ package openai +type OutputText struct { + ItemID string `json:"item_id"` + Annotations []*TextAnnotation `json:"annotations"` +} + type TextAnnotation struct { Type TextAnnotationType `json:"type"` From d9c9b13201216df164ab09b525375cfd2ae37a26 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Wed, 26 Nov 2025 17:52:22 +0800 Subject: [PATCH 05/28] feat: improve extension type name (#578) --- schema/agentic_message.go | 14 +++++++------- schema/claude/content_block.go | 2 +- .../claude/{message_meta.go => response_meta.go} | 2 +- schema/gemini/response_meta.go | 2 +- schema/openai/content_block.go | 2 +- schema/openai/response_meta.go | 2 +- 6 files changed, 12 insertions(+), 12 deletions(-) rename schema/claude/{message_meta.go => response_meta.go} (95%) diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 09ab28e40..3530038e2 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -68,10 +68,10 @@ type AgenticMessage struct { type AgenticResponseMeta struct { TokenUsage *TokenUsage - OpenAIExtensions *openai.ResponseMeta - GeminiExtensions *gemini.ResponseMeta - ClaudeExtensions *claude.MessageMeta - Extensions any + OpenAIExtension *openai.ResponseMetaExtension + GeminiExtension *gemini.ResponseMetaExtension + ClaudeExtension *claude.ResponseMetaExtension + Extension any } type StreamMeta struct { @@ -166,9 +166,9 @@ type UserInputFile struct { type AssistantGenText struct { Text string - OpenAIExtensions *openai.OutputText - ClaudeExtensions *claude.TextBlock - Extensions any + OpenAIExtension *openai.AssistantGenTextExtension + ClaudeExtension *claude.AssistantGenTextExtension + Extension any // Extra stores additional information. Extra map[string]any diff --git a/schema/claude/content_block.go b/schema/claude/content_block.go index ba297126e..4421db807 100644 --- a/schema/claude/content_block.go +++ b/schema/claude/content_block.go @@ -16,7 +16,7 @@ package claude -type TextBlock struct { +type AssistantGenTextExtension struct { Citations []*TextCitation `json:"citations"` } diff --git a/schema/claude/message_meta.go b/schema/claude/response_meta.go similarity index 95% rename from schema/claude/message_meta.go rename to schema/claude/response_meta.go index a72dded2a..7d9dbe740 100644 --- a/schema/claude/message_meta.go +++ b/schema/claude/response_meta.go @@ -16,7 +16,7 @@ package claude -type MessageMeta struct { +type ResponseMetaExtension struct { ID string `json:"id"` StopReason string `json:"stop_reason"` } diff --git a/schema/gemini/response_meta.go b/schema/gemini/response_meta.go index 3bc590a72..bb4af92c9 100644 --- a/schema/gemini/response_meta.go +++ b/schema/gemini/response_meta.go @@ -16,7 +16,7 @@ package gemini -type ResponseMeta struct { +type ResponseMetaExtension struct { ID string `json:"id"` FinishReason string `json:"finish_reason"` GroundingMeta *GroundingMetadata `json:"grounding_meta,omitempty"` diff --git a/schema/openai/content_block.go b/schema/openai/content_block.go index 5135964b7..dfa83109a 100644 --- a/schema/openai/content_block.go +++ b/schema/openai/content_block.go @@ -16,7 +16,7 @@ package openai -type OutputText struct { +type AssistantGenTextExtension struct { ItemID string `json:"item_id"` Annotations []*TextAnnotation `json:"annotations"` } diff --git a/schema/openai/response_meta.go b/schema/openai/response_meta.go index 1b184073e..809fbb1a5 100644 --- a/schema/openai/response_meta.go +++ b/schema/openai/response_meta.go @@ -16,7 +16,7 @@ package openai -type ResponseMeta struct { +type ResponseMetaExtension struct { ID string `json:"id"` Status string `json:"status"` Error *ResponseError `json:"error,omitempty"` From e6fe981d028dec30a704fd1cdd633d1e134cbcdf Mon Sep 17 00:00:00 2001 From: mrh997 Date: Wed, 26 Nov 2025 20:23:14 +0800 Subject: [PATCH 06/28] feat: modify package name (#579) --- components/{agency => agentic}/callback_extra.go | 2 +- components/{agency => agentic}/interface.go | 6 +++--- components/{agency => agentic}/option.go | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) rename components/{agency => agentic}/callback_extra.go (99%) rename components/{agency => agentic}/interface.go (89%) rename components/{agency => agentic}/option.go (99%) diff --git a/components/agency/callback_extra.go b/components/agentic/callback_extra.go similarity index 99% rename from components/agency/callback_extra.go rename to components/agentic/callback_extra.go index 984756d1a..f824750f9 100644 --- a/components/agency/callback_extra.go +++ b/components/agentic/callback_extra.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package agency +package agentic import ( "github.com/cloudwego/eino/callbacks" diff --git a/components/agency/interface.go b/components/agentic/interface.go similarity index 89% rename from components/agency/interface.go rename to components/agentic/interface.go index e33d6a933..e9960d332 100644 --- a/components/agency/interface.go +++ b/components/agentic/interface.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package agency +package agentic import ( "context" @@ -22,8 +22,8 @@ import ( "github.com/cloudwego/eino/schema" ) -type AgenticModel interface { +type Model interface { Generate(ctx context.Context, input []*schema.AgenticMessage, opts ...Option) (*schema.AgenticMessage, error) Stream(ctx context.Context, input []*schema.AgenticMessage, opts ...Option) (*schema.StreamReader[*schema.AgenticMessage], error) - WithTools(tools []*schema.ToolInfo) (AgenticModel, error) + WithTools(tools []*schema.ToolInfo) (Model, error) } diff --git a/components/agency/option.go b/components/agentic/option.go similarity index 99% rename from components/agency/option.go rename to components/agentic/option.go index 17028f6e9..b000d6893 100644 --- a/components/agency/option.go +++ b/components/agentic/option.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package agency +package agentic // Options is the common options for the model. type Options struct { From 27404845f2185c6907889c6e5eb73eaac60a8b44 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Wed, 26 Nov 2025 21:06:29 +0800 Subject: [PATCH 07/28] feat: remove TokenUsage definition in CallbackOutput (#580) --- components/agentic/callback_extra.go | 31 ++++++---------------------- 1 file changed, 6 insertions(+), 25 deletions(-) diff --git a/components/agentic/callback_extra.go b/components/agentic/callback_extra.go index f824750f9..f35b37779 100644 --- a/components/agentic/callback_extra.go +++ b/components/agentic/callback_extra.go @@ -21,23 +21,6 @@ import ( "github.com/cloudwego/eino/schema" ) -// TokenUsageMeta is the token usage for the model. -type TokenUsageMeta struct { - InputTokens int64 `json:"input_tokens"` - InputTokensDetails InputTokensUsageDetails `json:"input_tokens_details"` - OutputTokens int64 `json:"output_tokens"` - OutputTokensDetails OutputTokensUsageDetails `json:"output_tokens_details"` - TotalTokens int64 `json:"total_tokens"` -} - -type InputTokensUsageDetails struct { - CachedTokens int64 `json:"cached_tokens"` -} - -type OutputTokensUsageDetails struct { - ReasoningTokens int64 `json:"reasoning_tokens"` -} - // Config is the config for the model. type Config struct { // Model is the model name. @@ -50,8 +33,8 @@ type Config struct { // CallbackInput is the input for the model callback. type CallbackInput struct { - // Responses is the responses to be sent to the model. - Responses []*schema.AgenticMessage + // Messages is the messages to be sent to the model. + Messages []*schema.AgenticMessage // Tools is the tools to be used in the model. Tools []*schema.ToolInfo // Config is the config for the model. @@ -62,12 +45,10 @@ type CallbackInput struct { // CallbackOutput is the output for the model callback. type CallbackOutput struct { - // Response is the response generated by the model. - Response *schema.AgenticMessage + // Message is the message generated by the model. + Message *schema.AgenticMessage // Config is the config for the model. Config *Config - // Usage is the token usage of this request. - Usage *TokenUsageMeta // Extra is the extra information for the callback. Extra map[string]any } @@ -79,7 +60,7 @@ func ConvCallbackInput(src callbacks.CallbackInput) *CallbackInput { return t case []*schema.AgenticMessage: // when callback is injected by graph node, not the component implementation itself, the input is the input of Chat Model interface, which is []*schema.AgenticMessage return &CallbackInput{ - Responses: t, + Messages: t, } default: return nil @@ -93,7 +74,7 @@ func ConvCallbackOutput(src callbacks.CallbackOutput) *CallbackOutput { return t case *schema.AgenticMessage: // when callback is injected by graph node, not the component implementation itself, the output is the output of Chat Model interface, which is *schema.AgenticMessage return &CallbackOutput{ - Response: t, + Message: t, } default: return nil From 5780c79d2dd316513b8eb9504647b5a4af170086 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Mon, 1 Dec 2025 11:51:09 +0800 Subject: [PATCH 08/28] feat: add helper functions for AgenticMessage (#582) --- components/agentic/option.go | 62 +++++++++++++++++ schema/agentic_message.go | 119 ++++++++++++++++++++++++++++----- schema/openai/content_block.go | 1 - 3 files changed, 164 insertions(+), 18 deletions(-) diff --git a/components/agentic/option.go b/components/agentic/option.go index b000d6893..dae6139be 100644 --- a/components/agentic/option.go +++ b/components/agentic/option.go @@ -16,8 +16,22 @@ package agentic +import ( + "github.com/cloudwego/eino/schema" +) + // Options is the common options for the model. type Options struct { + // Temperature is the temperature for the model, which controls the randomness of the model. + Temperature *float32 + // Model is the model name. + Model *string + // TopP is the top p for the model, which controls the diversity of the model. + TopP *float32 + // Tools is a list of tools the model may call. + Tools []*schema.ToolInfo + // ToolChoice controls which tool is called by the model. + ToolChoice *schema.ToolChoice } // Option is the call option for ChatModel component. @@ -27,6 +41,54 @@ type Option struct { implSpecificOptFn any } +// WithTemperature is the option to set the temperature for the model. +func WithTemperature(temperature float32) Option { + return Option{ + apply: func(opts *Options) { + opts.Temperature = &temperature + }, + } +} + +// WithModel is the option to set the model name. +func WithModel(name string) Option { + return Option{ + apply: func(opts *Options) { + opts.Model = &name + }, + } +} + +// WithTopP is the option to set the top p for the model. +func WithTopP(topP float32) Option { + return Option{ + apply: func(opts *Options) { + opts.TopP = &topP + }, + } +} + +// WithTools is the option to set tools for the model. +func WithTools(tools []*schema.ToolInfo) Option { + if tools == nil { + tools = []*schema.ToolInfo{} + } + return Option{ + apply: func(opts *Options) { + opts.Tools = tools + }, + } +} + +// WithToolChoice is the option to set tool choice for the model. +func WithToolChoice(toolChoice schema.ToolChoice) Option { + return Option{ + apply: func(opts *Options) { + opts.ToolChoice = &toolChoice + }, + } +} + // WrapImplSpecificOptFn is the option to wrap the implementation specific option function. func WrapImplSpecificOptFn[T any](optFn func(*T)) Option { return Option{ diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 3530038e2..3953fac7c 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -126,8 +126,8 @@ type UserInputText struct { } type UserInputImage struct { - URL *string - Base64Data *string + URL string + Base64Data string MIMEType string Detail ImageURLDetail @@ -136,8 +136,8 @@ type UserInputImage struct { } type UserInputAudio struct { - URL *string - Base64Data *string + URL string + Base64Data string MIMEType string // Extra stores additional information. @@ -145,8 +145,8 @@ type UserInputAudio struct { } type UserInputVideo struct { - URL *string - Base64Data *string + URL string + Base64Data string MIMEType string // Extra stores additional information. @@ -154,9 +154,9 @@ type UserInputVideo struct { } type UserInputFile struct { - URL *string - Name *string - Base64Data *string + URL string + Name string + Base64Data string MIMEType string // Extra stores additional information. @@ -175,8 +175,8 @@ type AssistantGenText struct { } type AssistantGenImage struct { - URL *string - Base64Data *string + URL string + Base64Data string MIMEType string // Extra stores additional information. @@ -184,8 +184,8 @@ type AssistantGenImage struct { } type AssistantGenAudio struct { - URL *string - Base64Data *string + URL string + Base64Data string MIMEType string // Extra stores additional information. @@ -193,8 +193,8 @@ type AssistantGenAudio struct { } type AssistantGenVideo struct { - URL *string - Base64Data *string + URL string + Base64Data string MIMEType string // Extra stores additional information. @@ -290,6 +290,8 @@ type MCPToolCall struct { } type MCPToolResult struct { + // ServerLabel is the MCP server label used to identify it in tool calls + ServerLabel string // CallID is the unique ID of the tool call. CallID string // Name is the name of the tool to run. @@ -304,7 +306,7 @@ type MCPToolResult struct { } type MCPToolCallError struct { - Code int64 + Code *int64 Error string } @@ -312,7 +314,7 @@ type MCPListToolsResult struct { // ServerLabel is the MCP server label used to identify it in tool calls. ServerLabel string // Tools is the list of tools available on the server. - Tools []MCPListToolsItem + Tools []*MCPListToolsItem // Error returned when the server fails to list tools. Error string @@ -355,3 +357,86 @@ type MCPToolApprovalResponse struct { // Extra stores additional information. Extra map[string]any } + +// DeveloperAgenticMessage represents a message with AgenticRoleType "developer". +func DeveloperAgenticMessage(text string) *AgenticMessage { + return &AgenticMessage{ + Role: AgenticRoleTypeDeveloper, + ContentBlocks: []*ContentBlock{NewContentBlock(&UserInputText{Text: text})}, + } +} + +// SystemAgenticMessage represents a message with AgenticRoleType "system". +func SystemAgenticMessage(text string) *AgenticMessage { + return &AgenticMessage{ + Role: AgenticRoleTypeSystem, + ContentBlocks: []*ContentBlock{NewContentBlock(&UserInputText{Text: text})}, + } +} + +// UserAgenticMessage represents a message with AgenticRoleType "user". +func UserAgenticMessage(text string) *AgenticMessage { + return &AgenticMessage{ + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{NewContentBlock(&UserInputText{Text: text})}, + } +} + +// FunctionToolResultAgenticMessage represents a function tool result message with AgenticRoleType "user". +func FunctionToolResultAgenticMessage(callID, name, result string) *AgenticMessage { + return &AgenticMessage{ + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + NewContentBlock(&FunctionToolResult{ + CallID: callID, + Name: name, + Result: result, + }), + }, + } +} + +func NewContentBlock(block any) *ContentBlock { + switch b := block.(type) { + case *Reasoning: + return &ContentBlock{Type: ContentBlockTypeReasoning, Reasoning: b} + case *UserInputText: + return &ContentBlock{Type: ContentBlockTypeUserInputText, UserInputText: b} + case *UserInputImage: + return &ContentBlock{Type: ContentBlockTypeUserInputImage, UserInputImage: b} + case *UserInputAudio: + return &ContentBlock{Type: ContentBlockTypeUserInputAudio, UserInputAudio: b} + case *UserInputVideo: + return &ContentBlock{Type: ContentBlockTypeUserInputVideo, UserInputVideo: b} + case *UserInputFile: + return &ContentBlock{Type: ContentBlockTypeUserInputFile, UserInputFile: b} + case *AssistantGenText: + return &ContentBlock{Type: ContentBlockTypeAssistantGenText, AssistantGenText: b} + case *AssistantGenImage: + return &ContentBlock{Type: ContentBlockTypeAssistantGenImage, AssistantGenImage: b} + case *AssistantGenAudio: + return &ContentBlock{Type: ContentBlockTypeAssistantGenAudio, AssistantGenAudio: b} + case *AssistantGenVideo: + return &ContentBlock{Type: ContentBlockTypeAssistantGenVideo, AssistantGenVideo: b} + case *FunctionToolCall: + return &ContentBlock{Type: ContentBlockTypeFunctionToolCall, FunctionToolCall: b} + case *FunctionToolResult: + return &ContentBlock{Type: ContentBlockTypeFunctionToolResult, FunctionToolResult: b} + case *ServerToolCall: + return &ContentBlock{Type: ContentBlockTypeServerToolCall, ServerToolCall: b} + case *ServerToolResult: + return &ContentBlock{Type: ContentBlockTypeServerToolResult, ServerToolResult: b} + case *MCPToolCall: + return &ContentBlock{Type: ContentBlockTypeMCPToolCall, MCPToolCall: b} + case *MCPToolResult: + return &ContentBlock{Type: ContentBlockTypeMCPToolResult, MCPToolResult: b} + case *MCPListToolsResult: + return &ContentBlock{Type: ContentBlockTypeMCPListTools, MCPListToolsResult: b} + case *MCPToolApprovalRequest: + return &ContentBlock{Type: ContentBlockTypeMCPToolApprovalRequest, MCPToolApprovalRequest: b} + case *MCPToolApprovalResponse: + return &ContentBlock{Type: ContentBlockTypeMCPToolApprovalResponse, MCPToolApprovalResponse: b} + default: + return nil + } +} diff --git a/schema/openai/content_block.go b/schema/openai/content_block.go index dfa83109a..b0408e310 100644 --- a/schema/openai/content_block.go +++ b/schema/openai/content_block.go @@ -17,7 +17,6 @@ package openai type AssistantGenTextExtension struct { - ItemID string `json:"item_id"` Annotations []*TextAnnotation `json:"annotations"` } From 06efb85386662dd1770d9ba78d13b155960d1a4f Mon Sep 17 00:00:00 2001 From: mrh997 Date: Mon, 1 Dec 2025 12:08:12 +0800 Subject: [PATCH 09/28] feat: improve MCPToolCallError definition (#592) --- schema/agentic_message.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 3953fac7c..44fc37bb3 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -306,8 +306,8 @@ type MCPToolResult struct { } type MCPToolCallError struct { - Code *int64 - Error string + Code *int64 + Message string } type MCPListToolsResult struct { From da12054570be5c3121565a8e45ddace67c41ba69 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Mon, 1 Dec 2025 14:15:54 +0800 Subject: [PATCH 10/28] feat: improve Options definition (#593) --- components/agentic/option.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/components/agentic/option.go b/components/agentic/option.go index dae6139be..ac117ddb4 100644 --- a/components/agentic/option.go +++ b/components/agentic/option.go @@ -23,11 +23,11 @@ import ( // Options is the common options for the model. type Options struct { // Temperature is the temperature for the model, which controls the randomness of the model. - Temperature *float32 + Temperature *float64 // Model is the model name. Model *string // TopP is the top p for the model, which controls the diversity of the model. - TopP *float32 + TopP *float64 // Tools is a list of tools the model may call. Tools []*schema.ToolInfo // ToolChoice controls which tool is called by the model. @@ -42,7 +42,7 @@ type Option struct { } // WithTemperature is the option to set the temperature for the model. -func WithTemperature(temperature float32) Option { +func WithTemperature(temperature float64) Option { return Option{ apply: func(opts *Options) { opts.Temperature = &temperature @@ -60,7 +60,7 @@ func WithModel(name string) Option { } // WithTopP is the option to set the top p for the model. -func WithTopP(topP float32) Option { +func WithTopP(topP float64) Option { return Option{ apply: func(opts *Options) { opts.TopP = &topP From dc403dd46870eb305b743d97dbfa305f04a3d6ee Mon Sep 17 00:00:00 2001 From: mrh997 Date: Mon, 1 Dec 2025 14:41:51 +0800 Subject: [PATCH 11/28] feat: add CallbackInput definition for CallbackInput (#594) --- components/agentic/callback_extra.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/components/agentic/callback_extra.go b/components/agentic/callback_extra.go index f35b37779..389408d33 100644 --- a/components/agentic/callback_extra.go +++ b/components/agentic/callback_extra.go @@ -37,6 +37,8 @@ type CallbackInput struct { Messages []*schema.AgenticMessage // Tools is the tools to be used in the model. Tools []*schema.ToolInfo + // ToolChoice controls which tool is called by the model. + ToolChoice *schema.ToolChoice // Config is the config for the model. Config *Config // Extra is the extra information for the callback. From 28e8af0e179d4bab00c662af8fd03cb540cdc4f8 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Mon, 1 Dec 2025 17:01:51 +0800 Subject: [PATCH 12/28] feat: define 'omitempty' flag in json tag (#595) --- schema/agentic_message.go | 4 ++-- schema/claude/content_block.go | 42 +++++++++++++++++----------------- schema/claude/response_meta.go | 4 ++-- schema/gemini/response_meta.go | 4 ++-- schema/openai/content_block.go | 32 +++++++++++++------------- schema/openai/response_meta.go | 10 ++++---- 6 files changed, 48 insertions(+), 48 deletions(-) diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 44fc37bb3..367debd97 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -76,7 +76,7 @@ type AgenticResponseMeta struct { type StreamMeta struct { // Index specifies the index position of this block in the final response. - Index int + Index int64 } type ContentBlock struct { @@ -213,7 +213,7 @@ type Reasoning struct { type ReasoningSummary struct { // Index specifies the index position of this summary in the final Reasoning. - Index int + Index int64 Text string } diff --git a/schema/claude/content_block.go b/schema/claude/content_block.go index 4421db807..0c43d1045 100644 --- a/schema/claude/content_block.go +++ b/schema/claude/content_block.go @@ -17,11 +17,11 @@ package claude type AssistantGenTextExtension struct { - Citations []*TextCitation `json:"citations"` + Citations []*TextCitation `json:"citations,omitempty"` } type TextCitation struct { - Type TextCitationType `json:"type"` + Type TextCitationType `json:"type,omitempty"` CharLocation *CitationCharLocation `json:"char_location,omitempty"` PageLocation *CitationPageLocation `json:"page_location,omitempty"` @@ -30,40 +30,40 @@ type TextCitation struct { } type CitationCharLocation struct { - CitedText string `json:"cited_text"` + CitedText string `json:"cited_text,omitempty"` - DocumentTitle string `json:"document_title"` - DocumentIndex int64 `json:"document_index"` + DocumentTitle string `json:"document_title,omitempty"` + DocumentIndex int64 `json:"document_index,omitempty"` - StartCharIndex int64 `json:"start_char_index"` - EndCharIndex int64 `json:"end_char_index"` + StartCharIndex int64 `json:"start_char_index,omitempty"` + EndCharIndex int64 `json:"end_char_index,omitempty"` } type CitationPageLocation struct { - CitedText string `json:"cited_text"` + CitedText string `json:"cited_text,omitempty"` - DocumentTitle string `json:"document_title"` - DocumentIndex int64 `json:"document_index"` + DocumentTitle string `json:"document_title,omitempty"` + DocumentIndex int64 `json:"document_index,omitempty"` - StartPageNumber int64 `json:"start_page_number"` - EndPageNumber int64 `json:"end_page_number"` + StartPageNumber int64 `json:"start_page_number,omitempty"` + EndPageNumber int64 `json:"end_page_number,omitempty"` } type CitationContentBlockLocation struct { - CitedText string `json:"cited_text"` + CitedText string `json:"cited_text,omitempty"` - DocumentTitle string `json:"document_title"` - DocumentIndex int64 `json:"document_index"` + DocumentTitle string `json:"document_title,omitempty"` + DocumentIndex int64 `json:"document_index,omitempty"` - StartBlockIndex int64 `json:"start_block_index"` - EndBlockIndex int64 `json:"end_block_index"` + StartBlockIndex int64 `json:"start_block_index,omitempty"` + EndBlockIndex int64 `json:"end_block_index,omitempty"` } type CitationWebSearchResultLocation struct { - CitedText string `json:"cited_text"` + CitedText string `json:"cited_text,omitempty"` - Title string `json:"title"` - URL string `json:"url"` + Title string `json:"title,omitempty"` + URL string `json:"url,omitempty"` - EncryptedIndex string `json:"encrypted_index"` + EncryptedIndex string `json:"encrypted_index,omitempty"` } diff --git a/schema/claude/response_meta.go b/schema/claude/response_meta.go index 7d9dbe740..9f60dd713 100644 --- a/schema/claude/response_meta.go +++ b/schema/claude/response_meta.go @@ -17,6 +17,6 @@ package claude type ResponseMetaExtension struct { - ID string `json:"id"` - StopReason string `json:"stop_reason"` + ID string `json:"id,omitempty"` + StopReason string `json:"stop_reason,omitempty"` } diff --git a/schema/gemini/response_meta.go b/schema/gemini/response_meta.go index bb4af92c9..a5b3f626c 100644 --- a/schema/gemini/response_meta.go +++ b/schema/gemini/response_meta.go @@ -17,8 +17,8 @@ package gemini type ResponseMetaExtension struct { - ID string `json:"id"` - FinishReason string `json:"finish_reason"` + ID string `json:"id,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` GroundingMeta *GroundingMetadata `json:"grounding_meta,omitempty"` } diff --git a/schema/openai/content_block.go b/schema/openai/content_block.go index b0408e310..5d92be8f7 100644 --- a/schema/openai/content_block.go +++ b/schema/openai/content_block.go @@ -17,11 +17,11 @@ package openai type AssistantGenTextExtension struct { - Annotations []*TextAnnotation `json:"annotations"` + Annotations []*TextAnnotation `json:"annotations,omitempty"` } type TextAnnotation struct { - Type TextAnnotationType `json:"type"` + Type TextAnnotationType `json:"type,omitempty"` FileCitation *TextAnnotationFileCitation `json:"file_citation,omitempty"` URLCitation *TextAnnotationURLCitation `json:"url_citation,omitempty"` @@ -31,45 +31,45 @@ type TextAnnotation struct { type TextAnnotationFileCitation struct { // The ID of the file. - FileID string `json:"file_id"` + FileID string `json:"file_id,omitempty"` // The filename of the file cited. - Filename string `json:"filename"` + Filename string `json:"filename,omitempty"` // The index of the file in the list of files. - Index int64 `json:"index"` + Index int64 `json:"index,omitempty"` } type TextAnnotationURLCitation struct { // The title of the web resource. - Title string `json:"title"` + Title string `json:"title,omitempty"` // The URL of the web resource. - URL string `json:"url"` + URL string `json:"url,omitempty"` // The index of the first character of the URL citation in the message. - StartIndex int64 `json:"start_index"` + StartIndex int64 `json:"start_index,omitempty"` // The index of the last character of the URL citation in the message. - EndIndex int64 `json:"end_index"` + EndIndex int64 `json:"end_index,omitempty"` } type TextAnnotationContainerFileCitation struct { // The ID of the container file. - ContainerID string `json:"container_id"` + ContainerID string `json:"container_id,omitempty"` // The ID of the file. - FileID string `json:"file_id"` + FileID string `json:"file_id,omitempty"` // The filename of the container file cited. - Filename string `json:"filename"` + Filename string `json:"filename,omitempty"` // The index of the first character of the container file citation in the message. - StartIndex int64 `json:"start_index"` + StartIndex int64 `json:"start_index,omitempty"` // The index of the last character of the container file citation in the message. - EndIndex int64 `json:"end_index"` + EndIndex int64 `json:"end_index,omitempty"` } type TextAnnotationFilePath struct { // The ID of the file. - FileID string `json:"file_id"` + FileID string `json:"file_id,omitempty"` // The index of the file in the list of files. - Index int64 `json:"index"` + Index int64 `json:"index,omitempty"` } diff --git a/schema/openai/response_meta.go b/schema/openai/response_meta.go index 809fbb1a5..90e884173 100644 --- a/schema/openai/response_meta.go +++ b/schema/openai/response_meta.go @@ -17,17 +17,17 @@ package openai type ResponseMetaExtension struct { - ID string `json:"id"` - Status string `json:"status"` + ID string `json:"id,omitempty"` + Status string `json:"status,omitempty"` Error *ResponseError `json:"error,omitempty"` IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"` } type ResponseError struct { - Code string `json:"code"` - Message string `json:"message"` + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` } type IncompleteDetails struct { - Reason string `json:"reason"` + Reason string `json:"reason,omitempty"` } From 1d76bc1a8b37d78ca815db0aa3e78de60f338851 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Tue, 2 Dec 2025 19:09:01 +0800 Subject: [PATCH 13/28] fix: MCPToolApprovalRequest definition (#600) --- schema/agentic_message.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 367debd97..93dd817ca 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -42,7 +42,7 @@ const ( ContentBlockTypeServerToolResult ContentBlockType = "server_tool_result" ContentBlockTypeMCPToolCall ContentBlockType = "mcp_tool_call" ContentBlockTypeMCPToolResult ContentBlockType = "mcp_tool_result" - ContentBlockTypeMCPListTools ContentBlockType = "mcp_list_tools" + ContentBlockTypeMCPListToolsResult ContentBlockType = "mcp_list_tools_result" ContentBlockTypeMCPToolApprovalRequest ContentBlockType = "mcp_tool_approval_request" ContentBlockTypeMCPToolApprovalResponse ContentBlockType = "mcp_tool_approval_response" ) @@ -332,8 +332,8 @@ type MCPListToolsItem struct { } type MCPToolApprovalRequest struct { - // CallID is the unique ID of the tool call. - CallID string + // ID is the approval request ID. + ID string // Name is the name of the tool to run. Name string // Arguments is the JSON string arguments for the tool call. @@ -431,7 +431,7 @@ func NewContentBlock(block any) *ContentBlock { case *MCPToolResult: return &ContentBlock{Type: ContentBlockTypeMCPToolResult, MCPToolResult: b} case *MCPListToolsResult: - return &ContentBlock{Type: ContentBlockTypeMCPListTools, MCPListToolsResult: b} + return &ContentBlock{Type: ContentBlockTypeMCPListToolsResult, MCPListToolsResult: b} case *MCPToolApprovalRequest: return &ContentBlock{Type: ContentBlockTypeMCPToolApprovalRequest, MCPToolApprovalRequest: b} case *MCPToolApprovalResponse: From 52e8871e053ceffbe9f6b367326032423dd6eb91 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Wed, 3 Dec 2025 15:24:00 +0800 Subject: [PATCH 14/28] feat: define StreamResponseError for openai (#601) --- schema/openai/response_meta.go | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/schema/openai/response_meta.go b/schema/openai/response_meta.go index 90e884173..e1933065b 100644 --- a/schema/openai/response_meta.go +++ b/schema/openai/response_meta.go @@ -17,10 +17,11 @@ package openai type ResponseMetaExtension struct { - ID string `json:"id,omitempty"` - Status string `json:"status,omitempty"` - Error *ResponseError `json:"error,omitempty"` - IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"` + ID string `json:"id,omitempty"` + Status string `json:"status,omitempty"` + Error *ResponseError `json:"error,omitempty"` + StreamError *StreamResponseError `json:"stream_error,omitempty"` + IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"` } type ResponseError struct { @@ -28,6 +29,12 @@ type ResponseError struct { Message string `json:"message,omitempty"` } +type StreamResponseError struct { + Code string + Message string + Param string +} + type IncompleteDetails struct { Reason string `json:"reason,omitempty"` } From feb9e1dd6d26f4637bbb36ae870dc8871700d00d Mon Sep 17 00:00:00 2001 From: Megumin Date: Wed, 3 Dec 2025 17:22:51 +0800 Subject: [PATCH 15/28] feat: support agentic message concat (#576) feat(agentic_model): - format print - support agentic chat template - support to compose agentic odel&agentic tools node - support agentic tool node - support agentic message concat --- components/agentic/callback_extra_test.go | 35 + components/agentic/option_test.go | 79 + components/prompt/callback_extra.go | 38 + components/prompt/chat_template_agentic.go | 84 + .../prompt/chat_template_agentic_test.go | 111 ++ components/prompt/interface.go | 5 + components/types.go | 1 + compose/chain.go | 43 +- compose/chain_branch.go | 49 +- compose/chain_parallel.go | 43 + compose/component_to_graph_node.go | 33 + compose/graph.go | 40 +- compose/tools_node_agentic.go | 125 ++ compose/tools_node_agentic_test.go | 244 +++ compose/types.go | 15 +- internal/concat.go | 6 +- schema/agentic_message.go | 1352 ++++++++++++++++ schema/agentic_message_test.go | 1381 +++++++++++++++++ schema/message.go | 69 +- 19 files changed, 3707 insertions(+), 46 deletions(-) create mode 100644 components/agentic/callback_extra_test.go create mode 100644 components/agentic/option_test.go create mode 100644 components/prompt/chat_template_agentic.go create mode 100644 components/prompt/chat_template_agentic_test.go create mode 100644 compose/tools_node_agentic.go create mode 100644 compose/tools_node_agentic_test.go create mode 100644 schema/agentic_message_test.go diff --git a/components/agentic/callback_extra_test.go b/components/agentic/callback_extra_test.go new file mode 100644 index 000000000..a77da6cd2 --- /dev/null +++ b/components/agentic/callback_extra_test.go @@ -0,0 +1,35 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package agentic + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +func TestConvModel(t *testing.T) { + assert.NotNil(t, ConvCallbackInput(&CallbackInput{})) + assert.NotNil(t, ConvCallbackInput([]*schema.AgenticMessage{})) + assert.Nil(t, ConvCallbackInput("asd")) + + assert.NotNil(t, ConvCallbackOutput(&CallbackOutput{})) + assert.NotNil(t, ConvCallbackOutput(&schema.AgenticMessage{})) + assert.Nil(t, ConvCallbackOutput("asd")) +} diff --git a/components/agentic/option_test.go b/components/agentic/option_test.go new file mode 100644 index 000000000..d349f35ac --- /dev/null +++ b/components/agentic/option_test.go @@ -0,0 +1,79 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package agentic + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +func TestCommon(t *testing.T) { + o := GetCommonOptions(nil, + WithTools([]*schema.ToolInfo{{Name: "test"}}), + WithModel("test"), + WithTemperature(0.1), + WithToolChoice(schema.ToolChoiceAllowed), + WithTopP(0.1), + ) + assert.Len(t, o.Tools, 1) + assert.Equal(t, "test", o.Tools[0].Name) + assert.Equal(t, "test", *o.Model) + assert.Equal(t, float64(0.1), *o.Temperature) + assert.Equal(t, schema.ToolChoiceAllowed, *o.ToolChoice) + assert.Equal(t, float64(0.1), *o.TopP) +} + +func TestImplSpecificOpts(t *testing.T) { + type implSpecificOptions struct { + conf string + index int + } + + withConf := func(conf string) func(o *implSpecificOptions) { + return func(o *implSpecificOptions) { + o.conf = conf + } + } + + withIndex := func(index int) func(o *implSpecificOptions) { + return func(o *implSpecificOptions) { + o.index = index + } + } + + documentOption1 := WrapImplSpecificOptFn(withConf("test_conf")) + documentOption2 := WrapImplSpecificOptFn(withIndex(1)) + + implSpecificOpts := GetImplSpecificOptions(&implSpecificOptions{}, documentOption1, documentOption2) + + assert.Equal(t, &implSpecificOptions{ + conf: "test_conf", + index: 1, + }, implSpecificOpts) + documentOption1 = WrapImplSpecificOptFn(withConf("test_conf")) + documentOption2 = WrapImplSpecificOptFn(withIndex(1)) + + implSpecificOpts = GetImplSpecificOptions(&implSpecificOptions{}, documentOption1, documentOption2) + + assert.Equal(t, &implSpecificOptions{ + conf: "test_conf", + index: 1, + }, implSpecificOpts) +} diff --git a/components/prompt/callback_extra.go b/components/prompt/callback_extra.go index 324a418f3..ff5c3a8ff 100644 --- a/components/prompt/callback_extra.go +++ b/components/prompt/callback_extra.go @@ -21,6 +21,44 @@ import ( "github.com/cloudwego/eino/schema" ) +type AgenticCallbackInput struct { + Variables map[string]any + Templates []schema.AgenticMessagesTemplate + Extra map[string]any +} + +type AgenticCallbackOutput struct { + Result []*schema.AgenticMessage + Templates []schema.AgenticMessagesTemplate + Extra map[string]any +} + +func ConvAgenticCallbackInput(src callbacks.CallbackInput) *AgenticCallbackInput { + switch t := src.(type) { + case *AgenticCallbackInput: + return t + case map[string]any: + return &AgenticCallbackInput{ + Variables: t, + } + default: + return nil + } +} + +func ConvAgenticCallbackOutput(src callbacks.CallbackOutput) *AgenticCallbackOutput { + switch t := src.(type) { + case *AgenticCallbackOutput: + return t + case []*schema.AgenticMessage: + return &AgenticCallbackOutput{ + Result: t, + } + default: + return nil + } +} + // CallbackInput is the input for the callback. type CallbackInput struct { // Variables is the variables for the callback. diff --git a/components/prompt/chat_template_agentic.go b/components/prompt/chat_template_agentic.go new file mode 100644 index 000000000..937d46f26 --- /dev/null +++ b/components/prompt/chat_template_agentic.go @@ -0,0 +1,84 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package prompt + +import ( + "context" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/components" + "github.com/cloudwego/eino/schema" +) + +// FromAgenticMessages creates a new DefaultAgenticChatTemplate from the given templates and format type. +// eg. +// +// template := prompt.FromAgenticMessages(schema.FString, &schema.AgenticMessage{}) +// // in chain, or graph +// chain := compose.NewChain[map[string]any, []*schema.AgenticMessage]() +// chain.AppendAgenticChatTemplate(template) +func FromAgenticMessages(formatType schema.FormatType, templates ...schema.AgenticMessagesTemplate) *DefaultAgenticChatTemplate { + return &DefaultAgenticChatTemplate{ + templates: templates, + formatType: formatType, + } +} + +type DefaultAgenticChatTemplate struct { + templates []schema.AgenticMessagesTemplate + formatType schema.FormatType +} + +func (t *DefaultAgenticChatTemplate) Format(ctx context.Context, vs map[string]any, opts ...Option) (result []*schema.AgenticMessage, err error) { + ctx = callbacks.EnsureRunInfo(ctx, t.GetType(), components.ComponentOfAgenticPrompt) + ctx = callbacks.OnStart(ctx, &AgenticCallbackInput{ + Variables: vs, + Templates: t.templates, + }) + defer func() { + if err != nil { + _ = callbacks.OnError(ctx, err) + } + }() + + result = make([]*schema.AgenticMessage, 0, len(t.templates)) + for _, template := range t.templates { + msgs, err := template.Format(ctx, vs, t.formatType) + if err != nil { + return nil, err + } + + result = append(result, msgs...) + } + + _ = callbacks.OnEnd(ctx, &AgenticCallbackOutput{ + Result: result, + Templates: t.templates, + }) + + return result, nil +} + +// GetType returns the type of the chat template (Default). +func (t *DefaultAgenticChatTemplate) GetType() string { + return "Default" +} + +// IsCallbacksEnabled checks if the callbacks are enabled for the chat template. +func (t *DefaultAgenticChatTemplate) IsCallbacksEnabled() bool { + return true +} diff --git a/components/prompt/chat_template_agentic_test.go b/components/prompt/chat_template_agentic_test.go new file mode 100644 index 000000000..aaa7d6405 --- /dev/null +++ b/components/prompt/chat_template_agentic_test.go @@ -0,0 +1,111 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package prompt + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +func TestAgenticFormat(t *testing.T) { + pyFmtTestTemplate := []schema.AgenticMessagesTemplate{ + &schema.AgenticMessage{ + Role: schema.AgenticRoleTypeUser, + ContentBlocks: []*schema.ContentBlock{ + {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "{context}"}}, + }, + }, + schema.AgenticMessagesPlaceholder("chat_history", true), + } + jinja2TestTemplate := []schema.AgenticMessagesTemplate{ + &schema.AgenticMessage{ + Role: schema.AgenticRoleTypeUser, + ContentBlocks: []*schema.ContentBlock{ + {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "{{context}}"}}, + }, + }, + schema.AgenticMessagesPlaceholder("chat_history", true), + } + goFmtTestTemplate := []schema.AgenticMessagesTemplate{ + &schema.AgenticMessage{ + Role: schema.AgenticRoleTypeUser, + ContentBlocks: []*schema.ContentBlock{ + {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "{{.context}}"}}, + }, + }, + schema.AgenticMessagesPlaceholder("chat_history", true), + } + testValues := map[string]any{ + "context": "it's beautiful day", + "chat_history": []*schema.AgenticMessage{ + { + Role: schema.AgenticRoleTypeUser, + ContentBlocks: []*schema.ContentBlock{ + {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "1"}}, + }, + }, + { + Role: schema.AgenticRoleTypeUser, + ContentBlocks: []*schema.ContentBlock{ + {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "2"}}, + }, + }, + }, + } + expected := []*schema.AgenticMessage{ + { + Role: schema.AgenticRoleTypeUser, + ContentBlocks: []*schema.ContentBlock{ + {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "it's beautiful day"}}, + }, + }, + { + Role: schema.AgenticRoleTypeUser, + ContentBlocks: []*schema.ContentBlock{ + {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "1"}}, + }, + }, + { + Role: schema.AgenticRoleTypeUser, + ContentBlocks: []*schema.ContentBlock{ + {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "2"}}, + }, + }, + } + + // FString + chatTemplate := FromAgenticMessages(schema.FString, pyFmtTestTemplate...) + msgs, err := chatTemplate.Format(context.Background(), testValues) + assert.Nil(t, err) + assert.Equal(t, expected, msgs) + + // Jinja2 + chatTemplate = FromAgenticMessages(schema.Jinja2, jinja2TestTemplate...) + msgs, err = chatTemplate.Format(context.Background(), testValues) + assert.Nil(t, err) + assert.Equal(t, expected, msgs) + + // GoTemplate + chatTemplate = FromAgenticMessages(schema.GoTemplate, goFmtTestTemplate...) + msgs, err = chatTemplate.Format(context.Background(), testValues) + assert.Nil(t, err) + assert.Equal(t, expected, msgs) +} diff --git a/components/prompt/interface.go b/components/prompt/interface.go index eac695eda..7ffe7216a 100644 --- a/components/prompt/interface.go +++ b/components/prompt/interface.go @@ -23,6 +23,7 @@ import ( ) var _ ChatTemplate = &DefaultChatTemplate{} +var _ AgenticChatTemplate = &DefaultAgenticChatTemplate{} // ChatTemplate formats a variables map into a list of messages for a ChatModel. // @@ -42,3 +43,7 @@ var _ ChatTemplate = &DefaultChatTemplate{} type ChatTemplate interface { Format(ctx context.Context, vs map[string]any, opts ...Option) ([]*schema.Message, error) } + +type AgenticChatTemplate interface { + Format(ctx context.Context, vs map[string]any, opts ...Option) ([]*schema.AgenticMessage, error) +} diff --git a/components/types.go b/components/types.go index a23d82a68..2ba088e93 100644 --- a/components/types.go +++ b/components/types.go @@ -66,6 +66,7 @@ type Component string const ( // ComponentOfPrompt identifies chat template components. ComponentOfPrompt Component = "ChatTemplate" + ComponentOfAgenticPrompt Component = "AgenticChatTemplate" // ComponentOfChatModel identifies chat model components. ComponentOfChatModel Component = "ChatModel" ComponentOfAgenticModel Component = "AgenticModel" diff --git a/compose/chain.go b/compose/chain.go index 5e4a8e1c0..8484e8767 100644 --- a/compose/chain.go +++ b/compose/chain.go @@ -22,6 +22,7 @@ import ( "fmt" "reflect" + "github.com/cloudwego/eino/components/agentic" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" @@ -174,6 +175,18 @@ func (c *Chain[I, O]) AppendChatModel(node model.BaseChatModel, opts ...GraphAdd return c } +// AppendAgenticModel add a agentic.Model node to the chain. +// e.g. +// +// model, err := openai.NewAgenticModel(ctx, config) +// if err != nil {...} +// chain.AppendAgenticModel(model) +func (c *Chain[I, O]) AppendAgenticModel(node agentic.Model, opts ...GraphAddNodeOpt) *Chain[I, O] { + gNode, options := toAgenticModelNode(node, opts...) + c.addNode(gNode, options) + return c +} + // AppendChatTemplate add a ChatTemplate node to the chain. // eg. // @@ -189,11 +202,23 @@ func (c *Chain[I, O]) AppendChatTemplate(node prompt.ChatTemplate, opts ...Graph return c } +// AppendAgenticChatTemplate add a prompt.AgenticChatTemplate node to the chain. +// eg. +// +// chatTemplate, err := prompt.FromAgenticMessages(schema.FString, &schema.AgenticMessage{}) +// +// chain.AppendAgenticChatTemplate(chatTemplate) +func (c *Chain[I, O]) AppendAgenticChatTemplate(node prompt.AgenticChatTemplate, opts ...GraphAddNodeOpt) *Chain[I, O] { + gNode, options := toAgenticChatTemplateNode(node, opts...) + c.addNode(gNode, options) + return c +} + // AppendToolsNode add a ToolsNode node to the chain. // e.g. // -// toolsNode, err := tools.NewToolNode(ctx, &tools.ToolsNodeConfig{ -// Tools: []tools.Tool{...}, +// toolsNode, err := compose.NewToolNode(ctx, &compose.ToolsNodeConfig{ +// Tools: []tools.BaseTool{...}, // }) // // chain.AppendToolsNode(toolsNode) @@ -203,6 +228,20 @@ func (c *Chain[I, O]) AppendToolsNode(node *ToolsNode, opts ...GraphAddNodeOpt) return c } +// AppendAgenticToolsNode add a AgenticToolsNode node to the chain. +// e.g. +// +// toolsNode, err := compose.NewAgenticToolsNode(ctx, &compose.ToolsNodeConfig{ +// Tools: []tools.BaseTool{...}, +// }) +// +// chain.AppendAgenticToolsNode(toolsNode) +func (c *Chain[I, O]) AppendAgenticToolsNode(node *AgenticToolsNode, opts ...GraphAddNodeOpt) *Chain[I, O] { + gNode, options := toAgenticToolsNode(node, opts...) + c.addNode(gNode, options) + return c +} + // AppendDocumentTransformer add a DocumentTransformer node to the chain. // e.g. // diff --git a/compose/chain_branch.go b/compose/chain_branch.go index ec3a433af..004dbfac3 100644 --- a/compose/chain_branch.go +++ b/compose/chain_branch.go @@ -20,6 +20,7 @@ import ( "context" "fmt" + "github.com/cloudwego/eino/components/agentic" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" @@ -146,6 +147,22 @@ func (cb *ChainBranch) AddChatModel(key string, node model.BaseChatModel, opts . return cb.addNode(key, gNode, options) } +// AddAgenticModel adds a agentic.Model node to the branch. +// eg. +// +// model1, err := openai.NewAgenticModel(ctx, &openai.AgenticModelConfig{ +// Model: "gpt-4o", +// }) +// model2, err := openai.NewAgenticModel(ctx, &openai.AgenticModelConfig{ +// Model: "gpt-4o-mini", +// }) +// cb.AddAgenticModel("agentic_model_key_1", model1) +// cb.AddAgenticModel("agentic_model_key_2", model2) +func (cb *ChainBranch) AddAgenticModel(key string, node agentic.Model, opts ...GraphAddNodeOpt) *ChainBranch { + gNode, options := toAgenticModelNode(node, opts...) + return cb.addNode(key, gNode, options) +} + // AddChatTemplate adds a ChatTemplate node to the branch. // eg. // @@ -167,11 +184,26 @@ func (cb *ChainBranch) AddChatTemplate(key string, node prompt.ChatTemplate, opt return cb.addNode(key, gNode, options) } +// AddAgenticChatTemplate adds a prompt.AgenticChatTemplate node to the branch. +// eg. +// +// chatTemplate, err := prompt.FromAgenticMessages(schema.FString, &schema.AgenticMessage{}) +// +// cb.AddAgenticChatTemplate("chat_template_key_01", chatTemplate) +// +// chatTemplate2, err := prompt.FromAgenticMessages(schema.FString, &schema.AgenticMessage{}) +// +// cb.AddAgenticChatTemplate("chat_template_key_02", chatTemplate2) +func (cb *ChainBranch) AddAgenticChatTemplate(key string, node prompt.AgenticChatTemplate, opts ...GraphAddNodeOpt) *ChainBranch { + gNode, options := toAgenticChatTemplateNode(node, opts...) + return cb.addNode(key, gNode, options) +} + // AddToolsNode adds a ToolsNode to the branch. // eg. // -// toolsNode, err := tools.NewToolNode(ctx, &tools.ToolsNodeConfig{ -// Tools: []tools.Tool{...}, +// toolsNode, err := compose.NewToolNode(ctx, &compose.ToolsNodeConfig{ +// Tools: []tools.BaseTool{...}, // }) // // cb.AddToolsNode("tools_node_key", toolsNode) @@ -180,6 +212,19 @@ func (cb *ChainBranch) AddToolsNode(key string, node *ToolsNode, opts ...GraphAd return cb.addNode(key, gNode, options) } +// AddAgenticToolsNode adds a AgenticToolsNode to the branch. +// eg. +// +// toolsNode, err := compose.NewAgenticToolsNode(ctx, &compose.ToolsNodeConfig{ +// Tools: []tools.BaseTool{...}, +// }) +// +// cb.AddAgenticToolsNode("tools_node_key", toolsNode) +func (cb *ChainBranch) AddAgenticToolsNode(key string, node *AgenticToolsNode, opts ...GraphAddNodeOpt) *ChainBranch { + gNode, options := toAgenticToolsNode(node, opts...) + return cb.addNode(key, gNode, options) +} + // AddLambda adds a Lambda node to the branch. // eg. // diff --git a/compose/chain_parallel.go b/compose/chain_parallel.go index 64cdf2db1..128ed4a26 100644 --- a/compose/chain_parallel.go +++ b/compose/chain_parallel.go @@ -19,6 +19,7 @@ package compose import ( "fmt" + "github.com/cloudwego/eino/components/agentic" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" @@ -70,6 +71,24 @@ func (p *Parallel) AddChatModel(outputKey string, node model.BaseChatModel, opts return p.addNode(outputKey, gNode, options) } +// AddAgenticModel adds a agentic.Model to the parallel. +// eg. +// +// model1, err := openai.NewAgenticModel(ctx, &openai.AgenticModelConfig{ +// Model: "gpt-4o", +// }) +// +// model2, err := openai.NewAgenticModel(ctx, &openai.AgenticModelConfig{ +// Model: "gpt-4o", +// }) +// +// p.AddAgenticModel("output_key1", model1) +// p.AddAgenticModel("output_key2", model2) +func (p *Parallel) AddAgenticModel(outputKey string, node agentic.Model, opts ...GraphAddNodeOpt) *Parallel { + gNode, options := toAgenticModelNode(node, append(opts, WithOutputKey(outputKey))...) + return p.addNode(outputKey, gNode, options) +} + // AddChatTemplate adds a chat template to the parallel. // eg. // @@ -84,6 +103,17 @@ func (p *Parallel) AddChatTemplate(outputKey string, node prompt.ChatTemplate, o return p.addNode(outputKey, gNode, options) } +// AddAgenticChatTemplate adds a prompt.AgenticChatTemplate to the parallel. +// eg. +// +// chatTemplate01, err := prompt.FromAgenticMessages(schema.FString, &schema.AgenticMessage{}) +// +// p.AddAgenticChatTemplate("output_key01", chatTemplate01) +func (p *Parallel) AddAgenticChatTemplate(outputKey string, node prompt.AgenticChatTemplate, opts ...GraphAddNodeOpt) *Parallel { + gNode, options := toAgenticChatTemplateNode(node, append(opts, WithOutputKey(outputKey))...) + return p.addNode(outputKey, gNode, options) +} + // AddToolsNode adds a tools node to the parallel. // eg. // @@ -97,6 +127,19 @@ func (p *Parallel) AddToolsNode(outputKey string, node *ToolsNode, opts ...Graph return p.addNode(outputKey, gNode, options) } +// AddAgenticToolsNode adds a tools node to the parallel. +// eg. +// +// toolsNode, err := compose.NewAgenticToolsNode(ctx, &compose.ToolsNodeConfig{ +// Tools: []tool.BaseTool{...}, +// }) +// +// p.AddAgenticToolsNode("output_key01", toolsNode) +func (p *Parallel) AddAgenticToolsNode(outputKey string, node *AgenticToolsNode, opts ...GraphAddNodeOpt) *Parallel { + gNode, options := toAgenticToolsNode(node, append(opts, WithOutputKey(outputKey))...) + return p.addNode(outputKey, gNode, options) +} + // AddLambda adds a lambda node to the parallel. // eg. // diff --git a/compose/component_to_graph_node.go b/compose/component_to_graph_node.go index ab4694f1a..e64ce4f19 100644 --- a/compose/component_to_graph_node.go +++ b/compose/component_to_graph_node.go @@ -18,6 +18,7 @@ package compose import ( "github.com/cloudwego/eino/components" + "github.com/cloudwego/eino/components/agentic" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" @@ -101,6 +102,17 @@ func toChatModelNode(node model.BaseChatModel, opts ...GraphAddNodeOpt) (*graphN opts...) } +func toAgenticModelNode(node agentic.Model, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { + return toComponentNode( + node, + components.ComponentOfAgenticModel, + node.Generate, + node.Stream, + nil, nil, + opts..., + ) +} + func toChatTemplateNode(node prompt.ChatTemplate, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { return toComponentNode( node, @@ -112,6 +124,16 @@ func toChatTemplateNode(node prompt.ChatTemplate, opts ...GraphAddNodeOpt) (*gra opts...) } +func toAgenticChatTemplateNode(node prompt.AgenticChatTemplate, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { + return toComponentNode( + node, + components.ComponentOfAgenticPrompt, + node.Format, + nil, nil, nil, + opts..., + ) +} + func toDocumentTransformerNode(node document.Transformer, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { return toComponentNode( node, @@ -134,6 +156,17 @@ func toToolsNode(node *ToolsNode, opts ...GraphAddNodeOpt) (*graphNode, *graphAd opts...) } +func toAgenticToolsNode(node *AgenticToolsNode, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { + return toComponentNode( + node, + ComponentOfAgenticToolsNode, + node.Invoke, + node.Stream, + nil, nil, + opts..., + ) +} + func toLambdaNode(node *Lambda, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { info, options := getNodeInfo(opts...) diff --git a/compose/graph.go b/compose/graph.go index 9370665f0..877b8fb42 100644 --- a/compose/graph.go +++ b/compose/graph.go @@ -23,6 +23,7 @@ import ( "reflect" "strings" + "github.com/cloudwego/eino/components/agentic" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" @@ -352,6 +353,19 @@ func (g *graph) AddChatModelNode(key string, node model.BaseChatModel, opts ...G return g.addNode(key, gNode, options) } +// AddAgenticModelNode add node that implements agentic.Model. +// e.g. +// +// model, err := openai.NewAgenticModel(ctx, &openai.AgenticModelConfig{ +// Model: "gpt-4o", +// }) +// +// graph.AddAgenticModelNode("agentic_model_node_key", model) +func (g *graph) AddAgenticModelNode(key string, node agentic.Model, opts ...GraphAddNodeOpt) error { + gNode, options := toAgenticModelNode(node, opts...) + return g.addNode(key, gNode, options) +} + // AddChatTemplateNode add node that implements prompt.ChatTemplate. // e.g. // @@ -366,10 +380,21 @@ func (g *graph) AddChatTemplateNode(key string, node prompt.ChatTemplate, opts . return g.addNode(key, gNode, options) } -// AddToolsNode adds a node that implements tools.ToolsNode. +// AddAgenticChatTemplateNode add node that implements prompt.AgenticChatTemplate. +// e.g. +// +// chatTemplate, err := prompt.FromAgenticMessages(schema.FString, &schema.AgenticMessage{}) +// +// graph.AddAgenticChatTemplateNode("chat_template_node_key", chatTemplate) +func (g *graph) AddAgenticChatTemplateNode(key string, node prompt.AgenticChatTemplate, opts ...GraphAddNodeOpt) error { + gNode, options := toAgenticChatTemplateNode(node, opts...) + return g.addNode(key, gNode, options) +} + +// AddToolsNode adds a node that implements ToolsNode. // e.g. // -// toolsNode, err := tools.NewToolNode(ctx, &tools.ToolsNodeConfig{}) +// toolsNode, err := compose.NewToolNode(ctx, &compose.ToolsNodeConfig{}) // // graph.AddToolsNode("tools_node_key", toolsNode) func (g *graph) AddToolsNode(key string, node *ToolsNode, opts ...GraphAddNodeOpt) error { @@ -377,6 +402,17 @@ func (g *graph) AddToolsNode(key string, node *ToolsNode, opts ...GraphAddNodeOp return g.addNode(key, gNode, options) } +// AddAgenticToolsNode adds a node that implements AgenticToolsNode. +// e.g. +// +// toolsNode, err := compose.NewAgenticToolsNode(ctx, &compose.ToolsNodeConfig{}) +// +// graph.AddAgenticToolsNode("tools_node_key", toolsNode) +func (g *graph) AddAgenticToolsNode(key string, node *AgenticToolsNode, opts ...GraphAddNodeOpt) error { + gNode, options := toAgenticToolsNode(node, opts...) + return g.addNode(key, gNode, options) +} + // AddDocumentTransformerNode adds a node that implements document.Transformer. // e.g. // diff --git a/compose/tools_node_agentic.go b/compose/tools_node_agentic.go new file mode 100644 index 000000000..38c5c89de --- /dev/null +++ b/compose/tools_node_agentic.go @@ -0,0 +1,125 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "context" + + "github.com/cloudwego/eino/schema" +) + +// NewAgenticToolsNode creates a new AgenticToolsNode. +// e.g. +// +// conf := &ToolsNodeConfig{ +// Tools: []tool.BaseTool{invokableTool1, streamableTool2}, +// } +// toolsNode, err := NewAgenticToolsNode(ctx, conf) +func NewAgenticToolsNode(ctx context.Context, conf *ToolsNodeConfig) (*AgenticToolsNode, error) { + tn, err := NewToolNode(ctx, conf) + if err != nil { + return nil, err + } + return &AgenticToolsNode{inner: tn}, nil +} + +type AgenticToolsNode struct { + inner *ToolsNode +} + +func (a *AgenticToolsNode) Invoke(ctx context.Context, input *schema.AgenticMessage, opts ...ToolsNodeOption) ([]*schema.AgenticMessage, error) { + result, err := a.inner.Invoke(ctx, agenticMessageToToolCallMessage(input), opts...) + if err != nil { + return nil, err + } + return toolMessageToAgenticMessage(result), nil +} + +func (a *AgenticToolsNode) Stream(ctx context.Context, input *schema.AgenticMessage, + opts ...ToolsNodeOption) (*schema.StreamReader[[]*schema.AgenticMessage], error) { + result, err := a.inner.Stream(ctx, agenticMessageToToolCallMessage(input), opts...) + if err != nil { + return nil, err + } + return streamToolMessageToAgenticMessage(result), nil +} + +func agenticMessageToToolCallMessage(input *schema.AgenticMessage) *schema.Message { + var tc []schema.ToolCall + for _, block := range input.ContentBlocks { + if block.Type != schema.ContentBlockTypeFunctionToolCall || block.FunctionToolCall == nil { + continue + } + tc = append(tc, schema.ToolCall{ + ID: block.FunctionToolCall.CallID, + Function: schema.FunctionCall{ + Name: block.FunctionToolCall.Name, + Arguments: block.FunctionToolCall.Arguments, + }, + }) + } + return &schema.Message{ + Role: schema.Assistant, + ToolCalls: tc, + } +} + +func toolMessageToAgenticMessage(input []*schema.Message) []*schema.AgenticMessage { + var results []*schema.ContentBlock + for _, m := range input { + results = append(results, &schema.ContentBlock{ + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: m.ToolCallID, + Name: m.ToolName, + Result: m.Content, + Extra: m.Extra, + }, + }) + } + return []*schema.AgenticMessage{{ + Role: schema.AgenticRoleTypeUser, + ContentBlocks: results, + }} +} + +func streamToolMessageToAgenticMessage(input *schema.StreamReader[[]*schema.Message]) *schema.StreamReader[[]*schema.AgenticMessage] { + return schema.StreamReaderWithConvert(input, func(t []*schema.Message) ([]*schema.AgenticMessage, error) { + var results []*schema.ContentBlock + for i, m := range t { + if m == nil { + continue + } + results = append(results, &schema.ContentBlock{ + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: m.ToolCallID, + Name: m.ToolName, + Result: m.Content, + Extra: m.Extra, + }, + StreamMeta: &schema.StreamMeta{Index: int64(i)}, + }) + } + return []*schema.AgenticMessage{{ + Role: schema.AgenticRoleTypeUser, + ContentBlocks: results, + }}, nil + }) +} + +func (a *AgenticToolsNode) GetType() string { return "" } diff --git a/compose/tools_node_agentic_test.go b/compose/tools_node_agentic_test.go new file mode 100644 index 000000000..dcd3177a9 --- /dev/null +++ b/compose/tools_node_agentic_test.go @@ -0,0 +1,244 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "io" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +func TestAgenticMessageToToolCallMessage(t *testing.T) { + input := &schema.AgenticMessage{ + ContentBlocks: []*schema.ContentBlock{ + { + Type: schema.ContentBlockTypeFunctionToolCall, + FunctionToolCall: &schema.FunctionToolCall{ + CallID: "1", + Name: "name1", + Arguments: "arg1", + }, + }, + { + Type: schema.ContentBlockTypeFunctionToolCall, + FunctionToolCall: &schema.FunctionToolCall{ + CallID: "2", + Name: "name2", + Arguments: "arg2", + }, + }, + { + Type: schema.ContentBlockTypeFunctionToolCall, + FunctionToolCall: &schema.FunctionToolCall{ + CallID: "3", + Name: "name3", + Arguments: "arg3", + }, + }, + }, + } + ret := agenticMessageToToolCallMessage(input) + assert.Equal(t, schema.Assistant, ret.Role) + assert.Equal(t, []schema.ToolCall{ + { + ID: "1", + Function: schema.FunctionCall{ + Name: "name1", + Arguments: "arg1", + }, + }, + { + ID: "2", + Function: schema.FunctionCall{ + Name: "name2", + Arguments: "arg2", + }, + }, + { + ID: "3", + Function: schema.FunctionCall{ + Name: "name3", + Arguments: "arg3", + }, + }, + }, ret.ToolCalls) +} + +func TestToolMessageToAgenticMessage(t *testing.T) { + input := []*schema.Message{ + { + Role: schema.Tool, + Content: "content1", + ToolCallID: "1", + ToolName: "name1", + }, + { + Role: schema.Tool, + Content: "content2", + ToolCallID: "2", + ToolName: "name2", + }, + { + Role: schema.Tool, + Content: "content3", + ToolCallID: "3", + ToolName: "name3", + }, + } + ret := toolMessageToAgenticMessage(input) + assert.Equal(t, 1, len(ret)) + assert.Equal(t, schema.AgenticRoleTypeUser, ret[0].Role) + assert.Equal(t, []*schema.ContentBlock{ + { + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: "1", + Name: "name1", + Result: "content1", + }, + }, + { + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: "2", + Name: "name2", + Result: "content2", + }, + }, + { + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: "3", + Name: "name3", + Result: "content3", + }, + }, + }, ret[0].ContentBlocks) +} + +func TestStreamToolMessageToAgenticMessage(t *testing.T) { + input := schema.StreamReaderFromArray([][]*schema.Message{ + { + { + Role: schema.Tool, + Content: "content1-1", + ToolName: "name1", + ToolCallID: "1", + }, + nil, nil, + }, + { + nil, + { + Role: schema.Tool, + Content: "content2-1", + ToolName: "name2", + ToolCallID: "2", + }, + nil, + }, + { + { + Role: schema.Tool, + Content: "content1-2", + ToolName: "name2", + ToolCallID: "2", + }, + nil, nil, + }, + { + nil, nil, + { + Role: schema.Tool, + Content: "content3-1", + ToolName: "name3", + ToolCallID: "3", + }, + }, + { + nil, + { + Role: schema.Tool, + Content: "content2-2", + ToolName: "name2", + ToolCallID: "2", + }, + nil, + }, + { + nil, nil, + { + Role: schema.Tool, + Content: "content3-2", + ToolName: "name3", + ToolCallID: "3", + }, + }, + }) + ret := streamToolMessageToAgenticMessage(input) + var chunks [][]*schema.AgenticMessage + for { + chunk, err := ret.Recv() + if err == io.EOF { + break + } + assert.NoError(t, err) + chunks = append(chunks, chunk) + } + result, err := schema.ConcatAgenticMessagesArray(chunks) + assert.NoError(t, err) + assert.Equal(t, []*schema.AgenticMessage{ + { + Role: schema.AgenticRoleTypeUser, + ContentBlocks: []*schema.ContentBlock{ + { + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: "1", + Name: "name1", + Result: "content1-1content1-2", + Extra: map[string]interface{}{}, + }, + StreamMeta: &schema.StreamMeta{Index: 0}, + }, + { + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: "2", + Name: "name2", + Result: "content2-1content2-2", + Extra: map[string]interface{}{}, + }, + StreamMeta: &schema.StreamMeta{Index: 1}, + }, + { + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: "3", + Name: "name3", + Result: "content3-1content3-2", + Extra: map[string]interface{}{}, + }, + StreamMeta: &schema.StreamMeta{Index: 2}, + }, + }, + }, + }, result) +} diff --git a/compose/types.go b/compose/types.go index 13d925df2..54f8e2be3 100644 --- a/compose/types.go +++ b/compose/types.go @@ -25,13 +25,14 @@ type component = components.Component // built-in component types in graph node. // it represents the type of the most primitive executable object provided by the user. const ( - ComponentOfUnknown component = "Unknown" - ComponentOfGraph component = "Graph" - ComponentOfWorkflow component = "Workflow" - ComponentOfChain component = "Chain" - ComponentOfPassthrough component = "Passthrough" - ComponentOfToolsNode component = "ToolsNode" - ComponentOfLambda component = "Lambda" + ComponentOfUnknown component = "Unknown" + ComponentOfGraph component = "Graph" + ComponentOfWorkflow component = "Workflow" + ComponentOfChain component = "Chain" + ComponentOfPassthrough component = "Passthrough" + ComponentOfToolsNode component = "ToolsNode" + ComponentOfAgenticToolsNode component = "AgenticToolsNode" + ComponentOfLambda component = "Lambda" ) // NodeTriggerMode controls the triggering mode of graph nodes. diff --git a/internal/concat.go b/internal/concat.go index 2681322ab..fd9b8abc5 100644 --- a/internal/concat.go +++ b/internal/concat.go @@ -99,7 +99,7 @@ func ConcatItems[T any](items []T) (T, error) { if typ.Kind() == reflect.Map { cv, err = concatMaps(v) } else { - cv, err = concatSliceValue(v) + cv, err = ConcatSliceValue(v) } if err != nil { @@ -158,7 +158,7 @@ func concatMaps(ms reflect.Value) (reflect.Value, error) { if v.Type().Elem().Kind() == reflect.Map { cv, err = concatMaps(v) } else { - cv, err = concatSliceValue(v) + cv, err = ConcatSliceValue(v) } if err != nil { @@ -171,7 +171,7 @@ func concatMaps(ms reflect.Value) (reflect.Value, error) { return ret, nil } -func concatSliceValue(val reflect.Value) (reflect.Value, error) { +func ConcatSliceValue(val reflect.Value) (reflect.Value, error) { elmType := val.Type().Elem() if val.Len() == 1 { diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 93dd817ca..2139201ec 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -17,9 +17,17 @@ package schema import ( + "context" + "fmt" + "reflect" + "strings" + "github.com/cloudwego/eino/schema/claude" "github.com/cloudwego/eino/schema/gemini" + + "github.com/cloudwego/eino/internal" "github.com/cloudwego/eino/schema/openai" + "github.com/eino-contrib/jsonschema" ) @@ -440,3 +448,1347 @@ func NewContentBlock(block any) *ContentBlock { return nil } } + +// AgenticMessagesTemplate is the interface for messages template. +// It's used to render a template to a list of agentic messages. +// e.g. +// +// chatTemplate := prompt.FromAgenticMessages( +// &schema.AgenticMessage{ +// Role: schema.AgenticRoleTypeSystem, +// ContentBlocks: []*schema.ContentBlock{ +// {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "you are an eino helper"}}, +// }, +// }, +// schema.AgenticMessagesPlaceholder("history", false), // <= this will use the value of "history" in params +// ) +// msgs, err := chatTemplate.Format(ctx, params) +type AgenticMessagesTemplate interface { + Format(ctx context.Context, vs map[string]any, formatType FormatType) ([]*AgenticMessage, error) +} + +var _ AgenticMessagesTemplate = &AgenticMessage{} +var _ AgenticMessagesTemplate = AgenticMessagesPlaceholder("", false) + +type agenticMessagesPlaceholder struct { + key string + optional bool +} + +// AgenticMessagesPlaceholder can render a placeholder to a list of agentic messages in params. +// e.g. +// +// placeholder := AgenticMessagesPlaceholder("history", false) +// params := map[string]any{ +// "history": []*schema.AgenticMessage{ +// &schema.AgenticMessage{ +// Role: schema.AgenticRoleTypeSystem, +// ContentBlocks: []*schema.ContentBlock{ +// {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "you are an eino helper"}}, +// }, +// }, +// }, +// } +// chatTemplate := chatTpl := prompt.FromMessages( +// schema.AgenticMessagesPlaceholder("history", false), // <= this will use the value of "history" in params +// ) +// msgs, err := chatTemplate.Format(ctx, params) +func AgenticMessagesPlaceholder(key string, optional bool) AgenticMessagesTemplate { + return &agenticMessagesPlaceholder{ + key: key, + optional: optional, + } +} + +func (p *agenticMessagesPlaceholder) Format(_ context.Context, vs map[string]any, _ FormatType) ([]*AgenticMessage, error) { + v, ok := vs[p.key] + if !ok { + if p.optional { + return []*AgenticMessage{}, nil + } + + return nil, fmt.Errorf("message placeholder format: %s not found", p.key) + } + + msgs, ok := v.([]*AgenticMessage) + if !ok { + return nil, fmt.Errorf("only agentic messages can be used to format message placeholder, key: %v, actual type: %v", p.key, reflect.TypeOf(v)) + } + + return msgs, nil +} + +// Format returns the agentic messages after rendering by the given formatType. +// It formats only the user input fields (UserInputText, UserInputImage, UserInputAudio, UserInputVideo, UserInputFile). +// e.g. +// +// msg := &schema.AgenticMessage{ +// Role: schema.AgenticRoleTypeUser, +// ContentBlocks: []*schema.ContentBlock{ +// {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "hello {name}"}}, +// }, +// } +// msgs, err := msg.Format(ctx, map[string]any{"name": "eino"}, schema.FString) +// // msgs[0].ContentBlocks[0].UserInputText.Text will be "hello eino" +func (m *AgenticMessage) Format(_ context.Context, vs map[string]any, formatType FormatType) ([]*AgenticMessage, error) { + copied := *m + + if len(m.ContentBlocks) > 0 { + copiedBlocks := make([]*ContentBlock, len(m.ContentBlocks)) + for i, block := range m.ContentBlocks { + if block == nil { + copiedBlocks[i] = nil + continue + } + + copiedBlock := *block + var err error + + switch block.Type { + case ContentBlockTypeUserInputText: + if block.UserInputText != nil { + copiedBlock.UserInputText, err = formatUserInputText(block.UserInputText, vs, formatType) + if err != nil { + return nil, err + } + } + case ContentBlockTypeUserInputImage: + if block.UserInputImage != nil { + copiedBlock.UserInputImage, err = formatUserInputImage(block.UserInputImage, vs, formatType) + if err != nil { + return nil, err + } + } + case ContentBlockTypeUserInputAudio: + if block.UserInputAudio != nil { + copiedBlock.UserInputAudio, err = formatUserInputAudio(block.UserInputAudio, vs, formatType) + if err != nil { + return nil, err + } + } + case ContentBlockTypeUserInputVideo: + if block.UserInputVideo != nil { + copiedBlock.UserInputVideo, err = formatUserInputVideo(block.UserInputVideo, vs, formatType) + if err != nil { + return nil, err + } + } + case ContentBlockTypeUserInputFile: + if block.UserInputFile != nil { + copiedBlock.UserInputFile, err = formatUserInputFile(block.UserInputFile, vs, formatType) + if err != nil { + return nil, err + } + } + } + + copiedBlocks[i] = &copiedBlock + } + copied.ContentBlocks = copiedBlocks + } + + return []*AgenticMessage{&copied}, nil +} + +func formatUserInputText(uit *UserInputText, vs map[string]any, formatType FormatType) (*UserInputText, error) { + text, err := formatContent(uit.Text, vs, formatType) + if err != nil { + return nil, err + } + copied := *uit + copied.Text = text + return &copied, nil +} + +func formatUserInputImage(uii *UserInputImage, vs map[string]any, formatType FormatType) (*UserInputImage, error) { + copied := *uii + if uii.URL != "" { + url, err := formatContent(uii.URL, vs, formatType) + if err != nil { + return nil, err + } + copied.URL = url + } + if uii.Base64Data != "" { + base64data, err := formatContent(uii.Base64Data, vs, formatType) + if err != nil { + return nil, err + } + copied.Base64Data = base64data + } + return &copied, nil +} + +func formatUserInputAudio(uia *UserInputAudio, vs map[string]any, formatType FormatType) (*UserInputAudio, error) { + copied := *uia + if uia.URL != "" { + url, err := formatContent(uia.URL, vs, formatType) + if err != nil { + return nil, err + } + copied.URL = url + } + if uia.Base64Data != "" { + base64data, err := formatContent(uia.Base64Data, vs, formatType) + if err != nil { + return nil, err + } + copied.Base64Data = base64data + } + return &copied, nil +} + +func formatUserInputVideo(uiv *UserInputVideo, vs map[string]any, formatType FormatType) (*UserInputVideo, error) { + copied := *uiv + if uiv.URL != "" { + url, err := formatContent(uiv.URL, vs, formatType) + if err != nil { + return nil, err + } + copied.URL = url + } + if uiv.Base64Data != "" { + base64data, err := formatContent(uiv.Base64Data, vs, formatType) + if err != nil { + return nil, err + } + copied.Base64Data = base64data + } + return &copied, nil +} + +func formatUserInputFile(uif *UserInputFile, vs map[string]any, formatType FormatType) (*UserInputFile, error) { + copied := *uif + if uif.URL != "" { + url, err := formatContent(uif.URL, vs, formatType) + if err != nil { + return nil, err + } + copied.URL = url + } + if uif.Name != "" { + name, err := formatContent(uif.Name, vs, formatType) + if err != nil { + return nil, err + } + copied.Name = name + } + if uif.Base64Data != "" { + base64data, err := formatContent(uif.Base64Data, vs, formatType) + if err != nil { + return nil, err + } + copied.Base64Data = base64data + } + return &copied, nil +} + +func ConcatAgenticMessagesArray(mas [][]*AgenticMessage) ([]*AgenticMessage, error) { + return buildConcatGenericArray[AgenticMessage](ConcatAgenticMessages)(mas) +} + +func ConcatAgenticMessages(msgs []*AgenticMessage) (*AgenticMessage, error) { + var ( + role AgenticRoleType + blocksList [][]*ContentBlock + blocks []*ContentBlock + metas []*AgenticResponseMeta + ) + + if len(msgs) == 1 { + return msgs[0], nil + } + + for idx, msg := range msgs { + if msg == nil { + return nil, fmt.Errorf("message at index %d is nil", idx) + } + + if msg.Role != "" { + if role == "" { + role = msg.Role + } else if role != msg.Role { + return nil, fmt.Errorf("cannot concat messages with different roles: got '%s' and '%s'", role, msg.Role) + } + } + + for _, block := range msg.ContentBlocks { + if block.StreamMeta == nil { + // Non-streaming block + if len(blocksList) > 0 { + // Cannot mix streaming and non-streaming blocks + return nil, fmt.Errorf("found non-streaming block after streaming blocks") + } + // Collect non-streaming block + blocks = append(blocks, block) + } else { + // Streaming block + if len(blocks) > 0 { + // Cannot mix non-streaming and streaming blocks + return nil, fmt.Errorf("found streaming block after non-streaming blocks") + } + // Collect streaming block by index + blocksList = expandSlice(int(block.StreamMeta.Index), blocksList) + blocksList[block.StreamMeta.Index] = append(blocksList[block.StreamMeta.Index], block) + } + } + + if msg.ResponseMeta != nil { + metas = append(metas, msg.ResponseMeta) + } + } + + meta, err := concatAgenticResponseMeta(metas) + if err != nil { + return nil, fmt.Errorf("failed to concat agentic response meta: %w", err) + } + + if len(blocksList) > 0 { + // All blocks are streaming, concat each group by index + blocks = make([]*ContentBlock, len(blocksList)) + for i, bs := range blocksList { + if len(bs) == 0 { + continue + } + b, err := concatAgenticContentBlocks(bs) + if err != nil { + return nil, fmt.Errorf("failed to concat content blocks at index %d: %w", i, err) + } + blocks[i] = b + } + } + + for i := 0; i < len(blocks); i++ { + if blocks[i] == nil { + blocks = append(blocks[:i], blocks[i+1:]...) + } + } + + return &AgenticMessage{ + ResponseMeta: meta, + Role: role, + ContentBlocks: blocks, + }, nil +} + +func concatAgenticResponseMeta(metas []*AgenticResponseMeta) (*AgenticResponseMeta, error) { + if len(metas) == 0 { + return nil, nil + } + ret := &AgenticResponseMeta{ + TokenUsage: &TokenUsage{}, + OpenAIExtension: nil, + ClaudeExtension: nil, + GeminiExtension: nil, + Extension: nil, + } + for _, meta := range metas { + ret.Extension = meta.Extension + ret.OpenAIExtension = meta.OpenAIExtension + ret.ClaudeExtension = meta.ClaudeExtension + ret.GeminiExtension = meta.GeminiExtension + if meta.TokenUsage != nil { + ret.TokenUsage.CompletionTokens += meta.TokenUsage.CompletionTokens + ret.TokenUsage.CompletionTokenDetails.ReasoningTokens += meta.TokenUsage.CompletionTokenDetails.ReasoningTokens + ret.TokenUsage.PromptTokens += meta.TokenUsage.PromptTokens + ret.TokenUsage.PromptTokenDetails.CachedTokens += meta.TokenUsage.PromptTokenDetails.CachedTokens + ret.TokenUsage.TotalTokens += meta.TokenUsage.TotalTokens + } + } + return ret, nil +} + +func concatAgenticContentBlocks(blocks []*ContentBlock) (*ContentBlock, error) { + if len(blocks) == 0 { + return nil, fmt.Errorf("no content blocks to concat") + } + blockType := blocks[0].Type + index := blocks[0].StreamMeta.Index + switch blockType { + case ContentBlockTypeReasoning: + return concatContentBlockHelper(blocks, blockType, "reasoning", + func(b *ContentBlock) *Reasoning { return b.Reasoning }, + concatReasoning, + func(r *Reasoning) *ContentBlock { + return &ContentBlock{Type: blockType, Reasoning: r, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeUserInputText: + return concatContentBlockHelper(blocks, blockType, "user input text", + func(b *ContentBlock) *UserInputText { return b.UserInputText }, + concatUserInputText, + func(t *UserInputText) *ContentBlock { + return &ContentBlock{Type: blockType, UserInputText: t, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeUserInputImage: + return concatContentBlockHelper(blocks, blockType, "user input image", + func(b *ContentBlock) *UserInputImage { return b.UserInputImage }, + concatUserInputImage, + func(i *UserInputImage) *ContentBlock { + return &ContentBlock{Type: blockType, UserInputImage: i, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeUserInputAudio: + return concatContentBlockHelper(blocks, blockType, "user input audio", + func(b *ContentBlock) *UserInputAudio { return b.UserInputAudio }, + concatUserInputAudio, + func(a *UserInputAudio) *ContentBlock { + return &ContentBlock{Type: blockType, UserInputAudio: a, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeUserInputVideo: + return concatContentBlockHelper(blocks, blockType, "user input video", + func(b *ContentBlock) *UserInputVideo { return b.UserInputVideo }, + concatUserInputVideo, + func(v *UserInputVideo) *ContentBlock { + return &ContentBlock{Type: blockType, UserInputVideo: v, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeUserInputFile: + return concatContentBlockHelper(blocks, blockType, "user input file", + func(b *ContentBlock) *UserInputFile { return b.UserInputFile }, + concatUserInputFile, + func(f *UserInputFile) *ContentBlock { + return &ContentBlock{Type: blockType, UserInputFile: f, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeAssistantGenText: + return concatContentBlockHelper(blocks, blockType, "assistant gen text", + func(b *ContentBlock) *AssistantGenText { return b.AssistantGenText }, + concatAssistantGenText, + func(t *AssistantGenText) *ContentBlock { + return &ContentBlock{Type: blockType, AssistantGenText: t, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeAssistantGenImage: + return concatContentBlockHelper(blocks, blockType, "assistant gen image", + func(b *ContentBlock) *AssistantGenImage { return b.AssistantGenImage }, + concatAssistantGenImage, + func(i *AssistantGenImage) *ContentBlock { + return &ContentBlock{Type: blockType, AssistantGenImage: i, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeAssistantGenAudio: + return concatContentBlockHelper(blocks, blockType, "assistant gen audio", + func(b *ContentBlock) *AssistantGenAudio { return b.AssistantGenAudio }, + concatAssistantGenAudio, + func(a *AssistantGenAudio) *ContentBlock { + return &ContentBlock{Type: blockType, AssistantGenAudio: a, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeAssistantGenVideo: + return concatContentBlockHelper(blocks, blockType, "assistant gen video", + func(b *ContentBlock) *AssistantGenVideo { return b.AssistantGenVideo }, + concatAssistantGenVideo, + func(v *AssistantGenVideo) *ContentBlock { + return &ContentBlock{Type: blockType, AssistantGenVideo: v, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeFunctionToolCall: + return concatContentBlockHelper(blocks, blockType, "function tool call", + func(b *ContentBlock) *FunctionToolCall { return b.FunctionToolCall }, + concatFunctionToolCall, + func(c *FunctionToolCall) *ContentBlock { + return &ContentBlock{Type: blockType, FunctionToolCall: c, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeFunctionToolResult: + return concatContentBlockHelper(blocks, blockType, "function tool result", + func(b *ContentBlock) *FunctionToolResult { return b.FunctionToolResult }, + concatFunctionToolResult, + func(r *FunctionToolResult) *ContentBlock { + return &ContentBlock{Type: blockType, FunctionToolResult: r, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeServerToolCall: + return concatContentBlockHelper(blocks, blockType, "server tool call", + func(b *ContentBlock) *ServerToolCall { return b.ServerToolCall }, + concatServerToolCall, + func(c *ServerToolCall) *ContentBlock { + return &ContentBlock{Type: blockType, ServerToolCall: c, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeServerToolResult: + return concatContentBlockHelper(blocks, blockType, "server tool result", + func(b *ContentBlock) *ServerToolResult { return b.ServerToolResult }, + concatServerToolResult, + func(r *ServerToolResult) *ContentBlock { + return &ContentBlock{Type: blockType, ServerToolResult: r, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeMCPToolCall: + return concatContentBlockHelper(blocks, blockType, "MCP tool call", + func(b *ContentBlock) *MCPToolCall { return b.MCPToolCall }, + concatMCPToolCall, + func(c *MCPToolCall) *ContentBlock { + return &ContentBlock{Type: blockType, MCPToolCall: c, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeMCPToolResult: + return concatContentBlockHelper(blocks, blockType, "MCP tool result", + func(b *ContentBlock) *MCPToolResult { return b.MCPToolResult }, + concatMCPToolResult, + func(r *MCPToolResult) *ContentBlock { + return &ContentBlock{Type: blockType, MCPToolResult: r, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeMCPListToolsResult: + return concatContentBlockHelper(blocks, blockType, "MCP list tools", + func(b *ContentBlock) *MCPListToolsResult { return b.MCPListToolsResult }, + concatMCPListToolsResult, + func(r *MCPListToolsResult) *ContentBlock { + return &ContentBlock{Type: blockType, MCPListToolsResult: r, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeMCPToolApprovalRequest: + return concatContentBlockHelper(blocks, blockType, "MCP tool approval request", + func(b *ContentBlock) *MCPToolApprovalRequest { return b.MCPToolApprovalRequest }, + concatMCPToolApprovalRequest, + func(r *MCPToolApprovalRequest) *ContentBlock { + return &ContentBlock{Type: blockType, MCPToolApprovalRequest: r, StreamMeta: &StreamMeta{Index: index}} + }) + + case ContentBlockTypeMCPToolApprovalResponse: + return concatContentBlockHelper(blocks, blockType, "MCP tool approval response", + func(b *ContentBlock) *MCPToolApprovalResponse { return b.MCPToolApprovalResponse }, + concatMCPToolApprovalResponse, + func(r *MCPToolApprovalResponse) *ContentBlock { + return &ContentBlock{Type: blockType, MCPToolApprovalResponse: r, StreamMeta: &StreamMeta{Index: index}} + }) + + default: + return nil, fmt.Errorf("unknown content block type: %s", blockType) + } +} + +// concatContentBlockHelper is a generic helper function that reduces code duplication +// for concatenating content blocks of a specific type. +func concatContentBlockHelper[T any]( + blocks []*ContentBlock, + expectedType ContentBlockType, + typeName string, + getter func(*ContentBlock) *T, + concatFunc func([]*T) (*T, error), + constructor func(*T) *ContentBlock, +) (*ContentBlock, error) { + items, err := genericGetTFromContentBlocks(blocks, func(block *ContentBlock) (*T, error) { + if block.Type != expectedType { + return nil, fmt.Errorf("expected %s block, got %s", typeName, block.Type) + } + item := getter(block) + if item == nil { + return nil, fmt.Errorf("%s content is nil", typeName) + } + return item, nil + }) + if err != nil { + return nil, err + } + + concatenated, err := concatFunc(items) + if err != nil { + return nil, err + } + + return constructor(concatenated), nil +} + +func genericGetTFromContentBlocks[T any](blocks []*ContentBlock, checkAndGetter func(block *ContentBlock) (T, error)) ([]T, error) { + ret := make([]T, 0, len(blocks)) + for _, block := range blocks { + t, err := checkAndGetter(block) + if err != nil { + return nil, err + } + ret = append(ret, t) + } + return ret, nil +} + +// Concatenation strategies for different content block types: +// +// String concatenation (incremental streaming): +// - Reasoning: Summary texts are concatenated, grouped by Index if present +// - UserInputText: Text fields are concatenated +// - AssistantGenText: Text fields are concatenated, annotations/citations are merged +// - FunctionToolCall: Arguments (JSON strings) are concatenated incrementally +// - FunctionToolResult: Result strings are concatenated +// - ServerToolCall: Arguments are merged (last non-nil value for any type) +// - ServerToolResult: Results are merged using internal.ConcatItems +// - MCPToolCall: Arguments (JSON strings) are concatenated incrementally +// - MCPToolResult: Result strings are concatenated +// - MCPListToolsResult: Tools arrays are merged +// - MCPToolApprovalRequest: Arguments are concatenated +// +// Take last block (non-streaming content): +// - UserInputImage, UserInputAudio, UserInputVideo, UserInputFile: Return last block +// - AssistantGenImage, AssistantGenAudio, AssistantGenVideo: Return last block +// - MCPToolApprovalResponse: Return last block +// + +func concatReasoning(reasons []*Reasoning) (*Reasoning, error) { + if len(reasons) == 0 { + return nil, fmt.Errorf("no reasoning found") + } + if len(reasons) == 1 { + return reasons[0], nil + } + + ret := &Reasoning{ + Summary: make([]*ReasoningSummary, 0), + EncryptedContent: "", + Extra: make(map[string]any), + } + + // Collect all summaries from all reasons + allSummaries := make([]*ReasoningSummary, 0) + for _, r := range reasons { + if r == nil { + continue + } + allSummaries = append(allSummaries, r.Summary...) + if r.EncryptedContent != "" { + ret.EncryptedContent += r.EncryptedContent + } + for k, v := range r.Extra { + ret.Extra[k] = v + } + } + + // Group by Index and concatenate Text for same Index + // Use dynamic array that expands as needed + var summaryArray []*ReasoningSummary + for _, s := range allSummaries { + idx := s.Index + // Expand array if needed + summaryArray = expandSlice(int(idx), summaryArray) + if summaryArray[idx] == nil { + // Create new entry with a copy of Index + summaryArray[idx] = &ReasoningSummary{ + Index: idx, + Text: s.Text, + } + } else { + // Concatenate text for same index + summaryArray[idx].Text += s.Text + } + } + + // Convert array to slice, filtering out nil entries + ret.Summary = make([]*ReasoningSummary, 0, len(summaryArray)) + for _, summary := range summaryArray { + if summary != nil { + ret.Summary = append(ret.Summary, summary) + } + } + + return ret, nil +} + +func concatUserInputText(texts []*UserInputText) (*UserInputText, error) { + if len(texts) == 0 { + return nil, fmt.Errorf("no user input text found") + } + if len(texts) == 1 { + return texts[0], nil + } + + ret := &UserInputText{ + Text: "", + Extra: make(map[string]any), + } + + for _, t := range texts { + if t == nil { + continue + } + ret.Text += t.Text + for k, v := range t.Extra { + ret.Extra[k] = v + } + } + + return ret, nil +} + +func concatUserInputImage(images []*UserInputImage) (*UserInputImage, error) { + if len(images) == 0 { + return nil, fmt.Errorf("no user input image found") + } + return images[len(images)-1], nil +} + +func concatUserInputAudio(audios []*UserInputAudio) (*UserInputAudio, error) { + if len(audios) == 0 { + return nil, fmt.Errorf("no user input audio found") + } + return audios[len(audios)-1], nil +} + +func concatUserInputVideo(videos []*UserInputVideo) (*UserInputVideo, error) { + if len(videos) == 0 { + return nil, fmt.Errorf("no user input video found") + } + return videos[len(videos)-1], nil +} + +func concatUserInputFile(files []*UserInputFile) (*UserInputFile, error) { + if len(files) == 0 { + return nil, fmt.Errorf("no user input file found") + } + return files[len(files)-1], nil +} + +func concatAssistantGenText(texts []*AssistantGenText) (*AssistantGenText, error) { + if len(texts) == 0 { + return nil, fmt.Errorf("no assistant gen text found") + } + if len(texts) == 1 { + return texts[0], nil + } + + ret := &AssistantGenText{ + Text: "", + OpenAIExtension: nil, + ClaudeExtension: nil, + Extra: make(map[string]any), + } + + for _, t := range texts { + if t == nil { + continue + } + ret.Text += t.Text + if t.OpenAIExtension != nil { + if ret.OpenAIExtension == nil { + ret.OpenAIExtension = &openai.AssistantGenTextExtension{} + } + ret.OpenAIExtension.Annotations = append(ret.OpenAIExtension.Annotations, t.OpenAIExtension.Annotations...) + } + if t.ClaudeExtension != nil { + if ret.ClaudeExtension == nil { + ret.ClaudeExtension = &claude.AssistantGenTextExtension{} + } + ret.ClaudeExtension.Citations = append(ret.ClaudeExtension.Citations, t.ClaudeExtension.Citations...) + } + for k, v := range t.Extra { + ret.Extra[k] = v + } + } + + return ret, nil +} + +func concatAssistantGenImage(images []*AssistantGenImage) (*AssistantGenImage, error) { + if len(images) == 0 { + return nil, fmt.Errorf("no assistant gen image found") + } + return images[len(images)-1], nil +} + +func concatAssistantGenAudio(audios []*AssistantGenAudio) (*AssistantGenAudio, error) { + if len(audios) == 0 { + return nil, fmt.Errorf("no assistant gen audio found") + } + return audios[len(audios)-1], nil +} + +func concatAssistantGenVideo(videos []*AssistantGenVideo) (*AssistantGenVideo, error) { + if len(videos) == 0 { + return nil, fmt.Errorf("no assistant gen video found") + } + return videos[len(videos)-1], nil +} + +func concatFunctionToolCall(calls []*FunctionToolCall) (*FunctionToolCall, error) { + if len(calls) == 0 { + return nil, fmt.Errorf("no function tool call found") + } + if len(calls) == 1 { + return calls[0], nil + } + + // For tool calls, arguments are typically built incrementally during streaming + ret := &FunctionToolCall{ + Extra: make(map[string]any), + } + + for _, c := range calls { + if c == nil { + continue + } + if ret.CallID == "" { + ret.CallID = c.CallID + } + if ret.Name == "" { + ret.Name = c.Name + } + ret.Arguments += c.Arguments + for k, v := range c.Extra { + ret.Extra[k] = v + } + } + + return ret, nil +} + +func concatFunctionToolResult(results []*FunctionToolResult) (*FunctionToolResult, error) { + if len(results) == 0 { + return nil, fmt.Errorf("no function tool result found") + } + if len(results) == 1 { + return results[0], nil + } + + ret := &FunctionToolResult{ + Extra: make(map[string]any), + } + + for _, r := range results { + if r == nil { + continue + } + if ret.CallID == "" { + ret.CallID = r.CallID + } + if ret.Name == "" { + ret.Name = r.Name + } + ret.Result += r.Result + for k, v := range r.Extra { + ret.Extra[k] = v + } + } + + return ret, nil +} + +func concatServerToolCall(calls []*ServerToolCall) (*ServerToolCall, error) { + if len(calls) == 0 { + return nil, fmt.Errorf("no server tool call found") + } + if len(calls) == 1 { + return calls[0], nil + } + + // ServerToolCall Arguments is of type any; merge strategy uses the last non-nil value + ret := &ServerToolCall{ + Extra: make(map[string]any), + } + + for _, c := range calls { + if c == nil { + continue + } + if ret.Name == "" { + ret.Name = c.Name + } + if ret.CallID == "" { + ret.CallID = c.CallID + } + if c.Arguments != nil { + ret.Arguments = c.Arguments + } + for k, v := range c.Extra { + ret.Extra[k] = v + } + } + + return ret, nil +} + +func concatServerToolResult(results []*ServerToolResult) (*ServerToolResult, error) { + if len(results) == 0 { + return nil, fmt.Errorf("no server tool result found") + } + if len(results) == 1 { + return results[0], nil + } + + // ServerToolResult Result is of type any; merge strategy uses the last non-nil value + ret := &ServerToolResult{ + Extra: make(map[string]any), + } + + tZeroResult := reflect.TypeOf(results[0].Result) + data := reflect.MakeSlice(reflect.SliceOf(tZeroResult), 0, 0) + for _, r := range results { + if r == nil { + continue + } + if ret.Name == "" { + ret.Name = r.Name + } + if ret.CallID == "" { + ret.CallID = r.CallID + } + if r.Result != nil { + vResult := reflect.ValueOf(r.Result) + if tZeroResult != vResult.Type() { + return nil, fmt.Errorf("tool result types are different: %v %v", tZeroResult, vResult.Type()) + } + data = reflect.Append(data, vResult) + } + for k, v := range r.Extra { + ret.Extra[k] = v + } + } + + d, err := internal.ConcatSliceValue(data) + if err != nil { + return nil, fmt.Errorf("failed to concat server tool result: %v", err) + } + ret.Result = d + + return ret, nil +} + +func concatMCPToolCall(calls []*MCPToolCall) (*MCPToolCall, error) { + if len(calls) == 0 { + return nil, fmt.Errorf("no mcp tool call found") + } + if len(calls) == 1 { + return calls[0], nil + } + + ret := &MCPToolCall{ + Extra: make(map[string]any), + } + + for _, c := range calls { + if c == nil { + continue + } + if ret.ServerLabel == "" { + ret.ServerLabel = c.ServerLabel + } + if ret.ApprovalRequestID == "" { + ret.ApprovalRequestID = c.ApprovalRequestID + } + if ret.CallID == "" { + ret.CallID = c.CallID + } + if ret.Name == "" { + ret.Name = c.Name + } + ret.Arguments += c.Arguments + for k, v := range c.Extra { + ret.Extra[k] = v + } + } + + return ret, nil +} + +func concatMCPToolResult(results []*MCPToolResult) (*MCPToolResult, error) { + if len(results) == 0 { + return nil, fmt.Errorf("no mcp tool result found") + } + if len(results) == 1 { + return results[0], nil + } + + ret := &MCPToolResult{ + Extra: make(map[string]any), + } + + for _, r := range results { + if r == nil { + continue + } + if ret.CallID == "" { + ret.CallID = r.CallID + } + if ret.Name == "" { + ret.Name = r.Name + } + ret.Result += r.Result + if r.Error != nil { + ret.Error = r.Error // Use the last error + } + for k, v := range r.Extra { + ret.Extra[k] = v + } + } + + return ret, nil +} + +func concatMCPListToolsResult(results []*MCPListToolsResult) (*MCPListToolsResult, error) { + if len(results) == 0 { + return nil, fmt.Errorf("no mcp list tools result found") + } + if len(results) == 1 { + return results[0], nil + } + + ret := &MCPListToolsResult{ + Tools: make([]*MCPListToolsItem, 0), + Extra: make(map[string]any), + } + + for _, r := range results { + if r == nil { + continue + } + if ret.ServerLabel == "" { + ret.ServerLabel = r.ServerLabel + } + ret.Tools = append(ret.Tools, r.Tools...) + if r.Error != "" { + ret.Error = r.Error // Use the last error + } + for k, v := range r.Extra { + ret.Extra[k] = v + } + } + + return ret, nil +} + +func concatMCPToolApprovalRequest(requests []*MCPToolApprovalRequest) (*MCPToolApprovalRequest, error) { + if len(requests) == 0 { + return nil, fmt.Errorf("no mcp tool approval request found") + } + if len(requests) == 1 { + return requests[0], nil + } + + ret := &MCPToolApprovalRequest{ + Extra: make(map[string]any), + } + + for _, r := range requests { + if r == nil { + continue + } + if ret.ID == "" { + ret.ID = r.ID + } + if ret.Name == "" { + ret.Name = r.Name + } + ret.Arguments += r.Arguments + if ret.ServerLabel == "" { + ret.ServerLabel = r.ServerLabel + } + for k, v := range r.Extra { + ret.Extra[k] = v + } + } + + return ret, nil +} + +func concatMCPToolApprovalResponse(responses []*MCPToolApprovalResponse) (*MCPToolApprovalResponse, error) { + if len(responses) == 0 { + return nil, fmt.Errorf("no mcp tool approval response found") + } + if len(responses) == 1 { + return responses[0], nil + } + + return responses[len(responses)-1], nil +} + +func expandSlice[T any](idx int, s []T) []T { + if len(s) > idx { + return s + } + return append(s, make([]T, idx-len(s)+1)...) +} + +func (m *AgenticMessage) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf("role: %s\n", m.Role)) + + if len(m.ContentBlocks) > 0 { + sb.WriteString("content_blocks:\n") + for i, block := range m.ContentBlocks { + if block == nil { + continue + } + sb.WriteString(fmt.Sprintf(" [%d] %s", i, block.String())) + } + } + + if m.ResponseMeta != nil { + sb.WriteString(m.ResponseMeta.String()) + } + + return sb.String() +} + +func (b *ContentBlock) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf("type: %s\n", b.Type)) + + switch b.Type { + case ContentBlockTypeReasoning: + if b.Reasoning != nil { + sb.WriteString(b.Reasoning.String()) + } + case ContentBlockTypeUserInputText: + if b.UserInputText != nil { + sb.WriteString(b.UserInputText.String()) + } + case ContentBlockTypeUserInputImage: + if b.UserInputImage != nil { + sb.WriteString(b.UserInputImage.String()) + } + case ContentBlockTypeUserInputAudio: + if b.UserInputAudio != nil { + sb.WriteString(b.UserInputAudio.String()) + } + case ContentBlockTypeUserInputVideo: + if b.UserInputVideo != nil { + sb.WriteString(b.UserInputVideo.String()) + } + case ContentBlockTypeUserInputFile: + if b.UserInputFile != nil { + sb.WriteString(b.UserInputFile.String()) + } + case ContentBlockTypeAssistantGenText: + if b.AssistantGenText != nil { + sb.WriteString(b.AssistantGenText.String()) + } + case ContentBlockTypeAssistantGenImage: + if b.AssistantGenImage != nil { + sb.WriteString(b.AssistantGenImage.String()) + } + case ContentBlockTypeAssistantGenAudio: + if b.AssistantGenAudio != nil { + sb.WriteString(b.AssistantGenAudio.String()) + } + case ContentBlockTypeAssistantGenVideo: + if b.AssistantGenVideo != nil { + sb.WriteString(b.AssistantGenVideo.String()) + } + case ContentBlockTypeFunctionToolCall: + if b.FunctionToolCall != nil { + sb.WriteString(b.FunctionToolCall.String()) + } + case ContentBlockTypeFunctionToolResult: + if b.FunctionToolResult != nil { + sb.WriteString(b.FunctionToolResult.String()) + } + case ContentBlockTypeServerToolCall: + if b.ServerToolCall != nil { + sb.WriteString(b.ServerToolCall.String()) + } + case ContentBlockTypeServerToolResult: + if b.ServerToolResult != nil { + sb.WriteString(b.ServerToolResult.String()) + } + case ContentBlockTypeMCPToolCall: + if b.MCPToolCall != nil { + sb.WriteString(b.MCPToolCall.String()) + } + case ContentBlockTypeMCPToolResult: + if b.MCPToolResult != nil { + sb.WriteString(b.MCPToolResult.String()) + } + case ContentBlockTypeMCPListToolsResult: + if b.MCPListToolsResult != nil { + sb.WriteString(b.MCPListToolsResult.String()) + } + case ContentBlockTypeMCPToolApprovalRequest: + if b.MCPToolApprovalRequest != nil { + sb.WriteString(b.MCPToolApprovalRequest.String()) + } + case ContentBlockTypeMCPToolApprovalResponse: + if b.MCPToolApprovalResponse != nil { + sb.WriteString(b.MCPToolApprovalResponse.String()) + } + } + + if b.StreamMeta != nil { + sb.WriteString(fmt.Sprintf(" stream_index: %d\n", b.StreamMeta.Index)) + } + + return sb.String() +} + +func (r *Reasoning) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" summary: %d items\n", len(r.Summary))) + for _, s := range r.Summary { + sb.WriteString(fmt.Sprintf(" [%d] %s\n", s.Index, s.Text)) + } + if r.EncryptedContent != "" { + sb.WriteString(fmt.Sprintf(" encrypted_content: %s\n", truncateString(r.EncryptedContent, 50))) + } + return sb.String() +} + +func (u *UserInputText) String() string { + return fmt.Sprintf(" text: %s\n", u.Text) +} + +func (u *UserInputImage) String() string { + return formatMediaString(u.URL, u.Base64Data, u.MIMEType, u.Detail) +} + +func (u *UserInputAudio) String() string { + return formatMediaString(u.URL, u.Base64Data, u.MIMEType, "") +} + +func (u *UserInputVideo) String() string { + return formatMediaString(u.URL, u.Base64Data, u.MIMEType, "") +} + +func (u *UserInputFile) String() string { + sb := &strings.Builder{} + if u.Name != "" { + sb.WriteString(fmt.Sprintf(" name: %s\n", u.Name)) + } + sb.WriteString(formatMediaString(u.URL, u.Base64Data, u.MIMEType, "")) + return sb.String() +} + +func (a *AssistantGenText) String() string { + return fmt.Sprintf(" text: %s\n", a.Text) +} + +func (a *AssistantGenImage) String() string { + return formatMediaString(a.URL, a.Base64Data, a.MIMEType, "") +} + +func (a *AssistantGenAudio) String() string { + return formatMediaString(a.URL, a.Base64Data, a.MIMEType, "") +} + +func (a *AssistantGenVideo) String() string { + return formatMediaString(a.URL, a.Base64Data, a.MIMEType, "") +} + +func (f *FunctionToolCall) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" call_id: %s\n", f.CallID)) + sb.WriteString(fmt.Sprintf(" name: %s\n", f.Name)) + sb.WriteString(fmt.Sprintf(" arguments: %s\n", f.Arguments)) + return sb.String() +} + +func (f *FunctionToolResult) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" call_id: %s\n", f.CallID)) + sb.WriteString(fmt.Sprintf(" name: %s\n", f.Name)) + sb.WriteString(fmt.Sprintf(" result: %s\n", f.Result)) + return sb.String() +} + +func (s *ServerToolCall) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" name: %s\n", s.Name)) + if s.CallID != "" { + sb.WriteString(fmt.Sprintf(" call_id: %s\n", s.CallID)) + } + sb.WriteString(fmt.Sprintf(" arguments: %v\n", s.Arguments)) + return sb.String() +} + +func (s *ServerToolResult) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" name: %s\n", s.Name)) + if s.CallID != "" { + sb.WriteString(fmt.Sprintf(" call_id: %s\n", s.CallID)) + } + sb.WriteString(fmt.Sprintf(" result: %v\n", s.Result)) + return sb.String() +} + +func (m *MCPToolCall) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" server_label: %s\n", m.ServerLabel)) + sb.WriteString(fmt.Sprintf(" call_id: %s\n", m.CallID)) + sb.WriteString(fmt.Sprintf(" name: %s\n", m.Name)) + sb.WriteString(fmt.Sprintf(" arguments: %s\n", m.Arguments)) + if m.ApprovalRequestID != "" { + sb.WriteString(fmt.Sprintf(" approval_request_id: %s\n", m.ApprovalRequestID)) + } + return sb.String() +} + +func (m *MCPToolResult) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" call_id: %s\n", m.CallID)) + sb.WriteString(fmt.Sprintf(" name: %s\n", m.Name)) + sb.WriteString(fmt.Sprintf(" result: %s\n", m.Result)) + if m.Error != nil { + sb.WriteString(fmt.Sprintf(" error: [%d] %s\n", *m.Error.Code, m.Error.Message)) + } + return sb.String() +} + +func (m *MCPListToolsResult) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" server_label: %s\n", m.ServerLabel)) + sb.WriteString(fmt.Sprintf(" tools: %d items\n", len(m.Tools))) + for _, tool := range m.Tools { + sb.WriteString(fmt.Sprintf(" - %s: %s\n", tool.Name, tool.Description)) + } + if m.Error != "" { + sb.WriteString(fmt.Sprintf(" error: %s\n", m.Error)) + } + return sb.String() +} + +func (m *MCPToolApprovalRequest) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" server_label: %s\n", m.ServerLabel)) + sb.WriteString(fmt.Sprintf(" id: %s\n", m.ID)) + sb.WriteString(fmt.Sprintf(" name: %s\n", m.Name)) + sb.WriteString(fmt.Sprintf(" arguments: %s\n", m.Arguments)) + return sb.String() +} + +func (m *MCPToolApprovalResponse) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" approval_request_id: %s\n", m.ApprovalRequestID)) + sb.WriteString(fmt.Sprintf(" approve: %v\n", m.Approve)) + if m.Reason != "" { + sb.WriteString(fmt.Sprintf(" reason: %s\n", m.Reason)) + } + return sb.String() +} + +func (a *AgenticResponseMeta) String() string { + sb := &strings.Builder{} + sb.WriteString("response_meta:\n") + if a.TokenUsage != nil { + sb.WriteString(fmt.Sprintf(" token_usage: prompt=%d, completion=%d, total=%d\n", + a.TokenUsage.PromptTokens, + a.TokenUsage.CompletionTokens, + a.TokenUsage.TotalTokens)) + } + return sb.String() +} + +// truncateString truncates a string to maxLen characters, adding "..." if truncated +func truncateString(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen] + "..." +} + +// formatMediaString formats URL, Base64Data, MIMEType and Detail for media content +func formatMediaString(url, base64Data string, mimeType string, detail any) string { + sb := &strings.Builder{} + if url != "" { + sb.WriteString(fmt.Sprintf(" url: %s\n", truncateString(url, 100))) + } + if base64Data != "" { + // Only show first few characters of base64 data + sb.WriteString(fmt.Sprintf(" base64_data: %s... (%d bytes)\n", truncateString(base64Data, 20), len(base64Data))) + } + if mimeType != "" { + sb.WriteString(fmt.Sprintf(" mime_type: %s\n", mimeType)) + } + if detail != nil && detail != "" { + sb.WriteString(fmt.Sprintf(" detail: %v\n", detail)) + } + return sb.String() +} diff --git a/schema/agentic_message_test.go b/schema/agentic_message_test.go new file mode 100644 index 000000000..0cafcd9ff --- /dev/null +++ b/schema/agentic_message_test.go @@ -0,0 +1,1381 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package schema + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConcatAgenticMessages(t *testing.T) { + t.Run("single message", func(t *testing.T) { + msg := &AgenticMessage{ + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Hello", + }, + }, + }, + } + + result, err := ConcatAgenticMessages([]*AgenticMessage{msg}) + assert.NoError(t, err) + assert.Equal(t, msg, result) + }) + + t.Run("nil message in stream", func(t *testing.T) { + msgs := []*AgenticMessage{ + {Role: AgenticRoleTypeAssistant}, + nil, + {Role: AgenticRoleTypeAssistant}, + } + + _, err := ConcatAgenticMessages(msgs) + assert.Error(t, err) + assert.ErrorContains(t, err, "message at index 1 is nil") + }) + + t.Run("different roles", func(t *testing.T) { + msgs := []*AgenticMessage{ + {Role: AgenticRoleTypeUser}, + {Role: AgenticRoleTypeAssistant}, + } + + _, err := ConcatAgenticMessages(msgs) + assert.Error(t, err) + assert.ErrorContains(t, err, "cannot concat messages with different roles") + }) + + t.Run("concat text blocks", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Hello ", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "World!", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Equal(t, AgenticRoleTypeAssistant, result.Role) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "Hello World!", result.ContentBlocks[0].AssistantGenText.Text) + }) + + t.Run("concat reasoning with nil index", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeReasoning, + Reasoning: &Reasoning{ + Summary: []*ReasoningSummary{ + {Index: 0, Text: "First "}, + }, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeReasoning, + Reasoning: &Reasoning{ + Summary: []*ReasoningSummary{ + {Index: 0, Text: "Second"}, + }, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Len(t, result.ContentBlocks[0].Reasoning.Summary, 1) + assert.Equal(t, "First Second", result.ContentBlocks[0].Reasoning.Summary[0].Text) + assert.Equal(t, int64(0), result.ContentBlocks[0].Reasoning.Summary[0].Index) + }) + + t.Run("concat reasoning with index", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeReasoning, + Reasoning: &Reasoning{ + Summary: []*ReasoningSummary{ + {Index: 0, Text: "Part1-"}, + {Index: 1, Text: "Part2-"}, + }, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeReasoning, + Reasoning: &Reasoning{ + Summary: []*ReasoningSummary{ + {Index: 0, Text: "Part3"}, + {Index: 1, Text: "Part4"}, + }, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Len(t, result.ContentBlocks[0].Reasoning.Summary, 2) + assert.Equal(t, "Part1-Part3", result.ContentBlocks[0].Reasoning.Summary[0].Text) + assert.Equal(t, "Part2-Part4", result.ContentBlocks[0].Reasoning.Summary[1].Text) + }) + + t.Run("concat user input text", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputText, + UserInputText: &UserInputText{ + Text: "Hello ", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputText, + UserInputText: &UserInputText{ + Text: "World!", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "Hello World!", result.ContentBlocks[0].UserInputText.Text) + }) + + t.Run("concat user input image", func(t *testing.T) { + url1 := "https://example.com/image1.jpg" + url2 := "https://example.com/image2.jpg" + + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputImage, + UserInputImage: &UserInputImage{ + URL: url1, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputImage, + UserInputImage: &UserInputImage{ + URL: url2, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + // Should take the last image + assert.Equal(t, url2, result.ContentBlocks[0].UserInputImage.URL) + }) + + t.Run("concat user input audio", func(t *testing.T) { + url1 := "https://example.com/audio1.mp3" + url2 := "https://example.com/audio2.mp3" + + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputAudio, + UserInputAudio: &UserInputAudio{ + URL: url1, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputAudio, + UserInputAudio: &UserInputAudio{ + URL: url2, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + // Should take the last audio + assert.Equal(t, url2, result.ContentBlocks[0].UserInputAudio.URL) + }) + + t.Run("concat user input video", func(t *testing.T) { + url1 := "https://example.com/video1.mp4" + url2 := "https://example.com/video2.mp4" + + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputVideo, + UserInputVideo: &UserInputVideo{ + URL: url1, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputVideo, + UserInputVideo: &UserInputVideo{ + URL: url2, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + // Should take the last video + assert.Equal(t, url2, result.ContentBlocks[0].UserInputVideo.URL) + }) + + t.Run("concat assistant gen text", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Generated ", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Text", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "Generated Text", result.ContentBlocks[0].AssistantGenText.Text) + }) + + t.Run("concat assistant gen image", func(t *testing.T) { + url1 := "https://example.com/gen_image1.jpg" + url2 := "https://example.com/gen_image2.jpg" + + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenImage, + AssistantGenImage: &AssistantGenImage{ + URL: url1, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenImage, + AssistantGenImage: &AssistantGenImage{ + URL: url2, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + // Should take the last image + assert.Equal(t, url2, result.ContentBlocks[0].AssistantGenImage.URL) + }) + + t.Run("concat assistant gen audio", func(t *testing.T) { + url1 := "https://example.com/gen_audio1.mp3" + url2 := "https://example.com/gen_audio2.mp3" + + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenAudio, + AssistantGenAudio: &AssistantGenAudio{ + URL: url1, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenAudio, + AssistantGenAudio: &AssistantGenAudio{ + URL: url2, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + // Should take the last audio + assert.Equal(t, url2, result.ContentBlocks[0].AssistantGenAudio.URL) + }) + + t.Run("concat assistant gen video", func(t *testing.T) { + url1 := "https://example.com/gen_video1.mp4" + url2 := "https://example.com/gen_video2.mp4" + + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenVideo, + AssistantGenVideo: &AssistantGenVideo{ + URL: url1, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenVideo, + AssistantGenVideo: &AssistantGenVideo{ + URL: url2, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + // Should take the last video + assert.Equal(t, url2, result.ContentBlocks[0].AssistantGenVideo.URL) + }) + + t.Run("concat function tool call", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeFunctionToolCall, + FunctionToolCall: &FunctionToolCall{ + CallID: "call_123", + Name: "get_weather", + Arguments: `{"location`, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeFunctionToolCall, + FunctionToolCall: &FunctionToolCall{ + Arguments: `":"NYC"}`, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "call_123", result.ContentBlocks[0].FunctionToolCall.CallID) + assert.Equal(t, "get_weather", result.ContentBlocks[0].FunctionToolCall.Name) + assert.Equal(t, `{"location":"NYC"}`, result.ContentBlocks[0].FunctionToolCall.Arguments) + }) + + t.Run("concat function tool result", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeFunctionToolResult, + FunctionToolResult: &FunctionToolResult{ + CallID: "call_123", + Name: "get_weather", + Result: `{"temp`, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeFunctionToolResult, + FunctionToolResult: &FunctionToolResult{ + Result: `":72}`, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "call_123", result.ContentBlocks[0].FunctionToolResult.CallID) + assert.Equal(t, "get_weather", result.ContentBlocks[0].FunctionToolResult.Name) + assert.Equal(t, `{"temp":72}`, result.ContentBlocks[0].FunctionToolResult.Result) + }) + + t.Run("concat server tool call", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeServerToolCall, + ServerToolCall: &ServerToolCall{ + CallID: "server_call_1", + Name: "server_func", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeServerToolCall, + ServerToolCall: &ServerToolCall{ + Arguments: map[string]any{"key": "value"}, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "server_call_1", result.ContentBlocks[0].ServerToolCall.CallID) + assert.Equal(t, "server_func", result.ContentBlocks[0].ServerToolCall.Name) + assert.NotNil(t, result.ContentBlocks[0].ServerToolCall.Arguments) + }) + + t.Run("concat server tool result", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeServerToolResult, + ServerToolResult: &ServerToolResult{ + CallID: "server_call_1", + Name: "server_func", + Result: "result1", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeServerToolResult, + ServerToolResult: &ServerToolResult{ + Result: "result2", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "server_call_1", result.ContentBlocks[0].ServerToolResult.CallID) + assert.Equal(t, "server_func", result.ContentBlocks[0].ServerToolResult.Name) + }) + + t.Run("concat mcp tool call", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolCall, + MCPToolCall: &MCPToolCall{ + ServerLabel: "mcp-server", + CallID: "mcp_call_1", + Name: "mcp_func", + Arguments: `{"arg`, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolCall, + MCPToolCall: &MCPToolCall{ + Arguments: `":123}`, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "mcp-server", result.ContentBlocks[0].MCPToolCall.ServerLabel) + assert.Equal(t, "mcp_call_1", result.ContentBlocks[0].MCPToolCall.CallID) + assert.Equal(t, "mcp_func", result.ContentBlocks[0].MCPToolCall.Name) + assert.Equal(t, `{"arg":123}`, result.ContentBlocks[0].MCPToolCall.Arguments) + }) + + t.Run("concat mcp tool result", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolResult, + MCPToolResult: &MCPToolResult{ + CallID: "mcp_call_1", + Name: "mcp_func", + Result: `{"res`, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolResult, + MCPToolResult: &MCPToolResult{ + Result: `ult":true}`, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "mcp_call_1", result.ContentBlocks[0].MCPToolResult.CallID) + assert.Equal(t, "mcp_func", result.ContentBlocks[0].MCPToolResult.Name) + assert.Equal(t, `{"result":true}`, result.ContentBlocks[0].MCPToolResult.Result) + }) + + t.Run("concat mcp list tools", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPListToolsResult, + MCPListToolsResult: &MCPListToolsResult{ + ServerLabel: "mcp-server", + Tools: []*MCPListToolsItem{ + {Name: "tool1"}, + }, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPListToolsResult, + MCPListToolsResult: &MCPListToolsResult{ + Tools: []*MCPListToolsItem{ + {Name: "tool2"}, + }, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "mcp-server", result.ContentBlocks[0].MCPListToolsResult.ServerLabel) + assert.Len(t, result.ContentBlocks[0].MCPListToolsResult.Tools, 2) + }) + + t.Run("concat mcp tool approval request", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolApprovalRequest, + MCPToolApprovalRequest: &MCPToolApprovalRequest{ + ID: "approval_1", + Name: "approval_func", + ServerLabel: "mcp-server", + Arguments: `{"request`, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolApprovalRequest, + MCPToolApprovalRequest: &MCPToolApprovalRequest{ + Arguments: `":1}`, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "approval_1", result.ContentBlocks[0].MCPToolApprovalRequest.ID) + assert.Equal(t, "approval_func", result.ContentBlocks[0].MCPToolApprovalRequest.Name) + assert.Equal(t, "mcp-server", result.ContentBlocks[0].MCPToolApprovalRequest.ServerLabel) + assert.Equal(t, `{"request":1}`, result.ContentBlocks[0].MCPToolApprovalRequest.Arguments) + }) + + t.Run("concat mcp tool approval response", func(t *testing.T) { + response1 := &MCPToolApprovalResponse{ + ApprovalRequestID: "approval_1", + Approve: false, + } + response2 := &MCPToolApprovalResponse{ + ApprovalRequestID: "approval_1", + Approve: true, + } + + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolApprovalResponse, + MCPToolApprovalResponse: response1, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolApprovalResponse, + MCPToolApprovalResponse: response2, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + // Should take the last response + assert.Equal(t, response2, result.ContentBlocks[0].MCPToolApprovalResponse) + }) + + t.Run("concat response meta", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ResponseMeta: &AgenticResponseMeta{ + TokenUsage: &TokenUsage{ + PromptTokens: 10, + CompletionTokens: 5, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ResponseMeta: &AgenticResponseMeta{ + TokenUsage: &TokenUsage{ + PromptTokens: 10, + CompletionTokens: 15, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.NotNil(t, result.ResponseMeta) + assert.Equal(t, 20, result.ResponseMeta.TokenUsage.CompletionTokens) + assert.Equal(t, 20, result.ResponseMeta.TokenUsage.PromptTokens) + }) + + t.Run("mixed streaming and non-streaming blocks error", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Hello", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "World", + }, + // No StreamMeta - non-streaming + }, + }, + }, + } + + _, err := ConcatAgenticMessages(msgs) + assert.Error(t, err) + assert.ErrorContains(t, err, "found non-streaming block after streaming blocks") + }) + + t.Run("concat MCP tool call", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolCall, + MCPToolCall: &MCPToolCall{ + ServerLabel: "mcp-server", + CallID: "call_456", + Name: "list_files", + Arguments: `{"path`, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolCall, + MCPToolCall: &MCPToolCall{ + Arguments: `":"/tmp"}`, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "mcp-server", result.ContentBlocks[0].MCPToolCall.ServerLabel) + assert.Equal(t, "call_456", result.ContentBlocks[0].MCPToolCall.CallID) + assert.Equal(t, `{"path":"/tmp"}`, result.ContentBlocks[0].MCPToolCall.Arguments) + }) + + t.Run("concat user input text", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputText, + UserInputText: &UserInputText{ + Text: "What is ", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputText, + UserInputText: &UserInputText{ + Text: "the weather?", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "What is the weather?", result.ContentBlocks[0].UserInputText.Text) + }) + + t.Run("multiple stream indexes - sparse indexes", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Index0-", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Index2-", + }, + StreamMeta: &StreamMeta{Index: 2}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Part2", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Part2", + }, + StreamMeta: &StreamMeta{Index: 2}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 2) + assert.Equal(t, "Index0-Part2", result.ContentBlocks[0].AssistantGenText.Text) + assert.Equal(t, "Index2-Part2", result.ContentBlocks[1].AssistantGenText.Text) + }) + + t.Run("multiple stream indexes - mixed content types", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Text ", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + { + Type: ContentBlockTypeFunctionToolCall, + FunctionToolCall: &FunctionToolCall{ + CallID: "call_1", + Name: "func1", + Arguments: `{"a`, + }, + StreamMeta: &StreamMeta{Index: 1}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Content", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + { + Type: ContentBlockTypeFunctionToolCall, + FunctionToolCall: &FunctionToolCall{ + Arguments: `":1}`, + }, + StreamMeta: &StreamMeta{Index: 1}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 2) + assert.Equal(t, "Text Content", result.ContentBlocks[0].AssistantGenText.Text) + assert.Equal(t, "call_1", result.ContentBlocks[1].FunctionToolCall.CallID) + assert.Equal(t, "func1", result.ContentBlocks[1].FunctionToolCall.Name) + assert.Equal(t, `{"a":1}`, result.ContentBlocks[1].FunctionToolCall.Arguments) + }) + + t.Run("multiple stream indexes - three indexes", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "A", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "B", + }, + StreamMeta: &StreamMeta{Index: 1}, + }, + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "C", + }, + StreamMeta: &StreamMeta{Index: 2}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "1", + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "2", + }, + StreamMeta: &StreamMeta{Index: 1}, + }, + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "3", + }, + StreamMeta: &StreamMeta{Index: 2}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 3) + assert.Equal(t, "A1", result.ContentBlocks[0].AssistantGenText.Text) + assert.Equal(t, "B2", result.ContentBlocks[1].AssistantGenText.Text) + assert.Equal(t, "C3", result.ContentBlocks[2].AssistantGenText.Text) + }) +} + +func TestAgenticMessageFormat(t *testing.T) { + m := &AgenticMessage{ + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputText, + UserInputText: &UserInputText{Text: "{a}"}, + }, + { + Type: ContentBlockTypeUserInputImage, + UserInputImage: &UserInputImage{ + URL: "{b}", + Base64Data: "{c}", + }, + }, + { + Type: ContentBlockTypeUserInputAudio, + UserInputAudio: &UserInputAudio{ + URL: "{d}", + Base64Data: "{e}", + }, + }, + { + Type: ContentBlockTypeUserInputVideo, + UserInputVideo: &UserInputVideo{ + URL: "{f}", + Base64Data: "{g}", + }, + }, + { + Type: ContentBlockTypeUserInputFile, + UserInputFile: &UserInputFile{ + URL: "{h}", + Base64Data: "{i}", + }, + }, + }, + } + + result, err := m.Format(context.Background(), map[string]any{ + "a": "1", "b": "2", "c": "3", "d": "4", "e": "5", "f": "6", "g": "7", "h": "8", "i": "9", + }, FString) + assert.NoError(t, err) + assert.Equal(t, []*AgenticMessage{{ + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputText, + UserInputText: &UserInputText{Text: "1"}, + }, + { + Type: ContentBlockTypeUserInputImage, + UserInputImage: &UserInputImage{ + URL: "2", + Base64Data: "3", + }, + }, + { + Type: ContentBlockTypeUserInputAudio, + UserInputAudio: &UserInputAudio{ + URL: "4", + Base64Data: "5", + }, + }, + { + Type: ContentBlockTypeUserInputVideo, + UserInputVideo: &UserInputVideo{ + URL: "6", + Base64Data: "7", + }, + }, + { + Type: ContentBlockTypeUserInputFile, + UserInputFile: &UserInputFile{ + URL: "8", + Base64Data: "9", + }, + }, + }, + }}, result) +} + +func TestAgenticPlaceholderFormat(t *testing.T) { + ctx := context.Background() + ph := AgenticMessagesPlaceholder("a", false) + + result, err := ph.Format(ctx, map[string]any{ + "a": []*AgenticMessage{{Role: AgenticRoleTypeUser}, {Role: AgenticRoleTypeUser}}, + }, FString) + assert.NoError(t, err) + assert.Equal(t, 2, len(result)) + + ph = AgenticMessagesPlaceholder("a", true) + + result, err = ph.Format(ctx, map[string]any{}, FString) + assert.NoError(t, err) + assert.Equal(t, 0, len(result)) +} + +func ptrOf[T any](v T) *T { + return &v +} + +func TestAgenticMessageString(t *testing.T) { + longBase64 := "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" + + msg := &AgenticMessage{ + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputText, + UserInputText: &UserInputText{ + Text: "What's the weather like in New York City today?", + }, + }, + { + Type: ContentBlockTypeUserInputImage, + UserInputImage: &UserInputImage{ + URL: "https://example.com/weather-map.jpg", + Base64Data: longBase64, + MIMEType: "image/jpeg", + Detail: ImageURLDetailHigh, + }, + }, + { + Type: ContentBlockTypeReasoning, + Reasoning: &Reasoning{ + Summary: []*ReasoningSummary{ + {Index: 0, Text: "First, I need to identify the location (New York City) from the user's query."}, + {Index: 1, Text: "Then, I should call the weather API to get current conditions."}, + {Index: 2, Text: "Finally, I'll format the response in a user-friendly way with temperature and conditions."}, + }, + EncryptedContent: "encrypted_reasoning_content_that_is_very_long_and_will_be_truncated_for_display", + }, + }, + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "I'll check the current weather in New York City for you.", + }, + }, + { + Type: ContentBlockTypeFunctionToolCall, + FunctionToolCall: &FunctionToolCall{ + CallID: "call_weather_123", + Name: "get_current_weather", + Arguments: `{"location":"New York City","unit":"fahrenheit"}`, + }, + StreamMeta: &StreamMeta{Index: 0}, + }, + { + Type: ContentBlockTypeFunctionToolResult, + FunctionToolResult: &FunctionToolResult{ + CallID: "call_weather_123", + Name: "get_current_weather", + Result: `{"temperature":72,"condition":"sunny","humidity":45,"wind_speed":8}`, + }, + }, + { + Type: ContentBlockTypeMCPToolCall, + MCPToolCall: &MCPToolCall{ + ServerLabel: "weather-mcp-server", + CallID: "mcp_forecast_456", + Name: "get_7day_forecast", + Arguments: `{"city":"New York","days":7}`, + ApprovalRequestID: "approval_req_789", + }, + }, + { + Type: ContentBlockTypeMCPToolResult, + MCPToolResult: &MCPToolResult{ + CallID: "mcp_forecast_456", + Name: "get_7day_forecast", + Result: `{"status":"partial","days_available":3}`, + Error: &MCPToolCallError{ + Code: ptrOf[int64](503), + Message: "Service temporarily unavailable for full 7-day forecast", + }, + }, + }, + { + Type: ContentBlockTypeMCPListToolsResult, + MCPListToolsResult: &MCPListToolsResult{ + ServerLabel: "weather-mcp-server", + Tools: []*MCPListToolsItem{ + {Name: "get_current_weather", Description: "Get current weather conditions for a location"}, + {Name: "get_7day_forecast", Description: "Get 7-day weather forecast"}, + {Name: "get_weather_alerts", Description: "Get active weather alerts and warnings"}, + }, + }, + }, + }, + ResponseMeta: &AgenticResponseMeta{ + TokenUsage: &TokenUsage{ + PromptTokens: 250, + CompletionTokens: 180, + TotalTokens: 430, + }, + }, + } + + // Print the formatted output + output := msg.String() + + assert.Equal(t, `role: assistant +content_blocks: + [0] type: user_input_text + text: What's the weather like in New York City today? + [1] type: user_input_image + url: https://example.com/weather-map.jpg + base64_data: iVBORw0KGgoAAAANSUhE...... (96 bytes) + mime_type: image/jpeg + detail: high + [2] type: reasoning + summary: 3 items + [0] First, I need to identify the location (New York City) from the user's query. + [1] Then, I should call the weather API to get current conditions. + [2] Finally, I'll format the response in a user-friendly way with temperature and conditions. + encrypted_content: encrypted_reasoning_content_that_is_very_long_and_... + [3] type: assistant_gen_text + text: I'll check the current weather in New York City for you. + [4] type: function_tool_call + call_id: call_weather_123 + name: get_current_weather + arguments: {"location":"New York City","unit":"fahrenheit"} + stream_index: 0 + [5] type: function_tool_result + call_id: call_weather_123 + name: get_current_weather + result: {"temperature":72,"condition":"sunny","humidity":45,"wind_speed":8} + [6] type: mcp_tool_call + server_label: weather-mcp-server + call_id: mcp_forecast_456 + name: get_7day_forecast + arguments: {"city":"New York","days":7} + approval_request_id: approval_req_789 + [7] type: mcp_tool_result + call_id: mcp_forecast_456 + name: get_7day_forecast + result: {"status":"partial","days_available":3} + error: [503] Service temporarily unavailable for full 7-day forecast + [8] type: mcp_list_tools_result + server_label: weather-mcp-server + tools: 3 items + - get_current_weather: Get current weather conditions for a location + - get_7day_forecast: Get 7-day weather forecast + - get_weather_alerts: Get active weather alerts and warnings +response_meta: + token_usage: prompt=250, completion=180, total=430 +`, output) +} diff --git a/schema/message.go b/schema/message.go index 5eb864870..bc1cc184d 100644 --- a/schema/message.go +++ b/schema/message.go @@ -40,47 +40,56 @@ func init() { internal.RegisterStreamChunkConcatFunc(ConcatMessages) internal.RegisterStreamChunkConcatFunc(ConcatMessageArray) + internal.RegisterStreamChunkConcatFunc(ConcatAgenticMessages) + internal.RegisterStreamChunkConcatFunc(ConcatAgenticMessagesArray) + internal.RegisterStreamChunkConcatFunc(ConcatToolResults) } -// ConcatMessageArray merges aligned slices of messages into a single slice, -// concatenating messages at the same index across the input arrays. -func ConcatMessageArray(mas [][]*Message) ([]*Message, error) { - arrayLen := len(mas[0]) +func buildConcatGenericArray[T any](f func([]*T) (*T, error)) func([][]*T) ([]*T, error) { + return func(mas [][]*T) ([]*T, error) { + arrayLen := len(mas[0]) - ret := make([]*Message, arrayLen) - slicesToConcat := make([][]*Message, arrayLen) + ret := make([]*T, arrayLen) + slicesToConcat := make([][]*T, arrayLen) - for _, ma := range mas { - if len(ma) != arrayLen { - return nil, fmt.Errorf("unexpected array length. "+ - "Got %d, expected %d", len(ma), arrayLen) - } + for _, ma := range mas { + if len(ma) != arrayLen { + return nil, fmt.Errorf("unexpected array length. "+ + "Got %d, expected %d", len(ma), arrayLen) + } - for i := 0; i < arrayLen; i++ { - m := ma[i] - if m != nil { - slicesToConcat[i] = append(slicesToConcat[i], m) + for i := 0; i < arrayLen; i++ { + m := ma[i] + if m != nil { + slicesToConcat[i] = append(slicesToConcat[i], m) + } } } - } - for i, slice := range slicesToConcat { - if len(slice) == 0 { - ret[i] = nil - } else if len(slice) == 1 { - ret[i] = slice[0] - } else { - cm, err := ConcatMessages(slice) - if err != nil { - return nil, err - } + for i, slice := range slicesToConcat { + if len(slice) == 0 { + ret[i] = nil + } else if len(slice) == 1 { + ret[i] = slice[0] + } else { + cm, err := f(slice) + if err != nil { + return nil, err + } - ret[i] = cm + ret[i] = cm + } } + + return ret, nil } +} - return ret, nil +// ConcatMessageArray merges aligned slices of messages into a single slice, +// concatenating messages at the same index across the input arrays. +func ConcatMessageArray(mas [][]*Message) ([]*Message, error) { + return buildConcatGenericArray[Message](ConcatMessages)(mas) } // FormatType used by MessageTemplate.Format @@ -716,7 +725,7 @@ var _ MessagesTemplate = MessagesPlaceholder("", false) // e.g. // // chatTemplate := prompt.FromMessages( -// schema.SystemMessage("you are eino helper"), +// schema.SystemMessage("you are an eino helper"), // schema.MessagesPlaceholder("history", false), // <= this will use the value of "history" in params // ) // msgs, err := chatTemplate.Format(ctx, params) @@ -734,7 +743,7 @@ type messagesPlaceholder struct { // // placeholder := MessagesPlaceholder("history", false) // params := map[string]any{ -// "history": []*schema.Message{{Role: "user", Content: "what is eino?"}, {Role: "assistant", Content: "eino is a great freamwork to build llm apps"}}, +// "history": []*schema.Message{{Role: "user", Content: "what is eino?"}, {Role: "assistant", Content: "eino is a great framework to build llm apps"}}, // "query": "how to use eino?", // } // chatTemplate := chatTpl := prompt.FromMessages( From 16fd48fb3071d7bf1cc549909ac1622de1098562 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Tue, 6 Jan 2026 16:48:56 +0800 Subject: [PATCH 16/28] fix: concat agentic messages (#604) --- components/agentic/callback_extra.go | 5 +- components/agentic/option.go | 7 +- components/agentic/option_test.go | 2 +- components/model/callback_extra.go | 6 +- components/prompt/callback_extra.go | 2 + components/types.go | 2 + compose/tools_node_agentic.go | 7 +- compose/tools_node_agentic_test.go | 37 +- schema/agentic_message.go | 1052 ++++++++++------- schema/agentic_message_test.go | 552 ++++++--- schema/claude/consts.go | 1 + .../claude/{content_block.go => extension.go} | 70 +- schema/claude/extension_test.go | 190 +++ schema/claude/response_meta.go | 22 - .../gemini/{response_meta.go => extension.go} | 43 +- schema/gemini/extension_test.go | 79 ++ schema/message.go | 4 +- schema/openai/consts.go | 69 ++ schema/openai/content_block.go | 75 -- schema/openai/extension.go | 206 ++++ schema/openai/extension_test.go | 193 +++ schema/openai/response_meta.go | 40 - schema/tool.go | 21 + 23 files changed, 1942 insertions(+), 743 deletions(-) rename schema/claude/{content_block.go => extension.go} (51%) create mode 100644 schema/claude/extension_test.go delete mode 100644 schema/claude/response_meta.go rename schema/gemini/{response_meta.go => extension.go} (76%) create mode 100644 schema/gemini/extension_test.go delete mode 100644 schema/openai/content_block.go create mode 100644 schema/openai/extension.go create mode 100644 schema/openai/extension_test.go delete mode 100644 schema/openai/response_meta.go diff --git a/components/agentic/callback_extra.go b/components/agentic/callback_extra.go index 389408d33..2c5a656fa 100644 --- a/components/agentic/callback_extra.go +++ b/components/agentic/callback_extra.go @@ -14,6 +14,7 @@ * limitations under the License. */ +// Package agentic defines callback payloads and configuration types for agentic models. package agentic import ( @@ -26,9 +27,9 @@ type Config struct { // Model is the model name. Model string // Temperature is the temperature, which controls the randomness of the model. - Temperature float32 + Temperature float64 // TopP is the top p, which controls the diversity of the model. - TopP float32 + TopP float64 } // CallbackInput is the input for the model callback. diff --git a/components/agentic/option.go b/components/agentic/option.go index ac117ddb4..d8873442a 100644 --- a/components/agentic/option.go +++ b/components/agentic/option.go @@ -30,8 +30,10 @@ type Options struct { TopP *float64 // Tools is a list of tools the model may call. Tools []*schema.ToolInfo - // ToolChoice controls which tool is called by the model. + // ToolChoice controls how the model call the tools. ToolChoice *schema.ToolChoice + // AllowedTools is a list of allowed tools the model may call. + AllowedTools []*schema.AllowedTool } // Option is the call option for ChatModel component. @@ -81,10 +83,11 @@ func WithTools(tools []*schema.ToolInfo) Option { } // WithToolChoice is the option to set tool choice for the model. -func WithToolChoice(toolChoice schema.ToolChoice) Option { +func WithToolChoice(toolChoice schema.ToolChoice, allowedTools ...*schema.AllowedTool) Option { return Option{ apply: func(opts *Options) { opts.ToolChoice = &toolChoice + opts.AllowedTools = allowedTools }, } } diff --git a/components/agentic/option_test.go b/components/agentic/option_test.go index d349f35ac..2c5bac652 100644 --- a/components/agentic/option_test.go +++ b/components/agentic/option_test.go @@ -29,7 +29,7 @@ func TestCommon(t *testing.T) { WithTools([]*schema.ToolInfo{{Name: "test"}}), WithModel("test"), WithTemperature(0.1), - WithToolChoice(schema.ToolChoiceAllowed), + WithToolChoice(schema.ToolChoiceAllowed, []*schema.AllowedTool{{FunctionToolName: "test"}}...), WithTopP(0.1), ) assert.Len(t, o.Tools, 1) diff --git a/components/model/callback_extra.go b/components/model/callback_extra.go index 8591c4373..2767e2e5e 100644 --- a/components/model/callback_extra.go +++ b/components/model/callback_extra.go @@ -29,17 +29,17 @@ type TokenUsage struct { PromptTokenDetails PromptTokenDetails // CompletionTokens is the number of completion tokens. CompletionTokens int + // CompletionTokensDetails is a breakdown of the completion tokens. + CompletionTokensDetails CompletionTokensDetails // TotalTokens is the total number of tokens. TotalTokens int - // CompletionTokensDetails is breakdown of completion tokens. - CompletionTokensDetails CompletionTokensDetails `json:"completion_token_details"` } type CompletionTokensDetails struct { // ReasoningTokens tokens generated by the model for reasoning. // This is currently supported by OpenAI, Gemini, ARK and Qwen chat models. // For other models, this field will be 0. - ReasoningTokens int `json:"reasoning_tokens,omitempty"` + ReasoningTokens int } // PromptTokenDetails provides a breakdown of prompt token usage. diff --git a/components/prompt/callback_extra.go b/components/prompt/callback_extra.go index ff5c3a8ff..3be780543 100644 --- a/components/prompt/callback_extra.go +++ b/components/prompt/callback_extra.go @@ -33,6 +33,7 @@ type AgenticCallbackOutput struct { Extra map[string]any } +// ConvAgenticCallbackInput converts the callback input to the agentic callback input. func ConvAgenticCallbackInput(src callbacks.CallbackInput) *AgenticCallbackInput { switch t := src.(type) { case *AgenticCallbackInput: @@ -46,6 +47,7 @@ func ConvAgenticCallbackInput(src callbacks.CallbackInput) *AgenticCallbackInput } } +// ConvAgenticCallbackOutput converts the callback output to the agentic callback output. func ConvAgenticCallbackOutput(src callbacks.CallbackOutput) *AgenticCallbackOutput { switch t := src.(type) { case *AgenticCallbackOutput: diff --git a/components/types.go b/components/types.go index 2ba088e93..2b0ad8f0e 100644 --- a/components/types.go +++ b/components/types.go @@ -66,9 +66,11 @@ type Component string const ( // ComponentOfPrompt identifies chat template components. ComponentOfPrompt Component = "ChatTemplate" + // ComponentOfAgenticPrompt identifies agentic template components. ComponentOfAgenticPrompt Component = "AgenticChatTemplate" // ComponentOfChatModel identifies chat model components. ComponentOfChatModel Component = "ChatModel" + // ComponentOfAgenticModel identifies agentic model components. ComponentOfAgenticModel Component = "AgenticModel" // ComponentOfEmbedding identifies embedding components. ComponentOfEmbedding Component = "Embedding" diff --git a/compose/tools_node_agentic.go b/compose/tools_node_agentic.go index 38c5c89de..96aef7b72 100644 --- a/compose/tools_node_agentic.go +++ b/compose/tools_node_agentic.go @@ -70,6 +70,7 @@ func agenticMessageToToolCallMessage(input *schema.AgenticMessage) *schema.Messa Name: block.FunctionToolCall.Name, Arguments: block.FunctionToolCall.Arguments, }, + Extra: block.Extra, }) } return &schema.Message{ @@ -87,8 +88,8 @@ func toolMessageToAgenticMessage(input []*schema.Message) []*schema.AgenticMessa CallID: m.ToolCallID, Name: m.ToolName, Result: m.Content, - Extra: m.Extra, }, + Extra: m.Extra, }) } return []*schema.AgenticMessage{{ @@ -110,9 +111,9 @@ func streamToolMessageToAgenticMessage(input *schema.StreamReader[[]*schema.Mess CallID: m.ToolCallID, Name: m.ToolName, Result: m.Content, - Extra: m.Extra, }, - StreamMeta: &schema.StreamMeta{Index: int64(i)}, + StreamingMeta: &schema.StreamingMeta{Index: i}, + Extra: m.Extra, }) } return []*schema.AgenticMessage{{ diff --git a/compose/tools_node_agentic_test.go b/compose/tools_node_agentic_test.go index dcd3177a9..4641dd8ae 100644 --- a/compose/tools_node_agentic_test.go +++ b/compose/tools_node_agentic_test.go @@ -20,6 +20,7 @@ import ( "io" "testing" + "github.com/bytedance/sonic" "github.com/stretchr/testify/assert" "github.com/cloudwego/eino/schema" @@ -155,13 +156,14 @@ func TestStreamToolMessageToAgenticMessage(t *testing.T) { nil, }, { + nil, { Role: schema.Tool, - Content: "content1-2", + Content: "content2-2", ToolName: "name2", ToolCallID: "2", }, - nil, nil, + nil, }, { nil, nil, @@ -172,16 +174,6 @@ func TestStreamToolMessageToAgenticMessage(t *testing.T) { ToolCallID: "3", }, }, - { - nil, - { - Role: schema.Tool, - Content: "content2-2", - ToolName: "name2", - ToolCallID: "2", - }, - nil, - }, { nil, nil, { @@ -204,7 +196,11 @@ func TestStreamToolMessageToAgenticMessage(t *testing.T) { } result, err := schema.ConcatAgenticMessagesArray(chunks) assert.NoError(t, err) - assert.Equal(t, []*schema.AgenticMessage{ + + actualStr, err := sonic.MarshalString(result) + assert.NoError(t, err) + + expected := []*schema.AgenticMessage{ { Role: schema.AgenticRoleTypeUser, ContentBlocks: []*schema.ContentBlock{ @@ -213,10 +209,8 @@ func TestStreamToolMessageToAgenticMessage(t *testing.T) { FunctionToolResult: &schema.FunctionToolResult{ CallID: "1", Name: "name1", - Result: "content1-1content1-2", - Extra: map[string]interface{}{}, + Result: "content1-1", }, - StreamMeta: &schema.StreamMeta{Index: 0}, }, { Type: schema.ContentBlockTypeFunctionToolResult, @@ -224,9 +218,7 @@ func TestStreamToolMessageToAgenticMessage(t *testing.T) { CallID: "2", Name: "name2", Result: "content2-1content2-2", - Extra: map[string]interface{}{}, }, - StreamMeta: &schema.StreamMeta{Index: 1}, }, { Type: schema.ContentBlockTypeFunctionToolResult, @@ -234,11 +226,14 @@ func TestStreamToolMessageToAgenticMessage(t *testing.T) { CallID: "3", Name: "name3", Result: "content3-1content3-2", - Extra: map[string]interface{}{}, }, - StreamMeta: &schema.StreamMeta{Index: 2}, }, }, }, - }, result) + } + + expectedStr, err := sonic.MarshalString(expected) + assert.NoError(t, err) + + assert.Equal(t, expectedStr, actualStr) } diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 2139201ec..b2225b2c7 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -20,15 +20,15 @@ import ( "context" "fmt" "reflect" + "sort" "strings" - "github.com/cloudwego/eino/schema/claude" - "github.com/cloudwego/eino/schema/gemini" + "github.com/eino-contrib/jsonschema" "github.com/cloudwego/eino/internal" + "github.com/cloudwego/eino/schema/claude" + "github.com/cloudwego/eino/schema/gemini" "github.com/cloudwego/eino/schema/openai" - - "github.com/eino-contrib/jsonschema" ) type ContentBlockType string @@ -82,9 +82,9 @@ type AgenticResponseMeta struct { Extension any } -type StreamMeta struct { +type StreamingMeta struct { // Index specifies the index position of this block in the final response. - Index int64 + Index int } type ContentBlock struct { @@ -123,14 +123,12 @@ type ContentBlock struct { // MCPToolApprovalResponse records the user's approval decision for an MCP tool call. MCPToolApprovalResponse *MCPToolApprovalResponse - StreamMeta *StreamMeta + StreamingMeta *StreamingMeta + Extra map[string]any } type UserInputText struct { Text string - - // Extra stores additional information. - Extra map[string]any } type UserInputImage struct { @@ -138,27 +136,18 @@ type UserInputImage struct { Base64Data string MIMEType string Detail ImageURLDetail - - // Extra stores additional information. - Extra map[string]any } type UserInputAudio struct { URL string Base64Data string MIMEType string - - // Extra stores additional information. - Extra map[string]any } type UserInputVideo struct { URL string Base64Data string MIMEType string - - // Extra stores additional information. - Extra map[string]any } type UserInputFile struct { @@ -166,9 +155,6 @@ type UserInputFile struct { Name string Base64Data string MIMEType string - - // Extra stores additional information. - Extra map[string]any } type AssistantGenText struct { @@ -177,51 +163,37 @@ type AssistantGenText struct { OpenAIExtension *openai.AssistantGenTextExtension ClaudeExtension *claude.AssistantGenTextExtension Extension any - - // Extra stores additional information. - Extra map[string]any } type AssistantGenImage struct { URL string Base64Data string MIMEType string - - // Extra stores additional information. - Extra map[string]any } type AssistantGenAudio struct { URL string Base64Data string MIMEType string - - // Extra stores additional information. - Extra map[string]any } type AssistantGenVideo struct { URL string Base64Data string MIMEType string - - // Extra stores additional information. - Extra map[string]any } type Reasoning struct { // Summary is the reasoning content summary. Summary []*ReasoningSummary + // EncryptedContent is the encrypted reasoning content. EncryptedContent string - - // Extra stores additional information. - Extra map[string]any } type ReasoningSummary struct { // Index specifies the index position of this summary in the final Reasoning. - Index int64 + Index int Text string } @@ -229,39 +201,37 @@ type ReasoningSummary struct { type FunctionToolCall struct { // CallID is the unique identifier for the tool call. CallID string + // Name specifies the function tool invoked. Name string + // Arguments is the JSON string arguments for the function tool call. Arguments string - - // Extra stores additional information - Extra map[string]any } type FunctionToolResult struct { // CallID is the unique identifier for the tool call. CallID string + // Name specifies the function tool invoked. Name string + // Result is the function tool result returned by the user Result string - - // Extra stores additional information. - Extra map[string]any } type ServerToolCall struct { // Name specifies the server-side tool invoked. // Supplied by the model server (e.g., `web_search` for OpenAI, `googleSearch` for Gemini). Name string + // CallID is the unique identifier for the tool call. // Empty if not provided by the model server. CallID string + // Arguments are the raw inputs to the server-side tool, // supplied by the component implementer. Arguments any - // Extra stores additional information. - Extra map[string]any } type ServerToolResult struct { @@ -276,41 +246,40 @@ type ServerToolResult struct { // Result refers to the raw output generated by the server-side tool, // supplied by the component implementer. Result any - - // Extra stores additional information. - Extra map[string]any } type MCPToolCall struct { // ServerLabel is the MCP server label used to identify it in tool calls ServerLabel string - // ApprovalRequestID is the unique ID of the approval request. + + // ApprovalRequestID is the approval request ID. ApprovalRequestID string + // CallID is the unique ID of the tool call. CallID string + // Name is the name of the tool to run. Name string + // Arguments is the JSON string arguments for the tool call. Arguments string - - // Extra stores additional information. - Extra map[string]any } type MCPToolResult struct { // ServerLabel is the MCP server label used to identify it in tool calls ServerLabel string + // CallID is the unique ID of the tool call. CallID string + // Name is the name of the tool to run. Name string + // Result is the JSON string with the tool result. Result string + // Error returned when the server fails to run the tool. Error *MCPToolCallError - - // Extra stores additional information. - Extra map[string]any } type MCPToolCallError struct { @@ -321,49 +290,49 @@ type MCPToolCallError struct { type MCPListToolsResult struct { // ServerLabel is the MCP server label used to identify it in tool calls. ServerLabel string + // Tools is the list of tools available on the server. Tools []*MCPListToolsItem + // Error returned when the server fails to list tools. Error string - - // Extra stores additional information. - Extra map[string]any } type MCPListToolsItem struct { // Name is the name of the tool. Name string + // Description is the description of the tool. Description string - // InputSchema is the JSON schema that describes the tool input. + + // InputSchema is the JSON schema that describes the tool input parameters. InputSchema *jsonschema.Schema } type MCPToolApprovalRequest struct { // ID is the approval request ID. ID string + // Name is the name of the tool to run. Name string + // Arguments is the JSON string arguments for the tool call. Arguments string + // ServerLabel is the MCP server label used to identify it in tool calls. ServerLabel string - - // Extra stores additional information. - Extra map[string]any } type MCPToolApprovalResponse struct { // ApprovalRequestID is the approval request ID being responded to. ApprovalRequestID string + // Approve indicates whether the request is approved. Approve bool + // Reason is the rationale for the decision. // Optional. Reason string - - // Extra stores additional information. - Extra map[string]any } // DeveloperAgenticMessage represents a message with AgenticRoleType "developer". @@ -404,8 +373,33 @@ func FunctionToolResultAgenticMessage(callID, name, result string) *AgenticMessa } } -func NewContentBlock(block any) *ContentBlock { - switch b := block.(type) { +type contentBlockVariant interface { + Reasoning | userInputVariant | assistantGenVariant | functionToolCallVariant | serverToolCallVariant | mcpToolCallVariant +} + +type userInputVariant interface { + UserInputText | UserInputImage | UserInputAudio | UserInputVideo | UserInputFile +} + +type assistantGenVariant interface { + AssistantGenText | AssistantGenImage | AssistantGenAudio | AssistantGenVideo +} + +type functionToolCallVariant interface { + FunctionToolCall | FunctionToolResult +} + +type serverToolCallVariant interface { + ServerToolCall | ServerToolResult +} + +type mcpToolCallVariant interface { + MCPToolCall | MCPToolResult | MCPListToolsResult | MCPToolApprovalRequest | MCPToolApprovalResponse +} + +// NewContentBlock creates a new ContentBlock with the given content. +func NewContentBlock[T contentBlockVariant](content *T) *ContentBlock { + switch b := any(content).(type) { case *Reasoning: return &ContentBlock{Type: ContentBlockTypeReasoning, Reasoning: b} case *UserInputText: @@ -449,6 +443,13 @@ func NewContentBlock(block any) *ContentBlock { } } +// NewContentBlockChunk creates a new ContentBlock with the given content and streaming metadata. +func NewContentBlockChunk[T contentBlockVariant](content *T, meta *StreamingMeta) *ContentBlock { + block := NewContentBlock(content) + block.StreamingMeta = meta + return block +} + // AgenticMessagesTemplate is the interface for messages template. // It's used to render a template to a list of agentic messages. // e.g. @@ -683,16 +684,19 @@ func formatUserInputFile(uif *UserInputFile, vs map[string]any, formatType Forma return &copied, nil } +// ConcatAgenticMessagesArray concatenates multiple streams of AgenticMessage into a single slice of AgenticMessage. func ConcatAgenticMessagesArray(mas [][]*AgenticMessage) ([]*AgenticMessage, error) { return buildConcatGenericArray[AgenticMessage](ConcatAgenticMessages)(mas) } +// ConcatAgenticMessages concatenates a list of AgenticMessage chunks into a single AgenticMessage. func ConcatAgenticMessages(msgs []*AgenticMessage) (*AgenticMessage, error) { var ( - role AgenticRoleType - blocksList [][]*ContentBlock - blocks []*ContentBlock - metas []*AgenticResponseMeta + role AgenticRoleType + blocks []*ContentBlock + metas []*AgenticResponseMeta + blockIndices []int + indexToBlocks = map[int][]*ContentBlock{} ) if len(msgs) == 1 { @@ -713,9 +717,12 @@ func ConcatAgenticMessages(msgs []*AgenticMessage) (*AgenticMessage, error) { } for _, block := range msg.ContentBlocks { - if block.StreamMeta == nil { + if block == nil { + continue + } + if block.StreamingMeta == nil { // Non-streaming block - if len(blocksList) > 0 { + if len(blockIndices) > 0 { // Cannot mix streaming and non-streaming blocks return nil, fmt.Errorf("found non-streaming block after streaming blocks") } @@ -728,8 +735,12 @@ func ConcatAgenticMessages(msgs []*AgenticMessage) (*AgenticMessage, error) { return nil, fmt.Errorf("found streaming block after non-streaming blocks") } // Collect streaming block by index - blocksList = expandSlice(int(block.StreamMeta.Index), blocksList) - blocksList[block.StreamMeta.Index] = append(blocksList[block.StreamMeta.Index], block) + if blocks_, ok := indexToBlocks[block.StreamingMeta.Index]; ok { + indexToBlocks[block.StreamingMeta.Index] = append(blocks_, block) + } else { + blockIndices = append(blockIndices, block.StreamingMeta.Index) + indexToBlocks[block.StreamingMeta.Index] = []*ContentBlock{block} + } } } @@ -743,219 +754,254 @@ func ConcatAgenticMessages(msgs []*AgenticMessage) (*AgenticMessage, error) { return nil, fmt.Errorf("failed to concat agentic response meta: %w", err) } - if len(blocksList) > 0 { + if len(blockIndices) > 0 { // All blocks are streaming, concat each group by index - blocks = make([]*ContentBlock, len(blocksList)) - for i, bs := range blocksList { - if len(bs) == 0 { - continue - } - b, err := concatAgenticContentBlocks(bs) + indexToBlock := map[int]*ContentBlock{} + for idx, bs := range indexToBlocks { + b, err := concatChunksOfSameContentBlock(bs) if err != nil { - return nil, fmt.Errorf("failed to concat content blocks at index %d: %w", i, err) + return nil, err } - blocks[i] = b + indexToBlock[idx] = b } - } - - for i := 0; i < len(blocks); i++ { - if blocks[i] == nil { - blocks = append(blocks[:i], blocks[i+1:]...) + blocks = make([]*ContentBlock, 0, len(blockIndices)) + sort.Slice(blockIndices, func(i, j int) bool { + return blockIndices[i] < blockIndices[j] + }) + for _, idx := range blockIndices { + blocks = append(blocks, indexToBlock[idx]) } } return &AgenticMessage{ - ResponseMeta: meta, Role: role, + ResponseMeta: meta, ContentBlocks: blocks, }, nil } -func concatAgenticResponseMeta(metas []*AgenticResponseMeta) (*AgenticResponseMeta, error) { +func concatAgenticResponseMeta(metas []*AgenticResponseMeta) (ret *AgenticResponseMeta, err error) { if len(metas) == 0 { return nil, nil } - ret := &AgenticResponseMeta{ - TokenUsage: &TokenUsage{}, - OpenAIExtension: nil, - ClaudeExtension: nil, - GeminiExtension: nil, - Extension: nil, - } + + openaiExtensions := make([]*openai.ResponseMetaExtension, 0, len(metas)) + claudeExtensions := make([]*claude.ResponseMetaExtension, 0, len(metas)) + geminiExtensions := make([]*gemini.ResponseMetaExtension, 0, len(metas)) + tokenUsages := make([]*TokenUsage, 0, len(metas)) + + var ( + extType reflect.Type + extensions reflect.Value + ) + for _, meta := range metas { - ret.Extension = meta.Extension - ret.OpenAIExtension = meta.OpenAIExtension - ret.ClaudeExtension = meta.ClaudeExtension - ret.GeminiExtension = meta.GeminiExtension if meta.TokenUsage != nil { - ret.TokenUsage.CompletionTokens += meta.TokenUsage.CompletionTokens - ret.TokenUsage.CompletionTokenDetails.ReasoningTokens += meta.TokenUsage.CompletionTokenDetails.ReasoningTokens - ret.TokenUsage.PromptTokens += meta.TokenUsage.PromptTokens - ret.TokenUsage.PromptTokenDetails.CachedTokens += meta.TokenUsage.PromptTokenDetails.CachedTokens - ret.TokenUsage.TotalTokens += meta.TokenUsage.TotalTokens + tokenUsages = append(tokenUsages, meta.TokenUsage) + } + + var isConsistent bool + + if meta.Extension != nil { + extType, isConsistent = validateExtensionType(extType, meta.Extension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in response meta chunks: '%s' vs '%s'", + extType, reflect.TypeOf(meta.Extension)) + } + if !extensions.IsValid() { + extensions = reflect.MakeSlice(reflect.SliceOf(extType), 0, len(metas)) + } + extensions = reflect.Append(extensions, reflect.ValueOf(meta.Extension)) + } + + if meta.OpenAIExtension != nil { + extType, isConsistent = validateExtensionType(extType, meta.OpenAIExtension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in response meta chunks: '%s' vs '%s'", + extType, reflect.TypeOf(meta.OpenAIExtension)) + } + openaiExtensions = append(openaiExtensions, meta.OpenAIExtension) + } + + if meta.ClaudeExtension != nil { + extType, isConsistent = validateExtensionType(extType, meta.ClaudeExtension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in response meta chunks: '%s' vs '%s'", + extType, reflect.TypeOf(meta.ClaudeExtension)) + } + claudeExtensions = append(claudeExtensions, meta.ClaudeExtension) + } + + if meta.GeminiExtension != nil { + extType, isConsistent = validateExtensionType(extType, meta.GeminiExtension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in response meta chunks: '%s' vs '%s'", + extType, reflect.TypeOf(meta.GeminiExtension)) + } + geminiExtensions = append(geminiExtensions, meta.GeminiExtension) + } + } + + ret = &AgenticResponseMeta{ + TokenUsage: concatTokenUsage(tokenUsages), + } + + if extensions.IsValid() && !extensions.IsZero() { + var extension reflect.Value + extension, err = internal.ConcatSliceValue(extensions) + if err != nil { + return nil, fmt.Errorf("failed to concat extensions: %w", err) + } + ret.Extension = extension.Interface() + } + + if len(openaiExtensions) > 0 { + ret.OpenAIExtension, err = openai.ConcatResponseMetaExtensions(openaiExtensions) + if err != nil { + return nil, fmt.Errorf("failed to concat openai extensions: %w", err) + } + } + + if len(claudeExtensions) > 0 { + ret.ClaudeExtension, err = claude.ConcatResponseMetaExtensions(claudeExtensions) + if err != nil { + return nil, fmt.Errorf("failed to concat claude extensions: %w", err) + } + } + + if len(geminiExtensions) > 0 { + ret.GeminiExtension, err = gemini.ConcatResponseMetaExtensions(geminiExtensions) + if err != nil { + return nil, fmt.Errorf("failed to concat gemini extensions: %w", err) } } + return ret, nil } -func concatAgenticContentBlocks(blocks []*ContentBlock) (*ContentBlock, error) { +func concatTokenUsage(usages []*TokenUsage) *TokenUsage { + if len(usages) == 0 { + return nil + } + + ret := &TokenUsage{} + + for _, usage := range usages { + if usage == nil { + continue + } + ret.CompletionTokens += usage.CompletionTokens + ret.CompletionTokensDetails.ReasoningTokens += usage.CompletionTokensDetails.ReasoningTokens + ret.PromptTokens += usage.PromptTokens + ret.PromptTokenDetails.CachedTokens += usage.PromptTokenDetails.CachedTokens + ret.TotalTokens += usage.TotalTokens + } + + return ret +} + +func concatChunksOfSameContentBlock(blocks []*ContentBlock) (*ContentBlock, error) { if len(blocks) == 0 { return nil, fmt.Errorf("no content blocks to concat") } + blockType := blocks[0].Type - index := blocks[0].StreamMeta.Index + switch blockType { case ContentBlockTypeReasoning: - return concatContentBlockHelper(blocks, blockType, "reasoning", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *Reasoning { return b.Reasoning }, - concatReasoning, - func(r *Reasoning) *ContentBlock { - return &ContentBlock{Type: blockType, Reasoning: r, StreamMeta: &StreamMeta{Index: index}} - }) + concatReasoning) case ContentBlockTypeUserInputText: - return concatContentBlockHelper(blocks, blockType, "user input text", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *UserInputText { return b.UserInputText }, - concatUserInputText, - func(t *UserInputText) *ContentBlock { - return &ContentBlock{Type: blockType, UserInputText: t, StreamMeta: &StreamMeta{Index: index}} - }) + concatUserInputTexts) case ContentBlockTypeUserInputImage: - return concatContentBlockHelper(blocks, blockType, "user input image", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *UserInputImage { return b.UserInputImage }, - concatUserInputImage, - func(i *UserInputImage) *ContentBlock { - return &ContentBlock{Type: blockType, UserInputImage: i, StreamMeta: &StreamMeta{Index: index}} - }) + concatUserInputImages) case ContentBlockTypeUserInputAudio: - return concatContentBlockHelper(blocks, blockType, "user input audio", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *UserInputAudio { return b.UserInputAudio }, - concatUserInputAudio, - func(a *UserInputAudio) *ContentBlock { - return &ContentBlock{Type: blockType, UserInputAudio: a, StreamMeta: &StreamMeta{Index: index}} - }) + concatUserInputAudios) case ContentBlockTypeUserInputVideo: - return concatContentBlockHelper(blocks, blockType, "user input video", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *UserInputVideo { return b.UserInputVideo }, - concatUserInputVideo, - func(v *UserInputVideo) *ContentBlock { - return &ContentBlock{Type: blockType, UserInputVideo: v, StreamMeta: &StreamMeta{Index: index}} - }) + concatUserInputVideos) case ContentBlockTypeUserInputFile: - return concatContentBlockHelper(blocks, blockType, "user input file", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *UserInputFile { return b.UserInputFile }, - concatUserInputFile, - func(f *UserInputFile) *ContentBlock { - return &ContentBlock{Type: blockType, UserInputFile: f, StreamMeta: &StreamMeta{Index: index}} - }) + concatUserInputFiles) case ContentBlockTypeAssistantGenText: - return concatContentBlockHelper(blocks, blockType, "assistant gen text", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *AssistantGenText { return b.AssistantGenText }, - concatAssistantGenText, - func(t *AssistantGenText) *ContentBlock { - return &ContentBlock{Type: blockType, AssistantGenText: t, StreamMeta: &StreamMeta{Index: index}} - }) + concatAssistantGenTexts) case ContentBlockTypeAssistantGenImage: - return concatContentBlockHelper(blocks, blockType, "assistant gen image", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *AssistantGenImage { return b.AssistantGenImage }, - concatAssistantGenImage, - func(i *AssistantGenImage) *ContentBlock { - return &ContentBlock{Type: blockType, AssistantGenImage: i, StreamMeta: &StreamMeta{Index: index}} - }) + concatAssistantGenImages) case ContentBlockTypeAssistantGenAudio: - return concatContentBlockHelper(blocks, blockType, "assistant gen audio", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *AssistantGenAudio { return b.AssistantGenAudio }, - concatAssistantGenAudio, - func(a *AssistantGenAudio) *ContentBlock { - return &ContentBlock{Type: blockType, AssistantGenAudio: a, StreamMeta: &StreamMeta{Index: index}} - }) + concatAssistantGenAudios) case ContentBlockTypeAssistantGenVideo: - return concatContentBlockHelper(blocks, blockType, "assistant gen video", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *AssistantGenVideo { return b.AssistantGenVideo }, - concatAssistantGenVideo, - func(v *AssistantGenVideo) *ContentBlock { - return &ContentBlock{Type: blockType, AssistantGenVideo: v, StreamMeta: &StreamMeta{Index: index}} - }) + concatAssistantGenVideos) case ContentBlockTypeFunctionToolCall: - return concatContentBlockHelper(blocks, blockType, "function tool call", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *FunctionToolCall { return b.FunctionToolCall }, - concatFunctionToolCall, - func(c *FunctionToolCall) *ContentBlock { - return &ContentBlock{Type: blockType, FunctionToolCall: c, StreamMeta: &StreamMeta{Index: index}} - }) + concatFunctionToolCalls) case ContentBlockTypeFunctionToolResult: - return concatContentBlockHelper(blocks, blockType, "function tool result", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *FunctionToolResult { return b.FunctionToolResult }, - concatFunctionToolResult, - func(r *FunctionToolResult) *ContentBlock { - return &ContentBlock{Type: blockType, FunctionToolResult: r, StreamMeta: &StreamMeta{Index: index}} - }) + concatFunctionToolResults) case ContentBlockTypeServerToolCall: - return concatContentBlockHelper(blocks, blockType, "server tool call", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *ServerToolCall { return b.ServerToolCall }, - concatServerToolCall, - func(c *ServerToolCall) *ContentBlock { - return &ContentBlock{Type: blockType, ServerToolCall: c, StreamMeta: &StreamMeta{Index: index}} - }) + concatServerToolCalls) case ContentBlockTypeServerToolResult: - return concatContentBlockHelper(blocks, blockType, "server tool result", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *ServerToolResult { return b.ServerToolResult }, - concatServerToolResult, - func(r *ServerToolResult) *ContentBlock { - return &ContentBlock{Type: blockType, ServerToolResult: r, StreamMeta: &StreamMeta{Index: index}} - }) + concatServerToolResults) case ContentBlockTypeMCPToolCall: - return concatContentBlockHelper(blocks, blockType, "MCP tool call", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *MCPToolCall { return b.MCPToolCall }, - concatMCPToolCall, - func(c *MCPToolCall) *ContentBlock { - return &ContentBlock{Type: blockType, MCPToolCall: c, StreamMeta: &StreamMeta{Index: index}} - }) + concatMCPToolCalls) case ContentBlockTypeMCPToolResult: - return concatContentBlockHelper(blocks, blockType, "MCP tool result", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *MCPToolResult { return b.MCPToolResult }, - concatMCPToolResult, - func(r *MCPToolResult) *ContentBlock { - return &ContentBlock{Type: blockType, MCPToolResult: r, StreamMeta: &StreamMeta{Index: index}} - }) + concatMCPToolResults) case ContentBlockTypeMCPListToolsResult: - return concatContentBlockHelper(blocks, blockType, "MCP list tools", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *MCPListToolsResult { return b.MCPListToolsResult }, - concatMCPListToolsResult, - func(r *MCPListToolsResult) *ContentBlock { - return &ContentBlock{Type: blockType, MCPListToolsResult: r, StreamMeta: &StreamMeta{Index: index}} - }) + concatMCPListToolsResults) case ContentBlockTypeMCPToolApprovalRequest: - return concatContentBlockHelper(blocks, blockType, "MCP tool approval request", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *MCPToolApprovalRequest { return b.MCPToolApprovalRequest }, - concatMCPToolApprovalRequest, - func(r *MCPToolApprovalRequest) *ContentBlock { - return &ContentBlock{Type: blockType, MCPToolApprovalRequest: r, StreamMeta: &StreamMeta{Index: index}} - }) + concatMCPToolApprovalRequests) case ContentBlockTypeMCPToolApprovalResponse: - return concatContentBlockHelper(blocks, blockType, "MCP tool approval response", + return concatContentBlockHelper(blocks, blockType, func(b *ContentBlock) *MCPToolApprovalResponse { return b.MCPToolApprovalResponse }, - concatMCPToolApprovalResponse, - func(r *MCPToolApprovalResponse) *ContentBlock { - return &ContentBlock{Type: blockType, MCPToolApprovalResponse: r, StreamMeta: &StreamMeta{Index: index}} - }) + concatMCPToolApprovalResponses) default: return nil, fmt.Errorf("unknown content block type: %s", blockType) @@ -964,21 +1010,19 @@ func concatAgenticContentBlocks(blocks []*ContentBlock) (*ContentBlock, error) { // concatContentBlockHelper is a generic helper function that reduces code duplication // for concatenating content blocks of a specific type. -func concatContentBlockHelper[T any]( +func concatContentBlockHelper[T contentBlockVariant]( blocks []*ContentBlock, expectedType ContentBlockType, - typeName string, getter func(*ContentBlock) *T, concatFunc func([]*T) (*T, error), - constructor func(*T) *ContentBlock, ) (*ContentBlock, error) { items, err := genericGetTFromContentBlocks(blocks, func(block *ContentBlock) (*T, error) { if block.Type != expectedType { - return nil, fmt.Errorf("expected %s block, got %s", typeName, block.Type) + return nil, fmt.Errorf("content block type mismatch: expected '%s', but got '%s'", expectedType, block.Type) } item := getter(block) if item == nil { - return nil, fmt.Errorf("%s content is nil", typeName) + return nil, fmt.Errorf("'%s' content is nil", expectedType) } return item, nil }) @@ -988,10 +1032,28 @@ func concatContentBlockHelper[T any]( concatenated, err := concatFunc(items) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to concat '%s' content blocks: %w", expectedType, err) + } + + extras := make([]map[string]any, 0, len(blocks)) + for _, block := range blocks { + if len(block.Extra) > 0 { + extras = append(extras, block.Extra) + } + } + + var extra map[string]any + if len(extras) > 0 { + extra, err = internal.ConcatItems(extras) + if err != nil { + return nil, fmt.Errorf("failed to concat content block extras: %w", err) + } } - return constructor(concatenated), nil + block := NewContentBlock(concatenated) + block.Extra = extra + + return block, nil } func genericGetTFromContentBlocks[T any](blocks []*ContentBlock, checkAndGetter func(block *ContentBlock) (T, error)) ([]T, error) { @@ -1006,43 +1068,14 @@ func genericGetTFromContentBlocks[T any](blocks []*ContentBlock, checkAndGetter return ret, nil } -// Concatenation strategies for different content block types: -// -// String concatenation (incremental streaming): -// - Reasoning: Summary texts are concatenated, grouped by Index if present -// - UserInputText: Text fields are concatenated -// - AssistantGenText: Text fields are concatenated, annotations/citations are merged -// - FunctionToolCall: Arguments (JSON strings) are concatenated incrementally -// - FunctionToolResult: Result strings are concatenated -// - ServerToolCall: Arguments are merged (last non-nil value for any type) -// - ServerToolResult: Results are merged using internal.ConcatItems -// - MCPToolCall: Arguments (JSON strings) are concatenated incrementally -// - MCPToolResult: Result strings are concatenated -// - MCPListToolsResult: Tools arrays are merged -// - MCPToolApprovalRequest: Arguments are concatenated -// -// Take last block (non-streaming content): -// - UserInputImage, UserInputAudio, UserInputVideo, UserInputFile: Return last block -// - AssistantGenImage, AssistantGenAudio, AssistantGenVideo: Return last block -// - MCPToolApprovalResponse: Return last block -// - func concatReasoning(reasons []*Reasoning) (*Reasoning, error) { if len(reasons) == 0 { return nil, fmt.Errorf("no reasoning found") } - if len(reasons) == 1 { - return reasons[0], nil - } - ret := &Reasoning{ - Summary: make([]*ReasoningSummary, 0), - EncryptedContent: "", - Extra: make(map[string]any), - } + ret := &Reasoning{} - // Collect all summaries from all reasons - allSummaries := make([]*ReasoningSummary, 0) + var allSummaries []*ReasoningSummary for _, r := range reasons { if r == nil { continue @@ -1051,157 +1084,269 @@ func concatReasoning(reasons []*Reasoning) (*Reasoning, error) { if r.EncryptedContent != "" { ret.EncryptedContent += r.EncryptedContent } - for k, v := range r.Extra { - ret.Extra[k] = v - } } - // Group by Index and concatenate Text for same Index - // Use dynamic array that expands as needed - var summaryArray []*ReasoningSummary + var ( + indices []int + indexToSummary = map[int]*ReasoningSummary{} + ) + for _, s := range allSummaries { - idx := s.Index - // Expand array if needed - summaryArray = expandSlice(int(idx), summaryArray) - if summaryArray[idx] == nil { - // Create new entry with a copy of Index - summaryArray[idx] = &ReasoningSummary{ - Index: idx, - Text: s.Text, - } - } else { - // Concatenate text for same index - summaryArray[idx].Text += s.Text + if s == nil { + continue + } + if indexToSummary[s.Index] == nil { + indexToSummary[s.Index] = &ReasoningSummary{} + indices = append(indices, s.Index) } + indexToSummary[s.Index].Text += s.Text } - // Convert array to slice, filtering out nil entries - ret.Summary = make([]*ReasoningSummary, 0, len(summaryArray)) - for _, summary := range summaryArray { - if summary != nil { - ret.Summary = append(ret.Summary, summary) - } + sort.Slice(indices, func(i, j int) bool { + return indices[i] < indices[j] + }) + + ret.Summary = make([]*ReasoningSummary, 0, len(indices)) + for _, idx := range indices { + ret.Summary = append(ret.Summary, indexToSummary[idx]) } return ret, nil } -func concatUserInputText(texts []*UserInputText) (*UserInputText, error) { +func concatUserInputTexts(texts []*UserInputText) (*UserInputText, error) { if len(texts) == 0 { return nil, fmt.Errorf("no user input text found") } if len(texts) == 1 { return texts[0], nil } - - ret := &UserInputText{ - Text: "", - Extra: make(map[string]any), - } - - for _, t := range texts { - if t == nil { - continue - } - ret.Text += t.Text - for k, v := range t.Extra { - ret.Extra[k] = v - } - } - - return ret, nil + return nil, fmt.Errorf("cannot concat multiple user input texts") } -func concatUserInputImage(images []*UserInputImage) (*UserInputImage, error) { +func concatUserInputImages(images []*UserInputImage) (*UserInputImage, error) { if len(images) == 0 { return nil, fmt.Errorf("no user input image found") } - return images[len(images)-1], nil + if len(images) == 1 { + return images[0], nil + } + return nil, fmt.Errorf("cannot concat multiple user input images") } -func concatUserInputAudio(audios []*UserInputAudio) (*UserInputAudio, error) { +func concatUserInputAudios(audios []*UserInputAudio) (*UserInputAudio, error) { if len(audios) == 0 { return nil, fmt.Errorf("no user input audio found") } - return audios[len(audios)-1], nil + if len(audios) == 1 { + return audios[0], nil + } + return nil, fmt.Errorf("cannot concat multiple user input audios") } -func concatUserInputVideo(videos []*UserInputVideo) (*UserInputVideo, error) { +func concatUserInputVideos(videos []*UserInputVideo) (*UserInputVideo, error) { if len(videos) == 0 { return nil, fmt.Errorf("no user input video found") } - return videos[len(videos)-1], nil + if len(videos) == 1 { + return videos[0], nil + } + return nil, fmt.Errorf("cannot concat multiple user input videos") } -func concatUserInputFile(files []*UserInputFile) (*UserInputFile, error) { +func concatUserInputFiles(files []*UserInputFile) (*UserInputFile, error) { if len(files) == 0 { return nil, fmt.Errorf("no user input file found") } - return files[len(files)-1], nil + if len(files) == 1 { + return files[0], nil + } + return nil, fmt.Errorf("cannot concat multiple user input files") } -func concatAssistantGenText(texts []*AssistantGenText) (*AssistantGenText, error) { +func concatAssistantGenTexts(texts []*AssistantGenText) (ret *AssistantGenText, err error) { if len(texts) == 0 { - return nil, fmt.Errorf("no assistant gen text found") + return nil, fmt.Errorf("no assistant generated text found") } if len(texts) == 1 { return texts[0], nil } - ret := &AssistantGenText{ - Text: "", - OpenAIExtension: nil, - ClaudeExtension: nil, - Extra: make(map[string]any), - } + ret = &AssistantGenText{} + + openaiExtensions := make([]*openai.AssistantGenTextExtension, 0, len(texts)) + claudeExtensions := make([]*claude.AssistantGenTextExtension, 0, len(texts)) + + var ( + extType reflect.Type + extensions reflect.Value + ) for _, t := range texts { if t == nil { continue } + ret.Text += t.Text + + var isConsistent bool + + if t.Extension != nil { + extType, isConsistent = validateExtensionType(extType, t.Extension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in assistant generated text chunks: '%s' vs '%s'", + extType, reflect.TypeOf(t.Extension)) + } + if !extensions.IsValid() { + extensions = reflect.MakeSlice(reflect.SliceOf(extType), 0, len(texts)) + } + extensions = reflect.Append(extensions, reflect.ValueOf(t.Extension)) + } + if t.OpenAIExtension != nil { - if ret.OpenAIExtension == nil { - ret.OpenAIExtension = &openai.AssistantGenTextExtension{} + extType, isConsistent = validateExtensionType(extType, t.OpenAIExtension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in assistant generated text chunks: '%s' vs '%s'", + extType, reflect.TypeOf(t.OpenAIExtension)) } - ret.OpenAIExtension.Annotations = append(ret.OpenAIExtension.Annotations, t.OpenAIExtension.Annotations...) + openaiExtensions = append(openaiExtensions, t.OpenAIExtension) } + if t.ClaudeExtension != nil { - if ret.ClaudeExtension == nil { - ret.ClaudeExtension = &claude.AssistantGenTextExtension{} + extType, isConsistent = validateExtensionType(extType, t.ClaudeExtension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in assistant generated text chunks: '%s' vs '%s'", + extType, reflect.TypeOf(t.ClaudeExtension)) } - ret.ClaudeExtension.Citations = append(ret.ClaudeExtension.Citations, t.ClaudeExtension.Citations...) + claudeExtensions = append(claudeExtensions, t.ClaudeExtension) } - for k, v := range t.Extra { - ret.Extra[k] = v + } + + if extensions.IsValid() && !extensions.IsZero() { + ret.Extension, err = internal.ConcatSliceValue(extensions) + if err != nil { + return nil, err + } + ret.Extension = extensions.Interface() + } + + if len(openaiExtensions) > 0 { + ret.OpenAIExtension, err = openai.ConcatAssistantGenTextExtensions(openaiExtensions) + if err != nil { + return nil, err + } + } + + if len(claudeExtensions) > 0 { + ret.ClaudeExtension, err = claude.ConcatAssistantGenTextExtensions(claudeExtensions) + if err != nil { + return nil, err } } return ret, nil } -func concatAssistantGenImage(images []*AssistantGenImage) (*AssistantGenImage, error) { +func concatAssistantGenImages(images []*AssistantGenImage) (*AssistantGenImage, error) { if len(images) == 0 { return nil, fmt.Errorf("no assistant gen image found") } - return images[len(images)-1], nil + if len(images) == 1 { + return images[0], nil + } + + ret := &AssistantGenImage{} + + for _, img := range images { + if img == nil { + continue + } + + ret.Base64Data += img.Base64Data + + if ret.URL == "" { + ret.URL = img.URL + } else if img.URL != "" && ret.URL != img.URL { + return nil, fmt.Errorf("inconsistent URLs in assistant generated image chunks: '%s' vs '%s'", ret.URL, img.URL) + } + + if ret.MIMEType == "" { + ret.MIMEType = img.MIMEType + } else if img.MIMEType != "" && ret.MIMEType != img.MIMEType { + return nil, fmt.Errorf("inconsistent MIME types in assistant generated image chunks: '%s' vs '%s'", ret.MIMEType, img.MIMEType) + } + } + + return ret, nil } -func concatAssistantGenAudio(audios []*AssistantGenAudio) (*AssistantGenAudio, error) { +func concatAssistantGenAudios(audios []*AssistantGenAudio) (*AssistantGenAudio, error) { if len(audios) == 0 { return nil, fmt.Errorf("no assistant gen audio found") } - return audios[len(audios)-1], nil + if len(audios) == 1 { + return audios[0], nil + } + + ret := &AssistantGenAudio{} + + for _, audio := range audios { + if audio == nil { + continue + } + + ret.Base64Data += audio.Base64Data + + if ret.URL == "" { + ret.URL = audio.URL + } else if audio.URL != "" && ret.URL != audio.URL { + return nil, fmt.Errorf("inconsistent URLs in assistant generated audio chunks: '%s' vs '%s'", ret.URL, audio.URL) + } + + if ret.MIMEType == "" { + ret.MIMEType = audio.MIMEType + } else if audio.MIMEType != "" && ret.MIMEType != audio.MIMEType { + return nil, fmt.Errorf("inconsistent MIME types in assistant generated audio chunks: '%s' vs '%s'", ret.MIMEType, audio.MIMEType) + } + } + + return ret, nil } -func concatAssistantGenVideo(videos []*AssistantGenVideo) (*AssistantGenVideo, error) { +func concatAssistantGenVideos(videos []*AssistantGenVideo) (*AssistantGenVideo, error) { if len(videos) == 0 { return nil, fmt.Errorf("no assistant gen video found") } - return videos[len(videos)-1], nil + if len(videos) == 1 { + return videos[0], nil + } + + ret := &AssistantGenVideo{} + + for _, video := range videos { + if video == nil { + continue + } + + ret.Base64Data += video.Base64Data + + if ret.URL == "" { + ret.URL = video.URL + } else if video.URL != "" && ret.URL != video.URL { + return nil, fmt.Errorf("inconsistent URLs in assistant generated video chunks: '%s' vs '%s'", ret.URL, video.URL) + } + + if ret.MIMEType == "" { + ret.MIMEType = video.MIMEType + } else if video.MIMEType != "" && ret.MIMEType != video.MIMEType { + return nil, fmt.Errorf("inconsistent MIME types in assistant generated video chunks: '%s' vs '%s'", ret.MIMEType, video.MIMEType) + } + } + + return ret, nil } -func concatFunctionToolCall(calls []*FunctionToolCall) (*FunctionToolCall, error) { +func concatFunctionToolCalls(calls []*FunctionToolCall) (*FunctionToolCall, error) { if len(calls) == 0 { return nil, fmt.Errorf("no function tool call found") } @@ -1209,31 +1354,32 @@ func concatFunctionToolCall(calls []*FunctionToolCall) (*FunctionToolCall, error return calls[0], nil } - // For tool calls, arguments are typically built incrementally during streaming - ret := &FunctionToolCall{ - Extra: make(map[string]any), - } + ret := &FunctionToolCall{} for _, c := range calls { if c == nil { continue } + if ret.CallID == "" { ret.CallID = c.CallID + } else if c.CallID != "" && c.CallID != ret.CallID { + return nil, fmt.Errorf("expected call ID '%s' for function tool call, but got '%s'", ret.CallID, c.CallID) } + if ret.Name == "" { ret.Name = c.Name + } else if c.Name != "" && c.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for function tool call, but got '%s'", ret.Name, c.Name) } + ret.Arguments += c.Arguments - for k, v := range c.Extra { - ret.Extra[k] = v - } } return ret, nil } -func concatFunctionToolResult(results []*FunctionToolResult) (*FunctionToolResult, error) { +func concatFunctionToolResults(results []*FunctionToolResult) (*FunctionToolResult, error) { if len(results) == 0 { return nil, fmt.Errorf("no function tool result found") } @@ -1241,30 +1387,32 @@ func concatFunctionToolResult(results []*FunctionToolResult) (*FunctionToolResul return results[0], nil } - ret := &FunctionToolResult{ - Extra: make(map[string]any), - } + ret := &FunctionToolResult{} for _, r := range results { if r == nil { continue } + if ret.CallID == "" { ret.CallID = r.CallID + } else if r.CallID != "" && r.CallID != ret.CallID { + return nil, fmt.Errorf("expected call ID '%s' for function tool result, but got '%s'", ret.CallID, r.CallID) } + if ret.Name == "" { ret.Name = r.Name + } else if r.Name != "" && r.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for function tool result, but got '%s'", ret.Name, r.Name) } + ret.Result += r.Result - for k, v := range r.Extra { - ret.Extra[k] = v - } } return ret, nil } -func concatServerToolCall(calls []*ServerToolCall) (*ServerToolCall, error) { +func concatServerToolCalls(calls []*ServerToolCall) (ret *ServerToolCall, err error) { if len(calls) == 0 { return nil, fmt.Errorf("no server tool call found") } @@ -1272,33 +1420,54 @@ func concatServerToolCall(calls []*ServerToolCall) (*ServerToolCall, error) { return calls[0], nil } - // ServerToolCall Arguments is of type any; merge strategy uses the last non-nil value - ret := &ServerToolCall{ - Extra: make(map[string]any), - } + ret = &ServerToolCall{} + + var ( + argsType reflect.Type + argsChunks reflect.Value + ) for _, c := range calls { if c == nil { continue } - if ret.Name == "" { - ret.Name = c.Name - } + if ret.CallID == "" { ret.CallID = c.CallID + } else if c.CallID != "" && c.CallID != ret.CallID { + return nil, fmt.Errorf("expected call ID '%s' for server tool call, but got '%s'", ret.CallID, c.CallID) + } + + if ret.Name == "" { + ret.Name = c.Name + } else if c.Name != "" && c.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for server tool call, but got '%s'", ret.Name, c.Name) } + if c.Arguments != nil { - ret.Arguments = c.Arguments + argsType_ := reflect.TypeOf(c.Arguments) + if argsType == nil { + argsType = argsType_ + argsChunks = reflect.MakeSlice(reflect.SliceOf(argsType), 0, len(calls)) + } else if argsType != argsType_ { + return nil, fmt.Errorf("expected type '%s' for server tool call arguments, but got '%s'", argsType, argsType_) + } + argsChunks = reflect.Append(argsChunks, reflect.ValueOf(c.Arguments)) } - for k, v := range c.Extra { - ret.Extra[k] = v + } + + if argsChunks.IsValid() && !argsChunks.IsZero() { + arguments, err := internal.ConcatSliceValue(argsChunks) + if err != nil { + return nil, err } + ret.Arguments = arguments.Interface() } return ret, nil } -func concatServerToolResult(results []*ServerToolResult) (*ServerToolResult, error) { +func concatServerToolResults(results []*ServerToolResult) (ret *ServerToolResult, err error) { if len(results) == 0 { return nil, fmt.Errorf("no server tool result found") } @@ -1306,45 +1475,54 @@ func concatServerToolResult(results []*ServerToolResult) (*ServerToolResult, err return results[0], nil } - // ServerToolResult Result is of type any; merge strategy uses the last non-nil value - ret := &ServerToolResult{ - Extra: make(map[string]any), - } + ret = &ServerToolResult{} + + var ( + resType reflect.Type + resChunks reflect.Value + ) - tZeroResult := reflect.TypeOf(results[0].Result) - data := reflect.MakeSlice(reflect.SliceOf(tZeroResult), 0, 0) for _, r := range results { if r == nil { continue } - if ret.Name == "" { - ret.Name = r.Name - } + if ret.CallID == "" { ret.CallID = r.CallID + } else if r.CallID != "" && r.CallID != ret.CallID { + return nil, fmt.Errorf("expected call ID '%s' for server tool result, but got '%s'", ret.CallID, r.CallID) } + + if ret.Name == "" { + ret.Name = r.Name + } else if r.Name != "" && r.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for server tool result, but got '%s'", ret.Name, r.Name) + } + if r.Result != nil { - vResult := reflect.ValueOf(r.Result) - if tZeroResult != vResult.Type() { - return nil, fmt.Errorf("tool result types are different: %v %v", tZeroResult, vResult.Type()) + resType_ := reflect.TypeOf(r.Result) + if resType == nil { + resType = resType_ + resChunks = reflect.MakeSlice(reflect.SliceOf(resType), 0, len(results)) + } else if resType != resType_ { + return nil, fmt.Errorf("expected type '%s' for server tool result, but got '%s'", resType, resType_) } - data = reflect.Append(data, vResult) - } - for k, v := range r.Extra { - ret.Extra[k] = v + resChunks = reflect.Append(resChunks, reflect.ValueOf(r.Result)) } } - d, err := internal.ConcatSliceValue(data) - if err != nil { - return nil, fmt.Errorf("failed to concat server tool result: %v", err) + if resChunks.IsValid() && !resChunks.IsZero() { + result, err := internal.ConcatSliceValue(resChunks) + if err != nil { + return nil, fmt.Errorf("failed to concat server tool result: %v", err) + } + ret.Result = result.Interface() } - ret.Result = d return ret, nil } -func concatMCPToolCall(calls []*MCPToolCall) (*MCPToolCall, error) { +func concatMCPToolCalls(calls []*MCPToolCall) (*MCPToolCall, error) { if len(calls) == 0 { return nil, fmt.Errorf("no mcp tool call found") } @@ -1352,36 +1530,38 @@ func concatMCPToolCall(calls []*MCPToolCall) (*MCPToolCall, error) { return calls[0], nil } - ret := &MCPToolCall{ - Extra: make(map[string]any), - } + ret := &MCPToolCall{} for _, c := range calls { if c == nil { continue } + + ret.Arguments += c.Arguments + if ret.ServerLabel == "" { ret.ServerLabel = c.ServerLabel + } else if c.ServerLabel != "" && c.ServerLabel != ret.ServerLabel { + return nil, fmt.Errorf("expected server label '%s' for mcp tool call, but got '%s'", ret.ServerLabel, c.ServerLabel) } - if ret.ApprovalRequestID == "" { - ret.ApprovalRequestID = c.ApprovalRequestID - } + if ret.CallID == "" { ret.CallID = c.CallID + } else if c.CallID != "" && c.CallID != ret.CallID { + return nil, fmt.Errorf("expected call ID '%s' for mcp tool call, but got '%s'", ret.CallID, c.CallID) } + if ret.Name == "" { ret.Name = c.Name - } - ret.Arguments += c.Arguments - for k, v := range c.Extra { - ret.Extra[k] = v + } else if c.Name != "" && c.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for mcp tool call, but got '%s'", ret.Name, c.Name) } } return ret, nil } -func concatMCPToolResult(results []*MCPToolResult) (*MCPToolResult, error) { +func concatMCPToolResults(results []*MCPToolResult) (*MCPToolResult, error) { if len(results) == 0 { return nil, fmt.Errorf("no mcp tool result found") } @@ -1389,33 +1569,44 @@ func concatMCPToolResult(results []*MCPToolResult) (*MCPToolResult, error) { return results[0], nil } - ret := &MCPToolResult{ - Extra: make(map[string]any), - } + ret := &MCPToolResult{} for _, r := range results { if r == nil { continue } + + if r.Result != "" { + ret.Result = r.Result + } + + if ret.ServerLabel == "" { + ret.ServerLabel = r.ServerLabel + } else if r.ServerLabel != "" && r.ServerLabel != ret.ServerLabel { + return nil, fmt.Errorf("expected server label '%s' for mcp tool result, but got '%s'", ret.ServerLabel, r.ServerLabel) + } + if ret.CallID == "" { ret.CallID = r.CallID + } else if r.CallID != "" && r.CallID != ret.CallID { + return nil, fmt.Errorf("expected call ID '%s' for mcp tool result, but got '%s'", ret.CallID, r.CallID) } + if ret.Name == "" { ret.Name = r.Name + } else if r.Name != "" && r.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for mcp tool result, but got '%s'", ret.Name, r.Name) } - ret.Result += r.Result + if r.Error != nil { - ret.Error = r.Error // Use the last error - } - for k, v := range r.Extra { - ret.Extra[k] = v + ret.Error = r.Error } } return ret, nil } -func concatMCPListToolsResult(results []*MCPListToolsResult) (*MCPListToolsResult, error) { +func concatMCPListToolsResults(results []*MCPListToolsResult) (*MCPListToolsResult, error) { if len(results) == 0 { return nil, fmt.Errorf("no mcp list tools result found") } @@ -1423,31 +1614,30 @@ func concatMCPListToolsResult(results []*MCPListToolsResult) (*MCPListToolsResul return results[0], nil } - ret := &MCPListToolsResult{ - Tools: make([]*MCPListToolsItem, 0), - Extra: make(map[string]any), - } + ret := &MCPListToolsResult{} for _, r := range results { if r == nil { continue } - if ret.ServerLabel == "" { - ret.ServerLabel = r.ServerLabel - } + ret.Tools = append(ret.Tools, r.Tools...) + if r.Error != "" { - ret.Error = r.Error // Use the last error + ret.Error = r.Error } - for k, v := range r.Extra { - ret.Extra[k] = v + + if ret.ServerLabel == "" { + ret.ServerLabel = r.ServerLabel + } else if r.ServerLabel != "" && r.ServerLabel != ret.ServerLabel { + return nil, fmt.Errorf("expected server label '%s' for mcp list tools result, but got '%s'", ret.ServerLabel, r.ServerLabel) } } return ret, nil } -func concatMCPToolApprovalRequest(requests []*MCPToolApprovalRequest) (*MCPToolApprovalRequest, error) { +func concatMCPToolApprovalRequests(requests []*MCPToolApprovalRequest) (*MCPToolApprovalRequest, error) { if len(requests) == 0 { return nil, fmt.Errorf("no mcp tool approval request found") } @@ -1455,50 +1645,48 @@ func concatMCPToolApprovalRequest(requests []*MCPToolApprovalRequest) (*MCPToolA return requests[0], nil } - ret := &MCPToolApprovalRequest{ - Extra: make(map[string]any), - } + ret := &MCPToolApprovalRequest{} for _, r := range requests { if r == nil { continue } + + ret.Arguments += r.Arguments + if ret.ID == "" { ret.ID = r.ID + } else if r.ID != "" && r.ID != ret.ID { + return nil, fmt.Errorf("expected request ID '%s' for mcp tool approval request, but got '%s'", ret.ID, r.ID) } + if ret.Name == "" { ret.Name = r.Name + } else if r.Name != "" && r.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for mcp tool approval request, but got '%s'", ret.Name, r.Name) } - ret.Arguments += r.Arguments + if ret.ServerLabel == "" { ret.ServerLabel = r.ServerLabel - } - for k, v := range r.Extra { - ret.Extra[k] = v + } else if r.ServerLabel != "" && r.ServerLabel != ret.ServerLabel { + return nil, fmt.Errorf("expected server label '%s' for mcp tool approval request, but got '%s'", ret.ServerLabel, r.ServerLabel) } } return ret, nil } -func concatMCPToolApprovalResponse(responses []*MCPToolApprovalResponse) (*MCPToolApprovalResponse, error) { +func concatMCPToolApprovalResponses(responses []*MCPToolApprovalResponse) (*MCPToolApprovalResponse, error) { if len(responses) == 0 { return nil, fmt.Errorf("no mcp tool approval response found") } if len(responses) == 1 { return responses[0], nil } - - return responses[len(responses)-1], nil -} - -func expandSlice[T any](idx int, s []T) []T { - if len(s) > idx { - return s - } - return append(s, make([]T, idx-len(s)+1)...) + return nil, fmt.Errorf("cannot concat multiple mcp tool approval responses") } +// String returns the string representation of AgenticMessage. func (m *AgenticMessage) String() string { sb := &strings.Builder{} sb.WriteString(fmt.Sprintf("role: %s\n", m.Role)) @@ -1520,6 +1708,7 @@ func (m *AgenticMessage) String() string { return sb.String() } +// String returns the string representation of ContentBlock. func (b *ContentBlock) String() string { sb := &strings.Builder{} sb.WriteString(fmt.Sprintf("type: %s\n", b.Type)) @@ -1603,13 +1792,14 @@ func (b *ContentBlock) String() string { } } - if b.StreamMeta != nil { - sb.WriteString(fmt.Sprintf(" stream_index: %d\n", b.StreamMeta.Index)) + if b.StreamingMeta != nil { + sb.WriteString(fmt.Sprintf(" stream_index: %d\n", b.StreamingMeta.Index)) } return sb.String() } +// String returns the string representation of Reasoning. func (r *Reasoning) String() string { sb := &strings.Builder{} sb.WriteString(fmt.Sprintf(" summary: %d items\n", len(r.Summary))) @@ -1622,22 +1812,27 @@ func (r *Reasoning) String() string { return sb.String() } +// String returns the string representation of UserInputText. func (u *UserInputText) String() string { return fmt.Sprintf(" text: %s\n", u.Text) } +// String returns the string representation of UserInputImage. func (u *UserInputImage) String() string { return formatMediaString(u.URL, u.Base64Data, u.MIMEType, u.Detail) } +// String returns the string representation of UserInputAudio. func (u *UserInputAudio) String() string { return formatMediaString(u.URL, u.Base64Data, u.MIMEType, "") } +// String returns the string representation of UserInputVideo. func (u *UserInputVideo) String() string { return formatMediaString(u.URL, u.Base64Data, u.MIMEType, "") } +// String returns the string representation of UserInputFile. func (u *UserInputFile) String() string { sb := &strings.Builder{} if u.Name != "" { @@ -1647,22 +1842,27 @@ func (u *UserInputFile) String() string { return sb.String() } +// String returns the string representation of AssistantGenText. func (a *AssistantGenText) String() string { return fmt.Sprintf(" text: %s\n", a.Text) } +// String returns the string representation of AssistantGenImage. func (a *AssistantGenImage) String() string { return formatMediaString(a.URL, a.Base64Data, a.MIMEType, "") } +// String returns the string representation of AssistantGenAudio. func (a *AssistantGenAudio) String() string { return formatMediaString(a.URL, a.Base64Data, a.MIMEType, "") } +// String returns the string representation of AssistantGenVideo. func (a *AssistantGenVideo) String() string { return formatMediaString(a.URL, a.Base64Data, a.MIMEType, "") } +// String returns the string representation of FunctionToolCall. func (f *FunctionToolCall) String() string { sb := &strings.Builder{} sb.WriteString(fmt.Sprintf(" call_id: %s\n", f.CallID)) @@ -1671,6 +1871,7 @@ func (f *FunctionToolCall) String() string { return sb.String() } +// String returns the string representation of FunctionToolResult. func (f *FunctionToolResult) String() string { sb := &strings.Builder{} sb.WriteString(fmt.Sprintf(" call_id: %s\n", f.CallID)) @@ -1679,6 +1880,7 @@ func (f *FunctionToolResult) String() string { return sb.String() } +// String returns the string representation of ServerToolCall. func (s *ServerToolCall) String() string { sb := &strings.Builder{} sb.WriteString(fmt.Sprintf(" name: %s\n", s.Name)) @@ -1689,6 +1891,7 @@ func (s *ServerToolCall) String() string { return sb.String() } +// String returns the string representation of ServerToolResult. func (s *ServerToolResult) String() string { sb := &strings.Builder{} sb.WriteString(fmt.Sprintf(" name: %s\n", s.Name)) @@ -1699,18 +1902,17 @@ func (s *ServerToolResult) String() string { return sb.String() } +// String returns the string representation of MCPToolCall. func (m *MCPToolCall) String() string { sb := &strings.Builder{} sb.WriteString(fmt.Sprintf(" server_label: %s\n", m.ServerLabel)) sb.WriteString(fmt.Sprintf(" call_id: %s\n", m.CallID)) sb.WriteString(fmt.Sprintf(" name: %s\n", m.Name)) sb.WriteString(fmt.Sprintf(" arguments: %s\n", m.Arguments)) - if m.ApprovalRequestID != "" { - sb.WriteString(fmt.Sprintf(" approval_request_id: %s\n", m.ApprovalRequestID)) - } return sb.String() } +// String returns the string representation of MCPToolResult. func (m *MCPToolResult) String() string { sb := &strings.Builder{} sb.WriteString(fmt.Sprintf(" call_id: %s\n", m.CallID)) @@ -1722,6 +1924,7 @@ func (m *MCPToolResult) String() string { return sb.String() } +// String returns the string representation of MCPListToolsResult. func (m *MCPListToolsResult) String() string { sb := &strings.Builder{} sb.WriteString(fmt.Sprintf(" server_label: %s\n", m.ServerLabel)) @@ -1735,6 +1938,7 @@ func (m *MCPListToolsResult) String() string { return sb.String() } +// String returns the string representation of MCPToolApprovalRequest. func (m *MCPToolApprovalRequest) String() string { sb := &strings.Builder{} sb.WriteString(fmt.Sprintf(" server_label: %s\n", m.ServerLabel)) @@ -1744,6 +1948,7 @@ func (m *MCPToolApprovalRequest) String() string { return sb.String() } +// String returns the string representation of MCPToolApprovalResponse. func (m *MCPToolApprovalResponse) String() string { sb := &strings.Builder{} sb.WriteString(fmt.Sprintf(" approval_request_id: %s\n", m.ApprovalRequestID)) @@ -1754,6 +1959,7 @@ func (m *MCPToolApprovalResponse) String() string { return sb.String() } +// String returns the string representation of AgenticResponseMeta. func (a *AgenticResponseMeta) String() string { sb := &strings.Builder{} sb.WriteString("response_meta:\n") @@ -1792,3 +1998,17 @@ func formatMediaString(url, base64Data string, mimeType string, detail any) stri } return sb.String() } + +func validateExtensionType(expected reflect.Type, actual any) (reflect.Type, bool) { + if actual == nil { + return expected, true + } + actualType := reflect.TypeOf(actual) + if expected == nil { + return actualType, true + } + if expected != actualType { + return expected, false + } + return expected, true +} diff --git a/schema/agentic_message_test.go b/schema/agentic_message_test.go index 0cafcd9ff..016aa5c4e 100644 --- a/schema/agentic_message_test.go +++ b/schema/agentic_message_test.go @@ -18,6 +18,7 @@ package schema import ( "context" + "reflect" "testing" "github.com/stretchr/testify/assert" @@ -75,7 +76,7 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "Hello ", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -87,7 +88,7 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "World!", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -112,7 +113,7 @@ func TestConcatAgenticMessages(t *testing.T) { {Index: 0, Text: "First "}, }, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -126,7 +127,7 @@ func TestConcatAgenticMessages(t *testing.T) { {Index: 0, Text: "Second"}, }, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -137,7 +138,7 @@ func TestConcatAgenticMessages(t *testing.T) { assert.Len(t, result.ContentBlocks, 1) assert.Len(t, result.ContentBlocks[0].Reasoning.Summary, 1) assert.Equal(t, "First Second", result.ContentBlocks[0].Reasoning.Summary[0].Text) - assert.Equal(t, int64(0), result.ContentBlocks[0].Reasoning.Summary[0].Index) + assert.Equal(t, 0, result.ContentBlocks[0].Reasoning.Summary[0].Index) }) t.Run("concat reasoning with index", func(t *testing.T) { @@ -153,7 +154,7 @@ func TestConcatAgenticMessages(t *testing.T) { {Index: 1, Text: "Part2-"}, }, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -168,7 +169,7 @@ func TestConcatAgenticMessages(t *testing.T) { {Index: 1, Text: "Part4"}, }, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -185,26 +186,26 @@ func TestConcatAgenticMessages(t *testing.T) { t.Run("concat user input text", func(t *testing.T) { msgs := []*AgenticMessage{ { - Role: AgenticRoleTypeUser, + Role: AgenticRoleTypeAssistant, ContentBlocks: []*ContentBlock{ { - Type: ContentBlockTypeUserInputText, - UserInputText: &UserInputText{ + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ Text: "Hello ", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, { - Role: AgenticRoleTypeUser, + Role: AgenticRoleTypeAssistant, ContentBlocks: []*ContentBlock{ { - Type: ContentBlockTypeUserInputText, - UserInputText: &UserInputText{ + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ Text: "World!", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -213,35 +214,35 @@ func TestConcatAgenticMessages(t *testing.T) { result, err := ConcatAgenticMessages(msgs) assert.NoError(t, err) assert.Len(t, result.ContentBlocks, 1) - assert.Equal(t, "Hello World!", result.ContentBlocks[0].UserInputText.Text) + assert.Equal(t, "Hello World!", result.ContentBlocks[0].AssistantGenText.Text) }) - t.Run("concat user input image", func(t *testing.T) { - url1 := "https://example.com/image1.jpg" - url2 := "https://example.com/image2.jpg" + t.Run("concat assistant gen image", func(t *testing.T) { + base1 := "1" + base2 := "2" msgs := []*AgenticMessage{ { - Role: AgenticRoleTypeUser, + Role: AgenticRoleTypeAssistant, ContentBlocks: []*ContentBlock{ { - Type: ContentBlockTypeUserInputImage, - UserInputImage: &UserInputImage{ - URL: url1, + Type: ContentBlockTypeAssistantGenImage, + AssistantGenImage: &AssistantGenImage{ + Base64Data: base1, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, { - Role: AgenticRoleTypeUser, + Role: AgenticRoleTypeAssistant, ContentBlocks: []*ContentBlock{ { - Type: ContentBlockTypeUserInputImage, - UserInputImage: &UserInputImage{ - URL: url2, + Type: ContentBlockTypeAssistantGenImage, + AssistantGenImage: &AssistantGenImage{ + Base64Data: base2, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -250,11 +251,10 @@ func TestConcatAgenticMessages(t *testing.T) { result, err := ConcatAgenticMessages(msgs) assert.NoError(t, err) assert.Len(t, result.ContentBlocks, 1) - // Should take the last image - assert.Equal(t, url2, result.ContentBlocks[0].UserInputImage.URL) + assert.Equal(t, "12", result.ContentBlocks[0].AssistantGenImage.Base64Data) }) - t.Run("concat user input audio", func(t *testing.T) { + t.Run("concat user input audio - should error", func(t *testing.T) { url1 := "https://example.com/audio1.mp3" url2 := "https://example.com/audio2.mp3" @@ -267,7 +267,7 @@ func TestConcatAgenticMessages(t *testing.T) { UserInputAudio: &UserInputAudio{ URL: url1, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -279,20 +279,18 @@ func TestConcatAgenticMessages(t *testing.T) { UserInputAudio: &UserInputAudio{ URL: url2, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, } - result, err := ConcatAgenticMessages(msgs) - assert.NoError(t, err) - assert.Len(t, result.ContentBlocks, 1) - // Should take the last audio - assert.Equal(t, url2, result.ContentBlocks[0].UserInputAudio.URL) + _, err := ConcatAgenticMessages(msgs) + assert.Error(t, err) + assert.ErrorContains(t, err, "cannot concat multiple user input audios") }) - t.Run("concat user input video", func(t *testing.T) { + t.Run("concat user input video - should error", func(t *testing.T) { url1 := "https://example.com/video1.mp4" url2 := "https://example.com/video2.mp4" @@ -305,7 +303,7 @@ func TestConcatAgenticMessages(t *testing.T) { UserInputVideo: &UserInputVideo{ URL: url1, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -317,17 +315,15 @@ func TestConcatAgenticMessages(t *testing.T) { UserInputVideo: &UserInputVideo{ URL: url2, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, } - result, err := ConcatAgenticMessages(msgs) - assert.NoError(t, err) - assert.Len(t, result.ContentBlocks, 1) - // Should take the last video - assert.Equal(t, url2, result.ContentBlocks[0].UserInputVideo.URL) + _, err := ConcatAgenticMessages(msgs) + assert.Error(t, err) + assert.ErrorContains(t, err, "cannot concat multiple user input videos") }) t.Run("concat assistant gen text", func(t *testing.T) { @@ -340,7 +336,7 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "Generated ", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -352,7 +348,7 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "Text", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -365,9 +361,6 @@ func TestConcatAgenticMessages(t *testing.T) { }) t.Run("concat assistant gen image", func(t *testing.T) { - url1 := "https://example.com/gen_image1.jpg" - url2 := "https://example.com/gen_image2.jpg" - msgs := []*AgenticMessage{ { Role: AgenticRoleTypeAssistant, @@ -375,9 +368,9 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeAssistantGenImage, AssistantGenImage: &AssistantGenImage{ - URL: url1, + Base64Data: "part1", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -387,9 +380,9 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeAssistantGenImage, AssistantGenImage: &AssistantGenImage{ - URL: url2, + Base64Data: "part2", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -398,14 +391,10 @@ func TestConcatAgenticMessages(t *testing.T) { result, err := ConcatAgenticMessages(msgs) assert.NoError(t, err) assert.Len(t, result.ContentBlocks, 1) - // Should take the last image - assert.Equal(t, url2, result.ContentBlocks[0].AssistantGenImage.URL) + assert.Equal(t, "part1part2", result.ContentBlocks[0].AssistantGenImage.Base64Data) }) t.Run("concat assistant gen audio", func(t *testing.T) { - url1 := "https://example.com/gen_audio1.mp3" - url2 := "https://example.com/gen_audio2.mp3" - msgs := []*AgenticMessage{ { Role: AgenticRoleTypeAssistant, @@ -413,9 +402,9 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeAssistantGenAudio, AssistantGenAudio: &AssistantGenAudio{ - URL: url1, + Base64Data: "audio1", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -425,9 +414,9 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeAssistantGenAudio, AssistantGenAudio: &AssistantGenAudio{ - URL: url2, + Base64Data: "audio2", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -436,14 +425,10 @@ func TestConcatAgenticMessages(t *testing.T) { result, err := ConcatAgenticMessages(msgs) assert.NoError(t, err) assert.Len(t, result.ContentBlocks, 1) - // Should take the last audio - assert.Equal(t, url2, result.ContentBlocks[0].AssistantGenAudio.URL) + assert.Equal(t, "audio1audio2", result.ContentBlocks[0].AssistantGenAudio.Base64Data) }) t.Run("concat assistant gen video", func(t *testing.T) { - url1 := "https://example.com/gen_video1.mp4" - url2 := "https://example.com/gen_video2.mp4" - msgs := []*AgenticMessage{ { Role: AgenticRoleTypeAssistant, @@ -451,9 +436,9 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeAssistantGenVideo, AssistantGenVideo: &AssistantGenVideo{ - URL: url1, + Base64Data: "video1", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -463,9 +448,9 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeAssistantGenVideo, AssistantGenVideo: &AssistantGenVideo{ - URL: url2, + Base64Data: "video2", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -474,8 +459,7 @@ func TestConcatAgenticMessages(t *testing.T) { result, err := ConcatAgenticMessages(msgs) assert.NoError(t, err) assert.Len(t, result.ContentBlocks, 1) - // Should take the last video - assert.Equal(t, url2, result.ContentBlocks[0].AssistantGenVideo.URL) + assert.Equal(t, "video1video2", result.ContentBlocks[0].AssistantGenVideo.Base64Data) }) t.Run("concat function tool call", func(t *testing.T) { @@ -490,7 +474,7 @@ func TestConcatAgenticMessages(t *testing.T) { Name: "get_weather", Arguments: `{"location`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -502,7 +486,7 @@ func TestConcatAgenticMessages(t *testing.T) { FunctionToolCall: &FunctionToolCall{ Arguments: `":"NYC"}`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -528,7 +512,7 @@ func TestConcatAgenticMessages(t *testing.T) { Name: "get_weather", Result: `{"temp`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -540,7 +524,7 @@ func TestConcatAgenticMessages(t *testing.T) { FunctionToolResult: &FunctionToolResult{ Result: `":72}`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -565,7 +549,7 @@ func TestConcatAgenticMessages(t *testing.T) { CallID: "server_call_1", Name: "server_func", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -577,7 +561,7 @@ func TestConcatAgenticMessages(t *testing.T) { ServerToolCall: &ServerToolCall{ Arguments: map[string]any{"key": "value"}, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -603,7 +587,7 @@ func TestConcatAgenticMessages(t *testing.T) { Name: "server_func", Result: "result1", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -611,11 +595,9 @@ func TestConcatAgenticMessages(t *testing.T) { Role: AgenticRoleTypeAssistant, ContentBlocks: []*ContentBlock{ { - Type: ContentBlockTypeServerToolResult, - ServerToolResult: &ServerToolResult{ - Result: "result2", - }, - StreamMeta: &StreamMeta{Index: 0}, + Type: ContentBlockTypeServerToolResult, + ServerToolResult: &ServerToolResult{}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -626,6 +608,7 @@ func TestConcatAgenticMessages(t *testing.T) { assert.Len(t, result.ContentBlocks, 1) assert.Equal(t, "server_call_1", result.ContentBlocks[0].ServerToolResult.CallID) assert.Equal(t, "server_func", result.ContentBlocks[0].ServerToolResult.Name) + assert.Equal(t, "result1", result.ContentBlocks[0].ServerToolResult.Result) }) t.Run("concat mcp tool call", func(t *testing.T) { @@ -641,7 +624,7 @@ func TestConcatAgenticMessages(t *testing.T) { Name: "mcp_func", Arguments: `{"arg`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -653,7 +636,7 @@ func TestConcatAgenticMessages(t *testing.T) { MCPToolCall: &MCPToolCall{ Arguments: `":123}`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -676,11 +659,12 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeMCPToolResult, MCPToolResult: &MCPToolResult{ - CallID: "mcp_call_1", - Name: "mcp_func", - Result: `{"res`, + ServerLabel: "mcp-server", + CallID: "mcp_call_1", + Name: "mcp_func", + Result: `First`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -690,9 +674,9 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeMCPToolResult, MCPToolResult: &MCPToolResult{ - Result: `ult":true}`, + Result: `Second`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -701,9 +685,10 @@ func TestConcatAgenticMessages(t *testing.T) { result, err := ConcatAgenticMessages(msgs) assert.NoError(t, err) assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "mcp-server", result.ContentBlocks[0].MCPToolResult.ServerLabel) assert.Equal(t, "mcp_call_1", result.ContentBlocks[0].MCPToolResult.CallID) assert.Equal(t, "mcp_func", result.ContentBlocks[0].MCPToolResult.Name) - assert.Equal(t, `{"result":true}`, result.ContentBlocks[0].MCPToolResult.Result) + assert.Equal(t, `Second`, result.ContentBlocks[0].MCPToolResult.Result) }) t.Run("concat mcp list tools", func(t *testing.T) { @@ -719,7 +704,7 @@ func TestConcatAgenticMessages(t *testing.T) { {Name: "tool1"}, }, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -733,7 +718,7 @@ func TestConcatAgenticMessages(t *testing.T) { {Name: "tool2"}, }, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -759,7 +744,7 @@ func TestConcatAgenticMessages(t *testing.T) { ServerLabel: "mcp-server", Arguments: `{"request`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -771,7 +756,7 @@ func TestConcatAgenticMessages(t *testing.T) { MCPToolApprovalRequest: &MCPToolApprovalRequest{ Arguments: `":1}`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -786,7 +771,7 @@ func TestConcatAgenticMessages(t *testing.T) { assert.Equal(t, `{"request":1}`, result.ContentBlocks[0].MCPToolApprovalRequest.Arguments) }) - t.Run("concat mcp tool approval response", func(t *testing.T) { + t.Run("concat mcp tool approval response - should error", func(t *testing.T) { response1 := &MCPToolApprovalResponse{ ApprovalRequestID: "approval_1", Approve: false, @@ -803,7 +788,7 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeMCPToolApprovalResponse, MCPToolApprovalResponse: response1, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -813,17 +798,15 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeMCPToolApprovalResponse, MCPToolApprovalResponse: response2, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, } - result, err := ConcatAgenticMessages(msgs) - assert.NoError(t, err) - assert.Len(t, result.ContentBlocks, 1) - // Should take the last response - assert.Equal(t, response2, result.ContentBlocks[0].MCPToolApprovalResponse) + _, err := ConcatAgenticMessages(msgs) + assert.Error(t, err) + assert.ErrorContains(t, err, "cannot concat multiple mcp tool approval responses") }) t.Run("concat response meta", func(t *testing.T) { @@ -865,7 +848,7 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "Hello", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -877,7 +860,7 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "World", }, - // No StreamMeta - non-streaming + // No StreamingMeta - non-streaming }, }, }, @@ -901,7 +884,7 @@ func TestConcatAgenticMessages(t *testing.T) { Name: "list_files", Arguments: `{"path`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -913,7 +896,7 @@ func TestConcatAgenticMessages(t *testing.T) { MCPToolCall: &MCPToolCall{ Arguments: `":"/tmp"}`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -927,7 +910,7 @@ func TestConcatAgenticMessages(t *testing.T) { assert.Equal(t, `{"path":"/tmp"}`, result.ContentBlocks[0].MCPToolCall.Arguments) }) - t.Run("concat user input text", func(t *testing.T) { + t.Run("concat user input text - should error", func(t *testing.T) { msgs := []*AgenticMessage{ { Role: AgenticRoleTypeUser, @@ -937,7 +920,7 @@ func TestConcatAgenticMessages(t *testing.T) { UserInputText: &UserInputText{ Text: "What is ", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, @@ -949,16 +932,15 @@ func TestConcatAgenticMessages(t *testing.T) { UserInputText: &UserInputText{ Text: "the weather?", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, }, }, } - result, err := ConcatAgenticMessages(msgs) - assert.NoError(t, err) - assert.Len(t, result.ContentBlocks, 1) - assert.Equal(t, "What is the weather?", result.ContentBlocks[0].UserInputText.Text) + _, err := ConcatAgenticMessages(msgs) + assert.Error(t, err) + assert.ErrorContains(t, err, "cannot concat multiple user input texts") }) t.Run("multiple stream indexes - sparse indexes", func(t *testing.T) { @@ -971,14 +953,14 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "Index0-", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, { Type: ContentBlockTypeAssistantGenText, AssistantGenText: &AssistantGenText{ Text: "Index2-", }, - StreamMeta: &StreamMeta{Index: 2}, + StreamingMeta: &StreamingMeta{Index: 2}, }, }, }, @@ -990,14 +972,14 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "Part2", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, { Type: ContentBlockTypeAssistantGenText, AssistantGenText: &AssistantGenText{ Text: "Part2", }, - StreamMeta: &StreamMeta{Index: 2}, + StreamingMeta: &StreamingMeta{Index: 2}, }, }, }, @@ -1020,7 +1002,7 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "Text ", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, { Type: ContentBlockTypeFunctionToolCall, @@ -1029,7 +1011,7 @@ func TestConcatAgenticMessages(t *testing.T) { Name: "func1", Arguments: `{"a`, }, - StreamMeta: &StreamMeta{Index: 1}, + StreamingMeta: &StreamingMeta{Index: 1}, }, }, }, @@ -1041,14 +1023,14 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "Content", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, { Type: ContentBlockTypeFunctionToolCall, FunctionToolCall: &FunctionToolCall{ Arguments: `":1}`, }, - StreamMeta: &StreamMeta{Index: 1}, + StreamingMeta: &StreamingMeta{Index: 1}, }, }, }, @@ -1073,21 +1055,21 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "A", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, { Type: ContentBlockTypeAssistantGenText, AssistantGenText: &AssistantGenText{ Text: "B", }, - StreamMeta: &StreamMeta{Index: 1}, + StreamingMeta: &StreamingMeta{Index: 1}, }, { Type: ContentBlockTypeAssistantGenText, AssistantGenText: &AssistantGenText{ Text: "C", }, - StreamMeta: &StreamMeta{Index: 2}, + StreamingMeta: &StreamingMeta{Index: 2}, }, }, }, @@ -1099,21 +1081,21 @@ func TestConcatAgenticMessages(t *testing.T) { AssistantGenText: &AssistantGenText{ Text: "1", }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, { Type: ContentBlockTypeAssistantGenText, AssistantGenText: &AssistantGenText{ Text: "2", }, - StreamMeta: &StreamMeta{Index: 1}, + StreamingMeta: &StreamingMeta{Index: 1}, }, { Type: ContentBlockTypeAssistantGenText, AssistantGenText: &AssistantGenText{ Text: "3", }, - StreamMeta: &StreamMeta{Index: 2}, + StreamingMeta: &StreamingMeta{Index: 2}, }, }, }, @@ -1276,7 +1258,7 @@ func TestAgenticMessageString(t *testing.T) { Name: "get_current_weather", Arguments: `{"location":"New York City","unit":"fahrenheit"}`, }, - StreamMeta: &StreamMeta{Index: 0}, + StreamingMeta: &StreamingMeta{Index: 0}, }, { Type: ContentBlockTypeFunctionToolResult, @@ -1289,11 +1271,10 @@ func TestAgenticMessageString(t *testing.T) { { Type: ContentBlockTypeMCPToolCall, MCPToolCall: &MCPToolCall{ - ServerLabel: "weather-mcp-server", - CallID: "mcp_forecast_456", - Name: "get_7day_forecast", - Arguments: `{"city":"New York","days":7}`, - ApprovalRequestID: "approval_req_789", + ServerLabel: "weather-mcp-server", + CallID: "mcp_forecast_456", + Name: "get_7day_forecast", + Arguments: `{"city":"New York","days":7}`, }, }, { @@ -1363,7 +1344,6 @@ content_blocks: call_id: mcp_forecast_456 name: get_7day_forecast arguments: {"city":"New York","days":7} - approval_request_id: approval_req_789 [7] type: mcp_tool_result call_id: mcp_forecast_456 name: get_7day_forecast @@ -1378,4 +1358,294 @@ content_blocks: response_meta: token_usage: prompt=250, completion=180, total=430 `, output) + + t.Run("full fields", func(t *testing.T) { + msg := &AgenticMessage{ + Role: AgenticRoleTypeSystem, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputAudio, + UserInputAudio: &UserInputAudio{ + URL: "http://audio.com", + Base64Data: "audio_data", + MIMEType: "audio/mp3", + }, + }, + { + Type: ContentBlockTypeUserInputVideo, + UserInputVideo: &UserInputVideo{ + URL: "http://video.com", + Base64Data: "video_data", + MIMEType: "video/mp4", + }, + }, + { + Type: ContentBlockTypeUserInputFile, + UserInputFile: &UserInputFile{ + URL: "http://file.com", + Name: "file.txt", + Base64Data: "file_data", + MIMEType: "text/plain", + }, + }, + { + Type: ContentBlockTypeAssistantGenImage, + AssistantGenImage: &AssistantGenImage{ + URL: "http://gen_image.com", + Base64Data: "gen_image_data", + MIMEType: "image/png", + }, + }, + { + Type: ContentBlockTypeAssistantGenAudio, + AssistantGenAudio: &AssistantGenAudio{ + URL: "http://gen_audio.com", + Base64Data: "gen_audio_data", + MIMEType: "audio/wav", + }, + }, + { + Type: ContentBlockTypeAssistantGenVideo, + AssistantGenVideo: &AssistantGenVideo{ + URL: "http://gen_video.com", + Base64Data: "gen_video_data", + MIMEType: "video/mp4", + }, + }, + { + Type: ContentBlockTypeServerToolCall, + ServerToolCall: &ServerToolCall{ + Name: "server_tool", + CallID: "call_1", + Arguments: map[string]any{"a": 1}, + }, + }, + { + Type: ContentBlockTypeServerToolResult, + ServerToolResult: &ServerToolResult{ + Name: "server_tool", + CallID: "call_1", + Result: map[string]any{"success": true}, + }, + }, + { + Type: ContentBlockTypeMCPToolApprovalRequest, + MCPToolApprovalRequest: &MCPToolApprovalRequest{ + ID: "req_1", + Name: "mcp_tool", + ServerLabel: "mcp_server", + Arguments: "{}", + }, + }, + { + Type: ContentBlockTypeMCPToolApprovalResponse, + MCPToolApprovalResponse: &MCPToolApprovalResponse{ + ApprovalRequestID: "req_1", + Approve: true, + Reason: "looks good", + }, + }, + }, + } + + s := msg.String() + assert.Contains(t, s, "role: system") + assert.Contains(t, s, "type: user_input_audio") + assert.Contains(t, s, "http://audio.com") + assert.Contains(t, s, "type: user_input_video") + assert.Contains(t, s, "http://video.com") + assert.Contains(t, s, "type: user_input_file") + assert.Contains(t, s, "file.txt") + assert.Contains(t, s, "type: assistant_gen_image") + assert.Contains(t, s, "http://gen_image.com") + assert.Contains(t, s, "type: assistant_gen_audio") + assert.Contains(t, s, "http://gen_audio.com") + assert.Contains(t, s, "type: assistant_gen_video") + assert.Contains(t, s, "http://gen_video.com") + assert.Contains(t, s, "type: server_tool_call") + assert.Contains(t, s, "server_tool") + assert.Contains(t, s, "map[a:1]") + assert.Contains(t, s, "type: server_tool_result") + assert.Contains(t, s, "map[success:true]") + assert.Contains(t, s, "type: mcp_tool_approval_request") + assert.Contains(t, s, "req_1") + assert.Contains(t, s, "type: mcp_tool_approval_response") + assert.Contains(t, s, "looks good") + }) + + t.Run("nil/empty fields", func(t *testing.T) { + msg := &AgenticMessage{ + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + {Type: ContentBlockTypeUserInputAudio, UserInputAudio: &UserInputAudio{}}, // empty + {Type: ContentBlockTypeUserInputVideo, UserInputVideo: &UserInputVideo{}}, + {Type: ContentBlockTypeUserInputFile, UserInputFile: &UserInputFile{}}, + {Type: ContentBlockTypeAssistantGenImage, AssistantGenImage: &AssistantGenImage{}}, + {Type: ContentBlockTypeAssistantGenAudio, AssistantGenAudio: &AssistantGenAudio{}}, + {Type: ContentBlockTypeAssistantGenVideo, AssistantGenVideo: &AssistantGenVideo{}}, + {Type: ContentBlockTypeServerToolCall, ServerToolCall: &ServerToolCall{Name: "t"}}, // No CallID + {Type: ContentBlockTypeServerToolResult, ServerToolResult: &ServerToolResult{Name: "t"}}, // No CallID + {Type: ContentBlockTypeMCPToolResult, MCPToolResult: &MCPToolResult{Name: "t"}}, // No Error + {Type: ContentBlockTypeMCPListToolsResult, MCPListToolsResult: &MCPListToolsResult{}}, // No Error + {Type: ContentBlockTypeMCPToolApprovalResponse, MCPToolApprovalResponse: &MCPToolApprovalResponse{Approve: false}}, // No Reason + nil, // Nil block in slice + }, + } + + s := msg.String() + assert.Contains(t, s, "type: user_input_audio") + assert.NotContains(t, s, "mime_type:") + assert.Contains(t, s, "type: server_tool_call") + }) + + t.Run("nil content struct in block", func(t *testing.T) { + // Test cases where the specific content struct is nil but type is set + // This shouldn't crash and should just print type + msg := &AgenticMessage{ + ContentBlocks: []*ContentBlock{ + {Type: ContentBlockTypeReasoning, Reasoning: nil}, + {Type: ContentBlockTypeUserInputText, UserInputText: nil}, + {Type: ContentBlockTypeUserInputImage, UserInputImage: nil}, + {Type: ContentBlockTypeUserInputAudio, UserInputAudio: nil}, + {Type: ContentBlockTypeUserInputVideo, UserInputVideo: nil}, + {Type: ContentBlockTypeUserInputFile, UserInputFile: nil}, + {Type: ContentBlockTypeAssistantGenText, AssistantGenText: nil}, + {Type: ContentBlockTypeAssistantGenImage, AssistantGenImage: nil}, + {Type: ContentBlockTypeAssistantGenAudio, AssistantGenAudio: nil}, + {Type: ContentBlockTypeAssistantGenVideo, AssistantGenVideo: nil}, + {Type: ContentBlockTypeFunctionToolCall, FunctionToolCall: nil}, + {Type: ContentBlockTypeFunctionToolResult, FunctionToolResult: nil}, + {Type: ContentBlockTypeServerToolCall, ServerToolCall: nil}, + {Type: ContentBlockTypeServerToolResult, ServerToolResult: nil}, + {Type: ContentBlockTypeMCPToolCall, MCPToolCall: nil}, + {Type: ContentBlockTypeMCPToolResult, MCPToolResult: nil}, + {Type: ContentBlockTypeMCPListToolsResult, MCPListToolsResult: nil}, + {Type: ContentBlockTypeMCPToolApprovalRequest, MCPToolApprovalRequest: nil}, + {Type: ContentBlockTypeMCPToolApprovalResponse, MCPToolApprovalResponse: nil}, + }, + } + s := msg.String() + assert.Contains(t, s, "type: reasoning") + // ensure no panic and basic output present + }) +} + +func TestDeveloperAgenticMessage(t *testing.T) { + t.Run("basic", func(t *testing.T) { + msg := DeveloperAgenticMessage("developer") + assert.Equal(t, AgenticRoleTypeDeveloper, msg.Role) + assert.Len(t, msg.ContentBlocks, 1) + assert.Equal(t, "developer", msg.ContentBlocks[0].UserInputText.Text) + }) +} + +func TestSystemAgenticMessage(t *testing.T) { + t.Run("basic", func(t *testing.T) { + msg := SystemAgenticMessage("system") + assert.Equal(t, AgenticRoleTypeSystem, msg.Role) + assert.Len(t, msg.ContentBlocks, 1) + assert.Equal(t, "system", msg.ContentBlocks[0].UserInputText.Text) + }) +} + +func TestUserAgenticMessage(t *testing.T) { + t.Run("basic", func(t *testing.T) { + msg := UserAgenticMessage("user") + assert.Equal(t, AgenticRoleTypeUser, msg.Role) + assert.Len(t, msg.ContentBlocks, 1) + assert.Equal(t, "user", msg.ContentBlocks[0].UserInputText.Text) + }) +} + +func TestFunctionToolResultAgenticMessage(t *testing.T) { + t.Run("basic", func(t *testing.T) { + msg := FunctionToolResultAgenticMessage("call_1", "tool_name", "result_str") + assert.Equal(t, AgenticRoleTypeUser, msg.Role) + assert.Len(t, msg.ContentBlocks, 1) + assert.Equal(t, ContentBlockTypeFunctionToolResult, msg.ContentBlocks[0].Type) + assert.Equal(t, "call_1", msg.ContentBlocks[0].FunctionToolResult.CallID) + assert.Equal(t, "tool_name", msg.ContentBlocks[0].FunctionToolResult.Name) + assert.Equal(t, "result_str", msg.ContentBlocks[0].FunctionToolResult.Result) + }) +} + +func TestNewContentBlock(t *testing.T) { + cbType := reflect.TypeOf(ContentBlock{}) + for i := 0; i < cbType.NumField(); i++ { + field := cbType.Field(i) + + // Skip non-content fields + if field.Name == "Type" || field.Name == "Extra" || field.Name == "StreamingMeta" { + continue + } + + t.Run(field.Name, func(t *testing.T) { + // Ensure field is a pointer + assert.Equal(t, reflect.Ptr, field.Type.Kind(), "Field %s should be a pointer", field.Name) + + // Create a new instance of the field's type + // field.Type is *T, so Elem() is T. reflect.New(T) returns *T. + elemType := field.Type.Elem() + inputVal := reflect.New(elemType) + input := inputVal.Interface() + + // Call NewContentBlock (generic) via type switch + var block *ContentBlock + switch v := input.(type) { + case *Reasoning: + block = NewContentBlock(v) + case *UserInputText: + block = NewContentBlock(v) + case *UserInputImage: + block = NewContentBlock(v) + case *UserInputAudio: + block = NewContentBlock(v) + case *UserInputVideo: + block = NewContentBlock(v) + case *UserInputFile: + block = NewContentBlock(v) + case *AssistantGenText: + block = NewContentBlock(v) + case *AssistantGenImage: + block = NewContentBlock(v) + case *AssistantGenAudio: + block = NewContentBlock(v) + case *AssistantGenVideo: + block = NewContentBlock(v) + case *FunctionToolCall: + block = NewContentBlock(v) + case *FunctionToolResult: + block = NewContentBlock(v) + case *ServerToolCall: + block = NewContentBlock(v) + case *ServerToolResult: + block = NewContentBlock(v) + case *MCPToolCall: + block = NewContentBlock(v) + case *MCPToolResult: + block = NewContentBlock(v) + case *MCPListToolsResult: + block = NewContentBlock(v) + case *MCPToolApprovalRequest: + block = NewContentBlock(v) + case *MCPToolApprovalResponse: + block = NewContentBlock(v) + default: + t.Fatalf("unsupported ContentBlock field type: %T", input) + } + + // Assertions + assert.NotNil(t, block, "NewContentBlock should return non-nil for type %T", input) + + // Check if the corresponding field in block is set equals to input + blockVal := reflect.ValueOf(block).Elem() + fieldVal := blockVal.FieldByName(field.Name) + assert.True(t, fieldVal.IsValid(), "Field %s not found in result", field.Name) + assert.Equal(t, input, fieldVal.Interface(), "Field %s should match input", field.Name) + + // Check Type is set + typeVal := blockVal.FieldByName("Type") + assert.NotEmpty(t, typeVal.String(), "Type should be set for %s", field.Name) + }) + } } diff --git a/schema/claude/consts.go b/schema/claude/consts.go index cbf8784f6..714b0362e 100644 --- a/schema/claude/consts.go +++ b/schema/claude/consts.go @@ -14,6 +14,7 @@ * limitations under the License. */ +// Package claude defines constants for claude. package claude type TextCitationType string diff --git a/schema/claude/content_block.go b/schema/claude/extension.go similarity index 51% rename from schema/claude/content_block.go rename to schema/claude/extension.go index 0c43d1045..5df8d8907 100644 --- a/schema/claude/content_block.go +++ b/schema/claude/extension.go @@ -16,6 +16,15 @@ package claude +import ( + "fmt" +) + +type ResponseMetaExtension struct { + ID string `json:"id,omitempty"` + StopReason string `json:"stop_reason,omitempty"` +} + type AssistantGenTextExtension struct { Citations []*TextCitation `json:"citations,omitempty"` } @@ -33,30 +42,30 @@ type CitationCharLocation struct { CitedText string `json:"cited_text,omitempty"` DocumentTitle string `json:"document_title,omitempty"` - DocumentIndex int64 `json:"document_index,omitempty"` + DocumentIndex int `json:"document_index,omitempty"` - StartCharIndex int64 `json:"start_char_index,omitempty"` - EndCharIndex int64 `json:"end_char_index,omitempty"` + StartCharIndex int `json:"start_char_index,omitempty"` + EndCharIndex int `json:"end_char_index,omitempty"` } type CitationPageLocation struct { CitedText string `json:"cited_text,omitempty"` DocumentTitle string `json:"document_title,omitempty"` - DocumentIndex int64 `json:"document_index,omitempty"` + DocumentIndex int `json:"document_index,omitempty"` - StartPageNumber int64 `json:"start_page_number,omitempty"` - EndPageNumber int64 `json:"end_page_number,omitempty"` + StartPageNumber int `json:"start_page_number,omitempty"` + EndPageNumber int `json:"end_page_number,omitempty"` } type CitationContentBlockLocation struct { CitedText string `json:"cited_text,omitempty"` DocumentTitle string `json:"document_title,omitempty"` - DocumentIndex int64 `json:"document_index,omitempty"` + DocumentIndex int `json:"document_index,omitempty"` - StartBlockIndex int64 `json:"start_block_index,omitempty"` - EndBlockIndex int64 `json:"end_block_index,omitempty"` + StartBlockIndex int `json:"start_block_index,omitempty"` + EndBlockIndex int `json:"end_block_index,omitempty"` } type CitationWebSearchResultLocation struct { @@ -67,3 +76,46 @@ type CitationWebSearchResultLocation struct { EncryptedIndex string `json:"encrypted_index,omitempty"` } + +// ConcatAssistantGenTextExtensions concatenates multiple AssistantGenTextExtension chunks into a single one. +func ConcatAssistantGenTextExtensions(chunks []*AssistantGenTextExtension) (*AssistantGenTextExtension, error) { + if len(chunks) == 0 { + return nil, fmt.Errorf("no assistant generated text extension found") + } + if len(chunks) == 1 { + return chunks[0], nil + } + + ret := &AssistantGenTextExtension{ + Citations: make([]*TextCitation, 0, len(chunks)), + } + + for _, ext := range chunks { + ret.Citations = append(ret.Citations, ext.Citations...) + } + + return ret, nil +} + +// ConcatResponseMetaExtensions concatenates multiple ResponseMetaExtension chunks into a single one. +func ConcatResponseMetaExtensions(chunks []*ResponseMetaExtension) (*ResponseMetaExtension, error) { + if len(chunks) == 0 { + return nil, fmt.Errorf("no response meta extension found") + } + if len(chunks) == 1 { + return chunks[0], nil + } + + ret := &ResponseMetaExtension{} + + for _, ext := range chunks { + if ext.ID != "" { + ret.ID = ext.ID + } + if ext.StopReason != "" { + ret.StopReason = ext.StopReason + } + } + + return ret, nil +} diff --git a/schema/claude/extension_test.go b/schema/claude/extension_test.go new file mode 100644 index 000000000..474fe740b --- /dev/null +++ b/schema/claude/extension_test.go @@ -0,0 +1,190 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package claude + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConcatAssistantGenTextExtensions(t *testing.T) { + t.Run("multiple extensions - concatenates all citations", func(t *testing.T) { + exts := []*AssistantGenTextExtension{ + { + Citations: []*TextCitation{ + { + Type: "char_location", + CharLocation: &CitationCharLocation{ + CitedText: "citation 1", + DocumentIndex: 0, + }, + }, + }, + }, + { + Citations: []*TextCitation{ + { + Type: "page_location", + PageLocation: &CitationPageLocation{ + CitedText: "citation 2", + StartPageNumber: 1, + EndPageNumber: 2, + }, + }, + { + Type: "web_search_result_location", + WebSearchResultLocation: &CitationWebSearchResultLocation{ + CitedText: "citation 3", + URL: "https://example.com", + }, + }, + }, + }, + { + Citations: []*TextCitation{ + { + Type: "content_block_location", + ContentBlockLocation: &CitationContentBlockLocation{ + CitedText: "citation 4", + StartBlockIndex: 0, + EndBlockIndex: 5, + }, + }, + }, + }, + } + + result, err := ConcatAssistantGenTextExtensions(exts) + assert.NoError(t, err) + assert.Len(t, result.Citations, 4) + assert.Equal(t, "citation 1", result.Citations[0].CharLocation.CitedText) + assert.Equal(t, "citation 2", result.Citations[1].PageLocation.CitedText) + assert.Equal(t, "citation 3", result.Citations[2].WebSearchResultLocation.CitedText) + assert.Equal(t, "citation 4", result.Citations[3].ContentBlockLocation.CitedText) + }) + + t.Run("mixed empty and non-empty citations", func(t *testing.T) { + exts := []*AssistantGenTextExtension{ + {Citations: nil}, + { + Citations: []*TextCitation{ + { + Type: "char_location", + CharLocation: &CitationCharLocation{ + CitedText: "text1", + }, + }, + }, + }, + {Citations: []*TextCitation{}}, + { + Citations: []*TextCitation{ + { + Type: "page_location", + PageLocation: &CitationPageLocation{ + CitedText: "text2", + }, + }, + }, + }, + } + + result, err := ConcatAssistantGenTextExtensions(exts) + assert.NoError(t, err) + assert.Len(t, result.Citations, 2) + assert.Equal(t, "text1", result.Citations[0].CharLocation.CitedText) + assert.Equal(t, "text2", result.Citations[1].PageLocation.CitedText) + }) + + t.Run("streaming scenario - citations arrive in chunks", func(t *testing.T) { + // Simulates streaming where citations arrive progressively + exts := []*AssistantGenTextExtension{ + { + Citations: []*TextCitation{ + {Type: "char_location", CharLocation: &CitationCharLocation{CitedText: "chunk1"}}, + }, + }, + { + Citations: []*TextCitation{ + {Type: "char_location", CharLocation: &CitationCharLocation{CitedText: "chunk2"}}, + }, + }, + { + Citations: []*TextCitation{ + {Type: "char_location", CharLocation: &CitationCharLocation{CitedText: "chunk3"}}, + }, + }, + } + + result, err := ConcatAssistantGenTextExtensions(exts) + assert.NoError(t, err) + assert.Len(t, result.Citations, 3) + assert.Equal(t, "chunk1", result.Citations[0].CharLocation.CitedText) + assert.Equal(t, "chunk2", result.Citations[1].CharLocation.CitedText) + assert.Equal(t, "chunk3", result.Citations[2].CharLocation.CitedText) + }) +} + +func TestConcatResponseMetaExtensions(t *testing.T) { + t.Run("multiple extensions - takes last non-empty values", func(t *testing.T) { + exts := []*ResponseMetaExtension{ + { + ID: "msg_1", + StopReason: "stop_1", + }, + { + ID: "msg_2", + StopReason: "", + }, + { + ID: "", + StopReason: "stop_3", + }, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "msg_2", result.ID) // Last non-empty ID + assert.Equal(t, "stop_3", result.StopReason) // Last non-empty StopReason + }) + + t.Run("all empty fields", func(t *testing.T) { + exts := []*ResponseMetaExtension{ + {ID: "", StopReason: ""}, + {ID: "", StopReason: ""}, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "", result.ID) + assert.Equal(t, "", result.StopReason) + }) + + t.Run("streaming scenario - ID in first chunk, StopReason in last", func(t *testing.T) { + exts := []*ResponseMetaExtension{ + {ID: "msg_stream_123", StopReason: ""}, + {ID: "", StopReason: ""}, + {ID: "", StopReason: "end_turn"}, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "msg_stream_123", result.ID) + assert.Equal(t, "end_turn", result.StopReason) + }) +} diff --git a/schema/claude/response_meta.go b/schema/claude/response_meta.go deleted file mode 100644 index 9f60dd713..000000000 --- a/schema/claude/response_meta.go +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Copyright 2025 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package claude - -type ResponseMetaExtension struct { - ID string `json:"id,omitempty"` - StopReason string `json:"stop_reason,omitempty"` -} diff --git a/schema/gemini/response_meta.go b/schema/gemini/extension.go similarity index 76% rename from schema/gemini/response_meta.go rename to schema/gemini/extension.go index a5b3f626c..efbc4f4bd 100644 --- a/schema/gemini/response_meta.go +++ b/schema/gemini/extension.go @@ -14,8 +14,13 @@ * limitations under the License. */ +// Package gemini defines the extension for gemini. package gemini +import ( + "fmt" +) + type ResponseMetaExtension struct { ID string `json:"id,omitempty"` FinishReason string `json:"finish_reason,omitempty"` @@ -38,7 +43,7 @@ type GroundingChunk struct { Web *GroundingChunkWeb `json:"web,omitempty"` } -// Chunk from the web. +// GroundingChunkWeb is the chunk from the web. type GroundingChunkWeb struct { // Domain of the (original) URI. This field is not supported in Gemini API. Domain string `json:"domain,omitempty"` @@ -56,7 +61,7 @@ type GroundingSupport struct { // A list of indices (into 'grounding_chunk') specifying the citations associated with // the claim. For instance [1,3,4] means that grounding_chunk[1], grounding_chunk[3], // grounding_chunk[4] are the retrieved content attributed to the claim. - GroundingChunkIndices []int32 `json:"grounding_chunk_indices,omitempty"` + GroundingChunkIndices []int `json:"grounding_chunk_indices,omitempty"` // Segment of the content this support belongs to. Segment *Segment `json:"segment,omitempty"` } @@ -65,20 +70,46 @@ type GroundingSupport struct { type Segment struct { // Output only. End index in the given Part, measured in bytes. Offset from the start // of the Part, exclusive, starting at zero. - EndIndex int32 `json:"end_index,omitempty"` + EndIndex int `json:"end_index,omitempty"` // Output only. The index of a Part object within its parent Content object. - PartIndex int32 `json:"part_index,omitempty"` + PartIndex int `json:"part_index,omitempty"` // Output only. Start index in the given Part, measured in bytes. Offset from the start // of the Part, inclusive, starting at zero. - StartIndex int32 `json:"start_index,omitempty"` + StartIndex int `json:"start_index,omitempty"` // Output only. The text corresponding to the segment from the response. Text string `json:"text,omitempty"` } -// Google search entry point. +// SearchEntryPoint is the Google search entry point. type SearchEntryPoint struct { // Optional. Web content snippet that can be embedded in a web page or an app webview. RenderedContent string `json:"rendered_content,omitempty"` // Optional. Base64 encoded JSON representing array of tuple. SDKBlob []byte `json:"sdk_blob,omitempty"` } + +// ConcatResponseMetaExtensions concatenates multiple ResponseMetaExtension chunks into a single one. +func ConcatResponseMetaExtensions(chunks []*ResponseMetaExtension) (*ResponseMetaExtension, error) { + if len(chunks) == 0 { + return nil, fmt.Errorf("no response meta extension found") + } + if len(chunks) == 1 { + return chunks[0], nil + } + + ret := &ResponseMetaExtension{} + + for _, ext := range chunks { + if ext.ID != "" { + ret.ID = ext.ID + } + if ext.FinishReason != "" { + ret.FinishReason = ext.FinishReason + } + if ext.GroundingMeta != nil { + ret.GroundingMeta = ext.GroundingMeta + } + } + + return ret, nil +} diff --git a/schema/gemini/extension_test.go b/schema/gemini/extension_test.go new file mode 100644 index 000000000..56f390aa8 --- /dev/null +++ b/schema/gemini/extension_test.go @@ -0,0 +1,79 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package gemini + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConcatResponseMetaExtensions(t *testing.T) { + t.Run("multiple extensions - takes last non-empty values", func(t *testing.T) { + meta1 := &GroundingMetadata{WebSearchQueries: []string{"query1"}} + meta2 := &GroundingMetadata{WebSearchQueries: []string{"query2"}} + + exts := []*ResponseMetaExtension{ + { + ID: "resp_1", + FinishReason: "STOP", + GroundingMeta: meta1, + }, + { + ID: "resp_2", + FinishReason: "", + GroundingMeta: nil, + }, + { + ID: "", + FinishReason: "MAX_TOKENS", + GroundingMeta: meta2, + }, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "resp_2", result.ID) + assert.Equal(t, "MAX_TOKENS", result.FinishReason) + assert.Equal(t, meta2, result.GroundingMeta) + }) + + t.Run("streaming scenario", func(t *testing.T) { + meta := &GroundingMetadata{ + GroundingChunks: []*GroundingChunk{ + { + Web: &GroundingChunkWeb{ + Title: "Example", + URI: "https://example.com", + }, + }, + }, + } + + exts := []*ResponseMetaExtension{ + {ID: "stream_123", FinishReason: "", GroundingMeta: nil}, + {ID: "", FinishReason: "", GroundingMeta: nil}, + {ID: "", FinishReason: "STOP", GroundingMeta: meta}, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "stream_123", result.ID) + assert.Equal(t, "STOP", result.FinishReason) + assert.Equal(t, meta, result.GroundingMeta) + }) +} diff --git a/schema/message.go b/schema/message.go index bc1cc184d..e7514f305 100644 --- a/schema/message.go +++ b/schema/message.go @@ -698,10 +698,10 @@ type TokenUsage struct { PromptTokenDetails PromptTokenDetails `json:"prompt_token_details"` // CompletionTokens is the number of completion tokens. CompletionTokens int `json:"completion_tokens"` - // CompletionTokenDetails is a breakdown of the completion tokens. - CompletionTokenDetails CompletionTokensDetails `json:"completion_token_details"` // TotalTokens is the total number of tokens. TotalTokens int `json:"total_tokens"` + // CompletionTokensDetails is breakdown of completion tokens. + CompletionTokensDetails CompletionTokensDetails `json:"completion_token_details"` } type CompletionTokensDetails struct { diff --git a/schema/openai/consts.go b/schema/openai/consts.go index 321ee2a9e..5958cef40 100644 --- a/schema/openai/consts.go +++ b/schema/openai/consts.go @@ -14,6 +14,7 @@ * limitations under the License. */ +// Package openai defines constants for openai. package openai type TextAnnotationType string @@ -24,3 +25,71 @@ const ( TextAnnotationTypeContainerFileCitation TextAnnotationType = "container_file_citation" TextAnnotationTypeFilePath TextAnnotationType = "file_path" ) + +type ReasoningEffort string + +const ( + ReasoningEffortMinimal ReasoningEffort = "minimal" + ReasoningEffortLow ReasoningEffort = "low" + ReasoningEffortMedium ReasoningEffort = "medium" + ReasoningEffortHigh ReasoningEffort = "high" +) + +type ReasoningSummary string + +const ( + ReasoningSummaryAuto ReasoningSummary = "auto" + ReasoningSummaryConcise ReasoningSummary = "concise" + ReasoningSummaryDetailed ReasoningSummary = "detailed" +) + +type ServiceTier string + +const ( + ServiceTierAuto ServiceTier = "auto" + ServiceTierDefault ServiceTier = "default" + ServiceTierFlex ServiceTier = "flex" + ServiceTierScale ServiceTier = "scale" + ServiceTierPriority ServiceTier = "priority" +) + +type PromptCacheRetention string + +const ( + PromptCacheRetentionInMemory PromptCacheRetention = "in-memory" + PromptCacheRetention24h PromptCacheRetention = "24h" +) + +type ResponseStatus string + +const ( + ResponseStatusCompleted ResponseStatus = "completed" + ResponseStatusFailed ResponseStatus = "failed" + ResponseStatusInProgress ResponseStatus = "in_progress" + ResponseStatusCancelled ResponseStatus = "cancelled" + ResponseStatusQueued ResponseStatus = "queued" + ResponseStatusIncomplete ResponseStatus = "incomplete" +) + +type ResponseErrorCode string + +const ( + ResponseErrorCodeServerError ResponseErrorCode = "server_error" + ResponseErrorCodeRateLimitExceeded ResponseErrorCode = "rate_limit_exceeded" + ResponseErrorCodeInvalidPrompt ResponseErrorCode = "invalid_prompt" + ResponseErrorCodeVectorStoreTimeout ResponseErrorCode = "vector_store_timeout" + ResponseErrorCodeInvalidImage ResponseErrorCode = "invalid_image" + ResponseErrorCodeInvalidImageFormat ResponseErrorCode = "invalid_image_format" + ResponseErrorCodeInvalidBase64Image ResponseErrorCode = "invalid_base64_image" + ResponseErrorCodeInvalidImageURL ResponseErrorCode = "invalid_image_url" + ResponseErrorCodeImageTooLarge ResponseErrorCode = "image_too_large" + ResponseErrorCodeImageTooSmall ResponseErrorCode = "image_too_small" + ResponseErrorCodeImageParseError ResponseErrorCode = "image_parse_error" + ResponseErrorCodeImageContentPolicyViolation ResponseErrorCode = "image_content_policy_violation" + ResponseErrorCodeInvalidImageMode ResponseErrorCode = "invalid_image_mode" + ResponseErrorCodeImageFileTooLarge ResponseErrorCode = "image_file_too_large" + ResponseErrorCodeUnsupportedImageMediaType ResponseErrorCode = "unsupported_image_media_type" + ResponseErrorCodeEmptyImageFile ResponseErrorCode = "empty_image_file" + ResponseErrorCodeFailedToDownloadImage ResponseErrorCode = "failed_to_download_image" + ResponseErrorCodeImageFileNotFound ResponseErrorCode = "image_file_not_found" +) diff --git a/schema/openai/content_block.go b/schema/openai/content_block.go deleted file mode 100644 index 5d92be8f7..000000000 --- a/schema/openai/content_block.go +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Copyright 2025 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package openai - -type AssistantGenTextExtension struct { - Annotations []*TextAnnotation `json:"annotations,omitempty"` -} - -type TextAnnotation struct { - Type TextAnnotationType `json:"type,omitempty"` - - FileCitation *TextAnnotationFileCitation `json:"file_citation,omitempty"` - URLCitation *TextAnnotationURLCitation `json:"url_citation,omitempty"` - ContainerFileCitation *TextAnnotationContainerFileCitation `json:"container_file_citation,omitempty"` - FilePath *TextAnnotationFilePath `json:"file_path,omitempty"` -} - -type TextAnnotationFileCitation struct { - // The ID of the file. - FileID string `json:"file_id,omitempty"` - // The filename of the file cited. - Filename string `json:"filename,omitempty"` - - // The index of the file in the list of files. - Index int64 `json:"index,omitempty"` -} - -type TextAnnotationURLCitation struct { - // The title of the web resource. - Title string `json:"title,omitempty"` - // The URL of the web resource. - URL string `json:"url,omitempty"` - - // The index of the first character of the URL citation in the message. - StartIndex int64 `json:"start_index,omitempty"` - // The index of the last character of the URL citation in the message. - EndIndex int64 `json:"end_index,omitempty"` -} - -type TextAnnotationContainerFileCitation struct { - // The ID of the container file. - ContainerID string `json:"container_id,omitempty"` - - // The ID of the file. - FileID string `json:"file_id,omitempty"` - // The filename of the container file cited. - Filename string `json:"filename,omitempty"` - - // The index of the first character of the container file citation in the message. - StartIndex int64 `json:"start_index,omitempty"` - // The index of the last character of the container file citation in the message. - EndIndex int64 `json:"end_index,omitempty"` -} - -type TextAnnotationFilePath struct { - // The ID of the file. - FileID string `json:"file_id,omitempty"` - - // The index of the file in the list of files. - Index int64 `json:"index,omitempty"` -} diff --git a/schema/openai/extension.go b/schema/openai/extension.go new file mode 100644 index 000000000..c30d2d8ec --- /dev/null +++ b/schema/openai/extension.go @@ -0,0 +1,206 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openai + +import ( + "fmt" + "sort" +) + +type ResponseMetaExtension struct { + ID string `json:"id,omitempty"` + Status ResponseStatus `json:"status,omitempty"` + Error *ResponseError `json:"error,omitempty"` + IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"` + PreviousResponseID string `json:"previous_response_id,omitempty"` + Reasoning *Reasoning `json:"reasoning,omitempty"` + ServiceTier ServiceTier `json:"service_tier,omitempty"` + CreatedAt int64 `json:"created_at,omitempty"` + PromptCacheRetention PromptCacheRetention `json:"prompt_cache_retention,omitempty"` +} + +type AssistantGenTextExtension struct { + Refusal *OutputRefusal `json:"refusal,omitempty"` + Annotations []*TextAnnotation `json:"annotations,omitempty"` +} + +type ResponseError struct { + Code ResponseErrorCode `json:"code,omitempty"` + Message string `json:"message,omitempty"` +} + +type IncompleteDetails struct { + Reason string `json:"reason,omitempty"` +} + +type Reasoning struct { + Effort ReasoningEffort `json:"effort,omitempty"` + Summary ReasoningSummary `json:"summary,omitempty"` +} + +type OutputRefusal struct { + Reason string `json:"reason,omitempty"` +} + +type TextAnnotation struct { + Index int `json:"index,omitempty"` + + Type TextAnnotationType `json:"type,omitempty"` + + FileCitation *TextAnnotationFileCitation `json:"file_citation,omitempty"` + URLCitation *TextAnnotationURLCitation `json:"url_citation,omitempty"` + ContainerFileCitation *TextAnnotationContainerFileCitation `json:"container_file_citation,omitempty"` + FilePath *TextAnnotationFilePath `json:"file_path,omitempty"` +} + +type TextAnnotationFileCitation struct { + // The ID of the file. + FileID string `json:"file_id,omitempty"` + // The filename of the file cited. + Filename string `json:"filename,omitempty"` + + // The index of the file in the list of files. + Index int `json:"index,omitempty"` +} + +type TextAnnotationURLCitation struct { + // The title of the web resource. + Title string `json:"title,omitempty"` + // The URL of the web resource. + URL string `json:"url,omitempty"` + + // The index of the first character of the URL citation in the message. + StartIndex int `json:"start_index,omitempty"` + // The index of the last character of the URL citation in the message. + EndIndex int `json:"end_index,omitempty"` +} + +type TextAnnotationContainerFileCitation struct { + // The ID of the container file. + ContainerID string `json:"container_id,omitempty"` + + // The ID of the file. + FileID string `json:"file_id,omitempty"` + // The filename of the container file cited. + Filename string `json:"filename,omitempty"` + + // The index of the first character of the container file citation in the message. + StartIndex int `json:"start_index,omitempty"` + // The index of the last character of the container file citation in the message. + EndIndex int `json:"end_index,omitempty"` +} + +type TextAnnotationFilePath struct { + // The ID of the file. + FileID string `json:"file_id,omitempty"` + + // The index of the file in the list of files. + Index int `json:"index,omitempty"` +} + +// ConcatAssistantGenTextExtensions concatenates multiple AssistantGenTextExtension chunks into a single one. +func ConcatAssistantGenTextExtensions(chunks []*AssistantGenTextExtension) (*AssistantGenTextExtension, error) { + if len(chunks) == 0 { + return nil, fmt.Errorf("no assistant generated text extension found") + } + + ret := &AssistantGenTextExtension{} + + var allAnnotations []*TextAnnotation + for _, ext := range chunks { + allAnnotations = append(allAnnotations, ext.Annotations...) + } + + var ( + indices []int + indexToAnnotation = map[int]*TextAnnotation{} + ) + + for _, an := range allAnnotations { + if an == nil { + continue + } + if indexToAnnotation[an.Index] == nil { + indexToAnnotation[an.Index] = an + indices = append(indices, an.Index) + } else { + return nil, fmt.Errorf("duplicate annotation index %d", an.Index) + } + } + + sort.Slice(indices, func(i, j int) bool { + return indices[i] < indices[j] + }) + + ret.Annotations = make([]*TextAnnotation, 0, len(indices)) + for _, idx := range indices { + an := *indexToAnnotation[idx] + an.Index = 0 // clear index + ret.Annotations = append(ret.Annotations, &an) + } + + for _, ext := range chunks { + if ext.Refusal == nil { + continue + } + if ret.Refusal == nil { + ret.Refusal = ext.Refusal + } else { + ret.Refusal.Reason += ext.Refusal.Reason + } + } + + return ret, nil +} + +// ConcatResponseMetaExtensions concatenates multiple ResponseMetaExtension chunks into a single one. +func ConcatResponseMetaExtensions(chunks []*ResponseMetaExtension) (*ResponseMetaExtension, error) { + if len(chunks) == 0 { + return nil, fmt.Errorf("no response meta extension found") + } + if len(chunks) == 1 { + return chunks[0], nil + } + + ret := &ResponseMetaExtension{} + + for _, ext := range chunks { + if ext.ID != "" { + ret.ID = ext.ID + } + if ext.Status != "" { + ret.Status = ext.Status + } + if ext.Error != nil { + ret.Error = ext.Error + } + if ext.IncompleteDetails != nil { + ret.IncompleteDetails = ext.IncompleteDetails + } + if ext.PreviousResponseID != "" { + ret.PreviousResponseID = ext.PreviousResponseID + } + if ext.Reasoning != nil { + ret.Reasoning = ext.Reasoning + } + if ext.ServiceTier != "" { + ret.ServiceTier = ext.ServiceTier + } + } + + return ret, nil +} diff --git a/schema/openai/extension_test.go b/schema/openai/extension_test.go new file mode 100644 index 000000000..640982fdf --- /dev/null +++ b/schema/openai/extension_test.go @@ -0,0 +1,193 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openai + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConcatResponseMetaExtensions(t *testing.T) { + t.Run("multiple extensions - takes last non-empty values", func(t *testing.T) { + err1 := &ResponseError{Code: "err1", Message: "msg1"} + incomplete := &IncompleteDetails{Reason: "max_tokens"} + + exts := []*ResponseMetaExtension{ + { + ID: "id_1", + Status: "in_progress", + Error: err1, + IncompleteDetails: nil, + }, + { + ID: "id_2", + Status: "", + Error: nil, + IncompleteDetails: nil, + }, + { + ID: "", + Status: "completed", + Error: nil, + IncompleteDetails: incomplete, + }, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "id_2", result.ID) + assert.Equal(t, ResponseStatus("completed"), result.Status) + assert.Equal(t, err1, result.Error) + assert.Equal(t, incomplete, result.IncompleteDetails) + }) + + t.Run("streaming scenario", func(t *testing.T) { + exts := []*ResponseMetaExtension{ + {ID: "chatcmpl_stream", Status: "", Error: nil, IncompleteDetails: nil}, + {ID: "", Status: ResponseStatus("in_progress"), Error: nil, IncompleteDetails: nil}, + {ID: "", Status: ResponseStatus("completed"), Error: nil, IncompleteDetails: nil}, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "chatcmpl_stream", result.ID) + assert.Equal(t, ResponseStatus("completed"), result.Status) + }) +} + +func TestConcatAssistantGenTextExtensions(t *testing.T) { + t.Run("single extension with annotations", func(t *testing.T) { + ext := &AssistantGenTextExtension{ + Annotations: []*TextAnnotation{ + { + Index: 0, + Type: "file_citation", + FileCitation: &TextAnnotationFileCitation{ + FileID: "file_123", + Filename: "doc.pdf", + }, + }, + }, + } + + result, err := ConcatAssistantGenTextExtensions([]*AssistantGenTextExtension{ext}) + assert.NoError(t, err) + assert.Len(t, result.Annotations, 1) + assert.Equal(t, "file_123", result.Annotations[0].FileCitation.FileID) + }) + + t.Run("multiple extensions - merges annotations by index", func(t *testing.T) { + exts := []*AssistantGenTextExtension{ + { + Annotations: []*TextAnnotation{ + { + Index: 0, + Type: "file_citation", + FileCitation: &TextAnnotationFileCitation{ + FileID: "file_1", + }, + }, + }, + }, + { + Annotations: []*TextAnnotation{ + { + Index: 2, + Type: "url_citation", + URLCitation: &TextAnnotationURLCitation{ + URL: "https://example.com", + }, + }, + }, + }, + { + Annotations: []*TextAnnotation{ + { + Index: 1, + Type: "file_path", + FilePath: &TextAnnotationFilePath{ + FileID: "file_2", + }, + }, + }, + }, + } + + result, err := ConcatAssistantGenTextExtensions(exts) + assert.NoError(t, err) + assert.Len(t, result.Annotations, 3) + assert.Equal(t, "file_1", result.Annotations[0].FileCitation.FileID) + assert.Equal(t, "file_2", result.Annotations[1].FilePath.FileID) + assert.Equal(t, "https://example.com", result.Annotations[2].URLCitation.URL) + }) + + t.Run("streaming scenario - annotations arrive in chunks", func(t *testing.T) { + exts := []*AssistantGenTextExtension{ + { + Annotations: []*TextAnnotation{ + {Index: 0, Type: "file_citation", FileCitation: &TextAnnotationFileCitation{FileID: "f1"}}, + }, + }, + { + Annotations: []*TextAnnotation{ + {Index: 1, Type: "url_citation", URLCitation: &TextAnnotationURLCitation{URL: "url1"}}, + }, + }, + { + Annotations: []*TextAnnotation{ + {Index: 2, Type: "file_path", FilePath: &TextAnnotationFilePath{FileID: "f2"}}, + }, + }, + } + + result, err := ConcatAssistantGenTextExtensions(exts) + assert.NoError(t, err) + assert.Len(t, result.Annotations, 3) + assert.Equal(t, "f1", result.Annotations[0].FileCitation.FileID) + assert.Equal(t, "url1", result.Annotations[1].URLCitation.URL) + assert.Equal(t, "f2", result.Annotations[2].FilePath.FileID) + }) + + t.Run("multiple extensions - concatenates refusal reason", func(t *testing.T) { + ext1 := &AssistantGenTextExtension{Refusal: &OutputRefusal{Reason: "A"}} + ext2 := &AssistantGenTextExtension{Refusal: &OutputRefusal{Reason: "B"}} + + result, err := ConcatAssistantGenTextExtensions([]*AssistantGenTextExtension{ext1, ext2}) + assert.NoError(t, err) + assert.NotNil(t, result.Refusal) + assert.Equal(t, "AB", result.Refusal.Reason) + }) + + t.Run("duplicate index - error occurrence", func(t *testing.T) { + exts := []*AssistantGenTextExtension{ + { + Annotations: []*TextAnnotation{ + {Index: 0, Type: "file_citation", FileCitation: &TextAnnotationFileCitation{FileID: "first"}}, + }, + }, + { + Annotations: []*TextAnnotation{ + {Index: 0, Type: "url_citation", URLCitation: &TextAnnotationURLCitation{URL: "second"}}, + }, + }, + } + + _, err := ConcatAssistantGenTextExtensions(exts) + assert.Error(t, err) + }) +} diff --git a/schema/openai/response_meta.go b/schema/openai/response_meta.go deleted file mode 100644 index e1933065b..000000000 --- a/schema/openai/response_meta.go +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright 2025 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package openai - -type ResponseMetaExtension struct { - ID string `json:"id,omitempty"` - Status string `json:"status,omitempty"` - Error *ResponseError `json:"error,omitempty"` - StreamError *StreamResponseError `json:"stream_error,omitempty"` - IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"` -} - -type ResponseError struct { - Code string `json:"code,omitempty"` - Message string `json:"message,omitempty"` -} - -type StreamResponseError struct { - Code string - Message string - Param string -} - -type IncompleteDetails struct { - Reason string `json:"reason,omitempty"` -} diff --git a/schema/tool.go b/schema/tool.go index ccc93b6a3..2d6bf90db 100644 --- a/schema/tool.go +++ b/schema/tool.go @@ -59,6 +59,27 @@ const ( ToolChoiceForced ToolChoice = "forced" ) +type AllowedTool struct { + // FunctionToolName is the name of the function tool. + FunctionToolName string + + MCPTool *AllowedMCPTool + + ServerTool *AllowedServerTool +} +type AllowedMCPTool struct { + // ServerLabel is the label of the MCP server. + ServerLabel string + // The name of the MCP tool. + Name string +} + +type AllowedServerTool struct { + // The name of the server tool. + Name string +} + +// ToolInfo is the information of a tool. // ToolInfo describes a tool that can be passed to a ChatModel via // [ToolCallingChatModel.WithTools] or [ChatModel.BindTools]. // From 0313a931285cb34b8dd0a8aa661f11df09160105 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Tue, 6 Jan 2026 16:48:56 +0800 Subject: [PATCH 17/28] fix: concat agentic messages (#604) --- components/model/callback_extra.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/components/model/callback_extra.go b/components/model/callback_extra.go index 2767e2e5e..afff3f0a7 100644 --- a/components/model/callback_extra.go +++ b/components/model/callback_extra.go @@ -29,10 +29,10 @@ type TokenUsage struct { PromptTokenDetails PromptTokenDetails // CompletionTokens is the number of completion tokens. CompletionTokens int - // CompletionTokensDetails is a breakdown of the completion tokens. - CompletionTokensDetails CompletionTokensDetails // TotalTokens is the total number of tokens. TotalTokens int + // CompletionTokensDetails is a breakdown of the completion tokens. + CompletionTokensDetails CompletionTokensDetails } type CompletionTokensDetails struct { From 88fead762793e5feb5a1ab482ae2b530e9242b89 Mon Sep 17 00:00:00 2001 From: Megumin Date: Thu, 8 Jan 2026 15:34:53 +0800 Subject: [PATCH 18/28] fix(schema): agentic concat support extra (#670) --- schema/agentic_message.go | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/schema/agentic_message.go b/schema/agentic_message.go index b2225b2c7..2ba02b689 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -695,8 +695,10 @@ func ConcatAgenticMessages(msgs []*AgenticMessage) (*AgenticMessage, error) { role AgenticRoleType blocks []*ContentBlock metas []*AgenticResponseMeta + extra map[string]any blockIndices []int indexToBlocks = map[int][]*ContentBlock{} + extraList = make([]map[string]any, 0, len(msgs)) ) if len(msgs) == 1 { @@ -747,6 +749,10 @@ func ConcatAgenticMessages(msgs []*AgenticMessage) (*AgenticMessage, error) { if msg.ResponseMeta != nil { metas = append(metas, msg.ResponseMeta) } + + if msg.Extra != nil { + extraList = append(extraList, msg.Extra) + } } meta, err := concatAgenticResponseMeta(metas) @@ -758,7 +764,8 @@ func ConcatAgenticMessages(msgs []*AgenticMessage) (*AgenticMessage, error) { // All blocks are streaming, concat each group by index indexToBlock := map[int]*ContentBlock{} for idx, bs := range indexToBlocks { - b, err := concatChunksOfSameContentBlock(bs) + var b *ContentBlock + b, err = concatChunksOfSameContentBlock(bs) if err != nil { return nil, err } @@ -773,10 +780,18 @@ func ConcatAgenticMessages(msgs []*AgenticMessage) (*AgenticMessage, error) { } } + if len(extraList) > 0 { + extra, err = concatExtra(extraList) + if err != nil { + return nil, err + } + } + return &AgenticMessage{ Role: role, ResponseMeta: meta, ContentBlocks: blocks, + Extra: extra, }, nil } From 90e164419fd2361b312bf73468db42a228197c40 Mon Sep 17 00:00:00 2001 From: Megumin Date: Thu, 8 Jan 2026 19:36:17 +0800 Subject: [PATCH 19/28] feat(schema): optimize agent message format (#671) --- schema/agentic_message.go | 28 +++- schema/agentic_message_test.go | 270 +++++++++++++++++---------------- 2 files changed, 164 insertions(+), 134 deletions(-) diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 2ba02b689..5008f0c75 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -18,6 +18,7 @@ package schema import ( "context" + "encoding/json" "fmt" "reflect" "sort" @@ -1834,7 +1835,7 @@ func (u *UserInputText) String() string { // String returns the string representation of UserInputImage. func (u *UserInputImage) String() string { - return formatMediaString(u.URL, u.Base64Data, u.MIMEType, u.Detail) + return formatMediaString(u.URL, u.Base64Data, u.MIMEType, string(u.Detail)) } // String returns the string representation of UserInputAudio. @@ -1902,7 +1903,7 @@ func (s *ServerToolCall) String() string { if s.CallID != "" { sb.WriteString(fmt.Sprintf(" call_id: %s\n", s.CallID)) } - sb.WriteString(fmt.Sprintf(" arguments: %v\n", s.Arguments)) + sb.WriteString(fmt.Sprintf(" arguments: %s\n", printAny(s.Arguments))) return sb.String() } @@ -1913,7 +1914,7 @@ func (s *ServerToolResult) String() string { if s.CallID != "" { sb.WriteString(fmt.Sprintf(" call_id: %s\n", s.CallID)) } - sb.WriteString(fmt.Sprintf(" result: %v\n", s.Result)) + sb.WriteString(fmt.Sprintf(" result: %s\n", printAny(s.Result))) return sb.String() } @@ -1996,7 +1997,7 @@ func truncateString(s string, maxLen int) string { } // formatMediaString formats URL, Base64Data, MIMEType and Detail for media content -func formatMediaString(url, base64Data string, mimeType string, detail any) string { +func formatMediaString(url, base64Data string, mimeType string, detail string) string { sb := &strings.Builder{} if url != "" { sb.WriteString(fmt.Sprintf(" url: %s\n", truncateString(url, 100))) @@ -2008,8 +2009,8 @@ func formatMediaString(url, base64Data string, mimeType string, detail any) stri if mimeType != "" { sb.WriteString(fmt.Sprintf(" mime_type: %s\n", mimeType)) } - if detail != nil && detail != "" { - sb.WriteString(fmt.Sprintf(" detail: %v\n", detail)) + if detail != "" { + sb.WriteString(fmt.Sprintf(" detail: %s\n", detail)) } return sb.String() } @@ -2027,3 +2028,18 @@ func validateExtensionType(expected reflect.Type, actual any) (reflect.Type, boo } return expected, true } + +func printAny(a any) string { + switch v := a.(type) { + case string: + return v + case fmt.Stringer: + return v.String() + default: + b, err := json.MarshalIndent(a, "", " ") + if err != nil { + return fmt.Sprintf("%v", a) + } + return string(b) + } +} diff --git a/schema/agentic_message_test.go b/schema/agentic_message_test.go index 016aa5c4e..4beb74930 100644 --- a/schema/agentic_message_test.go +++ b/schema/agentic_message_test.go @@ -1234,6 +1234,61 @@ func TestAgenticMessageString(t *testing.T) { Detail: ImageURLDetailHigh, }, }, + { + Type: ContentBlockTypeUserInputAudio, + UserInputAudio: &UserInputAudio{ + URL: "http://audio.com", + Base64Data: "audio_data", + MIMEType: "audio/mp3", + }, + }, + { + Type: ContentBlockTypeUserInputVideo, + UserInputVideo: &UserInputVideo{ + URL: "http://video.com", + Base64Data: "video_data", + MIMEType: "video/mp4", + }, + }, + { + Type: ContentBlockTypeUserInputFile, + UserInputFile: &UserInputFile{ + URL: "http://file.com", + Name: "file.txt", + Base64Data: "file_data", + MIMEType: "text/plain", + }, + }, + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "I'll check the current weather in New York City for you.", + }, + }, + { + Type: ContentBlockTypeAssistantGenImage, + AssistantGenImage: &AssistantGenImage{ + URL: "http://gen_image.com", + Base64Data: "gen_image_data", + MIMEType: "image/png", + }, + }, + { + Type: ContentBlockTypeAssistantGenAudio, + AssistantGenAudio: &AssistantGenAudio{ + URL: "http://gen_audio.com", + Base64Data: "gen_audio_data", + MIMEType: "audio/wav", + }, + }, + { + Type: ContentBlockTypeAssistantGenVideo, + AssistantGenVideo: &AssistantGenVideo{ + URL: "http://gen_video.com", + Base64Data: "gen_video_data", + MIMEType: "video/mp4", + }, + }, { Type: ContentBlockTypeReasoning, Reasoning: &Reasoning{ @@ -1245,12 +1300,6 @@ func TestAgenticMessageString(t *testing.T) { EncryptedContent: "encrypted_reasoning_content_that_is_very_long_and_will_be_truncated_for_display", }, }, - { - Type: ContentBlockTypeAssistantGenText, - AssistantGenText: &AssistantGenText{ - Text: "I'll check the current weather in New York City for you.", - }, - }, { Type: ContentBlockTypeFunctionToolCall, FunctionToolCall: &FunctionToolCall{ @@ -1268,6 +1317,39 @@ func TestAgenticMessageString(t *testing.T) { Result: `{"temperature":72,"condition":"sunny","humidity":45,"wind_speed":8}`, }, }, + { + Type: ContentBlockTypeServerToolCall, + ServerToolCall: &ServerToolCall{ + Name: "server_tool", + CallID: "call_1", + Arguments: map[string]any{"a": 1}, + }, + }, + { + Type: ContentBlockTypeServerToolResult, + ServerToolResult: &ServerToolResult{ + Name: "server_tool", + CallID: "call_1", + Result: map[string]any{"success": true}, + }, + }, + { + Type: ContentBlockTypeMCPToolApprovalRequest, + MCPToolApprovalRequest: &MCPToolApprovalRequest{ + ID: "req_1", + Name: "mcp_tool", + ServerLabel: "mcp_server", + Arguments: "{}", + }, + }, + { + Type: ContentBlockTypeMCPToolApprovalResponse, + MCPToolApprovalResponse: &MCPToolApprovalResponse{ + ApprovalRequestID: "req_1", + Approve: true, + Reason: "looks good", + }, + }, { Type: ContentBlockTypeMCPToolCall, MCPToolCall: &MCPToolCall{ @@ -1322,34 +1404,80 @@ content_blocks: base64_data: iVBORw0KGgoAAAANSUhE...... (96 bytes) mime_type: image/jpeg detail: high - [2] type: reasoning + [2] type: user_input_audio + url: http://audio.com + base64_data: audio_data... (10 bytes) + mime_type: audio/mp3 + [3] type: user_input_video + url: http://video.com + base64_data: video_data... (10 bytes) + mime_type: video/mp4 + [4] type: user_input_file + name: file.txt + url: http://file.com + base64_data: file_data... (9 bytes) + mime_type: text/plain + [5] type: assistant_gen_text + text: I'll check the current weather in New York City for you. + [6] type: assistant_gen_image + url: http://gen_image.com + base64_data: gen_image_data... (14 bytes) + mime_type: image/png + [7] type: assistant_gen_audio + url: http://gen_audio.com + base64_data: gen_audio_data... (14 bytes) + mime_type: audio/wav + [8] type: assistant_gen_video + url: http://gen_video.com + base64_data: gen_video_data... (14 bytes) + mime_type: video/mp4 + [9] type: reasoning summary: 3 items [0] First, I need to identify the location (New York City) from the user's query. [1] Then, I should call the weather API to get current conditions. [2] Finally, I'll format the response in a user-friendly way with temperature and conditions. encrypted_content: encrypted_reasoning_content_that_is_very_long_and_... - [3] type: assistant_gen_text - text: I'll check the current weather in New York City for you. - [4] type: function_tool_call + [10] type: function_tool_call call_id: call_weather_123 name: get_current_weather arguments: {"location":"New York City","unit":"fahrenheit"} stream_index: 0 - [5] type: function_tool_result + [11] type: function_tool_result call_id: call_weather_123 name: get_current_weather result: {"temperature":72,"condition":"sunny","humidity":45,"wind_speed":8} - [6] type: mcp_tool_call + [12] type: server_tool_call + name: server_tool + call_id: call_1 + arguments: { + "a": 1 +} + [13] type: server_tool_result + name: server_tool + call_id: call_1 + result: { + "success": true +} + [14] type: mcp_tool_approval_request + server_label: mcp_server + id: req_1 + name: mcp_tool + arguments: {} + [15] type: mcp_tool_approval_response + approval_request_id: req_1 + approve: true + reason: looks good + [16] type: mcp_tool_call server_label: weather-mcp-server call_id: mcp_forecast_456 name: get_7day_forecast arguments: {"city":"New York","days":7} - [7] type: mcp_tool_result + [17] type: mcp_tool_result call_id: mcp_forecast_456 name: get_7day_forecast result: {"status":"partial","days_available":3} error: [503] Service temporarily unavailable for full 7-day forecast - [8] type: mcp_list_tools_result + [18] type: mcp_list_tools_result server_label: weather-mcp-server tools: 3 items - get_current_weather: Get current weather conditions for a location @@ -1359,120 +1487,6 @@ response_meta: token_usage: prompt=250, completion=180, total=430 `, output) - t.Run("full fields", func(t *testing.T) { - msg := &AgenticMessage{ - Role: AgenticRoleTypeSystem, - ContentBlocks: []*ContentBlock{ - { - Type: ContentBlockTypeUserInputAudio, - UserInputAudio: &UserInputAudio{ - URL: "http://audio.com", - Base64Data: "audio_data", - MIMEType: "audio/mp3", - }, - }, - { - Type: ContentBlockTypeUserInputVideo, - UserInputVideo: &UserInputVideo{ - URL: "http://video.com", - Base64Data: "video_data", - MIMEType: "video/mp4", - }, - }, - { - Type: ContentBlockTypeUserInputFile, - UserInputFile: &UserInputFile{ - URL: "http://file.com", - Name: "file.txt", - Base64Data: "file_data", - MIMEType: "text/plain", - }, - }, - { - Type: ContentBlockTypeAssistantGenImage, - AssistantGenImage: &AssistantGenImage{ - URL: "http://gen_image.com", - Base64Data: "gen_image_data", - MIMEType: "image/png", - }, - }, - { - Type: ContentBlockTypeAssistantGenAudio, - AssistantGenAudio: &AssistantGenAudio{ - URL: "http://gen_audio.com", - Base64Data: "gen_audio_data", - MIMEType: "audio/wav", - }, - }, - { - Type: ContentBlockTypeAssistantGenVideo, - AssistantGenVideo: &AssistantGenVideo{ - URL: "http://gen_video.com", - Base64Data: "gen_video_data", - MIMEType: "video/mp4", - }, - }, - { - Type: ContentBlockTypeServerToolCall, - ServerToolCall: &ServerToolCall{ - Name: "server_tool", - CallID: "call_1", - Arguments: map[string]any{"a": 1}, - }, - }, - { - Type: ContentBlockTypeServerToolResult, - ServerToolResult: &ServerToolResult{ - Name: "server_tool", - CallID: "call_1", - Result: map[string]any{"success": true}, - }, - }, - { - Type: ContentBlockTypeMCPToolApprovalRequest, - MCPToolApprovalRequest: &MCPToolApprovalRequest{ - ID: "req_1", - Name: "mcp_tool", - ServerLabel: "mcp_server", - Arguments: "{}", - }, - }, - { - Type: ContentBlockTypeMCPToolApprovalResponse, - MCPToolApprovalResponse: &MCPToolApprovalResponse{ - ApprovalRequestID: "req_1", - Approve: true, - Reason: "looks good", - }, - }, - }, - } - - s := msg.String() - assert.Contains(t, s, "role: system") - assert.Contains(t, s, "type: user_input_audio") - assert.Contains(t, s, "http://audio.com") - assert.Contains(t, s, "type: user_input_video") - assert.Contains(t, s, "http://video.com") - assert.Contains(t, s, "type: user_input_file") - assert.Contains(t, s, "file.txt") - assert.Contains(t, s, "type: assistant_gen_image") - assert.Contains(t, s, "http://gen_image.com") - assert.Contains(t, s, "type: assistant_gen_audio") - assert.Contains(t, s, "http://gen_audio.com") - assert.Contains(t, s, "type: assistant_gen_video") - assert.Contains(t, s, "http://gen_video.com") - assert.Contains(t, s, "type: server_tool_call") - assert.Contains(t, s, "server_tool") - assert.Contains(t, s, "map[a:1]") - assert.Contains(t, s, "type: server_tool_result") - assert.Contains(t, s, "map[success:true]") - assert.Contains(t, s, "type: mcp_tool_approval_request") - assert.Contains(t, s, "req_1") - assert.Contains(t, s, "type: mcp_tool_approval_response") - assert.Contains(t, s, "looks good") - }) - t.Run("nil/empty fields", func(t *testing.T) { msg := &AgenticMessage{ Role: AgenticRoleTypeUser, From d513f301aa9459ca6e439ec9d380f92d14339a38 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Mon, 12 Jan 2026 12:08:31 +0800 Subject: [PATCH 20/28] fix: openai ConcatResponseMetaExtensions (#678) --- schema/openai/extension.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/schema/openai/extension.go b/schema/openai/extension.go index c30d2d8ec..1e10c411e 100644 --- a/schema/openai/extension.go +++ b/schema/openai/extension.go @@ -200,6 +200,12 @@ func ConcatResponseMetaExtensions(chunks []*ResponseMetaExtension) (*ResponseMet if ext.ServiceTier != "" { ret.ServiceTier = ext.ServiceTier } + if ext.CreatedAt != 0 { + ret.CreatedAt = ext.CreatedAt + } + if ext.PromptCacheRetention != "" { + ret.PromptCacheRetention = ext.PromptCacheRetention + } } return ret, nil From 20993e340bb5d01bc798e1dc20b04e81f7f11a10 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Mon, 12 Jan 2026 18:04:30 +0800 Subject: [PATCH 21/28] feat: improve comment (#679) --- components/agentic/interface.go | 3 + schema/agentic_message.go | 164 +++++++++++++++++++++++++------- schema/tool.go | 1 + 3 files changed, 131 insertions(+), 37 deletions(-) diff --git a/components/agentic/interface.go b/components/agentic/interface.go index e9960d332..e62a8eeab 100644 --- a/components/agentic/interface.go +++ b/components/agentic/interface.go @@ -25,5 +25,8 @@ import ( type Model interface { Generate(ctx context.Context, input []*schema.AgenticMessage, opts ...Option) (*schema.AgenticMessage, error) Stream(ctx context.Context, input []*schema.AgenticMessage, opts ...Option) (*schema.StreamReader[*schema.AgenticMessage], error) + + // WithTools returns a new Model instance with the specified tools bound. + // This method does not modify the current instance, making it safer for concurrent use. WithTools(tools []*schema.ToolInfo) (Model, error) } diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 5008f0c75..63c78c3eb 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -66,122 +66,208 @@ const ( ) type AgenticMessage struct { + // ResponseMeta is the response metadata. ResponseMeta *AgenticResponseMeta - Role AgenticRoleType + // Role is the message role. + Role AgenticRoleType + + // ContentBlocks is the list of content blocks. ContentBlocks []*ContentBlock + // Extra is the additional information. Extra map[string]any } type AgenticResponseMeta struct { + // TokenUsage is the token usage. TokenUsage *TokenUsage + // OpenAIExtension is the extension for OpenAI. OpenAIExtension *openai.ResponseMetaExtension + + // GeminiExtension is the extension for Gemini. GeminiExtension *gemini.ResponseMetaExtension + + // ClaudeExtension is the extension for Claude. ClaudeExtension *claude.ResponseMetaExtension - Extension any -} -type StreamingMeta struct { - // Index specifies the index position of this block in the final response. - Index int + // Extension is the extension for other models, supplied by the component implementer. + Extension any } type ContentBlock struct { Type ContentBlockType + // Reasoning contains the reasoning content generated by the model. Reasoning *Reasoning - UserInputText *UserInputText + // UserInputText contains the text content provided by the user. + UserInputText *UserInputText + + // UserInputImage contains the image content provided by the user. UserInputImage *UserInputImage + + // UserInputAudio contains the audio content provided by the user. UserInputAudio *UserInputAudio + + // UserInputVideo contains the video content provided by the user. UserInputVideo *UserInputVideo - UserInputFile *UserInputFile - AssistantGenText *AssistantGenText + // UserInputFile contains the file content provided by the user. + UserInputFile *UserInputFile + + // AssistantGenText contains the text content generated by the model. + AssistantGenText *AssistantGenText + + // AssistantGenImage contains the image content generated by the model. AssistantGenImage *AssistantGenImage + + // AssistantGenAudio contains the audio content generated by the model. AssistantGenAudio *AssistantGenAudio + + // AssistantGenVideo contains the video content generated by the model. AssistantGenVideo *AssistantGenVideo - // FunctionToolCall holds invocation details for a user-defined tool. + // FunctionToolCall contains the invocation details for a user-defined tool. FunctionToolCall *FunctionToolCall - // FunctionToolResult is the result from a user-defined tool call. + + // FunctionToolResult contains the result returned from a user-defined tool call. FunctionToolResult *FunctionToolResult - // ServerToolCall holds invocation details for a provider built-in tool run on the model server. + + // ServerToolCall contains the invocation details for a provider built-in tool executed on the model server. ServerToolCall *ServerToolCall - // ServerToolResult is the result from a provider built-in tool run on the model server. + + // ServerToolResult contains the result returned from a provider built-in tool executed on the model server. ServerToolResult *ServerToolResult - // MCPToolCall holds invocation details for an MCP tool managed by the model server. + // MCPToolCall contains the invocation details for an MCP tool managed by the model server. MCPToolCall *MCPToolCall - // MCPToolResult is the result from an MCP tool managed by the model server. + + // MCPToolResult contains the result returned from an MCP tool managed by the model server. MCPToolResult *MCPToolResult - // MCPListToolsResult lists available MCP tools reported by the model server. + + // MCPListToolsResult contains the list of available MCP tools reported by the model server. MCPListToolsResult *MCPListToolsResult - // MCPToolApprovalRequest requests user approval for an MCP tool call when required. + + // MCPToolApprovalRequest contains the user approval request for an MCP tool call when required. MCPToolApprovalRequest *MCPToolApprovalRequest - // MCPToolApprovalResponse records the user's approval decision for an MCP tool call. + + // MCPToolApprovalResponse contains the user's approval decision for an MCP tool call. MCPToolApprovalResponse *MCPToolApprovalResponse + // StreamingMeta contains metadata for streaming responses. StreamingMeta *StreamingMeta - Extra map[string]any + + // Extra contains additional information for the content block. + Extra map[string]any +} + +type StreamingMeta struct { + // Index specifies the index position of this block in the final response. + Index int } type UserInputText struct { + // Text is the text content. Text string } type UserInputImage struct { - URL string + // URL is the HTTP/HTTPS link. + URL string + + // Base64Data is the binary data in Base64 encoded string format. Base64Data string - MIMEType string - Detail ImageURLDetail + + // MIMEType is the mime type, e.g. "image/png". + MIMEType string + + // Detail is the quality of the image url. + Detail ImageURLDetail } type UserInputAudio struct { - URL string + // URL is the HTTP/HTTPS link. + URL string + + // Base64Data is the binary data in Base64 encoded string format. Base64Data string - MIMEType string + + // MIMEType is the mime type, e.g. "audio/wav". + MIMEType string } type UserInputVideo struct { - URL string + // URL is the HTTP/HTTPS link. + URL string + + // Base64Data is the binary data in Base64 encoded string format. Base64Data string - MIMEType string + + // MIMEType is the mime type, e.g. "video/mp4". + MIMEType string } type UserInputFile struct { - URL string - Name string + // URL is the HTTP/HTTPS link. + URL string + + // Name is the filename. + Name string + + // Base64Data is the binary data in Base64 encoded string format. Base64Data string - MIMEType string + + // MIMEType is the mime type, e.g. "application/pdf". + MIMEType string } type AssistantGenText struct { + // Text is the generated text. Text string + // OpenAIExtension is the extension for OpenAI. OpenAIExtension *openai.AssistantGenTextExtension + + // ClaudeExtension is the extension for Claude. ClaudeExtension *claude.AssistantGenTextExtension - Extension any + + // Extension is the extension for other models, supplied by the component implementer. + Extension any } type AssistantGenImage struct { - URL string + // URL is the HTTP/HTTPS link. + URL string + + // Base64Data is the binary data in Base64 encoded string format. Base64Data string - MIMEType string + + // MIMEType is the mime type, e.g. "image/png". + MIMEType string } type AssistantGenAudio struct { - URL string + // URL is the HTTP/HTTPS link. + URL string + + // Base64Data is the binary data in Base64 encoded string format. Base64Data string - MIMEType string + + // MIMEType is the mime type, e.g. "audio/wav". + MIMEType string } type AssistantGenVideo struct { - URL string + // URL is the HTTP/HTTPS link. + URL string + + // Base64Data is the binary data in Base64 encoded string format. Base64Data string - MIMEType string + + // MIMEType is the mime type, e.g. "video/mp4". + MIMEType string } type Reasoning struct { @@ -196,6 +282,7 @@ type ReasoningSummary struct { // Index specifies the index position of this summary in the final Reasoning. Index int + // Text is the reasoning content summary. Text string } @@ -284,7 +371,10 @@ type MCPToolResult struct { } type MCPToolCallError struct { - Code *int64 + // Code is the error code. + Code *int64 + + // Message is the error message. Message string } diff --git a/schema/tool.go b/schema/tool.go index 2d6bf90db..efed6d34b 100644 --- a/schema/tool.go +++ b/schema/tool.go @@ -67,6 +67,7 @@ type AllowedTool struct { ServerTool *AllowedServerTool } + type AllowedMCPTool struct { // ServerLabel is the label of the MCP server. ServerLabel string From e3dc169340d564a8acd38fc05dd9e64b8a56e8ef Mon Sep 17 00:00:00 2001 From: mrh997 Date: Tue, 13 Jan 2026 21:41:07 +0800 Subject: [PATCH 22/28] feat: add agentic callbacks template (#681) --- components/agentic/callback_extra.go | 85 ----- components/agentic/callback_extra_test.go | 35 -- components/agentic/interface.go | 32 -- components/agentic/option.go | 142 -------- components/agentic/option_test.go | 79 ----- components/model/callback_extra.go | 18 +- components/model/interface.go | 12 + components/model/option.go | 31 +- components/model/option_test.go | 16 + ...te_agentic.go => agentic_chat_template.go} | 16 +- .../prompt/agentic_chat_template_test.go | 124 +++++++ components/prompt/callback_extra.go | 50 +-- components/prompt/callback_extra_test.go | 21 +- .../prompt/chat_template_agentic_test.go | 111 ------- components/prompt/interface.go | 1 + ..._node_agentic.go => agentic_tools_node.go} | 0 ...tic_test.go => agentic_tools_node_test.go} | 0 compose/chain.go | 3 +- compose/chain_branch.go | 3 +- compose/chain_parallel.go | 3 +- compose/component_to_graph_node.go | 3 +- compose/graph.go | 3 +- schema/agentic_message.go | 8 +- utils/callbacks/template.go | 176 +++++++++- utils/callbacks/template_test.go | 304 +++++++++++++++++- 25 files changed, 694 insertions(+), 582 deletions(-) delete mode 100644 components/agentic/callback_extra.go delete mode 100644 components/agentic/callback_extra_test.go delete mode 100644 components/agentic/interface.go delete mode 100644 components/agentic/option.go delete mode 100644 components/agentic/option_test.go rename components/prompt/{chat_template_agentic.go => agentic_chat_template.go} (87%) create mode 100644 components/prompt/agentic_chat_template_test.go delete mode 100644 components/prompt/chat_template_agentic_test.go rename compose/{tools_node_agentic.go => agentic_tools_node.go} (100%) rename compose/{tools_node_agentic_test.go => agentic_tools_node_test.go} (100%) diff --git a/components/agentic/callback_extra.go b/components/agentic/callback_extra.go deleted file mode 100644 index 2c5a656fa..000000000 --- a/components/agentic/callback_extra.go +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Copyright 2025 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// Package agentic defines callback payloads and configuration types for agentic models. -package agentic - -import ( - "github.com/cloudwego/eino/callbacks" - "github.com/cloudwego/eino/schema" -) - -// Config is the config for the model. -type Config struct { - // Model is the model name. - Model string - // Temperature is the temperature, which controls the randomness of the model. - Temperature float64 - // TopP is the top p, which controls the diversity of the model. - TopP float64 -} - -// CallbackInput is the input for the model callback. -type CallbackInput struct { - // Messages is the messages to be sent to the model. - Messages []*schema.AgenticMessage - // Tools is the tools to be used in the model. - Tools []*schema.ToolInfo - // ToolChoice controls which tool is called by the model. - ToolChoice *schema.ToolChoice - // Config is the config for the model. - Config *Config - // Extra is the extra information for the callback. - Extra map[string]any -} - -// CallbackOutput is the output for the model callback. -type CallbackOutput struct { - // Message is the message generated by the model. - Message *schema.AgenticMessage - // Config is the config for the model. - Config *Config - // Extra is the extra information for the callback. - Extra map[string]any -} - -// ConvCallbackInput converts the callback input to the model callback input. -func ConvCallbackInput(src callbacks.CallbackInput) *CallbackInput { - switch t := src.(type) { - case *CallbackInput: // when callback is triggered within component implementation, the input is usually already a typed *model.CallbackInput - return t - case []*schema.AgenticMessage: // when callback is injected by graph node, not the component implementation itself, the input is the input of Chat Model interface, which is []*schema.AgenticMessage - return &CallbackInput{ - Messages: t, - } - default: - return nil - } -} - -// ConvCallbackOutput converts the callback output to the model callback output. -func ConvCallbackOutput(src callbacks.CallbackOutput) *CallbackOutput { - switch t := src.(type) { - case *CallbackOutput: // when callback is triggered within component implementation, the output is usually already a typed *model.CallbackOutput - return t - case *schema.AgenticMessage: // when callback is injected by graph node, not the component implementation itself, the output is the output of Chat Model interface, which is *schema.AgenticMessage - return &CallbackOutput{ - Message: t, - } - default: - return nil - } -} diff --git a/components/agentic/callback_extra_test.go b/components/agentic/callback_extra_test.go deleted file mode 100644 index a77da6cd2..000000000 --- a/components/agentic/callback_extra_test.go +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Copyright 2025 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package agentic - -import ( - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/cloudwego/eino/schema" -) - -func TestConvModel(t *testing.T) { - assert.NotNil(t, ConvCallbackInput(&CallbackInput{})) - assert.NotNil(t, ConvCallbackInput([]*schema.AgenticMessage{})) - assert.Nil(t, ConvCallbackInput("asd")) - - assert.NotNil(t, ConvCallbackOutput(&CallbackOutput{})) - assert.NotNil(t, ConvCallbackOutput(&schema.AgenticMessage{})) - assert.Nil(t, ConvCallbackOutput("asd")) -} diff --git a/components/agentic/interface.go b/components/agentic/interface.go deleted file mode 100644 index e62a8eeab..000000000 --- a/components/agentic/interface.go +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright 2025 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package agentic - -import ( - "context" - - "github.com/cloudwego/eino/schema" -) - -type Model interface { - Generate(ctx context.Context, input []*schema.AgenticMessage, opts ...Option) (*schema.AgenticMessage, error) - Stream(ctx context.Context, input []*schema.AgenticMessage, opts ...Option) (*schema.StreamReader[*schema.AgenticMessage], error) - - // WithTools returns a new Model instance with the specified tools bound. - // This method does not modify the current instance, making it safer for concurrent use. - WithTools(tools []*schema.ToolInfo) (Model, error) -} diff --git a/components/agentic/option.go b/components/agentic/option.go deleted file mode 100644 index d8873442a..000000000 --- a/components/agentic/option.go +++ /dev/null @@ -1,142 +0,0 @@ -/* - * Copyright 2025 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package agentic - -import ( - "github.com/cloudwego/eino/schema" -) - -// Options is the common options for the model. -type Options struct { - // Temperature is the temperature for the model, which controls the randomness of the model. - Temperature *float64 - // Model is the model name. - Model *string - // TopP is the top p for the model, which controls the diversity of the model. - TopP *float64 - // Tools is a list of tools the model may call. - Tools []*schema.ToolInfo - // ToolChoice controls how the model call the tools. - ToolChoice *schema.ToolChoice - // AllowedTools is a list of allowed tools the model may call. - AllowedTools []*schema.AllowedTool -} - -// Option is the call option for ChatModel component. -type Option struct { - apply func(opts *Options) - - implSpecificOptFn any -} - -// WithTemperature is the option to set the temperature for the model. -func WithTemperature(temperature float64) Option { - return Option{ - apply: func(opts *Options) { - opts.Temperature = &temperature - }, - } -} - -// WithModel is the option to set the model name. -func WithModel(name string) Option { - return Option{ - apply: func(opts *Options) { - opts.Model = &name - }, - } -} - -// WithTopP is the option to set the top p for the model. -func WithTopP(topP float64) Option { - return Option{ - apply: func(opts *Options) { - opts.TopP = &topP - }, - } -} - -// WithTools is the option to set tools for the model. -func WithTools(tools []*schema.ToolInfo) Option { - if tools == nil { - tools = []*schema.ToolInfo{} - } - return Option{ - apply: func(opts *Options) { - opts.Tools = tools - }, - } -} - -// WithToolChoice is the option to set tool choice for the model. -func WithToolChoice(toolChoice schema.ToolChoice, allowedTools ...*schema.AllowedTool) Option { - return Option{ - apply: func(opts *Options) { - opts.ToolChoice = &toolChoice - opts.AllowedTools = allowedTools - }, - } -} - -// WrapImplSpecificOptFn is the option to wrap the implementation specific option function. -func WrapImplSpecificOptFn[T any](optFn func(*T)) Option { - return Option{ - implSpecificOptFn: optFn, - } -} - -// GetCommonOptions extract model Options from Option list, optionally providing a base Options with default values. -func GetCommonOptions(base *Options, opts ...Option) *Options { - if base == nil { - base = &Options{} - } - - for i := range opts { - opt := opts[i] - if opt.apply != nil { - opt.apply(base) - } - } - - return base -} - -// GetImplSpecificOptions extract the implementation specific options from Option list, optionally providing a base options with default values. -// e.g. -// -// myOption := &MyOption{ -// Field1: "default_value", -// } -// -// myOption := model.GetImplSpecificOptions(myOption, opts...) -func GetImplSpecificOptions[T any](base *T, opts ...Option) *T { - if base == nil { - base = new(T) - } - - for i := range opts { - opt := opts[i] - if opt.implSpecificOptFn != nil { - optFn, ok := opt.implSpecificOptFn.(func(*T)) - if ok { - optFn(base) - } - } - } - - return base -} diff --git a/components/agentic/option_test.go b/components/agentic/option_test.go deleted file mode 100644 index 2c5bac652..000000000 --- a/components/agentic/option_test.go +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Copyright 2025 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package agentic - -import ( - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/cloudwego/eino/schema" -) - -func TestCommon(t *testing.T) { - o := GetCommonOptions(nil, - WithTools([]*schema.ToolInfo{{Name: "test"}}), - WithModel("test"), - WithTemperature(0.1), - WithToolChoice(schema.ToolChoiceAllowed, []*schema.AllowedTool{{FunctionToolName: "test"}}...), - WithTopP(0.1), - ) - assert.Len(t, o.Tools, 1) - assert.Equal(t, "test", o.Tools[0].Name) - assert.Equal(t, "test", *o.Model) - assert.Equal(t, float64(0.1), *o.Temperature) - assert.Equal(t, schema.ToolChoiceAllowed, *o.ToolChoice) - assert.Equal(t, float64(0.1), *o.TopP) -} - -func TestImplSpecificOpts(t *testing.T) { - type implSpecificOptions struct { - conf string - index int - } - - withConf := func(conf string) func(o *implSpecificOptions) { - return func(o *implSpecificOptions) { - o.conf = conf - } - } - - withIndex := func(index int) func(o *implSpecificOptions) { - return func(o *implSpecificOptions) { - o.index = index - } - } - - documentOption1 := WrapImplSpecificOptFn(withConf("test_conf")) - documentOption2 := WrapImplSpecificOptFn(withIndex(1)) - - implSpecificOpts := GetImplSpecificOptions(&implSpecificOptions{}, documentOption1, documentOption2) - - assert.Equal(t, &implSpecificOptions{ - conf: "test_conf", - index: 1, - }, implSpecificOpts) - documentOption1 = WrapImplSpecificOptFn(withConf("test_conf")) - documentOption2 = WrapImplSpecificOptFn(withIndex(1)) - - implSpecificOpts = GetImplSpecificOptions(&implSpecificOptions{}, documentOption1, documentOption2) - - assert.Equal(t, &implSpecificOptions{ - conf: "test_conf", - index: 1, - }, implSpecificOpts) -} diff --git a/components/model/callback_extra.go b/components/model/callback_extra.go index afff3f0a7..ed9096d5c 100644 --- a/components/model/callback_extra.go +++ b/components/model/callback_extra.go @@ -31,15 +31,15 @@ type TokenUsage struct { CompletionTokens int // TotalTokens is the total number of tokens. TotalTokens int - // CompletionTokensDetails is a breakdown of the completion tokens. - CompletionTokensDetails CompletionTokensDetails + // CompletionTokensDetails is breakdown of completion tokens. + CompletionTokensDetails CompletionTokensDetails `json:"completion_token_details"` } type CompletionTokensDetails struct { // ReasoningTokens tokens generated by the model for reasoning. // This is currently supported by OpenAI, Gemini, ARK and Qwen chat models. // For other models, this field will be 0. - ReasoningTokens int + ReasoningTokens int `json:"reasoning_tokens,omitempty"` } // PromptTokenDetails provides a breakdown of prompt token usage. @@ -66,6 +66,8 @@ type Config struct { type CallbackInput struct { // Messages is the messages to be sent to the model. Messages []*schema.Message + // AgenticMessages is the agentic messages to be sent to the agentic model. + AgenticMessages []*schema.AgenticMessage // Tools is the tools to be used in the model. Tools []*schema.ToolInfo // ToolChoice is the tool choice, which controls the tool to be used in the model. @@ -80,6 +82,8 @@ type CallbackInput struct { type CallbackOutput struct { // Message is the message generated by the model. Message *schema.Message + // AgenticMessage is the agentic message generated by the agentic model. + AgenticMessage *schema.AgenticMessage // Config is the config for the model. Config *Config // TokenUsage is the token usage of this request. @@ -97,6 +101,10 @@ func ConvCallbackInput(src callbacks.CallbackInput) *CallbackInput { return &CallbackInput{ Messages: t, } + case []*schema.AgenticMessage: // when callback is injected by graph node, not the component implementation itself, the input is the input of Agentic Model interface, which is []*schema.AgenticMessage + return &CallbackInput{ + AgenticMessages: t, + } default: return nil } @@ -111,6 +119,10 @@ func ConvCallbackOutput(src callbacks.CallbackOutput) *CallbackOutput { return &CallbackOutput{ Message: t, } + case *schema.AgenticMessage: // when callback is injected by graph node, not the component implementation itself, the output is the output of Agentic Model interface, which is *schema.AgenticMessage + return &CallbackOutput{ + AgenticMessage: t, + } default: return nil } diff --git a/components/model/interface.go b/components/model/interface.go index deb7b56dd..cf79785bc 100644 --- a/components/model/interface.go +++ b/components/model/interface.go @@ -89,3 +89,15 @@ type ToolCallingChatModel interface { // This method does not modify the current instance, making it safer for concurrent use. WithTools(tools []*schema.ToolInfo) (ToolCallingChatModel, error) } + +// AgenticModel defines the interface for agentic models that support AgenticMessage. +// It provides methods for generating complete and streaming outputs, and supports +// tool calling via the WithTools method. +type AgenticModel interface { + Generate(ctx context.Context, input []*schema.AgenticMessage, opts ...Option) (*schema.AgenticMessage, error) + Stream(ctx context.Context, input []*schema.AgenticMessage, opts ...Option) (*schema.StreamReader[*schema.AgenticMessage], error) + + // WithTools returns a new Model instance with the specified tools bound. + // This method does not modify the current instance, making it safer for concurrent use. + WithTools(tools []*schema.ToolInfo) (AgenticModel, error) +} diff --git a/components/model/option.go b/components/model/option.go index 9fd96116c..0173d22aa 100644 --- a/components/model/option.go +++ b/components/model/option.go @@ -22,21 +22,29 @@ import "github.com/cloudwego/eino/schema" type Options struct { // Temperature is the temperature for the model, which controls the randomness of the model. Temperature *float32 - // MaxTokens is the max number of tokens, if reached the max tokens, the model will stop generating, and mostly return an finish reason of "length". - MaxTokens *int // Model is the model name. Model *string // TopP is the top p for the model, which controls the diversity of the model. TopP *float32 - // Stop is the stop words for the model, which controls the stopping condition of the model. - Stop []string // Tools is a list of tools the model may call. Tools []*schema.ToolInfo // ToolChoice controls which tool is called by the model. ToolChoice *schema.ToolChoice + + // Options only for chat model. + + // MaxTokens is the max number of tokens, if reached the max tokens, the model will stop generating, and mostly return an finish reason of "length". + MaxTokens *int // AllowedToolNames specifies a list of tool names that the model is allowed to call. // This allows for constraining the model to a specific subset of the available tools. AllowedToolNames []string + // Stop is the stop words for the model, which controls the stopping condition of the model. + Stop []string + + // Options only for agentic model. + + // AllowedTools is a list of allowed tools the model may call. + AllowedTools []*schema.AllowedTool } // Option is a call-time option for a ChatModel. Options are immutable and @@ -59,6 +67,7 @@ func WithTemperature(temperature float32) Option { } // WithMaxTokens is the option to set the max tokens for the model. +// Only available for ChatModel. func WithMaxTokens(maxTokens int) Option { return Option{ apply: func(opts *Options) { @@ -86,6 +95,7 @@ func WithTopP(topP float32) Option { } // WithStop is the option to set the stop words for the model. +// Only available for ChatModel. func WithStop(stop []string) Option { return Option{ apply: func(opts *Options) { @@ -108,6 +118,7 @@ func WithTools(tools []*schema.ToolInfo) Option { // WithToolChoice sets the tool choice for the model. It also allows for providing a list of // tool names to constrain the model to a specific subset of the available tools. +// Only available for ChatModel. func WithToolChoice(toolChoice schema.ToolChoice, allowedToolNames ...string) Option { return Option{ apply: func(opts *Options) { @@ -117,6 +128,18 @@ func WithToolChoice(toolChoice schema.ToolChoice, allowedToolNames ...string) Op } } +// WithAgenticToolChoice is the option to set tool choice for the agentic model. +// Only available for AgenticModel. +func WithAgenticToolChoice(toolChoice schema.ToolChoice, allowedTools ...*schema.AllowedTool) Option { + return Option{ + apply: func(opts *Options) { + opts.ToolChoice = &toolChoice + opts.AllowedTools = allowedTools + }, + } +} + +// WrapImplSpecificOptFn is the option to wrap the implementation specific option function. // WrapImplSpecificOptFn wraps an implementation-specific option function into // an [Option] so it can be passed alongside standard options. // diff --git a/components/model/option_test.go b/components/model/option_test.go index 36872c30e..bfacdd17c 100644 --- a/components/model/option_test.go +++ b/components/model/option_test.go @@ -82,6 +82,22 @@ func TestOptions(t *testing.T) { convey.So(opts.Tools, convey.ShouldNotBeNil) convey.So(len(opts.Tools), convey.ShouldEqual, 0) }) + + convey.Convey("test agentic tool choice option", t, func() { + var ( + toolChoice = schema.ToolChoiceForced + allowedTools = []*schema.AllowedTool{ + {FunctionToolName: "agentic_tool"}, + } + ) + opts := GetCommonOptions( + nil, + WithAgenticToolChoice(toolChoice, allowedTools...), + ) + + convey.So(opts.ToolChoice, convey.ShouldResemble, &toolChoice) + convey.So(opts.AllowedTools, convey.ShouldResemble, allowedTools) + }) } type implOption struct { diff --git a/components/prompt/chat_template_agentic.go b/components/prompt/agentic_chat_template.go similarity index 87% rename from components/prompt/chat_template_agentic.go rename to components/prompt/agentic_chat_template.go index 937d46f26..512a60ecd 100644 --- a/components/prompt/chat_template_agentic.go +++ b/components/prompt/agentic_chat_template.go @@ -1,5 +1,5 @@ /* - * Copyright 2025 CloudWeGo Authors + * Copyright 2026 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -45,9 +45,9 @@ type DefaultAgenticChatTemplate struct { func (t *DefaultAgenticChatTemplate) Format(ctx context.Context, vs map[string]any, opts ...Option) (result []*schema.AgenticMessage, err error) { ctx = callbacks.EnsureRunInfo(ctx, t.GetType(), components.ComponentOfAgenticPrompt) - ctx = callbacks.OnStart(ctx, &AgenticCallbackInput{ - Variables: vs, - Templates: t.templates, + ctx = callbacks.OnStart(ctx, &CallbackInput{ + Variables: vs, + AgenticTemplates: t.templates, }) defer func() { if err != nil { @@ -65,15 +65,15 @@ func (t *DefaultAgenticChatTemplate) Format(ctx context.Context, vs map[string]a result = append(result, msgs...) } - _ = callbacks.OnEnd(ctx, &AgenticCallbackOutput{ - Result: result, - Templates: t.templates, + _ = callbacks.OnEnd(ctx, &CallbackOutput{ + AgenticResult: result, + AgenticTemplates: t.templates, }) return result, nil } -// GetType returns the type of the chat template (Default). +// GetType returns the type of the agentic template (DefaultAgentic). func (t *DefaultAgenticChatTemplate) GetType() string { return "Default" } diff --git a/components/prompt/agentic_chat_template_test.go b/components/prompt/agentic_chat_template_test.go new file mode 100644 index 000000000..42d7a8630 --- /dev/null +++ b/components/prompt/agentic_chat_template_test.go @@ -0,0 +1,124 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package prompt + +import ( + "context" + "errors" + "testing" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/schema" + "github.com/stretchr/testify/assert" +) + +type mockAgenticTemplate struct { + err error +} + +func (m *mockAgenticTemplate) Format(ctx context.Context, vs map[string]any, formatType schema.FormatType) ([]*schema.AgenticMessage, error) { + if m.err != nil { + return nil, m.err + } + return []*schema.AgenticMessage{schema.UserAgenticMessage("mocked")}, nil +} + +func TestFromAgenticMessages(t *testing.T) { + t.Run("create template", func(t *testing.T) { + tpl := schema.UserAgenticMessage("hello") + ft := schema.FString + at := FromAgenticMessages(ft, tpl) + + assert.NotNil(t, at) + assert.Equal(t, ft, at.formatType) + assert.Len(t, at.templates, 1) + assert.Same(t, tpl, at.templates[0]) + }) +} + +func TestDefaultAgenticTemplate_GetType(t *testing.T) { + t.Run("get type", func(t *testing.T) { + at := &DefaultAgenticChatTemplate{} + assert.Equal(t, "Default", at.GetType()) + }) +} + +func TestDefaultAgenticTemplate_IsCallbacksEnabled(t *testing.T) { + t.Run("callbacks enabled", func(t *testing.T) { + at := &DefaultAgenticChatTemplate{} + assert.True(t, at.IsCallbacksEnabled()) + }) +} + +func TestDefaultAgenticTemplate_Format(t *testing.T) { + t.Run("success", func(t *testing.T) { + // Mock callback handler + cb := callbacks.NewHandlerBuilder(). + OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { + assert.Equal(t, "Default", info.Type) + return ctx + }). + OnEndFn(func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context { + assert.Equal(t, "Default", info.Type) + return ctx + }). + OnErrorFn(func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { + assert.Fail(t, "unexpected error callback") + return ctx + }). + Build() + + tpl := schema.UserAgenticMessage("hello {val}") + at := FromAgenticMessages(schema.FString, tpl) + + ctx := context.Background() + ctx = callbacks.InitCallbacks(ctx, &callbacks.RunInfo{ + Type: "Default", + Component: "agentic_prompt", + }, cb) + + res, err := at.Format(ctx, map[string]any{"val": "world"}) + assert.NoError(t, err) + assert.Len(t, res, 1) + assert.Equal(t, "hello world", res[0].ContentBlocks[0].UserInputText.Text) + }) + + t.Run("template format error", func(t *testing.T) { + mockErr := errors.New("mock error") + mockTpl := &mockAgenticTemplate{err: mockErr} + at := FromAgenticMessages(schema.FString, mockTpl) + + // Mock callback handler to verify OnError + cb := callbacks.NewHandlerBuilder(). + OnErrorFn(func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { + assert.Equal(t, mockErr, err) + return ctx + }). + Build() + + ctx := context.Background() + ctx = callbacks.InitCallbacks(ctx, &callbacks.RunInfo{ + Type: "Default", + Component: "agentic_prompt", + }, cb) + + res, err := at.Format(ctx, map[string]any{}) + assert.Error(t, err) + assert.Nil(t, res) + assert.Equal(t, mockErr, err) + }) +} diff --git a/components/prompt/callback_extra.go b/components/prompt/callback_extra.go index 3be780543..4c27f37c6 100644 --- a/components/prompt/callback_extra.go +++ b/components/prompt/callback_extra.go @@ -21,52 +21,14 @@ import ( "github.com/cloudwego/eino/schema" ) -type AgenticCallbackInput struct { - Variables map[string]any - Templates []schema.AgenticMessagesTemplate - Extra map[string]any -} - -type AgenticCallbackOutput struct { - Result []*schema.AgenticMessage - Templates []schema.AgenticMessagesTemplate - Extra map[string]any -} - -// ConvAgenticCallbackInput converts the callback input to the agentic callback input. -func ConvAgenticCallbackInput(src callbacks.CallbackInput) *AgenticCallbackInput { - switch t := src.(type) { - case *AgenticCallbackInput: - return t - case map[string]any: - return &AgenticCallbackInput{ - Variables: t, - } - default: - return nil - } -} - -// ConvAgenticCallbackOutput converts the callback output to the agentic callback output. -func ConvAgenticCallbackOutput(src callbacks.CallbackOutput) *AgenticCallbackOutput { - switch t := src.(type) { - case *AgenticCallbackOutput: - return t - case []*schema.AgenticMessage: - return &AgenticCallbackOutput{ - Result: t, - } - default: - return nil - } -} - // CallbackInput is the input for the callback. type CallbackInput struct { // Variables is the variables for the callback. Variables map[string]any // Templates is the templates for the callback. Templates []schema.MessagesTemplate + // AgenticTemplates is the agentic templates for the callback. + AgenticTemplates []schema.AgenticMessagesTemplate // Extra is the extra information for the callback. Extra map[string]any } @@ -75,8 +37,12 @@ type CallbackInput struct { type CallbackOutput struct { // Result is the result for the callback. Result []*schema.Message + // AgenticResult is the agentic result for the callback. + AgenticResult []*schema.AgenticMessage // Templates is the templates for the callback. Templates []schema.MessagesTemplate + // AgenticTemplates is the agentic templates for the callback. + AgenticTemplates []schema.AgenticMessagesTemplate // Extra is the extra information for the callback. Extra map[string]any } @@ -104,6 +70,10 @@ func ConvCallbackOutput(src callbacks.CallbackOutput) *CallbackOutput { return &CallbackOutput{ Result: t, } + case []*schema.AgenticMessage: + return &CallbackOutput{ + AgenticResult: t, + } default: return nil } diff --git a/components/prompt/callback_extra_test.go b/components/prompt/callback_extra_test.go index 456297e29..4b48ec114 100644 --- a/components/prompt/callback_extra_test.go +++ b/components/prompt/callback_extra_test.go @@ -25,11 +25,28 @@ import ( ) func TestConvPrompt(t *testing.T) { - assert.NotNil(t, ConvCallbackInput(&CallbackInput{})) + assert.NotNil(t, ConvCallbackInput(&CallbackInput{ + AgenticTemplates: []schema.AgenticMessagesTemplate{ + &schema.AgenticMessage{}, + }, + })) assert.NotNil(t, ConvCallbackInput(map[string]any{})) assert.Nil(t, ConvCallbackInput("asd")) - assert.NotNil(t, ConvCallbackOutput(&CallbackOutput{})) + assert.NotNil(t, ConvCallbackOutput(&CallbackOutput{ + AgenticResult: []*schema.AgenticMessage{ + {}, + }, + AgenticTemplates: []schema.AgenticMessagesTemplate{ + &schema.AgenticMessage{}, + }, + })) assert.NotNil(t, ConvCallbackOutput([]*schema.Message{})) + + agenticResult := []*schema.AgenticMessage{{}} + out := ConvCallbackOutput(agenticResult) + assert.NotNil(t, out) + assert.Equal(t, agenticResult, out.AgenticResult) + assert.Nil(t, ConvCallbackOutput("asd")) } diff --git a/components/prompt/chat_template_agentic_test.go b/components/prompt/chat_template_agentic_test.go deleted file mode 100644 index aaa7d6405..000000000 --- a/components/prompt/chat_template_agentic_test.go +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Copyright 2025 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package prompt - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/cloudwego/eino/schema" -) - -func TestAgenticFormat(t *testing.T) { - pyFmtTestTemplate := []schema.AgenticMessagesTemplate{ - &schema.AgenticMessage{ - Role: schema.AgenticRoleTypeUser, - ContentBlocks: []*schema.ContentBlock{ - {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "{context}"}}, - }, - }, - schema.AgenticMessagesPlaceholder("chat_history", true), - } - jinja2TestTemplate := []schema.AgenticMessagesTemplate{ - &schema.AgenticMessage{ - Role: schema.AgenticRoleTypeUser, - ContentBlocks: []*schema.ContentBlock{ - {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "{{context}}"}}, - }, - }, - schema.AgenticMessagesPlaceholder("chat_history", true), - } - goFmtTestTemplate := []schema.AgenticMessagesTemplate{ - &schema.AgenticMessage{ - Role: schema.AgenticRoleTypeUser, - ContentBlocks: []*schema.ContentBlock{ - {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "{{.context}}"}}, - }, - }, - schema.AgenticMessagesPlaceholder("chat_history", true), - } - testValues := map[string]any{ - "context": "it's beautiful day", - "chat_history": []*schema.AgenticMessage{ - { - Role: schema.AgenticRoleTypeUser, - ContentBlocks: []*schema.ContentBlock{ - {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "1"}}, - }, - }, - { - Role: schema.AgenticRoleTypeUser, - ContentBlocks: []*schema.ContentBlock{ - {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "2"}}, - }, - }, - }, - } - expected := []*schema.AgenticMessage{ - { - Role: schema.AgenticRoleTypeUser, - ContentBlocks: []*schema.ContentBlock{ - {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "it's beautiful day"}}, - }, - }, - { - Role: schema.AgenticRoleTypeUser, - ContentBlocks: []*schema.ContentBlock{ - {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "1"}}, - }, - }, - { - Role: schema.AgenticRoleTypeUser, - ContentBlocks: []*schema.ContentBlock{ - {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "2"}}, - }, - }, - } - - // FString - chatTemplate := FromAgenticMessages(schema.FString, pyFmtTestTemplate...) - msgs, err := chatTemplate.Format(context.Background(), testValues) - assert.Nil(t, err) - assert.Equal(t, expected, msgs) - - // Jinja2 - chatTemplate = FromAgenticMessages(schema.Jinja2, jinja2TestTemplate...) - msgs, err = chatTemplate.Format(context.Background(), testValues) - assert.Nil(t, err) - assert.Equal(t, expected, msgs) - - // GoTemplate - chatTemplate = FromAgenticMessages(schema.GoTemplate, goFmtTestTemplate...) - msgs, err = chatTemplate.Format(context.Background(), testValues) - assert.Nil(t, err) - assert.Equal(t, expected, msgs) -} diff --git a/components/prompt/interface.go b/components/prompt/interface.go index 7ffe7216a..2d5a2cbed 100644 --- a/components/prompt/interface.go +++ b/components/prompt/interface.go @@ -44,6 +44,7 @@ type ChatTemplate interface { Format(ctx context.Context, vs map[string]any, opts ...Option) ([]*schema.Message, error) } +// AgenticChatTemplate formats variables into a list of agentic messages according to a prompt schema. type AgenticChatTemplate interface { Format(ctx context.Context, vs map[string]any, opts ...Option) ([]*schema.AgenticMessage, error) } diff --git a/compose/tools_node_agentic.go b/compose/agentic_tools_node.go similarity index 100% rename from compose/tools_node_agentic.go rename to compose/agentic_tools_node.go diff --git a/compose/tools_node_agentic_test.go b/compose/agentic_tools_node_test.go similarity index 100% rename from compose/tools_node_agentic_test.go rename to compose/agentic_tools_node_test.go diff --git a/compose/chain.go b/compose/chain.go index 8484e8767..abfa6bf1d 100644 --- a/compose/chain.go +++ b/compose/chain.go @@ -22,7 +22,6 @@ import ( "fmt" "reflect" - "github.com/cloudwego/eino/components/agentic" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" @@ -181,7 +180,7 @@ func (c *Chain[I, O]) AppendChatModel(node model.BaseChatModel, opts ...GraphAdd // model, err := openai.NewAgenticModel(ctx, config) // if err != nil {...} // chain.AppendAgenticModel(model) -func (c *Chain[I, O]) AppendAgenticModel(node agentic.Model, opts ...GraphAddNodeOpt) *Chain[I, O] { +func (c *Chain[I, O]) AppendAgenticModel(node model.AgenticModel, opts ...GraphAddNodeOpt) *Chain[I, O] { gNode, options := toAgenticModelNode(node, opts...) c.addNode(gNode, options) return c diff --git a/compose/chain_branch.go b/compose/chain_branch.go index 004dbfac3..84fb11048 100644 --- a/compose/chain_branch.go +++ b/compose/chain_branch.go @@ -20,7 +20,6 @@ import ( "context" "fmt" - "github.com/cloudwego/eino/components/agentic" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" @@ -158,7 +157,7 @@ func (cb *ChainBranch) AddChatModel(key string, node model.BaseChatModel, opts . // }) // cb.AddAgenticModel("agentic_model_key_1", model1) // cb.AddAgenticModel("agentic_model_key_2", model2) -func (cb *ChainBranch) AddAgenticModel(key string, node agentic.Model, opts ...GraphAddNodeOpt) *ChainBranch { +func (cb *ChainBranch) AddAgenticModel(key string, node model.AgenticModel, opts ...GraphAddNodeOpt) *ChainBranch { gNode, options := toAgenticModelNode(node, opts...) return cb.addNode(key, gNode, options) } diff --git a/compose/chain_parallel.go b/compose/chain_parallel.go index 128ed4a26..463140be2 100644 --- a/compose/chain_parallel.go +++ b/compose/chain_parallel.go @@ -19,7 +19,6 @@ package compose import ( "fmt" - "github.com/cloudwego/eino/components/agentic" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" @@ -84,7 +83,7 @@ func (p *Parallel) AddChatModel(outputKey string, node model.BaseChatModel, opts // // p.AddAgenticModel("output_key1", model1) // p.AddAgenticModel("output_key2", model2) -func (p *Parallel) AddAgenticModel(outputKey string, node agentic.Model, opts ...GraphAddNodeOpt) *Parallel { +func (p *Parallel) AddAgenticModel(outputKey string, node model.AgenticModel, opts ...GraphAddNodeOpt) *Parallel { gNode, options := toAgenticModelNode(node, append(opts, WithOutputKey(outputKey))...) return p.addNode(outputKey, gNode, options) } diff --git a/compose/component_to_graph_node.go b/compose/component_to_graph_node.go index e64ce4f19..4bd27fe34 100644 --- a/compose/component_to_graph_node.go +++ b/compose/component_to_graph_node.go @@ -18,7 +18,6 @@ package compose import ( "github.com/cloudwego/eino/components" - "github.com/cloudwego/eino/components/agentic" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" @@ -102,7 +101,7 @@ func toChatModelNode(node model.BaseChatModel, opts ...GraphAddNodeOpt) (*graphN opts...) } -func toAgenticModelNode(node agentic.Model, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { +func toAgenticModelNode(node model.AgenticModel, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { return toComponentNode( node, components.ComponentOfAgenticModel, diff --git a/compose/graph.go b/compose/graph.go index 877b8fb42..bcf5ae423 100644 --- a/compose/graph.go +++ b/compose/graph.go @@ -23,7 +23,6 @@ import ( "reflect" "strings" - "github.com/cloudwego/eino/components/agentic" "github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/indexer" @@ -361,7 +360,7 @@ func (g *graph) AddChatModelNode(key string, node model.BaseChatModel, opts ...G // }) // // graph.AddAgenticModelNode("agentic_model_node_key", model) -func (g *graph) AddAgenticModelNode(key string, node agentic.Model, opts ...GraphAddNodeOpt) error { +func (g *graph) AddAgenticModelNode(key string, node model.AgenticModel, opts ...GraphAddNodeOpt) error { gNode, options := toAgenticModelNode(node, opts...) return g.addNode(key, gNode, options) } diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 63c78c3eb..a4554d38e 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -66,15 +66,15 @@ const ( ) type AgenticMessage struct { - // ResponseMeta is the response metadata. - ResponseMeta *AgenticResponseMeta - // Role is the message role. Role AgenticRoleType // ContentBlocks is the list of content blocks. ContentBlocks []*ContentBlock + // ResponseMeta is the response metadata. + ResponseMeta *AgenticResponseMeta + // Extra is the additional information. Extra map[string]any } @@ -541,7 +541,7 @@ func NewContentBlockChunk[T contentBlockVariant](content *T, meta *StreamingMeta return block } -// AgenticMessagesTemplate is the interface for messages template. +// AgenticMessagesTemplate is the interface for agentic messages template. // It's used to render a template to a list of agentic messages. // e.g. // diff --git a/utils/callbacks/template.go b/utils/callbacks/template.go index e04bddd63..4c73e6bbc 100644 --- a/utils/callbacks/template.go +++ b/utils/callbacks/template.go @@ -55,17 +55,20 @@ func NewHandlerHelper() *HandlerHelper { // // then use the handler with runnable.Invoke(ctx, input, compose.WithCallbacks(handler)) type HandlerHelper struct { - promptHandler *PromptCallbackHandler - chatModelHandler *ModelCallbackHandler - embeddingHandler *EmbeddingCallbackHandler - indexerHandler *IndexerCallbackHandler - retrieverHandler *RetrieverCallbackHandler - loaderHandler *LoaderCallbackHandler - transformerHandler *TransformerCallbackHandler - toolHandler *ToolCallbackHandler - toolsNodeHandler *ToolsNodeCallbackHandlers - agentHandler *AgentCallbackHandler - composeTemplates map[components.Component]callbacks.Handler + promptHandler *PromptCallbackHandler + chatModelHandler *ModelCallbackHandler + embeddingHandler *EmbeddingCallbackHandler + indexerHandler *IndexerCallbackHandler + retrieverHandler *RetrieverCallbackHandler + loaderHandler *LoaderCallbackHandler + transformerHandler *TransformerCallbackHandler + toolHandler *ToolCallbackHandler + toolsNodeHandler *ToolsNodeCallbackHandlers + agenticPromptHandler *AgenticPromptCallbackHandler + agenticModelHandler *AgenticModelCallbackHandler + agenticToolsNodeHandler *AgenticToolsNodeCallbackHandlers + agentHandler *AgentCallbackHandler + composeTemplates map[components.Component]callbacks.Handler } // Handler returns the callbacks.Handler created by HandlerHelper. @@ -127,6 +130,24 @@ func (c *HandlerHelper) ToolsNode(handler *ToolsNodeCallbackHandlers) *HandlerHe return c } +// AgenticPrompt sets the agentic prompt handler for the handler helper, which will be called when the agentic prompt component is executed. +func (c *HandlerHelper) AgenticPrompt(handler *AgenticPromptCallbackHandler) *HandlerHelper { + c.agenticPromptHandler = handler + return c +} + +// AgenticModel sets the agentic chat model handler for the handler helper, which will be called when the agentic chat model component is executed. +func (c *HandlerHelper) AgenticModel(handler *AgenticModelCallbackHandler) *HandlerHelper { + c.agenticModelHandler = handler + return c +} + +// AgenticToolsNode sets the agentic tools node handler for the handler helper, which will be called when the agentic tools node is executed. +func (c *HandlerHelper) AgenticToolsNode(handler *AgenticToolsNodeCallbackHandlers) *HandlerHelper { + c.agenticToolsNodeHandler = handler + return c +} + // Agent sets the agent handler for the handler helper, which will be called when the agent is executed. func (c *HandlerHelper) Agent(handler *AgentCallbackHandler) *HandlerHelper { c.agentHandler = handler @@ -161,8 +182,12 @@ func (c *handlerTemplate) OnStart(ctx context.Context, info *callbacks.RunInfo, switch info.Component { case components.ComponentOfPrompt: return c.promptHandler.OnStart(ctx, info, prompt.ConvCallbackInput(input)) + case components.ComponentOfAgenticPrompt: + return c.agenticPromptHandler.OnStart(ctx, info, prompt.ConvCallbackInput(input)) case components.ComponentOfChatModel: return c.chatModelHandler.OnStart(ctx, info, model.ConvCallbackInput(input)) + case components.ComponentOfAgenticModel: + return c.agenticModelHandler.OnStart(ctx, info, model.ConvCallbackInput(input)) case components.ComponentOfEmbedding: return c.embeddingHandler.OnStart(ctx, info, embedding.ConvCallbackInput(input)) case components.ComponentOfIndexer: @@ -177,6 +202,8 @@ func (c *handlerTemplate) OnStart(ctx context.Context, info *callbacks.RunInfo, return c.toolHandler.OnStart(ctx, info, tool.ConvCallbackInput(input)) case compose.ComponentOfToolsNode: return c.toolsNodeHandler.OnStart(ctx, info, convToolsNodeCallbackInput(input)) + case compose.ComponentOfAgenticToolsNode: + return c.agenticToolsNodeHandler.OnStart(ctx, info, convAgenticToolsNodeCallbackInput(input)) case adk.ComponentOfAgent: return c.agentHandler.OnStart(ctx, info, adk.ConvAgentCallbackInput(input)) case compose.ComponentOfGraph, @@ -194,8 +221,12 @@ func (c *handlerTemplate) OnEnd(ctx context.Context, info *callbacks.RunInfo, ou switch info.Component { case components.ComponentOfPrompt: return c.promptHandler.OnEnd(ctx, info, prompt.ConvCallbackOutput(output)) + case components.ComponentOfAgenticPrompt: + return c.agenticPromptHandler.OnEnd(ctx, info, prompt.ConvCallbackOutput(output)) case components.ComponentOfChatModel: return c.chatModelHandler.OnEnd(ctx, info, model.ConvCallbackOutput(output)) + case components.ComponentOfAgenticModel: + return c.agenticModelHandler.OnEnd(ctx, info, model.ConvCallbackOutput(output)) case components.ComponentOfEmbedding: return c.embeddingHandler.OnEnd(ctx, info, embedding.ConvCallbackOutput(output)) case components.ComponentOfIndexer: @@ -210,6 +241,8 @@ func (c *handlerTemplate) OnEnd(ctx context.Context, info *callbacks.RunInfo, ou return c.toolHandler.OnEnd(ctx, info, tool.ConvCallbackOutput(output)) case compose.ComponentOfToolsNode: return c.toolsNodeHandler.OnEnd(ctx, info, convToolsNodeCallbackOutput(output)) + case compose.ComponentOfAgenticToolsNode: + return c.agenticToolsNodeHandler.OnEnd(ctx, info, convAgenticToolsNodeCallbackOutput(output)) case adk.ComponentOfAgent: return c.agentHandler.OnEnd(ctx, info, adk.ConvAgentCallbackOutput(output)) case compose.ComponentOfGraph, @@ -227,8 +260,12 @@ func (c *handlerTemplate) OnError(ctx context.Context, info *callbacks.RunInfo, switch info.Component { case components.ComponentOfPrompt: return c.promptHandler.OnError(ctx, info, err) + case components.ComponentOfAgenticPrompt: + return c.agenticPromptHandler.OnError(ctx, info, err) case components.ComponentOfChatModel: return c.chatModelHandler.OnError(ctx, info, err) + case components.ComponentOfAgenticModel: + return c.agenticModelHandler.OnError(ctx, info, err) case components.ComponentOfEmbedding: return c.embeddingHandler.OnError(ctx, info, err) case components.ComponentOfIndexer: @@ -243,6 +280,8 @@ func (c *handlerTemplate) OnError(ctx context.Context, info *callbacks.RunInfo, return c.toolHandler.OnError(ctx, info, err) case compose.ComponentOfToolsNode: return c.toolsNodeHandler.OnError(ctx, info, err) + case compose.ComponentOfAgenticToolsNode: + return c.agenticToolsNodeHandler.OnError(ctx, info, err) case compose.ComponentOfGraph, compose.ComponentOfChain, compose.ComponentOfLambda: @@ -275,6 +314,11 @@ func (c *handlerTemplate) OnEndWithStreamOutput(ctx context.Context, info *callb schema.StreamReaderWithConvert(output, func(item callbacks.CallbackOutput) (*model.CallbackOutput, error) { return model.ConvCallbackOutput(item), nil })) + case components.ComponentOfAgenticModel: + return c.agenticModelHandler.OnEndWithStreamOutput(ctx, info, + schema.StreamReaderWithConvert(output, func(item callbacks.CallbackOutput) (*model.CallbackOutput, error) { + return model.ConvCallbackOutput(item), nil + })) case components.ComponentOfTool: return c.toolHandler.OnEndWithStreamOutput(ctx, info, schema.StreamReaderWithConvert(output, func(item callbacks.CallbackOutput) (*tool.CallbackOutput, error) { @@ -285,6 +329,11 @@ func (c *handlerTemplate) OnEndWithStreamOutput(ctx context.Context, info *callb schema.StreamReaderWithConvert(output, func(item callbacks.CallbackOutput) ([]*schema.Message, error) { return convToolsNodeCallbackOutput(item), nil })) + case compose.ComponentOfAgenticToolsNode: + return c.agenticToolsNodeHandler.OnEndWithStreamOutput(ctx, info, + schema.StreamReaderWithConvert(output, func(item callbacks.CallbackOutput) ([]*schema.AgenticMessage, error) { + return convAgenticToolsNodeCallbackOutput(item), nil + })) case compose.ComponentOfGraph, compose.ComponentOfChain, compose.ComponentOfLambda: @@ -295,6 +344,8 @@ func (c *handlerTemplate) OnEndWithStreamOutput(ctx context.Context, info *callb } // Needed checks if the callback handler is needed for the given timing. +// +//nolint:cyclop func (c *handlerTemplate) Needed(ctx context.Context, info *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { if info == nil { return false @@ -305,6 +356,10 @@ func (c *handlerTemplate) Needed(ctx context.Context, info *callbacks.RunInfo, t if c.chatModelHandler != nil && c.chatModelHandler.Needed(ctx, info, timing) { return true } + case components.ComponentOfAgenticModel: + if c.agenticModelHandler != nil && c.agenticModelHandler.Needed(ctx, info, timing) { + return true + } case components.ComponentOfEmbedding: if c.embeddingHandler != nil && c.embeddingHandler.Needed(ctx, info, timing) { return true @@ -321,6 +376,10 @@ func (c *handlerTemplate) Needed(ctx context.Context, info *callbacks.RunInfo, t if c.promptHandler != nil && c.promptHandler.Needed(ctx, info, timing) { return true } + case components.ComponentOfAgenticPrompt: + if c.agenticPromptHandler != nil && c.agenticPromptHandler.Needed(ctx, info, timing) { + return true + } case components.ComponentOfRetriever: if c.retrieverHandler != nil && c.retrieverHandler.Needed(ctx, info, timing) { return true @@ -337,6 +396,10 @@ func (c *handlerTemplate) Needed(ctx context.Context, info *callbacks.RunInfo, t if c.toolsNodeHandler != nil && c.toolsNodeHandler.Needed(ctx, info, timing) { return true } + case compose.ComponentOfAgenticToolsNode: + if c.agenticToolsNodeHandler != nil && c.agenticToolsNodeHandler.Needed(ctx, info, timing) { + return true + } case adk.ComponentOfAgent: if c.agentHandler != nil && c.agentHandler.Needed(ctx, info, timing) { return true @@ -596,3 +659,94 @@ func (ch *AgentCallbackHandler) Needed(ctx context.Context, info *callbacks.RunI return false } } + +// AgenticPromptCallbackHandler is the handler for the agentic prompt callback. +type AgenticPromptCallbackHandler struct { + // OnStart is the callback function for the start of the agentic prompt. + OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *prompt.CallbackInput) context.Context + // OnEnd is the callback function for the end of the agentic prompt. + OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *prompt.CallbackOutput) context.Context + // OnError is the callback function for the error of the agentic prompt. + OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context +} + +// Needed checks if the callback handler is needed for the given timing. +func (ch *AgenticPromptCallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { + switch timing { + case callbacks.TimingOnStart: + return ch.OnStart != nil + case callbacks.TimingOnEnd: + return ch.OnEnd != nil + case callbacks.TimingOnError: + return ch.OnError != nil + default: + return false + } +} + +// AgenticModelCallbackHandler is the handler for the agentic chat model callback. +type AgenticModelCallbackHandler struct { + OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.CallbackInput) context.Context + OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *model.CallbackOutput) context.Context + OnEndWithStreamOutput func(ctx context.Context, runInfo *callbacks.RunInfo, output *schema.StreamReader[*model.CallbackOutput]) context.Context + OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context +} + +// Needed checks if the callback handler is needed for the given timing. +func (ch *AgenticModelCallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { + switch timing { + case callbacks.TimingOnStart: + return ch.OnStart != nil + case callbacks.TimingOnEnd: + return ch.OnEnd != nil + case callbacks.TimingOnError: + return ch.OnError != nil + case callbacks.TimingOnEndWithStreamOutput: + return ch.OnEndWithStreamOutput != nil + default: + return false + } +} + +// AgenticToolsNodeCallbackHandlers defines optional callbacks for the Agentic Tools node +// lifecycle events. +type AgenticToolsNodeCallbackHandlers struct { + OnStart func(ctx context.Context, info *callbacks.RunInfo, input *schema.AgenticMessage) context.Context + OnEnd func(ctx context.Context, info *callbacks.RunInfo, input []*schema.AgenticMessage) context.Context + OnEndWithStreamOutput func(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[[]*schema.AgenticMessage]) context.Context + OnError func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context +} + +// Needed reports whether a handler is registered for the given timing. +func (ch *AgenticToolsNodeCallbackHandlers) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { + switch timing { + case callbacks.TimingOnStart: + return ch.OnStart != nil + case callbacks.TimingOnEnd: + return ch.OnEnd != nil + case callbacks.TimingOnEndWithStreamOutput: + return ch.OnEndWithStreamOutput != nil + case callbacks.TimingOnError: + return ch.OnError != nil + default: + return false + } +} + +func convAgenticToolsNodeCallbackInput(src callbacks.CallbackInput) *schema.AgenticMessage { + switch t := src.(type) { + case *schema.AgenticMessage: + return t + default: + return nil + } +} + +func convAgenticToolsNodeCallbackOutput(src callbacks.CallbackInput) []*schema.AgenticMessage { + switch t := src.(type) { + case []*schema.AgenticMessage: + return t + default: + return nil + } +} diff --git a/utils/callbacks/template_test.go b/utils/callbacks/template_test.go index 84ed6dfc6..f599e5300 100644 --- a/utils/callbacks/template_test.go +++ b/utils/callbacks/template_test.go @@ -142,6 +142,58 @@ func TestNewComponentTemplate(t *testing.T) { cnt++ return ctx }).Build()). + AgenticModel(&AgenticModelCallbackHandler{ + OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.CallbackInput) context.Context { + cnt++ + return ctx + }, + OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *model.CallbackOutput) context.Context { + cnt++ + return ctx + }, + OnEndWithStreamOutput: func(ctx context.Context, runInfo *callbacks.RunInfo, output *schema.StreamReader[*model.CallbackOutput]) context.Context { + output.Close() + cnt++ + return ctx + }, + OnError: func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context { + cnt++ + return ctx + }, + }). + AgenticPrompt(&AgenticPromptCallbackHandler{ + OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *prompt.CallbackInput) context.Context { + cnt++ + return ctx + }, + OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *prompt.CallbackOutput) context.Context { + cnt++ + return ctx + }, + OnError: func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context { + cnt++ + return ctx + }, + }). + AgenticToolsNode(&AgenticToolsNodeCallbackHandlers{ + OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *schema.AgenticMessage) context.Context { + cnt++ + return ctx + }, + OnEnd: func(ctx context.Context, info *callbacks.RunInfo, input []*schema.AgenticMessage) context.Context { + cnt++ + return ctx + }, + OnEndWithStreamOutput: func(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[[]*schema.AgenticMessage]) context.Context { + output.Close() + cnt++ + return ctx + }, + OnError: func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { + cnt++ + return ctx + }, + }). Handler() types := []components.Component{ @@ -151,6 +203,9 @@ func TestNewComponentTemplate(t *testing.T) { components.ComponentOfRetriever, components.ComponentOfTool, compose.ComponentOfLambda, + components.ComponentOfAgenticModel, + components.ComponentOfAgenticPrompt, + compose.ComponentOfAgenticToolsNode, } handler := tpl.Handler() @@ -169,28 +224,28 @@ func TestNewComponentTemplate(t *testing.T) { handler.OnEndWithStreamOutput(ctx, &callbacks.RunInfo{Component: typ}, sor) } - assert.Equal(t, 22, cnt) + assert.Equal(t, 33, cnt) ctx = context.Background() ctx = callbacks.InitCallbacks(ctx, &callbacks.RunInfo{Component: components.ComponentOfTransformer}, handler) callbacks.OnStart[any](ctx, nil) - assert.Equal(t, 22, cnt) + assert.Equal(t, 33, cnt) ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{Component: components.ComponentOfPrompt}) ctx = callbacks.OnStart[any](ctx, nil) - assert.Equal(t, 23, cnt) + assert.Equal(t, 34, cnt) ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{Component: components.ComponentOfIndexer}) callbacks.OnEnd[any](ctx, nil) - assert.Equal(t, 23, cnt) + assert.Equal(t, 34, cnt) ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{Component: components.ComponentOfEmbedding}) callbacks.OnError(ctx, nil) - assert.Equal(t, 24, cnt) + assert.Equal(t, 35, cnt) ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{Component: components.ComponentOfLoader}) callbacks.OnStart[any](ctx, nil) - assert.Equal(t, 24, cnt) + assert.Equal(t, 35, cnt) tpl.Transformer(&TransformerCallbackHandler{ OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *document.TransformerCallbackInput) context.Context { @@ -250,6 +305,37 @@ func TestNewComponentTemplate(t *testing.T) { } } }, + }).AgenticPrompt(&AgenticPromptCallbackHandler{ + OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *prompt.CallbackInput) context.Context { + cnt++ + return ctx + }, + OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *prompt.CallbackOutput) context.Context { + cnt++ + return ctx + }, + OnError: func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context { + cnt++ + return ctx + }, + }).AgenticToolsNode(&AgenticToolsNodeCallbackHandlers{ + OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *schema.AgenticMessage) context.Context { + cnt++ + return ctx + }, + OnEnd: func(ctx context.Context, info *callbacks.RunInfo, input []*schema.AgenticMessage) context.Context { + cnt++ + return ctx + }, + OnEndWithStreamOutput: func(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[[]*schema.AgenticMessage]) context.Context { + output.Close() + cnt++ + return ctx + }, + OnError: func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { + cnt++ + return ctx + }, }) handler = tpl.Handler() @@ -257,36 +343,222 @@ func TestNewComponentTemplate(t *testing.T) { ctx = callbacks.InitCallbacks(ctx, &callbacks.RunInfo{Component: components.ComponentOfTransformer}, handler) ctx = callbacks.OnStart[any](ctx, nil) - assert.Equal(t, 25, cnt) + assert.Equal(t, 36, cnt) callbacks.OnEnd[any](ctx, nil) - assert.Equal(t, 26, cnt) + assert.Equal(t, 37, cnt) ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{Component: components.ComponentOfLoader}) callbacks.OnEnd[any](ctx, nil) - assert.Equal(t, 27, cnt) + assert.Equal(t, 38, cnt) ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{Component: compose.ComponentOfToolsNode}) callbacks.OnStart[any](ctx, nil) - assert.Equal(t, 28, cnt) + assert.Equal(t, 39, cnt) sr, sw := schema.Pipe[any](0) sw.Close() callbacks.OnEndWithStreamOutput[any](ctx, sr) - assert.Equal(t, 29, cnt) + assert.Equal(t, 40, cnt) sr1, sw1 := schema.Pipe[[]*schema.Message](1) sw1.Send([]*schema.Message{{}}, nil) sw1.Close() callbacks.OnEndWithStreamOutput[[]*schema.Message](ctx, sr1) - assert.Equal(t, 30, cnt) - - callbacks.OnError(ctx, nil) - assert.Equal(t, 30, cnt) + // Check AgenticModel stream + sir2, siw2 := schema.Pipe[callbacks.CallbackOutput](1) + siw2.Close() + handler.OnEndWithStreamOutput(ctx, &callbacks.RunInfo{Component: components.ComponentOfAgenticModel}, sir2) + assert.Equal(t, 42, cnt) + + // Check AgenticToolsNode stream + sir3, siw3 := schema.Pipe[callbacks.CallbackOutput](1) + siw3.Close() + handler.OnEndWithStreamOutput(ctx, &callbacks.RunInfo{Component: compose.ComponentOfAgenticToolsNode}, sir3) + assert.Equal(t, 43, cnt) ctx = callbacks.ReuseHandlers(ctx, nil) callbacks.OnStart[any](ctx, nil) - assert.Equal(t, 30, cnt) + assert.Equal(t, 43, cnt) + }) + + t.Run("EdgeCases", func(t *testing.T) { + ctx := context.Background() + cnt := 0 + + // 1. Test Graph and Chain Setters and Execution + tpl := NewHandlerHelper(). + Graph(callbacks.NewHandlerBuilder(). + OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { + cnt++ + return ctx + }).Build()). + Chain(callbacks.NewHandlerBuilder(). + OnEndFn(func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context { + cnt++ + return ctx + }).Build()) + + h := tpl.Handler() + + // Trigger Graph OnStart + h.OnStart(ctx, &callbacks.RunInfo{Component: compose.ComponentOfGraph}, nil) + assert.Equal(t, 1, cnt) + + // Trigger Chain OnEnd + h.OnEnd(ctx, &callbacks.RunInfo{Component: compose.ComponentOfChain}, nil) + assert.Equal(t, 2, cnt) + + // 2. Test Needed logic for Graph/Chain when handler is present/absent + // Graph is present (OnStart) + needed := h.(callbacks.TimingChecker).Needed(ctx, &callbacks.RunInfo{Component: compose.ComponentOfGraph}, callbacks.TimingOnStart) + assert.True(t, needed) + + // Chain is present (OnEnd) - but we check OnStart which is not defined in the builder above? + // NewHandlerBuilder returns a handler that usually returns true for Needed if the specific func is not nil. + // Let's verify Chain OnStart is NOT needed because we only set OnEndFn. + needed = h.(callbacks.TimingChecker).Needed(ctx, &callbacks.RunInfo{Component: compose.ComponentOfChain}, callbacks.TimingOnStart) + assert.False(t, needed) // Should be false because OnStartFn wasn't set for Chain + + // Lambda is NOT present + needed = h.(callbacks.TimingChecker).Needed(ctx, &callbacks.RunInfo{Component: compose.ComponentOfLambda}, callbacks.TimingOnStart) + assert.False(t, needed) + + // 3. Test Conversion Fallbacks (Default cases) + // We need a handler with ToolsNode and AgenticToolsNode to test their conversion fallbacks + tpl2 := NewHandlerHelper(). + ToolsNode(&ToolsNodeCallbackHandlers{ + OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *schema.Message) context.Context { + if input == nil { + cnt++ + } + return ctx + }, + OnEnd: func(ctx context.Context, info *callbacks.RunInfo, input []*schema.Message) context.Context { + if input == nil { + cnt++ + } + return ctx + }, + }). + AgenticToolsNode(&AgenticToolsNodeCallbackHandlers{ + OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *schema.AgenticMessage) context.Context { + if input == nil { + cnt++ + } + return ctx + }, + OnEnd: func(ctx context.Context, info *callbacks.RunInfo, input []*schema.AgenticMessage) context.Context { + if input == nil { + cnt++ + } + return ctx + }, + }) + + h2 := tpl2.Handler() + + // Pass wrong type (string) to trigger default case in convToolsNodeCallbackInput -> returns nil + h2.OnStart(ctx, &callbacks.RunInfo{Component: compose.ComponentOfToolsNode}, "wrong-input-type") + assert.Equal(t, 3, cnt) // +1 + + // Pass wrong type to trigger default case in convToolsNodeCallbackOutput -> returns nil + h2.OnEnd(ctx, &callbacks.RunInfo{Component: compose.ComponentOfToolsNode}, "wrong-output-type") + assert.Equal(t, 4, cnt) // +1 + + // Pass wrong type to trigger default case in convAgenticToolsNodeCallbackInput -> returns nil + h2.OnStart(ctx, &callbacks.RunInfo{Component: compose.ComponentOfAgenticToolsNode}, "wrong-input-type") + assert.Equal(t, 5, cnt) // +1 + + // Pass wrong type to trigger default case in convAgenticToolsNodeCallbackOutput -> returns nil + h2.OnEnd(ctx, &callbacks.RunInfo{Component: compose.ComponentOfAgenticToolsNode}, "wrong-output-type") + assert.Equal(t, 6, cnt) // +1 + + // 4. Test Needed for Agentic components when handlers are Set vs Unset + // tpl2 has AgenticToolsNode set + needed = h2.(callbacks.TimingChecker).Needed(ctx, &callbacks.RunInfo{Component: compose.ComponentOfAgenticToolsNode}, callbacks.TimingOnStart) + assert.True(t, needed) + + // tpl2 does NOT have AgenticModel set + needed = h2.(callbacks.TimingChecker).Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfAgenticModel}, callbacks.TimingOnStart) + assert.False(t, needed) + + // Set it now + tpl2.AgenticModel(&AgenticModelCallbackHandler{ + OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.CallbackInput) context.Context { + return ctx + }, + }) + + needed = h2.(callbacks.TimingChecker).Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfAgenticModel}, callbacks.TimingOnStart) + assert.True(t, needed) + + // Check invalid component + needed = h2.(callbacks.TimingChecker).Needed(ctx, &callbacks.RunInfo{Component: "UnknownComponent"}, callbacks.TimingOnStart) + assert.False(t, needed) + + // Check RunInfo nil + needed = h2.(callbacks.TimingChecker).Needed(ctx, nil, callbacks.TimingOnStart) + assert.False(t, needed) + + // 5. Test Needed for Transformer, Loader, Indexer, etc to ensure switch coverage + tpl3 := NewHandlerHelper(). + Transformer(&TransformerCallbackHandler{OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *document.TransformerCallbackInput) context.Context { + return ctx + }}). + Loader(&LoaderCallbackHandler{OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *document.LoaderCallbackInput) context.Context { + return ctx + }}). + Indexer(&IndexerCallbackHandler{OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *indexer.CallbackInput) context.Context { + return ctx + }}). + Retriever(&RetrieverCallbackHandler{OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *retriever.CallbackInput) context.Context { + return ctx + }}). + Embedding(&EmbeddingCallbackHandler{OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *embedding.CallbackInput) context.Context { + return ctx + }}). + Tool(&ToolCallbackHandler{OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *tool.CallbackInput) context.Context { + return ctx + }}) + + h3 := tpl3.Handler() + checker := h3.(callbacks.TimingChecker) + + assert.True(t, checker.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfTransformer}, callbacks.TimingOnStart)) + assert.True(t, checker.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfLoader}, callbacks.TimingOnStart)) + assert.True(t, checker.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfIndexer}, callbacks.TimingOnStart)) + assert.True(t, checker.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfRetriever}, callbacks.TimingOnStart)) + assert.True(t, checker.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfEmbedding}, callbacks.TimingOnStart)) + assert.True(t, checker.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfTool}, callbacks.TimingOnStart)) + + // Verify False paths (by using a helper without them) + emptyH := NewHandlerHelper().Handler().(callbacks.TimingChecker) + assert.False(t, emptyH.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfTransformer}, callbacks.TimingOnStart)) + assert.False(t, emptyH.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfLoader}, callbacks.TimingOnStart)) + assert.False(t, emptyH.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfIndexer}, callbacks.TimingOnStart)) + assert.False(t, emptyH.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfRetriever}, callbacks.TimingOnStart)) + assert.False(t, emptyH.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfEmbedding}, callbacks.TimingOnStart)) + assert.False(t, emptyH.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfTool}, callbacks.TimingOnStart)) + + // 6. Test Needed for remaining components (ChatModel, Prompt, AgenticPrompt) + tpl4 := NewHandlerHelper(). + ChatModel(&ModelCallbackHandler{OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.CallbackInput) context.Context { + return ctx + }}). + Prompt(&PromptCallbackHandler{OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *prompt.CallbackInput) context.Context { + return ctx + }}). + AgenticPrompt(&AgenticPromptCallbackHandler{OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *prompt.CallbackInput) context.Context { + return ctx + }}) + + h4 := tpl4.Handler() + checker4 := h4.(callbacks.TimingChecker) + + assert.True(t, checker4.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfChatModel}, callbacks.TimingOnStart)) + assert.True(t, checker4.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfPrompt}, callbacks.TimingOnStart)) + assert.True(t, checker4.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfAgenticPrompt}, callbacks.TimingOnStart)) }) } From 06ea53cc520b7352eebcb3d1b47019b366db4a1c Mon Sep 17 00:00:00 2001 From: mrh997 Date: Thu, 15 Jan 2026 16:19:47 +0800 Subject: [PATCH 23/28] feat: improve AgenticToolChoice (#684) --- components/model/option.go | 17 ++++++++--------- components/model/option_test.go | 13 ++++++++++--- compose/workflow.go | 18 ++++++++++++++++++ schema/tool.go | 31 +++++++++++++++++++++++++++++-- 4 files changed, 65 insertions(+), 14 deletions(-) diff --git a/components/model/option.go b/components/model/option.go index 0173d22aa..a337b7af2 100644 --- a/components/model/option.go +++ b/components/model/option.go @@ -28,11 +28,11 @@ type Options struct { TopP *float32 // Tools is a list of tools the model may call. Tools []*schema.ToolInfo - // ToolChoice controls which tool is called by the model. - ToolChoice *schema.ToolChoice - // Options only for chat model. + // Options only available for chat model. + // ToolChoice controls which tool is called by the model. + ToolChoice *schema.ToolChoice // MaxTokens is the max number of tokens, if reached the max tokens, the model will stop generating, and mostly return an finish reason of "length". MaxTokens *int // AllowedToolNames specifies a list of tool names that the model is allowed to call. @@ -41,10 +41,10 @@ type Options struct { // Stop is the stop words for the model, which controls the stopping condition of the model. Stop []string - // Options only for agentic model. + // Options only available for agentic model. - // AllowedTools is a list of allowed tools the model may call. - AllowedTools []*schema.AllowedTool + // AgenticToolChoice controls how the agentic model calls tools. + AgenticToolChoice *schema.AgenticToolChoice } // Option is a call-time option for a ChatModel. Options are immutable and @@ -130,11 +130,10 @@ func WithToolChoice(toolChoice schema.ToolChoice, allowedToolNames ...string) Op // WithAgenticToolChoice is the option to set tool choice for the agentic model. // Only available for AgenticModel. -func WithAgenticToolChoice(toolChoice schema.ToolChoice, allowedTools ...*schema.AllowedTool) Option { +func WithAgenticToolChoice(toolChoice *schema.AgenticToolChoice) Option { return Option{ apply: func(opts *Options) { - opts.ToolChoice = &toolChoice - opts.AllowedTools = allowedTools + opts.AgenticToolChoice = toolChoice }, } } diff --git a/components/model/option_test.go b/components/model/option_test.go index bfacdd17c..aa43e6e01 100644 --- a/components/model/option_test.go +++ b/components/model/option_test.go @@ -92,11 +92,18 @@ func TestOptions(t *testing.T) { ) opts := GetCommonOptions( nil, - WithAgenticToolChoice(toolChoice, allowedTools...), + WithAgenticToolChoice(&schema.AgenticToolChoice{ + Type: toolChoice, + Forced: &schema.AgenticForcedToolChoice{ + Tools: allowedTools, + }, + }), ) - convey.So(opts.ToolChoice, convey.ShouldResemble, &toolChoice) - convey.So(opts.AllowedTools, convey.ShouldResemble, allowedTools) + convey.So(opts.AgenticToolChoice, convey.ShouldNotBeNil) + convey.So(opts.AgenticToolChoice.Type, convey.ShouldEqual, toolChoice) + convey.So(opts.AgenticToolChoice.Forced, convey.ShouldNotBeNil) + convey.So(opts.AgenticToolChoice.Forced.Tools, convey.ShouldResemble, allowedTools) }) } diff --git a/compose/workflow.go b/compose/workflow.go index c3e4331a3..6b50962bb 100644 --- a/compose/workflow.go +++ b/compose/workflow.go @@ -89,18 +89,36 @@ func (wf *Workflow[I, O]) AddChatModelNode(key string, chatModel model.BaseChatM return wf.initNode(key) } +// AddAgenticModelNode adds an agentic model node and returns it. +func (wf *Workflow[I, O]) AddAgenticModelNode(key string, agenticModel model.AgenticModel, opts ...GraphAddNodeOpt) *WorkflowNode { + _ = wf.g.AddAgenticModelNode(key, agenticModel, opts...) + return wf.initNode(key) +} + // AddChatTemplateNode adds a chat template node and returns it. func (wf *Workflow[I, O]) AddChatTemplateNode(key string, chatTemplate prompt.ChatTemplate, opts ...GraphAddNodeOpt) *WorkflowNode { _ = wf.g.AddChatTemplateNode(key, chatTemplate, opts...) return wf.initNode(key) } +// AddAgenticChatTemplateNode adds an agentic chat template node and returns it. +func (wf *Workflow[I, O]) AddAgenticChatTemplateNode(key string, chatTemplate prompt.AgenticChatTemplate, opts ...GraphAddNodeOpt) *WorkflowNode { + _ = wf.g.AddAgenticChatTemplateNode(key, chatTemplate, opts...) + return wf.initNode(key) +} + // AddToolsNode adds a tools node and returns it. func (wf *Workflow[I, O]) AddToolsNode(key string, tools *ToolsNode, opts ...GraphAddNodeOpt) *WorkflowNode { _ = wf.g.AddToolsNode(key, tools, opts...) return wf.initNode(key) } +// AddAgenticToolsNode adds an agentic tools node and returns it. +func (wf *Workflow[I, O]) AddAgenticToolsNode(key string, tools *AgenticToolsNode, opts ...GraphAddNodeOpt) *WorkflowNode { + _ = wf.g.AddAgenticToolsNode(key, tools, opts...) + return wf.initNode(key) +} + // AddRetrieverNode adds a retriever node and returns it. func (wf *Workflow[I, O]) AddRetrieverNode(key string, retriever retriever.Retriever, opts ...GraphAddNodeOpt) *WorkflowNode { _ = wf.g.AddRetrieverNode(key, retriever, opts...) diff --git a/schema/tool.go b/schema/tool.go index efed6d34b..a067d87db 100644 --- a/schema/tool.go +++ b/schema/tool.go @@ -59,6 +59,31 @@ const ( ToolChoiceForced ToolChoice = "forced" ) +type AgenticToolChoice struct { + // Type is the tool choice mode. + Type ToolChoice + // Allowed optionally specifies the list of tools that the model is permitted to call. + Allowed *AgenticAllowedToolChoice + // Forced optionally specifies the list of tools that the model is required to call. + Forced *AgenticForcedToolChoice +} + +// AgenticAllowedToolChoice specifies a list of allowed tools for the model. +type AgenticAllowedToolChoice struct { + // Tools is the list of allowed tools for the model to call. + // Optional. + Tools []*AllowedTool +} + +// AgenticForcedToolChoice specifies a list of tools that the model must call. +type AgenticForcedToolChoice struct { + // Tools is the list of tools that the model must call. + // Optional. + Tools []*AllowedTool +} + +// AllowedTool represents a tool that the model is allowed or forced to call. +// Exactly one of FunctionToolName, MCPTool, or ServerTool must be specified. type AllowedTool struct { // FunctionToolName is the name of the function tool. FunctionToolName string @@ -68,15 +93,17 @@ type AllowedTool struct { ServerTool *AllowedServerTool } +// AllowedMCPTool contains the information for identifying an MCP tool. type AllowedMCPTool struct { // ServerLabel is the label of the MCP server. ServerLabel string - // The name of the MCP tool. + // Name is the name of the MCP tool. Name string } +// AllowedServerTool contains the information for identifying a server tool. type AllowedServerTool struct { - // The name of the server tool. + // Name is the name of the server tool. Name string } From 4c892b849f55ca3621234666f08e29862eab5267 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Thu, 15 Jan 2026 18:04:59 +0800 Subject: [PATCH 24/28] feat: define AgenticCallbackInput/Output (#689) --- components/model/agentic_callback_extra.go | 92 +++++++++++++++++++ .../model/agentic_callback_extra_test.go | 35 +++++++ components/model/callback_extra.go | 12 --- components/model/option_test.go | 2 +- components/prompt/agentic_callback_extra.go | 70 ++++++++++++++ .../prompt/agentic_callback_extra_test.go | 46 ++++++++++ components/prompt/agentic_chat_template.go | 4 +- components/prompt/callback_extra.go | 10 -- components/prompt/callback_extra_test.go | 17 +--- schema/agentic_message.go | 9 -- schema/agentic_message_test.go | 9 -- schema/tool.go | 8 +- 12 files changed, 256 insertions(+), 58 deletions(-) create mode 100644 components/model/agentic_callback_extra.go create mode 100644 components/model/agentic_callback_extra_test.go create mode 100644 components/prompt/agentic_callback_extra.go create mode 100644 components/prompt/agentic_callback_extra_test.go diff --git a/components/model/agentic_callback_extra.go b/components/model/agentic_callback_extra.go new file mode 100644 index 000000000..28dd366e6 --- /dev/null +++ b/components/model/agentic_callback_extra.go @@ -0,0 +1,92 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package model + +import ( + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/schema" +) + +// AgenticConfig is the config for the agentic model. +type AgenticConfig struct { + // Model is the model name. + Model string + // Temperature is the temperature, which controls the randomness of the agentic model. + Temperature float32 + // TopP is the top p, which controls the diversity of the agentic model. + TopP float32 +} + +// AgenticCallbackInput is the input for the agentic model callback. +type AgenticCallbackInput struct { + // AgenticMessages is the agentic messages to be sent to the agentic model. + AgenticMessages []*schema.AgenticMessage + // Tools is the tools to be used in the agentic model. + Tools []*schema.ToolInfo + // Config is the config for the agentic model. + Config *AgenticConfig + // Extra is the extra information for the callback. + Extra map[string]any +} + +// AgenticCallbackOutput is the output for the agentic model callback. +type AgenticCallbackOutput struct { + // AgenticMessage is the agentic message generated by the agentic model. + AgenticMessage *schema.AgenticMessage + // Config is the config for the agentic model. + Config *AgenticConfig + // TokenUsage is the token usage of this request. + TokenUsage *TokenUsage + // Extra is the extra information for the callback. + Extra map[string]any +} + +// ConvAgenticCallbackInput converts the callback input to the agentic model callback input. +func ConvAgenticCallbackInput(src callbacks.CallbackInput) *AgenticCallbackInput { + switch t := src.(type) { + case *AgenticCallbackInput: + // when callback is triggered within component implementation, + // the input is usually already a typed *model.AgenticCallbackInput + return t + case []*schema.AgenticMessage: + // when callback is injected by graph node, not the component implementation itself, + // the input is the input of Agentic Model interface, which is []*schema.AgenticMessage + return &AgenticCallbackInput{ + AgenticMessages: t, + } + default: + return nil + } +} + +// ConvAgenticCallbackOutput converts the callback output to the agentic model callback output. +func ConvAgenticCallbackOutput(src callbacks.CallbackOutput) *AgenticCallbackOutput { + switch t := src.(type) { + case *AgenticCallbackOutput: + // when callback is triggered within component implementation, + // the output is usually already a typed *model.AgenticCallbackOutput + return t + case *schema.AgenticMessage: + // when callback is injected by graph node, not the component implementation itself, + // the output is the output of Agentic Model interface, which is *schema.AgenticMessage + return &AgenticCallbackOutput{ + AgenticMessage: t, + } + default: + return nil + } +} diff --git a/components/model/agentic_callback_extra_test.go b/components/model/agentic_callback_extra_test.go new file mode 100644 index 000000000..937367477 --- /dev/null +++ b/components/model/agentic_callback_extra_test.go @@ -0,0 +1,35 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package model + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +func TestConvAgenticModel(t *testing.T) { + assert.NotNil(t, ConvAgenticCallbackInput(&AgenticCallbackInput{})) + assert.NotNil(t, ConvAgenticCallbackInput([]*schema.AgenticMessage{})) + assert.Nil(t, ConvAgenticCallbackInput("asd")) + + assert.NotNil(t, ConvAgenticCallbackOutput(&AgenticCallbackOutput{})) + assert.NotNil(t, ConvAgenticCallbackOutput(&schema.AgenticMessage{})) + assert.Nil(t, ConvAgenticCallbackOutput("asd")) +} diff --git a/components/model/callback_extra.go b/components/model/callback_extra.go index ed9096d5c..8591c4373 100644 --- a/components/model/callback_extra.go +++ b/components/model/callback_extra.go @@ -66,8 +66,6 @@ type Config struct { type CallbackInput struct { // Messages is the messages to be sent to the model. Messages []*schema.Message - // AgenticMessages is the agentic messages to be sent to the agentic model. - AgenticMessages []*schema.AgenticMessage // Tools is the tools to be used in the model. Tools []*schema.ToolInfo // ToolChoice is the tool choice, which controls the tool to be used in the model. @@ -82,8 +80,6 @@ type CallbackInput struct { type CallbackOutput struct { // Message is the message generated by the model. Message *schema.Message - // AgenticMessage is the agentic message generated by the agentic model. - AgenticMessage *schema.AgenticMessage // Config is the config for the model. Config *Config // TokenUsage is the token usage of this request. @@ -101,10 +97,6 @@ func ConvCallbackInput(src callbacks.CallbackInput) *CallbackInput { return &CallbackInput{ Messages: t, } - case []*schema.AgenticMessage: // when callback is injected by graph node, not the component implementation itself, the input is the input of Agentic Model interface, which is []*schema.AgenticMessage - return &CallbackInput{ - AgenticMessages: t, - } default: return nil } @@ -119,10 +111,6 @@ func ConvCallbackOutput(src callbacks.CallbackOutput) *CallbackOutput { return &CallbackOutput{ Message: t, } - case *schema.AgenticMessage: // when callback is injected by graph node, not the component implementation itself, the output is the output of Agentic Model interface, which is *schema.AgenticMessage - return &CallbackOutput{ - AgenticMessage: t, - } default: return nil } diff --git a/components/model/option_test.go b/components/model/option_test.go index aa43e6e01..c836933b7 100644 --- a/components/model/option_test.go +++ b/components/model/option_test.go @@ -87,7 +87,7 @@ func TestOptions(t *testing.T) { var ( toolChoice = schema.ToolChoiceForced allowedTools = []*schema.AllowedTool{ - {FunctionToolName: "agentic_tool"}, + {FunctionName: "agentic_tool"}, } ) opts := GetCommonOptions( diff --git a/components/prompt/agentic_callback_extra.go b/components/prompt/agentic_callback_extra.go new file mode 100644 index 000000000..1170854a1 --- /dev/null +++ b/components/prompt/agentic_callback_extra.go @@ -0,0 +1,70 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package prompt + +import ( + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/schema" +) + +// AgenticCallbackInput is the input for the callback. +type AgenticCallbackInput struct { + // Variables is the variables for the callback. + Variables map[string]any + // AgenticTemplates is the agentic templates for the callback. + AgenticTemplates []schema.AgenticMessagesTemplate + // Extra is the extra information for the callback. + Extra map[string]any +} + +// AgenticCallbackOutput is the output for the callback. +type AgenticCallbackOutput struct { + // AgenticResult is the agentic result for the callback. + AgenticResult []*schema.AgenticMessage + // AgenticTemplates is the agentic templates for the callback. + AgenticTemplates []schema.AgenticMessagesTemplate + // Extra is the extra information for the callback. + Extra map[string]any +} + +// ConvAgenticCallbackInput converts the callback input to the agentic prompt callback input. +func ConvAgenticCallbackInput(src callbacks.CallbackInput) *AgenticCallbackInput { + switch t := src.(type) { + case *AgenticCallbackInput: + return t + case map[string]any: + return &AgenticCallbackInput{ + Variables: t, + } + default: + return nil + } +} + +// ConvAgenticCallbackOutput converts the callback output to the agentic prompt callback output. +func ConvAgenticCallbackOutput(src callbacks.CallbackOutput) *AgenticCallbackOutput { + switch t := src.(type) { + case *AgenticCallbackOutput: + return t + case []*schema.AgenticMessage: + return &AgenticCallbackOutput{ + AgenticResult: t, + } + default: + return nil + } +} diff --git a/components/prompt/agentic_callback_extra_test.go b/components/prompt/agentic_callback_extra_test.go new file mode 100644 index 000000000..6dda1a349 --- /dev/null +++ b/components/prompt/agentic_callback_extra_test.go @@ -0,0 +1,46 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package prompt + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +func TestConvAgenticPrompt(t *testing.T) { + assert.NotNil(t, ConvAgenticCallbackInput(&AgenticCallbackInput{ + Variables: map[string]any{}, + AgenticTemplates: []schema.AgenticMessagesTemplate{ + &schema.AgenticMessage{}, + }, + })) + assert.NotNil(t, ConvAgenticCallbackInput(map[string]any{})) + assert.Nil(t, ConvAgenticCallbackInput("asd")) + + assert.NotNil(t, ConvAgenticCallbackOutput(&AgenticCallbackOutput{ + AgenticResult: []*schema.AgenticMessage{ + {}, + }, + AgenticTemplates: []schema.AgenticMessagesTemplate{ + &schema.AgenticMessage{}, + }, + })) + assert.NotNil(t, ConvAgenticCallbackOutput([]*schema.AgenticMessage{})) +} diff --git a/components/prompt/agentic_chat_template.go b/components/prompt/agentic_chat_template.go index 512a60ecd..c6c300d5a 100644 --- a/components/prompt/agentic_chat_template.go +++ b/components/prompt/agentic_chat_template.go @@ -45,7 +45,7 @@ type DefaultAgenticChatTemplate struct { func (t *DefaultAgenticChatTemplate) Format(ctx context.Context, vs map[string]any, opts ...Option) (result []*schema.AgenticMessage, err error) { ctx = callbacks.EnsureRunInfo(ctx, t.GetType(), components.ComponentOfAgenticPrompt) - ctx = callbacks.OnStart(ctx, &CallbackInput{ + ctx = callbacks.OnStart(ctx, &AgenticCallbackInput{ Variables: vs, AgenticTemplates: t.templates, }) @@ -65,7 +65,7 @@ func (t *DefaultAgenticChatTemplate) Format(ctx context.Context, vs map[string]a result = append(result, msgs...) } - _ = callbacks.OnEnd(ctx, &CallbackOutput{ + _ = callbacks.OnEnd(ctx, &AgenticCallbackOutput{ AgenticResult: result, AgenticTemplates: t.templates, }) diff --git a/components/prompt/callback_extra.go b/components/prompt/callback_extra.go index 4c27f37c6..324a418f3 100644 --- a/components/prompt/callback_extra.go +++ b/components/prompt/callback_extra.go @@ -27,8 +27,6 @@ type CallbackInput struct { Variables map[string]any // Templates is the templates for the callback. Templates []schema.MessagesTemplate - // AgenticTemplates is the agentic templates for the callback. - AgenticTemplates []schema.AgenticMessagesTemplate // Extra is the extra information for the callback. Extra map[string]any } @@ -37,12 +35,8 @@ type CallbackInput struct { type CallbackOutput struct { // Result is the result for the callback. Result []*schema.Message - // AgenticResult is the agentic result for the callback. - AgenticResult []*schema.AgenticMessage // Templates is the templates for the callback. Templates []schema.MessagesTemplate - // AgenticTemplates is the agentic templates for the callback. - AgenticTemplates []schema.AgenticMessagesTemplate // Extra is the extra information for the callback. Extra map[string]any } @@ -70,10 +64,6 @@ func ConvCallbackOutput(src callbacks.CallbackOutput) *CallbackOutput { return &CallbackOutput{ Result: t, } - case []*schema.AgenticMessage: - return &CallbackOutput{ - AgenticResult: t, - } default: return nil } diff --git a/components/prompt/callback_extra_test.go b/components/prompt/callback_extra_test.go index 4b48ec114..ad8a3c0c2 100644 --- a/components/prompt/callback_extra_test.go +++ b/components/prompt/callback_extra_test.go @@ -26,27 +26,20 @@ import ( func TestConvPrompt(t *testing.T) { assert.NotNil(t, ConvCallbackInput(&CallbackInput{ - AgenticTemplates: []schema.AgenticMessagesTemplate{ - &schema.AgenticMessage{}, + Templates: []schema.MessagesTemplate{ + &schema.Message{}, }, })) assert.NotNil(t, ConvCallbackInput(map[string]any{})) assert.Nil(t, ConvCallbackInput("asd")) assert.NotNil(t, ConvCallbackOutput(&CallbackOutput{ - AgenticResult: []*schema.AgenticMessage{ + Result: []*schema.Message{ {}, }, - AgenticTemplates: []schema.AgenticMessagesTemplate{ - &schema.AgenticMessage{}, + Templates: []schema.MessagesTemplate{ + &schema.Message{}, }, })) assert.NotNil(t, ConvCallbackOutput([]*schema.Message{})) - - agenticResult := []*schema.AgenticMessage{{}} - out := ConvCallbackOutput(agenticResult) - assert.NotNil(t, out) - assert.Equal(t, agenticResult, out.AgenticResult) - - assert.Nil(t, ConvCallbackOutput("asd")) } diff --git a/schema/agentic_message.go b/schema/agentic_message.go index a4554d38e..743f67855 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -59,7 +59,6 @@ const ( type AgenticRoleType string const ( - AgenticRoleTypeDeveloper AgenticRoleType = "developer" AgenticRoleTypeSystem AgenticRoleType = "system" AgenticRoleTypeUser AgenticRoleType = "user" AgenticRoleTypeAssistant AgenticRoleType = "assistant" @@ -426,14 +425,6 @@ type MCPToolApprovalResponse struct { Reason string } -// DeveloperAgenticMessage represents a message with AgenticRoleType "developer". -func DeveloperAgenticMessage(text string) *AgenticMessage { - return &AgenticMessage{ - Role: AgenticRoleTypeDeveloper, - ContentBlocks: []*ContentBlock{NewContentBlock(&UserInputText{Text: text})}, - } -} - // SystemAgenticMessage represents a message with AgenticRoleType "system". func SystemAgenticMessage(text string) *AgenticMessage { return &AgenticMessage{ diff --git a/schema/agentic_message_test.go b/schema/agentic_message_test.go index 4beb74930..144c0077e 100644 --- a/schema/agentic_message_test.go +++ b/schema/agentic_message_test.go @@ -1544,15 +1544,6 @@ response_meta: }) } -func TestDeveloperAgenticMessage(t *testing.T) { - t.Run("basic", func(t *testing.T) { - msg := DeveloperAgenticMessage("developer") - assert.Equal(t, AgenticRoleTypeDeveloper, msg.Role) - assert.Len(t, msg.ContentBlocks, 1) - assert.Equal(t, "developer", msg.ContentBlocks[0].UserInputText.Text) - }) -} - func TestSystemAgenticMessage(t *testing.T) { t.Run("basic", func(t *testing.T) { msg := SystemAgenticMessage("system") diff --git a/schema/tool.go b/schema/tool.go index a067d87db..c195d1f9e 100644 --- a/schema/tool.go +++ b/schema/tool.go @@ -83,13 +83,15 @@ type AgenticForcedToolChoice struct { } // AllowedTool represents a tool that the model is allowed or forced to call. -// Exactly one of FunctionToolName, MCPTool, or ServerTool must be specified. +// Exactly one of FunctionName, MCPTool, or ServerTool must be specified. type AllowedTool struct { - // FunctionToolName is the name of the function tool. - FunctionToolName string + // FunctionName specifies a function tool by name. + FunctionName string + // MCPTool specifies an MCP tool. MCPTool *AllowedMCPTool + // ServerTool specifies a server tool. ServerTool *AllowedServerTool } From ccfc66f99c742a2458fce1b67b1cf276a95663db Mon Sep 17 00:00:00 2001 From: mrh997 Date: Thu, 15 Jan 2026 21:49:22 +0800 Subject: [PATCH 25/28] feat: improve callback definition (#692) --- components/model/agentic_callback_extra.go | 12 ++++++------ components/prompt/agentic_callback_extra.go | 14 +++++++------- components/prompt/agentic_callback_extra_test.go | 6 +++--- components/prompt/agentic_chat_template.go | 8 ++++---- 4 files changed, 20 insertions(+), 20 deletions(-) diff --git a/components/model/agentic_callback_extra.go b/components/model/agentic_callback_extra.go index 28dd366e6..54d49ff72 100644 --- a/components/model/agentic_callback_extra.go +++ b/components/model/agentic_callback_extra.go @@ -33,8 +33,8 @@ type AgenticConfig struct { // AgenticCallbackInput is the input for the agentic model callback. type AgenticCallbackInput struct { - // AgenticMessages is the agentic messages to be sent to the agentic model. - AgenticMessages []*schema.AgenticMessage + // Messages is the agentic messages to be sent to the agentic model. + Messages []*schema.AgenticMessage // Tools is the tools to be used in the agentic model. Tools []*schema.ToolInfo // Config is the config for the agentic model. @@ -45,8 +45,8 @@ type AgenticCallbackInput struct { // AgenticCallbackOutput is the output for the agentic model callback. type AgenticCallbackOutput struct { - // AgenticMessage is the agentic message generated by the agentic model. - AgenticMessage *schema.AgenticMessage + // Message is the agentic message generated by the agentic model. + Message *schema.AgenticMessage // Config is the config for the agentic model. Config *AgenticConfig // TokenUsage is the token usage of this request. @@ -66,7 +66,7 @@ func ConvAgenticCallbackInput(src callbacks.CallbackInput) *AgenticCallbackInput // when callback is injected by graph node, not the component implementation itself, // the input is the input of Agentic Model interface, which is []*schema.AgenticMessage return &AgenticCallbackInput{ - AgenticMessages: t, + Messages: t, } default: return nil @@ -84,7 +84,7 @@ func ConvAgenticCallbackOutput(src callbacks.CallbackOutput) *AgenticCallbackOut // when callback is injected by graph node, not the component implementation itself, // the output is the output of Agentic Model interface, which is *schema.AgenticMessage return &AgenticCallbackOutput{ - AgenticMessage: t, + Message: t, } default: return nil diff --git a/components/prompt/agentic_callback_extra.go b/components/prompt/agentic_callback_extra.go index 1170854a1..315d5a4da 100644 --- a/components/prompt/agentic_callback_extra.go +++ b/components/prompt/agentic_callback_extra.go @@ -25,18 +25,18 @@ import ( type AgenticCallbackInput struct { // Variables is the variables for the callback. Variables map[string]any - // AgenticTemplates is the agentic templates for the callback. - AgenticTemplates []schema.AgenticMessagesTemplate + // Templates is the agentic templates for the callback. + Templates []schema.AgenticMessagesTemplate // Extra is the extra information for the callback. Extra map[string]any } // AgenticCallbackOutput is the output for the callback. type AgenticCallbackOutput struct { - // AgenticResult is the agentic result for the callback. - AgenticResult []*schema.AgenticMessage - // AgenticTemplates is the agentic templates for the callback. - AgenticTemplates []schema.AgenticMessagesTemplate + // Result is the agentic result for the callback. + Result []*schema.AgenticMessage + // Templates is the agentic templates for the callback. + Templates []schema.AgenticMessagesTemplate // Extra is the extra information for the callback. Extra map[string]any } @@ -62,7 +62,7 @@ func ConvAgenticCallbackOutput(src callbacks.CallbackOutput) *AgenticCallbackOut return t case []*schema.AgenticMessage: return &AgenticCallbackOutput{ - AgenticResult: t, + Result: t, } default: return nil diff --git a/components/prompt/agentic_callback_extra_test.go b/components/prompt/agentic_callback_extra_test.go index 6dda1a349..67982be80 100644 --- a/components/prompt/agentic_callback_extra_test.go +++ b/components/prompt/agentic_callback_extra_test.go @@ -27,7 +27,7 @@ import ( func TestConvAgenticPrompt(t *testing.T) { assert.NotNil(t, ConvAgenticCallbackInput(&AgenticCallbackInput{ Variables: map[string]any{}, - AgenticTemplates: []schema.AgenticMessagesTemplate{ + Templates: []schema.AgenticMessagesTemplate{ &schema.AgenticMessage{}, }, })) @@ -35,10 +35,10 @@ func TestConvAgenticPrompt(t *testing.T) { assert.Nil(t, ConvAgenticCallbackInput("asd")) assert.NotNil(t, ConvAgenticCallbackOutput(&AgenticCallbackOutput{ - AgenticResult: []*schema.AgenticMessage{ + Result: []*schema.AgenticMessage{ {}, }, - AgenticTemplates: []schema.AgenticMessagesTemplate{ + Templates: []schema.AgenticMessagesTemplate{ &schema.AgenticMessage{}, }, })) diff --git a/components/prompt/agentic_chat_template.go b/components/prompt/agentic_chat_template.go index c6c300d5a..41d291065 100644 --- a/components/prompt/agentic_chat_template.go +++ b/components/prompt/agentic_chat_template.go @@ -46,8 +46,8 @@ type DefaultAgenticChatTemplate struct { func (t *DefaultAgenticChatTemplate) Format(ctx context.Context, vs map[string]any, opts ...Option) (result []*schema.AgenticMessage, err error) { ctx = callbacks.EnsureRunInfo(ctx, t.GetType(), components.ComponentOfAgenticPrompt) ctx = callbacks.OnStart(ctx, &AgenticCallbackInput{ - Variables: vs, - AgenticTemplates: t.templates, + Variables: vs, + Templates: t.templates, }) defer func() { if err != nil { @@ -66,8 +66,8 @@ func (t *DefaultAgenticChatTemplate) Format(ctx context.Context, vs map[string]a } _ = callbacks.OnEnd(ctx, &AgenticCallbackOutput{ - AgenticResult: result, - AgenticTemplates: t.templates, + Result: result, + Templates: t.templates, }) return result, nil From 766ca832e54dae418f397e426e5612639d53ea54 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Mon, 19 Jan 2026 20:58:32 +0800 Subject: [PATCH 26/28] feat: improve callback definition (#702) --- schema/agentic_message.go | 61 +++++++------------------------- schema/agentic_message_test.go | 45 ++++++++--------------- schema/tool.go | 4 +++ utils/callbacks/template.go | 14 ++++---- utils/callbacks/template_test.go | 8 ++--- 5 files changed, 41 insertions(+), 91 deletions(-) diff --git a/schema/agentic_message.go b/schema/agentic_message.go index 743f67855..ead2d866d 100644 --- a/schema/agentic_message.go +++ b/schema/agentic_message.go @@ -270,19 +270,12 @@ type AssistantGenVideo struct { } type Reasoning struct { - // Summary is the reasoning content summary. - Summary []*ReasoningSummary - - // EncryptedContent is the encrypted reasoning content. - EncryptedContent string -} - -type ReasoningSummary struct { - // Index specifies the index position of this summary in the final Reasoning. - Index int - - // Text is the reasoning content summary. + // Text is either the thought summary or the raw reasoning text itself. Text string + + // Signature contains encrypted reasoning tokens. + // Required by some models when passing reasoning text back. + Signature string } type FunctionToolCall struct { @@ -1172,42 +1165,15 @@ func concatReasoning(reasons []*Reasoning) (*Reasoning, error) { ret := &Reasoning{} - var allSummaries []*ReasoningSummary for _, r := range reasons { - if r == nil { - continue + if r.Text != "" { + ret.Text += r.Text } - allSummaries = append(allSummaries, r.Summary...) - if r.EncryptedContent != "" { - ret.EncryptedContent += r.EncryptedContent + if r.Signature != "" { + ret.Signature += r.Signature } } - var ( - indices []int - indexToSummary = map[int]*ReasoningSummary{} - ) - - for _, s := range allSummaries { - if s == nil { - continue - } - if indexToSummary[s.Index] == nil { - indexToSummary[s.Index] = &ReasoningSummary{} - indices = append(indices, s.Index) - } - indexToSummary[s.Index].Text += s.Text - } - - sort.Slice(indices, func(i, j int) bool { - return indices[i] < indices[j] - }) - - ret.Summary = make([]*ReasoningSummary, 0, len(indices)) - for _, idx := range indices { - ret.Summary = append(ret.Summary, indexToSummary[idx]) - } - return ret, nil } @@ -1899,12 +1865,9 @@ func (b *ContentBlock) String() string { // String returns the string representation of Reasoning. func (r *Reasoning) String() string { sb := &strings.Builder{} - sb.WriteString(fmt.Sprintf(" summary: %d items\n", len(r.Summary))) - for _, s := range r.Summary { - sb.WriteString(fmt.Sprintf(" [%d] %s\n", s.Index, s.Text)) - } - if r.EncryptedContent != "" { - sb.WriteString(fmt.Sprintf(" encrypted_content: %s\n", truncateString(r.EncryptedContent, 50))) + sb.WriteString(fmt.Sprintf(" text: %s\n", r.Text)) + if r.Signature != "" { + sb.WriteString(fmt.Sprintf(" signature: %s\n", truncateString(r.Signature, 50))) } return sb.String() } diff --git a/schema/agentic_message_test.go b/schema/agentic_message_test.go index 144c0077e..e8a1003f5 100644 --- a/schema/agentic_message_test.go +++ b/schema/agentic_message_test.go @@ -109,9 +109,7 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeReasoning, Reasoning: &Reasoning{ - Summary: []*ReasoningSummary{ - {Index: 0, Text: "First "}, - }, + Text: "First ", }, StreamingMeta: &StreamingMeta{Index: 0}, }, @@ -123,9 +121,7 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeReasoning, Reasoning: &Reasoning{ - Summary: []*ReasoningSummary{ - {Index: 0, Text: "Second"}, - }, + Text: "Second", }, StreamingMeta: &StreamingMeta{Index: 0}, }, @@ -136,9 +132,7 @@ func TestConcatAgenticMessages(t *testing.T) { result, err := ConcatAgenticMessages(msgs) assert.NoError(t, err) assert.Len(t, result.ContentBlocks, 1) - assert.Len(t, result.ContentBlocks[0].Reasoning.Summary, 1) - assert.Equal(t, "First Second", result.ContentBlocks[0].Reasoning.Summary[0].Text) - assert.Equal(t, 0, result.ContentBlocks[0].Reasoning.Summary[0].Index) + assert.Equal(t, "First Second", result.ContentBlocks[0].Reasoning.Text) }) t.Run("concat reasoning with index", func(t *testing.T) { @@ -149,10 +143,7 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeReasoning, Reasoning: &Reasoning{ - Summary: []*ReasoningSummary{ - {Index: 0, Text: "Part1-"}, - {Index: 1, Text: "Part2-"}, - }, + Text: "Part1-", }, StreamingMeta: &StreamingMeta{Index: 0}, }, @@ -164,10 +155,7 @@ func TestConcatAgenticMessages(t *testing.T) { { Type: ContentBlockTypeReasoning, Reasoning: &Reasoning{ - Summary: []*ReasoningSummary{ - {Index: 0, Text: "Part3"}, - {Index: 1, Text: "Part4"}, - }, + Text: "Part3", }, StreamingMeta: &StreamingMeta{Index: 0}, }, @@ -178,9 +166,7 @@ func TestConcatAgenticMessages(t *testing.T) { result, err := ConcatAgenticMessages(msgs) assert.NoError(t, err) assert.Len(t, result.ContentBlocks, 1) - assert.Len(t, result.ContentBlocks[0].Reasoning.Summary, 2) - assert.Equal(t, "Part1-Part3", result.ContentBlocks[0].Reasoning.Summary[0].Text) - assert.Equal(t, "Part2-Part4", result.ContentBlocks[0].Reasoning.Summary[1].Text) + assert.Equal(t, "Part1-Part3", result.ContentBlocks[0].Reasoning.Text) }) t.Run("concat user input text", func(t *testing.T) { @@ -1292,12 +1278,10 @@ func TestAgenticMessageString(t *testing.T) { { Type: ContentBlockTypeReasoning, Reasoning: &Reasoning{ - Summary: []*ReasoningSummary{ - {Index: 0, Text: "First, I need to identify the location (New York City) from the user's query."}, - {Index: 1, Text: "Then, I should call the weather API to get current conditions."}, - {Index: 2, Text: "Finally, I'll format the response in a user-friendly way with temperature and conditions."}, - }, - EncryptedContent: "encrypted_reasoning_content_that_is_very_long_and_will_be_truncated_for_display", + Text: "First, I need to identify the location (New York City) from the user's query.\n" + + "Then, I should call the weather API to get current conditions.\n" + + "Finally, I'll format the response in a user-friendly way with temperature and conditions.", + Signature: "encrypted_reasoning_content_that_is_very_long_and_will_be_truncated_for_display", }, }, { @@ -1432,11 +1416,10 @@ content_blocks: base64_data: gen_video_data... (14 bytes) mime_type: video/mp4 [9] type: reasoning - summary: 3 items - [0] First, I need to identify the location (New York City) from the user's query. - [1] Then, I should call the weather API to get current conditions. - [2] Finally, I'll format the response in a user-friendly way with temperature and conditions. - encrypted_content: encrypted_reasoning_content_that_is_very_long_and_... + text: First, I need to identify the location (New York City) from the user's query. +Then, I should call the weather API to get current conditions. +Finally, I'll format the response in a user-friendly way with temperature and conditions. + signature: encrypted_reasoning_content_that_is_very_long_and_... [10] type: function_tool_call call_id: call_weather_123 name: get_current_weather diff --git a/schema/tool.go b/schema/tool.go index c195d1f9e..a49306047 100644 --- a/schema/tool.go +++ b/schema/tool.go @@ -62,9 +62,13 @@ const ( type AgenticToolChoice struct { // Type is the tool choice mode. Type ToolChoice + // Allowed optionally specifies the list of tools that the model is permitted to call. + // Optional. Allowed *AgenticAllowedToolChoice + // Forced optionally specifies the list of tools that the model is required to call. + // Optional. Forced *AgenticForcedToolChoice } diff --git a/utils/callbacks/template.go b/utils/callbacks/template.go index 4c73e6bbc..4c2c709da 100644 --- a/utils/callbacks/template.go +++ b/utils/callbacks/template.go @@ -187,7 +187,7 @@ func (c *handlerTemplate) OnStart(ctx context.Context, info *callbacks.RunInfo, case components.ComponentOfChatModel: return c.chatModelHandler.OnStart(ctx, info, model.ConvCallbackInput(input)) case components.ComponentOfAgenticModel: - return c.agenticModelHandler.OnStart(ctx, info, model.ConvCallbackInput(input)) + return c.agenticModelHandler.OnStart(ctx, info, model.ConvAgenticCallbackInput(input)) case components.ComponentOfEmbedding: return c.embeddingHandler.OnStart(ctx, info, embedding.ConvCallbackInput(input)) case components.ComponentOfIndexer: @@ -226,7 +226,7 @@ func (c *handlerTemplate) OnEnd(ctx context.Context, info *callbacks.RunInfo, ou case components.ComponentOfChatModel: return c.chatModelHandler.OnEnd(ctx, info, model.ConvCallbackOutput(output)) case components.ComponentOfAgenticModel: - return c.agenticModelHandler.OnEnd(ctx, info, model.ConvCallbackOutput(output)) + return c.agenticModelHandler.OnEnd(ctx, info, model.ConvAgenticCallbackOutput(output)) case components.ComponentOfEmbedding: return c.embeddingHandler.OnEnd(ctx, info, embedding.ConvCallbackOutput(output)) case components.ComponentOfIndexer: @@ -316,8 +316,8 @@ func (c *handlerTemplate) OnEndWithStreamOutput(ctx context.Context, info *callb })) case components.ComponentOfAgenticModel: return c.agenticModelHandler.OnEndWithStreamOutput(ctx, info, - schema.StreamReaderWithConvert(output, func(item callbacks.CallbackOutput) (*model.CallbackOutput, error) { - return model.ConvCallbackOutput(item), nil + schema.StreamReaderWithConvert(output, func(item callbacks.CallbackOutput) (*model.AgenticCallbackOutput, error) { + return model.ConvAgenticCallbackOutput(item), nil })) case components.ComponentOfTool: return c.toolHandler.OnEndWithStreamOutput(ctx, info, @@ -686,9 +686,9 @@ func (ch *AgenticPromptCallbackHandler) Needed(ctx context.Context, runInfo *cal // AgenticModelCallbackHandler is the handler for the agentic chat model callback. type AgenticModelCallbackHandler struct { - OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.CallbackInput) context.Context - OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *model.CallbackOutput) context.Context - OnEndWithStreamOutput func(ctx context.Context, runInfo *callbacks.RunInfo, output *schema.StreamReader[*model.CallbackOutput]) context.Context + OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.AgenticCallbackInput) context.Context + OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *model.AgenticCallbackOutput) context.Context + OnEndWithStreamOutput func(ctx context.Context, runInfo *callbacks.RunInfo, output *schema.StreamReader[*model.AgenticCallbackOutput]) context.Context OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context } diff --git a/utils/callbacks/template_test.go b/utils/callbacks/template_test.go index f599e5300..dcc0e5c7f 100644 --- a/utils/callbacks/template_test.go +++ b/utils/callbacks/template_test.go @@ -143,15 +143,15 @@ func TestNewComponentTemplate(t *testing.T) { return ctx }).Build()). AgenticModel(&AgenticModelCallbackHandler{ - OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.CallbackInput) context.Context { + OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.AgenticCallbackInput) context.Context { cnt++ return ctx }, - OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *model.CallbackOutput) context.Context { + OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *model.AgenticCallbackOutput) context.Context { cnt++ return ctx }, - OnEndWithStreamOutput: func(ctx context.Context, runInfo *callbacks.RunInfo, output *schema.StreamReader[*model.CallbackOutput]) context.Context { + OnEndWithStreamOutput: func(ctx context.Context, runInfo *callbacks.RunInfo, output *schema.StreamReader[*model.AgenticCallbackOutput]) context.Context { output.Close() cnt++ return ctx @@ -485,7 +485,7 @@ func TestNewComponentTemplate(t *testing.T) { // Set it now tpl2.AgenticModel(&AgenticModelCallbackHandler{ - OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.CallbackInput) context.Context { + OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.AgenticCallbackInput) context.Context { return ctx }, }) From 088153726f0fb5d896134a19e8cc99ed142b866c Mon Sep 17 00:00:00 2001 From: mrh997 Date: Mon, 19 Jan 2026 22:12:10 +0800 Subject: [PATCH 27/28] feat: agentic model support MaxTokens (#703) --- components/model/agentic_callback_extra.go | 2 ++ components/model/option.go | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/components/model/agentic_callback_extra.go b/components/model/agentic_callback_extra.go index 54d49ff72..9a769cf7e 100644 --- a/components/model/agentic_callback_extra.go +++ b/components/model/agentic_callback_extra.go @@ -25,6 +25,8 @@ import ( type AgenticConfig struct { // Model is the model name. Model string + // MaxTokens is the max number of output tokens, if reached the max tokens, the model will stop generating. + MaxTokens int // Temperature is the temperature, which controls the randomness of the agentic model. Temperature float32 // TopP is the top p, which controls the diversity of the agentic model. diff --git a/components/model/option.go b/components/model/option.go index a337b7af2..a46b71b19 100644 --- a/components/model/option.go +++ b/components/model/option.go @@ -28,13 +28,13 @@ type Options struct { TopP *float32 // Tools is a list of tools the model may call. Tools []*schema.ToolInfo + // MaxTokens is the max number of tokens, if reached the max tokens, the model will stop generating, and mostly return a finish reason of "length". + MaxTokens *int // Options only available for chat model. // ToolChoice controls which tool is called by the model. ToolChoice *schema.ToolChoice - // MaxTokens is the max number of tokens, if reached the max tokens, the model will stop generating, and mostly return an finish reason of "length". - MaxTokens *int // AllowedToolNames specifies a list of tool names that the model is allowed to call. // This allows for constraining the model to a specific subset of the available tools. AllowedToolNames []string From 4bb86f5e73ef4d3924d97749c8e57462258fd666 Mon Sep 17 00:00:00 2001 From: mrh997 Date: Tue, 20 Jan 2026 13:21:40 +0800 Subject: [PATCH 28/28] feat: agentic model support stop option --- components/model/option.go | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/components/model/option.go b/components/model/option.go index a46b71b19..936b0fbda 100644 --- a/components/model/option.go +++ b/components/model/option.go @@ -30,6 +30,8 @@ type Options struct { Tools []*schema.ToolInfo // MaxTokens is the max number of tokens, if reached the max tokens, the model will stop generating, and mostly return a finish reason of "length". MaxTokens *int + // Stop is the stop words for the model, which controls the stopping condition of the model. + Stop []string // Options only available for chat model. @@ -38,8 +40,6 @@ type Options struct { // AllowedToolNames specifies a list of tool names that the model is allowed to call. // This allows for constraining the model to a specific subset of the available tools. AllowedToolNames []string - // Stop is the stop words for the model, which controls the stopping condition of the model. - Stop []string // Options only available for agentic model. @@ -67,7 +67,6 @@ func WithTemperature(temperature float32) Option { } // WithMaxTokens is the option to set the max tokens for the model. -// Only available for ChatModel. func WithMaxTokens(maxTokens int) Option { return Option{ apply: func(opts *Options) { @@ -95,7 +94,6 @@ func WithTopP(topP float32) Option { } // WithStop is the option to set the stop words for the model. -// Only available for ChatModel. func WithStop(stop []string) Option { return Option{ apply: func(opts *Options) {