Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
"editor.defaultFormatter": "charliermarsh.ruff",
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.fixAll": true,
"source.organizeImports": true
"source.fixAll": "explicit",
"source.organizeImports": "explicit"
}
},
"python.languageServer": "Pylance",
Expand Down
5 changes: 3 additions & 2 deletions replicate/client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import os
import random
import time
Expand Down Expand Up @@ -151,7 +152,7 @@ async def async_run(
ref: str,
input: Optional[Dict[str, Any]] = None,
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Union[Any, Iterator[Any]]: # noqa: ANN401
) -> Union[Any, AsyncIterator[Any]]: # noqa: ANN401
"""
Run a model and wait for its output asynchronously.
"""
Expand Down Expand Up @@ -298,7 +299,7 @@ async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
response.close()

sleep_for = self._calculate_sleep(attempts_made, response.headers)
time.sleep(sleep_for)
await asyncio.sleep(sleep_for)

response = await self._wrapped_transport.handle_async_request(request) # type: ignore

Expand Down
36 changes: 35 additions & 1 deletion replicate/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,17 @@
import re
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Literal, Optional, Union
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Dict,
Iterator,
List,
Literal,
Optional,
Union,
)

from typing_extensions import NotRequired, TypedDict, Unpack

Expand Down Expand Up @@ -208,6 +218,30 @@ def output_iterator(self) -> Iterator[Any]:
for output in new_output:
yield output

async def async_output_iterator(self) -> AsyncIterator[Any]:
"""
Return an asynchronous iterator of the prediction output.
"""

# TODO: check output is list
previous_output = self.output or []
while self.status not in ["succeeded", "failed", "canceled"]:
output = self.output or []
new_output = output[len(previous_output) :]
for item in new_output:
yield item
previous_output = output
await asyncio.sleep(self._client.poll_interval) # pylint: disable=no-member
await self.async_reload()

if self.status == "failed":
raise ModelError(self.error)

output = self.output or []
new_output = output[len(previous_output) :]
for output in new_output:
yield output


class Predictions(Namespace):
"""
Expand Down
40 changes: 32 additions & 8 deletions replicate/run.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Dict,
Iterator,
List,
Optional,
Union,
)

from typing_extensions import Unpack

Expand Down Expand Up @@ -59,7 +68,7 @@ async def async_run(
ref: Union["Model", "Version", "ModelVersionIdentifier", str],
input: Optional[Dict[str, Any]] = None,
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Union[Any, Iterator[Any]]: # noqa: ANN401
) -> Union[Any, AsyncIterator[Any]]: # noqa: ANN401
"""
Run a model and wait for its output asynchronously.
"""
Expand All @@ -82,7 +91,7 @@ async def async_run(
if not version and (owner and name and version_id):
version = await Versions(client, model=(owner, name)).async_get(version_id)

if version and (iterator := _make_output_iterator(version, prediction)):
if version and (iterator := _make_async_output_iterator(version, prediction)):
return iterator

await prediction.async_wait()
Expand All @@ -93,17 +102,32 @@ async def async_run(
return prediction.output


def _make_output_iterator(
version: Version, prediction: Prediction
) -> Optional[Iterator[Any]]:
def _has_output_iterator_array_type(version: Version) -> bool:
schema = make_schema_backwards_compatible(
version.openapi_schema, version.cog_version
)
output = schema["components"]["schemas"]["Output"]
if output.get("type") == "array" and output.get("x-cog-array-type") == "iterator":
output = schema.get("components", {}).get("schemas", {}).get("Output", {})
return (
output.get("type") == "array" and output.get("x-cog-array-type") == "iterator"
)


def _make_output_iterator(
version: Version, prediction: Prediction
) -> Optional[Iterator[Any]]:
if _has_output_iterator_array_type(version):
return prediction.output_iterator()

return None


def _make_async_output_iterator(
version: Version, prediction: Prediction
) -> Optional[AsyncIterator[Any]]:
if _has_output_iterator_array_type(version):
return prediction.async_output_iterator()

return None


__all__: List = []