From 13de8159b58a10aea9c1d32124639d63ad50b928 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Wed, 25 Mar 2026 09:59:56 +0100 Subject: [PATCH 1/2] Rework compaction Signed-off-by: David Gageot --- pkg/compaction/compaction.go | 80 +----------- pkg/compaction/compaction_test.go | 135 -------------------- pkg/runtime/loop.go | 7 +- pkg/runtime/runtime_test.go | 31 ----- pkg/runtime/session_compaction.go | 169 +++++++++++++++++++------ pkg/runtime/session_compaction_test.go | 123 ++++++++++++++++++ pkg/session/session.go | 6 + 7 files changed, 264 insertions(+), 287 deletions(-) create mode 100644 pkg/runtime/session_compaction_test.go diff --git a/pkg/compaction/compaction.go b/pkg/compaction/compaction.go index 23367ac9b..90a395f53 100644 --- a/pkg/compaction/compaction.go +++ b/pkg/compaction/compaction.go @@ -1,22 +1,7 @@ -// Package compaction provides conversation compaction (summarization) for -// chat sessions that approach their model's context window limit. -// -// It is designed as a standalone component that can be used independently of -// the runtime loop. The package exposes: -// -// - [BuildPrompt]: prepares a conversation for summarization by appending -// the compaction prompt and sanitizing message costs. -// - [ShouldCompact]: decides whether a session needs compaction based on -// token usage and context window limits. -// - [EstimateMessageTokens]: a fast heuristic for estimating the token -// count of a single chat message. -// - [HasConversationMessages]: checks whether a message list contains any -// non-system messages worth summarizing. package compaction import ( _ "embed" - "time" "github.com/docker/docker-agent/pkg/chat" ) @@ -26,7 +11,7 @@ var ( SystemPrompt string //go:embed prompts/compaction-user.txt - userPrompt string + UserPrompt string ) // contextThreshold is the fraction of the context window at which compaction @@ -34,60 +19,15 @@ var ( // context limit, compaction is recommended. const contextThreshold = 0.9 -// Result holds the outcome of a compaction operation. -type Result struct { - // Summary is the generated summary text. - Summary string - - // InputTokens is the token count reported by the summarization model, - // used as an approximation of the new context size after compaction. - InputTokens int64 - - // Cost is the cost of the summarization request in dollars. - Cost float64 -} - -// BuildPrompt prepares the messages for a summarization request. -// It clones the conversation (zeroing per-message costs so they aren't -// double-counted), then appends a user message containing the compaction -// prompt. If additionalPrompt is non-empty it is included as extra -// instructions. -// -// Callers should first check [HasConversationMessages] to avoid sending -// an empty conversation to the model. -func BuildPrompt(messages []chat.Message, additionalPrompt string) []chat.Message { - prompt := userPrompt - if additionalPrompt != "" { - prompt += "\n\nAdditional instructions from user: " + additionalPrompt - } - - out := make([]chat.Message, len(messages), len(messages)+1) - for i, msg := range messages { - cloned := msg - cloned.Cost = 0 - cloned.CacheControl = false - out[i] = cloned - } - out = append(out, chat.Message{ - Role: chat.MessageRoleUser, - Content: prompt, - CreatedAt: time.Now().Format(time.RFC3339), - }) - - return out -} - // ShouldCompact reports whether a session's context usage has crossed the -// compaction threshold. It returns true when the estimated total token count +// compaction threshold. It returns true when the total token count // (input + output + addedTokens) exceeds [contextThreshold] (90%) of -// contextLimit. A non-positive contextLimit is treated as unlimited and -// always returns false. +// contextLimit. func ShouldCompact(inputTokens, outputTokens, addedTokens, contextLimit int64) bool { if contextLimit <= 0 { return false } - estimated := inputTokens + outputTokens + addedTokens - return estimated > int64(float64(contextLimit)*contextThreshold) + return (inputTokens + outputTokens + addedTokens) > int64(float64(contextLimit)*contextThreshold) } // EstimateMessageTokens returns a rough token-count estimate for a single @@ -121,15 +61,3 @@ func EstimateMessageTokens(msg *chat.Message) int64 { } return int64(chars/charsPerToken) + perMessageOverhead } - -// HasConversationMessages reports whether messages contains at least one -// non-system message. A session with only system prompts has no conversation -// to summarize. -func HasConversationMessages(messages []chat.Message) bool { - for _, msg := range messages { - if msg.Role != chat.MessageRoleSystem { - return true - } - } - return false -} diff --git a/pkg/compaction/compaction_test.go b/pkg/compaction/compaction_test.go index 7dde3b155..490c730c1 100644 --- a/pkg/compaction/compaction_test.go +++ b/pkg/compaction/compaction_test.go @@ -4,7 +4,6 @@ import ( "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "github.com/docker/docker-agent/pkg/chat" "github.com/docker/docker-agent/pkg/tools" @@ -164,137 +163,3 @@ func TestShouldCompact(t *testing.T) { }) } } - -func TestHasConversationMessages(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - messages []chat.Message - want bool - }{ - { - name: "empty", - messages: nil, - want: false, - }, - { - name: "system only", - messages: []chat.Message{ - {Role: chat.MessageRoleSystem, Content: "You are helpful."}, - }, - want: false, - }, - { - name: "system and user", - messages: []chat.Message{ - {Role: chat.MessageRoleSystem, Content: "You are helpful."}, - {Role: chat.MessageRoleUser, Content: "Hello"}, - }, - want: true, - }, - { - name: "only user", - messages: []chat.Message{ - {Role: chat.MessageRoleUser, Content: "Hello"}, - }, - want: true, - }, - { - name: "assistant message", - messages: []chat.Message{ - {Role: chat.MessageRoleAssistant, Content: "Hi there"}, - }, - want: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - got := HasConversationMessages(tt.messages) - assert.Equal(t, tt.want, got) - }) - } -} - -func TestBuildPrompt(t *testing.T) { - t.Parallel() - - messages := []chat.Message{ - {Role: chat.MessageRoleSystem, Content: "You are helpful."}, - {Role: chat.MessageRoleUser, Content: "Hello", Cost: 0.05}, - {Role: chat.MessageRoleAssistant, Content: "Hi!", Cost: 0.10}, - } - - t.Run("basic", func(t *testing.T) { - t.Parallel() - - out := BuildPrompt(messages, "") - - // Original messages + appended summarization prompt. - require.Len(t, out, 4) - - // Costs are zeroed. - for _, msg := range out[:3] { - assert.Zero(t, msg.Cost, "cost should be zeroed for %q", msg.Content) - } - - // Last message is the summarization prompt. - last := out[len(out)-1] - assert.Equal(t, chat.MessageRoleUser, last.Role) - assert.Contains(t, last.Content, "summary") - assert.NotEmpty(t, last.CreatedAt) - }) - - t.Run("with additional prompt", func(t *testing.T) { - t.Parallel() - - out := BuildPrompt(messages, "focus on code changes") - - last := out[len(out)-1] - assert.Contains(t, last.Content, "Additional instructions from user: focus on code changes") - }) - - t.Run("does not modify original messages", func(t *testing.T) { - t.Parallel() - - original := []chat.Message{ - {Role: chat.MessageRoleUser, Content: "Hello", Cost: 0.05}, - } - - _ = BuildPrompt(original, "") - - assert.InDelta(t, 0.05, original[0].Cost, 1e-9) - assert.Len(t, original, 1) - }) - - t.Run("strips CacheControl from cloned messages", func(t *testing.T) { - t.Parallel() - - input := []chat.Message{ - {Role: chat.MessageRoleSystem, Content: "system", CacheControl: true}, - {Role: chat.MessageRoleSystem, Content: "context", CacheControl: true}, - {Role: chat.MessageRoleUser, Content: "hello"}, - } - - out := BuildPrompt(input, "") - - // All cloned messages should have CacheControl=false - for i, msg := range out { - assert.False(t, msg.CacheControl, "message %d should have CacheControl stripped", i) - } - // Original should be unchanged - assert.True(t, input[0].CacheControl) - assert.True(t, input[1].CacheControl) - }) -} - -func TestPromptsAreEmbedded(t *testing.T) { - t.Parallel() - - assert.NotEmpty(t, SystemPrompt, "compaction system prompt should be embedded") - assert.NotEmpty(t, userPrompt, "compaction user prompt should be embedded") - assert.Contains(t, SystemPrompt, "summary") - assert.Contains(t, userPrompt, "summary") -} diff --git a/pkg/runtime/loop.go b/pkg/runtime/loop.go index fe5f11283..b7c194e75 100644 --- a/pkg/runtime/loop.go +++ b/pkg/runtime/loop.go @@ -248,13 +248,14 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c slog.Debug("Failed to get model definition", "error", err) } + // We can only compact if we know the limit. var contextLimit int64 if m != nil { contextLimit = int64(m.Limit.Context) - } - if m != nil && r.sessionCompaction && compaction.ShouldCompact(sess.InputTokens, sess.OutputTokens, 0, contextLimit) { - r.Summarize(ctx, sess, "", events) + if r.sessionCompaction && compaction.ShouldCompact(sess.InputTokens, sess.OutputTokens, 0, contextLimit) { + r.Summarize(ctx, sess, "", events) + } } messages := sess.GetMessages(a) diff --git a/pkg/runtime/runtime_test.go b/pkg/runtime/runtime_test.go index 4bb92cb9e..247dc4bbd 100644 --- a/pkg/runtime/runtime_test.go +++ b/pkg/runtime/runtime_test.go @@ -737,37 +737,6 @@ func TestNewRuntime_InvalidCurrentAgentError(t *testing.T) { require.Contains(t, err.Error(), "agent not found: other (available agents: root)") } -func TestSummarize_EmptySession(t *testing.T) { - prov := &mockProvider{id: "test/mock-model", stream: &mockStream{}} - root := agent.New("root", "You are a test agent", agent.WithModel(prov)) - tm := team.New(team.WithAgents(root)) - - rt, err := NewLocalRuntime(tm, WithSessionCompaction(false), WithModelStore(mockModelStore{})) - require.NoError(t, err) - - sess := session.New() - sess.Title = "Empty Session Test" - - // Try to summarize the empty session - events := make(chan Event, 10) - rt.Summarize(t.Context(), sess, "", events) - close(events) - - // Collect events - var warningFound bool - var warningMsg string - for ev := range events { - if warningEvent, ok := ev.(*WarningEvent); ok { - warningFound = true - warningMsg = warningEvent.Message - } - } - - // Should have received a warning event about empty session - require.True(t, warningFound, "expected a warning event for empty session") - require.Contains(t, warningMsg, "empty", "warning message should mention empty session") -} - func TestProcessToolCalls_UnknownTool_ReturnsErrorResponse(t *testing.T) { root := agent.New("root", "You are a test agent", agent.WithModel(&mockProvider{})) tm := team.New(team.WithAgents(root)) diff --git a/pkg/runtime/session_compaction.go b/pkg/runtime/session_compaction.go index 72db1e184..714038e4c 100644 --- a/pkg/runtime/session_compaction.go +++ b/pkg/runtime/session_compaction.go @@ -2,7 +2,9 @@ package runtime import ( "context" + "errors" "log/slog" + "time" "github.com/docker/docker-agent/pkg/agent" "github.com/docker/docker-agent/pkg/chat" @@ -13,72 +15,155 @@ import ( "github.com/docker/docker-agent/pkg/team" ) -// runSummarization sends the prepared messages through a one-shot runtime -// and returns the model's summary together with the output token count and -// cost. The runtime is created with compaction disabled so it cannot recurse. -func runSummarization(ctx context.Context, model provider.Provider, messages []chat.Message) (compaction.Result, error) { - summaryModel := provider.CloneWithOptions(ctx, model, options.WithStructuredOutput(nil)) - root := agent.New("root", compaction.SystemPrompt, agent.WithModel(summaryModel)) - t := team.New(team.WithAgents(root)) - - sess := session.New() - sess.Title = "Generating summary..." - for _, msg := range messages { - sess.AddMessage(&session.Message{Message: msg}) - } - - rt, err := New(t, WithSessionCompaction(false)) - if err != nil { - return compaction.Result{}, err - } - if _, err = rt.Run(ctx, sess); err != nil { - return compaction.Result{}, err - } - - return compaction.Result{ - Summary: sess.GetLastAssistantMessageContent(), - InputTokens: sess.OutputTokens, - Cost: sess.TotalCost(), - }, nil -} +const maxSummaryTokens = 16_000 // doCompact runs compaction on a session and applies the result (events, // persistence, token count updates). The agent is used to extract the // conversation from the session and to obtain the model for summarization. func (r *LocalRuntime) doCompact(ctx context.Context, sess *session.Session, a *agent.Agent, additionalPrompt string, events chan Event) { slog.Debug("Generating summary for session", "session_id", sess.ID) - events <- SessionCompaction(sess.ID, "started", a.Name()) defer func() { events <- SessionCompaction(sess.ID, "completed", a.Name()) }() - messages := sess.GetMessages(a) - if !compaction.HasConversationMessages(messages) { - if additionalPrompt == "" { - events <- Warning("Session is empty. Start a conversation before compacting.", a.Name()) - } + // Build a model just for compaction. + summaryModel := provider.CloneWithOptions(ctx, a.Model(), + options.WithStructuredOutput(nil), + options.WithMaxTokens(maxSummaryTokens), + ) + m, err := r.modelsStore.GetModel(ctx, summaryModel.ID()) + if err != nil { + slog.Error("Failed to generate session summary", "error", errors.New("failed to get model definition")) + events <- Error("Failed to get model definition") return } - prepared := compaction.BuildPrompt(messages, additionalPrompt) + compactionAgent := agent.New("root", compaction.SystemPrompt, agent.WithModel(summaryModel)) + + // Compute the messages to compact. + messages := extractMessagesToCompact(sess, compactionAgent, int64(m.Limit.Context), additionalPrompt) + + // Run the compaction. + compactionSession := session.New( + session.WithTitle("Generating summary"), + session.WithMessages(toItems(messages)), + ) - result, err := runSummarization(ctx, a.Model(), prepared) + t := team.New(team.WithAgents(compactionAgent)) + rt, err := New(t, WithSessionCompaction(false)) if err != nil { slog.Error("Failed to generate session summary", "error", err) events <- Error(err.Error()) return } - if result.Summary == "" { + if _, err = rt.Run(ctx, compactionSession); err != nil { + slog.Error("Failed to generate session summary", "error", err) + events <- Error(err.Error()) return } - sess.Messages = append(sess.Messages, session.Item{Summary: result.Summary, Cost: result.Cost}) - sess.InputTokens = result.InputTokens - sess.OutputTokens = 0 + summary := compactionSession.GetLastAssistantMessageContent() + if summary == "" { + return + } + // Update the session. + sess.InputTokens = compactionSession.OutputTokens + sess.OutputTokens = 0 + sess.Messages = append(sess.Messages, session.Item{ + Summary: summary, + Cost: compactionSession.TotalCost(), + }) _ = r.sessionStore.UpdateSession(ctx, sess) - slog.Debug("Generated session summary", "session_id", sess.ID, "summary_length", len(result.Summary), "compaction_cost", result.Cost) - events <- SessionSummary(sess.ID, result.Summary, a.Name()) + slog.Debug("Generated session summary", "session_id", sess.ID, "summary_length", len(summary)) + events <- SessionSummary(sess.ID, summary, a.Name()) +} + +func extractMessagesToCompact(sess *session.Session, compactionAgent *agent.Agent, contextLimit int64, additionalPrompt string) []chat.Message { + // Add all the existing messages. + var messages []chat.Message + for _, msg := range sess.GetMessages(compactionAgent) { + if msg.Role == chat.MessageRoleSystem { + continue + } + + msg.Cost = 0 + msg.CacheControl = false + + messages = append(messages, msg) + } + + // Prepare the first (system) message. + systemPromptMessage := chat.Message{ + Role: chat.MessageRoleSystem, + Content: compaction.SystemPrompt, + CreatedAt: time.Now().Format(time.RFC3339), + } + systemPromptMessageLen := compaction.EstimateMessageTokens(&systemPromptMessage) + + // Prepare the last (user) message. + userPrompt := compaction.UserPrompt + if additionalPrompt != "" { + userPrompt += "\n\n" + additionalPrompt + } + userPromptMessage := chat.Message{ + Role: chat.MessageRoleUser, + Content: userPrompt, + CreatedAt: time.Now().Format(time.RFC3339), + } + userPromptMessageLen := compaction.EstimateMessageTokens(&userPromptMessage) + + // Truncate the messages so that they fit in the available context limit + // (minus the expected max length of the summary). + contextAvailable := max(0, contextLimit-maxSummaryTokens-systemPromptMessageLen-userPromptMessageLen) + firstIndex := firstMessageToKeep(messages, contextAvailable) + if firstIndex < len(messages) { + messages = messages[firstIndex:] + } else { + messages = nil + } + + // Prepend the first (system) message. + messages = append([]chat.Message{systemPromptMessage}, messages...) + + // Append the last (user) message. + messages = append(messages, userPromptMessage) + + return messages +} + +func firstMessageToKeep(messages []chat.Message, contextLimit int64) int { + var tokens int64 + + lastValidMessageSeen := len(messages) + + for i := len(messages) - 1; i >= 0; i-- { + tokens += compaction.EstimateMessageTokens(&messages[i]) + if tokens > contextLimit { + return lastValidMessageSeen + } + + role := messages[i].Role + if role == chat.MessageRoleUser || role == chat.MessageRoleAssistant { + lastValidMessageSeen = i + } + } + + return lastValidMessageSeen +} + +func toItems(messages []chat.Message) []session.Item { + var items []session.Item + + for _, message := range messages { + items = append(items, session.Item{ + Message: &session.Message{ + Message: message, + }, + }) + } + + return items } diff --git a/pkg/runtime/session_compaction_test.go b/pkg/runtime/session_compaction_test.go new file mode 100644 index 000000000..57d3c73f3 --- /dev/null +++ b/pkg/runtime/session_compaction_test.go @@ -0,0 +1,123 @@ +package runtime + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/docker/docker-agent/pkg/agent" + "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/compaction" + "github.com/docker/docker-agent/pkg/session" +) + +func TestExtractMessagesToCompact(t *testing.T) { + newMsg := func(role chat.MessageRole, content string) session.Item { + return session.NewMessageItem(&session.Message{ + Message: chat.Message{Role: role, Content: content}, + }) + } + + tests := []struct { + name string + messages []session.Item + contextLimit int64 + additionalPrompt string + wantConversationMsgCount int + }{ + { + name: "empty session returns system and user prompt only", + messages: nil, + contextLimit: 100_000, + wantConversationMsgCount: 0, + }, + { + name: "system messages are filtered out", + messages: []session.Item{ + newMsg(chat.MessageRoleSystem, "system instruction"), + newMsg(chat.MessageRoleUser, "hello"), + newMsg(chat.MessageRoleAssistant, "hi"), + }, + contextLimit: 100_000, + wantConversationMsgCount: 2, + }, + { + name: "messages fit within context limit", + messages: []session.Item{ + newMsg(chat.MessageRoleUser, "msg1"), + newMsg(chat.MessageRoleAssistant, "msg2"), + newMsg(chat.MessageRoleUser, "msg3"), + newMsg(chat.MessageRoleAssistant, "msg4"), + }, + contextLimit: 100_000, + wantConversationMsgCount: 4, + }, + { + name: "truncation when context limit is very small", + messages: []session.Item{ + newMsg(chat.MessageRoleUser, "first message with lots of content that takes tokens"), + newMsg(chat.MessageRoleAssistant, "first response with lots of content that takes tokens"), + newMsg(chat.MessageRoleUser, "second message"), + newMsg(chat.MessageRoleAssistant, "second response"), + }, + // Set context limit so small that after subtracting maxSummaryTokens + prompt overhead, + // not all messages fit. + contextLimit: maxSummaryTokens + 50, + wantConversationMsgCount: 0, + }, + { + name: "additional prompt is appended", + messages: []session.Item{ + newMsg(chat.MessageRoleUser, "hello"), + }, + contextLimit: 100_000, + additionalPrompt: "focus on code quality", + wantConversationMsgCount: 1, + }, + { + name: "cost and cache control are cleared", + messages: []session.Item{ + session.NewMessageItem(&session.Message{ + Message: chat.Message{ + Role: chat.MessageRoleUser, + Content: "hello", + Cost: 1.5, + CacheControl: true, + }, + }), + }, + contextLimit: 100_000, + wantConversationMsgCount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sess := session.New(session.WithMessages(tt.messages)) + + a := agent.New("test", "test prompt") + result := extractMessagesToCompact(sess, a, tt.contextLimit, tt.additionalPrompt) + + assert.GreaterOrEqual(t, len(result), tt.wantConversationMsgCount+2) + assert.Equal(t, chat.MessageRoleSystem, result[0].Role) + assert.Equal(t, compaction.SystemPrompt, result[0].Content) + + last := result[len(result)-1] + assert.Equal(t, chat.MessageRoleUser, last.Role) + expectedPrompt := compaction.UserPrompt + if tt.additionalPrompt != "" { + expectedPrompt += "\n\n" + tt.additionalPrompt + } + assert.Equal(t, expectedPrompt, last.Content) + + // Conversation messages are all except first (system) and last (user prompt) + assert.Equal(t, tt.wantConversationMsgCount, len(result)-2) + + // Verify cost and cache control are cleared on conversation messages + for i := 1; i < len(result)-1; i++ { + assert.Zero(t, result[i].Cost) + assert.False(t, result[i].CacheControl) + } + }) + } +} diff --git a/pkg/session/session.go b/pkg/session/session.go index b9740d4ac..89b0324e4 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -470,6 +470,12 @@ func WithTitle(title string) Opt { } } +func WithMessages(messages []Item) Opt { + return func(s *Session) { + s.Messages = messages + } +} + func WithToolsApproved(toolsApproved bool) Opt { return func(s *Session) { s.ToolsApproved = toolsApproved From 644680cc25a5fcf7706337ec4e6f566e31d2130d Mon Sep 17 00:00:00 2001 From: David Gageot Date: Wed, 25 Mar 2026 13:35:11 +0100 Subject: [PATCH 2/2] Prevent infinite compaction loop on repeated ContextOverflowError Add a retry counter (maxOverflowCompactions=1) to the auto-compaction path in the runtime loop. When every model call returns a ContextOverflowError, compaction is now attempted at most once before the error is surfaced to the user. The counter resets after each successful model call so future overflows can still trigger compaction. Add TestCompactionOverflowDoesNotLoop to verify the guard. Assisted-By: docker-agent --- pkg/runtime/loop.go | 17 ++++++++++++- pkg/runtime/runtime_test.go | 50 +++++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 1 deletion(-) diff --git a/pkg/runtime/loop.go b/pkg/runtime/loop.go index b7c194e75..3005022a4 100644 --- a/pkg/runtime/loop.go +++ b/pkg/runtime/loop.go @@ -128,6 +128,13 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c } loopDetector := newToolLoopDetector(loopThreshold) + // overflowCompactions counts how many consecutive context-overflow + // auto-compactions have been attempted without a successful model + // call in between. This prevents an infinite loop when compaction + // cannot reduce the context below the model's limit. + const maxOverflowCompactions = 1 + var overflowCompactions int + // toolModelOverride holds the per-toolset model from the most recent // tool calls. It applies for one LLM turn, then resets. var toolModelOverride string @@ -281,13 +288,18 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c // Auto-recovery: if the error is a context overflow and // session compaction is enabled, compact the conversation // and retry the request instead of surfacing raw errors. - if _, ok := errors.AsType[*modelerrors.ContextOverflowError](err); ok && r.sessionCompaction { + // We allow at most maxOverflowCompactions consecutive attempts + // to avoid an infinite loop when compaction cannot reduce + // the context enough. + if _, ok := errors.AsType[*modelerrors.ContextOverflowError](err); ok && r.sessionCompaction && overflowCompactions < maxOverflowCompactions { + overflowCompactions++ slog.Warn("Context window overflow detected, attempting auto-compaction", "agent", a.Name(), "session_id", sess.ID, "input_tokens", sess.InputTokens, "output_tokens", sess.OutputTokens, "context_limit", contextLimit, + "attempt", overflowCompactions, ) events <- Warning( "The conversation has exceeded the model's context window. Automatically compacting the conversation history...", @@ -314,6 +326,9 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c return } + // A successful model call resets the overflow compaction counter. + overflowCompactions = 0 + if usedModel != nil && usedModel.ID() != model.ID() { slog.Info("Used fallback model", "agent", a.Name(), "primary", model.ID(), "used", usedModel.ID()) events <- AgentInfo(a.Name(), usedModel.ID(), a.Description(), a.WelcomeMessage()) diff --git a/pkg/runtime/runtime_test.go b/pkg/runtime/runtime_test.go index 247dc4bbd..5b6d25caf 100644 --- a/pkg/runtime/runtime_test.go +++ b/pkg/runtime/runtime_test.go @@ -16,6 +16,7 @@ import ( "github.com/docker/docker-agent/pkg/chat" "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/model/provider/base" + "github.com/docker/docker-agent/pkg/modelerrors" "github.com/docker/docker-agent/pkg/modelsdev" "github.com/docker/docker-agent/pkg/permissions" "github.com/docker/docker-agent/pkg/session" @@ -631,6 +632,55 @@ func TestCompaction(t *testing.T) { require.NotEqual(t, -1, compactionStartIdx, "expected a SessionCompaction start event") } +// errorProvider always returns the configured error from CreateChatCompletionStream. +type errorProvider struct { + id string + err error +} + +func (p *errorProvider) ID() string { return p.id } + +func (p *errorProvider) CreateChatCompletionStream(context.Context, []chat.Message, []tools.Tool) (chat.MessageStream, error) { + return nil, p.err +} + +func (p *errorProvider) BaseConfig() base.Config { return base.Config{} } + +func (p *errorProvider) MaxTokens() int { return 0 } + +func TestCompactionOverflowDoesNotLoop(t *testing.T) { + // The model always returns a ContextOverflowError. Without the + // max-retry guard this would loop forever because compaction + // cannot fix the problem. + overflowErr := modelerrors.NewContextOverflowError(errors.New("prompt is too long")) + prov := &errorProvider{id: "test/overflow-model", err: overflowErr} + + root := agent.New("root", "You are a test agent", agent.WithModel(prov)) + tm := team.New(team.WithAgents(root)) + + rt, err := NewLocalRuntime(tm, WithSessionCompaction(true), WithModelStore(mockModelStoreWithLimit{limit: 100})) + require.NoError(t, err) + + sess := session.New(session.WithUserMessage("Hello")) + events := rt.RunStream(t.Context(), sess) + + var compactionCount int + var sawError bool + for ev := range events { + if e, ok := ev.(*SessionCompactionEvent); ok && e.Status == "started" { + compactionCount++ + } + if _, ok := ev.(*ErrorEvent); ok { + sawError = true + } + } + + // Compaction should have been attempted at most once, then the loop + // must give up and surface an error instead of retrying indefinitely. + require.LessOrEqual(t, compactionCount, 1, "expected at most 1 compaction attempt, got %d", compactionCount) + require.True(t, sawError, "expected an ErrorEvent after exhausting compaction retries") +} + func TestSessionWithoutUserMessage(t *testing.T) { stream := newStreamBuilder().AddContent("OK").AddStopWithUsage(1, 1).Build()