diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 5ccf7761..acfb932b 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -13,6 +13,9 @@ jobs: name: "Test Python ${{ matrix.python-version }}" + env: + REPLICATE_API_TOKEN: ${{ secrets.REPLICATE_API_TOKEN }} + timeout-minutes: 10 strategy: diff --git a/README.md b/README.md index e1d8a885..16e1f8bf 100644 --- a/README.md +++ b/README.md @@ -114,6 +114,19 @@ for event in replicate.stream( print(str(event), end="") ``` +You can also stream the output of a prediction you create. +This is helpful when you want the ID of the prediction separate from its output. + +```python +version = "02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3 +prediction = replicate.predictions.create(version=version, input={ + "prompt": "Please write a haiku about llamas.", +}) + +for event in prediction.stream(): + print(str(event), end="") +``` + For more information, see ["Streaming output"](https://replicate.com/docs/streaming) in Replicate's docs. diff --git a/replicate/prediction.py b/replicate/prediction.py index 014a77ce..2d59791e 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -149,7 +149,7 @@ async def async_wait(self) -> None: await asyncio.sleep(self._client.poll_interval) await self.async_reload() - def stream(self) -> Optional[Iterator["ServerSentEvent"]]: + def stream(self) -> Iterator["ServerSentEvent"]: """ Stream the prediction output. @@ -168,6 +168,31 @@ def stream(self) -> Optional[Iterator["ServerSentEvent"]]: with self._client._client.stream("GET", url, headers=headers) as response: yield from EventSource(response) + async def async_stream(self) -> AsyncIterator["ServerSentEvent"]: + """ + Stream the prediction output asynchronously. + + Raises: + ReplicateError: If the model does not support streaming. + """ + + # no-op to enforce the use of 'await' when calling this method + await asyncio.sleep(0) + + url = self.urls and self.urls.get("stream", None) + if not url or not isinstance(url, str): + raise ReplicateError("Model does not support streaming") + + headers = {} + headers["Accept"] = "text/event-stream" + headers["Cache-Control"] = "no-store" + + async with self._client._async_client.stream( + "GET", url, headers=headers + ) as response: + async for event in EventSource(response): + yield event + def cancel(self) -> None: """ Cancels a running prediction. diff --git a/tests/test_stream.py b/tests/test_stream.py index b9ee2776..0bf673d6 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -1,55 +1,56 @@ import pytest import replicate +from replicate.stream import ServerSentEvent @pytest.mark.asyncio @pytest.mark.parametrize("async_flag", [True, False]) async def test_stream(async_flag, record_mode): - if record_mode == "none": - return - - version = "02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3" - + model = "replicate/canary:30e22229542eb3f79d4f945dacb58d32001b02cc313ae6f54eef27904edf3272" input = { - "prompt": "Please write a haiku about llamas.", + "text": "Hello", } events = [] if async_flag: async for event in await replicate.async_stream( - f"meta/llama-2-70b-chat:{version}", + model, input=input, ): events.append(event) else: for event in replicate.stream( - f"meta/llama-2-70b-chat:{version}", + model, input=input, ): events.append(event) assert len(events) > 0 - assert events[0].event == "output" + assert any(event.event == ServerSentEvent.EventType.OUTPUT for event in events) + assert any(event.event == ServerSentEvent.EventType.DONE for event in events) @pytest.mark.asyncio -async def test_stream_prediction(record_mode): - if record_mode == "none": - return - - version = "02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3" - +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_stream_prediction(async_flag, record_mode): + version = "30e22229542eb3f79d4f945dacb58d32001b02cc313ae6f54eef27904edf3272" input = { - "prompt": "Please write a haiku about llamas.", + "text": "Hello", } - prediction = replicate.predictions.create(version=version, input=input) - events = [] - for event in prediction.stream(): - events.append(event) + + if async_flag: + async for event in replicate.predictions.create( + version=version, input=input, stream=True + ).async_stream(): + events.append(event) + else: + for event in replicate.predictions.create( + version=version, input=input, stream=True + ).stream(): + events.append(event) assert len(events) > 0 - assert events[0].event == "output"