diff --git a/packages/opencode/src/session/revert.ts b/packages/opencode/src/session/revert.ts index b1e9840e4fdc..92049b12bd25 100644 --- a/packages/opencode/src/session/revert.ts +++ b/packages/opencode/src/session/revert.ts @@ -1,12 +1,14 @@ import z from "zod" -import { SessionID, MessageID, PartID } from "./schema" +import { Effect, Layer, ServiceMap } from "effect" +import { makeRuntime } from "@/effect/run-service" +import { Bus } from "../bus" import { Snapshot } from "../snapshot" -import { MessageV2 } from "./message-v2" -import { Session } from "." -import { Log } from "../util/log" -import { SyncEvent } from "../sync" import { Storage } from "@/storage/storage" -import { Bus } from "../bus" +import { SyncEvent } from "../sync" +import { Log } from "../util/log" +import { Session } from "." +import { MessageV2 } from "./message-v2" +import { SessionID, MessageID, PartID } from "./schema" import { SessionPrompt } from "./prompt" import { SessionSummary } from "./summary" @@ -20,116 +22,152 @@ export namespace SessionRevert { }) export type RevertInput = z.infer - export async function revert(input: RevertInput) { - await SessionPrompt.assertNotBusy(input.sessionID) - const all = await Session.messages({ sessionID: input.sessionID }) - let lastUser: MessageV2.User | undefined - const session = await Session.get(input.sessionID) - - let revert: Session.Info["revert"] - const patches: Snapshot.Patch[] = [] - for (const msg of all) { - if (msg.info.role === "user") lastUser = msg.info - const remaining = [] - for (const part of msg.parts) { - if (revert) { - if (part.type === "patch") { - patches.push(part) + export interface Interface { + readonly revert: (input: RevertInput) => Effect.Effect + readonly unrevert: (input: { sessionID: SessionID }) => Effect.Effect + readonly cleanup: (session: Session.Info) => Effect.Effect + } + + export class Service extends ServiceMap.Service()("@opencode/SessionRevert") {} + + export const layer = Layer.effect( + Service, + Effect.gen(function* () { + const sessions = yield* Session.Service + const snap = yield* Snapshot.Service + const storage = yield* Storage.Service + const bus = yield* Bus.Service + + const revert = Effect.fn("SessionRevert.revert")(function* (input: RevertInput) { + yield* Effect.promise(() => SessionPrompt.assertNotBusy(input.sessionID)) + const all = yield* sessions.messages({ sessionID: input.sessionID }) + let lastUser: MessageV2.User | undefined + const session = yield* sessions.get(input.sessionID) + + let rev: Session.Info["revert"] + const patches: Snapshot.Patch[] = [] + for (const msg of all) { + if (msg.info.role === "user") lastUser = msg.info + const remaining = [] + for (const part of msg.parts) { + if (rev) { + if (part.type === "patch") patches.push(part) + continue + } + + if (!rev) { + if ((msg.info.id === input.messageID && !input.partID) || part.id === input.partID) { + const partID = remaining.some((item) => ["text", "tool"].includes(item.type)) ? input.partID : undefined + rev = { + messageID: !partID && lastUser ? lastUser.id : msg.info.id, + partID, + } + } + remaining.push(part) + } } - continue } - if (!revert) { - if ((msg.info.id === input.messageID && !input.partID) || part.id === input.partID) { - // if no useful parts left in message, same as reverting whole message - const partID = remaining.some((item) => ["text", "tool"].includes(item.type)) ? input.partID : undefined - revert = { - messageID: !partID && lastUser ? lastUser.id : msg.info.id, - partID, + if (!rev) return session + + rev.snapshot = session.revert?.snapshot ?? (yield* snap.track()) + yield* snap.revert(patches) + if (rev.snapshot) rev.diff = yield* snap.diff(rev.snapshot as string) + const range = all.filter((msg) => msg.info.id >= rev!.messageID) + const diffs = yield* Effect.promise(() => SessionSummary.computeDiff({ messages: range })) + yield* storage.write(["session_diff", input.sessionID], diffs).pipe(Effect.ignore) + yield* bus.publish(Session.Event.Diff, { sessionID: input.sessionID, diff: diffs }) + yield* sessions.setRevert({ + sessionID: input.sessionID, + revert: rev, + summary: { + additions: diffs.reduce((sum, x) => sum + x.additions, 0), + deletions: diffs.reduce((sum, x) => sum + x.deletions, 0), + files: diffs.length, + }, + }) + return yield* sessions.get(input.sessionID) + }) + + const unrevert = Effect.fn("SessionRevert.unrevert")(function* (input: { sessionID: SessionID }) { + log.info("unreverting", input) + yield* Effect.promise(() => SessionPrompt.assertNotBusy(input.sessionID)) + const session = yield* sessions.get(input.sessionID) + if (!session.revert) return session + if (session.revert.snapshot) yield* snap.restore(session.revert!.snapshot!) + yield* sessions.clearRevert(input.sessionID) + return yield* sessions.get(input.sessionID) + }) + + const cleanup = Effect.fn("SessionRevert.cleanup")(function* (session: Session.Info) { + if (!session.revert) return + const sessionID = session.id + const msgs = yield* sessions.messages({ sessionID }) + const messageID = session.revert.messageID + const remove = [] as MessageV2.WithParts[] + let target: MessageV2.WithParts | undefined + for (const msg of msgs) { + if (msg.info.id < messageID) continue + if (msg.info.id > messageID) { + remove.push(msg) + continue + } + if (session.revert.partID) { + target = msg + continue + } + remove.push(msg) + } + for (const msg of remove) { + SyncEvent.run(MessageV2.Event.Removed, { + sessionID, + messageID: msg.info.id, + }) + } + if (session.revert.partID && target) { + const partID = session.revert.partID + const idx = target.parts.findIndex((part) => part.id === partID) + if (idx >= 0) { + const removeParts = target.parts.slice(idx) + target.parts = target.parts.slice(0, idx) + for (const part of removeParts) { + SyncEvent.run(MessageV2.Event.PartRemoved, { + sessionID, + messageID: target.info.id, + partID: part.id, + }) } } - remaining.push(part) } - } - } - - if (revert) { - const session = await Session.get(input.sessionID) - revert.snapshot = session.revert?.snapshot ?? (await Snapshot.track()) - await Snapshot.revert(patches) - if (revert.snapshot) revert.diff = await Snapshot.diff(revert.snapshot) - const rangeMessages = all.filter((msg) => msg.info.id >= revert!.messageID) - const diffs = await SessionSummary.computeDiff({ messages: rangeMessages }) - await Storage.write(["session_diff", input.sessionID], diffs) - Bus.publish(Session.Event.Diff, { - sessionID: input.sessionID, - diff: diffs, - }) - return Session.setRevert({ - sessionID: input.sessionID, - revert, - summary: { - additions: diffs.reduce((sum, x) => sum + x.additions, 0), - deletions: diffs.reduce((sum, x) => sum + x.deletions, 0), - files: diffs.length, - }, + yield* sessions.clearRevert(sessionID) }) - } - return session + + return Service.of({ revert, unrevert, cleanup }) + }), + ) + + export const defaultLayer = Layer.unwrap( + Effect.sync(() => + layer.pipe( + Layer.provide(Session.defaultLayer), + Layer.provide(Snapshot.defaultLayer), + Layer.provide(Storage.defaultLayer), + Layer.provide(Bus.layer), + ), + ), + ) + + const { runPromise } = makeRuntime(Service, defaultLayer) + + export async function revert(input: RevertInput) { + return runPromise((svc) => svc.revert(input)) } export async function unrevert(input: { sessionID: SessionID }) { - log.info("unreverting", input) - await SessionPrompt.assertNotBusy(input.sessionID) - const session = await Session.get(input.sessionID) - if (!session.revert) return session - if (session.revert.snapshot) await Snapshot.restore(session.revert.snapshot) - return Session.clearRevert(input.sessionID) + return runPromise((svc) => svc.unrevert(input)) } export async function cleanup(session: Session.Info) { - if (!session.revert) return - const sessionID = session.id - const msgs = await Session.messages({ sessionID }) - const messageID = session.revert.messageID - const remove = [] as MessageV2.WithParts[] - let target: MessageV2.WithParts | undefined - for (const msg of msgs) { - if (msg.info.id < messageID) { - continue - } - if (msg.info.id > messageID) { - remove.push(msg) - continue - } - if (session.revert.partID) { - target = msg - continue - } - remove.push(msg) - } - for (const msg of remove) { - SyncEvent.run(MessageV2.Event.Removed, { - sessionID: sessionID, - messageID: msg.info.id, - }) - } - if (session.revert.partID && target) { - const partID = session.revert.partID - const removeStart = target.parts.findIndex((part) => part.id === partID) - if (removeStart >= 0) { - const preserveParts = target.parts.slice(0, removeStart) - const removeParts = target.parts.slice(removeStart) - target.parts = preserveParts - for (const part of removeParts) { - SyncEvent.run(MessageV2.Event.PartRemoved, { - sessionID: sessionID, - messageID: target.info.id, - partID: part.id, - }) - } - } - } - await Session.clearRevert(sessionID) + return runPromise((svc) => svc.cleanup(session)) } } diff --git a/packages/opencode/test/session/revert-compact.test.ts b/packages/opencode/test/session/revert-compact.test.ts index fb37a3a8dca1..fe7055779c57 100644 --- a/packages/opencode/test/session/revert-compact.test.ts +++ b/packages/opencode/test/session/revert-compact.test.ts @@ -10,9 +10,59 @@ import { Instance } from "../../src/project/instance" import { MessageID, PartID } from "../../src/session/schema" import { tmpdir } from "../fixture/fixture" -const projectRoot = path.join(__dirname, "../..") Log.init({ print: false }) +function user(sessionID: string, agent = "default") { + return Session.updateMessage({ + id: MessageID.ascending(), + role: "user" as const, + sessionID: sessionID as any, + agent, + model: { providerID: ProviderID.make("openai"), modelID: ModelID.make("gpt-4") }, + time: { created: Date.now() }, + }) +} + +function assistant(sessionID: string, parentID: string, dir: string) { + return Session.updateMessage({ + id: MessageID.ascending(), + role: "assistant" as const, + sessionID: sessionID as any, + mode: "default", + agent: "default", + path: { cwd: dir, root: dir }, + cost: 0, + tokens: { output: 0, input: 0, reasoning: 0, cache: { read: 0, write: 0 } }, + modelID: ModelID.make("gpt-4"), + providerID: ProviderID.make("openai"), + parentID: parentID as any, + time: { created: Date.now() }, + finish: "end_turn", + }) +} + +function text(sessionID: string, messageID: string, content: string) { + return Session.updatePart({ + id: PartID.ascending(), + messageID: messageID as any, + sessionID: sessionID as any, + type: "text" as const, + text: content, + }) +} + +function tool(sessionID: string, messageID: string) { + return Session.updatePart({ + id: PartID.ascending(), + messageID: messageID as any, + sessionID: sessionID as any, + type: "tool" as const, + tool: "bash", + callID: "call-1", + state: { status: "completed" as const, input: {}, output: "done", title: "", metadata: {}, time: { start: 0, end: 1 } }, + }) +} + describe("revert + compact workflow", () => { test("should properly handle compact command after revert", async () => { await using tmp = await tmpdir({ git: true }) @@ -283,4 +333,98 @@ describe("revert + compact workflow", () => { }, }) }) + + test("cleanup with partID removes parts from the revert point onward", async () => { + await using tmp = await tmpdir({ git: true }) + await Instance.provide({ + directory: tmp.path, + fn: async () => { + const session = await Session.create({}) + const sid = session.id + + const u1 = await user(sid) + const p1 = await text(sid, u1.id, "first part") + const p2 = await tool(sid, u1.id) + const p3 = await text(sid, u1.id, "third part") + + // Set revert state pointing at a specific part + await Session.setRevert({ + sessionID: sid, + revert: { messageID: u1.id, partID: p2.id }, + summary: { additions: 0, deletions: 0, files: 0 }, + }) + + const info = await Session.get(sid) + await SessionRevert.cleanup(info) + + const msgs = await Session.messages({ sessionID: sid }) + expect(msgs.length).toBe(1) + // Only the first part should remain (before the revert partID) + expect(msgs[0].parts.length).toBe(1) + expect(msgs[0].parts[0].id).toBe(p1.id) + + const cleared = await Session.get(sid) + expect(cleared.revert).toBeUndefined() + }, + }) + }) + + test("cleanup removes messages after revert point but keeps earlier ones", async () => { + await using tmp = await tmpdir({ git: true }) + await Instance.provide({ + directory: tmp.path, + fn: async () => { + const session = await Session.create({}) + const sid = session.id + + const u1 = await user(sid) + await text(sid, u1.id, "hello") + const a1 = await assistant(sid, u1.id, tmp.path) + await text(sid, a1.id, "hi back") + + const u2 = await user(sid) + await text(sid, u2.id, "second question") + const a2 = await assistant(sid, u2.id, tmp.path) + await text(sid, a2.id, "second answer") + + // Revert from u2 onward + await Session.setRevert({ + sessionID: sid, + revert: { messageID: u2.id }, + summary: { additions: 0, deletions: 0, files: 0 }, + }) + + const info = await Session.get(sid) + await SessionRevert.cleanup(info) + + const msgs = await Session.messages({ sessionID: sid }) + const ids = msgs.map((m) => m.info.id) + expect(ids).toContain(u1.id) + expect(ids).toContain(a1.id) + expect(ids).not.toContain(u2.id) + expect(ids).not.toContain(a2.id) + }, + }) + }) + + test("cleanup is a no-op when session has no revert state", async () => { + await using tmp = await tmpdir({ git: true }) + await Instance.provide({ + directory: tmp.path, + fn: async () => { + const session = await Session.create({}) + const sid = session.id + + const u1 = await user(sid) + await text(sid, u1.id, "hello") + + const info = await Session.get(sid) + expect(info.revert).toBeUndefined() + await SessionRevert.cleanup(info) + + const msgs = await Session.messages({ sessionID: sid }) + expect(msgs.length).toBe(1) + }, + }) + }) })