From 08174e86c2f83e494ba238d89c25912cc37216f3 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 4 Dec 2023 09:37:03 -0800 Subject: [PATCH 1/3] Fix event type checks Signed-off-by: Mattt Zmuda --- replicate/stream.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/replicate/stream.py b/replicate/stream.py index 22cea974..fb6a066b 100644 --- a/replicate/stream.py +++ b/replicate/stream.py @@ -52,7 +52,7 @@ class EventType(Enum): retry: Optional[int] def __str__(self) -> str: - if self.event == "output": + if self.event == ServerSentEvent.EventType.OUTPUT: return self.data return "" @@ -138,26 +138,28 @@ def __iter__(self) -> Iterator[ServerSentEvent]: line = line.rstrip("\n") sse = decoder.decode(line) if sse is not None: - if sse.event == "done": - return - if sse.event == "error": + if sse.event == ServerSentEvent.EventType.ERROR: raise RuntimeError(sse.data) yield sse + if sse.event == ServerSentEvent.EventType.DONE: + return + async def __aiter__(self) -> AsyncIterator[ServerSentEvent]: decoder = EventSource.Decoder() async for line in self.response.aiter_lines(): line = line.rstrip("\n") sse = decoder.decode(line) if sse is not None: - if sse.event == "done": - return - if sse.event == "error": + if sse.event == ServerSentEvent.EventType.ERROR: raise RuntimeError(sse.data) yield sse + if sse.event == ServerSentEvent.EventType.DONE: + return + def stream( client: "Client", From 5f22f9f51472d2ea50425ca14c5cecb8b6604dff Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 4 Dec 2023 09:37:38 -0800 Subject: [PATCH 2/3] Fix how leading spaces are stripped Signed-off-by: Mattt Zmuda --- replicate/stream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/replicate/stream.py b/replicate/stream.py index fb6a066b..47f9df5f 100644 --- a/replicate/stream.py +++ b/replicate/stream.py @@ -114,7 +114,7 @@ def decode(self, line: str) -> Optional[ServerSentEvent]: return None fieldname, _, value = line.partition(":") - value = value.lstrip() + value = value.removeprefix(" ") if fieldname == "event": if event := ServerSentEvent.EventType(value): From 003d595601786b96be2d2d73227098321b05d42b Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 4 Dec 2023 09:39:21 -0800 Subject: [PATCH 3/3] Update example usage in README Signed-off-by: Mattt Zmuda --- README.md | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/README.md b/README.md index 63175050..c55c4e62 100644 --- a/README.md +++ b/README.md @@ -90,17 +90,13 @@ import replicate # https://replicate.com/meta/llama-2-70b-chat model_version = "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3" -tokens = [] for event in replicate.stream( model_version, input={ "prompt": "Please write a haiku about llamas.", }, ): - print(event) - tokens.append(str(event)) - -print("".join(tokens)) + print(str(event), end="") ``` For more information, see