Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions components/agentic/callback_extra.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ type Config struct {
// Model is the model name.
Model string
// Temperature is the temperature, which controls the randomness of the model.
Temperature float32
Temperature float64
// TopP is the top p, which controls the diversity of the model.
TopP float32
TopP float64
}

// CallbackInput is the input for the model callback.
Expand Down
7 changes: 5 additions & 2 deletions components/agentic/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ type Options struct {
TopP *float64
// Tools is a list of tools the model may call.
Tools []*schema.ToolInfo
// ToolChoice controls which tool is called by the model.
// ToolChoice controls how the model call the tools.
ToolChoice *schema.ToolChoice
// AllowedTools is a list of allowed tools the model may call.
AllowedTools []*schema.AllowedTool
}

// Option is the call option for ChatModel component.
Expand Down Expand Up @@ -81,10 +83,11 @@ func WithTools(tools []*schema.ToolInfo) Option {
}

// WithToolChoice is the option to set tool choice for the model.
func WithToolChoice(toolChoice schema.ToolChoice) Option {
func WithToolChoice(toolChoice schema.ToolChoice, allowedTools ...*schema.AllowedTool) Option {
return Option{
apply: func(opts *Options) {
opts.ToolChoice = &toolChoice
opts.AllowedTools = allowedTools
},
}
}
Expand Down
2 changes: 1 addition & 1 deletion components/agentic/option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func TestCommon(t *testing.T) {
WithTools([]*schema.ToolInfo{{Name: "test"}}),
WithModel("test"),
WithTemperature(0.1),
WithToolChoice(schema.ToolChoiceAllowed),
WithToolChoice(schema.ToolChoiceAllowed, []*schema.AllowedTool{{FunctionToolName: "test"}}...),
WithTopP(0.1),
)
assert.Len(t, o.Tools, 1)
Expand Down
6 changes: 3 additions & 3 deletions components/model/callback_extra.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,17 @@ type TokenUsage struct {
PromptTokenDetails PromptTokenDetails
// CompletionTokens is the number of completion tokens.
CompletionTokens int
// CompletionTokensDetails is a breakdown of the completion tokens.
CompletionTokensDetails CompletionTokensDetails
// TotalTokens is the total number of tokens.
TotalTokens int
// CompletionTokensDetails is breakdown of completion tokens.
CompletionTokensDetails CompletionTokensDetails `json:"completion_token_details"`
}

type CompletionTokensDetails struct {
// ReasoningTokens tokens generated by the model for reasoning.
// This is currently supported by OpenAI, Gemini, ARK and Qwen chat models.
// For other models, this field will be 0.
ReasoningTokens int `json:"reasoning_tokens,omitempty"`
ReasoningTokens int
}

type PromptTokenDetails struct {
Expand Down
7 changes: 4 additions & 3 deletions compose/tools_node_agentic.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ func agenticMessageToToolCallMessage(input *schema.AgenticMessage) *schema.Messa
Name: block.FunctionToolCall.Name,
Arguments: block.FunctionToolCall.Arguments,
},
Extra: block.Extra,
})
}
return &schema.Message{
Expand All @@ -87,8 +88,8 @@ func toolMessageToAgenticMessage(input []*schema.Message) []*schema.AgenticMessa
CallID: m.ToolCallID,
Name: m.ToolName,
Result: m.Content,
Extra: m.Extra,
},
Extra: m.Extra,
})
}
return []*schema.AgenticMessage{{
Expand All @@ -110,9 +111,9 @@ func streamToolMessageToAgenticMessage(input *schema.StreamReader[[]*schema.Mess
CallID: m.ToolCallID,
Name: m.ToolName,
Result: m.Content,
Extra: m.Extra,
},
StreamMeta: &schema.StreamMeta{Index: int64(i)},
StreamingMeta: &schema.StreamingMeta{Index: i},
Extra: m.Extra,
})
}
return []*schema.AgenticMessage{{
Expand Down
37 changes: 16 additions & 21 deletions compose/tools_node_agentic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"io"
"testing"

"github.com/bytedance/sonic"
"github.com/stretchr/testify/assert"

"github.com/cloudwego/eino/schema"
Expand Down Expand Up @@ -155,13 +156,14 @@ func TestStreamToolMessageToAgenticMessage(t *testing.T) {
nil,
},
{
nil,
{
Role: schema.Tool,
Content: "content1-2",
Content: "content2-2",
ToolName: "name2",
ToolCallID: "2",
},
nil, nil,
nil,
},
{
nil, nil,
Expand All @@ -172,16 +174,6 @@ func TestStreamToolMessageToAgenticMessage(t *testing.T) {
ToolCallID: "3",
},
},
{
nil,
{
Role: schema.Tool,
Content: "content2-2",
ToolName: "name2",
ToolCallID: "2",
},
nil,
},
{
nil, nil,
{
Expand All @@ -204,7 +196,11 @@ func TestStreamToolMessageToAgenticMessage(t *testing.T) {
}
result, err := schema.ConcatAgenticMessagesArray(chunks)
assert.NoError(t, err)
assert.Equal(t, []*schema.AgenticMessage{

actualStr, err := sonic.MarshalString(result)
assert.NoError(t, err)

expected := []*schema.AgenticMessage{
{
Role: schema.AgenticRoleTypeUser,
ContentBlocks: []*schema.ContentBlock{
Expand All @@ -213,32 +209,31 @@ func TestStreamToolMessageToAgenticMessage(t *testing.T) {
FunctionToolResult: &schema.FunctionToolResult{
CallID: "1",
Name: "name1",
Result: "content1-1content1-2",
Extra: map[string]interface{}{},
Result: "content1-1",
},
StreamMeta: &schema.StreamMeta{Index: 0},
},
{
Type: schema.ContentBlockTypeFunctionToolResult,
FunctionToolResult: &schema.FunctionToolResult{
CallID: "2",
Name: "name2",
Result: "content2-1content2-2",
Extra: map[string]interface{}{},
},
StreamMeta: &schema.StreamMeta{Index: 1},
},
{
Type: schema.ContentBlockTypeFunctionToolResult,
FunctionToolResult: &schema.FunctionToolResult{
CallID: "3",
Name: "name3",
Result: "content3-1content3-2",
Extra: map[string]interface{}{},
},
StreamMeta: &schema.StreamMeta{Index: 2},
},
},
},
}, result)
}

expectedStr, err := sonic.MarshalString(expected)
assert.NoError(t, err)

assert.Equal(t, expectedStr, actualStr)
}
Loading