From 28ae44e39aa6379a106090851d0e88af0c8470b8 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Tue, 23 Jan 2024 14:13:11 -0800 Subject: [PATCH 1/2] Update async_run to use async output iterator --- replicate/client.py | 5 +++-- replicate/prediction.py | 36 +++++++++++++++++++++++++++++++++++- replicate/run.py | 40 ++++++++++++++++++++++++++++++++-------- 3 files changed, 70 insertions(+), 11 deletions(-) diff --git a/replicate/client.py b/replicate/client.py index 87da11f7..5cde2e36 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -1,3 +1,4 @@ +import asyncio import os import random import time @@ -151,7 +152,7 @@ async def async_run( ref: str, input: Optional[Dict[str, Any]] = None, **params: Unpack["Predictions.CreatePredictionParams"], - ) -> Union[Any, Iterator[Any]]: # noqa: ANN401 + ) -> Union[Any, AsyncIterator[Any]]: # noqa: ANN401 """ Run a model and wait for its output asynchronously. """ @@ -298,7 +299,7 @@ async def handle_async_request(self, request: httpx.Request) -> httpx.Response: response.close() sleep_for = self._calculate_sleep(attempts_made, response.headers) - time.sleep(sleep_for) + await asyncio.sleep(sleep_for) response = await self._wrapped_transport.handle_async_request(request) # type: ignore diff --git a/replicate/prediction.py b/replicate/prediction.py index be2ceffe..4319856e 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -2,7 +2,17 @@ import re import time from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Literal, Optional, Union +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Dict, + Iterator, + List, + Literal, + Optional, + Union, +) from typing_extensions import NotRequired, TypedDict, Unpack @@ -208,6 +218,30 @@ def output_iterator(self) -> Iterator[Any]: for output in new_output: yield output + async def async_output_iterator(self) -> AsyncIterator[Any]: + """ + Return an asynchronous iterator of the prediction output. + """ + + # TODO: check output is list + previous_output = self.output or [] + while self.status not in ["succeeded", "failed", "canceled"]: + output = self.output or [] + new_output = output[len(previous_output) :] + for item in new_output: + yield item + previous_output = output + await asyncio.sleep(self._client.poll_interval) # pylint: disable=no-member + await self.async_reload() + + if self.status == "failed": + raise ModelError(self.error) + + output = self.output or [] + new_output = output[len(previous_output) :] + for output in new_output: + yield output + class Predictions(Namespace): """ diff --git a/replicate/run.py b/replicate/run.py index a957f9a3..5d5d5634 100644 --- a/replicate/run.py +++ b/replicate/run.py @@ -1,4 +1,13 @@ -from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Dict, + Iterator, + List, + Optional, + Union, +) from typing_extensions import Unpack @@ -59,7 +68,7 @@ async def async_run( ref: Union["Model", "Version", "ModelVersionIdentifier", str], input: Optional[Dict[str, Any]] = None, **params: Unpack["Predictions.CreatePredictionParams"], -) -> Union[Any, Iterator[Any]]: # noqa: ANN401 +) -> Union[Any, AsyncIterator[Any]]: # noqa: ANN401 """ Run a model and wait for its output asynchronously. """ @@ -82,7 +91,7 @@ async def async_run( if not version and (owner and name and version_id): version = await Versions(client, model=(owner, name)).async_get(version_id) - if version and (iterator := _make_output_iterator(version, prediction)): + if version and (iterator := await _async_make_output_iterator(version, prediction)): return iterator await prediction.async_wait() @@ -93,17 +102,32 @@ async def async_run( return prediction.output -def _make_output_iterator( - version: Version, prediction: Prediction -) -> Optional[Iterator[Any]]: +def _has_output_iterator_array_type(version: Version) -> bool: schema = make_schema_backwards_compatible( version.openapi_schema, version.cog_version ) - output = schema["components"]["schemas"]["Output"] - if output.get("type") == "array" and output.get("x-cog-array-type") == "iterator": + output = schema.get("components", {}).get("schemas", {}).get("Output", {}) + return ( + output.get("type") == "array" and output.get("x-cog-array-type") == "iterator" + ) + + +def _make_output_iterator( + version: Version, prediction: Prediction +) -> Optional[Iterator[Any]]: + if _has_output_iterator_array_type(version): return prediction.output_iterator() return None +async def _async_make_output_iterator( + version: Version, prediction: Prediction +) -> Optional[AsyncIterator[Any]]: + if _has_output_iterator_array_type(version): + return prediction.async_output_iterator() + + return None + + __all__: List = [] From 7f31323bd9d506ccbae4a9f417ceb6b1399f753b Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Sat, 27 Jan 2024 12:27:19 -0800 Subject: [PATCH 2/2] Make helper method returning async iterator synchronous Signed-off-by: Mattt Zmuda --- .vscode/settings.json | 4 ++-- replicate/run.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index a9dcf817..62c9bbbb 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -10,8 +10,8 @@ "editor.defaultFormatter": "charliermarsh.ruff", "editor.formatOnSave": true, "editor.codeActionsOnSave": { - "source.fixAll": true, - "source.organizeImports": true + "source.fixAll": "explicit", + "source.organizeImports": "explicit" } }, "python.languageServer": "Pylance", diff --git a/replicate/run.py b/replicate/run.py index 5d5d5634..975cc4dc 100644 --- a/replicate/run.py +++ b/replicate/run.py @@ -91,7 +91,7 @@ async def async_run( if not version and (owner and name and version_id): version = await Versions(client, model=(owner, name)).async_get(version_id) - if version and (iterator := await _async_make_output_iterator(version, prediction)): + if version and (iterator := _make_async_output_iterator(version, prediction)): return iterator await prediction.async_wait() @@ -121,7 +121,7 @@ def _make_output_iterator( return None -async def _async_make_output_iterator( +def _make_async_output_iterator( version: Version, prediction: Prediction ) -> Optional[AsyncIterator[Any]]: if _has_output_iterator_array_type(version):