diff --git a/replicate/prediction.py b/replicate/prediction.py index d4edf735..be2ceffe 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -1,3 +1,4 @@ +import asyncio import re import time from dataclasses import dataclass @@ -114,6 +115,7 @@ def progress(self) -> Optional[Progress]: """ The progress of the prediction, if available. """ + if self.logs is None or self.logs == "": return None @@ -123,10 +125,20 @@ def wait(self) -> None: """ Wait for prediction to finish. """ + while self.status not in ["succeeded", "failed", "canceled"]: time.sleep(self._client.poll_interval) self.reload() + async def async_wait(self) -> None: + """ + Wait for prediction to finish asynchronously. + """ + + while self.status not in ["succeeded", "failed", "canceled"]: + await asyncio.sleep(self._client.poll_interval) + await self.async_reload() + def stream(self) -> Optional[Iterator["ServerSentEvent"]]: """ Stream the prediction output. @@ -164,6 +176,15 @@ def reload(self) -> None: for name, value in updated.dict().items(): setattr(self, name, value) + async def async_reload(self) -> None: + """ + Load this prediction from the server asynchronously. + """ + + updated = await self._client.predictions.async_get(self.id) + for name, value in updated.dict().items(): + setattr(self, name, value) + def output_iterator(self) -> Iterator[Any]: """ Return an iterator of the prediction output. diff --git a/replicate/run.py b/replicate/run.py index 6bbab588..a957f9a3 100644 --- a/replicate/run.py +++ b/replicate/run.py @@ -80,12 +80,12 @@ async def async_run( ) if not version and (owner and name and version_id): - version = Versions(client, model=(owner, name)).get(version_id) + version = await Versions(client, model=(owner, name)).async_get(version_id) if version and (iterator := _make_output_iterator(version, prediction)): return iterator - prediction.wait() + await prediction.async_wait() if prediction.status == "failed": raise ModelError(prediction.error)