From 65a3e7efc399150104ca3932cb5b9fc1407ea376 Mon Sep 17 00:00:00 2001 From: Ainur Timerbaev Date: Sun, 7 Jan 2024 16:59:46 +0000 Subject: [PATCH 1/5] Add async_wait method to Prediction class Signed-off-by: John Doe --- replicate/prediction.py | 9 +++++++++ replicate/run.py | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/replicate/prediction.py b/replicate/prediction.py index d4edf735..54b0278f 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -1,3 +1,4 @@ +import asyncio import re import time from dataclasses import dataclass @@ -127,6 +128,14 @@ def wait(self) -> None: time.sleep(self._client.poll_interval) self.reload() + async def async_wait(self) -> None: + """ + Wait for prediction to finish. + """ + while self.status not in ["succeeded", "failed", "canceled"]: + await asyncio.sleep(self._client.poll_interval) + self.reload() + def stream(self) -> Optional[Iterator["ServerSentEvent"]]: """ Stream the prediction output. diff --git a/replicate/run.py b/replicate/run.py index 6bbab588..e1b421db 100644 --- a/replicate/run.py +++ b/replicate/run.py @@ -85,7 +85,7 @@ async def async_run( if version and (iterator := _make_output_iterator(version, prediction)): return iterator - prediction.wait() + await prediction.async_wait() if prediction.status == "failed": raise ModelError(prediction.error) From 2c69f344d4c5380f137c3580626c362092fd81de Mon Sep 17 00:00:00 2001 From: John Doe Date: Sun, 14 Jan 2024 11:04:07 +0000 Subject: [PATCH 2/5] Refactor Prediction class to use async/await for reloading Signed-off-by: John Doe --- replicate/prediction.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/replicate/prediction.py b/replicate/prediction.py index 54b0278f..62b21bc4 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -134,7 +134,7 @@ async def async_wait(self) -> None: """ while self.status not in ["succeeded", "failed", "canceled"]: await asyncio.sleep(self._client.poll_interval) - self.reload() + await self.async_reload() def stream(self) -> Optional[Iterator["ServerSentEvent"]]: """ @@ -173,6 +173,15 @@ def reload(self) -> None: for name, value in updated.dict().items(): setattr(self, name, value) + async def async_reload(self) -> None: + """ + Load this prediction from the server. + """ + + updated = await self._client.predictions.async_get(self.id) + for name, value in updated.dict().items(): + setattr(self, name, value) + def output_iterator(self) -> Iterator[Any]: """ Return an iterator of the prediction output. From 439ccae970983028a47ec2f7f8c9b8a564921b02 Mon Sep 17 00:00:00 2001 From: John Doe Date: Sun, 14 Jan 2024 11:07:27 +0000 Subject: [PATCH 3/5] Refactor async_get method call in run.py Signed-off-by: John Doe --- replicate/run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/replicate/run.py b/replicate/run.py index e1b421db..a957f9a3 100644 --- a/replicate/run.py +++ b/replicate/run.py @@ -80,7 +80,7 @@ async def async_run( ) if not version and (owner and name and version_id): - version = Versions(client, model=(owner, name)).get(version_id) + version = await Versions(client, model=(owner, name)).async_get(version_id) if version and (iterator := _make_output_iterator(version, prediction)): return iterator From f6f2a5f94e8aab252616bb6e63936f749fec0ae4 Mon Sep 17 00:00:00 2001 From: Mattt Date: Mon, 22 Jan 2024 18:37:07 -0800 Subject: [PATCH 4/5] Apply suggestions from code review --- replicate/prediction.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/replicate/prediction.py b/replicate/prediction.py index 62b21bc4..4ab0c72d 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -130,7 +130,7 @@ def wait(self) -> None: async def async_wait(self) -> None: """ - Wait for prediction to finish. + Wait for prediction to finish asynchronously. """ while self.status not in ["succeeded", "failed", "canceled"]: await asyncio.sleep(self._client.poll_interval) @@ -175,7 +175,7 @@ def reload(self) -> None: async def async_reload(self) -> None: """ - Load this prediction from the server. + Load this prediction from the server asynchronously. """ updated = await self._client.predictions.async_get(self.id) From 9b927f005af0f14a37f46871e9b04724350f2d9e Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 22 Jan 2024 18:39:23 -0800 Subject: [PATCH 5/5] Formatting --- replicate/prediction.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/replicate/prediction.py b/replicate/prediction.py index 4ab0c72d..be2ceffe 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -115,6 +115,7 @@ def progress(self) -> Optional[Progress]: """ The progress of the prediction, if available. """ + if self.logs is None or self.logs == "": return None @@ -124,6 +125,7 @@ def wait(self) -> None: """ Wait for prediction to finish. """ + while self.status not in ["succeeded", "failed", "canceled"]: time.sleep(self._client.poll_interval) self.reload() @@ -132,6 +134,7 @@ async def async_wait(self) -> None: """ Wait for prediction to finish asynchronously. """ + while self.status not in ["succeeded", "failed", "canceled"]: await asyncio.sleep(self._client.poll_interval) await self.async_reload()