diff --git a/pkg/model/provider/rulebased/client.go b/pkg/model/provider/rulebased/client.go index ea972cc73..5894de2b5 100644 --- a/pkg/model/provider/rulebased/client.go +++ b/pkg/model/provider/rulebased/client.go @@ -41,9 +41,10 @@ type ProviderFactory func(ctx context.Context, modelSpec string, models map[stri // Client implements the Provider interface for rule-based model routing. type Client struct { base.Config - routes []Provider - fallback Provider - index bleve.Index + routes []Provider + fallback Provider + index bleve.Index + lastSelectedID string // ID of the provider selected by the most recent call } // NewClient creates a new rule-based routing client. @@ -152,6 +153,7 @@ func filterOutMaxTokens(opts []options.Opt) []options.Opt { } // CreateChatCompletionStream selects a provider based on input and delegates the call. +// The selected provider's ID is recorded in LastSelectedModelID. func (c *Client) CreateChatCompletionStream( ctx context.Context, messages []chat.Message, @@ -162,15 +164,23 @@ func (c *Client) CreateChatCompletionStream( return nil, errors.New("no provider available for routing") } + c.lastSelectedID = provider.ID() slog.Debug("Rule-based router selected model", "router", c.ID(), - "selected_model", provider.ID(), + "selected_model", c.lastSelectedID, "message_count", len(messages), ) return provider.CreateChatCompletionStream(ctx, messages, availableTools) } +// LastSelectedModelID returns the ID of the provider selected by the most +// recent CreateChatCompletionStream call. This allows callers to display +// the YAML-configured sub-model name for rule-based routing. +func (c *Client) LastSelectedModelID() string { + return c.lastSelectedID +} + // selectProvider finds the best matching provider for the messages. // Bleve returns hits sorted by score, so the top hit determines the route. func (c *Client) selectProvider(messages []chat.Message) Provider { diff --git a/pkg/runtime/fallback.go b/pkg/runtime/fallback.go index a9caa0456..024665363 100644 --- a/pkg/runtime/fallback.go +++ b/pkg/runtime/fallback.go @@ -283,6 +283,15 @@ func (r *LocalRuntime) tryModelWithFallback( // Stream created successfully, now handle it slog.Debug("Processing stream", "agent", a.Name(), "model", modelEntry.provider.ID()) + + // If the provider is a rule-based router, notify the sidebar + // of the selected sub-model's YAML-configured name. + if rp, ok := modelEntry.provider.(interface{ LastSelectedModelID() string }); ok { + if selected := rp.LastSelectedModelID(); selected != "" { + events <- AgentInfo(a.Name(), selected, a.Description(), a.WelcomeMessage()) + } + } + res, err := r.handleStream(ctx, stream, a, agentTools, sess, m, events) if err != nil { lastErr = err diff --git a/pkg/runtime/loop.go b/pkg/runtime/loop.go index 10a570ebf..a4fb05f49 100644 --- a/pkg/runtime/loop.go +++ b/pkg/runtime/loop.go @@ -1,7 +1,6 @@ package runtime import ( - "cmp" "context" "errors" "fmt" @@ -86,10 +85,6 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c a := r.resolveSessionAgent(sess) - // Emit agent information for sidebar display - // Use getEffectiveModelID to account for active fallback cooldowns - events <- AgentInfo(a.Name(), r.getEffectiveModelID(a), a.Description(), a.WelcomeMessage()) - // Emit team information events <- TeamInfo(r.agentDetailsFromTeam(), a.Name()) @@ -210,7 +205,6 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c )) model := a.Model() - defaultModelID := r.getEffectiveModelID(a) // Per-tool model routing: use a cheaper model for this turn // if the previous tool calls specified one, then reset. @@ -236,10 +230,10 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c modelID := model.ID() - // Notify sidebar when this turn uses a different model (per-tool override). - if modelID != defaultModelID { - events <- AgentInfo(a.Name(), modelID, a.Description(), a.WelcomeMessage()) - } + // Notify sidebar of the model for this turn. For rule-based + // routing, the actual routed model is emitted from within the + // stream once the first chunk arrives. + events <- AgentInfo(a.Name(), modelID, a.Description(), a.WelcomeMessage()) slog.Debug("Using agent", "agent", a.Name(), "model", modelID) slog.Debug("Getting model definition", "model_id", modelID) @@ -311,16 +305,9 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c return } - // Update sidebar model info to reflect what was actually used this turn. - // Fallback models are sticky (cooldown system persists them), so we only - // emit once. Per-tool model overrides are temporary (one turn), so we - // emit the override and then revert to the agent's default. if usedModel != nil && usedModel.ID() != model.ID() { slog.Info("Used fallback model", "agent", a.Name(), "primary", model.ID(), "used", usedModel.ID()) events <- AgentInfo(a.Name(), usedModel.ID(), a.Description(), a.WelcomeMessage()) - } else if model.ID() != defaultModelID { - // Per-tool override was active: revert sidebar to the agent's default model. - events <- AgentInfo(a.Name(), defaultModelID, a.Description(), a.WelcomeMessage()) } streamSpan.SetAttributes( attribute.Int("tool.calls", len(res.Calls)), @@ -410,7 +397,7 @@ func (r *LocalRuntime) recordAssistantMessage( float64(res.Usage.CacheWriteTokens)*m.Cost.CacheWrite) / 1e6 } - messageModel := cmp.Or(res.ActualModel, modelID) + messageModel := modelID assistantMessage := chat.Message{ Role: chat.MessageRoleAssistant, diff --git a/pkg/runtime/runtime_test.go b/pkg/runtime/runtime_test.go index 916f9d3fa..216e66985 100644 --- a/pkg/runtime/runtime_test.go +++ b/pkg/runtime/runtime_test.go @@ -276,12 +276,12 @@ func TestSimple(t *testing.T) { require.Equal(t, chat.MessageRoleAssistant, msgAdded.Message.Message.Role) expectedEvents := []Event{ - AgentInfo("root", "test/mock-model", "", ""), TeamInfo([]AgentDetails{{Name: "root", Provider: "test", Model: "mock-model"}}, "root"), ToolsetInfo(0, false, "root"), UserMessage("Hi", sess.ID, nil, 0), StreamStarted(sess.ID, "root"), ToolsetInfo(0, false, "root"), + AgentInfo("root", "test/mock-model", "", ""), AgentChoice("root", sess.ID, "Hello"), MessageAdded(sess.ID, msgAdded.Message, "root"), NewTokenUsageEvent(sess.ID, "root", &Usage{InputTokens: 3, OutputTokens: 2, ContextLength: 5, LastMessage: &MessageUsage{ @@ -315,12 +315,12 @@ func TestMultipleContentChunks(t *testing.T) { require.NotNil(t, msgAdded.Message) expectedEvents := []Event{ - AgentInfo("root", "test/mock-model", "", ""), TeamInfo([]AgentDetails{{Name: "root", Provider: "test", Model: "mock-model"}}, "root"), ToolsetInfo(0, false, "root"), UserMessage("Please greet me", sess.ID, nil, 0), StreamStarted(sess.ID, "root"), ToolsetInfo(0, false, "root"), + AgentInfo("root", "test/mock-model", "", ""), AgentChoice("root", sess.ID, "Hello "), AgentChoice("root", sess.ID, "there, "), AgentChoice("root", sess.ID, "how "), @@ -356,12 +356,12 @@ func TestWithReasoning(t *testing.T) { require.NotNil(t, msgAdded.Message) expectedEvents := []Event{ - AgentInfo("root", "test/mock-model", "", ""), TeamInfo([]AgentDetails{{Name: "root", Provider: "test", Model: "mock-model"}}, "root"), ToolsetInfo(0, false, "root"), UserMessage("Hi", sess.ID, nil, 0), StreamStarted(sess.ID, "root"), ToolsetInfo(0, false, "root"), + AgentInfo("root", "test/mock-model", "", ""), AgentChoiceReasoning("root", sess.ID, "Let me think about this..."), AgentChoiceReasoning("root", sess.ID, " I should respond politely."), AgentChoice("root", sess.ID, "Hello, how can I help you?"), @@ -396,12 +396,12 @@ func TestMixedContentAndReasoning(t *testing.T) { require.NotNil(t, msgAdded.Message) expectedEvents := []Event{ - AgentInfo("root", "test/mock-model", "", ""), TeamInfo([]AgentDetails{{Name: "root", Provider: "test", Model: "mock-model"}}, "root"), ToolsetInfo(0, false, "root"), UserMessage("Hi there", sess.ID, nil, 0), StreamStarted(sess.ID, "root"), ToolsetInfo(0, false, "root"), + AgentInfo("root", "test/mock-model", "", ""), AgentChoiceReasoning("root", sess.ID, "The user wants a greeting"), AgentChoice("root", sess.ID, "Hello!"), AgentChoiceReasoning("root", sess.ID, " I should be friendly"), @@ -454,12 +454,12 @@ func TestErrorEvent(t *testing.T) { } require.Len(t, events, 8) - require.IsType(t, &AgentInfoEvent{}, events[0]) - require.IsType(t, &TeamInfoEvent{}, events[1]) - require.IsType(t, &ToolsetInfoEvent{}, events[2]) - require.IsType(t, &UserMessageEvent{}, events[3]) - require.IsType(t, &StreamStartedEvent{}, events[4]) - require.IsType(t, &ToolsetInfoEvent{}, events[5]) + require.IsType(t, &TeamInfoEvent{}, events[0]) + require.IsType(t, &ToolsetInfoEvent{}, events[1]) + require.IsType(t, &UserMessageEvent{}, events[2]) + require.IsType(t, &StreamStartedEvent{}, events[3]) + require.IsType(t, &ToolsetInfoEvent{}, events[4]) + require.IsType(t, &AgentInfoEvent{}, events[5]) require.IsType(t, &ErrorEvent{}, events[6]) require.IsType(t, &StreamStoppedEvent{}, events[7]) @@ -493,12 +493,11 @@ func TestContextCancellation(t *testing.T) { events = append(events, ev) } - require.GreaterOrEqual(t, len(events), 5) - require.IsType(t, &AgentInfoEvent{}, events[0]) - require.IsType(t, &TeamInfoEvent{}, events[1]) - require.IsType(t, &ToolsetInfoEvent{}, events[2]) - require.IsType(t, &UserMessageEvent{}, events[3]) - require.IsType(t, &StreamStartedEvent{}, events[4]) + require.GreaterOrEqual(t, len(events), 4) + require.IsType(t, &TeamInfoEvent{}, events[0]) + require.IsType(t, &ToolsetInfoEvent{}, events[1]) + require.IsType(t, &UserMessageEvent{}, events[2]) + require.IsType(t, &StreamStartedEvent{}, events[3]) require.IsType(t, &StreamStoppedEvent{}, events[len(events)-1]) } diff --git a/pkg/runtime/streaming.go b/pkg/runtime/streaming.go index 62b132076..ceea95c18 100644 --- a/pkg/runtime/streaming.go +++ b/pkg/runtime/streaming.go @@ -26,7 +26,6 @@ type streamResult struct { ThinkingSignature string ThoughtSignature []byte Stopped bool - ActualModel string Usage *chat.Usage RateLimit *chat.RateLimit } @@ -43,7 +42,6 @@ func (r *LocalRuntime) handleStream(ctx context.Context, stream chat.MessageStre var thinkingSignature string var thoughtSignature []byte var toolCalls []tools.ToolCall - var actualModel string var messageUsage *chat.Usage var messageRateLimit *chat.RateLimit @@ -102,11 +100,6 @@ func (r *LocalRuntime) handleStream(ctx context.Context, stream chat.MessageStre thoughtSignature = choice.Delta.ThoughtSignature } - // Capture the actual model from the stream response (useful for model routing) - if actualModel == "" && response.Model != "" { - actualModel = response.Model - } - if choice.FinishReason == chat.FinishReasonStop || choice.FinishReason == chat.FinishReasonLength { recordUsage() return streamResult{ @@ -116,7 +109,6 @@ func (r *LocalRuntime) handleStream(ctx context.Context, stream chat.MessageStre ThinkingSignature: thinkingSignature, ThoughtSignature: thoughtSignature, Stopped: true, - ActualModel: actualModel, Usage: messageUsage, RateLimit: messageRateLimit, }, nil @@ -191,7 +183,6 @@ func (r *LocalRuntime) handleStream(ctx context.Context, stream chat.MessageStre ThinkingSignature: thinkingSignature, ThoughtSignature: thoughtSignature, Stopped: stoppedDueToNoOutput, - ActualModel: actualModel, Usage: messageUsage, RateLimit: messageRateLimit, }, nil diff --git a/pkg/tui/components/sidebar/sidebar.go b/pkg/tui/components/sidebar/sidebar.go index e3db64c81..e01327c14 100644 --- a/pkg/tui/components/sidebar/sidebar.go +++ b/pkg/tui/components/sidebar/sidebar.go @@ -268,8 +268,13 @@ func checkReasoningSupportCmd(ctx context.Context, modelID string) tea.Cmd { } } -// SetAgentInfo sets the current agent information and updates the model in availableAgents +// SetAgentInfo sets the current agent information and updates the model in availableAgents. +// It no-ops when the values are unchanged to avoid unnecessary cache invalidation and re-renders. func (m *model) SetAgentInfo(agentName, modelID, description string) tea.Cmd { + if m.currentAgent == agentName && m.agentModel == modelID && m.agentDescription == description { + return nil + } + m.currentAgent = agentName m.agentModel = modelID m.agentDescription = description