diff --git a/src/core/assistant-message/presentAssistantMessage.ts b/src/core/assistant-message/presentAssistantMessage.ts index ee3fa148b41..acdc7f5412e 100644 --- a/src/core/assistant-message/presentAssistantMessage.ts +++ b/src/core/assistant-message/presentAssistantMessage.ts @@ -25,7 +25,6 @@ import { switchModeTool } from "../tools/switchModeTool" import { attemptCompletionTool } from "../tools/attemptCompletionTool" import { newTaskTool } from "../tools/newTaskTool" -import { checkpointSave } from "../checkpoints" import { updateTodoListTool } from "../tools/updateTodoListTool" import { formatResponse } from "../prompts/responses" @@ -411,6 +410,7 @@ export async function presentAssistantMessage(cline: Task) { switch (block.name) { case "write_to_file": + await checkpointSaveAndMark(cline) await writeToFileTool(cline, block, askApproval, handleError, pushToolResult, removeClosingTag) break case "update_todo_list": @@ -430,8 +430,10 @@ export async function presentAssistantMessage(cline: Task) { } if (isMultiFileApplyDiffEnabled) { + await checkpointSaveAndMark(cline) await applyDiffTool(cline, block, askApproval, handleError, pushToolResult, removeClosingTag) } else { + await checkpointSaveAndMark(cline) await applyDiffToolLegacy( cline, block, @@ -444,9 +446,11 @@ export async function presentAssistantMessage(cline: Task) { break } case "insert_content": + await checkpointSaveAndMark(cline) await insertContentTool(cline, block, askApproval, handleError, pushToolResult, removeClosingTag) break case "search_and_replace": + await checkpointSaveAndMark(cline) await searchAndReplaceTool(cline, block, askApproval, handleError, pushToolResult, removeClosingTag) break case "read_file": @@ -527,14 +531,6 @@ export async function presentAssistantMessage(cline: Task) { break } - const recentlyModifiedFiles = cline.fileContextTracker.getAndClearCheckpointPossibleFile() - - if (recentlyModifiedFiles.length > 0) { - // TODO: We can track what file changes were made and only - // checkpoint those files, this will be save storage. - await checkpointSave(cline) - } - // Seeing out of bounds is fine, it means that the next too call is being // built up and ready to add to assistantMessageContent to present. // When you see the UI inactive during this, it means that a tool is @@ -583,3 +579,20 @@ export async function presentAssistantMessage(cline: Task) { presentAssistantMessage(cline) } } + +/** + * save checkpoint and mark done in the current streaming task. + * @param task The Task instance to checkpoint save and mark. + * @returns + */ +async function checkpointSaveAndMark(task: Task) { + if (task.currentStreamingDidCheckpoint) { + return + } + try { + await task.checkpointSave(true) + task.currentStreamingDidCheckpoint = true + } catch (error) { + console.error(`[Task#presentAssistantMessage] Error saving checkpoint: ${error.message}`, error) + } +} diff --git a/src/core/checkpoints/index.ts b/src/core/checkpoints/index.ts index 02fb5dfc5a1..f08dc24e167 100644 --- a/src/core/checkpoints/index.ts +++ b/src/core/checkpoints/index.ts @@ -16,18 +16,29 @@ import { DIFF_VIEW_URI_SCHEME } from "../../integrations/editor/DiffViewProvider import { CheckpointServiceOptions, RepoPerTaskCheckpointService } from "../../services/checkpoints" -export function getCheckpointService(cline: Task) { +export async function getCheckpointService( + cline: Task, + { interval = 250, timeout = 15_000 }: { interval?: number; timeout?: number } = {}, +) { if (!cline.enableCheckpoints) { return undefined } if (cline.checkpointService) { - return cline.checkpointService - } - - if (cline.checkpointServiceInitializing) { - console.log("[Task#getCheckpointService] checkpoint service is still initializing") - return undefined + if (cline.checkpointServiceInitializing) { + console.log("[Task#getCheckpointService] checkpoint service is still initializing") + const service = cline.checkpointService + await pWaitFor( + () => { + console.log("[Task#getCheckpointService] waiting for service to initialize") + return service.isInitialized + }, + { interval, timeout }, + ) + return service.isInitialized ? cline.checkpointService : undefined + } else { + return cline.checkpointService + } } const provider = cline.providerRef.deref() @@ -69,15 +80,20 @@ export function getCheckpointService(cline: Task) { } const service = RepoPerTaskCheckpointService.create(options) - cline.checkpointServiceInitializing = true // Check if Git is installed before initializing the service - // Note: This is intentionally fire-and-forget to match the original IIFE pattern - // The service is returned immediately while Git check happens asynchronously - checkGitInstallation(cline, service, log, provider) - - return service + // Only assign the service after successful initialization + try { + await checkGitInstallation(cline, service, log, provider) + cline.checkpointService = service + return service + } catch (err) { + // Clean up on failure + cline.checkpointServiceInitializing = false + cline.enableCheckpoints = false + throw err + } } catch (err) { log(`[Task#getCheckpointService] ${err.message}`) cline.enableCheckpoints = false @@ -115,22 +131,7 @@ async function checkGitInstallation( // Git is installed, proceed with initialization service.on("initialize", () => { log("[Task#getCheckpointService] service initialized") - - try { - const isCheckpointNeeded = - typeof cline.clineMessages.find(({ say }) => say === "checkpoint_saved") === "undefined" - - cline.checkpointService = service - cline.checkpointServiceInitializing = false - - if (isCheckpointNeeded) { - log("[Task#getCheckpointService] no checkpoints found, saving initial checkpoint") - checkpointSave(cline) - } - } catch (err) { - log("[Task#getCheckpointService] caught error in on('initialize'), disabling checkpoints") - cline.enableCheckpoints = false - } + cline.checkpointServiceInitializing = false }) service.on("checkpoint", ({ isFirst, fromHash: from, toHash: to }) => { @@ -153,11 +154,12 @@ async function checkGitInstallation( }) log("[Task#getCheckpointService] initializing shadow git") - - service.initShadowGit().catch((err) => { + try { + await service.initShadowGit() + } catch (err) { log(`[Task#getCheckpointService] initShadowGit -> ${err.message}`) cline.enableCheckpoints = false - }) + } } catch (err) { log(`[Task#getCheckpointService] Unexpected error during Git check: ${err.message}`) console.error("Git check error:", err) @@ -166,33 +168,8 @@ async function checkGitInstallation( } } -async function getInitializedCheckpointService( - cline: Task, - { interval = 250, timeout = 15_000 }: { interval?: number; timeout?: number } = {}, -) { - const service = getCheckpointService(cline) - - if (!service || service.isInitialized) { - return service - } - - try { - await pWaitFor( - () => { - console.log("[Task#getCheckpointService] waiting for service to initialize") - return service.isInitialized - }, - { interval, timeout }, - ) - - return service - } catch (err) { - return undefined - } -} - export async function checkpointSave(cline: Task, force = false) { - const service = getCheckpointService(cline) + const service = await getCheckpointService(cline) if (!service) { return @@ -221,7 +198,7 @@ export type CheckpointRestoreOptions = { } export async function checkpointRestore(cline: Task, { ts, commitHash, mode }: CheckpointRestoreOptions) { - const service = await getInitializedCheckpointService(cline) + const service = await getCheckpointService(cline) if (!service) { return @@ -289,7 +266,7 @@ export type CheckpointDiffOptions = { } export async function checkpointDiff(cline: Task, { ts, previousCommitHash, commitHash, mode }: CheckpointDiffOptions) { - const service = await getInitializedCheckpointService(cline) + const service = await getCheckpointService(cline) if (!service) { return @@ -297,17 +274,19 @@ export async function checkpointDiff(cline: Task, { ts, previousCommitHash, comm TelemetryService.instance.captureCheckpointDiffed(cline.taskId) - if (!previousCommitHash && mode === "checkpoint") { - const previousCheckpoint = cline.clineMessages - .filter(({ say }) => say === "checkpoint_saved") - .sort((a, b) => b.ts - a.ts) - .find((message) => message.ts < ts) + let prevHash = commitHash + let nextHash: string | undefined - previousCommitHash = previousCheckpoint?.text + const checkpoints = typeof service.getCheckpoints === "function" ? service.getCheckpoints() : [] + const idx = checkpoints.indexOf(commitHash) + if (idx !== -1 && idx < checkpoints.length - 1) { + nextHash = checkpoints[idx + 1] + } else { + nextHash = undefined } try { - const changes = await service.getDiff({ from: previousCommitHash, to: commitHash }) + const changes = await service.getDiff({ from: prevHash, to: nextHash }) if (!changes?.length) { vscode.window.showInformationMessage("No changes found.") diff --git a/src/core/task/Task.ts b/src/core/task/Task.ts index edbde32ea7a..e1511568b0e 100644 --- a/src/core/task/Task.ts +++ b/src/core/task/Task.ts @@ -247,6 +247,7 @@ export class Task extends EventEmitter { isWaitingForFirstChunk = false isStreaming = false currentStreamingContentIndex = 0 + currentStreamingDidCheckpoint = false assistantMessageContent: AssistantMessageContent[] = [] presentAssistantMessageLocked = false presentAssistantMessageHasPendingUpdates = false @@ -1523,6 +1524,7 @@ export class Task extends EventEmitter { // Reset streaming state. this.currentStreamingContentIndex = 0 + this.currentStreamingDidCheckpoint = false this.assistantMessageContent = [] this.didCompleteReadingStream = false this.userMessageContent = [] diff --git a/src/services/checkpoints/ShadowCheckpointService.ts b/src/services/checkpoints/ShadowCheckpointService.ts index be2c86852ab..280cbd81183 100644 --- a/src/services/checkpoints/ShadowCheckpointService.ts +++ b/src/services/checkpoints/ShadowCheckpointService.ts @@ -38,6 +38,10 @@ export abstract class ShadowCheckpointService extends EventEmitter { return !!this.git } + public getCheckpoints(): string[] { + return this._checkpoints.slice() + } + constructor(taskId: string, checkpointsDir: string, workspaceDir: string, log: (message: string) => void) { super() diff --git a/webview-ui/src/components/chat/checkpoints/CheckpointMenu.tsx b/webview-ui/src/components/chat/checkpoints/CheckpointMenu.tsx index 21b4f486c7b..eba47699abc 100644 --- a/webview-ui/src/components/chat/checkpoints/CheckpointMenu.tsx +++ b/webview-ui/src/components/chat/checkpoints/CheckpointMenu.tsx @@ -22,9 +22,6 @@ export const CheckpointMenu = ({ ts, commitHash, currentHash, checkpoint }: Chec const portalContainer = useRooPortal("roo-portal") const isCurrent = currentHash === commitHash - const isFirst = checkpoint.isFirst - const isDiffAvailable = !isFirst - const isRestoreAvailable = !isFirst || !isCurrent const previousCommitHash = checkpoint?.from @@ -47,78 +44,72 @@ export const CheckpointMenu = ({ ts, commitHash, currentHash, checkpoint }: Chec return (
- {isDiffAvailable && ( - - + + + + { + setIsOpen(open) + setIsConfirming(false) + }}> + + + + - )} - {isRestoreAvailable && ( - { - setIsOpen(open) - setIsConfirming(false) - }}> - - - - - - -
- {!isCurrent && ( -
- -
- {t("chat:checkpoint.menu.restoreFilesDescription")} -
+ +
+ {!isCurrent && ( +
+ +
+ {t("chat:checkpoint.menu.restoreFilesDescription")}
- )} - {!isFirst && ( -
-
- {!isConfirming ? ( - - ) : ( - <> - - - - )} - {isConfirming ? ( -
- {t("chat:checkpoint.menu.cannotUndo")} +
+ )} +
+
+ {!isConfirming ? ( + + ) : ( + <> + + + + )} + {isConfirming ? ( +
+ {t("chat:checkpoint.menu.cannotUndo")}
-
- )} + ) : ( +
+ {t("chat:checkpoint.menu.restoreFilesAndTaskDescription")} +
+ )} +
- - - )} +
+ +
) }