diff --git a/pkg/compaction/compaction.go b/pkg/compaction/compaction.go index 23367ac9b..90a395f53 100644 --- a/pkg/compaction/compaction.go +++ b/pkg/compaction/compaction.go @@ -1,22 +1,7 @@ -// Package compaction provides conversation compaction (summarization) for -// chat sessions that approach their model's context window limit. -// -// It is designed as a standalone component that can be used independently of -// the runtime loop. The package exposes: -// -// - [BuildPrompt]: prepares a conversation for summarization by appending -// the compaction prompt and sanitizing message costs. -// - [ShouldCompact]: decides whether a session needs compaction based on -// token usage and context window limits. -// - [EstimateMessageTokens]: a fast heuristic for estimating the token -// count of a single chat message. -// - [HasConversationMessages]: checks whether a message list contains any -// non-system messages worth summarizing. package compaction import ( _ "embed" - "time" "github.com/docker/docker-agent/pkg/chat" ) @@ -26,7 +11,7 @@ var ( SystemPrompt string //go:embed prompts/compaction-user.txt - userPrompt string + UserPrompt string ) // contextThreshold is the fraction of the context window at which compaction @@ -34,60 +19,15 @@ var ( // context limit, compaction is recommended. const contextThreshold = 0.9 -// Result holds the outcome of a compaction operation. -type Result struct { - // Summary is the generated summary text. - Summary string - - // InputTokens is the token count reported by the summarization model, - // used as an approximation of the new context size after compaction. - InputTokens int64 - - // Cost is the cost of the summarization request in dollars. - Cost float64 -} - -// BuildPrompt prepares the messages for a summarization request. -// It clones the conversation (zeroing per-message costs so they aren't -// double-counted), then appends a user message containing the compaction -// prompt. If additionalPrompt is non-empty it is included as extra -// instructions. -// -// Callers should first check [HasConversationMessages] to avoid sending -// an empty conversation to the model. -func BuildPrompt(messages []chat.Message, additionalPrompt string) []chat.Message { - prompt := userPrompt - if additionalPrompt != "" { - prompt += "\n\nAdditional instructions from user: " + additionalPrompt - } - - out := make([]chat.Message, len(messages), len(messages)+1) - for i, msg := range messages { - cloned := msg - cloned.Cost = 0 - cloned.CacheControl = false - out[i] = cloned - } - out = append(out, chat.Message{ - Role: chat.MessageRoleUser, - Content: prompt, - CreatedAt: time.Now().Format(time.RFC3339), - }) - - return out -} - // ShouldCompact reports whether a session's context usage has crossed the -// compaction threshold. It returns true when the estimated total token count +// compaction threshold. It returns true when the total token count // (input + output + addedTokens) exceeds [contextThreshold] (90%) of -// contextLimit. A non-positive contextLimit is treated as unlimited and -// always returns false. +// contextLimit. func ShouldCompact(inputTokens, outputTokens, addedTokens, contextLimit int64) bool { if contextLimit <= 0 { return false } - estimated := inputTokens + outputTokens + addedTokens - return estimated > int64(float64(contextLimit)*contextThreshold) + return (inputTokens + outputTokens + addedTokens) > int64(float64(contextLimit)*contextThreshold) } // EstimateMessageTokens returns a rough token-count estimate for a single @@ -121,15 +61,3 @@ func EstimateMessageTokens(msg *chat.Message) int64 { } return int64(chars/charsPerToken) + perMessageOverhead } - -// HasConversationMessages reports whether messages contains at least one -// non-system message. A session with only system prompts has no conversation -// to summarize. -func HasConversationMessages(messages []chat.Message) bool { - for _, msg := range messages { - if msg.Role != chat.MessageRoleSystem { - return true - } - } - return false -} diff --git a/pkg/compaction/compaction_test.go b/pkg/compaction/compaction_test.go index 7dde3b155..490c730c1 100644 --- a/pkg/compaction/compaction_test.go +++ b/pkg/compaction/compaction_test.go @@ -4,7 +4,6 @@ import ( "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "github.com/docker/docker-agent/pkg/chat" "github.com/docker/docker-agent/pkg/tools" @@ -164,137 +163,3 @@ func TestShouldCompact(t *testing.T) { }) } } - -func TestHasConversationMessages(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - messages []chat.Message - want bool - }{ - { - name: "empty", - messages: nil, - want: false, - }, - { - name: "system only", - messages: []chat.Message{ - {Role: chat.MessageRoleSystem, Content: "You are helpful."}, - }, - want: false, - }, - { - name: "system and user", - messages: []chat.Message{ - {Role: chat.MessageRoleSystem, Content: "You are helpful."}, - {Role: chat.MessageRoleUser, Content: "Hello"}, - }, - want: true, - }, - { - name: "only user", - messages: []chat.Message{ - {Role: chat.MessageRoleUser, Content: "Hello"}, - }, - want: true, - }, - { - name: "assistant message", - messages: []chat.Message{ - {Role: chat.MessageRoleAssistant, Content: "Hi there"}, - }, - want: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - got := HasConversationMessages(tt.messages) - assert.Equal(t, tt.want, got) - }) - } -} - -func TestBuildPrompt(t *testing.T) { - t.Parallel() - - messages := []chat.Message{ - {Role: chat.MessageRoleSystem, Content: "You are helpful."}, - {Role: chat.MessageRoleUser, Content: "Hello", Cost: 0.05}, - {Role: chat.MessageRoleAssistant, Content: "Hi!", Cost: 0.10}, - } - - t.Run("basic", func(t *testing.T) { - t.Parallel() - - out := BuildPrompt(messages, "") - - // Original messages + appended summarization prompt. - require.Len(t, out, 4) - - // Costs are zeroed. - for _, msg := range out[:3] { - assert.Zero(t, msg.Cost, "cost should be zeroed for %q", msg.Content) - } - - // Last message is the summarization prompt. - last := out[len(out)-1] - assert.Equal(t, chat.MessageRoleUser, last.Role) - assert.Contains(t, last.Content, "summary") - assert.NotEmpty(t, last.CreatedAt) - }) - - t.Run("with additional prompt", func(t *testing.T) { - t.Parallel() - - out := BuildPrompt(messages, "focus on code changes") - - last := out[len(out)-1] - assert.Contains(t, last.Content, "Additional instructions from user: focus on code changes") - }) - - t.Run("does not modify original messages", func(t *testing.T) { - t.Parallel() - - original := []chat.Message{ - {Role: chat.MessageRoleUser, Content: "Hello", Cost: 0.05}, - } - - _ = BuildPrompt(original, "") - - assert.InDelta(t, 0.05, original[0].Cost, 1e-9) - assert.Len(t, original, 1) - }) - - t.Run("strips CacheControl from cloned messages", func(t *testing.T) { - t.Parallel() - - input := []chat.Message{ - {Role: chat.MessageRoleSystem, Content: "system", CacheControl: true}, - {Role: chat.MessageRoleSystem, Content: "context", CacheControl: true}, - {Role: chat.MessageRoleUser, Content: "hello"}, - } - - out := BuildPrompt(input, "") - - // All cloned messages should have CacheControl=false - for i, msg := range out { - assert.False(t, msg.CacheControl, "message %d should have CacheControl stripped", i) - } - // Original should be unchanged - assert.True(t, input[0].CacheControl) - assert.True(t, input[1].CacheControl) - }) -} - -func TestPromptsAreEmbedded(t *testing.T) { - t.Parallel() - - assert.NotEmpty(t, SystemPrompt, "compaction system prompt should be embedded") - assert.NotEmpty(t, userPrompt, "compaction user prompt should be embedded") - assert.Contains(t, SystemPrompt, "summary") - assert.Contains(t, userPrompt, "summary") -} diff --git a/pkg/runtime/loop.go b/pkg/runtime/loop.go index fe5f11283..3005022a4 100644 --- a/pkg/runtime/loop.go +++ b/pkg/runtime/loop.go @@ -128,6 +128,13 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c } loopDetector := newToolLoopDetector(loopThreshold) + // overflowCompactions counts how many consecutive context-overflow + // auto-compactions have been attempted without a successful model + // call in between. This prevents an infinite loop when compaction + // cannot reduce the context below the model's limit. + const maxOverflowCompactions = 1 + var overflowCompactions int + // toolModelOverride holds the per-toolset model from the most recent // tool calls. It applies for one LLM turn, then resets. var toolModelOverride string @@ -248,13 +255,14 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c slog.Debug("Failed to get model definition", "error", err) } + // We can only compact if we know the limit. var contextLimit int64 if m != nil { contextLimit = int64(m.Limit.Context) - } - if m != nil && r.sessionCompaction && compaction.ShouldCompact(sess.InputTokens, sess.OutputTokens, 0, contextLimit) { - r.Summarize(ctx, sess, "", events) + if r.sessionCompaction && compaction.ShouldCompact(sess.InputTokens, sess.OutputTokens, 0, contextLimit) { + r.Summarize(ctx, sess, "", events) + } } messages := sess.GetMessages(a) @@ -280,13 +288,18 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c // Auto-recovery: if the error is a context overflow and // session compaction is enabled, compact the conversation // and retry the request instead of surfacing raw errors. - if _, ok := errors.AsType[*modelerrors.ContextOverflowError](err); ok && r.sessionCompaction { + // We allow at most maxOverflowCompactions consecutive attempts + // to avoid an infinite loop when compaction cannot reduce + // the context enough. + if _, ok := errors.AsType[*modelerrors.ContextOverflowError](err); ok && r.sessionCompaction && overflowCompactions < maxOverflowCompactions { + overflowCompactions++ slog.Warn("Context window overflow detected, attempting auto-compaction", "agent", a.Name(), "session_id", sess.ID, "input_tokens", sess.InputTokens, "output_tokens", sess.OutputTokens, "context_limit", contextLimit, + "attempt", overflowCompactions, ) events <- Warning( "The conversation has exceeded the model's context window. Automatically compacting the conversation history...", @@ -313,6 +326,9 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c return } + // A successful model call resets the overflow compaction counter. + overflowCompactions = 0 + if usedModel != nil && usedModel.ID() != model.ID() { slog.Info("Used fallback model", "agent", a.Name(), "primary", model.ID(), "used", usedModel.ID()) events <- AgentInfo(a.Name(), usedModel.ID(), a.Description(), a.WelcomeMessage()) diff --git a/pkg/runtime/runtime_test.go b/pkg/runtime/runtime_test.go index 4bb92cb9e..5b6d25caf 100644 --- a/pkg/runtime/runtime_test.go +++ b/pkg/runtime/runtime_test.go @@ -16,6 +16,7 @@ import ( "github.com/docker/docker-agent/pkg/chat" "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/model/provider/base" + "github.com/docker/docker-agent/pkg/modelerrors" "github.com/docker/docker-agent/pkg/modelsdev" "github.com/docker/docker-agent/pkg/permissions" "github.com/docker/docker-agent/pkg/session" @@ -631,6 +632,55 @@ func TestCompaction(t *testing.T) { require.NotEqual(t, -1, compactionStartIdx, "expected a SessionCompaction start event") } +// errorProvider always returns the configured error from CreateChatCompletionStream. +type errorProvider struct { + id string + err error +} + +func (p *errorProvider) ID() string { return p.id } + +func (p *errorProvider) CreateChatCompletionStream(context.Context, []chat.Message, []tools.Tool) (chat.MessageStream, error) { + return nil, p.err +} + +func (p *errorProvider) BaseConfig() base.Config { return base.Config{} } + +func (p *errorProvider) MaxTokens() int { return 0 } + +func TestCompactionOverflowDoesNotLoop(t *testing.T) { + // The model always returns a ContextOverflowError. Without the + // max-retry guard this would loop forever because compaction + // cannot fix the problem. + overflowErr := modelerrors.NewContextOverflowError(errors.New("prompt is too long")) + prov := &errorProvider{id: "test/overflow-model", err: overflowErr} + + root := agent.New("root", "You are a test agent", agent.WithModel(prov)) + tm := team.New(team.WithAgents(root)) + + rt, err := NewLocalRuntime(tm, WithSessionCompaction(true), WithModelStore(mockModelStoreWithLimit{limit: 100})) + require.NoError(t, err) + + sess := session.New(session.WithUserMessage("Hello")) + events := rt.RunStream(t.Context(), sess) + + var compactionCount int + var sawError bool + for ev := range events { + if e, ok := ev.(*SessionCompactionEvent); ok && e.Status == "started" { + compactionCount++ + } + if _, ok := ev.(*ErrorEvent); ok { + sawError = true + } + } + + // Compaction should have been attempted at most once, then the loop + // must give up and surface an error instead of retrying indefinitely. + require.LessOrEqual(t, compactionCount, 1, "expected at most 1 compaction attempt, got %d", compactionCount) + require.True(t, sawError, "expected an ErrorEvent after exhausting compaction retries") +} + func TestSessionWithoutUserMessage(t *testing.T) { stream := newStreamBuilder().AddContent("OK").AddStopWithUsage(1, 1).Build() @@ -737,37 +787,6 @@ func TestNewRuntime_InvalidCurrentAgentError(t *testing.T) { require.Contains(t, err.Error(), "agent not found: other (available agents: root)") } -func TestSummarize_EmptySession(t *testing.T) { - prov := &mockProvider{id: "test/mock-model", stream: &mockStream{}} - root := agent.New("root", "You are a test agent", agent.WithModel(prov)) - tm := team.New(team.WithAgents(root)) - - rt, err := NewLocalRuntime(tm, WithSessionCompaction(false), WithModelStore(mockModelStore{})) - require.NoError(t, err) - - sess := session.New() - sess.Title = "Empty Session Test" - - // Try to summarize the empty session - events := make(chan Event, 10) - rt.Summarize(t.Context(), sess, "", events) - close(events) - - // Collect events - var warningFound bool - var warningMsg string - for ev := range events { - if warningEvent, ok := ev.(*WarningEvent); ok { - warningFound = true - warningMsg = warningEvent.Message - } - } - - // Should have received a warning event about empty session - require.True(t, warningFound, "expected a warning event for empty session") - require.Contains(t, warningMsg, "empty", "warning message should mention empty session") -} - func TestProcessToolCalls_UnknownTool_ReturnsErrorResponse(t *testing.T) { root := agent.New("root", "You are a test agent", agent.WithModel(&mockProvider{})) tm := team.New(team.WithAgents(root)) diff --git a/pkg/runtime/session_compaction.go b/pkg/runtime/session_compaction.go index 72db1e184..714038e4c 100644 --- a/pkg/runtime/session_compaction.go +++ b/pkg/runtime/session_compaction.go @@ -2,7 +2,9 @@ package runtime import ( "context" + "errors" "log/slog" + "time" "github.com/docker/docker-agent/pkg/agent" "github.com/docker/docker-agent/pkg/chat" @@ -13,72 +15,155 @@ import ( "github.com/docker/docker-agent/pkg/team" ) -// runSummarization sends the prepared messages through a one-shot runtime -// and returns the model's summary together with the output token count and -// cost. The runtime is created with compaction disabled so it cannot recurse. -func runSummarization(ctx context.Context, model provider.Provider, messages []chat.Message) (compaction.Result, error) { - summaryModel := provider.CloneWithOptions(ctx, model, options.WithStructuredOutput(nil)) - root := agent.New("root", compaction.SystemPrompt, agent.WithModel(summaryModel)) - t := team.New(team.WithAgents(root)) - - sess := session.New() - sess.Title = "Generating summary..." - for _, msg := range messages { - sess.AddMessage(&session.Message{Message: msg}) - } - - rt, err := New(t, WithSessionCompaction(false)) - if err != nil { - return compaction.Result{}, err - } - if _, err = rt.Run(ctx, sess); err != nil { - return compaction.Result{}, err - } - - return compaction.Result{ - Summary: sess.GetLastAssistantMessageContent(), - InputTokens: sess.OutputTokens, - Cost: sess.TotalCost(), - }, nil -} +const maxSummaryTokens = 16_000 // doCompact runs compaction on a session and applies the result (events, // persistence, token count updates). The agent is used to extract the // conversation from the session and to obtain the model for summarization. func (r *LocalRuntime) doCompact(ctx context.Context, sess *session.Session, a *agent.Agent, additionalPrompt string, events chan Event) { slog.Debug("Generating summary for session", "session_id", sess.ID) - events <- SessionCompaction(sess.ID, "started", a.Name()) defer func() { events <- SessionCompaction(sess.ID, "completed", a.Name()) }() - messages := sess.GetMessages(a) - if !compaction.HasConversationMessages(messages) { - if additionalPrompt == "" { - events <- Warning("Session is empty. Start a conversation before compacting.", a.Name()) - } + // Build a model just for compaction. + summaryModel := provider.CloneWithOptions(ctx, a.Model(), + options.WithStructuredOutput(nil), + options.WithMaxTokens(maxSummaryTokens), + ) + m, err := r.modelsStore.GetModel(ctx, summaryModel.ID()) + if err != nil { + slog.Error("Failed to generate session summary", "error", errors.New("failed to get model definition")) + events <- Error("Failed to get model definition") return } - prepared := compaction.BuildPrompt(messages, additionalPrompt) + compactionAgent := agent.New("root", compaction.SystemPrompt, agent.WithModel(summaryModel)) + + // Compute the messages to compact. + messages := extractMessagesToCompact(sess, compactionAgent, int64(m.Limit.Context), additionalPrompt) + + // Run the compaction. + compactionSession := session.New( + session.WithTitle("Generating summary"), + session.WithMessages(toItems(messages)), + ) - result, err := runSummarization(ctx, a.Model(), prepared) + t := team.New(team.WithAgents(compactionAgent)) + rt, err := New(t, WithSessionCompaction(false)) if err != nil { slog.Error("Failed to generate session summary", "error", err) events <- Error(err.Error()) return } - if result.Summary == "" { + if _, err = rt.Run(ctx, compactionSession); err != nil { + slog.Error("Failed to generate session summary", "error", err) + events <- Error(err.Error()) return } - sess.Messages = append(sess.Messages, session.Item{Summary: result.Summary, Cost: result.Cost}) - sess.InputTokens = result.InputTokens - sess.OutputTokens = 0 + summary := compactionSession.GetLastAssistantMessageContent() + if summary == "" { + return + } + // Update the session. + sess.InputTokens = compactionSession.OutputTokens + sess.OutputTokens = 0 + sess.Messages = append(sess.Messages, session.Item{ + Summary: summary, + Cost: compactionSession.TotalCost(), + }) _ = r.sessionStore.UpdateSession(ctx, sess) - slog.Debug("Generated session summary", "session_id", sess.ID, "summary_length", len(result.Summary), "compaction_cost", result.Cost) - events <- SessionSummary(sess.ID, result.Summary, a.Name()) + slog.Debug("Generated session summary", "session_id", sess.ID, "summary_length", len(summary)) + events <- SessionSummary(sess.ID, summary, a.Name()) +} + +func extractMessagesToCompact(sess *session.Session, compactionAgent *agent.Agent, contextLimit int64, additionalPrompt string) []chat.Message { + // Add all the existing messages. + var messages []chat.Message + for _, msg := range sess.GetMessages(compactionAgent) { + if msg.Role == chat.MessageRoleSystem { + continue + } + + msg.Cost = 0 + msg.CacheControl = false + + messages = append(messages, msg) + } + + // Prepare the first (system) message. + systemPromptMessage := chat.Message{ + Role: chat.MessageRoleSystem, + Content: compaction.SystemPrompt, + CreatedAt: time.Now().Format(time.RFC3339), + } + systemPromptMessageLen := compaction.EstimateMessageTokens(&systemPromptMessage) + + // Prepare the last (user) message. + userPrompt := compaction.UserPrompt + if additionalPrompt != "" { + userPrompt += "\n\n" + additionalPrompt + } + userPromptMessage := chat.Message{ + Role: chat.MessageRoleUser, + Content: userPrompt, + CreatedAt: time.Now().Format(time.RFC3339), + } + userPromptMessageLen := compaction.EstimateMessageTokens(&userPromptMessage) + + // Truncate the messages so that they fit in the available context limit + // (minus the expected max length of the summary). + contextAvailable := max(0, contextLimit-maxSummaryTokens-systemPromptMessageLen-userPromptMessageLen) + firstIndex := firstMessageToKeep(messages, contextAvailable) + if firstIndex < len(messages) { + messages = messages[firstIndex:] + } else { + messages = nil + } + + // Prepend the first (system) message. + messages = append([]chat.Message{systemPromptMessage}, messages...) + + // Append the last (user) message. + messages = append(messages, userPromptMessage) + + return messages +} + +func firstMessageToKeep(messages []chat.Message, contextLimit int64) int { + var tokens int64 + + lastValidMessageSeen := len(messages) + + for i := len(messages) - 1; i >= 0; i-- { + tokens += compaction.EstimateMessageTokens(&messages[i]) + if tokens > contextLimit { + return lastValidMessageSeen + } + + role := messages[i].Role + if role == chat.MessageRoleUser || role == chat.MessageRoleAssistant { + lastValidMessageSeen = i + } + } + + return lastValidMessageSeen +} + +func toItems(messages []chat.Message) []session.Item { + var items []session.Item + + for _, message := range messages { + items = append(items, session.Item{ + Message: &session.Message{ + Message: message, + }, + }) + } + + return items } diff --git a/pkg/runtime/session_compaction_test.go b/pkg/runtime/session_compaction_test.go new file mode 100644 index 000000000..57d3c73f3 --- /dev/null +++ b/pkg/runtime/session_compaction_test.go @@ -0,0 +1,123 @@ +package runtime + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/docker/docker-agent/pkg/agent" + "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/compaction" + "github.com/docker/docker-agent/pkg/session" +) + +func TestExtractMessagesToCompact(t *testing.T) { + newMsg := func(role chat.MessageRole, content string) session.Item { + return session.NewMessageItem(&session.Message{ + Message: chat.Message{Role: role, Content: content}, + }) + } + + tests := []struct { + name string + messages []session.Item + contextLimit int64 + additionalPrompt string + wantConversationMsgCount int + }{ + { + name: "empty session returns system and user prompt only", + messages: nil, + contextLimit: 100_000, + wantConversationMsgCount: 0, + }, + { + name: "system messages are filtered out", + messages: []session.Item{ + newMsg(chat.MessageRoleSystem, "system instruction"), + newMsg(chat.MessageRoleUser, "hello"), + newMsg(chat.MessageRoleAssistant, "hi"), + }, + contextLimit: 100_000, + wantConversationMsgCount: 2, + }, + { + name: "messages fit within context limit", + messages: []session.Item{ + newMsg(chat.MessageRoleUser, "msg1"), + newMsg(chat.MessageRoleAssistant, "msg2"), + newMsg(chat.MessageRoleUser, "msg3"), + newMsg(chat.MessageRoleAssistant, "msg4"), + }, + contextLimit: 100_000, + wantConversationMsgCount: 4, + }, + { + name: "truncation when context limit is very small", + messages: []session.Item{ + newMsg(chat.MessageRoleUser, "first message with lots of content that takes tokens"), + newMsg(chat.MessageRoleAssistant, "first response with lots of content that takes tokens"), + newMsg(chat.MessageRoleUser, "second message"), + newMsg(chat.MessageRoleAssistant, "second response"), + }, + // Set context limit so small that after subtracting maxSummaryTokens + prompt overhead, + // not all messages fit. + contextLimit: maxSummaryTokens + 50, + wantConversationMsgCount: 0, + }, + { + name: "additional prompt is appended", + messages: []session.Item{ + newMsg(chat.MessageRoleUser, "hello"), + }, + contextLimit: 100_000, + additionalPrompt: "focus on code quality", + wantConversationMsgCount: 1, + }, + { + name: "cost and cache control are cleared", + messages: []session.Item{ + session.NewMessageItem(&session.Message{ + Message: chat.Message{ + Role: chat.MessageRoleUser, + Content: "hello", + Cost: 1.5, + CacheControl: true, + }, + }), + }, + contextLimit: 100_000, + wantConversationMsgCount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sess := session.New(session.WithMessages(tt.messages)) + + a := agent.New("test", "test prompt") + result := extractMessagesToCompact(sess, a, tt.contextLimit, tt.additionalPrompt) + + assert.GreaterOrEqual(t, len(result), tt.wantConversationMsgCount+2) + assert.Equal(t, chat.MessageRoleSystem, result[0].Role) + assert.Equal(t, compaction.SystemPrompt, result[0].Content) + + last := result[len(result)-1] + assert.Equal(t, chat.MessageRoleUser, last.Role) + expectedPrompt := compaction.UserPrompt + if tt.additionalPrompt != "" { + expectedPrompt += "\n\n" + tt.additionalPrompt + } + assert.Equal(t, expectedPrompt, last.Content) + + // Conversation messages are all except first (system) and last (user prompt) + assert.Equal(t, tt.wantConversationMsgCount, len(result)-2) + + // Verify cost and cache control are cleared on conversation messages + for i := 1; i < len(result)-1; i++ { + assert.Zero(t, result[i].Cost) + assert.False(t, result[i].CacheControl) + } + }) + } +} diff --git a/pkg/session/session.go b/pkg/session/session.go index b9740d4ac..89b0324e4 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -470,6 +470,12 @@ func WithTitle(title string) Opt { } } +func WithMessages(messages []Item) Opt { + return func(s *Session) { + s.Messages = messages + } +} + func WithToolsApproved(toolsApproved bool) Opt { return func(s *Session) { s.ToolsApproved = toolsApproved