From c8853c57ff1f56775f1f42eece1a5eacb9aa3fed Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Tue, 19 Mar 2024 04:50:57 -0700 Subject: [PATCH 1/4] Implement Prediction.async_stream method Signed-off-by: Mattt Zmuda --- replicate/prediction.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/replicate/prediction.py b/replicate/prediction.py index 014a77ce..2d59791e 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -149,7 +149,7 @@ async def async_wait(self) -> None: await asyncio.sleep(self._client.poll_interval) await self.async_reload() - def stream(self) -> Optional[Iterator["ServerSentEvent"]]: + def stream(self) -> Iterator["ServerSentEvent"]: """ Stream the prediction output. @@ -168,6 +168,31 @@ def stream(self) -> Optional[Iterator["ServerSentEvent"]]: with self._client._client.stream("GET", url, headers=headers) as response: yield from EventSource(response) + async def async_stream(self) -> AsyncIterator["ServerSentEvent"]: + """ + Stream the prediction output asynchronously. + + Raises: + ReplicateError: If the model does not support streaming. + """ + + # no-op to enforce the use of 'await' when calling this method + await asyncio.sleep(0) + + url = self.urls and self.urls.get("stream", None) + if not url or not isinstance(url, str): + raise ReplicateError("Model does not support streaming") + + headers = {} + headers["Accept"] = "text/event-stream" + headers["Cache-Control"] = "no-store" + + async with self._client._async_client.stream( + "GET", url, headers=headers + ) as response: + async for event in EventSource(response): + yield event + def cancel(self) -> None: """ Cancels a running prediction. From 057e6505cee6eb603eb726494de4ddb3bb0f4579 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Tue, 19 Mar 2024 04:51:47 -0700 Subject: [PATCH 2/4] Update streaming tests Signed-off-by: Mattt Zmuda --- tests/test_stream.py | 45 +++++++++++++++++++++++++------------------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/tests/test_stream.py b/tests/test_stream.py index b9ee2776..e0c71c78 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -1,55 +1,62 @@ import pytest import replicate +from replicate.stream import ServerSentEvent @pytest.mark.asyncio @pytest.mark.parametrize("async_flag", [True, False]) async def test_stream(async_flag, record_mode): - if record_mode == "none": - return - - version = "02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3" + # if record_mode == "none": + # return + model = "replicate/canary:30e22229542eb3f79d4f945dacb58d32001b02cc313ae6f54eef27904edf3272" input = { - "prompt": "Please write a haiku about llamas.", + "text": "Hello", } events = [] if async_flag: async for event in await replicate.async_stream( - f"meta/llama-2-70b-chat:{version}", + model, input=input, ): events.append(event) else: for event in replicate.stream( - f"meta/llama-2-70b-chat:{version}", + model, input=input, ): events.append(event) assert len(events) > 0 - assert events[0].event == "output" + assert any(event.event == ServerSentEvent.EventType.OUTPUT for event in events) + assert any(event.event == ServerSentEvent.EventType.DONE for event in events) @pytest.mark.asyncio -async def test_stream_prediction(record_mode): - if record_mode == "none": - return - - version = "02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3" +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_stream_prediction(async_flag, record_mode): + # if record_mode == "none": + # return + version = "30e22229542eb3f79d4f945dacb58d32001b02cc313ae6f54eef27904edf3272" input = { - "prompt": "Please write a haiku about llamas.", + "text": "Hello", } - prediction = replicate.predictions.create(version=version, input=input) - events = [] - for event in prediction.stream(): - events.append(event) + + if async_flag: + async for event in replicate.predictions.create( + version=version, input=input, stream=True + ).async_stream(): + events.append(event) + else: + for event in replicate.predictions.create( + version=version, input=input, stream=True + ).stream(): + events.append(event) assert len(events) > 0 - assert events[0].event == "output" From 333d0900dccc547604bc749d102bd8c5cb739fe9 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Tue, 19 Mar 2024 04:53:21 -0700 Subject: [PATCH 3/4] Update README with discussion of prediction.stream() method Signed-off-by: Mattt Zmuda --- README.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/README.md b/README.md index e1d8a885..16e1f8bf 100644 --- a/README.md +++ b/README.md @@ -114,6 +114,19 @@ for event in replicate.stream( print(str(event), end="") ``` +You can also stream the output of a prediction you create. +This is helpful when you want the ID of the prediction separate from its output. + +```python +version = "02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3 +prediction = replicate.predictions.create(version=version, input={ + "prompt": "Please write a haiku about llamas.", +}) + +for event in prediction.stream(): + print(str(event), end="") +``` + For more information, see ["Streaming output"](https://replicate.com/docs/streaming) in Replicate's docs. From 1476928c2d4c454b28f86990c72f31caf8abb329 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Tue, 19 Mar 2024 05:05:59 -0700 Subject: [PATCH 4/4] Run streaming tests in CI Pass REPLICATE_API_TOKEN environment variable Signed-off-by: Mattt Zmuda --- .github/workflows/ci.yaml | 3 +++ tests/test_stream.py | 6 ------ 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 5ccf7761..acfb932b 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -13,6 +13,9 @@ jobs: name: "Test Python ${{ matrix.python-version }}" + env: + REPLICATE_API_TOKEN: ${{ secrets.REPLICATE_API_TOKEN }} + timeout-minutes: 10 strategy: diff --git a/tests/test_stream.py b/tests/test_stream.py index e0c71c78..0bf673d6 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -7,9 +7,6 @@ @pytest.mark.asyncio @pytest.mark.parametrize("async_flag", [True, False]) async def test_stream(async_flag, record_mode): - # if record_mode == "none": - # return - model = "replicate/canary:30e22229542eb3f79d4f945dacb58d32001b02cc313ae6f54eef27904edf3272" input = { "text": "Hello", @@ -38,9 +35,6 @@ async def test_stream(async_flag, record_mode): @pytest.mark.asyncio @pytest.mark.parametrize("async_flag", [True, False]) async def test_stream_prediction(async_flag, record_mode): - # if record_mode == "none": - # return - version = "30e22229542eb3f79d4f945dacb58d32001b02cc313ae6f54eef27904edf3272" input = { "text": "Hello",