diff --git a/packages/agent/src/server/agent-server.test.ts b/packages/agent/src/server/agent-server.test.ts index c487f8a35..9bbc755ca 100644 --- a/packages/agent/src/server/agent-server.test.ts +++ b/packages/agent/src/server/agent-server.test.ts @@ -13,7 +13,7 @@ import { import { createTestRepo, type TestRepo } from "../test/fixtures/api"; import { createPostHogHandlers } from "../test/mocks/msw-handlers"; import type { TaskRun } from "../types"; -import { AgentServer } from "./agent-server"; +import { AgentServer, SSE_KEEPALIVE_INTERVAL_MS } from "./agent-server"; import { type JwtPayload, SANDBOX_CONNECTION_AUDIENCE } from "./jwt"; interface TestableServer { @@ -274,6 +274,64 @@ describe("AgentServer HTTP Mode", () => { expect(response.status).toBe(200); expect(response.headers.get("content-type")).toBe("text/event-stream"); }, 20000); + + it("emits transport keepalive comments while idle", async () => { + const keepaliveCallback: { current: (() => void) | null } = { + current: null, + }; + const setIntervalSpy = vi + .spyOn(globalThis, "setInterval") + .mockImplementation( + (callback: (_: undefined) => void, timeout?: number) => { + if (timeout === SSE_KEEPALIVE_INTERVAL_MS) { + keepaliveCallback.current = () => callback(undefined); + } + return setTimeout(() => undefined, 60_000); + }, + ); + + let reader: ReadableStreamDefaultReader | null = null; + try { + await createServer().start(); + const token = createToken(); + + const response = await fetch(`http://localhost:${port}/events`, { + headers: { Authorization: `Bearer ${token}` }, + }); + + expect(response.status).toBe(200); + expect(response.body).not.toBeNull(); + reader = response.body?.getReader() ?? null; + expect(reader).not.toBeNull(); + if (!reader) { + throw new Error("Expected SSE response body reader"); + } + + await vi.waitFor(() => + expect(keepaliveCallback.current).not.toBeNull(), + ); + const emitKeepalive = keepaliveCallback.current; + if (!emitKeepalive) { + throw new Error("Expected keepalive callback to be registered"); + } + emitKeepalive(); + + const decoder = new TextDecoder(); + let streamText = ""; + for (let attempts = 0; attempts < 5; attempts++) { + const { done, value } = await reader.read(); + if (done) break; + streamText += decoder.decode(value, { stream: true }); + if (streamText.includes(": keepalive\n\n")) break; + } + + expect(streamText).toContain(": keepalive\n\n"); + expect(streamText).not.toContain('"type":"keepalive"'); + } finally { + await reader?.cancel(); + setIntervalSpy.mockRestore(); + } + }, 20000); }); describe("POST /command", () => { diff --git a/packages/agent/src/server/agent-server.ts b/packages/agent/src/server/agent-server.ts index 18acf4506..357ceef72 100644 --- a/packages/agent/src/server/agent-server.ts +++ b/packages/agent/src/server/agent-server.ts @@ -73,6 +73,8 @@ const errorWithClassificationSchema = z.object({ type MessageCallback = (message: unknown) => void; +export const SSE_KEEPALIVE_INTERVAL_MS = 25_000; + class NdJsonTap { private decoder = new TextDecoder(); private buffer = ""; @@ -329,41 +331,73 @@ export class AgentServer { ); } + let keepaliveInterval: ReturnType | null = null; + const clearKeepalive = (): void => { + if (keepaliveInterval) { + clearInterval(keepaliveInterval); + keepaliveInterval = null; + } + }; + const stream = new ReadableStream({ start: async (controller) => { - const sseController: SseController = { + let sseController: SseController | null = null; + const encoder = new TextEncoder(); + const detachCurrentSseController = (): void => { + if (sseController) { + this.detachSseController(sseController); + } + }; + const enqueueSseFrame = (frame: string): void => { + try { + controller.enqueue(encoder.encode(frame)); + } catch { + clearKeepalive(); + detachCurrentSseController(); + } + }; + + sseController = { send: (data: unknown) => { - try { - controller.enqueue( - new TextEncoder().encode(`data: ${JSON.stringify(data)}\n\n`), - ); - } catch { - this.detachSseController(sseController); - } + enqueueSseFrame(`data: ${JSON.stringify(data)}\n\n`); }, close: () => { try { + clearKeepalive(); controller.close(); } catch { - this.detachSseController(sseController); + detachCurrentSseController(); } }, }; - if (!this.session || this.session.payload.run_id !== payload.run_id) { - await this.initializeSession(payload, sseController); - } else { - this.session.sseController = sseController; - this.session.hasDesktopConnected = true; - this.replayPendingEvents(); - } + keepaliveInterval = setInterval(() => { + enqueueSseFrame(": keepalive\n\n"); + }, SSE_KEEPALIVE_INTERVAL_MS); + + try { + if ( + !this.session || + this.session.payload.run_id !== payload.run_id + ) { + await this.initializeSession(payload, sseController); + } else { + this.session.sseController = sseController; + this.session.hasDesktopConnected = true; + this.replayPendingEvents(); + } - this.sendSseEvent(sseController, { - type: "connected", - run_id: payload.run_id, - }); + this.sendSseEvent(sseController, { + type: "connected", + run_id: payload.run_id, + }); + } catch (error) { + clearKeepalive(); + throw error; + } }, cancel: () => { + clearKeepalive(); this.logger.debug("SSE connection closed"); if (this.session?.sseController) { this.session.sseController = null;