Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 4 additions & 76 deletions pkg/compaction/compaction.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,7 @@
// Package compaction provides conversation compaction (summarization) for
// chat sessions that approach their model's context window limit.
//
// It is designed as a standalone component that can be used independently of
// the runtime loop. The package exposes:
//
// - [BuildPrompt]: prepares a conversation for summarization by appending
// the compaction prompt and sanitizing message costs.
// - [ShouldCompact]: decides whether a session needs compaction based on
// token usage and context window limits.
// - [EstimateMessageTokens]: a fast heuristic for estimating the token
// count of a single chat message.
// - [HasConversationMessages]: checks whether a message list contains any
// non-system messages worth summarizing.
package compaction

import (
_ "embed"
"time"

"github.com/docker/docker-agent/pkg/chat"
)
Expand All @@ -26,68 +11,23 @@ var (
SystemPrompt string

//go:embed prompts/compaction-user.txt
userPrompt string
UserPrompt string
)

// contextThreshold is the fraction of the context window at which compaction
// is triggered. When the estimated token usage exceeds this fraction of the
// context limit, compaction is recommended.
const contextThreshold = 0.9

// Result holds the outcome of a compaction operation.
type Result struct {
// Summary is the generated summary text.
Summary string

// InputTokens is the token count reported by the summarization model,
// used as an approximation of the new context size after compaction.
InputTokens int64

// Cost is the cost of the summarization request in dollars.
Cost float64
}

// BuildPrompt prepares the messages for a summarization request.
// It clones the conversation (zeroing per-message costs so they aren't
// double-counted), then appends a user message containing the compaction
// prompt. If additionalPrompt is non-empty it is included as extra
// instructions.
//
// Callers should first check [HasConversationMessages] to avoid sending
// an empty conversation to the model.
func BuildPrompt(messages []chat.Message, additionalPrompt string) []chat.Message {
prompt := userPrompt
if additionalPrompt != "" {
prompt += "\n\nAdditional instructions from user: " + additionalPrompt
}

out := make([]chat.Message, len(messages), len(messages)+1)
for i, msg := range messages {
cloned := msg
cloned.Cost = 0
cloned.CacheControl = false
out[i] = cloned
}
out = append(out, chat.Message{
Role: chat.MessageRoleUser,
Content: prompt,
CreatedAt: time.Now().Format(time.RFC3339),
})

return out
}

// ShouldCompact reports whether a session's context usage has crossed the
// compaction threshold. It returns true when the estimated total token count
// compaction threshold. It returns true when the total token count
// (input + output + addedTokens) exceeds [contextThreshold] (90%) of
// contextLimit. A non-positive contextLimit is treated as unlimited and
// always returns false.
// contextLimit.
func ShouldCompact(inputTokens, outputTokens, addedTokens, contextLimit int64) bool {
if contextLimit <= 0 {
return false
}
estimated := inputTokens + outputTokens + addedTokens
return estimated > int64(float64(contextLimit)*contextThreshold)
return (inputTokens + outputTokens + addedTokens) > int64(float64(contextLimit)*contextThreshold)
}

// EstimateMessageTokens returns a rough token-count estimate for a single
Expand Down Expand Up @@ -121,15 +61,3 @@ func EstimateMessageTokens(msg *chat.Message) int64 {
}
return int64(chars/charsPerToken) + perMessageOverhead
}

// HasConversationMessages reports whether messages contains at least one
// non-system message. A session with only system prompts has no conversation
// to summarize.
func HasConversationMessages(messages []chat.Message) bool {
for _, msg := range messages {
if msg.Role != chat.MessageRoleSystem {
return true
}
}
return false
}
135 changes: 0 additions & 135 deletions pkg/compaction/compaction_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/docker/docker-agent/pkg/chat"
"github.com/docker/docker-agent/pkg/tools"
Expand Down Expand Up @@ -164,137 +163,3 @@ func TestShouldCompact(t *testing.T) {
})
}
}

func TestHasConversationMessages(t *testing.T) {
t.Parallel()

tests := []struct {
name string
messages []chat.Message
want bool
}{
{
name: "empty",
messages: nil,
want: false,
},
{
name: "system only",
messages: []chat.Message{
{Role: chat.MessageRoleSystem, Content: "You are helpful."},
},
want: false,
},
{
name: "system and user",
messages: []chat.Message{
{Role: chat.MessageRoleSystem, Content: "You are helpful."},
{Role: chat.MessageRoleUser, Content: "Hello"},
},
want: true,
},
{
name: "only user",
messages: []chat.Message{
{Role: chat.MessageRoleUser, Content: "Hello"},
},
want: true,
},
{
name: "assistant message",
messages: []chat.Message{
{Role: chat.MessageRoleAssistant, Content: "Hi there"},
},
want: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := HasConversationMessages(tt.messages)
assert.Equal(t, tt.want, got)
})
}
}

func TestBuildPrompt(t *testing.T) {
t.Parallel()

messages := []chat.Message{
{Role: chat.MessageRoleSystem, Content: "You are helpful."},
{Role: chat.MessageRoleUser, Content: "Hello", Cost: 0.05},
{Role: chat.MessageRoleAssistant, Content: "Hi!", Cost: 0.10},
}

t.Run("basic", func(t *testing.T) {
t.Parallel()

out := BuildPrompt(messages, "")

// Original messages + appended summarization prompt.
require.Len(t, out, 4)

// Costs are zeroed.
for _, msg := range out[:3] {
assert.Zero(t, msg.Cost, "cost should be zeroed for %q", msg.Content)
}

// Last message is the summarization prompt.
last := out[len(out)-1]
assert.Equal(t, chat.MessageRoleUser, last.Role)
assert.Contains(t, last.Content, "summary")
assert.NotEmpty(t, last.CreatedAt)
})

t.Run("with additional prompt", func(t *testing.T) {
t.Parallel()

out := BuildPrompt(messages, "focus on code changes")

last := out[len(out)-1]
assert.Contains(t, last.Content, "Additional instructions from user: focus on code changes")
})

t.Run("does not modify original messages", func(t *testing.T) {
t.Parallel()

original := []chat.Message{
{Role: chat.MessageRoleUser, Content: "Hello", Cost: 0.05},
}

_ = BuildPrompt(original, "")

assert.InDelta(t, 0.05, original[0].Cost, 1e-9)
assert.Len(t, original, 1)
})

t.Run("strips CacheControl from cloned messages", func(t *testing.T) {
t.Parallel()

input := []chat.Message{
{Role: chat.MessageRoleSystem, Content: "system", CacheControl: true},
{Role: chat.MessageRoleSystem, Content: "context", CacheControl: true},
{Role: chat.MessageRoleUser, Content: "hello"},
}

out := BuildPrompt(input, "")

// All cloned messages should have CacheControl=false
for i, msg := range out {
assert.False(t, msg.CacheControl, "message %d should have CacheControl stripped", i)
}
// Original should be unchanged
assert.True(t, input[0].CacheControl)
assert.True(t, input[1].CacheControl)
})
}

func TestPromptsAreEmbedded(t *testing.T) {
t.Parallel()

assert.NotEmpty(t, SystemPrompt, "compaction system prompt should be embedded")
assert.NotEmpty(t, userPrompt, "compaction user prompt should be embedded")
assert.Contains(t, SystemPrompt, "summary")
assert.Contains(t, userPrompt, "summary")
}
24 changes: 20 additions & 4 deletions pkg/runtime/loop.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,13 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
}
loopDetector := newToolLoopDetector(loopThreshold)

// overflowCompactions counts how many consecutive context-overflow
// auto-compactions have been attempted without a successful model
// call in between. This prevents an infinite loop when compaction
// cannot reduce the context below the model's limit.
const maxOverflowCompactions = 1
var overflowCompactions int

// toolModelOverride holds the per-toolset model from the most recent
// tool calls. It applies for one LLM turn, then resets.
var toolModelOverride string
Expand Down Expand Up @@ -248,13 +255,14 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
slog.Debug("Failed to get model definition", "error", err)
}

// We can only compact if we know the limit.
var contextLimit int64
if m != nil {
contextLimit = int64(m.Limit.Context)
}

if m != nil && r.sessionCompaction && compaction.ShouldCompact(sess.InputTokens, sess.OutputTokens, 0, contextLimit) {
r.Summarize(ctx, sess, "", events)
if r.sessionCompaction && compaction.ShouldCompact(sess.InputTokens, sess.OutputTokens, 0, contextLimit) {
r.Summarize(ctx, sess, "", events)
}
}

messages := sess.GetMessages(a)
Expand All @@ -280,13 +288,18 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
// Auto-recovery: if the error is a context overflow and
// session compaction is enabled, compact the conversation
// and retry the request instead of surfacing raw errors.
if _, ok := errors.AsType[*modelerrors.ContextOverflowError](err); ok && r.sessionCompaction {
// We allow at most maxOverflowCompactions consecutive attempts
// to avoid an infinite loop when compaction cannot reduce
// the context enough.
if _, ok := errors.AsType[*modelerrors.ContextOverflowError](err); ok && r.sessionCompaction && overflowCompactions < maxOverflowCompactions {
overflowCompactions++
slog.Warn("Context window overflow detected, attempting auto-compaction",
"agent", a.Name(),
"session_id", sess.ID,
"input_tokens", sess.InputTokens,
"output_tokens", sess.OutputTokens,
"context_limit", contextLimit,
"attempt", overflowCompactions,
)
events <- Warning(
"The conversation has exceeded the model's context window. Automatically compacting the conversation history...",
Expand All @@ -313,6 +326,9 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
return
}

// A successful model call resets the overflow compaction counter.
overflowCompactions = 0

if usedModel != nil && usedModel.ID() != model.ID() {
slog.Info("Used fallback model", "agent", a.Name(), "primary", model.ID(), "used", usedModel.ID())
events <- AgentInfo(a.Name(), usedModel.ID(), a.Description(), a.WelcomeMessage())
Expand Down
Loading
Loading