diff --git a/packages/opencode/src/session/revert.ts b/packages/opencode/src/session/revert.ts index ef9c7e2aace9..ada4f82bfd57 100644 --- a/packages/opencode/src/session/revert.ts +++ b/packages/opencode/src/session/revert.ts @@ -56,11 +56,21 @@ export namespace SessionRevert { if (revert) { const session = await Session.get(input.sessionID) - revert.snapshot = session.revert?.snapshot ?? (await Snapshot.track()) + const snapshot = session.revert?.snapshot ?? (await Snapshot.track()) + // When stacking reverts, ensure we restore the original snapshot baseline so the + // next diff only reflects the newly reverted content, not prior revert state. + if (session.revert?.snapshot) await Snapshot.restore(session.revert.snapshot) + // Preserve the baseline snapshot for future redo/unrevert operations. + revert.snapshot = snapshot await Snapshot.revert(patches) if (revert.snapshot) revert.diff = await Snapshot.diff(revert.snapshot) - const rangeMessages = all.filter((msg) => msg.info.id >= revert!.messageID) - const diffs = await SessionSummary.computeDiff({ messages: rangeMessages }) + const remainingMessages = all.filter((msg) => + revert?.partID ? msg.info.id <= revert.messageID : msg.info.id < revert!.messageID, + ) + // Make sure new summary is made up of remainingMessgaes, messages with id less (<) then the reverted one + // If reverting a part, we add in the message itself + + const diffs = await SessionSummary.computeDiff({ messages: remainingMessages }) await Storage.write(["session_diff", input.sessionID], diffs) Bus.publish(Session.Event.Diff, { sessionID: input.sessionID, @@ -85,6 +95,23 @@ export namespace SessionRevert { const session = await Session.get(input.sessionID) if (!session.revert) return session if (session.revert.snapshot) await Snapshot.restore(session.revert.snapshot) + const messages = await Session.messages({ sessionID: input.sessionID }) + const diffs = await SessionSummary.computeDiff({ messages }) + await Storage.write(["session_diff", input.sessionID], diffs) + Bus.publish(Session.Event.Diff, { + sessionID: input.sessionID, + diff: diffs, + }) + + await Session.setSummary({ + sessionID: input.sessionID, + summary: { + ...session.summary, + additions: diffs.reduce((sum, x) => sum + x.additions, 0), + deletions: diffs.reduce((sum, x) => sum + x.deletions, 0), + files: diffs.length, + }, + }) return Session.clearRevert(input.sessionID) }