Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions packages/appkit/src/plugin/plugin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -443,8 +443,16 @@ export abstract class Plugin<
}
};

// stream the result to the client
await this.streamManager.stream(res, asyncWrapperFn, streamConfig);
// stream the result to the client. The effective user key is forwarded
// to the stream manager so that reconnections to existing streamIds are
// bound to the original creator (prevents cross-user stream takeover via
// guessed/leaked IDs).
await this.streamManager.stream(
res,
asyncWrapperFn,
streamConfig,
effectiveUserKey,
);
}

/**
Expand Down
4 changes: 4 additions & 0 deletions packages/appkit/src/plugin/tests/plugin.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,10 @@ describe("Plugin", () => {
mockResponse,
expect.any(Function),
{},
// The plugin forwards the resolved user key as the 4th argument to
// bind the stream to its creator. The test passes `false` as an
// explicit override, which propagates through `userKey ?? getCurrentUserId()`.
false,
);
});

Expand Down
20 changes: 18 additions & 2 deletions packages/appkit/src/stream/stream-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ export class StreamManager {
res: IAppResponse,
handler: (signal: AbortSignal) => AsyncGenerator<any, void, unknown>,
options?: StreamConfig,
ownerKey?: string,
): Promise<void> {
const { streamId } = options || {};

Expand All @@ -45,14 +46,27 @@ export class StreamManager {
// handle reconnection
if (streamId && StreamValidator.validateStreamId(streamId)) {
const existingStream = this.streamRegistry.get(streamId);
// if stream exists, attach to it
if (existingStream) {
// Enforce per-user binding: the stream's owner key must match the
// requesting caller's owner key. This prevents cross-user stream
// takeover via guessed/leaked stream IDs (the SSE registry was
// previously a global lookup with no authorization step).
if (existingStream.ownerKey !== ownerKey) {
this.sseWriter.writeError(
res,
randomUUID(),
"Stream not found or access denied",
SSEErrorCode.STREAM_FORBIDDEN,
);
res.end();
return;
}
return this._attachToExistingStream(res, existingStream, options);
}
}

// if stream does not exist, create a new one
return this._createNewStream(res, handler, options);
return this._createNewStream(res, handler, options, ownerKey);
}

// abort all active operations
Expand Down Expand Up @@ -143,6 +157,7 @@ export class StreamManager {
res: IAppResponse,
handler: (signal: AbortSignal) => AsyncGenerator<any, void, unknown>,
options?: StreamConfig,
ownerKey?: string,
): Promise<void> {
const streamId = options?.streamId ?? randomUUID();

Expand Down Expand Up @@ -177,6 +192,7 @@ export class StreamManager {
// create stream entry
const streamEntry: StreamEntry = {
streamId,
ownerKey,
generator: handler(combinedSignal),
eventBuffer,
clients: new Set([res]),
Expand Down
115 changes: 115 additions & 0 deletions packages/appkit/src/stream/tests/stream.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,121 @@ describe("StreamManager", () => {
expect(hasNewStream).toBe(false);
});

test("rejects reconnect from a different owner", async () => {
const streamId = "owner-bound-123";

const { mockRes: mockRes1 } = createMockResponse();

async function* generator1() {
for (let i = 0; i < 5; i++) {
yield { type: "message", data: `secret-${i}` };
}
}

await streamManager.stream(
mockRes1 as any,
generator1,
{ streamId },
"user-alice",
);

const { mockRes: mockRes2, events: events2 } = createMockResponse();

async function* generator2() {
yield { type: "should-not-run" };
}

await streamManager.stream(
mockRes2 as any,
generator2,
{ streamId },
"user-bob",
);

// Bob must not see any of Alice's events or replays.
expect(events2.some((e) => e.includes("secret-"))).toBe(false);
expect(events2.some((e) => e.includes("should-not-run"))).toBe(false);

// A STREAM_FORBIDDEN error must be emitted and the connection ended.
expect(events2.some((e) => e.includes("STREAM_FORBIDDEN"))).toBe(true);
expect(mockRes2.end).toHaveBeenCalled();
});

test("allows reconnect from the same owner", async () => {
const streamId = "owner-bound-456";

const { mockRes: mockRes1, events: events1 } = createMockResponse();

async function* generator1() {
yield { type: "message", data: "event-1" };
yield { type: "message", data: "event-2" };
yield { type: "message", data: "event-3" };
}

await streamManager.stream(
mockRes1 as any,
generator1,
{ streamId },
"user-alice",
);

const eventIds = events1
.filter((e) => e.startsWith("id: "))
.map((e) => e.replace("id: ", "").replace("\n", ""));

const { mockRes: mockRes2, events: events2 } = createMockResponse({
"last-event-id": eventIds[1],
});

async function* generator2() {
yield { type: "should-not-run" };
}

await streamManager.stream(
mockRes2 as any,
generator2,
{ streamId },
"user-alice",
);

const replayedData = events2
.filter((e) => e.startsWith("data: "))
.map((e) => e.replace("data: ", "").replace("\n\n", ""));
expect(replayedData.length).toBe(1);
expect(replayedData[0]).toContain("event-3");
expect(events2.some((e) => e.includes("STREAM_FORBIDDEN"))).toBe(false);
});

test("treats a missing owner as a distinct identity from a named owner", async () => {
const streamId = "owner-bound-789";

const { mockRes: mockRes1 } = createMockResponse();

async function* generator1() {
yield { type: "message", data: "scoped" };
}

await streamManager.stream(
mockRes1 as any,
generator1,
{ streamId },
"user-alice",
);

const { mockRes: mockRes2, events: events2 } = createMockResponse();

async function* generator2() {
yield { type: "should-not-run" };
}

// Caller without an owner key must not attach to a stream that
// was created with one.
await streamManager.stream(mockRes2 as any, generator2, { streamId });

expect(events2.some((e) => e.includes("scoped"))).toBe(false);
expect(events2.some((e) => e.includes("STREAM_FORBIDDEN"))).toBe(true);
});

test("should replay successfully when within buffer capacity", async () => {
const streamId = "no-overflow-test-456";

Expand Down
7 changes: 7 additions & 0 deletions packages/appkit/src/stream/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ export const SSEErrorCode = {
INVALID_REQUEST: "INVALID_REQUEST",
STREAM_ABORTED: "STREAM_ABORTED",
STREAM_EVICTED: "STREAM_EVICTED",
STREAM_FORBIDDEN: "STREAM_FORBIDDEN",
UPSTREAM_ERROR: "UPSTREAM_ERROR",
} as const satisfies Record<string, string>;

Expand All @@ -35,6 +36,12 @@ export interface BufferedEvent {

export interface StreamEntry {
streamId: string;
/**
* Identifier of the principal that created the stream (e.g. end-user ID
* or service principal user ID). When set, only requests sharing the
* same owner key may reconnect to the stream.
*/
ownerKey?: string;
generator: AsyncGenerator<any, void, unknown>;
eventBuffer: EventRingBuffer;
clients: Set<IAppResponse>;
Expand Down
Loading