Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions pkg/model/provider/rulebased/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ type ProviderFactory func(ctx context.Context, modelSpec string, models map[stri
// Client implements the Provider interface for rule-based model routing.
type Client struct {
base.Config
routes []Provider
fallback Provider
index bleve.Index
routes []Provider
fallback Provider
index bleve.Index
lastSelectedID string // ID of the provider selected by the most recent call
}

// NewClient creates a new rule-based routing client.
Expand Down Expand Up @@ -152,6 +153,7 @@ func filterOutMaxTokens(opts []options.Opt) []options.Opt {
}

// CreateChatCompletionStream selects a provider based on input and delegates the call.
// The selected provider's ID is recorded in LastSelectedModelID.
func (c *Client) CreateChatCompletionStream(
ctx context.Context,
messages []chat.Message,
Expand All @@ -162,15 +164,23 @@ func (c *Client) CreateChatCompletionStream(
return nil, errors.New("no provider available for routing")
}

c.lastSelectedID = provider.ID()
slog.Debug("Rule-based router selected model",
"router", c.ID(),
"selected_model", provider.ID(),
"selected_model", c.lastSelectedID,
"message_count", len(messages),
)

return provider.CreateChatCompletionStream(ctx, messages, availableTools)
}

// LastSelectedModelID returns the ID of the provider selected by the most
// recent CreateChatCompletionStream call. This allows callers to display
// the YAML-configured sub-model name for rule-based routing.
func (c *Client) LastSelectedModelID() string {
return c.lastSelectedID
}

// selectProvider finds the best matching provider for the messages.
// Bleve returns hits sorted by score, so the top hit determines the route.
func (c *Client) selectProvider(messages []chat.Message) Provider {
Expand Down
9 changes: 9 additions & 0 deletions pkg/runtime/fallback.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,15 @@ func (r *LocalRuntime) tryModelWithFallback(

// Stream created successfully, now handle it
slog.Debug("Processing stream", "agent", a.Name(), "model", modelEntry.provider.ID())

// If the provider is a rule-based router, notify the sidebar
// of the selected sub-model's YAML-configured name.
if rp, ok := modelEntry.provider.(interface{ LastSelectedModelID() string }); ok {
if selected := rp.LastSelectedModelID(); selected != "" {
events <- AgentInfo(a.Name(), selected, a.Description(), a.WelcomeMessage())
}
}

res, err := r.handleStream(ctx, stream, a, agentTools, sess, m, events)
if err != nil {
lastErr = err
Expand Down
23 changes: 5 additions & 18 deletions pkg/runtime/loop.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package runtime

import (
"cmp"
"context"
"errors"
"fmt"
Expand Down Expand Up @@ -86,10 +85,6 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c

a := r.resolveSessionAgent(sess)

// Emit agent information for sidebar display
// Use getEffectiveModelID to account for active fallback cooldowns
events <- AgentInfo(a.Name(), r.getEffectiveModelID(a), a.Description(), a.WelcomeMessage())

// Emit team information
events <- TeamInfo(r.agentDetailsFromTeam(), a.Name())

Expand Down Expand Up @@ -210,7 +205,6 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
))

model := a.Model()
defaultModelID := r.getEffectiveModelID(a)

// Per-tool model routing: use a cheaper model for this turn
// if the previous tool calls specified one, then reset.
Expand All @@ -236,10 +230,10 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c

modelID := model.ID()

// Notify sidebar when this turn uses a different model (per-tool override).
if modelID != defaultModelID {
events <- AgentInfo(a.Name(), modelID, a.Description(), a.WelcomeMessage())
}
// Notify sidebar of the model for this turn. For rule-based
// routing, the actual routed model is emitted from within the
// stream once the first chunk arrives.
events <- AgentInfo(a.Name(), modelID, a.Description(), a.WelcomeMessage())

slog.Debug("Using agent", "agent", a.Name(), "model", modelID)
slog.Debug("Getting model definition", "model_id", modelID)
Expand Down Expand Up @@ -311,16 +305,9 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
return
}

// Update sidebar model info to reflect what was actually used this turn.
// Fallback models are sticky (cooldown system persists them), so we only
// emit once. Per-tool model overrides are temporary (one turn), so we
// emit the override and then revert to the agent's default.
if usedModel != nil && usedModel.ID() != model.ID() {
slog.Info("Used fallback model", "agent", a.Name(), "primary", model.ID(), "used", usedModel.ID())
events <- AgentInfo(a.Name(), usedModel.ID(), a.Description(), a.WelcomeMessage())
} else if model.ID() != defaultModelID {
// Per-tool override was active: revert sidebar to the agent's default model.
events <- AgentInfo(a.Name(), defaultModelID, a.Description(), a.WelcomeMessage())
}
streamSpan.SetAttributes(
attribute.Int("tool.calls", len(res.Calls)),
Expand Down Expand Up @@ -410,7 +397,7 @@ func (r *LocalRuntime) recordAssistantMessage(
float64(res.Usage.CacheWriteTokens)*m.Cost.CacheWrite) / 1e6
}

messageModel := cmp.Or(res.ActualModel, modelID)
messageModel := modelID

assistantMessage := chat.Message{
Role: chat.MessageRoleAssistant,
Expand Down
31 changes: 15 additions & 16 deletions pkg/runtime/runtime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,12 +276,12 @@ func TestSimple(t *testing.T) {
require.Equal(t, chat.MessageRoleAssistant, msgAdded.Message.Message.Role)

expectedEvents := []Event{
AgentInfo("root", "test/mock-model", "", ""),
TeamInfo([]AgentDetails{{Name: "root", Provider: "test", Model: "mock-model"}}, "root"),
ToolsetInfo(0, false, "root"),
UserMessage("Hi", sess.ID, nil, 0),
StreamStarted(sess.ID, "root"),
ToolsetInfo(0, false, "root"),
AgentInfo("root", "test/mock-model", "", ""),
AgentChoice("root", sess.ID, "Hello"),
MessageAdded(sess.ID, msgAdded.Message, "root"),
NewTokenUsageEvent(sess.ID, "root", &Usage{InputTokens: 3, OutputTokens: 2, ContextLength: 5, LastMessage: &MessageUsage{
Expand Down Expand Up @@ -315,12 +315,12 @@ func TestMultipleContentChunks(t *testing.T) {
require.NotNil(t, msgAdded.Message)

expectedEvents := []Event{
AgentInfo("root", "test/mock-model", "", ""),
TeamInfo([]AgentDetails{{Name: "root", Provider: "test", Model: "mock-model"}}, "root"),
ToolsetInfo(0, false, "root"),
UserMessage("Please greet me", sess.ID, nil, 0),
StreamStarted(sess.ID, "root"),
ToolsetInfo(0, false, "root"),
AgentInfo("root", "test/mock-model", "", ""),
AgentChoice("root", sess.ID, "Hello "),
AgentChoice("root", sess.ID, "there, "),
AgentChoice("root", sess.ID, "how "),
Expand Down Expand Up @@ -356,12 +356,12 @@ func TestWithReasoning(t *testing.T) {
require.NotNil(t, msgAdded.Message)

expectedEvents := []Event{
AgentInfo("root", "test/mock-model", "", ""),
TeamInfo([]AgentDetails{{Name: "root", Provider: "test", Model: "mock-model"}}, "root"),
ToolsetInfo(0, false, "root"),
UserMessage("Hi", sess.ID, nil, 0),
StreamStarted(sess.ID, "root"),
ToolsetInfo(0, false, "root"),
AgentInfo("root", "test/mock-model", "", ""),
AgentChoiceReasoning("root", sess.ID, "Let me think about this..."),
AgentChoiceReasoning("root", sess.ID, " I should respond politely."),
AgentChoice("root", sess.ID, "Hello, how can I help you?"),
Expand Down Expand Up @@ -396,12 +396,12 @@ func TestMixedContentAndReasoning(t *testing.T) {
require.NotNil(t, msgAdded.Message)

expectedEvents := []Event{
AgentInfo("root", "test/mock-model", "", ""),
TeamInfo([]AgentDetails{{Name: "root", Provider: "test", Model: "mock-model"}}, "root"),
ToolsetInfo(0, false, "root"),
UserMessage("Hi there", sess.ID, nil, 0),
StreamStarted(sess.ID, "root"),
ToolsetInfo(0, false, "root"),
AgentInfo("root", "test/mock-model", "", ""),
AgentChoiceReasoning("root", sess.ID, "The user wants a greeting"),
AgentChoice("root", sess.ID, "Hello!"),
AgentChoiceReasoning("root", sess.ID, " I should be friendly"),
Expand Down Expand Up @@ -454,12 +454,12 @@ func TestErrorEvent(t *testing.T) {
}

require.Len(t, events, 8)
require.IsType(t, &AgentInfoEvent{}, events[0])
require.IsType(t, &TeamInfoEvent{}, events[1])
require.IsType(t, &ToolsetInfoEvent{}, events[2])
require.IsType(t, &UserMessageEvent{}, events[3])
require.IsType(t, &StreamStartedEvent{}, events[4])
require.IsType(t, &ToolsetInfoEvent{}, events[5])
require.IsType(t, &TeamInfoEvent{}, events[0])
require.IsType(t, &ToolsetInfoEvent{}, events[1])
require.IsType(t, &UserMessageEvent{}, events[2])
require.IsType(t, &StreamStartedEvent{}, events[3])
require.IsType(t, &ToolsetInfoEvent{}, events[4])
require.IsType(t, &AgentInfoEvent{}, events[5])
require.IsType(t, &ErrorEvent{}, events[6])
require.IsType(t, &StreamStoppedEvent{}, events[7])

Expand Down Expand Up @@ -493,12 +493,11 @@ func TestContextCancellation(t *testing.T) {
events = append(events, ev)
}

require.GreaterOrEqual(t, len(events), 5)
require.IsType(t, &AgentInfoEvent{}, events[0])
require.IsType(t, &TeamInfoEvent{}, events[1])
require.IsType(t, &ToolsetInfoEvent{}, events[2])
require.IsType(t, &UserMessageEvent{}, events[3])
require.IsType(t, &StreamStartedEvent{}, events[4])
require.GreaterOrEqual(t, len(events), 4)
require.IsType(t, &TeamInfoEvent{}, events[0])
require.IsType(t, &ToolsetInfoEvent{}, events[1])
require.IsType(t, &UserMessageEvent{}, events[2])
require.IsType(t, &StreamStartedEvent{}, events[3])
require.IsType(t, &StreamStoppedEvent{}, events[len(events)-1])
}

Expand Down
9 changes: 0 additions & 9 deletions pkg/runtime/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ type streamResult struct {
ThinkingSignature string
ThoughtSignature []byte
Stopped bool
ActualModel string
Usage *chat.Usage
RateLimit *chat.RateLimit
}
Expand All @@ -43,7 +42,6 @@ func (r *LocalRuntime) handleStream(ctx context.Context, stream chat.MessageStre
var thinkingSignature string
var thoughtSignature []byte
var toolCalls []tools.ToolCall
var actualModel string
var messageUsage *chat.Usage
var messageRateLimit *chat.RateLimit

Expand Down Expand Up @@ -102,11 +100,6 @@ func (r *LocalRuntime) handleStream(ctx context.Context, stream chat.MessageStre
thoughtSignature = choice.Delta.ThoughtSignature
}

// Capture the actual model from the stream response (useful for model routing)
if actualModel == "" && response.Model != "" {
actualModel = response.Model
}

if choice.FinishReason == chat.FinishReasonStop || choice.FinishReason == chat.FinishReasonLength {
recordUsage()
return streamResult{
Expand All @@ -116,7 +109,6 @@ func (r *LocalRuntime) handleStream(ctx context.Context, stream chat.MessageStre
ThinkingSignature: thinkingSignature,
ThoughtSignature: thoughtSignature,
Stopped: true,
ActualModel: actualModel,
Usage: messageUsage,
RateLimit: messageRateLimit,
}, nil
Expand Down Expand Up @@ -191,7 +183,6 @@ func (r *LocalRuntime) handleStream(ctx context.Context, stream chat.MessageStre
ThinkingSignature: thinkingSignature,
ThoughtSignature: thoughtSignature,
Stopped: stoppedDueToNoOutput,
ActualModel: actualModel,
Usage: messageUsage,
RateLimit: messageRateLimit,
}, nil
Expand Down
7 changes: 6 additions & 1 deletion pkg/tui/components/sidebar/sidebar.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,13 @@ func checkReasoningSupportCmd(ctx context.Context, modelID string) tea.Cmd {
}
}

// SetAgentInfo sets the current agent information and updates the model in availableAgents
// SetAgentInfo sets the current agent information and updates the model in availableAgents.
// It no-ops when the values are unchanged to avoid unnecessary cache invalidation and re-renders.
func (m *model) SetAgentInfo(agentName, modelID, description string) tea.Cmd {
if m.currentAgent == agentName && m.agentModel == modelID && m.agentDescription == description {
return nil
}

m.currentAgent = agentName
m.agentModel = modelID
m.agentDescription = description
Expand Down
Loading