diff --git a/components/model/agentic_callback_extra.go b/components/model/agentic_callback_extra.go new file mode 100644 index 000000000..9a769cf7e --- /dev/null +++ b/components/model/agentic_callback_extra.go @@ -0,0 +1,94 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package model + +import ( + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/schema" +) + +// AgenticConfig is the config for the agentic model. +type AgenticConfig struct { + // Model is the model name. + Model string + // MaxTokens is the max number of output tokens, if reached the max tokens, the model will stop generating. + MaxTokens int + // Temperature is the temperature, which controls the randomness of the agentic model. + Temperature float32 + // TopP is the top p, which controls the diversity of the agentic model. + TopP float32 +} + +// AgenticCallbackInput is the input for the agentic model callback. +type AgenticCallbackInput struct { + // Messages is the agentic messages to be sent to the agentic model. + Messages []*schema.AgenticMessage + // Tools is the tools to be used in the agentic model. + Tools []*schema.ToolInfo + // Config is the config for the agentic model. + Config *AgenticConfig + // Extra is the extra information for the callback. + Extra map[string]any +} + +// AgenticCallbackOutput is the output for the agentic model callback. +type AgenticCallbackOutput struct { + // Message is the agentic message generated by the agentic model. + Message *schema.AgenticMessage + // Config is the config for the agentic model. + Config *AgenticConfig + // TokenUsage is the token usage of this request. + TokenUsage *TokenUsage + // Extra is the extra information for the callback. + Extra map[string]any +} + +// ConvAgenticCallbackInput converts the callback input to the agentic model callback input. +func ConvAgenticCallbackInput(src callbacks.CallbackInput) *AgenticCallbackInput { + switch t := src.(type) { + case *AgenticCallbackInput: + // when callback is triggered within component implementation, + // the input is usually already a typed *model.AgenticCallbackInput + return t + case []*schema.AgenticMessage: + // when callback is injected by graph node, not the component implementation itself, + // the input is the input of Agentic Model interface, which is []*schema.AgenticMessage + return &AgenticCallbackInput{ + Messages: t, + } + default: + return nil + } +} + +// ConvAgenticCallbackOutput converts the callback output to the agentic model callback output. +func ConvAgenticCallbackOutput(src callbacks.CallbackOutput) *AgenticCallbackOutput { + switch t := src.(type) { + case *AgenticCallbackOutput: + // when callback is triggered within component implementation, + // the output is usually already a typed *model.AgenticCallbackOutput + return t + case *schema.AgenticMessage: + // when callback is injected by graph node, not the component implementation itself, + // the output is the output of Agentic Model interface, which is *schema.AgenticMessage + return &AgenticCallbackOutput{ + Message: t, + } + default: + return nil + } +} diff --git a/components/model/agentic_callback_extra_test.go b/components/model/agentic_callback_extra_test.go new file mode 100644 index 000000000..937367477 --- /dev/null +++ b/components/model/agentic_callback_extra_test.go @@ -0,0 +1,35 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package model + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +func TestConvAgenticModel(t *testing.T) { + assert.NotNil(t, ConvAgenticCallbackInput(&AgenticCallbackInput{})) + assert.NotNil(t, ConvAgenticCallbackInput([]*schema.AgenticMessage{})) + assert.Nil(t, ConvAgenticCallbackInput("asd")) + + assert.NotNil(t, ConvAgenticCallbackOutput(&AgenticCallbackOutput{})) + assert.NotNil(t, ConvAgenticCallbackOutput(&schema.AgenticMessage{})) + assert.Nil(t, ConvAgenticCallbackOutput("asd")) +} diff --git a/components/model/interface.go b/components/model/interface.go index deb7b56dd..cf79785bc 100644 --- a/components/model/interface.go +++ b/components/model/interface.go @@ -89,3 +89,15 @@ type ToolCallingChatModel interface { // This method does not modify the current instance, making it safer for concurrent use. WithTools(tools []*schema.ToolInfo) (ToolCallingChatModel, error) } + +// AgenticModel defines the interface for agentic models that support AgenticMessage. +// It provides methods for generating complete and streaming outputs, and supports +// tool calling via the WithTools method. +type AgenticModel interface { + Generate(ctx context.Context, input []*schema.AgenticMessage, opts ...Option) (*schema.AgenticMessage, error) + Stream(ctx context.Context, input []*schema.AgenticMessage, opts ...Option) (*schema.StreamReader[*schema.AgenticMessage], error) + + // WithTools returns a new Model instance with the specified tools bound. + // This method does not modify the current instance, making it safer for concurrent use. + WithTools(tools []*schema.ToolInfo) (AgenticModel, error) +} diff --git a/components/model/option.go b/components/model/option.go index 9fd96116c..936b0fbda 100644 --- a/components/model/option.go +++ b/components/model/option.go @@ -22,21 +22,29 @@ import "github.com/cloudwego/eino/schema" type Options struct { // Temperature is the temperature for the model, which controls the randomness of the model. Temperature *float32 - // MaxTokens is the max number of tokens, if reached the max tokens, the model will stop generating, and mostly return an finish reason of "length". - MaxTokens *int // Model is the model name. Model *string // TopP is the top p for the model, which controls the diversity of the model. TopP *float32 - // Stop is the stop words for the model, which controls the stopping condition of the model. - Stop []string // Tools is a list of tools the model may call. Tools []*schema.ToolInfo + // MaxTokens is the max number of tokens, if reached the max tokens, the model will stop generating, and mostly return a finish reason of "length". + MaxTokens *int + // Stop is the stop words for the model, which controls the stopping condition of the model. + Stop []string + + // Options only available for chat model. + // ToolChoice controls which tool is called by the model. ToolChoice *schema.ToolChoice // AllowedToolNames specifies a list of tool names that the model is allowed to call. // This allows for constraining the model to a specific subset of the available tools. AllowedToolNames []string + + // Options only available for agentic model. + + // AgenticToolChoice controls how the agentic model calls tools. + AgenticToolChoice *schema.AgenticToolChoice } // Option is a call-time option for a ChatModel. Options are immutable and @@ -108,6 +116,7 @@ func WithTools(tools []*schema.ToolInfo) Option { // WithToolChoice sets the tool choice for the model. It also allows for providing a list of // tool names to constrain the model to a specific subset of the available tools. +// Only available for ChatModel. func WithToolChoice(toolChoice schema.ToolChoice, allowedToolNames ...string) Option { return Option{ apply: func(opts *Options) { @@ -117,6 +126,17 @@ func WithToolChoice(toolChoice schema.ToolChoice, allowedToolNames ...string) Op } } +// WithAgenticToolChoice is the option to set tool choice for the agentic model. +// Only available for AgenticModel. +func WithAgenticToolChoice(toolChoice *schema.AgenticToolChoice) Option { + return Option{ + apply: func(opts *Options) { + opts.AgenticToolChoice = toolChoice + }, + } +} + +// WrapImplSpecificOptFn is the option to wrap the implementation specific option function. // WrapImplSpecificOptFn wraps an implementation-specific option function into // an [Option] so it can be passed alongside standard options. // diff --git a/components/model/option_test.go b/components/model/option_test.go index 36872c30e..c836933b7 100644 --- a/components/model/option_test.go +++ b/components/model/option_test.go @@ -82,6 +82,29 @@ func TestOptions(t *testing.T) { convey.So(opts.Tools, convey.ShouldNotBeNil) convey.So(len(opts.Tools), convey.ShouldEqual, 0) }) + + convey.Convey("test agentic tool choice option", t, func() { + var ( + toolChoice = schema.ToolChoiceForced + allowedTools = []*schema.AllowedTool{ + {FunctionName: "agentic_tool"}, + } + ) + opts := GetCommonOptions( + nil, + WithAgenticToolChoice(&schema.AgenticToolChoice{ + Type: toolChoice, + Forced: &schema.AgenticForcedToolChoice{ + Tools: allowedTools, + }, + }), + ) + + convey.So(opts.AgenticToolChoice, convey.ShouldNotBeNil) + convey.So(opts.AgenticToolChoice.Type, convey.ShouldEqual, toolChoice) + convey.So(opts.AgenticToolChoice.Forced, convey.ShouldNotBeNil) + convey.So(opts.AgenticToolChoice.Forced.Tools, convey.ShouldResemble, allowedTools) + }) } type implOption struct { diff --git a/components/prompt/agentic_callback_extra.go b/components/prompt/agentic_callback_extra.go new file mode 100644 index 000000000..315d5a4da --- /dev/null +++ b/components/prompt/agentic_callback_extra.go @@ -0,0 +1,70 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package prompt + +import ( + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/schema" +) + +// AgenticCallbackInput is the input for the callback. +type AgenticCallbackInput struct { + // Variables is the variables for the callback. + Variables map[string]any + // Templates is the agentic templates for the callback. + Templates []schema.AgenticMessagesTemplate + // Extra is the extra information for the callback. + Extra map[string]any +} + +// AgenticCallbackOutput is the output for the callback. +type AgenticCallbackOutput struct { + // Result is the agentic result for the callback. + Result []*schema.AgenticMessage + // Templates is the agentic templates for the callback. + Templates []schema.AgenticMessagesTemplate + // Extra is the extra information for the callback. + Extra map[string]any +} + +// ConvAgenticCallbackInput converts the callback input to the agentic prompt callback input. +func ConvAgenticCallbackInput(src callbacks.CallbackInput) *AgenticCallbackInput { + switch t := src.(type) { + case *AgenticCallbackInput: + return t + case map[string]any: + return &AgenticCallbackInput{ + Variables: t, + } + default: + return nil + } +} + +// ConvAgenticCallbackOutput converts the callback output to the agentic prompt callback output. +func ConvAgenticCallbackOutput(src callbacks.CallbackOutput) *AgenticCallbackOutput { + switch t := src.(type) { + case *AgenticCallbackOutput: + return t + case []*schema.AgenticMessage: + return &AgenticCallbackOutput{ + Result: t, + } + default: + return nil + } +} diff --git a/components/prompt/agentic_callback_extra_test.go b/components/prompt/agentic_callback_extra_test.go new file mode 100644 index 000000000..67982be80 --- /dev/null +++ b/components/prompt/agentic_callback_extra_test.go @@ -0,0 +1,46 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package prompt + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +func TestConvAgenticPrompt(t *testing.T) { + assert.NotNil(t, ConvAgenticCallbackInput(&AgenticCallbackInput{ + Variables: map[string]any{}, + Templates: []schema.AgenticMessagesTemplate{ + &schema.AgenticMessage{}, + }, + })) + assert.NotNil(t, ConvAgenticCallbackInput(map[string]any{})) + assert.Nil(t, ConvAgenticCallbackInput("asd")) + + assert.NotNil(t, ConvAgenticCallbackOutput(&AgenticCallbackOutput{ + Result: []*schema.AgenticMessage{ + {}, + }, + Templates: []schema.AgenticMessagesTemplate{ + &schema.AgenticMessage{}, + }, + })) + assert.NotNil(t, ConvAgenticCallbackOutput([]*schema.AgenticMessage{})) +} diff --git a/components/prompt/agentic_chat_template.go b/components/prompt/agentic_chat_template.go new file mode 100644 index 000000000..41d291065 --- /dev/null +++ b/components/prompt/agentic_chat_template.go @@ -0,0 +1,84 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package prompt + +import ( + "context" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/components" + "github.com/cloudwego/eino/schema" +) + +// FromAgenticMessages creates a new DefaultAgenticChatTemplate from the given templates and format type. +// eg. +// +// template := prompt.FromAgenticMessages(schema.FString, &schema.AgenticMessage{}) +// // in chain, or graph +// chain := compose.NewChain[map[string]any, []*schema.AgenticMessage]() +// chain.AppendAgenticChatTemplate(template) +func FromAgenticMessages(formatType schema.FormatType, templates ...schema.AgenticMessagesTemplate) *DefaultAgenticChatTemplate { + return &DefaultAgenticChatTemplate{ + templates: templates, + formatType: formatType, + } +} + +type DefaultAgenticChatTemplate struct { + templates []schema.AgenticMessagesTemplate + formatType schema.FormatType +} + +func (t *DefaultAgenticChatTemplate) Format(ctx context.Context, vs map[string]any, opts ...Option) (result []*schema.AgenticMessage, err error) { + ctx = callbacks.EnsureRunInfo(ctx, t.GetType(), components.ComponentOfAgenticPrompt) + ctx = callbacks.OnStart(ctx, &AgenticCallbackInput{ + Variables: vs, + Templates: t.templates, + }) + defer func() { + if err != nil { + _ = callbacks.OnError(ctx, err) + } + }() + + result = make([]*schema.AgenticMessage, 0, len(t.templates)) + for _, template := range t.templates { + msgs, err := template.Format(ctx, vs, t.formatType) + if err != nil { + return nil, err + } + + result = append(result, msgs...) + } + + _ = callbacks.OnEnd(ctx, &AgenticCallbackOutput{ + Result: result, + Templates: t.templates, + }) + + return result, nil +} + +// GetType returns the type of the agentic template (DefaultAgentic). +func (t *DefaultAgenticChatTemplate) GetType() string { + return "Default" +} + +// IsCallbacksEnabled checks if the callbacks are enabled for the chat template. +func (t *DefaultAgenticChatTemplate) IsCallbacksEnabled() bool { + return true +} diff --git a/components/prompt/agentic_chat_template_test.go b/components/prompt/agentic_chat_template_test.go new file mode 100644 index 000000000..42d7a8630 --- /dev/null +++ b/components/prompt/agentic_chat_template_test.go @@ -0,0 +1,124 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package prompt + +import ( + "context" + "errors" + "testing" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/schema" + "github.com/stretchr/testify/assert" +) + +type mockAgenticTemplate struct { + err error +} + +func (m *mockAgenticTemplate) Format(ctx context.Context, vs map[string]any, formatType schema.FormatType) ([]*schema.AgenticMessage, error) { + if m.err != nil { + return nil, m.err + } + return []*schema.AgenticMessage{schema.UserAgenticMessage("mocked")}, nil +} + +func TestFromAgenticMessages(t *testing.T) { + t.Run("create template", func(t *testing.T) { + tpl := schema.UserAgenticMessage("hello") + ft := schema.FString + at := FromAgenticMessages(ft, tpl) + + assert.NotNil(t, at) + assert.Equal(t, ft, at.formatType) + assert.Len(t, at.templates, 1) + assert.Same(t, tpl, at.templates[0]) + }) +} + +func TestDefaultAgenticTemplate_GetType(t *testing.T) { + t.Run("get type", func(t *testing.T) { + at := &DefaultAgenticChatTemplate{} + assert.Equal(t, "Default", at.GetType()) + }) +} + +func TestDefaultAgenticTemplate_IsCallbacksEnabled(t *testing.T) { + t.Run("callbacks enabled", func(t *testing.T) { + at := &DefaultAgenticChatTemplate{} + assert.True(t, at.IsCallbacksEnabled()) + }) +} + +func TestDefaultAgenticTemplate_Format(t *testing.T) { + t.Run("success", func(t *testing.T) { + // Mock callback handler + cb := callbacks.NewHandlerBuilder(). + OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { + assert.Equal(t, "Default", info.Type) + return ctx + }). + OnEndFn(func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context { + assert.Equal(t, "Default", info.Type) + return ctx + }). + OnErrorFn(func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { + assert.Fail(t, "unexpected error callback") + return ctx + }). + Build() + + tpl := schema.UserAgenticMessage("hello {val}") + at := FromAgenticMessages(schema.FString, tpl) + + ctx := context.Background() + ctx = callbacks.InitCallbacks(ctx, &callbacks.RunInfo{ + Type: "Default", + Component: "agentic_prompt", + }, cb) + + res, err := at.Format(ctx, map[string]any{"val": "world"}) + assert.NoError(t, err) + assert.Len(t, res, 1) + assert.Equal(t, "hello world", res[0].ContentBlocks[0].UserInputText.Text) + }) + + t.Run("template format error", func(t *testing.T) { + mockErr := errors.New("mock error") + mockTpl := &mockAgenticTemplate{err: mockErr} + at := FromAgenticMessages(schema.FString, mockTpl) + + // Mock callback handler to verify OnError + cb := callbacks.NewHandlerBuilder(). + OnErrorFn(func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { + assert.Equal(t, mockErr, err) + return ctx + }). + Build() + + ctx := context.Background() + ctx = callbacks.InitCallbacks(ctx, &callbacks.RunInfo{ + Type: "Default", + Component: "agentic_prompt", + }, cb) + + res, err := at.Format(ctx, map[string]any{}) + assert.Error(t, err) + assert.Nil(t, res) + assert.Equal(t, mockErr, err) + }) +} diff --git a/components/prompt/callback_extra_test.go b/components/prompt/callback_extra_test.go index 456297e29..ad8a3c0c2 100644 --- a/components/prompt/callback_extra_test.go +++ b/components/prompt/callback_extra_test.go @@ -25,11 +25,21 @@ import ( ) func TestConvPrompt(t *testing.T) { - assert.NotNil(t, ConvCallbackInput(&CallbackInput{})) + assert.NotNil(t, ConvCallbackInput(&CallbackInput{ + Templates: []schema.MessagesTemplate{ + &schema.Message{}, + }, + })) assert.NotNil(t, ConvCallbackInput(map[string]any{})) assert.Nil(t, ConvCallbackInput("asd")) - assert.NotNil(t, ConvCallbackOutput(&CallbackOutput{})) + assert.NotNil(t, ConvCallbackOutput(&CallbackOutput{ + Result: []*schema.Message{ + {}, + }, + Templates: []schema.MessagesTemplate{ + &schema.Message{}, + }, + })) assert.NotNil(t, ConvCallbackOutput([]*schema.Message{})) - assert.Nil(t, ConvCallbackOutput("asd")) } diff --git a/components/prompt/interface.go b/components/prompt/interface.go index eac695eda..2d5a2cbed 100644 --- a/components/prompt/interface.go +++ b/components/prompt/interface.go @@ -23,6 +23,7 @@ import ( ) var _ ChatTemplate = &DefaultChatTemplate{} +var _ AgenticChatTemplate = &DefaultAgenticChatTemplate{} // ChatTemplate formats a variables map into a list of messages for a ChatModel. // @@ -42,3 +43,8 @@ var _ ChatTemplate = &DefaultChatTemplate{} type ChatTemplate interface { Format(ctx context.Context, vs map[string]any, opts ...Option) ([]*schema.Message, error) } + +// AgenticChatTemplate formats variables into a list of agentic messages according to a prompt schema. +type AgenticChatTemplate interface { + Format(ctx context.Context, vs map[string]any, opts ...Option) ([]*schema.AgenticMessage, error) +} diff --git a/components/types.go b/components/types.go index a546ae59f..2b0ad8f0e 100644 --- a/components/types.go +++ b/components/types.go @@ -66,8 +66,12 @@ type Component string const ( // ComponentOfPrompt identifies chat template components. ComponentOfPrompt Component = "ChatTemplate" + // ComponentOfAgenticPrompt identifies agentic template components. + ComponentOfAgenticPrompt Component = "AgenticChatTemplate" // ComponentOfChatModel identifies chat model components. ComponentOfChatModel Component = "ChatModel" + // ComponentOfAgenticModel identifies agentic model components. + ComponentOfAgenticModel Component = "AgenticModel" // ComponentOfEmbedding identifies embedding components. ComponentOfEmbedding Component = "Embedding" // ComponentOfIndexer identifies indexer components. diff --git a/compose/agentic_tools_node.go b/compose/agentic_tools_node.go new file mode 100644 index 000000000..96aef7b72 --- /dev/null +++ b/compose/agentic_tools_node.go @@ -0,0 +1,126 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "context" + + "github.com/cloudwego/eino/schema" +) + +// NewAgenticToolsNode creates a new AgenticToolsNode. +// e.g. +// +// conf := &ToolsNodeConfig{ +// Tools: []tool.BaseTool{invokableTool1, streamableTool2}, +// } +// toolsNode, err := NewAgenticToolsNode(ctx, conf) +func NewAgenticToolsNode(ctx context.Context, conf *ToolsNodeConfig) (*AgenticToolsNode, error) { + tn, err := NewToolNode(ctx, conf) + if err != nil { + return nil, err + } + return &AgenticToolsNode{inner: tn}, nil +} + +type AgenticToolsNode struct { + inner *ToolsNode +} + +func (a *AgenticToolsNode) Invoke(ctx context.Context, input *schema.AgenticMessage, opts ...ToolsNodeOption) ([]*schema.AgenticMessage, error) { + result, err := a.inner.Invoke(ctx, agenticMessageToToolCallMessage(input), opts...) + if err != nil { + return nil, err + } + return toolMessageToAgenticMessage(result), nil +} + +func (a *AgenticToolsNode) Stream(ctx context.Context, input *schema.AgenticMessage, + opts ...ToolsNodeOption) (*schema.StreamReader[[]*schema.AgenticMessage], error) { + result, err := a.inner.Stream(ctx, agenticMessageToToolCallMessage(input), opts...) + if err != nil { + return nil, err + } + return streamToolMessageToAgenticMessage(result), nil +} + +func agenticMessageToToolCallMessage(input *schema.AgenticMessage) *schema.Message { + var tc []schema.ToolCall + for _, block := range input.ContentBlocks { + if block.Type != schema.ContentBlockTypeFunctionToolCall || block.FunctionToolCall == nil { + continue + } + tc = append(tc, schema.ToolCall{ + ID: block.FunctionToolCall.CallID, + Function: schema.FunctionCall{ + Name: block.FunctionToolCall.Name, + Arguments: block.FunctionToolCall.Arguments, + }, + Extra: block.Extra, + }) + } + return &schema.Message{ + Role: schema.Assistant, + ToolCalls: tc, + } +} + +func toolMessageToAgenticMessage(input []*schema.Message) []*schema.AgenticMessage { + var results []*schema.ContentBlock + for _, m := range input { + results = append(results, &schema.ContentBlock{ + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: m.ToolCallID, + Name: m.ToolName, + Result: m.Content, + }, + Extra: m.Extra, + }) + } + return []*schema.AgenticMessage{{ + Role: schema.AgenticRoleTypeUser, + ContentBlocks: results, + }} +} + +func streamToolMessageToAgenticMessage(input *schema.StreamReader[[]*schema.Message]) *schema.StreamReader[[]*schema.AgenticMessage] { + return schema.StreamReaderWithConvert(input, func(t []*schema.Message) ([]*schema.AgenticMessage, error) { + var results []*schema.ContentBlock + for i, m := range t { + if m == nil { + continue + } + results = append(results, &schema.ContentBlock{ + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: m.ToolCallID, + Name: m.ToolName, + Result: m.Content, + }, + StreamingMeta: &schema.StreamingMeta{Index: i}, + Extra: m.Extra, + }) + } + return []*schema.AgenticMessage{{ + Role: schema.AgenticRoleTypeUser, + ContentBlocks: results, + }}, nil + }) +} + +func (a *AgenticToolsNode) GetType() string { return "" } diff --git a/compose/agentic_tools_node_test.go b/compose/agentic_tools_node_test.go new file mode 100644 index 000000000..4641dd8ae --- /dev/null +++ b/compose/agentic_tools_node_test.go @@ -0,0 +1,239 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package compose + +import ( + "io" + "testing" + + "github.com/bytedance/sonic" + "github.com/stretchr/testify/assert" + + "github.com/cloudwego/eino/schema" +) + +func TestAgenticMessageToToolCallMessage(t *testing.T) { + input := &schema.AgenticMessage{ + ContentBlocks: []*schema.ContentBlock{ + { + Type: schema.ContentBlockTypeFunctionToolCall, + FunctionToolCall: &schema.FunctionToolCall{ + CallID: "1", + Name: "name1", + Arguments: "arg1", + }, + }, + { + Type: schema.ContentBlockTypeFunctionToolCall, + FunctionToolCall: &schema.FunctionToolCall{ + CallID: "2", + Name: "name2", + Arguments: "arg2", + }, + }, + { + Type: schema.ContentBlockTypeFunctionToolCall, + FunctionToolCall: &schema.FunctionToolCall{ + CallID: "3", + Name: "name3", + Arguments: "arg3", + }, + }, + }, + } + ret := agenticMessageToToolCallMessage(input) + assert.Equal(t, schema.Assistant, ret.Role) + assert.Equal(t, []schema.ToolCall{ + { + ID: "1", + Function: schema.FunctionCall{ + Name: "name1", + Arguments: "arg1", + }, + }, + { + ID: "2", + Function: schema.FunctionCall{ + Name: "name2", + Arguments: "arg2", + }, + }, + { + ID: "3", + Function: schema.FunctionCall{ + Name: "name3", + Arguments: "arg3", + }, + }, + }, ret.ToolCalls) +} + +func TestToolMessageToAgenticMessage(t *testing.T) { + input := []*schema.Message{ + { + Role: schema.Tool, + Content: "content1", + ToolCallID: "1", + ToolName: "name1", + }, + { + Role: schema.Tool, + Content: "content2", + ToolCallID: "2", + ToolName: "name2", + }, + { + Role: schema.Tool, + Content: "content3", + ToolCallID: "3", + ToolName: "name3", + }, + } + ret := toolMessageToAgenticMessage(input) + assert.Equal(t, 1, len(ret)) + assert.Equal(t, schema.AgenticRoleTypeUser, ret[0].Role) + assert.Equal(t, []*schema.ContentBlock{ + { + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: "1", + Name: "name1", + Result: "content1", + }, + }, + { + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: "2", + Name: "name2", + Result: "content2", + }, + }, + { + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: "3", + Name: "name3", + Result: "content3", + }, + }, + }, ret[0].ContentBlocks) +} + +func TestStreamToolMessageToAgenticMessage(t *testing.T) { + input := schema.StreamReaderFromArray([][]*schema.Message{ + { + { + Role: schema.Tool, + Content: "content1-1", + ToolName: "name1", + ToolCallID: "1", + }, + nil, nil, + }, + { + nil, + { + Role: schema.Tool, + Content: "content2-1", + ToolName: "name2", + ToolCallID: "2", + }, + nil, + }, + { + nil, + { + Role: schema.Tool, + Content: "content2-2", + ToolName: "name2", + ToolCallID: "2", + }, + nil, + }, + { + nil, nil, + { + Role: schema.Tool, + Content: "content3-1", + ToolName: "name3", + ToolCallID: "3", + }, + }, + { + nil, nil, + { + Role: schema.Tool, + Content: "content3-2", + ToolName: "name3", + ToolCallID: "3", + }, + }, + }) + ret := streamToolMessageToAgenticMessage(input) + var chunks [][]*schema.AgenticMessage + for { + chunk, err := ret.Recv() + if err == io.EOF { + break + } + assert.NoError(t, err) + chunks = append(chunks, chunk) + } + result, err := schema.ConcatAgenticMessagesArray(chunks) + assert.NoError(t, err) + + actualStr, err := sonic.MarshalString(result) + assert.NoError(t, err) + + expected := []*schema.AgenticMessage{ + { + Role: schema.AgenticRoleTypeUser, + ContentBlocks: []*schema.ContentBlock{ + { + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: "1", + Name: "name1", + Result: "content1-1", + }, + }, + { + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: "2", + Name: "name2", + Result: "content2-1content2-2", + }, + }, + { + Type: schema.ContentBlockTypeFunctionToolResult, + FunctionToolResult: &schema.FunctionToolResult{ + CallID: "3", + Name: "name3", + Result: "content3-1content3-2", + }, + }, + }, + }, + } + + expectedStr, err := sonic.MarshalString(expected) + assert.NoError(t, err) + + assert.Equal(t, expectedStr, actualStr) +} diff --git a/compose/chain.go b/compose/chain.go index 5e4a8e1c0..abfa6bf1d 100644 --- a/compose/chain.go +++ b/compose/chain.go @@ -174,6 +174,18 @@ func (c *Chain[I, O]) AppendChatModel(node model.BaseChatModel, opts ...GraphAdd return c } +// AppendAgenticModel add a agentic.Model node to the chain. +// e.g. +// +// model, err := openai.NewAgenticModel(ctx, config) +// if err != nil {...} +// chain.AppendAgenticModel(model) +func (c *Chain[I, O]) AppendAgenticModel(node model.AgenticModel, opts ...GraphAddNodeOpt) *Chain[I, O] { + gNode, options := toAgenticModelNode(node, opts...) + c.addNode(gNode, options) + return c +} + // AppendChatTemplate add a ChatTemplate node to the chain. // eg. // @@ -189,11 +201,23 @@ func (c *Chain[I, O]) AppendChatTemplate(node prompt.ChatTemplate, opts ...Graph return c } +// AppendAgenticChatTemplate add a prompt.AgenticChatTemplate node to the chain. +// eg. +// +// chatTemplate, err := prompt.FromAgenticMessages(schema.FString, &schema.AgenticMessage{}) +// +// chain.AppendAgenticChatTemplate(chatTemplate) +func (c *Chain[I, O]) AppendAgenticChatTemplate(node prompt.AgenticChatTemplate, opts ...GraphAddNodeOpt) *Chain[I, O] { + gNode, options := toAgenticChatTemplateNode(node, opts...) + c.addNode(gNode, options) + return c +} + // AppendToolsNode add a ToolsNode node to the chain. // e.g. // -// toolsNode, err := tools.NewToolNode(ctx, &tools.ToolsNodeConfig{ -// Tools: []tools.Tool{...}, +// toolsNode, err := compose.NewToolNode(ctx, &compose.ToolsNodeConfig{ +// Tools: []tools.BaseTool{...}, // }) // // chain.AppendToolsNode(toolsNode) @@ -203,6 +227,20 @@ func (c *Chain[I, O]) AppendToolsNode(node *ToolsNode, opts ...GraphAddNodeOpt) return c } +// AppendAgenticToolsNode add a AgenticToolsNode node to the chain. +// e.g. +// +// toolsNode, err := compose.NewAgenticToolsNode(ctx, &compose.ToolsNodeConfig{ +// Tools: []tools.BaseTool{...}, +// }) +// +// chain.AppendAgenticToolsNode(toolsNode) +func (c *Chain[I, O]) AppendAgenticToolsNode(node *AgenticToolsNode, opts ...GraphAddNodeOpt) *Chain[I, O] { + gNode, options := toAgenticToolsNode(node, opts...) + c.addNode(gNode, options) + return c +} + // AppendDocumentTransformer add a DocumentTransformer node to the chain. // e.g. // diff --git a/compose/chain_branch.go b/compose/chain_branch.go index ec3a433af..84fb11048 100644 --- a/compose/chain_branch.go +++ b/compose/chain_branch.go @@ -146,6 +146,22 @@ func (cb *ChainBranch) AddChatModel(key string, node model.BaseChatModel, opts . return cb.addNode(key, gNode, options) } +// AddAgenticModel adds a agentic.Model node to the branch. +// eg. +// +// model1, err := openai.NewAgenticModel(ctx, &openai.AgenticModelConfig{ +// Model: "gpt-4o", +// }) +// model2, err := openai.NewAgenticModel(ctx, &openai.AgenticModelConfig{ +// Model: "gpt-4o-mini", +// }) +// cb.AddAgenticModel("agentic_model_key_1", model1) +// cb.AddAgenticModel("agentic_model_key_2", model2) +func (cb *ChainBranch) AddAgenticModel(key string, node model.AgenticModel, opts ...GraphAddNodeOpt) *ChainBranch { + gNode, options := toAgenticModelNode(node, opts...) + return cb.addNode(key, gNode, options) +} + // AddChatTemplate adds a ChatTemplate node to the branch. // eg. // @@ -167,11 +183,26 @@ func (cb *ChainBranch) AddChatTemplate(key string, node prompt.ChatTemplate, opt return cb.addNode(key, gNode, options) } +// AddAgenticChatTemplate adds a prompt.AgenticChatTemplate node to the branch. +// eg. +// +// chatTemplate, err := prompt.FromAgenticMessages(schema.FString, &schema.AgenticMessage{}) +// +// cb.AddAgenticChatTemplate("chat_template_key_01", chatTemplate) +// +// chatTemplate2, err := prompt.FromAgenticMessages(schema.FString, &schema.AgenticMessage{}) +// +// cb.AddAgenticChatTemplate("chat_template_key_02", chatTemplate2) +func (cb *ChainBranch) AddAgenticChatTemplate(key string, node prompt.AgenticChatTemplate, opts ...GraphAddNodeOpt) *ChainBranch { + gNode, options := toAgenticChatTemplateNode(node, opts...) + return cb.addNode(key, gNode, options) +} + // AddToolsNode adds a ToolsNode to the branch. // eg. // -// toolsNode, err := tools.NewToolNode(ctx, &tools.ToolsNodeConfig{ -// Tools: []tools.Tool{...}, +// toolsNode, err := compose.NewToolNode(ctx, &compose.ToolsNodeConfig{ +// Tools: []tools.BaseTool{...}, // }) // // cb.AddToolsNode("tools_node_key", toolsNode) @@ -180,6 +211,19 @@ func (cb *ChainBranch) AddToolsNode(key string, node *ToolsNode, opts ...GraphAd return cb.addNode(key, gNode, options) } +// AddAgenticToolsNode adds a AgenticToolsNode to the branch. +// eg. +// +// toolsNode, err := compose.NewAgenticToolsNode(ctx, &compose.ToolsNodeConfig{ +// Tools: []tools.BaseTool{...}, +// }) +// +// cb.AddAgenticToolsNode("tools_node_key", toolsNode) +func (cb *ChainBranch) AddAgenticToolsNode(key string, node *AgenticToolsNode, opts ...GraphAddNodeOpt) *ChainBranch { + gNode, options := toAgenticToolsNode(node, opts...) + return cb.addNode(key, gNode, options) +} + // AddLambda adds a Lambda node to the branch. // eg. // diff --git a/compose/chain_parallel.go b/compose/chain_parallel.go index 64cdf2db1..463140be2 100644 --- a/compose/chain_parallel.go +++ b/compose/chain_parallel.go @@ -70,6 +70,24 @@ func (p *Parallel) AddChatModel(outputKey string, node model.BaseChatModel, opts return p.addNode(outputKey, gNode, options) } +// AddAgenticModel adds a agentic.Model to the parallel. +// eg. +// +// model1, err := openai.NewAgenticModel(ctx, &openai.AgenticModelConfig{ +// Model: "gpt-4o", +// }) +// +// model2, err := openai.NewAgenticModel(ctx, &openai.AgenticModelConfig{ +// Model: "gpt-4o", +// }) +// +// p.AddAgenticModel("output_key1", model1) +// p.AddAgenticModel("output_key2", model2) +func (p *Parallel) AddAgenticModel(outputKey string, node model.AgenticModel, opts ...GraphAddNodeOpt) *Parallel { + gNode, options := toAgenticModelNode(node, append(opts, WithOutputKey(outputKey))...) + return p.addNode(outputKey, gNode, options) +} + // AddChatTemplate adds a chat template to the parallel. // eg. // @@ -84,6 +102,17 @@ func (p *Parallel) AddChatTemplate(outputKey string, node prompt.ChatTemplate, o return p.addNode(outputKey, gNode, options) } +// AddAgenticChatTemplate adds a prompt.AgenticChatTemplate to the parallel. +// eg. +// +// chatTemplate01, err := prompt.FromAgenticMessages(schema.FString, &schema.AgenticMessage{}) +// +// p.AddAgenticChatTemplate("output_key01", chatTemplate01) +func (p *Parallel) AddAgenticChatTemplate(outputKey string, node prompt.AgenticChatTemplate, opts ...GraphAddNodeOpt) *Parallel { + gNode, options := toAgenticChatTemplateNode(node, append(opts, WithOutputKey(outputKey))...) + return p.addNode(outputKey, gNode, options) +} + // AddToolsNode adds a tools node to the parallel. // eg. // @@ -97,6 +126,19 @@ func (p *Parallel) AddToolsNode(outputKey string, node *ToolsNode, opts ...Graph return p.addNode(outputKey, gNode, options) } +// AddAgenticToolsNode adds a tools node to the parallel. +// eg. +// +// toolsNode, err := compose.NewAgenticToolsNode(ctx, &compose.ToolsNodeConfig{ +// Tools: []tool.BaseTool{...}, +// }) +// +// p.AddAgenticToolsNode("output_key01", toolsNode) +func (p *Parallel) AddAgenticToolsNode(outputKey string, node *AgenticToolsNode, opts ...GraphAddNodeOpt) *Parallel { + gNode, options := toAgenticToolsNode(node, append(opts, WithOutputKey(outputKey))...) + return p.addNode(outputKey, gNode, options) +} + // AddLambda adds a lambda node to the parallel. // eg. // diff --git a/compose/component_to_graph_node.go b/compose/component_to_graph_node.go index ab4694f1a..4bd27fe34 100644 --- a/compose/component_to_graph_node.go +++ b/compose/component_to_graph_node.go @@ -101,6 +101,17 @@ func toChatModelNode(node model.BaseChatModel, opts ...GraphAddNodeOpt) (*graphN opts...) } +func toAgenticModelNode(node model.AgenticModel, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { + return toComponentNode( + node, + components.ComponentOfAgenticModel, + node.Generate, + node.Stream, + nil, nil, + opts..., + ) +} + func toChatTemplateNode(node prompt.ChatTemplate, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { return toComponentNode( node, @@ -112,6 +123,16 @@ func toChatTemplateNode(node prompt.ChatTemplate, opts ...GraphAddNodeOpt) (*gra opts...) } +func toAgenticChatTemplateNode(node prompt.AgenticChatTemplate, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { + return toComponentNode( + node, + components.ComponentOfAgenticPrompt, + node.Format, + nil, nil, nil, + opts..., + ) +} + func toDocumentTransformerNode(node document.Transformer, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { return toComponentNode( node, @@ -134,6 +155,17 @@ func toToolsNode(node *ToolsNode, opts ...GraphAddNodeOpt) (*graphNode, *graphAd opts...) } +func toAgenticToolsNode(node *AgenticToolsNode, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { + return toComponentNode( + node, + ComponentOfAgenticToolsNode, + node.Invoke, + node.Stream, + nil, nil, + opts..., + ) +} + func toLambdaNode(node *Lambda, opts ...GraphAddNodeOpt) (*graphNode, *graphAddNodeOpts) { info, options := getNodeInfo(opts...) diff --git a/compose/graph.go b/compose/graph.go index 9370665f0..bcf5ae423 100644 --- a/compose/graph.go +++ b/compose/graph.go @@ -352,6 +352,19 @@ func (g *graph) AddChatModelNode(key string, node model.BaseChatModel, opts ...G return g.addNode(key, gNode, options) } +// AddAgenticModelNode add node that implements agentic.Model. +// e.g. +// +// model, err := openai.NewAgenticModel(ctx, &openai.AgenticModelConfig{ +// Model: "gpt-4o", +// }) +// +// graph.AddAgenticModelNode("agentic_model_node_key", model) +func (g *graph) AddAgenticModelNode(key string, node model.AgenticModel, opts ...GraphAddNodeOpt) error { + gNode, options := toAgenticModelNode(node, opts...) + return g.addNode(key, gNode, options) +} + // AddChatTemplateNode add node that implements prompt.ChatTemplate. // e.g. // @@ -366,10 +379,21 @@ func (g *graph) AddChatTemplateNode(key string, node prompt.ChatTemplate, opts . return g.addNode(key, gNode, options) } -// AddToolsNode adds a node that implements tools.ToolsNode. +// AddAgenticChatTemplateNode add node that implements prompt.AgenticChatTemplate. +// e.g. +// +// chatTemplate, err := prompt.FromAgenticMessages(schema.FString, &schema.AgenticMessage{}) +// +// graph.AddAgenticChatTemplateNode("chat_template_node_key", chatTemplate) +func (g *graph) AddAgenticChatTemplateNode(key string, node prompt.AgenticChatTemplate, opts ...GraphAddNodeOpt) error { + gNode, options := toAgenticChatTemplateNode(node, opts...) + return g.addNode(key, gNode, options) +} + +// AddToolsNode adds a node that implements ToolsNode. // e.g. // -// toolsNode, err := tools.NewToolNode(ctx, &tools.ToolsNodeConfig{}) +// toolsNode, err := compose.NewToolNode(ctx, &compose.ToolsNodeConfig{}) // // graph.AddToolsNode("tools_node_key", toolsNode) func (g *graph) AddToolsNode(key string, node *ToolsNode, opts ...GraphAddNodeOpt) error { @@ -377,6 +401,17 @@ func (g *graph) AddToolsNode(key string, node *ToolsNode, opts ...GraphAddNodeOp return g.addNode(key, gNode, options) } +// AddAgenticToolsNode adds a node that implements AgenticToolsNode. +// e.g. +// +// toolsNode, err := compose.NewAgenticToolsNode(ctx, &compose.ToolsNodeConfig{}) +// +// graph.AddAgenticToolsNode("tools_node_key", toolsNode) +func (g *graph) AddAgenticToolsNode(key string, node *AgenticToolsNode, opts ...GraphAddNodeOpt) error { + gNode, options := toAgenticToolsNode(node, opts...) + return g.addNode(key, gNode, options) +} + // AddDocumentTransformerNode adds a node that implements document.Transformer. // e.g. // diff --git a/compose/types.go b/compose/types.go index 13d925df2..54f8e2be3 100644 --- a/compose/types.go +++ b/compose/types.go @@ -25,13 +25,14 @@ type component = components.Component // built-in component types in graph node. // it represents the type of the most primitive executable object provided by the user. const ( - ComponentOfUnknown component = "Unknown" - ComponentOfGraph component = "Graph" - ComponentOfWorkflow component = "Workflow" - ComponentOfChain component = "Chain" - ComponentOfPassthrough component = "Passthrough" - ComponentOfToolsNode component = "ToolsNode" - ComponentOfLambda component = "Lambda" + ComponentOfUnknown component = "Unknown" + ComponentOfGraph component = "Graph" + ComponentOfWorkflow component = "Workflow" + ComponentOfChain component = "Chain" + ComponentOfPassthrough component = "Passthrough" + ComponentOfToolsNode component = "ToolsNode" + ComponentOfAgenticToolsNode component = "AgenticToolsNode" + ComponentOfLambda component = "Lambda" ) // NodeTriggerMode controls the triggering mode of graph nodes. diff --git a/compose/workflow.go b/compose/workflow.go index c3e4331a3..6b50962bb 100644 --- a/compose/workflow.go +++ b/compose/workflow.go @@ -89,18 +89,36 @@ func (wf *Workflow[I, O]) AddChatModelNode(key string, chatModel model.BaseChatM return wf.initNode(key) } +// AddAgenticModelNode adds an agentic model node and returns it. +func (wf *Workflow[I, O]) AddAgenticModelNode(key string, agenticModel model.AgenticModel, opts ...GraphAddNodeOpt) *WorkflowNode { + _ = wf.g.AddAgenticModelNode(key, agenticModel, opts...) + return wf.initNode(key) +} + // AddChatTemplateNode adds a chat template node and returns it. func (wf *Workflow[I, O]) AddChatTemplateNode(key string, chatTemplate prompt.ChatTemplate, opts ...GraphAddNodeOpt) *WorkflowNode { _ = wf.g.AddChatTemplateNode(key, chatTemplate, opts...) return wf.initNode(key) } +// AddAgenticChatTemplateNode adds an agentic chat template node and returns it. +func (wf *Workflow[I, O]) AddAgenticChatTemplateNode(key string, chatTemplate prompt.AgenticChatTemplate, opts ...GraphAddNodeOpt) *WorkflowNode { + _ = wf.g.AddAgenticChatTemplateNode(key, chatTemplate, opts...) + return wf.initNode(key) +} + // AddToolsNode adds a tools node and returns it. func (wf *Workflow[I, O]) AddToolsNode(key string, tools *ToolsNode, opts ...GraphAddNodeOpt) *WorkflowNode { _ = wf.g.AddToolsNode(key, tools, opts...) return wf.initNode(key) } +// AddAgenticToolsNode adds an agentic tools node and returns it. +func (wf *Workflow[I, O]) AddAgenticToolsNode(key string, tools *AgenticToolsNode, opts ...GraphAddNodeOpt) *WorkflowNode { + _ = wf.g.AddAgenticToolsNode(key, tools, opts...) + return wf.initNode(key) +} + // AddRetrieverNode adds a retriever node and returns it. func (wf *Workflow[I, O]) AddRetrieverNode(key string, retriever retriever.Retriever, opts ...GraphAddNodeOpt) *WorkflowNode { _ = wf.g.AddRetrieverNode(key, retriever, opts...) diff --git a/go.mod b/go.mod index cfa6957cc..0b87a6cab 100644 --- a/go.mod +++ b/go.mod @@ -41,6 +41,7 @@ require ( github.com/yargevad/filepathx v1.0.0 // indirect golang.org/x/arch v0.11.0 // indirect golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 // indirect - golang.org/x/sys v0.26.0 // indirect + golang.org/x/sys v0.29.0 // indirect + golang.org/x/term v0.28.0 // indirect gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect ) diff --git a/go.sum b/go.sum index a80d6399b..5813766b2 100644 --- a/go.sum +++ b/go.sum @@ -117,9 +117,10 @@ golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= -golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.10.0 h1:3R7pNqamzBraeqj/Tj8qt1aQ2HpmlC+Cx/qL/7hn4/c= +golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= +golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg= +golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= diff --git a/internal/concat.go b/internal/concat.go index 2681322ab..fd9b8abc5 100644 --- a/internal/concat.go +++ b/internal/concat.go @@ -99,7 +99,7 @@ func ConcatItems[T any](items []T) (T, error) { if typ.Kind() == reflect.Map { cv, err = concatMaps(v) } else { - cv, err = concatSliceValue(v) + cv, err = ConcatSliceValue(v) } if err != nil { @@ -158,7 +158,7 @@ func concatMaps(ms reflect.Value) (reflect.Value, error) { if v.Type().Elem().Kind() == reflect.Map { cv, err = concatMaps(v) } else { - cv, err = concatSliceValue(v) + cv, err = ConcatSliceValue(v) } if err != nil { @@ -171,7 +171,7 @@ func concatMaps(ms reflect.Value) (reflect.Value, error) { return ret, nil } -func concatSliceValue(val reflect.Value) (reflect.Value, error) { +func ConcatSliceValue(val reflect.Value) (reflect.Value, error) { elmType := val.Type().Elem() if val.Len() == 1 { diff --git a/schema/agentic_message.go b/schema/agentic_message.go new file mode 100644 index 000000000..ead2d866d --- /dev/null +++ b/schema/agentic_message.go @@ -0,0 +1,2089 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package schema + +import ( + "context" + "encoding/json" + "fmt" + "reflect" + "sort" + "strings" + + "github.com/eino-contrib/jsonschema" + + "github.com/cloudwego/eino/internal" + "github.com/cloudwego/eino/schema/claude" + "github.com/cloudwego/eino/schema/gemini" + "github.com/cloudwego/eino/schema/openai" +) + +type ContentBlockType string + +const ( + ContentBlockTypeReasoning ContentBlockType = "reasoning" + ContentBlockTypeUserInputText ContentBlockType = "user_input_text" + ContentBlockTypeUserInputImage ContentBlockType = "user_input_image" + ContentBlockTypeUserInputAudio ContentBlockType = "user_input_audio" + ContentBlockTypeUserInputVideo ContentBlockType = "user_input_video" + ContentBlockTypeUserInputFile ContentBlockType = "user_input_file" + ContentBlockTypeAssistantGenText ContentBlockType = "assistant_gen_text" + ContentBlockTypeAssistantGenImage ContentBlockType = "assistant_gen_image" + ContentBlockTypeAssistantGenAudio ContentBlockType = "assistant_gen_audio" + ContentBlockTypeAssistantGenVideo ContentBlockType = "assistant_gen_video" + ContentBlockTypeFunctionToolCall ContentBlockType = "function_tool_call" + ContentBlockTypeFunctionToolResult ContentBlockType = "function_tool_result" + ContentBlockTypeServerToolCall ContentBlockType = "server_tool_call" + ContentBlockTypeServerToolResult ContentBlockType = "server_tool_result" + ContentBlockTypeMCPToolCall ContentBlockType = "mcp_tool_call" + ContentBlockTypeMCPToolResult ContentBlockType = "mcp_tool_result" + ContentBlockTypeMCPListToolsResult ContentBlockType = "mcp_list_tools_result" + ContentBlockTypeMCPToolApprovalRequest ContentBlockType = "mcp_tool_approval_request" + ContentBlockTypeMCPToolApprovalResponse ContentBlockType = "mcp_tool_approval_response" +) + +type AgenticRoleType string + +const ( + AgenticRoleTypeSystem AgenticRoleType = "system" + AgenticRoleTypeUser AgenticRoleType = "user" + AgenticRoleTypeAssistant AgenticRoleType = "assistant" +) + +type AgenticMessage struct { + // Role is the message role. + Role AgenticRoleType + + // ContentBlocks is the list of content blocks. + ContentBlocks []*ContentBlock + + // ResponseMeta is the response metadata. + ResponseMeta *AgenticResponseMeta + + // Extra is the additional information. + Extra map[string]any +} + +type AgenticResponseMeta struct { + // TokenUsage is the token usage. + TokenUsage *TokenUsage + + // OpenAIExtension is the extension for OpenAI. + OpenAIExtension *openai.ResponseMetaExtension + + // GeminiExtension is the extension for Gemini. + GeminiExtension *gemini.ResponseMetaExtension + + // ClaudeExtension is the extension for Claude. + ClaudeExtension *claude.ResponseMetaExtension + + // Extension is the extension for other models, supplied by the component implementer. + Extension any +} + +type ContentBlock struct { + Type ContentBlockType + + // Reasoning contains the reasoning content generated by the model. + Reasoning *Reasoning + + // UserInputText contains the text content provided by the user. + UserInputText *UserInputText + + // UserInputImage contains the image content provided by the user. + UserInputImage *UserInputImage + + // UserInputAudio contains the audio content provided by the user. + UserInputAudio *UserInputAudio + + // UserInputVideo contains the video content provided by the user. + UserInputVideo *UserInputVideo + + // UserInputFile contains the file content provided by the user. + UserInputFile *UserInputFile + + // AssistantGenText contains the text content generated by the model. + AssistantGenText *AssistantGenText + + // AssistantGenImage contains the image content generated by the model. + AssistantGenImage *AssistantGenImage + + // AssistantGenAudio contains the audio content generated by the model. + AssistantGenAudio *AssistantGenAudio + + // AssistantGenVideo contains the video content generated by the model. + AssistantGenVideo *AssistantGenVideo + + // FunctionToolCall contains the invocation details for a user-defined tool. + FunctionToolCall *FunctionToolCall + + // FunctionToolResult contains the result returned from a user-defined tool call. + FunctionToolResult *FunctionToolResult + + // ServerToolCall contains the invocation details for a provider built-in tool executed on the model server. + ServerToolCall *ServerToolCall + + // ServerToolResult contains the result returned from a provider built-in tool executed on the model server. + ServerToolResult *ServerToolResult + + // MCPToolCall contains the invocation details for an MCP tool managed by the model server. + MCPToolCall *MCPToolCall + + // MCPToolResult contains the result returned from an MCP tool managed by the model server. + MCPToolResult *MCPToolResult + + // MCPListToolsResult contains the list of available MCP tools reported by the model server. + MCPListToolsResult *MCPListToolsResult + + // MCPToolApprovalRequest contains the user approval request for an MCP tool call when required. + MCPToolApprovalRequest *MCPToolApprovalRequest + + // MCPToolApprovalResponse contains the user's approval decision for an MCP tool call. + MCPToolApprovalResponse *MCPToolApprovalResponse + + // StreamingMeta contains metadata for streaming responses. + StreamingMeta *StreamingMeta + + // Extra contains additional information for the content block. + Extra map[string]any +} + +type StreamingMeta struct { + // Index specifies the index position of this block in the final response. + Index int +} + +type UserInputText struct { + // Text is the text content. + Text string +} + +type UserInputImage struct { + // URL is the HTTP/HTTPS link. + URL string + + // Base64Data is the binary data in Base64 encoded string format. + Base64Data string + + // MIMEType is the mime type, e.g. "image/png". + MIMEType string + + // Detail is the quality of the image url. + Detail ImageURLDetail +} + +type UserInputAudio struct { + // URL is the HTTP/HTTPS link. + URL string + + // Base64Data is the binary data in Base64 encoded string format. + Base64Data string + + // MIMEType is the mime type, e.g. "audio/wav". + MIMEType string +} + +type UserInputVideo struct { + // URL is the HTTP/HTTPS link. + URL string + + // Base64Data is the binary data in Base64 encoded string format. + Base64Data string + + // MIMEType is the mime type, e.g. "video/mp4". + MIMEType string +} + +type UserInputFile struct { + // URL is the HTTP/HTTPS link. + URL string + + // Name is the filename. + Name string + + // Base64Data is the binary data in Base64 encoded string format. + Base64Data string + + // MIMEType is the mime type, e.g. "application/pdf". + MIMEType string +} + +type AssistantGenText struct { + // Text is the generated text. + Text string + + // OpenAIExtension is the extension for OpenAI. + OpenAIExtension *openai.AssistantGenTextExtension + + // ClaudeExtension is the extension for Claude. + ClaudeExtension *claude.AssistantGenTextExtension + + // Extension is the extension for other models, supplied by the component implementer. + Extension any +} + +type AssistantGenImage struct { + // URL is the HTTP/HTTPS link. + URL string + + // Base64Data is the binary data in Base64 encoded string format. + Base64Data string + + // MIMEType is the mime type, e.g. "image/png". + MIMEType string +} + +type AssistantGenAudio struct { + // URL is the HTTP/HTTPS link. + URL string + + // Base64Data is the binary data in Base64 encoded string format. + Base64Data string + + // MIMEType is the mime type, e.g. "audio/wav". + MIMEType string +} + +type AssistantGenVideo struct { + // URL is the HTTP/HTTPS link. + URL string + + // Base64Data is the binary data in Base64 encoded string format. + Base64Data string + + // MIMEType is the mime type, e.g. "video/mp4". + MIMEType string +} + +type Reasoning struct { + // Text is either the thought summary or the raw reasoning text itself. + Text string + + // Signature contains encrypted reasoning tokens. + // Required by some models when passing reasoning text back. + Signature string +} + +type FunctionToolCall struct { + // CallID is the unique identifier for the tool call. + CallID string + + // Name specifies the function tool invoked. + Name string + + // Arguments is the JSON string arguments for the function tool call. + Arguments string +} + +type FunctionToolResult struct { + // CallID is the unique identifier for the tool call. + CallID string + + // Name specifies the function tool invoked. + Name string + + // Result is the function tool result returned by the user + Result string +} + +type ServerToolCall struct { + // Name specifies the server-side tool invoked. + // Supplied by the model server (e.g., `web_search` for OpenAI, `googleSearch` for Gemini). + Name string + + // CallID is the unique identifier for the tool call. + // Empty if not provided by the model server. + CallID string + + // Arguments are the raw inputs to the server-side tool, + // supplied by the component implementer. + Arguments any +} + +type ServerToolResult struct { + // Name specifies the server-side tool invoked. + // Supplied by the model server (e.g., `web_search` for OpenAI, `googleSearch` for Gemini). + Name string + + // CallID is the unique identifier for the tool call. + // Empty if not provided by the model server. + CallID string + + // Result refers to the raw output generated by the server-side tool, + // supplied by the component implementer. + Result any +} + +type MCPToolCall struct { + // ServerLabel is the MCP server label used to identify it in tool calls + ServerLabel string + + // ApprovalRequestID is the approval request ID. + ApprovalRequestID string + + // CallID is the unique ID of the tool call. + CallID string + + // Name is the name of the tool to run. + Name string + + // Arguments is the JSON string arguments for the tool call. + Arguments string +} + +type MCPToolResult struct { + // ServerLabel is the MCP server label used to identify it in tool calls + ServerLabel string + + // CallID is the unique ID of the tool call. + CallID string + + // Name is the name of the tool to run. + Name string + + // Result is the JSON string with the tool result. + Result string + + // Error returned when the server fails to run the tool. + Error *MCPToolCallError +} + +type MCPToolCallError struct { + // Code is the error code. + Code *int64 + + // Message is the error message. + Message string +} + +type MCPListToolsResult struct { + // ServerLabel is the MCP server label used to identify it in tool calls. + ServerLabel string + + // Tools is the list of tools available on the server. + Tools []*MCPListToolsItem + + // Error returned when the server fails to list tools. + Error string +} + +type MCPListToolsItem struct { + // Name is the name of the tool. + Name string + + // Description is the description of the tool. + Description string + + // InputSchema is the JSON schema that describes the tool input parameters. + InputSchema *jsonschema.Schema +} + +type MCPToolApprovalRequest struct { + // ID is the approval request ID. + ID string + + // Name is the name of the tool to run. + Name string + + // Arguments is the JSON string arguments for the tool call. + Arguments string + + // ServerLabel is the MCP server label used to identify it in tool calls. + ServerLabel string +} + +type MCPToolApprovalResponse struct { + // ApprovalRequestID is the approval request ID being responded to. + ApprovalRequestID string + + // Approve indicates whether the request is approved. + Approve bool + + // Reason is the rationale for the decision. + // Optional. + Reason string +} + +// SystemAgenticMessage represents a message with AgenticRoleType "system". +func SystemAgenticMessage(text string) *AgenticMessage { + return &AgenticMessage{ + Role: AgenticRoleTypeSystem, + ContentBlocks: []*ContentBlock{NewContentBlock(&UserInputText{Text: text})}, + } +} + +// UserAgenticMessage represents a message with AgenticRoleType "user". +func UserAgenticMessage(text string) *AgenticMessage { + return &AgenticMessage{ + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{NewContentBlock(&UserInputText{Text: text})}, + } +} + +// FunctionToolResultAgenticMessage represents a function tool result message with AgenticRoleType "user". +func FunctionToolResultAgenticMessage(callID, name, result string) *AgenticMessage { + return &AgenticMessage{ + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + NewContentBlock(&FunctionToolResult{ + CallID: callID, + Name: name, + Result: result, + }), + }, + } +} + +type contentBlockVariant interface { + Reasoning | userInputVariant | assistantGenVariant | functionToolCallVariant | serverToolCallVariant | mcpToolCallVariant +} + +type userInputVariant interface { + UserInputText | UserInputImage | UserInputAudio | UserInputVideo | UserInputFile +} + +type assistantGenVariant interface { + AssistantGenText | AssistantGenImage | AssistantGenAudio | AssistantGenVideo +} + +type functionToolCallVariant interface { + FunctionToolCall | FunctionToolResult +} + +type serverToolCallVariant interface { + ServerToolCall | ServerToolResult +} + +type mcpToolCallVariant interface { + MCPToolCall | MCPToolResult | MCPListToolsResult | MCPToolApprovalRequest | MCPToolApprovalResponse +} + +// NewContentBlock creates a new ContentBlock with the given content. +func NewContentBlock[T contentBlockVariant](content *T) *ContentBlock { + switch b := any(content).(type) { + case *Reasoning: + return &ContentBlock{Type: ContentBlockTypeReasoning, Reasoning: b} + case *UserInputText: + return &ContentBlock{Type: ContentBlockTypeUserInputText, UserInputText: b} + case *UserInputImage: + return &ContentBlock{Type: ContentBlockTypeUserInputImage, UserInputImage: b} + case *UserInputAudio: + return &ContentBlock{Type: ContentBlockTypeUserInputAudio, UserInputAudio: b} + case *UserInputVideo: + return &ContentBlock{Type: ContentBlockTypeUserInputVideo, UserInputVideo: b} + case *UserInputFile: + return &ContentBlock{Type: ContentBlockTypeUserInputFile, UserInputFile: b} + case *AssistantGenText: + return &ContentBlock{Type: ContentBlockTypeAssistantGenText, AssistantGenText: b} + case *AssistantGenImage: + return &ContentBlock{Type: ContentBlockTypeAssistantGenImage, AssistantGenImage: b} + case *AssistantGenAudio: + return &ContentBlock{Type: ContentBlockTypeAssistantGenAudio, AssistantGenAudio: b} + case *AssistantGenVideo: + return &ContentBlock{Type: ContentBlockTypeAssistantGenVideo, AssistantGenVideo: b} + case *FunctionToolCall: + return &ContentBlock{Type: ContentBlockTypeFunctionToolCall, FunctionToolCall: b} + case *FunctionToolResult: + return &ContentBlock{Type: ContentBlockTypeFunctionToolResult, FunctionToolResult: b} + case *ServerToolCall: + return &ContentBlock{Type: ContentBlockTypeServerToolCall, ServerToolCall: b} + case *ServerToolResult: + return &ContentBlock{Type: ContentBlockTypeServerToolResult, ServerToolResult: b} + case *MCPToolCall: + return &ContentBlock{Type: ContentBlockTypeMCPToolCall, MCPToolCall: b} + case *MCPToolResult: + return &ContentBlock{Type: ContentBlockTypeMCPToolResult, MCPToolResult: b} + case *MCPListToolsResult: + return &ContentBlock{Type: ContentBlockTypeMCPListToolsResult, MCPListToolsResult: b} + case *MCPToolApprovalRequest: + return &ContentBlock{Type: ContentBlockTypeMCPToolApprovalRequest, MCPToolApprovalRequest: b} + case *MCPToolApprovalResponse: + return &ContentBlock{Type: ContentBlockTypeMCPToolApprovalResponse, MCPToolApprovalResponse: b} + default: + return nil + } +} + +// NewContentBlockChunk creates a new ContentBlock with the given content and streaming metadata. +func NewContentBlockChunk[T contentBlockVariant](content *T, meta *StreamingMeta) *ContentBlock { + block := NewContentBlock(content) + block.StreamingMeta = meta + return block +} + +// AgenticMessagesTemplate is the interface for agentic messages template. +// It's used to render a template to a list of agentic messages. +// e.g. +// +// chatTemplate := prompt.FromAgenticMessages( +// &schema.AgenticMessage{ +// Role: schema.AgenticRoleTypeSystem, +// ContentBlocks: []*schema.ContentBlock{ +// {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "you are an eino helper"}}, +// }, +// }, +// schema.AgenticMessagesPlaceholder("history", false), // <= this will use the value of "history" in params +// ) +// msgs, err := chatTemplate.Format(ctx, params) +type AgenticMessagesTemplate interface { + Format(ctx context.Context, vs map[string]any, formatType FormatType) ([]*AgenticMessage, error) +} + +var _ AgenticMessagesTemplate = &AgenticMessage{} +var _ AgenticMessagesTemplate = AgenticMessagesPlaceholder("", false) + +type agenticMessagesPlaceholder struct { + key string + optional bool +} + +// AgenticMessagesPlaceholder can render a placeholder to a list of agentic messages in params. +// e.g. +// +// placeholder := AgenticMessagesPlaceholder("history", false) +// params := map[string]any{ +// "history": []*schema.AgenticMessage{ +// &schema.AgenticMessage{ +// Role: schema.AgenticRoleTypeSystem, +// ContentBlocks: []*schema.ContentBlock{ +// {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "you are an eino helper"}}, +// }, +// }, +// }, +// } +// chatTemplate := chatTpl := prompt.FromMessages( +// schema.AgenticMessagesPlaceholder("history", false), // <= this will use the value of "history" in params +// ) +// msgs, err := chatTemplate.Format(ctx, params) +func AgenticMessagesPlaceholder(key string, optional bool) AgenticMessagesTemplate { + return &agenticMessagesPlaceholder{ + key: key, + optional: optional, + } +} + +func (p *agenticMessagesPlaceholder) Format(_ context.Context, vs map[string]any, _ FormatType) ([]*AgenticMessage, error) { + v, ok := vs[p.key] + if !ok { + if p.optional { + return []*AgenticMessage{}, nil + } + + return nil, fmt.Errorf("message placeholder format: %s not found", p.key) + } + + msgs, ok := v.([]*AgenticMessage) + if !ok { + return nil, fmt.Errorf("only agentic messages can be used to format message placeholder, key: %v, actual type: %v", p.key, reflect.TypeOf(v)) + } + + return msgs, nil +} + +// Format returns the agentic messages after rendering by the given formatType. +// It formats only the user input fields (UserInputText, UserInputImage, UserInputAudio, UserInputVideo, UserInputFile). +// e.g. +// +// msg := &schema.AgenticMessage{ +// Role: schema.AgenticRoleTypeUser, +// ContentBlocks: []*schema.ContentBlock{ +// {Type: schema.ContentBlockTypeUserInputText, UserInputText: &schema.UserInputText{Text: "hello {name}"}}, +// }, +// } +// msgs, err := msg.Format(ctx, map[string]any{"name": "eino"}, schema.FString) +// // msgs[0].ContentBlocks[0].UserInputText.Text will be "hello eino" +func (m *AgenticMessage) Format(_ context.Context, vs map[string]any, formatType FormatType) ([]*AgenticMessage, error) { + copied := *m + + if len(m.ContentBlocks) > 0 { + copiedBlocks := make([]*ContentBlock, len(m.ContentBlocks)) + for i, block := range m.ContentBlocks { + if block == nil { + copiedBlocks[i] = nil + continue + } + + copiedBlock := *block + var err error + + switch block.Type { + case ContentBlockTypeUserInputText: + if block.UserInputText != nil { + copiedBlock.UserInputText, err = formatUserInputText(block.UserInputText, vs, formatType) + if err != nil { + return nil, err + } + } + case ContentBlockTypeUserInputImage: + if block.UserInputImage != nil { + copiedBlock.UserInputImage, err = formatUserInputImage(block.UserInputImage, vs, formatType) + if err != nil { + return nil, err + } + } + case ContentBlockTypeUserInputAudio: + if block.UserInputAudio != nil { + copiedBlock.UserInputAudio, err = formatUserInputAudio(block.UserInputAudio, vs, formatType) + if err != nil { + return nil, err + } + } + case ContentBlockTypeUserInputVideo: + if block.UserInputVideo != nil { + copiedBlock.UserInputVideo, err = formatUserInputVideo(block.UserInputVideo, vs, formatType) + if err != nil { + return nil, err + } + } + case ContentBlockTypeUserInputFile: + if block.UserInputFile != nil { + copiedBlock.UserInputFile, err = formatUserInputFile(block.UserInputFile, vs, formatType) + if err != nil { + return nil, err + } + } + } + + copiedBlocks[i] = &copiedBlock + } + copied.ContentBlocks = copiedBlocks + } + + return []*AgenticMessage{&copied}, nil +} + +func formatUserInputText(uit *UserInputText, vs map[string]any, formatType FormatType) (*UserInputText, error) { + text, err := formatContent(uit.Text, vs, formatType) + if err != nil { + return nil, err + } + copied := *uit + copied.Text = text + return &copied, nil +} + +func formatUserInputImage(uii *UserInputImage, vs map[string]any, formatType FormatType) (*UserInputImage, error) { + copied := *uii + if uii.URL != "" { + url, err := formatContent(uii.URL, vs, formatType) + if err != nil { + return nil, err + } + copied.URL = url + } + if uii.Base64Data != "" { + base64data, err := formatContent(uii.Base64Data, vs, formatType) + if err != nil { + return nil, err + } + copied.Base64Data = base64data + } + return &copied, nil +} + +func formatUserInputAudio(uia *UserInputAudio, vs map[string]any, formatType FormatType) (*UserInputAudio, error) { + copied := *uia + if uia.URL != "" { + url, err := formatContent(uia.URL, vs, formatType) + if err != nil { + return nil, err + } + copied.URL = url + } + if uia.Base64Data != "" { + base64data, err := formatContent(uia.Base64Data, vs, formatType) + if err != nil { + return nil, err + } + copied.Base64Data = base64data + } + return &copied, nil +} + +func formatUserInputVideo(uiv *UserInputVideo, vs map[string]any, formatType FormatType) (*UserInputVideo, error) { + copied := *uiv + if uiv.URL != "" { + url, err := formatContent(uiv.URL, vs, formatType) + if err != nil { + return nil, err + } + copied.URL = url + } + if uiv.Base64Data != "" { + base64data, err := formatContent(uiv.Base64Data, vs, formatType) + if err != nil { + return nil, err + } + copied.Base64Data = base64data + } + return &copied, nil +} + +func formatUserInputFile(uif *UserInputFile, vs map[string]any, formatType FormatType) (*UserInputFile, error) { + copied := *uif + if uif.URL != "" { + url, err := formatContent(uif.URL, vs, formatType) + if err != nil { + return nil, err + } + copied.URL = url + } + if uif.Name != "" { + name, err := formatContent(uif.Name, vs, formatType) + if err != nil { + return nil, err + } + copied.Name = name + } + if uif.Base64Data != "" { + base64data, err := formatContent(uif.Base64Data, vs, formatType) + if err != nil { + return nil, err + } + copied.Base64Data = base64data + } + return &copied, nil +} + +// ConcatAgenticMessagesArray concatenates multiple streams of AgenticMessage into a single slice of AgenticMessage. +func ConcatAgenticMessagesArray(mas [][]*AgenticMessage) ([]*AgenticMessage, error) { + return buildConcatGenericArray[AgenticMessage](ConcatAgenticMessages)(mas) +} + +// ConcatAgenticMessages concatenates a list of AgenticMessage chunks into a single AgenticMessage. +func ConcatAgenticMessages(msgs []*AgenticMessage) (*AgenticMessage, error) { + var ( + role AgenticRoleType + blocks []*ContentBlock + metas []*AgenticResponseMeta + extra map[string]any + blockIndices []int + indexToBlocks = map[int][]*ContentBlock{} + extraList = make([]map[string]any, 0, len(msgs)) + ) + + if len(msgs) == 1 { + return msgs[0], nil + } + + for idx, msg := range msgs { + if msg == nil { + return nil, fmt.Errorf("message at index %d is nil", idx) + } + + if msg.Role != "" { + if role == "" { + role = msg.Role + } else if role != msg.Role { + return nil, fmt.Errorf("cannot concat messages with different roles: got '%s' and '%s'", role, msg.Role) + } + } + + for _, block := range msg.ContentBlocks { + if block == nil { + continue + } + if block.StreamingMeta == nil { + // Non-streaming block + if len(blockIndices) > 0 { + // Cannot mix streaming and non-streaming blocks + return nil, fmt.Errorf("found non-streaming block after streaming blocks") + } + // Collect non-streaming block + blocks = append(blocks, block) + } else { + // Streaming block + if len(blocks) > 0 { + // Cannot mix non-streaming and streaming blocks + return nil, fmt.Errorf("found streaming block after non-streaming blocks") + } + // Collect streaming block by index + if blocks_, ok := indexToBlocks[block.StreamingMeta.Index]; ok { + indexToBlocks[block.StreamingMeta.Index] = append(blocks_, block) + } else { + blockIndices = append(blockIndices, block.StreamingMeta.Index) + indexToBlocks[block.StreamingMeta.Index] = []*ContentBlock{block} + } + } + } + + if msg.ResponseMeta != nil { + metas = append(metas, msg.ResponseMeta) + } + + if msg.Extra != nil { + extraList = append(extraList, msg.Extra) + } + } + + meta, err := concatAgenticResponseMeta(metas) + if err != nil { + return nil, fmt.Errorf("failed to concat agentic response meta: %w", err) + } + + if len(blockIndices) > 0 { + // All blocks are streaming, concat each group by index + indexToBlock := map[int]*ContentBlock{} + for idx, bs := range indexToBlocks { + var b *ContentBlock + b, err = concatChunksOfSameContentBlock(bs) + if err != nil { + return nil, err + } + indexToBlock[idx] = b + } + blocks = make([]*ContentBlock, 0, len(blockIndices)) + sort.Slice(blockIndices, func(i, j int) bool { + return blockIndices[i] < blockIndices[j] + }) + for _, idx := range blockIndices { + blocks = append(blocks, indexToBlock[idx]) + } + } + + if len(extraList) > 0 { + extra, err = concatExtra(extraList) + if err != nil { + return nil, err + } + } + + return &AgenticMessage{ + Role: role, + ResponseMeta: meta, + ContentBlocks: blocks, + Extra: extra, + }, nil +} + +func concatAgenticResponseMeta(metas []*AgenticResponseMeta) (ret *AgenticResponseMeta, err error) { + if len(metas) == 0 { + return nil, nil + } + + openaiExtensions := make([]*openai.ResponseMetaExtension, 0, len(metas)) + claudeExtensions := make([]*claude.ResponseMetaExtension, 0, len(metas)) + geminiExtensions := make([]*gemini.ResponseMetaExtension, 0, len(metas)) + tokenUsages := make([]*TokenUsage, 0, len(metas)) + + var ( + extType reflect.Type + extensions reflect.Value + ) + + for _, meta := range metas { + if meta.TokenUsage != nil { + tokenUsages = append(tokenUsages, meta.TokenUsage) + } + + var isConsistent bool + + if meta.Extension != nil { + extType, isConsistent = validateExtensionType(extType, meta.Extension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in response meta chunks: '%s' vs '%s'", + extType, reflect.TypeOf(meta.Extension)) + } + if !extensions.IsValid() { + extensions = reflect.MakeSlice(reflect.SliceOf(extType), 0, len(metas)) + } + extensions = reflect.Append(extensions, reflect.ValueOf(meta.Extension)) + } + + if meta.OpenAIExtension != nil { + extType, isConsistent = validateExtensionType(extType, meta.OpenAIExtension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in response meta chunks: '%s' vs '%s'", + extType, reflect.TypeOf(meta.OpenAIExtension)) + } + openaiExtensions = append(openaiExtensions, meta.OpenAIExtension) + } + + if meta.ClaudeExtension != nil { + extType, isConsistent = validateExtensionType(extType, meta.ClaudeExtension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in response meta chunks: '%s' vs '%s'", + extType, reflect.TypeOf(meta.ClaudeExtension)) + } + claudeExtensions = append(claudeExtensions, meta.ClaudeExtension) + } + + if meta.GeminiExtension != nil { + extType, isConsistent = validateExtensionType(extType, meta.GeminiExtension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in response meta chunks: '%s' vs '%s'", + extType, reflect.TypeOf(meta.GeminiExtension)) + } + geminiExtensions = append(geminiExtensions, meta.GeminiExtension) + } + } + + ret = &AgenticResponseMeta{ + TokenUsage: concatTokenUsage(tokenUsages), + } + + if extensions.IsValid() && !extensions.IsZero() { + var extension reflect.Value + extension, err = internal.ConcatSliceValue(extensions) + if err != nil { + return nil, fmt.Errorf("failed to concat extensions: %w", err) + } + ret.Extension = extension.Interface() + } + + if len(openaiExtensions) > 0 { + ret.OpenAIExtension, err = openai.ConcatResponseMetaExtensions(openaiExtensions) + if err != nil { + return nil, fmt.Errorf("failed to concat openai extensions: %w", err) + } + } + + if len(claudeExtensions) > 0 { + ret.ClaudeExtension, err = claude.ConcatResponseMetaExtensions(claudeExtensions) + if err != nil { + return nil, fmt.Errorf("failed to concat claude extensions: %w", err) + } + } + + if len(geminiExtensions) > 0 { + ret.GeminiExtension, err = gemini.ConcatResponseMetaExtensions(geminiExtensions) + if err != nil { + return nil, fmt.Errorf("failed to concat gemini extensions: %w", err) + } + } + + return ret, nil +} + +func concatTokenUsage(usages []*TokenUsage) *TokenUsage { + if len(usages) == 0 { + return nil + } + + ret := &TokenUsage{} + + for _, usage := range usages { + if usage == nil { + continue + } + ret.CompletionTokens += usage.CompletionTokens + ret.CompletionTokensDetails.ReasoningTokens += usage.CompletionTokensDetails.ReasoningTokens + ret.PromptTokens += usage.PromptTokens + ret.PromptTokenDetails.CachedTokens += usage.PromptTokenDetails.CachedTokens + ret.TotalTokens += usage.TotalTokens + } + + return ret +} + +func concatChunksOfSameContentBlock(blocks []*ContentBlock) (*ContentBlock, error) { + if len(blocks) == 0 { + return nil, fmt.Errorf("no content blocks to concat") + } + + blockType := blocks[0].Type + + switch blockType { + case ContentBlockTypeReasoning: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *Reasoning { return b.Reasoning }, + concatReasoning) + + case ContentBlockTypeUserInputText: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *UserInputText { return b.UserInputText }, + concatUserInputTexts) + + case ContentBlockTypeUserInputImage: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *UserInputImage { return b.UserInputImage }, + concatUserInputImages) + + case ContentBlockTypeUserInputAudio: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *UserInputAudio { return b.UserInputAudio }, + concatUserInputAudios) + + case ContentBlockTypeUserInputVideo: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *UserInputVideo { return b.UserInputVideo }, + concatUserInputVideos) + + case ContentBlockTypeUserInputFile: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *UserInputFile { return b.UserInputFile }, + concatUserInputFiles) + + case ContentBlockTypeAssistantGenText: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *AssistantGenText { return b.AssistantGenText }, + concatAssistantGenTexts) + + case ContentBlockTypeAssistantGenImage: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *AssistantGenImage { return b.AssistantGenImage }, + concatAssistantGenImages) + + case ContentBlockTypeAssistantGenAudio: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *AssistantGenAudio { return b.AssistantGenAudio }, + concatAssistantGenAudios) + + case ContentBlockTypeAssistantGenVideo: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *AssistantGenVideo { return b.AssistantGenVideo }, + concatAssistantGenVideos) + + case ContentBlockTypeFunctionToolCall: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *FunctionToolCall { return b.FunctionToolCall }, + concatFunctionToolCalls) + + case ContentBlockTypeFunctionToolResult: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *FunctionToolResult { return b.FunctionToolResult }, + concatFunctionToolResults) + + case ContentBlockTypeServerToolCall: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *ServerToolCall { return b.ServerToolCall }, + concatServerToolCalls) + + case ContentBlockTypeServerToolResult: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *ServerToolResult { return b.ServerToolResult }, + concatServerToolResults) + + case ContentBlockTypeMCPToolCall: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *MCPToolCall { return b.MCPToolCall }, + concatMCPToolCalls) + + case ContentBlockTypeMCPToolResult: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *MCPToolResult { return b.MCPToolResult }, + concatMCPToolResults) + + case ContentBlockTypeMCPListToolsResult: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *MCPListToolsResult { return b.MCPListToolsResult }, + concatMCPListToolsResults) + + case ContentBlockTypeMCPToolApprovalRequest: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *MCPToolApprovalRequest { return b.MCPToolApprovalRequest }, + concatMCPToolApprovalRequests) + + case ContentBlockTypeMCPToolApprovalResponse: + return concatContentBlockHelper(blocks, blockType, + func(b *ContentBlock) *MCPToolApprovalResponse { return b.MCPToolApprovalResponse }, + concatMCPToolApprovalResponses) + + default: + return nil, fmt.Errorf("unknown content block type: %s", blockType) + } +} + +// concatContentBlockHelper is a generic helper function that reduces code duplication +// for concatenating content blocks of a specific type. +func concatContentBlockHelper[T contentBlockVariant]( + blocks []*ContentBlock, + expectedType ContentBlockType, + getter func(*ContentBlock) *T, + concatFunc func([]*T) (*T, error), +) (*ContentBlock, error) { + items, err := genericGetTFromContentBlocks(blocks, func(block *ContentBlock) (*T, error) { + if block.Type != expectedType { + return nil, fmt.Errorf("content block type mismatch: expected '%s', but got '%s'", expectedType, block.Type) + } + item := getter(block) + if item == nil { + return nil, fmt.Errorf("'%s' content is nil", expectedType) + } + return item, nil + }) + if err != nil { + return nil, err + } + + concatenated, err := concatFunc(items) + if err != nil { + return nil, fmt.Errorf("failed to concat '%s' content blocks: %w", expectedType, err) + } + + extras := make([]map[string]any, 0, len(blocks)) + for _, block := range blocks { + if len(block.Extra) > 0 { + extras = append(extras, block.Extra) + } + } + + var extra map[string]any + if len(extras) > 0 { + extra, err = internal.ConcatItems(extras) + if err != nil { + return nil, fmt.Errorf("failed to concat content block extras: %w", err) + } + } + + block := NewContentBlock(concatenated) + block.Extra = extra + + return block, nil +} + +func genericGetTFromContentBlocks[T any](blocks []*ContentBlock, checkAndGetter func(block *ContentBlock) (T, error)) ([]T, error) { + ret := make([]T, 0, len(blocks)) + for _, block := range blocks { + t, err := checkAndGetter(block) + if err != nil { + return nil, err + } + ret = append(ret, t) + } + return ret, nil +} + +func concatReasoning(reasons []*Reasoning) (*Reasoning, error) { + if len(reasons) == 0 { + return nil, fmt.Errorf("no reasoning found") + } + + ret := &Reasoning{} + + for _, r := range reasons { + if r.Text != "" { + ret.Text += r.Text + } + if r.Signature != "" { + ret.Signature += r.Signature + } + } + + return ret, nil +} + +func concatUserInputTexts(texts []*UserInputText) (*UserInputText, error) { + if len(texts) == 0 { + return nil, fmt.Errorf("no user input text found") + } + if len(texts) == 1 { + return texts[0], nil + } + return nil, fmt.Errorf("cannot concat multiple user input texts") +} + +func concatUserInputImages(images []*UserInputImage) (*UserInputImage, error) { + if len(images) == 0 { + return nil, fmt.Errorf("no user input image found") + } + if len(images) == 1 { + return images[0], nil + } + return nil, fmt.Errorf("cannot concat multiple user input images") +} + +func concatUserInputAudios(audios []*UserInputAudio) (*UserInputAudio, error) { + if len(audios) == 0 { + return nil, fmt.Errorf("no user input audio found") + } + if len(audios) == 1 { + return audios[0], nil + } + return nil, fmt.Errorf("cannot concat multiple user input audios") +} + +func concatUserInputVideos(videos []*UserInputVideo) (*UserInputVideo, error) { + if len(videos) == 0 { + return nil, fmt.Errorf("no user input video found") + } + if len(videos) == 1 { + return videos[0], nil + } + return nil, fmt.Errorf("cannot concat multiple user input videos") +} + +func concatUserInputFiles(files []*UserInputFile) (*UserInputFile, error) { + if len(files) == 0 { + return nil, fmt.Errorf("no user input file found") + } + if len(files) == 1 { + return files[0], nil + } + return nil, fmt.Errorf("cannot concat multiple user input files") +} + +func concatAssistantGenTexts(texts []*AssistantGenText) (ret *AssistantGenText, err error) { + if len(texts) == 0 { + return nil, fmt.Errorf("no assistant generated text found") + } + if len(texts) == 1 { + return texts[0], nil + } + + ret = &AssistantGenText{} + + openaiExtensions := make([]*openai.AssistantGenTextExtension, 0, len(texts)) + claudeExtensions := make([]*claude.AssistantGenTextExtension, 0, len(texts)) + + var ( + extType reflect.Type + extensions reflect.Value + ) + + for _, t := range texts { + if t == nil { + continue + } + + ret.Text += t.Text + + var isConsistent bool + + if t.Extension != nil { + extType, isConsistent = validateExtensionType(extType, t.Extension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in assistant generated text chunks: '%s' vs '%s'", + extType, reflect.TypeOf(t.Extension)) + } + if !extensions.IsValid() { + extensions = reflect.MakeSlice(reflect.SliceOf(extType), 0, len(texts)) + } + extensions = reflect.Append(extensions, reflect.ValueOf(t.Extension)) + } + + if t.OpenAIExtension != nil { + extType, isConsistent = validateExtensionType(extType, t.OpenAIExtension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in assistant generated text chunks: '%s' vs '%s'", + extType, reflect.TypeOf(t.OpenAIExtension)) + } + openaiExtensions = append(openaiExtensions, t.OpenAIExtension) + } + + if t.ClaudeExtension != nil { + extType, isConsistent = validateExtensionType(extType, t.ClaudeExtension) + if !isConsistent { + return nil, fmt.Errorf("inconsistent extension types in assistant generated text chunks: '%s' vs '%s'", + extType, reflect.TypeOf(t.ClaudeExtension)) + } + claudeExtensions = append(claudeExtensions, t.ClaudeExtension) + } + } + + if extensions.IsValid() && !extensions.IsZero() { + ret.Extension, err = internal.ConcatSliceValue(extensions) + if err != nil { + return nil, err + } + ret.Extension = extensions.Interface() + } + + if len(openaiExtensions) > 0 { + ret.OpenAIExtension, err = openai.ConcatAssistantGenTextExtensions(openaiExtensions) + if err != nil { + return nil, err + } + } + + if len(claudeExtensions) > 0 { + ret.ClaudeExtension, err = claude.ConcatAssistantGenTextExtensions(claudeExtensions) + if err != nil { + return nil, err + } + } + + return ret, nil +} + +func concatAssistantGenImages(images []*AssistantGenImage) (*AssistantGenImage, error) { + if len(images) == 0 { + return nil, fmt.Errorf("no assistant gen image found") + } + if len(images) == 1 { + return images[0], nil + } + + ret := &AssistantGenImage{} + + for _, img := range images { + if img == nil { + continue + } + + ret.Base64Data += img.Base64Data + + if ret.URL == "" { + ret.URL = img.URL + } else if img.URL != "" && ret.URL != img.URL { + return nil, fmt.Errorf("inconsistent URLs in assistant generated image chunks: '%s' vs '%s'", ret.URL, img.URL) + } + + if ret.MIMEType == "" { + ret.MIMEType = img.MIMEType + } else if img.MIMEType != "" && ret.MIMEType != img.MIMEType { + return nil, fmt.Errorf("inconsistent MIME types in assistant generated image chunks: '%s' vs '%s'", ret.MIMEType, img.MIMEType) + } + } + + return ret, nil +} + +func concatAssistantGenAudios(audios []*AssistantGenAudio) (*AssistantGenAudio, error) { + if len(audios) == 0 { + return nil, fmt.Errorf("no assistant gen audio found") + } + if len(audios) == 1 { + return audios[0], nil + } + + ret := &AssistantGenAudio{} + + for _, audio := range audios { + if audio == nil { + continue + } + + ret.Base64Data += audio.Base64Data + + if ret.URL == "" { + ret.URL = audio.URL + } else if audio.URL != "" && ret.URL != audio.URL { + return nil, fmt.Errorf("inconsistent URLs in assistant generated audio chunks: '%s' vs '%s'", ret.URL, audio.URL) + } + + if ret.MIMEType == "" { + ret.MIMEType = audio.MIMEType + } else if audio.MIMEType != "" && ret.MIMEType != audio.MIMEType { + return nil, fmt.Errorf("inconsistent MIME types in assistant generated audio chunks: '%s' vs '%s'", ret.MIMEType, audio.MIMEType) + } + } + + return ret, nil +} + +func concatAssistantGenVideos(videos []*AssistantGenVideo) (*AssistantGenVideo, error) { + if len(videos) == 0 { + return nil, fmt.Errorf("no assistant gen video found") + } + if len(videos) == 1 { + return videos[0], nil + } + + ret := &AssistantGenVideo{} + + for _, video := range videos { + if video == nil { + continue + } + + ret.Base64Data += video.Base64Data + + if ret.URL == "" { + ret.URL = video.URL + } else if video.URL != "" && ret.URL != video.URL { + return nil, fmt.Errorf("inconsistent URLs in assistant generated video chunks: '%s' vs '%s'", ret.URL, video.URL) + } + + if ret.MIMEType == "" { + ret.MIMEType = video.MIMEType + } else if video.MIMEType != "" && ret.MIMEType != video.MIMEType { + return nil, fmt.Errorf("inconsistent MIME types in assistant generated video chunks: '%s' vs '%s'", ret.MIMEType, video.MIMEType) + } + } + + return ret, nil +} + +func concatFunctionToolCalls(calls []*FunctionToolCall) (*FunctionToolCall, error) { + if len(calls) == 0 { + return nil, fmt.Errorf("no function tool call found") + } + if len(calls) == 1 { + return calls[0], nil + } + + ret := &FunctionToolCall{} + + for _, c := range calls { + if c == nil { + continue + } + + if ret.CallID == "" { + ret.CallID = c.CallID + } else if c.CallID != "" && c.CallID != ret.CallID { + return nil, fmt.Errorf("expected call ID '%s' for function tool call, but got '%s'", ret.CallID, c.CallID) + } + + if ret.Name == "" { + ret.Name = c.Name + } else if c.Name != "" && c.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for function tool call, but got '%s'", ret.Name, c.Name) + } + + ret.Arguments += c.Arguments + } + + return ret, nil +} + +func concatFunctionToolResults(results []*FunctionToolResult) (*FunctionToolResult, error) { + if len(results) == 0 { + return nil, fmt.Errorf("no function tool result found") + } + if len(results) == 1 { + return results[0], nil + } + + ret := &FunctionToolResult{} + + for _, r := range results { + if r == nil { + continue + } + + if ret.CallID == "" { + ret.CallID = r.CallID + } else if r.CallID != "" && r.CallID != ret.CallID { + return nil, fmt.Errorf("expected call ID '%s' for function tool result, but got '%s'", ret.CallID, r.CallID) + } + + if ret.Name == "" { + ret.Name = r.Name + } else if r.Name != "" && r.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for function tool result, but got '%s'", ret.Name, r.Name) + } + + ret.Result += r.Result + } + + return ret, nil +} + +func concatServerToolCalls(calls []*ServerToolCall) (ret *ServerToolCall, err error) { + if len(calls) == 0 { + return nil, fmt.Errorf("no server tool call found") + } + if len(calls) == 1 { + return calls[0], nil + } + + ret = &ServerToolCall{} + + var ( + argsType reflect.Type + argsChunks reflect.Value + ) + + for _, c := range calls { + if c == nil { + continue + } + + if ret.CallID == "" { + ret.CallID = c.CallID + } else if c.CallID != "" && c.CallID != ret.CallID { + return nil, fmt.Errorf("expected call ID '%s' for server tool call, but got '%s'", ret.CallID, c.CallID) + } + + if ret.Name == "" { + ret.Name = c.Name + } else if c.Name != "" && c.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for server tool call, but got '%s'", ret.Name, c.Name) + } + + if c.Arguments != nil { + argsType_ := reflect.TypeOf(c.Arguments) + if argsType == nil { + argsType = argsType_ + argsChunks = reflect.MakeSlice(reflect.SliceOf(argsType), 0, len(calls)) + } else if argsType != argsType_ { + return nil, fmt.Errorf("expected type '%s' for server tool call arguments, but got '%s'", argsType, argsType_) + } + argsChunks = reflect.Append(argsChunks, reflect.ValueOf(c.Arguments)) + } + } + + if argsChunks.IsValid() && !argsChunks.IsZero() { + arguments, err := internal.ConcatSliceValue(argsChunks) + if err != nil { + return nil, err + } + ret.Arguments = arguments.Interface() + } + + return ret, nil +} + +func concatServerToolResults(results []*ServerToolResult) (ret *ServerToolResult, err error) { + if len(results) == 0 { + return nil, fmt.Errorf("no server tool result found") + } + if len(results) == 1 { + return results[0], nil + } + + ret = &ServerToolResult{} + + var ( + resType reflect.Type + resChunks reflect.Value + ) + + for _, r := range results { + if r == nil { + continue + } + + if ret.CallID == "" { + ret.CallID = r.CallID + } else if r.CallID != "" && r.CallID != ret.CallID { + return nil, fmt.Errorf("expected call ID '%s' for server tool result, but got '%s'", ret.CallID, r.CallID) + } + + if ret.Name == "" { + ret.Name = r.Name + } else if r.Name != "" && r.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for server tool result, but got '%s'", ret.Name, r.Name) + } + + if r.Result != nil { + resType_ := reflect.TypeOf(r.Result) + if resType == nil { + resType = resType_ + resChunks = reflect.MakeSlice(reflect.SliceOf(resType), 0, len(results)) + } else if resType != resType_ { + return nil, fmt.Errorf("expected type '%s' for server tool result, but got '%s'", resType, resType_) + } + resChunks = reflect.Append(resChunks, reflect.ValueOf(r.Result)) + } + } + + if resChunks.IsValid() && !resChunks.IsZero() { + result, err := internal.ConcatSliceValue(resChunks) + if err != nil { + return nil, fmt.Errorf("failed to concat server tool result: %v", err) + } + ret.Result = result.Interface() + } + + return ret, nil +} + +func concatMCPToolCalls(calls []*MCPToolCall) (*MCPToolCall, error) { + if len(calls) == 0 { + return nil, fmt.Errorf("no mcp tool call found") + } + if len(calls) == 1 { + return calls[0], nil + } + + ret := &MCPToolCall{} + + for _, c := range calls { + if c == nil { + continue + } + + ret.Arguments += c.Arguments + + if ret.ServerLabel == "" { + ret.ServerLabel = c.ServerLabel + } else if c.ServerLabel != "" && c.ServerLabel != ret.ServerLabel { + return nil, fmt.Errorf("expected server label '%s' for mcp tool call, but got '%s'", ret.ServerLabel, c.ServerLabel) + } + + if ret.CallID == "" { + ret.CallID = c.CallID + } else if c.CallID != "" && c.CallID != ret.CallID { + return nil, fmt.Errorf("expected call ID '%s' for mcp tool call, but got '%s'", ret.CallID, c.CallID) + } + + if ret.Name == "" { + ret.Name = c.Name + } else if c.Name != "" && c.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for mcp tool call, but got '%s'", ret.Name, c.Name) + } + } + + return ret, nil +} + +func concatMCPToolResults(results []*MCPToolResult) (*MCPToolResult, error) { + if len(results) == 0 { + return nil, fmt.Errorf("no mcp tool result found") + } + if len(results) == 1 { + return results[0], nil + } + + ret := &MCPToolResult{} + + for _, r := range results { + if r == nil { + continue + } + + if r.Result != "" { + ret.Result = r.Result + } + + if ret.ServerLabel == "" { + ret.ServerLabel = r.ServerLabel + } else if r.ServerLabel != "" && r.ServerLabel != ret.ServerLabel { + return nil, fmt.Errorf("expected server label '%s' for mcp tool result, but got '%s'", ret.ServerLabel, r.ServerLabel) + } + + if ret.CallID == "" { + ret.CallID = r.CallID + } else if r.CallID != "" && r.CallID != ret.CallID { + return nil, fmt.Errorf("expected call ID '%s' for mcp tool result, but got '%s'", ret.CallID, r.CallID) + } + + if ret.Name == "" { + ret.Name = r.Name + } else if r.Name != "" && r.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for mcp tool result, but got '%s'", ret.Name, r.Name) + } + + if r.Error != nil { + ret.Error = r.Error + } + } + + return ret, nil +} + +func concatMCPListToolsResults(results []*MCPListToolsResult) (*MCPListToolsResult, error) { + if len(results) == 0 { + return nil, fmt.Errorf("no mcp list tools result found") + } + if len(results) == 1 { + return results[0], nil + } + + ret := &MCPListToolsResult{} + + for _, r := range results { + if r == nil { + continue + } + + ret.Tools = append(ret.Tools, r.Tools...) + + if r.Error != "" { + ret.Error = r.Error + } + + if ret.ServerLabel == "" { + ret.ServerLabel = r.ServerLabel + } else if r.ServerLabel != "" && r.ServerLabel != ret.ServerLabel { + return nil, fmt.Errorf("expected server label '%s' for mcp list tools result, but got '%s'", ret.ServerLabel, r.ServerLabel) + } + } + + return ret, nil +} + +func concatMCPToolApprovalRequests(requests []*MCPToolApprovalRequest) (*MCPToolApprovalRequest, error) { + if len(requests) == 0 { + return nil, fmt.Errorf("no mcp tool approval request found") + } + if len(requests) == 1 { + return requests[0], nil + } + + ret := &MCPToolApprovalRequest{} + + for _, r := range requests { + if r == nil { + continue + } + + ret.Arguments += r.Arguments + + if ret.ID == "" { + ret.ID = r.ID + } else if r.ID != "" && r.ID != ret.ID { + return nil, fmt.Errorf("expected request ID '%s' for mcp tool approval request, but got '%s'", ret.ID, r.ID) + } + + if ret.Name == "" { + ret.Name = r.Name + } else if r.Name != "" && r.Name != ret.Name { + return nil, fmt.Errorf("expected tool name '%s' for mcp tool approval request, but got '%s'", ret.Name, r.Name) + } + + if ret.ServerLabel == "" { + ret.ServerLabel = r.ServerLabel + } else if r.ServerLabel != "" && r.ServerLabel != ret.ServerLabel { + return nil, fmt.Errorf("expected server label '%s' for mcp tool approval request, but got '%s'", ret.ServerLabel, r.ServerLabel) + } + } + + return ret, nil +} + +func concatMCPToolApprovalResponses(responses []*MCPToolApprovalResponse) (*MCPToolApprovalResponse, error) { + if len(responses) == 0 { + return nil, fmt.Errorf("no mcp tool approval response found") + } + if len(responses) == 1 { + return responses[0], nil + } + return nil, fmt.Errorf("cannot concat multiple mcp tool approval responses") +} + +// String returns the string representation of AgenticMessage. +func (m *AgenticMessage) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf("role: %s\n", m.Role)) + + if len(m.ContentBlocks) > 0 { + sb.WriteString("content_blocks:\n") + for i, block := range m.ContentBlocks { + if block == nil { + continue + } + sb.WriteString(fmt.Sprintf(" [%d] %s", i, block.String())) + } + } + + if m.ResponseMeta != nil { + sb.WriteString(m.ResponseMeta.String()) + } + + return sb.String() +} + +// String returns the string representation of ContentBlock. +func (b *ContentBlock) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf("type: %s\n", b.Type)) + + switch b.Type { + case ContentBlockTypeReasoning: + if b.Reasoning != nil { + sb.WriteString(b.Reasoning.String()) + } + case ContentBlockTypeUserInputText: + if b.UserInputText != nil { + sb.WriteString(b.UserInputText.String()) + } + case ContentBlockTypeUserInputImage: + if b.UserInputImage != nil { + sb.WriteString(b.UserInputImage.String()) + } + case ContentBlockTypeUserInputAudio: + if b.UserInputAudio != nil { + sb.WriteString(b.UserInputAudio.String()) + } + case ContentBlockTypeUserInputVideo: + if b.UserInputVideo != nil { + sb.WriteString(b.UserInputVideo.String()) + } + case ContentBlockTypeUserInputFile: + if b.UserInputFile != nil { + sb.WriteString(b.UserInputFile.String()) + } + case ContentBlockTypeAssistantGenText: + if b.AssistantGenText != nil { + sb.WriteString(b.AssistantGenText.String()) + } + case ContentBlockTypeAssistantGenImage: + if b.AssistantGenImage != nil { + sb.WriteString(b.AssistantGenImage.String()) + } + case ContentBlockTypeAssistantGenAudio: + if b.AssistantGenAudio != nil { + sb.WriteString(b.AssistantGenAudio.String()) + } + case ContentBlockTypeAssistantGenVideo: + if b.AssistantGenVideo != nil { + sb.WriteString(b.AssistantGenVideo.String()) + } + case ContentBlockTypeFunctionToolCall: + if b.FunctionToolCall != nil { + sb.WriteString(b.FunctionToolCall.String()) + } + case ContentBlockTypeFunctionToolResult: + if b.FunctionToolResult != nil { + sb.WriteString(b.FunctionToolResult.String()) + } + case ContentBlockTypeServerToolCall: + if b.ServerToolCall != nil { + sb.WriteString(b.ServerToolCall.String()) + } + case ContentBlockTypeServerToolResult: + if b.ServerToolResult != nil { + sb.WriteString(b.ServerToolResult.String()) + } + case ContentBlockTypeMCPToolCall: + if b.MCPToolCall != nil { + sb.WriteString(b.MCPToolCall.String()) + } + case ContentBlockTypeMCPToolResult: + if b.MCPToolResult != nil { + sb.WriteString(b.MCPToolResult.String()) + } + case ContentBlockTypeMCPListToolsResult: + if b.MCPListToolsResult != nil { + sb.WriteString(b.MCPListToolsResult.String()) + } + case ContentBlockTypeMCPToolApprovalRequest: + if b.MCPToolApprovalRequest != nil { + sb.WriteString(b.MCPToolApprovalRequest.String()) + } + case ContentBlockTypeMCPToolApprovalResponse: + if b.MCPToolApprovalResponse != nil { + sb.WriteString(b.MCPToolApprovalResponse.String()) + } + } + + if b.StreamingMeta != nil { + sb.WriteString(fmt.Sprintf(" stream_index: %d\n", b.StreamingMeta.Index)) + } + + return sb.String() +} + +// String returns the string representation of Reasoning. +func (r *Reasoning) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" text: %s\n", r.Text)) + if r.Signature != "" { + sb.WriteString(fmt.Sprintf(" signature: %s\n", truncateString(r.Signature, 50))) + } + return sb.String() +} + +// String returns the string representation of UserInputText. +func (u *UserInputText) String() string { + return fmt.Sprintf(" text: %s\n", u.Text) +} + +// String returns the string representation of UserInputImage. +func (u *UserInputImage) String() string { + return formatMediaString(u.URL, u.Base64Data, u.MIMEType, string(u.Detail)) +} + +// String returns the string representation of UserInputAudio. +func (u *UserInputAudio) String() string { + return formatMediaString(u.URL, u.Base64Data, u.MIMEType, "") +} + +// String returns the string representation of UserInputVideo. +func (u *UserInputVideo) String() string { + return formatMediaString(u.URL, u.Base64Data, u.MIMEType, "") +} + +// String returns the string representation of UserInputFile. +func (u *UserInputFile) String() string { + sb := &strings.Builder{} + if u.Name != "" { + sb.WriteString(fmt.Sprintf(" name: %s\n", u.Name)) + } + sb.WriteString(formatMediaString(u.URL, u.Base64Data, u.MIMEType, "")) + return sb.String() +} + +// String returns the string representation of AssistantGenText. +func (a *AssistantGenText) String() string { + return fmt.Sprintf(" text: %s\n", a.Text) +} + +// String returns the string representation of AssistantGenImage. +func (a *AssistantGenImage) String() string { + return formatMediaString(a.URL, a.Base64Data, a.MIMEType, "") +} + +// String returns the string representation of AssistantGenAudio. +func (a *AssistantGenAudio) String() string { + return formatMediaString(a.URL, a.Base64Data, a.MIMEType, "") +} + +// String returns the string representation of AssistantGenVideo. +func (a *AssistantGenVideo) String() string { + return formatMediaString(a.URL, a.Base64Data, a.MIMEType, "") +} + +// String returns the string representation of FunctionToolCall. +func (f *FunctionToolCall) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" call_id: %s\n", f.CallID)) + sb.WriteString(fmt.Sprintf(" name: %s\n", f.Name)) + sb.WriteString(fmt.Sprintf(" arguments: %s\n", f.Arguments)) + return sb.String() +} + +// String returns the string representation of FunctionToolResult. +func (f *FunctionToolResult) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" call_id: %s\n", f.CallID)) + sb.WriteString(fmt.Sprintf(" name: %s\n", f.Name)) + sb.WriteString(fmt.Sprintf(" result: %s\n", f.Result)) + return sb.String() +} + +// String returns the string representation of ServerToolCall. +func (s *ServerToolCall) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" name: %s\n", s.Name)) + if s.CallID != "" { + sb.WriteString(fmt.Sprintf(" call_id: %s\n", s.CallID)) + } + sb.WriteString(fmt.Sprintf(" arguments: %s\n", printAny(s.Arguments))) + return sb.String() +} + +// String returns the string representation of ServerToolResult. +func (s *ServerToolResult) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" name: %s\n", s.Name)) + if s.CallID != "" { + sb.WriteString(fmt.Sprintf(" call_id: %s\n", s.CallID)) + } + sb.WriteString(fmt.Sprintf(" result: %s\n", printAny(s.Result))) + return sb.String() +} + +// String returns the string representation of MCPToolCall. +func (m *MCPToolCall) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" server_label: %s\n", m.ServerLabel)) + sb.WriteString(fmt.Sprintf(" call_id: %s\n", m.CallID)) + sb.WriteString(fmt.Sprintf(" name: %s\n", m.Name)) + sb.WriteString(fmt.Sprintf(" arguments: %s\n", m.Arguments)) + return sb.String() +} + +// String returns the string representation of MCPToolResult. +func (m *MCPToolResult) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" call_id: %s\n", m.CallID)) + sb.WriteString(fmt.Sprintf(" name: %s\n", m.Name)) + sb.WriteString(fmt.Sprintf(" result: %s\n", m.Result)) + if m.Error != nil { + sb.WriteString(fmt.Sprintf(" error: [%d] %s\n", *m.Error.Code, m.Error.Message)) + } + return sb.String() +} + +// String returns the string representation of MCPListToolsResult. +func (m *MCPListToolsResult) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" server_label: %s\n", m.ServerLabel)) + sb.WriteString(fmt.Sprintf(" tools: %d items\n", len(m.Tools))) + for _, tool := range m.Tools { + sb.WriteString(fmt.Sprintf(" - %s: %s\n", tool.Name, tool.Description)) + } + if m.Error != "" { + sb.WriteString(fmt.Sprintf(" error: %s\n", m.Error)) + } + return sb.String() +} + +// String returns the string representation of MCPToolApprovalRequest. +func (m *MCPToolApprovalRequest) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" server_label: %s\n", m.ServerLabel)) + sb.WriteString(fmt.Sprintf(" id: %s\n", m.ID)) + sb.WriteString(fmt.Sprintf(" name: %s\n", m.Name)) + sb.WriteString(fmt.Sprintf(" arguments: %s\n", m.Arguments)) + return sb.String() +} + +// String returns the string representation of MCPToolApprovalResponse. +func (m *MCPToolApprovalResponse) String() string { + sb := &strings.Builder{} + sb.WriteString(fmt.Sprintf(" approval_request_id: %s\n", m.ApprovalRequestID)) + sb.WriteString(fmt.Sprintf(" approve: %v\n", m.Approve)) + if m.Reason != "" { + sb.WriteString(fmt.Sprintf(" reason: %s\n", m.Reason)) + } + return sb.String() +} + +// String returns the string representation of AgenticResponseMeta. +func (a *AgenticResponseMeta) String() string { + sb := &strings.Builder{} + sb.WriteString("response_meta:\n") + if a.TokenUsage != nil { + sb.WriteString(fmt.Sprintf(" token_usage: prompt=%d, completion=%d, total=%d\n", + a.TokenUsage.PromptTokens, + a.TokenUsage.CompletionTokens, + a.TokenUsage.TotalTokens)) + } + return sb.String() +} + +// truncateString truncates a string to maxLen characters, adding "..." if truncated +func truncateString(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen] + "..." +} + +// formatMediaString formats URL, Base64Data, MIMEType and Detail for media content +func formatMediaString(url, base64Data string, mimeType string, detail string) string { + sb := &strings.Builder{} + if url != "" { + sb.WriteString(fmt.Sprintf(" url: %s\n", truncateString(url, 100))) + } + if base64Data != "" { + // Only show first few characters of base64 data + sb.WriteString(fmt.Sprintf(" base64_data: %s... (%d bytes)\n", truncateString(base64Data, 20), len(base64Data))) + } + if mimeType != "" { + sb.WriteString(fmt.Sprintf(" mime_type: %s\n", mimeType)) + } + if detail != "" { + sb.WriteString(fmt.Sprintf(" detail: %s\n", detail)) + } + return sb.String() +} + +func validateExtensionType(expected reflect.Type, actual any) (reflect.Type, bool) { + if actual == nil { + return expected, true + } + actualType := reflect.TypeOf(actual) + if expected == nil { + return actualType, true + } + if expected != actualType { + return expected, false + } + return expected, true +} + +func printAny(a any) string { + switch v := a.(type) { + case string: + return v + case fmt.Stringer: + return v.String() + default: + b, err := json.MarshalIndent(a, "", " ") + if err != nil { + return fmt.Sprintf("%v", a) + } + return string(b) + } +} diff --git a/schema/agentic_message_test.go b/schema/agentic_message_test.go new file mode 100644 index 000000000..e8a1003f5 --- /dev/null +++ b/schema/agentic_message_test.go @@ -0,0 +1,1639 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package schema + +import ( + "context" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConcatAgenticMessages(t *testing.T) { + t.Run("single message", func(t *testing.T) { + msg := &AgenticMessage{ + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Hello", + }, + }, + }, + } + + result, err := ConcatAgenticMessages([]*AgenticMessage{msg}) + assert.NoError(t, err) + assert.Equal(t, msg, result) + }) + + t.Run("nil message in stream", func(t *testing.T) { + msgs := []*AgenticMessage{ + {Role: AgenticRoleTypeAssistant}, + nil, + {Role: AgenticRoleTypeAssistant}, + } + + _, err := ConcatAgenticMessages(msgs) + assert.Error(t, err) + assert.ErrorContains(t, err, "message at index 1 is nil") + }) + + t.Run("different roles", func(t *testing.T) { + msgs := []*AgenticMessage{ + {Role: AgenticRoleTypeUser}, + {Role: AgenticRoleTypeAssistant}, + } + + _, err := ConcatAgenticMessages(msgs) + assert.Error(t, err) + assert.ErrorContains(t, err, "cannot concat messages with different roles") + }) + + t.Run("concat text blocks", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Hello ", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "World!", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Equal(t, AgenticRoleTypeAssistant, result.Role) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "Hello World!", result.ContentBlocks[0].AssistantGenText.Text) + }) + + t.Run("concat reasoning with nil index", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeReasoning, + Reasoning: &Reasoning{ + Text: "First ", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeReasoning, + Reasoning: &Reasoning{ + Text: "Second", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "First Second", result.ContentBlocks[0].Reasoning.Text) + }) + + t.Run("concat reasoning with index", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeReasoning, + Reasoning: &Reasoning{ + Text: "Part1-", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeReasoning, + Reasoning: &Reasoning{ + Text: "Part3", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "Part1-Part3", result.ContentBlocks[0].Reasoning.Text) + }) + + t.Run("concat user input text", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Hello ", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "World!", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "Hello World!", result.ContentBlocks[0].AssistantGenText.Text) + }) + + t.Run("concat assistant gen image", func(t *testing.T) { + base1 := "1" + base2 := "2" + + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenImage, + AssistantGenImage: &AssistantGenImage{ + Base64Data: base1, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenImage, + AssistantGenImage: &AssistantGenImage{ + Base64Data: base2, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "12", result.ContentBlocks[0].AssistantGenImage.Base64Data) + }) + + t.Run("concat user input audio - should error", func(t *testing.T) { + url1 := "https://example.com/audio1.mp3" + url2 := "https://example.com/audio2.mp3" + + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputAudio, + UserInputAudio: &UserInputAudio{ + URL: url1, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputAudio, + UserInputAudio: &UserInputAudio{ + URL: url2, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + _, err := ConcatAgenticMessages(msgs) + assert.Error(t, err) + assert.ErrorContains(t, err, "cannot concat multiple user input audios") + }) + + t.Run("concat user input video - should error", func(t *testing.T) { + url1 := "https://example.com/video1.mp4" + url2 := "https://example.com/video2.mp4" + + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputVideo, + UserInputVideo: &UserInputVideo{ + URL: url1, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputVideo, + UserInputVideo: &UserInputVideo{ + URL: url2, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + _, err := ConcatAgenticMessages(msgs) + assert.Error(t, err) + assert.ErrorContains(t, err, "cannot concat multiple user input videos") + }) + + t.Run("concat assistant gen text", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Generated ", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Text", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "Generated Text", result.ContentBlocks[0].AssistantGenText.Text) + }) + + t.Run("concat assistant gen image", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenImage, + AssistantGenImage: &AssistantGenImage{ + Base64Data: "part1", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenImage, + AssistantGenImage: &AssistantGenImage{ + Base64Data: "part2", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "part1part2", result.ContentBlocks[0].AssistantGenImage.Base64Data) + }) + + t.Run("concat assistant gen audio", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenAudio, + AssistantGenAudio: &AssistantGenAudio{ + Base64Data: "audio1", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenAudio, + AssistantGenAudio: &AssistantGenAudio{ + Base64Data: "audio2", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "audio1audio2", result.ContentBlocks[0].AssistantGenAudio.Base64Data) + }) + + t.Run("concat assistant gen video", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenVideo, + AssistantGenVideo: &AssistantGenVideo{ + Base64Data: "video1", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenVideo, + AssistantGenVideo: &AssistantGenVideo{ + Base64Data: "video2", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "video1video2", result.ContentBlocks[0].AssistantGenVideo.Base64Data) + }) + + t.Run("concat function tool call", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeFunctionToolCall, + FunctionToolCall: &FunctionToolCall{ + CallID: "call_123", + Name: "get_weather", + Arguments: `{"location`, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeFunctionToolCall, + FunctionToolCall: &FunctionToolCall{ + Arguments: `":"NYC"}`, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "call_123", result.ContentBlocks[0].FunctionToolCall.CallID) + assert.Equal(t, "get_weather", result.ContentBlocks[0].FunctionToolCall.Name) + assert.Equal(t, `{"location":"NYC"}`, result.ContentBlocks[0].FunctionToolCall.Arguments) + }) + + t.Run("concat function tool result", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeFunctionToolResult, + FunctionToolResult: &FunctionToolResult{ + CallID: "call_123", + Name: "get_weather", + Result: `{"temp`, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeFunctionToolResult, + FunctionToolResult: &FunctionToolResult{ + Result: `":72}`, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "call_123", result.ContentBlocks[0].FunctionToolResult.CallID) + assert.Equal(t, "get_weather", result.ContentBlocks[0].FunctionToolResult.Name) + assert.Equal(t, `{"temp":72}`, result.ContentBlocks[0].FunctionToolResult.Result) + }) + + t.Run("concat server tool call", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeServerToolCall, + ServerToolCall: &ServerToolCall{ + CallID: "server_call_1", + Name: "server_func", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeServerToolCall, + ServerToolCall: &ServerToolCall{ + Arguments: map[string]any{"key": "value"}, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "server_call_1", result.ContentBlocks[0].ServerToolCall.CallID) + assert.Equal(t, "server_func", result.ContentBlocks[0].ServerToolCall.Name) + assert.NotNil(t, result.ContentBlocks[0].ServerToolCall.Arguments) + }) + + t.Run("concat server tool result", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeServerToolResult, + ServerToolResult: &ServerToolResult{ + CallID: "server_call_1", + Name: "server_func", + Result: "result1", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeServerToolResult, + ServerToolResult: &ServerToolResult{}, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "server_call_1", result.ContentBlocks[0].ServerToolResult.CallID) + assert.Equal(t, "server_func", result.ContentBlocks[0].ServerToolResult.Name) + assert.Equal(t, "result1", result.ContentBlocks[0].ServerToolResult.Result) + }) + + t.Run("concat mcp tool call", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolCall, + MCPToolCall: &MCPToolCall{ + ServerLabel: "mcp-server", + CallID: "mcp_call_1", + Name: "mcp_func", + Arguments: `{"arg`, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolCall, + MCPToolCall: &MCPToolCall{ + Arguments: `":123}`, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "mcp-server", result.ContentBlocks[0].MCPToolCall.ServerLabel) + assert.Equal(t, "mcp_call_1", result.ContentBlocks[0].MCPToolCall.CallID) + assert.Equal(t, "mcp_func", result.ContentBlocks[0].MCPToolCall.Name) + assert.Equal(t, `{"arg":123}`, result.ContentBlocks[0].MCPToolCall.Arguments) + }) + + t.Run("concat mcp tool result", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolResult, + MCPToolResult: &MCPToolResult{ + ServerLabel: "mcp-server", + CallID: "mcp_call_1", + Name: "mcp_func", + Result: `First`, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolResult, + MCPToolResult: &MCPToolResult{ + Result: `Second`, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "mcp-server", result.ContentBlocks[0].MCPToolResult.ServerLabel) + assert.Equal(t, "mcp_call_1", result.ContentBlocks[0].MCPToolResult.CallID) + assert.Equal(t, "mcp_func", result.ContentBlocks[0].MCPToolResult.Name) + assert.Equal(t, `Second`, result.ContentBlocks[0].MCPToolResult.Result) + }) + + t.Run("concat mcp list tools", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPListToolsResult, + MCPListToolsResult: &MCPListToolsResult{ + ServerLabel: "mcp-server", + Tools: []*MCPListToolsItem{ + {Name: "tool1"}, + }, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPListToolsResult, + MCPListToolsResult: &MCPListToolsResult{ + Tools: []*MCPListToolsItem{ + {Name: "tool2"}, + }, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "mcp-server", result.ContentBlocks[0].MCPListToolsResult.ServerLabel) + assert.Len(t, result.ContentBlocks[0].MCPListToolsResult.Tools, 2) + }) + + t.Run("concat mcp tool approval request", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolApprovalRequest, + MCPToolApprovalRequest: &MCPToolApprovalRequest{ + ID: "approval_1", + Name: "approval_func", + ServerLabel: "mcp-server", + Arguments: `{"request`, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolApprovalRequest, + MCPToolApprovalRequest: &MCPToolApprovalRequest{ + Arguments: `":1}`, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "approval_1", result.ContentBlocks[0].MCPToolApprovalRequest.ID) + assert.Equal(t, "approval_func", result.ContentBlocks[0].MCPToolApprovalRequest.Name) + assert.Equal(t, "mcp-server", result.ContentBlocks[0].MCPToolApprovalRequest.ServerLabel) + assert.Equal(t, `{"request":1}`, result.ContentBlocks[0].MCPToolApprovalRequest.Arguments) + }) + + t.Run("concat mcp tool approval response - should error", func(t *testing.T) { + response1 := &MCPToolApprovalResponse{ + ApprovalRequestID: "approval_1", + Approve: false, + } + response2 := &MCPToolApprovalResponse{ + ApprovalRequestID: "approval_1", + Approve: true, + } + + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolApprovalResponse, + MCPToolApprovalResponse: response1, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolApprovalResponse, + MCPToolApprovalResponse: response2, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + _, err := ConcatAgenticMessages(msgs) + assert.Error(t, err) + assert.ErrorContains(t, err, "cannot concat multiple mcp tool approval responses") + }) + + t.Run("concat response meta", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ResponseMeta: &AgenticResponseMeta{ + TokenUsage: &TokenUsage{ + PromptTokens: 10, + CompletionTokens: 5, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ResponseMeta: &AgenticResponseMeta{ + TokenUsage: &TokenUsage{ + PromptTokens: 10, + CompletionTokens: 15, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.NotNil(t, result.ResponseMeta) + assert.Equal(t, 20, result.ResponseMeta.TokenUsage.CompletionTokens) + assert.Equal(t, 20, result.ResponseMeta.TokenUsage.PromptTokens) + }) + + t.Run("mixed streaming and non-streaming blocks error", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Hello", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "World", + }, + // No StreamingMeta - non-streaming + }, + }, + }, + } + + _, err := ConcatAgenticMessages(msgs) + assert.Error(t, err) + assert.ErrorContains(t, err, "found non-streaming block after streaming blocks") + }) + + t.Run("concat MCP tool call", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolCall, + MCPToolCall: &MCPToolCall{ + ServerLabel: "mcp-server", + CallID: "call_456", + Name: "list_files", + Arguments: `{"path`, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeMCPToolCall, + MCPToolCall: &MCPToolCall{ + Arguments: `":"/tmp"}`, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 1) + assert.Equal(t, "mcp-server", result.ContentBlocks[0].MCPToolCall.ServerLabel) + assert.Equal(t, "call_456", result.ContentBlocks[0].MCPToolCall.CallID) + assert.Equal(t, `{"path":"/tmp"}`, result.ContentBlocks[0].MCPToolCall.Arguments) + }) + + t.Run("concat user input text - should error", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputText, + UserInputText: &UserInputText{ + Text: "What is ", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + { + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputText, + UserInputText: &UserInputText{ + Text: "the weather?", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + }, + }, + } + + _, err := ConcatAgenticMessages(msgs) + assert.Error(t, err) + assert.ErrorContains(t, err, "cannot concat multiple user input texts") + }) + + t.Run("multiple stream indexes - sparse indexes", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Index0-", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Index2-", + }, + StreamingMeta: &StreamingMeta{Index: 2}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Part2", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Part2", + }, + StreamingMeta: &StreamingMeta{Index: 2}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 2) + assert.Equal(t, "Index0-Part2", result.ContentBlocks[0].AssistantGenText.Text) + assert.Equal(t, "Index2-Part2", result.ContentBlocks[1].AssistantGenText.Text) + }) + + t.Run("multiple stream indexes - mixed content types", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Text ", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + { + Type: ContentBlockTypeFunctionToolCall, + FunctionToolCall: &FunctionToolCall{ + CallID: "call_1", + Name: "func1", + Arguments: `{"a`, + }, + StreamingMeta: &StreamingMeta{Index: 1}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "Content", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + { + Type: ContentBlockTypeFunctionToolCall, + FunctionToolCall: &FunctionToolCall{ + Arguments: `":1}`, + }, + StreamingMeta: &StreamingMeta{Index: 1}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 2) + assert.Equal(t, "Text Content", result.ContentBlocks[0].AssistantGenText.Text) + assert.Equal(t, "call_1", result.ContentBlocks[1].FunctionToolCall.CallID) + assert.Equal(t, "func1", result.ContentBlocks[1].FunctionToolCall.Name) + assert.Equal(t, `{"a":1}`, result.ContentBlocks[1].FunctionToolCall.Arguments) + }) + + t.Run("multiple stream indexes - three indexes", func(t *testing.T) { + msgs := []*AgenticMessage{ + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "A", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "B", + }, + StreamingMeta: &StreamingMeta{Index: 1}, + }, + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "C", + }, + StreamingMeta: &StreamingMeta{Index: 2}, + }, + }, + }, + { + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "1", + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "2", + }, + StreamingMeta: &StreamingMeta{Index: 1}, + }, + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "3", + }, + StreamingMeta: &StreamingMeta{Index: 2}, + }, + }, + }, + } + + result, err := ConcatAgenticMessages(msgs) + assert.NoError(t, err) + assert.Len(t, result.ContentBlocks, 3) + assert.Equal(t, "A1", result.ContentBlocks[0].AssistantGenText.Text) + assert.Equal(t, "B2", result.ContentBlocks[1].AssistantGenText.Text) + assert.Equal(t, "C3", result.ContentBlocks[2].AssistantGenText.Text) + }) +} + +func TestAgenticMessageFormat(t *testing.T) { + m := &AgenticMessage{ + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputText, + UserInputText: &UserInputText{Text: "{a}"}, + }, + { + Type: ContentBlockTypeUserInputImage, + UserInputImage: &UserInputImage{ + URL: "{b}", + Base64Data: "{c}", + }, + }, + { + Type: ContentBlockTypeUserInputAudio, + UserInputAudio: &UserInputAudio{ + URL: "{d}", + Base64Data: "{e}", + }, + }, + { + Type: ContentBlockTypeUserInputVideo, + UserInputVideo: &UserInputVideo{ + URL: "{f}", + Base64Data: "{g}", + }, + }, + { + Type: ContentBlockTypeUserInputFile, + UserInputFile: &UserInputFile{ + URL: "{h}", + Base64Data: "{i}", + }, + }, + }, + } + + result, err := m.Format(context.Background(), map[string]any{ + "a": "1", "b": "2", "c": "3", "d": "4", "e": "5", "f": "6", "g": "7", "h": "8", "i": "9", + }, FString) + assert.NoError(t, err) + assert.Equal(t, []*AgenticMessage{{ + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputText, + UserInputText: &UserInputText{Text: "1"}, + }, + { + Type: ContentBlockTypeUserInputImage, + UserInputImage: &UserInputImage{ + URL: "2", + Base64Data: "3", + }, + }, + { + Type: ContentBlockTypeUserInputAudio, + UserInputAudio: &UserInputAudio{ + URL: "4", + Base64Data: "5", + }, + }, + { + Type: ContentBlockTypeUserInputVideo, + UserInputVideo: &UserInputVideo{ + URL: "6", + Base64Data: "7", + }, + }, + { + Type: ContentBlockTypeUserInputFile, + UserInputFile: &UserInputFile{ + URL: "8", + Base64Data: "9", + }, + }, + }, + }}, result) +} + +func TestAgenticPlaceholderFormat(t *testing.T) { + ctx := context.Background() + ph := AgenticMessagesPlaceholder("a", false) + + result, err := ph.Format(ctx, map[string]any{ + "a": []*AgenticMessage{{Role: AgenticRoleTypeUser}, {Role: AgenticRoleTypeUser}}, + }, FString) + assert.NoError(t, err) + assert.Equal(t, 2, len(result)) + + ph = AgenticMessagesPlaceholder("a", true) + + result, err = ph.Format(ctx, map[string]any{}, FString) + assert.NoError(t, err) + assert.Equal(t, 0, len(result)) +} + +func ptrOf[T any](v T) *T { + return &v +} + +func TestAgenticMessageString(t *testing.T) { + longBase64 := "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" + + msg := &AgenticMessage{ + Role: AgenticRoleTypeAssistant, + ContentBlocks: []*ContentBlock{ + { + Type: ContentBlockTypeUserInputText, + UserInputText: &UserInputText{ + Text: "What's the weather like in New York City today?", + }, + }, + { + Type: ContentBlockTypeUserInputImage, + UserInputImage: &UserInputImage{ + URL: "https://example.com/weather-map.jpg", + Base64Data: longBase64, + MIMEType: "image/jpeg", + Detail: ImageURLDetailHigh, + }, + }, + { + Type: ContentBlockTypeUserInputAudio, + UserInputAudio: &UserInputAudio{ + URL: "http://audio.com", + Base64Data: "audio_data", + MIMEType: "audio/mp3", + }, + }, + { + Type: ContentBlockTypeUserInputVideo, + UserInputVideo: &UserInputVideo{ + URL: "http://video.com", + Base64Data: "video_data", + MIMEType: "video/mp4", + }, + }, + { + Type: ContentBlockTypeUserInputFile, + UserInputFile: &UserInputFile{ + URL: "http://file.com", + Name: "file.txt", + Base64Data: "file_data", + MIMEType: "text/plain", + }, + }, + { + Type: ContentBlockTypeAssistantGenText, + AssistantGenText: &AssistantGenText{ + Text: "I'll check the current weather in New York City for you.", + }, + }, + { + Type: ContentBlockTypeAssistantGenImage, + AssistantGenImage: &AssistantGenImage{ + URL: "http://gen_image.com", + Base64Data: "gen_image_data", + MIMEType: "image/png", + }, + }, + { + Type: ContentBlockTypeAssistantGenAudio, + AssistantGenAudio: &AssistantGenAudio{ + URL: "http://gen_audio.com", + Base64Data: "gen_audio_data", + MIMEType: "audio/wav", + }, + }, + { + Type: ContentBlockTypeAssistantGenVideo, + AssistantGenVideo: &AssistantGenVideo{ + URL: "http://gen_video.com", + Base64Data: "gen_video_data", + MIMEType: "video/mp4", + }, + }, + { + Type: ContentBlockTypeReasoning, + Reasoning: &Reasoning{ + Text: "First, I need to identify the location (New York City) from the user's query.\n" + + "Then, I should call the weather API to get current conditions.\n" + + "Finally, I'll format the response in a user-friendly way with temperature and conditions.", + Signature: "encrypted_reasoning_content_that_is_very_long_and_will_be_truncated_for_display", + }, + }, + { + Type: ContentBlockTypeFunctionToolCall, + FunctionToolCall: &FunctionToolCall{ + CallID: "call_weather_123", + Name: "get_current_weather", + Arguments: `{"location":"New York City","unit":"fahrenheit"}`, + }, + StreamingMeta: &StreamingMeta{Index: 0}, + }, + { + Type: ContentBlockTypeFunctionToolResult, + FunctionToolResult: &FunctionToolResult{ + CallID: "call_weather_123", + Name: "get_current_weather", + Result: `{"temperature":72,"condition":"sunny","humidity":45,"wind_speed":8}`, + }, + }, + { + Type: ContentBlockTypeServerToolCall, + ServerToolCall: &ServerToolCall{ + Name: "server_tool", + CallID: "call_1", + Arguments: map[string]any{"a": 1}, + }, + }, + { + Type: ContentBlockTypeServerToolResult, + ServerToolResult: &ServerToolResult{ + Name: "server_tool", + CallID: "call_1", + Result: map[string]any{"success": true}, + }, + }, + { + Type: ContentBlockTypeMCPToolApprovalRequest, + MCPToolApprovalRequest: &MCPToolApprovalRequest{ + ID: "req_1", + Name: "mcp_tool", + ServerLabel: "mcp_server", + Arguments: "{}", + }, + }, + { + Type: ContentBlockTypeMCPToolApprovalResponse, + MCPToolApprovalResponse: &MCPToolApprovalResponse{ + ApprovalRequestID: "req_1", + Approve: true, + Reason: "looks good", + }, + }, + { + Type: ContentBlockTypeMCPToolCall, + MCPToolCall: &MCPToolCall{ + ServerLabel: "weather-mcp-server", + CallID: "mcp_forecast_456", + Name: "get_7day_forecast", + Arguments: `{"city":"New York","days":7}`, + }, + }, + { + Type: ContentBlockTypeMCPToolResult, + MCPToolResult: &MCPToolResult{ + CallID: "mcp_forecast_456", + Name: "get_7day_forecast", + Result: `{"status":"partial","days_available":3}`, + Error: &MCPToolCallError{ + Code: ptrOf[int64](503), + Message: "Service temporarily unavailable for full 7-day forecast", + }, + }, + }, + { + Type: ContentBlockTypeMCPListToolsResult, + MCPListToolsResult: &MCPListToolsResult{ + ServerLabel: "weather-mcp-server", + Tools: []*MCPListToolsItem{ + {Name: "get_current_weather", Description: "Get current weather conditions for a location"}, + {Name: "get_7day_forecast", Description: "Get 7-day weather forecast"}, + {Name: "get_weather_alerts", Description: "Get active weather alerts and warnings"}, + }, + }, + }, + }, + ResponseMeta: &AgenticResponseMeta{ + TokenUsage: &TokenUsage{ + PromptTokens: 250, + CompletionTokens: 180, + TotalTokens: 430, + }, + }, + } + + // Print the formatted output + output := msg.String() + + assert.Equal(t, `role: assistant +content_blocks: + [0] type: user_input_text + text: What's the weather like in New York City today? + [1] type: user_input_image + url: https://example.com/weather-map.jpg + base64_data: iVBORw0KGgoAAAANSUhE...... (96 bytes) + mime_type: image/jpeg + detail: high + [2] type: user_input_audio + url: http://audio.com + base64_data: audio_data... (10 bytes) + mime_type: audio/mp3 + [3] type: user_input_video + url: http://video.com + base64_data: video_data... (10 bytes) + mime_type: video/mp4 + [4] type: user_input_file + name: file.txt + url: http://file.com + base64_data: file_data... (9 bytes) + mime_type: text/plain + [5] type: assistant_gen_text + text: I'll check the current weather in New York City for you. + [6] type: assistant_gen_image + url: http://gen_image.com + base64_data: gen_image_data... (14 bytes) + mime_type: image/png + [7] type: assistant_gen_audio + url: http://gen_audio.com + base64_data: gen_audio_data... (14 bytes) + mime_type: audio/wav + [8] type: assistant_gen_video + url: http://gen_video.com + base64_data: gen_video_data... (14 bytes) + mime_type: video/mp4 + [9] type: reasoning + text: First, I need to identify the location (New York City) from the user's query. +Then, I should call the weather API to get current conditions. +Finally, I'll format the response in a user-friendly way with temperature and conditions. + signature: encrypted_reasoning_content_that_is_very_long_and_... + [10] type: function_tool_call + call_id: call_weather_123 + name: get_current_weather + arguments: {"location":"New York City","unit":"fahrenheit"} + stream_index: 0 + [11] type: function_tool_result + call_id: call_weather_123 + name: get_current_weather + result: {"temperature":72,"condition":"sunny","humidity":45,"wind_speed":8} + [12] type: server_tool_call + name: server_tool + call_id: call_1 + arguments: { + "a": 1 +} + [13] type: server_tool_result + name: server_tool + call_id: call_1 + result: { + "success": true +} + [14] type: mcp_tool_approval_request + server_label: mcp_server + id: req_1 + name: mcp_tool + arguments: {} + [15] type: mcp_tool_approval_response + approval_request_id: req_1 + approve: true + reason: looks good + [16] type: mcp_tool_call + server_label: weather-mcp-server + call_id: mcp_forecast_456 + name: get_7day_forecast + arguments: {"city":"New York","days":7} + [17] type: mcp_tool_result + call_id: mcp_forecast_456 + name: get_7day_forecast + result: {"status":"partial","days_available":3} + error: [503] Service temporarily unavailable for full 7-day forecast + [18] type: mcp_list_tools_result + server_label: weather-mcp-server + tools: 3 items + - get_current_weather: Get current weather conditions for a location + - get_7day_forecast: Get 7-day weather forecast + - get_weather_alerts: Get active weather alerts and warnings +response_meta: + token_usage: prompt=250, completion=180, total=430 +`, output) + + t.Run("nil/empty fields", func(t *testing.T) { + msg := &AgenticMessage{ + Role: AgenticRoleTypeUser, + ContentBlocks: []*ContentBlock{ + {Type: ContentBlockTypeUserInputAudio, UserInputAudio: &UserInputAudio{}}, // empty + {Type: ContentBlockTypeUserInputVideo, UserInputVideo: &UserInputVideo{}}, + {Type: ContentBlockTypeUserInputFile, UserInputFile: &UserInputFile{}}, + {Type: ContentBlockTypeAssistantGenImage, AssistantGenImage: &AssistantGenImage{}}, + {Type: ContentBlockTypeAssistantGenAudio, AssistantGenAudio: &AssistantGenAudio{}}, + {Type: ContentBlockTypeAssistantGenVideo, AssistantGenVideo: &AssistantGenVideo{}}, + {Type: ContentBlockTypeServerToolCall, ServerToolCall: &ServerToolCall{Name: "t"}}, // No CallID + {Type: ContentBlockTypeServerToolResult, ServerToolResult: &ServerToolResult{Name: "t"}}, // No CallID + {Type: ContentBlockTypeMCPToolResult, MCPToolResult: &MCPToolResult{Name: "t"}}, // No Error + {Type: ContentBlockTypeMCPListToolsResult, MCPListToolsResult: &MCPListToolsResult{}}, // No Error + {Type: ContentBlockTypeMCPToolApprovalResponse, MCPToolApprovalResponse: &MCPToolApprovalResponse{Approve: false}}, // No Reason + nil, // Nil block in slice + }, + } + + s := msg.String() + assert.Contains(t, s, "type: user_input_audio") + assert.NotContains(t, s, "mime_type:") + assert.Contains(t, s, "type: server_tool_call") + }) + + t.Run("nil content struct in block", func(t *testing.T) { + // Test cases where the specific content struct is nil but type is set + // This shouldn't crash and should just print type + msg := &AgenticMessage{ + ContentBlocks: []*ContentBlock{ + {Type: ContentBlockTypeReasoning, Reasoning: nil}, + {Type: ContentBlockTypeUserInputText, UserInputText: nil}, + {Type: ContentBlockTypeUserInputImage, UserInputImage: nil}, + {Type: ContentBlockTypeUserInputAudio, UserInputAudio: nil}, + {Type: ContentBlockTypeUserInputVideo, UserInputVideo: nil}, + {Type: ContentBlockTypeUserInputFile, UserInputFile: nil}, + {Type: ContentBlockTypeAssistantGenText, AssistantGenText: nil}, + {Type: ContentBlockTypeAssistantGenImage, AssistantGenImage: nil}, + {Type: ContentBlockTypeAssistantGenAudio, AssistantGenAudio: nil}, + {Type: ContentBlockTypeAssistantGenVideo, AssistantGenVideo: nil}, + {Type: ContentBlockTypeFunctionToolCall, FunctionToolCall: nil}, + {Type: ContentBlockTypeFunctionToolResult, FunctionToolResult: nil}, + {Type: ContentBlockTypeServerToolCall, ServerToolCall: nil}, + {Type: ContentBlockTypeServerToolResult, ServerToolResult: nil}, + {Type: ContentBlockTypeMCPToolCall, MCPToolCall: nil}, + {Type: ContentBlockTypeMCPToolResult, MCPToolResult: nil}, + {Type: ContentBlockTypeMCPListToolsResult, MCPListToolsResult: nil}, + {Type: ContentBlockTypeMCPToolApprovalRequest, MCPToolApprovalRequest: nil}, + {Type: ContentBlockTypeMCPToolApprovalResponse, MCPToolApprovalResponse: nil}, + }, + } + s := msg.String() + assert.Contains(t, s, "type: reasoning") + // ensure no panic and basic output present + }) +} + +func TestSystemAgenticMessage(t *testing.T) { + t.Run("basic", func(t *testing.T) { + msg := SystemAgenticMessage("system") + assert.Equal(t, AgenticRoleTypeSystem, msg.Role) + assert.Len(t, msg.ContentBlocks, 1) + assert.Equal(t, "system", msg.ContentBlocks[0].UserInputText.Text) + }) +} + +func TestUserAgenticMessage(t *testing.T) { + t.Run("basic", func(t *testing.T) { + msg := UserAgenticMessage("user") + assert.Equal(t, AgenticRoleTypeUser, msg.Role) + assert.Len(t, msg.ContentBlocks, 1) + assert.Equal(t, "user", msg.ContentBlocks[0].UserInputText.Text) + }) +} + +func TestFunctionToolResultAgenticMessage(t *testing.T) { + t.Run("basic", func(t *testing.T) { + msg := FunctionToolResultAgenticMessage("call_1", "tool_name", "result_str") + assert.Equal(t, AgenticRoleTypeUser, msg.Role) + assert.Len(t, msg.ContentBlocks, 1) + assert.Equal(t, ContentBlockTypeFunctionToolResult, msg.ContentBlocks[0].Type) + assert.Equal(t, "call_1", msg.ContentBlocks[0].FunctionToolResult.CallID) + assert.Equal(t, "tool_name", msg.ContentBlocks[0].FunctionToolResult.Name) + assert.Equal(t, "result_str", msg.ContentBlocks[0].FunctionToolResult.Result) + }) +} + +func TestNewContentBlock(t *testing.T) { + cbType := reflect.TypeOf(ContentBlock{}) + for i := 0; i < cbType.NumField(); i++ { + field := cbType.Field(i) + + // Skip non-content fields + if field.Name == "Type" || field.Name == "Extra" || field.Name == "StreamingMeta" { + continue + } + + t.Run(field.Name, func(t *testing.T) { + // Ensure field is a pointer + assert.Equal(t, reflect.Ptr, field.Type.Kind(), "Field %s should be a pointer", field.Name) + + // Create a new instance of the field's type + // field.Type is *T, so Elem() is T. reflect.New(T) returns *T. + elemType := field.Type.Elem() + inputVal := reflect.New(elemType) + input := inputVal.Interface() + + // Call NewContentBlock (generic) via type switch + var block *ContentBlock + switch v := input.(type) { + case *Reasoning: + block = NewContentBlock(v) + case *UserInputText: + block = NewContentBlock(v) + case *UserInputImage: + block = NewContentBlock(v) + case *UserInputAudio: + block = NewContentBlock(v) + case *UserInputVideo: + block = NewContentBlock(v) + case *UserInputFile: + block = NewContentBlock(v) + case *AssistantGenText: + block = NewContentBlock(v) + case *AssistantGenImage: + block = NewContentBlock(v) + case *AssistantGenAudio: + block = NewContentBlock(v) + case *AssistantGenVideo: + block = NewContentBlock(v) + case *FunctionToolCall: + block = NewContentBlock(v) + case *FunctionToolResult: + block = NewContentBlock(v) + case *ServerToolCall: + block = NewContentBlock(v) + case *ServerToolResult: + block = NewContentBlock(v) + case *MCPToolCall: + block = NewContentBlock(v) + case *MCPToolResult: + block = NewContentBlock(v) + case *MCPListToolsResult: + block = NewContentBlock(v) + case *MCPToolApprovalRequest: + block = NewContentBlock(v) + case *MCPToolApprovalResponse: + block = NewContentBlock(v) + default: + t.Fatalf("unsupported ContentBlock field type: %T", input) + } + + // Assertions + assert.NotNil(t, block, "NewContentBlock should return non-nil for type %T", input) + + // Check if the corresponding field in block is set equals to input + blockVal := reflect.ValueOf(block).Elem() + fieldVal := blockVal.FieldByName(field.Name) + assert.True(t, fieldVal.IsValid(), "Field %s not found in result", field.Name) + assert.Equal(t, input, fieldVal.Interface(), "Field %s should match input", field.Name) + + // Check Type is set + typeVal := blockVal.FieldByName("Type") + assert.NotEmpty(t, typeVal.String(), "Type should be set for %s", field.Name) + }) + } +} diff --git a/schema/claude/consts.go b/schema/claude/consts.go new file mode 100644 index 000000000..714b0362e --- /dev/null +++ b/schema/claude/consts.go @@ -0,0 +1,27 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package claude defines constants for claude. +package claude + +type TextCitationType string + +const ( + TextCitationTypeCharLocation TextCitationType = "char_location" + TextCitationTypePageLocation TextCitationType = "page_location" + TextCitationTypeContentBlockLocation TextCitationType = "content_block_location" + TextCitationTypeWebSearchResultLocation TextCitationType = "web_search_result_location" +) diff --git a/schema/claude/extension.go b/schema/claude/extension.go new file mode 100644 index 000000000..5df8d8907 --- /dev/null +++ b/schema/claude/extension.go @@ -0,0 +1,121 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package claude + +import ( + "fmt" +) + +type ResponseMetaExtension struct { + ID string `json:"id,omitempty"` + StopReason string `json:"stop_reason,omitempty"` +} + +type AssistantGenTextExtension struct { + Citations []*TextCitation `json:"citations,omitempty"` +} + +type TextCitation struct { + Type TextCitationType `json:"type,omitempty"` + + CharLocation *CitationCharLocation `json:"char_location,omitempty"` + PageLocation *CitationPageLocation `json:"page_location,omitempty"` + ContentBlockLocation *CitationContentBlockLocation `json:"content_block_location,omitempty"` + WebSearchResultLocation *CitationWebSearchResultLocation `json:"web_search_result_location,omitempty"` +} + +type CitationCharLocation struct { + CitedText string `json:"cited_text,omitempty"` + + DocumentTitle string `json:"document_title,omitempty"` + DocumentIndex int `json:"document_index,omitempty"` + + StartCharIndex int `json:"start_char_index,omitempty"` + EndCharIndex int `json:"end_char_index,omitempty"` +} + +type CitationPageLocation struct { + CitedText string `json:"cited_text,omitempty"` + + DocumentTitle string `json:"document_title,omitempty"` + DocumentIndex int `json:"document_index,omitempty"` + + StartPageNumber int `json:"start_page_number,omitempty"` + EndPageNumber int `json:"end_page_number,omitempty"` +} + +type CitationContentBlockLocation struct { + CitedText string `json:"cited_text,omitempty"` + + DocumentTitle string `json:"document_title,omitempty"` + DocumentIndex int `json:"document_index,omitempty"` + + StartBlockIndex int `json:"start_block_index,omitempty"` + EndBlockIndex int `json:"end_block_index,omitempty"` +} + +type CitationWebSearchResultLocation struct { + CitedText string `json:"cited_text,omitempty"` + + Title string `json:"title,omitempty"` + URL string `json:"url,omitempty"` + + EncryptedIndex string `json:"encrypted_index,omitempty"` +} + +// ConcatAssistantGenTextExtensions concatenates multiple AssistantGenTextExtension chunks into a single one. +func ConcatAssistantGenTextExtensions(chunks []*AssistantGenTextExtension) (*AssistantGenTextExtension, error) { + if len(chunks) == 0 { + return nil, fmt.Errorf("no assistant generated text extension found") + } + if len(chunks) == 1 { + return chunks[0], nil + } + + ret := &AssistantGenTextExtension{ + Citations: make([]*TextCitation, 0, len(chunks)), + } + + for _, ext := range chunks { + ret.Citations = append(ret.Citations, ext.Citations...) + } + + return ret, nil +} + +// ConcatResponseMetaExtensions concatenates multiple ResponseMetaExtension chunks into a single one. +func ConcatResponseMetaExtensions(chunks []*ResponseMetaExtension) (*ResponseMetaExtension, error) { + if len(chunks) == 0 { + return nil, fmt.Errorf("no response meta extension found") + } + if len(chunks) == 1 { + return chunks[0], nil + } + + ret := &ResponseMetaExtension{} + + for _, ext := range chunks { + if ext.ID != "" { + ret.ID = ext.ID + } + if ext.StopReason != "" { + ret.StopReason = ext.StopReason + } + } + + return ret, nil +} diff --git a/schema/claude/extension_test.go b/schema/claude/extension_test.go new file mode 100644 index 000000000..474fe740b --- /dev/null +++ b/schema/claude/extension_test.go @@ -0,0 +1,190 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package claude + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConcatAssistantGenTextExtensions(t *testing.T) { + t.Run("multiple extensions - concatenates all citations", func(t *testing.T) { + exts := []*AssistantGenTextExtension{ + { + Citations: []*TextCitation{ + { + Type: "char_location", + CharLocation: &CitationCharLocation{ + CitedText: "citation 1", + DocumentIndex: 0, + }, + }, + }, + }, + { + Citations: []*TextCitation{ + { + Type: "page_location", + PageLocation: &CitationPageLocation{ + CitedText: "citation 2", + StartPageNumber: 1, + EndPageNumber: 2, + }, + }, + { + Type: "web_search_result_location", + WebSearchResultLocation: &CitationWebSearchResultLocation{ + CitedText: "citation 3", + URL: "https://example.com", + }, + }, + }, + }, + { + Citations: []*TextCitation{ + { + Type: "content_block_location", + ContentBlockLocation: &CitationContentBlockLocation{ + CitedText: "citation 4", + StartBlockIndex: 0, + EndBlockIndex: 5, + }, + }, + }, + }, + } + + result, err := ConcatAssistantGenTextExtensions(exts) + assert.NoError(t, err) + assert.Len(t, result.Citations, 4) + assert.Equal(t, "citation 1", result.Citations[0].CharLocation.CitedText) + assert.Equal(t, "citation 2", result.Citations[1].PageLocation.CitedText) + assert.Equal(t, "citation 3", result.Citations[2].WebSearchResultLocation.CitedText) + assert.Equal(t, "citation 4", result.Citations[3].ContentBlockLocation.CitedText) + }) + + t.Run("mixed empty and non-empty citations", func(t *testing.T) { + exts := []*AssistantGenTextExtension{ + {Citations: nil}, + { + Citations: []*TextCitation{ + { + Type: "char_location", + CharLocation: &CitationCharLocation{ + CitedText: "text1", + }, + }, + }, + }, + {Citations: []*TextCitation{}}, + { + Citations: []*TextCitation{ + { + Type: "page_location", + PageLocation: &CitationPageLocation{ + CitedText: "text2", + }, + }, + }, + }, + } + + result, err := ConcatAssistantGenTextExtensions(exts) + assert.NoError(t, err) + assert.Len(t, result.Citations, 2) + assert.Equal(t, "text1", result.Citations[0].CharLocation.CitedText) + assert.Equal(t, "text2", result.Citations[1].PageLocation.CitedText) + }) + + t.Run("streaming scenario - citations arrive in chunks", func(t *testing.T) { + // Simulates streaming where citations arrive progressively + exts := []*AssistantGenTextExtension{ + { + Citations: []*TextCitation{ + {Type: "char_location", CharLocation: &CitationCharLocation{CitedText: "chunk1"}}, + }, + }, + { + Citations: []*TextCitation{ + {Type: "char_location", CharLocation: &CitationCharLocation{CitedText: "chunk2"}}, + }, + }, + { + Citations: []*TextCitation{ + {Type: "char_location", CharLocation: &CitationCharLocation{CitedText: "chunk3"}}, + }, + }, + } + + result, err := ConcatAssistantGenTextExtensions(exts) + assert.NoError(t, err) + assert.Len(t, result.Citations, 3) + assert.Equal(t, "chunk1", result.Citations[0].CharLocation.CitedText) + assert.Equal(t, "chunk2", result.Citations[1].CharLocation.CitedText) + assert.Equal(t, "chunk3", result.Citations[2].CharLocation.CitedText) + }) +} + +func TestConcatResponseMetaExtensions(t *testing.T) { + t.Run("multiple extensions - takes last non-empty values", func(t *testing.T) { + exts := []*ResponseMetaExtension{ + { + ID: "msg_1", + StopReason: "stop_1", + }, + { + ID: "msg_2", + StopReason: "", + }, + { + ID: "", + StopReason: "stop_3", + }, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "msg_2", result.ID) // Last non-empty ID + assert.Equal(t, "stop_3", result.StopReason) // Last non-empty StopReason + }) + + t.Run("all empty fields", func(t *testing.T) { + exts := []*ResponseMetaExtension{ + {ID: "", StopReason: ""}, + {ID: "", StopReason: ""}, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "", result.ID) + assert.Equal(t, "", result.StopReason) + }) + + t.Run("streaming scenario - ID in first chunk, StopReason in last", func(t *testing.T) { + exts := []*ResponseMetaExtension{ + {ID: "msg_stream_123", StopReason: ""}, + {ID: "", StopReason: ""}, + {ID: "", StopReason: "end_turn"}, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "msg_stream_123", result.ID) + assert.Equal(t, "end_turn", result.StopReason) + }) +} diff --git a/schema/gemini/extension.go b/schema/gemini/extension.go new file mode 100644 index 000000000..efbc4f4bd --- /dev/null +++ b/schema/gemini/extension.go @@ -0,0 +1,115 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package gemini defines the extension for gemini. +package gemini + +import ( + "fmt" +) + +type ResponseMetaExtension struct { + ID string `json:"id,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` + GroundingMeta *GroundingMetadata `json:"grounding_meta,omitempty"` +} + +type GroundingMetadata struct { + // List of supporting references retrieved from specified grounding source. + GroundingChunks []*GroundingChunk `json:"grounding_chunks,omitempty"` + // Optional. List of grounding support. + GroundingSupports []*GroundingSupport `json:"grounding_supports,omitempty"` + // Optional. Google search entry for the following-up web searches. + SearchEntryPoint *SearchEntryPoint `json:"search_entry_point,omitempty"` + // Optional. Web search queries for the following-up web search. + WebSearchQueries []string `json:"web_search_queries,omitempty"` +} + +type GroundingChunk struct { + // Grounding chunk from the web. + Web *GroundingChunkWeb `json:"web,omitempty"` +} + +// GroundingChunkWeb is the chunk from the web. +type GroundingChunkWeb struct { + // Domain of the (original) URI. This field is not supported in Gemini API. + Domain string `json:"domain,omitempty"` + // Title of the chunk. + Title string `json:"title,omitempty"` + // URI reference of the chunk. + URI string `json:"uri,omitempty"` +} + +type GroundingSupport struct { + // Confidence score of the support references. Ranges from 0 to 1. 1 is the most confident. + // For Gemini 2.0 and before, this list must have the same size as the grounding_chunk_indices. + // For Gemini 2.5 and after, this list will be empty and should be ignored. + ConfidenceScores []float32 `json:"confidence_scores,omitempty"` + // A list of indices (into 'grounding_chunk') specifying the citations associated with + // the claim. For instance [1,3,4] means that grounding_chunk[1], grounding_chunk[3], + // grounding_chunk[4] are the retrieved content attributed to the claim. + GroundingChunkIndices []int `json:"grounding_chunk_indices,omitempty"` + // Segment of the content this support belongs to. + Segment *Segment `json:"segment,omitempty"` +} + +// Segment of the content. +type Segment struct { + // Output only. End index in the given Part, measured in bytes. Offset from the start + // of the Part, exclusive, starting at zero. + EndIndex int `json:"end_index,omitempty"` + // Output only. The index of a Part object within its parent Content object. + PartIndex int `json:"part_index,omitempty"` + // Output only. Start index in the given Part, measured in bytes. Offset from the start + // of the Part, inclusive, starting at zero. + StartIndex int `json:"start_index,omitempty"` + // Output only. The text corresponding to the segment from the response. + Text string `json:"text,omitempty"` +} + +// SearchEntryPoint is the Google search entry point. +type SearchEntryPoint struct { + // Optional. Web content snippet that can be embedded in a web page or an app webview. + RenderedContent string `json:"rendered_content,omitempty"` + // Optional. Base64 encoded JSON representing array of tuple. + SDKBlob []byte `json:"sdk_blob,omitempty"` +} + +// ConcatResponseMetaExtensions concatenates multiple ResponseMetaExtension chunks into a single one. +func ConcatResponseMetaExtensions(chunks []*ResponseMetaExtension) (*ResponseMetaExtension, error) { + if len(chunks) == 0 { + return nil, fmt.Errorf("no response meta extension found") + } + if len(chunks) == 1 { + return chunks[0], nil + } + + ret := &ResponseMetaExtension{} + + for _, ext := range chunks { + if ext.ID != "" { + ret.ID = ext.ID + } + if ext.FinishReason != "" { + ret.FinishReason = ext.FinishReason + } + if ext.GroundingMeta != nil { + ret.GroundingMeta = ext.GroundingMeta + } + } + + return ret, nil +} diff --git a/schema/gemini/extension_test.go b/schema/gemini/extension_test.go new file mode 100644 index 000000000..56f390aa8 --- /dev/null +++ b/schema/gemini/extension_test.go @@ -0,0 +1,79 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package gemini + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConcatResponseMetaExtensions(t *testing.T) { + t.Run("multiple extensions - takes last non-empty values", func(t *testing.T) { + meta1 := &GroundingMetadata{WebSearchQueries: []string{"query1"}} + meta2 := &GroundingMetadata{WebSearchQueries: []string{"query2"}} + + exts := []*ResponseMetaExtension{ + { + ID: "resp_1", + FinishReason: "STOP", + GroundingMeta: meta1, + }, + { + ID: "resp_2", + FinishReason: "", + GroundingMeta: nil, + }, + { + ID: "", + FinishReason: "MAX_TOKENS", + GroundingMeta: meta2, + }, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "resp_2", result.ID) + assert.Equal(t, "MAX_TOKENS", result.FinishReason) + assert.Equal(t, meta2, result.GroundingMeta) + }) + + t.Run("streaming scenario", func(t *testing.T) { + meta := &GroundingMetadata{ + GroundingChunks: []*GroundingChunk{ + { + Web: &GroundingChunkWeb{ + Title: "Example", + URI: "https://example.com", + }, + }, + }, + } + + exts := []*ResponseMetaExtension{ + {ID: "stream_123", FinishReason: "", GroundingMeta: nil}, + {ID: "", FinishReason: "", GroundingMeta: nil}, + {ID: "", FinishReason: "STOP", GroundingMeta: meta}, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "stream_123", result.ID) + assert.Equal(t, "STOP", result.FinishReason) + assert.Equal(t, meta, result.GroundingMeta) + }) +} diff --git a/schema/message.go b/schema/message.go index 127a520a1..e7514f305 100644 --- a/schema/message.go +++ b/schema/message.go @@ -40,47 +40,56 @@ func init() { internal.RegisterStreamChunkConcatFunc(ConcatMessages) internal.RegisterStreamChunkConcatFunc(ConcatMessageArray) + internal.RegisterStreamChunkConcatFunc(ConcatAgenticMessages) + internal.RegisterStreamChunkConcatFunc(ConcatAgenticMessagesArray) + internal.RegisterStreamChunkConcatFunc(ConcatToolResults) } -// ConcatMessageArray merges aligned slices of messages into a single slice, -// concatenating messages at the same index across the input arrays. -func ConcatMessageArray(mas [][]*Message) ([]*Message, error) { - arrayLen := len(mas[0]) +func buildConcatGenericArray[T any](f func([]*T) (*T, error)) func([][]*T) ([]*T, error) { + return func(mas [][]*T) ([]*T, error) { + arrayLen := len(mas[0]) - ret := make([]*Message, arrayLen) - slicesToConcat := make([][]*Message, arrayLen) + ret := make([]*T, arrayLen) + slicesToConcat := make([][]*T, arrayLen) - for _, ma := range mas { - if len(ma) != arrayLen { - return nil, fmt.Errorf("unexpected array length. "+ - "Got %d, expected %d", len(ma), arrayLen) - } + for _, ma := range mas { + if len(ma) != arrayLen { + return nil, fmt.Errorf("unexpected array length. "+ + "Got %d, expected %d", len(ma), arrayLen) + } - for i := 0; i < arrayLen; i++ { - m := ma[i] - if m != nil { - slicesToConcat[i] = append(slicesToConcat[i], m) + for i := 0; i < arrayLen; i++ { + m := ma[i] + if m != nil { + slicesToConcat[i] = append(slicesToConcat[i], m) + } } } - } - for i, slice := range slicesToConcat { - if len(slice) == 0 { - ret[i] = nil - } else if len(slice) == 1 { - ret[i] = slice[0] - } else { - cm, err := ConcatMessages(slice) - if err != nil { - return nil, err - } + for i, slice := range slicesToConcat { + if len(slice) == 0 { + ret[i] = nil + } else if len(slice) == 1 { + ret[i] = slice[0] + } else { + cm, err := f(slice) + if err != nil { + return nil, err + } - ret[i] = cm + ret[i] = cm + } } + + return ret, nil } +} - return ret, nil +// ConcatMessageArray merges aligned slices of messages into a single slice, +// concatenating messages at the same index across the input arrays. +func ConcatMessageArray(mas [][]*Message) ([]*Message, error) { + return buildConcatGenericArray[Message](ConcatMessages)(mas) } // FormatType used by MessageTemplate.Format @@ -716,7 +725,7 @@ var _ MessagesTemplate = MessagesPlaceholder("", false) // e.g. // // chatTemplate := prompt.FromMessages( -// schema.SystemMessage("you are eino helper"), +// schema.SystemMessage("you are an eino helper"), // schema.MessagesPlaceholder("history", false), // <= this will use the value of "history" in params // ) // msgs, err := chatTemplate.Format(ctx, params) @@ -734,7 +743,7 @@ type messagesPlaceholder struct { // // placeholder := MessagesPlaceholder("history", false) // params := map[string]any{ -// "history": []*schema.Message{{Role: "user", Content: "what is eino?"}, {Role: "assistant", Content: "eino is a great freamwork to build llm apps"}}, +// "history": []*schema.Message{{Role: "user", Content: "what is eino?"}, {Role: "assistant", Content: "eino is a great framework to build llm apps"}}, // "query": "how to use eino?", // } // chatTemplate := chatTpl := prompt.FromMessages( diff --git a/schema/openai/consts.go b/schema/openai/consts.go new file mode 100644 index 000000000..5958cef40 --- /dev/null +++ b/schema/openai/consts.go @@ -0,0 +1,95 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package openai defines constants for openai. +package openai + +type TextAnnotationType string + +const ( + TextAnnotationTypeFileCitation TextAnnotationType = "file_citation" + TextAnnotationTypeURLCitation TextAnnotationType = "url_citation" + TextAnnotationTypeContainerFileCitation TextAnnotationType = "container_file_citation" + TextAnnotationTypeFilePath TextAnnotationType = "file_path" +) + +type ReasoningEffort string + +const ( + ReasoningEffortMinimal ReasoningEffort = "minimal" + ReasoningEffortLow ReasoningEffort = "low" + ReasoningEffortMedium ReasoningEffort = "medium" + ReasoningEffortHigh ReasoningEffort = "high" +) + +type ReasoningSummary string + +const ( + ReasoningSummaryAuto ReasoningSummary = "auto" + ReasoningSummaryConcise ReasoningSummary = "concise" + ReasoningSummaryDetailed ReasoningSummary = "detailed" +) + +type ServiceTier string + +const ( + ServiceTierAuto ServiceTier = "auto" + ServiceTierDefault ServiceTier = "default" + ServiceTierFlex ServiceTier = "flex" + ServiceTierScale ServiceTier = "scale" + ServiceTierPriority ServiceTier = "priority" +) + +type PromptCacheRetention string + +const ( + PromptCacheRetentionInMemory PromptCacheRetention = "in-memory" + PromptCacheRetention24h PromptCacheRetention = "24h" +) + +type ResponseStatus string + +const ( + ResponseStatusCompleted ResponseStatus = "completed" + ResponseStatusFailed ResponseStatus = "failed" + ResponseStatusInProgress ResponseStatus = "in_progress" + ResponseStatusCancelled ResponseStatus = "cancelled" + ResponseStatusQueued ResponseStatus = "queued" + ResponseStatusIncomplete ResponseStatus = "incomplete" +) + +type ResponseErrorCode string + +const ( + ResponseErrorCodeServerError ResponseErrorCode = "server_error" + ResponseErrorCodeRateLimitExceeded ResponseErrorCode = "rate_limit_exceeded" + ResponseErrorCodeInvalidPrompt ResponseErrorCode = "invalid_prompt" + ResponseErrorCodeVectorStoreTimeout ResponseErrorCode = "vector_store_timeout" + ResponseErrorCodeInvalidImage ResponseErrorCode = "invalid_image" + ResponseErrorCodeInvalidImageFormat ResponseErrorCode = "invalid_image_format" + ResponseErrorCodeInvalidBase64Image ResponseErrorCode = "invalid_base64_image" + ResponseErrorCodeInvalidImageURL ResponseErrorCode = "invalid_image_url" + ResponseErrorCodeImageTooLarge ResponseErrorCode = "image_too_large" + ResponseErrorCodeImageTooSmall ResponseErrorCode = "image_too_small" + ResponseErrorCodeImageParseError ResponseErrorCode = "image_parse_error" + ResponseErrorCodeImageContentPolicyViolation ResponseErrorCode = "image_content_policy_violation" + ResponseErrorCodeInvalidImageMode ResponseErrorCode = "invalid_image_mode" + ResponseErrorCodeImageFileTooLarge ResponseErrorCode = "image_file_too_large" + ResponseErrorCodeUnsupportedImageMediaType ResponseErrorCode = "unsupported_image_media_type" + ResponseErrorCodeEmptyImageFile ResponseErrorCode = "empty_image_file" + ResponseErrorCodeFailedToDownloadImage ResponseErrorCode = "failed_to_download_image" + ResponseErrorCodeImageFileNotFound ResponseErrorCode = "image_file_not_found" +) diff --git a/schema/openai/extension.go b/schema/openai/extension.go new file mode 100644 index 000000000..1e10c411e --- /dev/null +++ b/schema/openai/extension.go @@ -0,0 +1,212 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openai + +import ( + "fmt" + "sort" +) + +type ResponseMetaExtension struct { + ID string `json:"id,omitempty"` + Status ResponseStatus `json:"status,omitempty"` + Error *ResponseError `json:"error,omitempty"` + IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"` + PreviousResponseID string `json:"previous_response_id,omitempty"` + Reasoning *Reasoning `json:"reasoning,omitempty"` + ServiceTier ServiceTier `json:"service_tier,omitempty"` + CreatedAt int64 `json:"created_at,omitempty"` + PromptCacheRetention PromptCacheRetention `json:"prompt_cache_retention,omitempty"` +} + +type AssistantGenTextExtension struct { + Refusal *OutputRefusal `json:"refusal,omitempty"` + Annotations []*TextAnnotation `json:"annotations,omitempty"` +} + +type ResponseError struct { + Code ResponseErrorCode `json:"code,omitempty"` + Message string `json:"message,omitempty"` +} + +type IncompleteDetails struct { + Reason string `json:"reason,omitempty"` +} + +type Reasoning struct { + Effort ReasoningEffort `json:"effort,omitempty"` + Summary ReasoningSummary `json:"summary,omitempty"` +} + +type OutputRefusal struct { + Reason string `json:"reason,omitempty"` +} + +type TextAnnotation struct { + Index int `json:"index,omitempty"` + + Type TextAnnotationType `json:"type,omitempty"` + + FileCitation *TextAnnotationFileCitation `json:"file_citation,omitempty"` + URLCitation *TextAnnotationURLCitation `json:"url_citation,omitempty"` + ContainerFileCitation *TextAnnotationContainerFileCitation `json:"container_file_citation,omitempty"` + FilePath *TextAnnotationFilePath `json:"file_path,omitempty"` +} + +type TextAnnotationFileCitation struct { + // The ID of the file. + FileID string `json:"file_id,omitempty"` + // The filename of the file cited. + Filename string `json:"filename,omitempty"` + + // The index of the file in the list of files. + Index int `json:"index,omitempty"` +} + +type TextAnnotationURLCitation struct { + // The title of the web resource. + Title string `json:"title,omitempty"` + // The URL of the web resource. + URL string `json:"url,omitempty"` + + // The index of the first character of the URL citation in the message. + StartIndex int `json:"start_index,omitempty"` + // The index of the last character of the URL citation in the message. + EndIndex int `json:"end_index,omitempty"` +} + +type TextAnnotationContainerFileCitation struct { + // The ID of the container file. + ContainerID string `json:"container_id,omitempty"` + + // The ID of the file. + FileID string `json:"file_id,omitempty"` + // The filename of the container file cited. + Filename string `json:"filename,omitempty"` + + // The index of the first character of the container file citation in the message. + StartIndex int `json:"start_index,omitempty"` + // The index of the last character of the container file citation in the message. + EndIndex int `json:"end_index,omitempty"` +} + +type TextAnnotationFilePath struct { + // The ID of the file. + FileID string `json:"file_id,omitempty"` + + // The index of the file in the list of files. + Index int `json:"index,omitempty"` +} + +// ConcatAssistantGenTextExtensions concatenates multiple AssistantGenTextExtension chunks into a single one. +func ConcatAssistantGenTextExtensions(chunks []*AssistantGenTextExtension) (*AssistantGenTextExtension, error) { + if len(chunks) == 0 { + return nil, fmt.Errorf("no assistant generated text extension found") + } + + ret := &AssistantGenTextExtension{} + + var allAnnotations []*TextAnnotation + for _, ext := range chunks { + allAnnotations = append(allAnnotations, ext.Annotations...) + } + + var ( + indices []int + indexToAnnotation = map[int]*TextAnnotation{} + ) + + for _, an := range allAnnotations { + if an == nil { + continue + } + if indexToAnnotation[an.Index] == nil { + indexToAnnotation[an.Index] = an + indices = append(indices, an.Index) + } else { + return nil, fmt.Errorf("duplicate annotation index %d", an.Index) + } + } + + sort.Slice(indices, func(i, j int) bool { + return indices[i] < indices[j] + }) + + ret.Annotations = make([]*TextAnnotation, 0, len(indices)) + for _, idx := range indices { + an := *indexToAnnotation[idx] + an.Index = 0 // clear index + ret.Annotations = append(ret.Annotations, &an) + } + + for _, ext := range chunks { + if ext.Refusal == nil { + continue + } + if ret.Refusal == nil { + ret.Refusal = ext.Refusal + } else { + ret.Refusal.Reason += ext.Refusal.Reason + } + } + + return ret, nil +} + +// ConcatResponseMetaExtensions concatenates multiple ResponseMetaExtension chunks into a single one. +func ConcatResponseMetaExtensions(chunks []*ResponseMetaExtension) (*ResponseMetaExtension, error) { + if len(chunks) == 0 { + return nil, fmt.Errorf("no response meta extension found") + } + if len(chunks) == 1 { + return chunks[0], nil + } + + ret := &ResponseMetaExtension{} + + for _, ext := range chunks { + if ext.ID != "" { + ret.ID = ext.ID + } + if ext.Status != "" { + ret.Status = ext.Status + } + if ext.Error != nil { + ret.Error = ext.Error + } + if ext.IncompleteDetails != nil { + ret.IncompleteDetails = ext.IncompleteDetails + } + if ext.PreviousResponseID != "" { + ret.PreviousResponseID = ext.PreviousResponseID + } + if ext.Reasoning != nil { + ret.Reasoning = ext.Reasoning + } + if ext.ServiceTier != "" { + ret.ServiceTier = ext.ServiceTier + } + if ext.CreatedAt != 0 { + ret.CreatedAt = ext.CreatedAt + } + if ext.PromptCacheRetention != "" { + ret.PromptCacheRetention = ext.PromptCacheRetention + } + } + + return ret, nil +} diff --git a/schema/openai/extension_test.go b/schema/openai/extension_test.go new file mode 100644 index 000000000..640982fdf --- /dev/null +++ b/schema/openai/extension_test.go @@ -0,0 +1,193 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package openai + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConcatResponseMetaExtensions(t *testing.T) { + t.Run("multiple extensions - takes last non-empty values", func(t *testing.T) { + err1 := &ResponseError{Code: "err1", Message: "msg1"} + incomplete := &IncompleteDetails{Reason: "max_tokens"} + + exts := []*ResponseMetaExtension{ + { + ID: "id_1", + Status: "in_progress", + Error: err1, + IncompleteDetails: nil, + }, + { + ID: "id_2", + Status: "", + Error: nil, + IncompleteDetails: nil, + }, + { + ID: "", + Status: "completed", + Error: nil, + IncompleteDetails: incomplete, + }, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "id_2", result.ID) + assert.Equal(t, ResponseStatus("completed"), result.Status) + assert.Equal(t, err1, result.Error) + assert.Equal(t, incomplete, result.IncompleteDetails) + }) + + t.Run("streaming scenario", func(t *testing.T) { + exts := []*ResponseMetaExtension{ + {ID: "chatcmpl_stream", Status: "", Error: nil, IncompleteDetails: nil}, + {ID: "", Status: ResponseStatus("in_progress"), Error: nil, IncompleteDetails: nil}, + {ID: "", Status: ResponseStatus("completed"), Error: nil, IncompleteDetails: nil}, + } + + result, err := ConcatResponseMetaExtensions(exts) + assert.NoError(t, err) + assert.Equal(t, "chatcmpl_stream", result.ID) + assert.Equal(t, ResponseStatus("completed"), result.Status) + }) +} + +func TestConcatAssistantGenTextExtensions(t *testing.T) { + t.Run("single extension with annotations", func(t *testing.T) { + ext := &AssistantGenTextExtension{ + Annotations: []*TextAnnotation{ + { + Index: 0, + Type: "file_citation", + FileCitation: &TextAnnotationFileCitation{ + FileID: "file_123", + Filename: "doc.pdf", + }, + }, + }, + } + + result, err := ConcatAssistantGenTextExtensions([]*AssistantGenTextExtension{ext}) + assert.NoError(t, err) + assert.Len(t, result.Annotations, 1) + assert.Equal(t, "file_123", result.Annotations[0].FileCitation.FileID) + }) + + t.Run("multiple extensions - merges annotations by index", func(t *testing.T) { + exts := []*AssistantGenTextExtension{ + { + Annotations: []*TextAnnotation{ + { + Index: 0, + Type: "file_citation", + FileCitation: &TextAnnotationFileCitation{ + FileID: "file_1", + }, + }, + }, + }, + { + Annotations: []*TextAnnotation{ + { + Index: 2, + Type: "url_citation", + URLCitation: &TextAnnotationURLCitation{ + URL: "https://example.com", + }, + }, + }, + }, + { + Annotations: []*TextAnnotation{ + { + Index: 1, + Type: "file_path", + FilePath: &TextAnnotationFilePath{ + FileID: "file_2", + }, + }, + }, + }, + } + + result, err := ConcatAssistantGenTextExtensions(exts) + assert.NoError(t, err) + assert.Len(t, result.Annotations, 3) + assert.Equal(t, "file_1", result.Annotations[0].FileCitation.FileID) + assert.Equal(t, "file_2", result.Annotations[1].FilePath.FileID) + assert.Equal(t, "https://example.com", result.Annotations[2].URLCitation.URL) + }) + + t.Run("streaming scenario - annotations arrive in chunks", func(t *testing.T) { + exts := []*AssistantGenTextExtension{ + { + Annotations: []*TextAnnotation{ + {Index: 0, Type: "file_citation", FileCitation: &TextAnnotationFileCitation{FileID: "f1"}}, + }, + }, + { + Annotations: []*TextAnnotation{ + {Index: 1, Type: "url_citation", URLCitation: &TextAnnotationURLCitation{URL: "url1"}}, + }, + }, + { + Annotations: []*TextAnnotation{ + {Index: 2, Type: "file_path", FilePath: &TextAnnotationFilePath{FileID: "f2"}}, + }, + }, + } + + result, err := ConcatAssistantGenTextExtensions(exts) + assert.NoError(t, err) + assert.Len(t, result.Annotations, 3) + assert.Equal(t, "f1", result.Annotations[0].FileCitation.FileID) + assert.Equal(t, "url1", result.Annotations[1].URLCitation.URL) + assert.Equal(t, "f2", result.Annotations[2].FilePath.FileID) + }) + + t.Run("multiple extensions - concatenates refusal reason", func(t *testing.T) { + ext1 := &AssistantGenTextExtension{Refusal: &OutputRefusal{Reason: "A"}} + ext2 := &AssistantGenTextExtension{Refusal: &OutputRefusal{Reason: "B"}} + + result, err := ConcatAssistantGenTextExtensions([]*AssistantGenTextExtension{ext1, ext2}) + assert.NoError(t, err) + assert.NotNil(t, result.Refusal) + assert.Equal(t, "AB", result.Refusal.Reason) + }) + + t.Run("duplicate index - error occurrence", func(t *testing.T) { + exts := []*AssistantGenTextExtension{ + { + Annotations: []*TextAnnotation{ + {Index: 0, Type: "file_citation", FileCitation: &TextAnnotationFileCitation{FileID: "first"}}, + }, + }, + { + Annotations: []*TextAnnotation{ + {Index: 0, Type: "url_citation", URLCitation: &TextAnnotationURLCitation{URL: "second"}}, + }, + }, + } + + _, err := ConcatAssistantGenTextExtensions(exts) + assert.Error(t, err) + }) +} diff --git a/schema/tool.go b/schema/tool.go index ccc93b6a3..a49306047 100644 --- a/schema/tool.go +++ b/schema/tool.go @@ -59,6 +59,61 @@ const ( ToolChoiceForced ToolChoice = "forced" ) +type AgenticToolChoice struct { + // Type is the tool choice mode. + Type ToolChoice + + // Allowed optionally specifies the list of tools that the model is permitted to call. + // Optional. + Allowed *AgenticAllowedToolChoice + + // Forced optionally specifies the list of tools that the model is required to call. + // Optional. + Forced *AgenticForcedToolChoice +} + +// AgenticAllowedToolChoice specifies a list of allowed tools for the model. +type AgenticAllowedToolChoice struct { + // Tools is the list of allowed tools for the model to call. + // Optional. + Tools []*AllowedTool +} + +// AgenticForcedToolChoice specifies a list of tools that the model must call. +type AgenticForcedToolChoice struct { + // Tools is the list of tools that the model must call. + // Optional. + Tools []*AllowedTool +} + +// AllowedTool represents a tool that the model is allowed or forced to call. +// Exactly one of FunctionName, MCPTool, or ServerTool must be specified. +type AllowedTool struct { + // FunctionName specifies a function tool by name. + FunctionName string + + // MCPTool specifies an MCP tool. + MCPTool *AllowedMCPTool + + // ServerTool specifies a server tool. + ServerTool *AllowedServerTool +} + +// AllowedMCPTool contains the information for identifying an MCP tool. +type AllowedMCPTool struct { + // ServerLabel is the label of the MCP server. + ServerLabel string + // Name is the name of the MCP tool. + Name string +} + +// AllowedServerTool contains the information for identifying a server tool. +type AllowedServerTool struct { + // Name is the name of the server tool. + Name string +} + +// ToolInfo is the information of a tool. // ToolInfo describes a tool that can be passed to a ChatModel via // [ToolCallingChatModel.WithTools] or [ChatModel.BindTools]. // diff --git a/utils/callbacks/template.go b/utils/callbacks/template.go index e04bddd63..4c2c709da 100644 --- a/utils/callbacks/template.go +++ b/utils/callbacks/template.go @@ -55,17 +55,20 @@ func NewHandlerHelper() *HandlerHelper { // // then use the handler with runnable.Invoke(ctx, input, compose.WithCallbacks(handler)) type HandlerHelper struct { - promptHandler *PromptCallbackHandler - chatModelHandler *ModelCallbackHandler - embeddingHandler *EmbeddingCallbackHandler - indexerHandler *IndexerCallbackHandler - retrieverHandler *RetrieverCallbackHandler - loaderHandler *LoaderCallbackHandler - transformerHandler *TransformerCallbackHandler - toolHandler *ToolCallbackHandler - toolsNodeHandler *ToolsNodeCallbackHandlers - agentHandler *AgentCallbackHandler - composeTemplates map[components.Component]callbacks.Handler + promptHandler *PromptCallbackHandler + chatModelHandler *ModelCallbackHandler + embeddingHandler *EmbeddingCallbackHandler + indexerHandler *IndexerCallbackHandler + retrieverHandler *RetrieverCallbackHandler + loaderHandler *LoaderCallbackHandler + transformerHandler *TransformerCallbackHandler + toolHandler *ToolCallbackHandler + toolsNodeHandler *ToolsNodeCallbackHandlers + agenticPromptHandler *AgenticPromptCallbackHandler + agenticModelHandler *AgenticModelCallbackHandler + agenticToolsNodeHandler *AgenticToolsNodeCallbackHandlers + agentHandler *AgentCallbackHandler + composeTemplates map[components.Component]callbacks.Handler } // Handler returns the callbacks.Handler created by HandlerHelper. @@ -127,6 +130,24 @@ func (c *HandlerHelper) ToolsNode(handler *ToolsNodeCallbackHandlers) *HandlerHe return c } +// AgenticPrompt sets the agentic prompt handler for the handler helper, which will be called when the agentic prompt component is executed. +func (c *HandlerHelper) AgenticPrompt(handler *AgenticPromptCallbackHandler) *HandlerHelper { + c.agenticPromptHandler = handler + return c +} + +// AgenticModel sets the agentic chat model handler for the handler helper, which will be called when the agentic chat model component is executed. +func (c *HandlerHelper) AgenticModel(handler *AgenticModelCallbackHandler) *HandlerHelper { + c.agenticModelHandler = handler + return c +} + +// AgenticToolsNode sets the agentic tools node handler for the handler helper, which will be called when the agentic tools node is executed. +func (c *HandlerHelper) AgenticToolsNode(handler *AgenticToolsNodeCallbackHandlers) *HandlerHelper { + c.agenticToolsNodeHandler = handler + return c +} + // Agent sets the agent handler for the handler helper, which will be called when the agent is executed. func (c *HandlerHelper) Agent(handler *AgentCallbackHandler) *HandlerHelper { c.agentHandler = handler @@ -161,8 +182,12 @@ func (c *handlerTemplate) OnStart(ctx context.Context, info *callbacks.RunInfo, switch info.Component { case components.ComponentOfPrompt: return c.promptHandler.OnStart(ctx, info, prompt.ConvCallbackInput(input)) + case components.ComponentOfAgenticPrompt: + return c.agenticPromptHandler.OnStart(ctx, info, prompt.ConvCallbackInput(input)) case components.ComponentOfChatModel: return c.chatModelHandler.OnStart(ctx, info, model.ConvCallbackInput(input)) + case components.ComponentOfAgenticModel: + return c.agenticModelHandler.OnStart(ctx, info, model.ConvAgenticCallbackInput(input)) case components.ComponentOfEmbedding: return c.embeddingHandler.OnStart(ctx, info, embedding.ConvCallbackInput(input)) case components.ComponentOfIndexer: @@ -177,6 +202,8 @@ func (c *handlerTemplate) OnStart(ctx context.Context, info *callbacks.RunInfo, return c.toolHandler.OnStart(ctx, info, tool.ConvCallbackInput(input)) case compose.ComponentOfToolsNode: return c.toolsNodeHandler.OnStart(ctx, info, convToolsNodeCallbackInput(input)) + case compose.ComponentOfAgenticToolsNode: + return c.agenticToolsNodeHandler.OnStart(ctx, info, convAgenticToolsNodeCallbackInput(input)) case adk.ComponentOfAgent: return c.agentHandler.OnStart(ctx, info, adk.ConvAgentCallbackInput(input)) case compose.ComponentOfGraph, @@ -194,8 +221,12 @@ func (c *handlerTemplate) OnEnd(ctx context.Context, info *callbacks.RunInfo, ou switch info.Component { case components.ComponentOfPrompt: return c.promptHandler.OnEnd(ctx, info, prompt.ConvCallbackOutput(output)) + case components.ComponentOfAgenticPrompt: + return c.agenticPromptHandler.OnEnd(ctx, info, prompt.ConvCallbackOutput(output)) case components.ComponentOfChatModel: return c.chatModelHandler.OnEnd(ctx, info, model.ConvCallbackOutput(output)) + case components.ComponentOfAgenticModel: + return c.agenticModelHandler.OnEnd(ctx, info, model.ConvAgenticCallbackOutput(output)) case components.ComponentOfEmbedding: return c.embeddingHandler.OnEnd(ctx, info, embedding.ConvCallbackOutput(output)) case components.ComponentOfIndexer: @@ -210,6 +241,8 @@ func (c *handlerTemplate) OnEnd(ctx context.Context, info *callbacks.RunInfo, ou return c.toolHandler.OnEnd(ctx, info, tool.ConvCallbackOutput(output)) case compose.ComponentOfToolsNode: return c.toolsNodeHandler.OnEnd(ctx, info, convToolsNodeCallbackOutput(output)) + case compose.ComponentOfAgenticToolsNode: + return c.agenticToolsNodeHandler.OnEnd(ctx, info, convAgenticToolsNodeCallbackOutput(output)) case adk.ComponentOfAgent: return c.agentHandler.OnEnd(ctx, info, adk.ConvAgentCallbackOutput(output)) case compose.ComponentOfGraph, @@ -227,8 +260,12 @@ func (c *handlerTemplate) OnError(ctx context.Context, info *callbacks.RunInfo, switch info.Component { case components.ComponentOfPrompt: return c.promptHandler.OnError(ctx, info, err) + case components.ComponentOfAgenticPrompt: + return c.agenticPromptHandler.OnError(ctx, info, err) case components.ComponentOfChatModel: return c.chatModelHandler.OnError(ctx, info, err) + case components.ComponentOfAgenticModel: + return c.agenticModelHandler.OnError(ctx, info, err) case components.ComponentOfEmbedding: return c.embeddingHandler.OnError(ctx, info, err) case components.ComponentOfIndexer: @@ -243,6 +280,8 @@ func (c *handlerTemplate) OnError(ctx context.Context, info *callbacks.RunInfo, return c.toolHandler.OnError(ctx, info, err) case compose.ComponentOfToolsNode: return c.toolsNodeHandler.OnError(ctx, info, err) + case compose.ComponentOfAgenticToolsNode: + return c.agenticToolsNodeHandler.OnError(ctx, info, err) case compose.ComponentOfGraph, compose.ComponentOfChain, compose.ComponentOfLambda: @@ -275,6 +314,11 @@ func (c *handlerTemplate) OnEndWithStreamOutput(ctx context.Context, info *callb schema.StreamReaderWithConvert(output, func(item callbacks.CallbackOutput) (*model.CallbackOutput, error) { return model.ConvCallbackOutput(item), nil })) + case components.ComponentOfAgenticModel: + return c.agenticModelHandler.OnEndWithStreamOutput(ctx, info, + schema.StreamReaderWithConvert(output, func(item callbacks.CallbackOutput) (*model.AgenticCallbackOutput, error) { + return model.ConvAgenticCallbackOutput(item), nil + })) case components.ComponentOfTool: return c.toolHandler.OnEndWithStreamOutput(ctx, info, schema.StreamReaderWithConvert(output, func(item callbacks.CallbackOutput) (*tool.CallbackOutput, error) { @@ -285,6 +329,11 @@ func (c *handlerTemplate) OnEndWithStreamOutput(ctx context.Context, info *callb schema.StreamReaderWithConvert(output, func(item callbacks.CallbackOutput) ([]*schema.Message, error) { return convToolsNodeCallbackOutput(item), nil })) + case compose.ComponentOfAgenticToolsNode: + return c.agenticToolsNodeHandler.OnEndWithStreamOutput(ctx, info, + schema.StreamReaderWithConvert(output, func(item callbacks.CallbackOutput) ([]*schema.AgenticMessage, error) { + return convAgenticToolsNodeCallbackOutput(item), nil + })) case compose.ComponentOfGraph, compose.ComponentOfChain, compose.ComponentOfLambda: @@ -295,6 +344,8 @@ func (c *handlerTemplate) OnEndWithStreamOutput(ctx context.Context, info *callb } // Needed checks if the callback handler is needed for the given timing. +// +//nolint:cyclop func (c *handlerTemplate) Needed(ctx context.Context, info *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { if info == nil { return false @@ -305,6 +356,10 @@ func (c *handlerTemplate) Needed(ctx context.Context, info *callbacks.RunInfo, t if c.chatModelHandler != nil && c.chatModelHandler.Needed(ctx, info, timing) { return true } + case components.ComponentOfAgenticModel: + if c.agenticModelHandler != nil && c.agenticModelHandler.Needed(ctx, info, timing) { + return true + } case components.ComponentOfEmbedding: if c.embeddingHandler != nil && c.embeddingHandler.Needed(ctx, info, timing) { return true @@ -321,6 +376,10 @@ func (c *handlerTemplate) Needed(ctx context.Context, info *callbacks.RunInfo, t if c.promptHandler != nil && c.promptHandler.Needed(ctx, info, timing) { return true } + case components.ComponentOfAgenticPrompt: + if c.agenticPromptHandler != nil && c.agenticPromptHandler.Needed(ctx, info, timing) { + return true + } case components.ComponentOfRetriever: if c.retrieverHandler != nil && c.retrieverHandler.Needed(ctx, info, timing) { return true @@ -337,6 +396,10 @@ func (c *handlerTemplate) Needed(ctx context.Context, info *callbacks.RunInfo, t if c.toolsNodeHandler != nil && c.toolsNodeHandler.Needed(ctx, info, timing) { return true } + case compose.ComponentOfAgenticToolsNode: + if c.agenticToolsNodeHandler != nil && c.agenticToolsNodeHandler.Needed(ctx, info, timing) { + return true + } case adk.ComponentOfAgent: if c.agentHandler != nil && c.agentHandler.Needed(ctx, info, timing) { return true @@ -596,3 +659,94 @@ func (ch *AgentCallbackHandler) Needed(ctx context.Context, info *callbacks.RunI return false } } + +// AgenticPromptCallbackHandler is the handler for the agentic prompt callback. +type AgenticPromptCallbackHandler struct { + // OnStart is the callback function for the start of the agentic prompt. + OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *prompt.CallbackInput) context.Context + // OnEnd is the callback function for the end of the agentic prompt. + OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *prompt.CallbackOutput) context.Context + // OnError is the callback function for the error of the agentic prompt. + OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context +} + +// Needed checks if the callback handler is needed for the given timing. +func (ch *AgenticPromptCallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { + switch timing { + case callbacks.TimingOnStart: + return ch.OnStart != nil + case callbacks.TimingOnEnd: + return ch.OnEnd != nil + case callbacks.TimingOnError: + return ch.OnError != nil + default: + return false + } +} + +// AgenticModelCallbackHandler is the handler for the agentic chat model callback. +type AgenticModelCallbackHandler struct { + OnStart func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.AgenticCallbackInput) context.Context + OnEnd func(ctx context.Context, runInfo *callbacks.RunInfo, output *model.AgenticCallbackOutput) context.Context + OnEndWithStreamOutput func(ctx context.Context, runInfo *callbacks.RunInfo, output *schema.StreamReader[*model.AgenticCallbackOutput]) context.Context + OnError func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context +} + +// Needed checks if the callback handler is needed for the given timing. +func (ch *AgenticModelCallbackHandler) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { + switch timing { + case callbacks.TimingOnStart: + return ch.OnStart != nil + case callbacks.TimingOnEnd: + return ch.OnEnd != nil + case callbacks.TimingOnError: + return ch.OnError != nil + case callbacks.TimingOnEndWithStreamOutput: + return ch.OnEndWithStreamOutput != nil + default: + return false + } +} + +// AgenticToolsNodeCallbackHandlers defines optional callbacks for the Agentic Tools node +// lifecycle events. +type AgenticToolsNodeCallbackHandlers struct { + OnStart func(ctx context.Context, info *callbacks.RunInfo, input *schema.AgenticMessage) context.Context + OnEnd func(ctx context.Context, info *callbacks.RunInfo, input []*schema.AgenticMessage) context.Context + OnEndWithStreamOutput func(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[[]*schema.AgenticMessage]) context.Context + OnError func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context +} + +// Needed reports whether a handler is registered for the given timing. +func (ch *AgenticToolsNodeCallbackHandlers) Needed(ctx context.Context, runInfo *callbacks.RunInfo, timing callbacks.CallbackTiming) bool { + switch timing { + case callbacks.TimingOnStart: + return ch.OnStart != nil + case callbacks.TimingOnEnd: + return ch.OnEnd != nil + case callbacks.TimingOnEndWithStreamOutput: + return ch.OnEndWithStreamOutput != nil + case callbacks.TimingOnError: + return ch.OnError != nil + default: + return false + } +} + +func convAgenticToolsNodeCallbackInput(src callbacks.CallbackInput) *schema.AgenticMessage { + switch t := src.(type) { + case *schema.AgenticMessage: + return t + default: + return nil + } +} + +func convAgenticToolsNodeCallbackOutput(src callbacks.CallbackInput) []*schema.AgenticMessage { + switch t := src.(type) { + case []*schema.AgenticMessage: + return t + default: + return nil + } +} diff --git a/utils/callbacks/template_test.go b/utils/callbacks/template_test.go index 84ed6dfc6..dcc0e5c7f 100644 --- a/utils/callbacks/template_test.go +++ b/utils/callbacks/template_test.go @@ -142,6 +142,58 @@ func TestNewComponentTemplate(t *testing.T) { cnt++ return ctx }).Build()). + AgenticModel(&AgenticModelCallbackHandler{ + OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.AgenticCallbackInput) context.Context { + cnt++ + return ctx + }, + OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *model.AgenticCallbackOutput) context.Context { + cnt++ + return ctx + }, + OnEndWithStreamOutput: func(ctx context.Context, runInfo *callbacks.RunInfo, output *schema.StreamReader[*model.AgenticCallbackOutput]) context.Context { + output.Close() + cnt++ + return ctx + }, + OnError: func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context { + cnt++ + return ctx + }, + }). + AgenticPrompt(&AgenticPromptCallbackHandler{ + OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *prompt.CallbackInput) context.Context { + cnt++ + return ctx + }, + OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *prompt.CallbackOutput) context.Context { + cnt++ + return ctx + }, + OnError: func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context { + cnt++ + return ctx + }, + }). + AgenticToolsNode(&AgenticToolsNodeCallbackHandlers{ + OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *schema.AgenticMessage) context.Context { + cnt++ + return ctx + }, + OnEnd: func(ctx context.Context, info *callbacks.RunInfo, input []*schema.AgenticMessage) context.Context { + cnt++ + return ctx + }, + OnEndWithStreamOutput: func(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[[]*schema.AgenticMessage]) context.Context { + output.Close() + cnt++ + return ctx + }, + OnError: func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { + cnt++ + return ctx + }, + }). Handler() types := []components.Component{ @@ -151,6 +203,9 @@ func TestNewComponentTemplate(t *testing.T) { components.ComponentOfRetriever, components.ComponentOfTool, compose.ComponentOfLambda, + components.ComponentOfAgenticModel, + components.ComponentOfAgenticPrompt, + compose.ComponentOfAgenticToolsNode, } handler := tpl.Handler() @@ -169,28 +224,28 @@ func TestNewComponentTemplate(t *testing.T) { handler.OnEndWithStreamOutput(ctx, &callbacks.RunInfo{Component: typ}, sor) } - assert.Equal(t, 22, cnt) + assert.Equal(t, 33, cnt) ctx = context.Background() ctx = callbacks.InitCallbacks(ctx, &callbacks.RunInfo{Component: components.ComponentOfTransformer}, handler) callbacks.OnStart[any](ctx, nil) - assert.Equal(t, 22, cnt) + assert.Equal(t, 33, cnt) ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{Component: components.ComponentOfPrompt}) ctx = callbacks.OnStart[any](ctx, nil) - assert.Equal(t, 23, cnt) + assert.Equal(t, 34, cnt) ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{Component: components.ComponentOfIndexer}) callbacks.OnEnd[any](ctx, nil) - assert.Equal(t, 23, cnt) + assert.Equal(t, 34, cnt) ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{Component: components.ComponentOfEmbedding}) callbacks.OnError(ctx, nil) - assert.Equal(t, 24, cnt) + assert.Equal(t, 35, cnt) ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{Component: components.ComponentOfLoader}) callbacks.OnStart[any](ctx, nil) - assert.Equal(t, 24, cnt) + assert.Equal(t, 35, cnt) tpl.Transformer(&TransformerCallbackHandler{ OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *document.TransformerCallbackInput) context.Context { @@ -250,6 +305,37 @@ func TestNewComponentTemplate(t *testing.T) { } } }, + }).AgenticPrompt(&AgenticPromptCallbackHandler{ + OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *prompt.CallbackInput) context.Context { + cnt++ + return ctx + }, + OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *prompt.CallbackOutput) context.Context { + cnt++ + return ctx + }, + OnError: func(ctx context.Context, runInfo *callbacks.RunInfo, err error) context.Context { + cnt++ + return ctx + }, + }).AgenticToolsNode(&AgenticToolsNodeCallbackHandlers{ + OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *schema.AgenticMessage) context.Context { + cnt++ + return ctx + }, + OnEnd: func(ctx context.Context, info *callbacks.RunInfo, input []*schema.AgenticMessage) context.Context { + cnt++ + return ctx + }, + OnEndWithStreamOutput: func(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[[]*schema.AgenticMessage]) context.Context { + output.Close() + cnt++ + return ctx + }, + OnError: func(ctx context.Context, info *callbacks.RunInfo, err error) context.Context { + cnt++ + return ctx + }, }) handler = tpl.Handler() @@ -257,36 +343,222 @@ func TestNewComponentTemplate(t *testing.T) { ctx = callbacks.InitCallbacks(ctx, &callbacks.RunInfo{Component: components.ComponentOfTransformer}, handler) ctx = callbacks.OnStart[any](ctx, nil) - assert.Equal(t, 25, cnt) + assert.Equal(t, 36, cnt) callbacks.OnEnd[any](ctx, nil) - assert.Equal(t, 26, cnt) + assert.Equal(t, 37, cnt) ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{Component: components.ComponentOfLoader}) callbacks.OnEnd[any](ctx, nil) - assert.Equal(t, 27, cnt) + assert.Equal(t, 38, cnt) ctx = callbacks.ReuseHandlers(ctx, &callbacks.RunInfo{Component: compose.ComponentOfToolsNode}) callbacks.OnStart[any](ctx, nil) - assert.Equal(t, 28, cnt) + assert.Equal(t, 39, cnt) sr, sw := schema.Pipe[any](0) sw.Close() callbacks.OnEndWithStreamOutput[any](ctx, sr) - assert.Equal(t, 29, cnt) + assert.Equal(t, 40, cnt) sr1, sw1 := schema.Pipe[[]*schema.Message](1) sw1.Send([]*schema.Message{{}}, nil) sw1.Close() callbacks.OnEndWithStreamOutput[[]*schema.Message](ctx, sr1) - assert.Equal(t, 30, cnt) - - callbacks.OnError(ctx, nil) - assert.Equal(t, 30, cnt) + // Check AgenticModel stream + sir2, siw2 := schema.Pipe[callbacks.CallbackOutput](1) + siw2.Close() + handler.OnEndWithStreamOutput(ctx, &callbacks.RunInfo{Component: components.ComponentOfAgenticModel}, sir2) + assert.Equal(t, 42, cnt) + + // Check AgenticToolsNode stream + sir3, siw3 := schema.Pipe[callbacks.CallbackOutput](1) + siw3.Close() + handler.OnEndWithStreamOutput(ctx, &callbacks.RunInfo{Component: compose.ComponentOfAgenticToolsNode}, sir3) + assert.Equal(t, 43, cnt) ctx = callbacks.ReuseHandlers(ctx, nil) callbacks.OnStart[any](ctx, nil) - assert.Equal(t, 30, cnt) + assert.Equal(t, 43, cnt) + }) + + t.Run("EdgeCases", func(t *testing.T) { + ctx := context.Background() + cnt := 0 + + // 1. Test Graph and Chain Setters and Execution + tpl := NewHandlerHelper(). + Graph(callbacks.NewHandlerBuilder(). + OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { + cnt++ + return ctx + }).Build()). + Chain(callbacks.NewHandlerBuilder(). + OnEndFn(func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context { + cnt++ + return ctx + }).Build()) + + h := tpl.Handler() + + // Trigger Graph OnStart + h.OnStart(ctx, &callbacks.RunInfo{Component: compose.ComponentOfGraph}, nil) + assert.Equal(t, 1, cnt) + + // Trigger Chain OnEnd + h.OnEnd(ctx, &callbacks.RunInfo{Component: compose.ComponentOfChain}, nil) + assert.Equal(t, 2, cnt) + + // 2. Test Needed logic for Graph/Chain when handler is present/absent + // Graph is present (OnStart) + needed := h.(callbacks.TimingChecker).Needed(ctx, &callbacks.RunInfo{Component: compose.ComponentOfGraph}, callbacks.TimingOnStart) + assert.True(t, needed) + + // Chain is present (OnEnd) - but we check OnStart which is not defined in the builder above? + // NewHandlerBuilder returns a handler that usually returns true for Needed if the specific func is not nil. + // Let's verify Chain OnStart is NOT needed because we only set OnEndFn. + needed = h.(callbacks.TimingChecker).Needed(ctx, &callbacks.RunInfo{Component: compose.ComponentOfChain}, callbacks.TimingOnStart) + assert.False(t, needed) // Should be false because OnStartFn wasn't set for Chain + + // Lambda is NOT present + needed = h.(callbacks.TimingChecker).Needed(ctx, &callbacks.RunInfo{Component: compose.ComponentOfLambda}, callbacks.TimingOnStart) + assert.False(t, needed) + + // 3. Test Conversion Fallbacks (Default cases) + // We need a handler with ToolsNode and AgenticToolsNode to test their conversion fallbacks + tpl2 := NewHandlerHelper(). + ToolsNode(&ToolsNodeCallbackHandlers{ + OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *schema.Message) context.Context { + if input == nil { + cnt++ + } + return ctx + }, + OnEnd: func(ctx context.Context, info *callbacks.RunInfo, input []*schema.Message) context.Context { + if input == nil { + cnt++ + } + return ctx + }, + }). + AgenticToolsNode(&AgenticToolsNodeCallbackHandlers{ + OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *schema.AgenticMessage) context.Context { + if input == nil { + cnt++ + } + return ctx + }, + OnEnd: func(ctx context.Context, info *callbacks.RunInfo, input []*schema.AgenticMessage) context.Context { + if input == nil { + cnt++ + } + return ctx + }, + }) + + h2 := tpl2.Handler() + + // Pass wrong type (string) to trigger default case in convToolsNodeCallbackInput -> returns nil + h2.OnStart(ctx, &callbacks.RunInfo{Component: compose.ComponentOfToolsNode}, "wrong-input-type") + assert.Equal(t, 3, cnt) // +1 + + // Pass wrong type to trigger default case in convToolsNodeCallbackOutput -> returns nil + h2.OnEnd(ctx, &callbacks.RunInfo{Component: compose.ComponentOfToolsNode}, "wrong-output-type") + assert.Equal(t, 4, cnt) // +1 + + // Pass wrong type to trigger default case in convAgenticToolsNodeCallbackInput -> returns nil + h2.OnStart(ctx, &callbacks.RunInfo{Component: compose.ComponentOfAgenticToolsNode}, "wrong-input-type") + assert.Equal(t, 5, cnt) // +1 + + // Pass wrong type to trigger default case in convAgenticToolsNodeCallbackOutput -> returns nil + h2.OnEnd(ctx, &callbacks.RunInfo{Component: compose.ComponentOfAgenticToolsNode}, "wrong-output-type") + assert.Equal(t, 6, cnt) // +1 + + // 4. Test Needed for Agentic components when handlers are Set vs Unset + // tpl2 has AgenticToolsNode set + needed = h2.(callbacks.TimingChecker).Needed(ctx, &callbacks.RunInfo{Component: compose.ComponentOfAgenticToolsNode}, callbacks.TimingOnStart) + assert.True(t, needed) + + // tpl2 does NOT have AgenticModel set + needed = h2.(callbacks.TimingChecker).Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfAgenticModel}, callbacks.TimingOnStart) + assert.False(t, needed) + + // Set it now + tpl2.AgenticModel(&AgenticModelCallbackHandler{ + OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.AgenticCallbackInput) context.Context { + return ctx + }, + }) + + needed = h2.(callbacks.TimingChecker).Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfAgenticModel}, callbacks.TimingOnStart) + assert.True(t, needed) + + // Check invalid component + needed = h2.(callbacks.TimingChecker).Needed(ctx, &callbacks.RunInfo{Component: "UnknownComponent"}, callbacks.TimingOnStart) + assert.False(t, needed) + + // Check RunInfo nil + needed = h2.(callbacks.TimingChecker).Needed(ctx, nil, callbacks.TimingOnStart) + assert.False(t, needed) + + // 5. Test Needed for Transformer, Loader, Indexer, etc to ensure switch coverage + tpl3 := NewHandlerHelper(). + Transformer(&TransformerCallbackHandler{OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *document.TransformerCallbackInput) context.Context { + return ctx + }}). + Loader(&LoaderCallbackHandler{OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *document.LoaderCallbackInput) context.Context { + return ctx + }}). + Indexer(&IndexerCallbackHandler{OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *indexer.CallbackInput) context.Context { + return ctx + }}). + Retriever(&RetrieverCallbackHandler{OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *retriever.CallbackInput) context.Context { + return ctx + }}). + Embedding(&EmbeddingCallbackHandler{OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *embedding.CallbackInput) context.Context { + return ctx + }}). + Tool(&ToolCallbackHandler{OnStart: func(ctx context.Context, info *callbacks.RunInfo, input *tool.CallbackInput) context.Context { + return ctx + }}) + + h3 := tpl3.Handler() + checker := h3.(callbacks.TimingChecker) + + assert.True(t, checker.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfTransformer}, callbacks.TimingOnStart)) + assert.True(t, checker.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfLoader}, callbacks.TimingOnStart)) + assert.True(t, checker.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfIndexer}, callbacks.TimingOnStart)) + assert.True(t, checker.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfRetriever}, callbacks.TimingOnStart)) + assert.True(t, checker.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfEmbedding}, callbacks.TimingOnStart)) + assert.True(t, checker.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfTool}, callbacks.TimingOnStart)) + + // Verify False paths (by using a helper without them) + emptyH := NewHandlerHelper().Handler().(callbacks.TimingChecker) + assert.False(t, emptyH.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfTransformer}, callbacks.TimingOnStart)) + assert.False(t, emptyH.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfLoader}, callbacks.TimingOnStart)) + assert.False(t, emptyH.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfIndexer}, callbacks.TimingOnStart)) + assert.False(t, emptyH.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfRetriever}, callbacks.TimingOnStart)) + assert.False(t, emptyH.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfEmbedding}, callbacks.TimingOnStart)) + assert.False(t, emptyH.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfTool}, callbacks.TimingOnStart)) + + // 6. Test Needed for remaining components (ChatModel, Prompt, AgenticPrompt) + tpl4 := NewHandlerHelper(). + ChatModel(&ModelCallbackHandler{OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *model.CallbackInput) context.Context { + return ctx + }}). + Prompt(&PromptCallbackHandler{OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *prompt.CallbackInput) context.Context { + return ctx + }}). + AgenticPrompt(&AgenticPromptCallbackHandler{OnStart: func(ctx context.Context, runInfo *callbacks.RunInfo, input *prompt.CallbackInput) context.Context { + return ctx + }}) + + h4 := tpl4.Handler() + checker4 := h4.(callbacks.TimingChecker) + + assert.True(t, checker4.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfChatModel}, callbacks.TimingOnStart)) + assert.True(t, checker4.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfPrompt}, callbacks.TimingOnStart)) + assert.True(t, checker4.Needed(ctx, &callbacks.RunInfo{Component: components.ComponentOfAgenticPrompt}, callbacks.TimingOnStart)) }) }