diff --git a/README.md b/README.md index 63175050..c55c4e62 100644 --- a/README.md +++ b/README.md @@ -90,17 +90,13 @@ import replicate # https://replicate.com/meta/llama-2-70b-chat model_version = "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3" -tokens = [] for event in replicate.stream( model_version, input={ "prompt": "Please write a haiku about llamas.", }, ): - print(event) - tokens.append(str(event)) - -print("".join(tokens)) + print(str(event), end="") ``` For more information, see diff --git a/replicate/stream.py b/replicate/stream.py index 22cea974..47f9df5f 100644 --- a/replicate/stream.py +++ b/replicate/stream.py @@ -52,7 +52,7 @@ class EventType(Enum): retry: Optional[int] def __str__(self) -> str: - if self.event == "output": + if self.event == ServerSentEvent.EventType.OUTPUT: return self.data return "" @@ -114,7 +114,7 @@ def decode(self, line: str) -> Optional[ServerSentEvent]: return None fieldname, _, value = line.partition(":") - value = value.lstrip() + value = value.removeprefix(" ") if fieldname == "event": if event := ServerSentEvent.EventType(value): @@ -138,26 +138,28 @@ def __iter__(self) -> Iterator[ServerSentEvent]: line = line.rstrip("\n") sse = decoder.decode(line) if sse is not None: - if sse.event == "done": - return - if sse.event == "error": + if sse.event == ServerSentEvent.EventType.ERROR: raise RuntimeError(sse.data) yield sse + if sse.event == ServerSentEvent.EventType.DONE: + return + async def __aiter__(self) -> AsyncIterator[ServerSentEvent]: decoder = EventSource.Decoder() async for line in self.response.aiter_lines(): line = line.rstrip("\n") sse = decoder.decode(line) if sse is not None: - if sse.event == "done": - return - if sse.event == "error": + if sse.event == ServerSentEvent.EventType.ERROR: raise RuntimeError(sse.data) yield sse + if sse.event == ServerSentEvent.EventType.DONE: + return + def stream( client: "Client",