diff --git a/packages/opencode/src/cli/cmd/tui/context/sync.tsx b/packages/opencode/src/cli/cmd/tui/context/sync.tsx index 269ed7ae0bd1..eec04250a5e4 100644 --- a/packages/opencode/src/cli/cmd/tui/context/sync.tsx +++ b/packages/opencode/src/cli/cmd/tui/context/sync.tsx @@ -103,6 +103,30 @@ export const { use: useSync, provider: SyncProvider } = createSimpleContext({ }) const sdk = useSDK() + const fullSyncedSessions = new Set() + async function syncSession(sessionID: string) { + if (fullSyncedSessions.has(sessionID)) return + const [session, messages, todo, diff] = await Promise.all([ + sdk.client.session.get({ sessionID }, { throwOnError: true }), + sdk.client.session.messages({ sessionID, limit: 100 }), + sdk.client.session.todo({ sessionID }), + sdk.client.session.diff({ sessionID }), + ]) + setStore( + produce((draft) => { + const match = Binary.search(draft.session, sessionID, (s) => s.id) + if (match.found) draft.session[match.index] = session.data! + if (!match.found) draft.session.splice(match.index, 0, session.data!) + draft.todo[sessionID] = todo.data ?? [] + draft.message[sessionID] = messages.data!.map((x) => x.info) + for (const message of messages.data!) { + draft.part[message.info.id] = message.parts + } + draft.session_diff[sessionID] = diff.data ?? [] + }), + ) + fullSyncedSessions.add(sessionID) + } sdk.event.listen((e) => { const event = e.details @@ -192,6 +216,10 @@ export const { use: useSync, provider: SyncProvider } = createSimpleContext({ case "session.diff": setStore("session_diff", event.properties.sessionID, event.properties.diff) break + case "session.compacted": + fullSyncedSessions.delete(event.properties.sessionID) + void syncSession(event.properties.sessionID) + break case "session.deleted": { const result = Binary.search(store.session, event.properties.info.id, (s) => s.id) @@ -431,7 +459,6 @@ export const { use: useSync, provider: SyncProvider } = createSimpleContext({ bootstrap() }) - const fullSyncedSessions = new Set() const result = { data: store, set: setStore, @@ -458,27 +485,7 @@ export const { use: useSync, provider: SyncProvider } = createSimpleContext({ return last.time.completed ? "idle" : "working" }, async sync(sessionID: string) { - if (fullSyncedSessions.has(sessionID)) return - const [session, messages, todo, diff] = await Promise.all([ - sdk.client.session.get({ sessionID }, { throwOnError: true }), - sdk.client.session.messages({ sessionID, limit: 100 }), - sdk.client.session.todo({ sessionID }), - sdk.client.session.diff({ sessionID }), - ]) - setStore( - produce((draft) => { - const match = Binary.search(draft.session, sessionID, (s) => s.id) - if (match.found) draft.session[match.index] = session.data! - if (!match.found) draft.session.splice(match.index, 0, session.data!) - draft.todo[sessionID] = todo.data ?? [] - draft.message[sessionID] = messages.data!.map((x) => x.info) - for (const message of messages.data!) { - draft.part[message.info.id] = message.parts - } - draft.session_diff[sessionID] = diff.data ?? [] - }), - ) - fullSyncedSessions.add(sessionID) + return syncSession(sessionID) }, }, bootstrap, diff --git a/packages/opencode/src/cli/cmd/tui/routes/session/sidebar.tsx b/packages/opencode/src/cli/cmd/tui/routes/session/sidebar.tsx index 42ac5fbe080a..8432de838bd6 100644 --- a/packages/opencode/src/cli/cmd/tui/routes/session/sidebar.tsx +++ b/packages/opencode/src/cli/cmd/tui/routes/session/sidebar.tsx @@ -19,6 +19,11 @@ export function Sidebar(props: { sessionID: string; overlay?: boolean }) { const diff = createMemo(() => sync.data.session_diff[props.sessionID] ?? []) const todo = createMemo(() => sync.data.todo[props.sessionID] ?? []) const messages = createMemo(() => sync.data.message[props.sessionID] ?? []) + const active = createMemo(() => { + const messageID = session().revert?.messageID + if (!messageID) return messages() + return messages().filter((item) => item.id < messageID) + }) const [expanded, setExpanded] = createStore({ mcp: true, @@ -41,7 +46,7 @@ export function Sidebar(props: { sessionID: string; overlay?: boolean }) { ) const cost = createMemo(() => { - const total = messages().reduce((sum, x) => sum + (x.role === "assistant" ? x.cost : 0), 0) + const total = active().reduce((sum, x) => sum + (x.role === "assistant" ? x.cost : 0), 0) return new Intl.NumberFormat("en-US", { style: "currency", currency: "USD", @@ -49,7 +54,7 @@ export function Sidebar(props: { sessionID: string; overlay?: boolean }) { }) const context = createMemo(() => { - const last = messages().findLast((x) => x.role === "assistant" && x.tokens.output > 0) as AssistantMessage + const last = active().findLast((x) => x.role === "assistant" && x.tokens.output > 0) as AssistantMessage if (!last) return const total = last.tokens.input + last.tokens.output + last.tokens.reasoning + last.tokens.cache.read + last.tokens.cache.write diff --git a/packages/opencode/src/session/revert.ts b/packages/opencode/src/session/revert.ts index ef9c7e2aace9..547779574c66 100644 --- a/packages/opencode/src/session/revert.ts +++ b/packages/opencode/src/session/revert.ts @@ -14,6 +14,32 @@ import { SessionSummary } from "./summary" export namespace SessionRevert { const log = Log.create({ service: "session.revert" }) + function active(messages: MessageV2.WithParts[], revert: Session.Info["revert"]) { + if (!revert) return messages + return messages.flatMap((msg) => { + if (msg.info.id < revert.messageID) return [msg] + if (msg.info.id > revert.messageID) return [] + if (!revert.partID) return [] + const index = msg.parts.findIndex((part) => part.id === revert.partID) + if (index <= 0) return [] + return [{ ...msg, parts: msg.parts.slice(0, index) }] + }) + } + + async function sync(sessionID: string, messages: MessageV2.WithParts[]) { + const diff = await SessionSummary.computeDiff({ messages }) + await Storage.write(["session_diff", sessionID], diff) + Bus.publish(Session.Event.Diff, { + sessionID, + diff, + }) + return { + additions: diff.reduce((sum, x) => sum + x.additions, 0), + deletions: diff.reduce((sum, x) => sum + x.deletions, 0), + files: diff.length, + } + } + export const RevertInput = z.object({ sessionID: Identifier.schema("session"), messageID: Identifier.schema("message"), @@ -59,21 +85,10 @@ export namespace SessionRevert { revert.snapshot = session.revert?.snapshot ?? (await Snapshot.track()) await Snapshot.revert(patches) if (revert.snapshot) revert.diff = await Snapshot.diff(revert.snapshot) - const rangeMessages = all.filter((msg) => msg.info.id >= revert!.messageID) - const diffs = await SessionSummary.computeDiff({ messages: rangeMessages }) - await Storage.write(["session_diff", input.sessionID], diffs) - Bus.publish(Session.Event.Diff, { - sessionID: input.sessionID, - diff: diffs, - }) return Session.setRevert({ sessionID: input.sessionID, revert, - summary: { - additions: diffs.reduce((sum, x) => sum + x.additions, 0), - deletions: diffs.reduce((sum, x) => sum + x.deletions, 0), - files: diffs.length, - }, + summary: await sync(input.sessionID, active(all, revert)), }) } return session @@ -85,7 +100,11 @@ export namespace SessionRevert { const session = await Session.get(input.sessionID) if (!session.revert) return session if (session.revert.snapshot) await Snapshot.restore(session.revert.snapshot) - return Session.clearRevert(input.sessionID) + await Session.clearRevert(input.sessionID) + return Session.setSummary({ + sessionID: input.sessionID, + summary: await sync(input.sessionID, await Session.messages({ sessionID: input.sessionID })), + }) } export async function cleanup(session: Session.Info) { @@ -134,5 +153,9 @@ export namespace SessionRevert { } } await Session.clearRevert(sessionID) + await Session.setSummary({ + sessionID, + summary: await sync(sessionID, await Session.messages({ sessionID })), + }) } } diff --git a/packages/opencode/test/session/revert-compact.test.ts b/packages/opencode/test/session/revert-compact.test.ts index de2b14573f43..7d9c63be8a73 100644 --- a/packages/opencode/test/session/revert-compact.test.ts +++ b/packages/opencode/test/session/revert-compact.test.ts @@ -7,6 +7,7 @@ import { MessageV2 } from "../../src/session/message-v2" import { Log } from "../../src/util/log" import { Instance } from "../../src/project/instance" import { Identifier } from "../../src/id/id" +import { Snapshot } from "../../src/snapshot" import { tmpdir } from "../fixture/fixture" const projectRoot = path.join(__dirname, "../..") @@ -282,4 +283,132 @@ describe("revert + compact workflow", () => { }, }) }) + + test("should keep session diff aligned across revert and unrevert", async () => { + await using tmp = await tmpdir({ git: true }) + await Instance.provide({ + directory: tmp.path, + fn: async () => { + const session = await Session.create({}) + const sessionID = session.id + + async function turn(input: { file: string; text: string }) { + const user = await Session.updateMessage({ + id: Identifier.ascending("message"), + role: "user", + sessionID, + agent: "default", + model: { + providerID: "openai", + modelID: "gpt-4", + }, + time: { + created: Date.now(), + }, + }) + + await Session.updatePart({ + id: Identifier.ascending("part"), + messageID: user.id, + sessionID, + type: "text", + text: `Edit ${input.file}`, + }) + + const assistant: MessageV2.Assistant = { + id: Identifier.ascending("message"), + role: "assistant", + sessionID, + mode: "default", + agent: "default", + path: { + cwd: tmp.path, + root: tmp.path, + }, + cost: 0, + tokens: { + output: 0, + input: 0, + reasoning: 0, + cache: { read: 0, write: 0 }, + }, + modelID: "gpt-4", + providerID: "openai", + parentID: user.id, + time: { + created: Date.now(), + }, + finish: "end_turn", + } + await Session.updateMessage(assistant) + + const from = await Snapshot.track() + await Bun.write(path.join(tmp.path, input.file), input.text) + const to = await Snapshot.track() + if (!from || !to) throw new Error("expected snapshot hashes") + const patch = await Snapshot.patch(from) + + await Session.updatePart({ + id: Identifier.ascending("part"), + messageID: assistant.id, + sessionID, + type: "step-start", + snapshot: from, + }) + + await Session.updatePart({ + id: Identifier.ascending("part"), + messageID: assistant.id, + sessionID, + type: "step-finish", + reason: "stop", + snapshot: to, + tokens: { + output: 1, + input: 1, + reasoning: 0, + cache: { read: 0, write: 0 }, + }, + cost: 0, + }) + + if (patch.files.length > 0) { + await Session.updatePart({ + id: Identifier.ascending("part"), + messageID: assistant.id, + sessionID, + type: "patch", + hash: patch.hash, + files: patch.files, + }) + } + + return user + } + + const has = (diffs: Awaited>, file: string) => + diffs.some((item) => item.file === file || item.file.endsWith(`/${file}`)) + + await turn({ file: "one.txt", text: "one\n" }) + const second = await turn({ file: "two.txt", text: "two\n" }) + + await SessionRevert.revert({ + sessionID, + messageID: second.id, + }) + + const undone = await Session.diff(sessionID) + expect(has(undone, "one.txt")).toBe(true) + expect(has(undone, "two.txt")).toBe(false) + + await SessionRevert.unrevert({ sessionID }) + + const restored = await Session.diff(sessionID) + expect(has(restored, "one.txt")).toBe(true) + expect(has(restored, "two.txt")).toBe(true) + + await Session.remove(sessionID) + }, + }) + }) })