Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/tired-dogs-worry.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"roo-cline": patch
---

Adds a button to intelligently condense the context window
13 changes: 8 additions & 5 deletions src/core/condense/__tests__/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,16 @@ describe("summarizeConversation", () => {
} as unknown as ApiHandler
})

// Default system prompt for tests
const defaultSystemPrompt = "You are a helpful assistant."

it("should not summarize when there are not enough messages", async () => {
const messages: ApiMessage[] = [
{ role: "user", content: "Hello", ts: 1 },
{ role: "assistant", content: "Hi there", ts: 2 },
]

const result = await summarizeConversation(messages, mockApiHandler)
const result = await summarizeConversation(messages, mockApiHandler, defaultSystemPrompt)
expect(result.messages).toEqual(messages)
expect(result.cost).toBe(0)
expect(result.summary).toBe("")
Expand All @@ -122,7 +125,7 @@ describe("summarizeConversation", () => {
{ role: "user", content: "Tell me more", ts: 7 },
]

const result = await summarizeConversation(messages, mockApiHandler)
const result = await summarizeConversation(messages, mockApiHandler, defaultSystemPrompt)
expect(result.messages).toEqual(messages)
expect(result.cost).toBe(0)
expect(result.summary).toBe("")
Expand All @@ -141,7 +144,7 @@ describe("summarizeConversation", () => {
{ role: "user", content: "Tell me more", ts: 7 },
]

const result = await summarizeConversation(messages, mockApiHandler)
const result = await summarizeConversation(messages, mockApiHandler, defaultSystemPrompt)

// Check that the API was called correctly
expect(mockApiHandler.createMessage).toHaveBeenCalled()
Expand Down Expand Up @@ -199,7 +202,7 @@ describe("summarizeConversation", () => {
return messages.map(({ role, content }: { role: string; content: any }) => ({ role, content }))
})

const result = await summarizeConversation(messages, mockApiHandler)
const result = await summarizeConversation(messages, mockApiHandler, defaultSystemPrompt)

// Should return original messages when summary is empty
expect(result.messages).toEqual(messages)
Expand All @@ -222,7 +225,7 @@ describe("summarizeConversation", () => {
{ role: "user", content: "Tell me more", ts: 7 },
]

await summarizeConversation(messages, mockApiHandler)
await summarizeConversation(messages, mockApiHandler, defaultSystemPrompt)

// Verify the final request message
const expectedFinalMessage = {
Expand Down
11 changes: 6 additions & 5 deletions src/core/condense/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,13 @@ export type SummarizeResponse = {
*
* @param {ApiMessage[]} messages - The conversation messages
* @param {ApiHandler} apiHandler - The API handler to use for token counting.
* @param {string} systemPrompt - The system prompt for API requests, which should be considered in the context token count
* @returns {SummarizeResponse} - The result of the summarization operation (see above)
*/
export async function summarizeConversation(
messages: ApiMessage[],
apiHandler: ApiHandler,
systemPrompt?: string,
systemPrompt: string,
): Promise<SummarizeResponse> {
const response: SummarizeResponse = { messages, cost: 0, summary: "" }
const messagesToSummarize = getMessagesSinceLastSummary(messages.slice(0, -N_MESSAGES_TO_KEEP))
Expand Down Expand Up @@ -111,10 +112,10 @@ export async function summarizeConversation(

// Count the tokens in the context for the next API request
// We only estimate the tokens in summaryMesage if outputTokens is 0, otherwise we use outputTokens
const contextMessages = outputTokens ? [...keepMessages] : [summaryMessage, ...keepMessages]
if (systemPrompt) {
contextMessages.unshift({ role: "user", content: systemPrompt })
}
const systemPromptMessage: ApiMessage = { role: "user", content: systemPrompt }
const contextMessages = outputTokens
? [systemPromptMessage, ...keepMessages]
: [systemPromptMessage, summaryMessage, ...keepMessages]
const contextBlocks = contextMessages.flatMap((message) =>
typeof message.content === "string" ? [{ text: message.content, type: "text" as const }] : message.content,
)
Expand Down
20 changes: 20 additions & 0 deletions src/core/sliding-window/__tests__/sliding-window.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ describe("truncateConversationIfNeeded", () => {
contextWindow: modelInfo.contextWindow,
maxTokens: modelInfo.maxTokens,
apiHandler: mockApiHandler,
systemPrompt: "System prompt",
})

// Check the new return type
Expand Down Expand Up @@ -276,6 +277,7 @@ describe("truncateConversationIfNeeded", () => {
contextWindow: modelInfo.contextWindow,
maxTokens: modelInfo.maxTokens,
apiHandler: mockApiHandler,
systemPrompt: "System prompt",
})

expect(result).toEqual({
Expand All @@ -302,6 +304,7 @@ describe("truncateConversationIfNeeded", () => {
contextWindow: modelInfo1.contextWindow,
maxTokens: modelInfo1.maxTokens,
apiHandler: mockApiHandler,
systemPrompt: "System prompt",
})

const result2 = await truncateConversationIfNeeded({
Expand All @@ -310,6 +313,7 @@ describe("truncateConversationIfNeeded", () => {
contextWindow: modelInfo2.contextWindow,
maxTokens: modelInfo2.maxTokens,
apiHandler: mockApiHandler,
systemPrompt: "System prompt",
})

expect(result1.messages).toEqual(result2.messages)
Expand All @@ -325,6 +329,7 @@ describe("truncateConversationIfNeeded", () => {
contextWindow: modelInfo1.contextWindow,
maxTokens: modelInfo1.maxTokens,
apiHandler: mockApiHandler,
systemPrompt: "System prompt",
})

const result4 = await truncateConversationIfNeeded({
Expand All @@ -333,6 +338,7 @@ describe("truncateConversationIfNeeded", () => {
contextWindow: modelInfo2.contextWindow,
maxTokens: modelInfo2.maxTokens,
apiHandler: mockApiHandler,
systemPrompt: "System prompt",
})

expect(result3.messages).toEqual(result4.messages)
Expand Down Expand Up @@ -363,6 +369,7 @@ describe("truncateConversationIfNeeded", () => {
contextWindow: modelInfo.contextWindow,
maxTokens,
apiHandler: mockApiHandler,
systemPrompt: "System prompt",
})
expect(resultWithSmall).toEqual({
messages: messagesWithSmallContent,
Expand Down Expand Up @@ -392,6 +399,7 @@ describe("truncateConversationIfNeeded", () => {
contextWindow: modelInfo.contextWindow,
maxTokens,
apiHandler: mockApiHandler,
systemPrompt: "System prompt",
})
expect(resultWithLarge.messages).not.toEqual(messagesWithLargeContent) // Should truncate
expect(resultWithLarge.summary).toBe("")
Expand All @@ -414,6 +422,7 @@ describe("truncateConversationIfNeeded", () => {
contextWindow: modelInfo.contextWindow,
maxTokens,
apiHandler: mockApiHandler,
systemPrompt: "System prompt",
})
expect(resultWithVeryLarge.messages).not.toEqual(messagesWithVeryLargeContent) // Should truncate
expect(resultWithVeryLarge.summary).toBe("")
Expand All @@ -439,6 +448,7 @@ describe("truncateConversationIfNeeded", () => {
contextWindow: modelInfo.contextWindow,
maxTokens: modelInfo.maxTokens,
apiHandler: mockApiHandler,
systemPrompt: "System prompt",
})
expect(result).toEqual({
messages: expectedResult,
Expand Down Expand Up @@ -524,6 +534,7 @@ describe("truncateConversationIfNeeded", () => {
maxTokens: modelInfo.maxTokens,
apiHandler: mockApiHandler,
autoCondenseContext: true,
systemPrompt: "System prompt",
})

// Verify summarizeConversation was called
Expand Down Expand Up @@ -559,6 +570,7 @@ describe("truncateConversationIfNeeded", () => {
maxTokens: modelInfo.maxTokens,
apiHandler: mockApiHandler,
autoCondenseContext: false,
systemPrompt: "System prompt",
})

// Verify summarizeConversation was not called
Expand Down Expand Up @@ -612,6 +624,7 @@ describe("getMaxTokens", () => {
contextWindow: modelInfo.contextWindow,
maxTokens: modelInfo.maxTokens,
apiHandler: mockApiHandler,
systemPrompt: "System prompt",
})
expect(result1).toEqual({
messages: messagesWithSmallContent,
Expand All @@ -627,6 +640,7 @@ describe("getMaxTokens", () => {
contextWindow: modelInfo.contextWindow,
maxTokens: modelInfo.maxTokens,
apiHandler: mockApiHandler,
systemPrompt: "System prompt",
})
expect(result2.messages).not.toEqual(messagesWithSmallContent)
expect(result2.messages.length).toBe(3) // Truncated with 0.5 fraction
Expand All @@ -650,6 +664,7 @@ describe("getMaxTokens", () => {
contextWindow: modelInfo.contextWindow,
maxTokens: modelInfo.maxTokens,
apiHandler: mockApiHandler,
systemPrompt: "System prompt",
})
expect(result1).toEqual({
messages: messagesWithSmallContent,
Expand All @@ -665,6 +680,7 @@ describe("getMaxTokens", () => {
contextWindow: modelInfo.contextWindow,
maxTokens: modelInfo.maxTokens,
apiHandler: mockApiHandler,
systemPrompt: "System prompt",
})
expect(result2.messages).not.toEqual(messagesWithSmallContent)
expect(result2.messages.length).toBe(3) // Truncated with 0.5 fraction
Expand All @@ -687,6 +703,7 @@ describe("getMaxTokens", () => {
contextWindow: modelInfo.contextWindow,
maxTokens: modelInfo.maxTokens,
apiHandler: mockApiHandler,
systemPrompt: "System prompt",
})
expect(result1.messages).toEqual(messagesWithSmallContent)

Expand All @@ -697,6 +714,7 @@ describe("getMaxTokens", () => {
contextWindow: modelInfo.contextWindow,
maxTokens: modelInfo.maxTokens,
apiHandler: mockApiHandler,
systemPrompt: "System prompt",
})
expect(result2).not.toEqual(messagesWithSmallContent)
expect(result2.messages.length).toBe(3) // Truncated with 0.5 fraction
Expand All @@ -717,6 +735,7 @@ describe("getMaxTokens", () => {
contextWindow: modelInfo.contextWindow,
maxTokens: modelInfo.maxTokens,
apiHandler: mockApiHandler,
systemPrompt: "System prompt",
})
expect(result1.messages).toEqual(messagesWithSmallContent)

Expand All @@ -727,6 +746,7 @@ describe("getMaxTokens", () => {
contextWindow: modelInfo.contextWindow,
maxTokens: modelInfo.maxTokens,
apiHandler: mockApiHandler,
systemPrompt: "System prompt",
})
expect(result2).not.toEqual(messagesWithSmallContent)
expect(result2.messages.length).toBe(3) // Truncated with 0.5 fraction
Expand Down
2 changes: 1 addition & 1 deletion src/core/sliding-window/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ type TruncateOptions = {
maxTokens?: number | null
apiHandler: ApiHandler
autoCondenseContext?: boolean
systemPrompt?: string
systemPrompt: string
}

type TruncateResponse = SummarizeResponse & { prevContextTokens: number }
Expand Down
Loading
Loading