diff --git a/index.js b/index.js index f4c0e2cf..dd60ceca 100644 --- a/index.js +++ b/index.js @@ -274,7 +274,7 @@ class Replicate { * @yields {ServerSentEvent} Each streamed event from the prediction */ async *stream(ref, options) { - const { wait, ...data } = options; + const { wait, signal, ...data } = options; const identifier = ModelVersionIdentifier.parse(ref); @@ -296,11 +296,10 @@ class Replicate { } if (prediction.urls && prediction.urls.stream) { - const { signal } = options; const stream = createReadableStream({ url: prediction.urls.stream, fetch: this.fetch, - options: { signal }, + ...(signal ? { options: { signal } } : {}), }); yield* streamAsyncIterator(stream); diff --git a/integration/cloudflare-worker/index.js b/integration/cloudflare-worker/index.js index be18d53c..32ec9fc8 100644 --- a/integration/cloudflare-worker/index.js +++ b/integration/cloudflare-worker/index.js @@ -5,12 +5,14 @@ export default { const replicate = new Replicate({ auth: env.REPLICATE_API_TOKEN }); try { + const controller = new AbortController(); const output = replicate.stream( "replicate/canary:30e22229542eb3f79d4f945dacb58d32001b02cc313ae6f54eef27904edf3272", { input: { text: "Colin CloudFlare", }, + signal: controller.signal, } ); const stream = new ReadableStream({