Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ jobs:

name: "Test Python ${{ matrix.python-version }}"

env:
REPLICATE_API_TOKEN: ${{ secrets.REPLICATE_API_TOKEN }}

timeout-minutes: 10

strategy:
Expand Down
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,19 @@ for event in replicate.stream(
print(str(event), end="")
```

You can also stream the output of a prediction you create.
This is helpful when you want the ID of the prediction separate from its output.

```python
version = "02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3
prediction = replicate.predictions.create(version=version, input={
"prompt": "Please write a haiku about llamas.",
})

for event in prediction.stream():
print(str(event), end="")
```

For more information, see
["Streaming output"](https://replicate.com/docs/streaming) in Replicate's docs.

Expand Down
27 changes: 26 additions & 1 deletion replicate/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ async def async_wait(self) -> None:
await asyncio.sleep(self._client.poll_interval)
await self.async_reload()

def stream(self) -> Optional[Iterator["ServerSentEvent"]]:
def stream(self) -> Iterator["ServerSentEvent"]:
"""
Stream the prediction output.

Expand All @@ -168,6 +168,31 @@ def stream(self) -> Optional[Iterator["ServerSentEvent"]]:
with self._client._client.stream("GET", url, headers=headers) as response:
yield from EventSource(response)

async def async_stream(self) -> AsyncIterator["ServerSentEvent"]:
"""
Stream the prediction output asynchronously.

Raises:
ReplicateError: If the model does not support streaming.
"""

# no-op to enforce the use of 'await' when calling this method
await asyncio.sleep(0)

url = self.urls and self.urls.get("stream", None)
if not url or not isinstance(url, str):
raise ReplicateError("Model does not support streaming")

headers = {}
headers["Accept"] = "text/event-stream"
headers["Cache-Control"] = "no-store"

async with self._client._async_client.stream(
"GET", url, headers=headers
) as response:
async for event in EventSource(response):
yield event

def cancel(self) -> None:
"""
Cancels a running prediction.
Expand Down
43 changes: 22 additions & 21 deletions tests/test_stream.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,56 @@
import pytest

import replicate
from replicate.stream import ServerSentEvent


@pytest.mark.asyncio
@pytest.mark.parametrize("async_flag", [True, False])
async def test_stream(async_flag, record_mode):
if record_mode == "none":
return

version = "02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3"

model = "replicate/canary:30e22229542eb3f79d4f945dacb58d32001b02cc313ae6f54eef27904edf3272"
input = {
"prompt": "Please write a haiku about llamas.",
"text": "Hello",
}

events = []

if async_flag:
async for event in await replicate.async_stream(
f"meta/llama-2-70b-chat:{version}",
model,
input=input,
):
events.append(event)
else:
for event in replicate.stream(
f"meta/llama-2-70b-chat:{version}",
model,
input=input,
):
events.append(event)

assert len(events) > 0
assert events[0].event == "output"
assert any(event.event == ServerSentEvent.EventType.OUTPUT for event in events)
assert any(event.event == ServerSentEvent.EventType.DONE for event in events)


@pytest.mark.asyncio
async def test_stream_prediction(record_mode):
if record_mode == "none":
return

version = "02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3"

@pytest.mark.parametrize("async_flag", [True, False])
async def test_stream_prediction(async_flag, record_mode):
version = "30e22229542eb3f79d4f945dacb58d32001b02cc313ae6f54eef27904edf3272"
input = {
"prompt": "Please write a haiku about llamas.",
"text": "Hello",
}

prediction = replicate.predictions.create(version=version, input=input)

events = []
for event in prediction.stream():
events.append(event)

if async_flag:
async for event in replicate.predictions.create(
version=version, input=input, stream=True
).async_stream():
events.append(event)
else:
for event in replicate.predictions.create(
version=version, input=input, stream=True
).stream():
events.append(event)

assert len(events) > 0
assert events[0].event == "output"