From 363c8f53a06125e1696094d73a4fde7f677cfa03 Mon Sep 17 00:00:00 2001 From: Diwank Singh Tomer Date: Fri, 3 Mar 2023 20:52:29 +0530 Subject: [PATCH 1/2] feat: Add async support Signed-off-by: Diwank Singh Tomer --- .gitignore | 3 ++ replicate/base_model.py | 8 +++++ replicate/client.py | 32 +++++++++++++++++- replicate/collection.py | 9 +++++ replicate/files.py | 67 ++++++++++++++++++++++++++++++++----- replicate/prediction.py | 73 +++++++++++++++++++++++++++++++++++++++++ replicate/version.py | 38 +++++++++++++++++++++ setup.py | 2 +- 8 files changed, 222 insertions(+), 10 deletions(-) diff --git a/.gitignore b/.gitignore index 86feca03..180b8005 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# Virtualenv +.venv + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/replicate/base_model.py b/replicate/base_model.py index 51fadf21..30d1bed4 100644 --- a/replicate/base_model.py +++ b/replicate/base_model.py @@ -21,3 +21,11 @@ def reload(self): new_model = self._collection.get(self.id) for k, v in new_model.dict().items(): setattr(self, k, v) + + async def reload_async(self): + """ + Load this object from the server again. + """ + new_model = await self._collection.get_async(self.id) + for k, v in new_model.dict().items(): + setattr(self, k, v) diff --git a/replicate/client.py b/replicate/client.py index 27f7ddc8..8a88fdb9 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -1,6 +1,7 @@ import os from json import JSONDecodeError +import httpx import requests from requests.adapters import HTTPAdapter, Retry @@ -20,6 +21,9 @@ def __init__(self, api_token=None) -> None: "REPLICATE_API_BASE_URL", "https://api.replicate.com" ) + max_retries: int = 5 + self.httpx_transport = httpx.AsyncHTTPTransport(retries=max_retries) + # TODO: make thread safe self.session = requests.Session() @@ -29,7 +33,7 @@ def __init__(self, api_token=None) -> None: # We might just want to enable retry logic for iterators, but for now this is a blunt instrument to # make this reliable. retries = Retry( - total=5, + total=max_retries, backoff_factor=2, # Only retry on GET so we don't unintionally mutute data method_whitelist=["GET"], @@ -56,6 +60,32 @@ def _request(self, method: str, path: str, **kwargs): raise ReplicateError(f"HTTP error: {resp.status_code, resp.reason}") return resp + async def _request_async(self, method: str, path: str, **kwargs): + # from requests.Session + if method in ["GET", "OPTIONS"]: + kwargs.setdefault("allow_redirects", True) + if method in ["HEAD"]: + kwargs.setdefault("allow_redirects", False) + kwargs.setdefault("headers", {}) + kwargs["headers"].update(self._headers()) + + async with httpx.AsyncClient( + follow_redirects=True, + transport=self.httpx_transport, + ) as client: + if "allow_redirects" in kwargs: + kwargs.pop("allow_redirects") + + resp = await client.request(method, self.base_url + path, **kwargs) + + if 400 <= resp.status_code < 600: + try: + raise ReplicateError(resp.json()["detail"]) + except (JSONDecodeError, KeyError): + pass + raise ReplicateError(f"HTTP error: {resp.status_code, resp.reason}") + return resp + def _headers(self): return { "Authorization": f"Token {self._api_token()}", diff --git a/replicate/collection.py b/replicate/collection.py index 1b9c1368..1340a75c 100644 --- a/replicate/collection.py +++ b/replicate/collection.py @@ -21,6 +21,15 @@ def get(self, key): def create(self, attrs=None): raise NotImplementedError + async def list_async(self): + raise NotImplementedError + + async def get_async(self, key): + raise NotImplementedError + + async def create_async(self, attrs=None): + raise NotImplementedError + def prepare_model(self, attrs): """ Create a model from a set of attributes. diff --git a/replicate/files.py b/replicate/files.py index 82f70c89..8520b1ae 100644 --- a/replicate/files.py +++ b/replicate/files.py @@ -3,22 +3,16 @@ import mimetypes import os +import httpx import requests -def upload_file(fh: io.IOBase, output_file_prefix: str = None) -> str: +def to_data_url(fh: io.IOBase) -> str: """ Lifted straight from cog.files """ fh.seek(0) - if output_file_prefix is not None: - name = getattr(fh, "name", "output") - url = output_file_prefix + os.path.basename(name) - resp = requests.put(url, files={"file": fh}) - resp.raise_for_status() - return url - b = fh.read() # The file handle is strings, not bytes if isinstance(b, str): @@ -31,3 +25,60 @@ def upload_file(fh: io.IOBase, output_file_prefix: str = None) -> str: mime_type = "application/octet-stream" s = encoded_body.decode("utf-8") return f"data:{mime_type};base64,{s}" + + +def upload_file_to_server(fh: io.IOBase, output_file_prefix: str) -> str: + """ + Lifted straight from cog.files + """ + fh.seek(0) + + name = getattr(fh, "name", "output") + url = output_file_prefix + os.path.basename(name) + resp = requests.put(url, files={"file": fh}) + resp.raise_for_status() + return url + + +def upload_file(fh: io.IOBase, output_file_prefix: str = None) -> str: + """ + Lifted straight from cog.files + """ + fh.seek(0) + + if output_file_prefix is not None: + url = upload_file_to_server(fh, output_file_prefix) + return url + + data_url: str = to_data_url(fh) + return data_url + + +async def upload_file_to_server_async(fh: io.IOBase, output_file_prefix: str) -> str: + """ + Lifted straight from cog.files + """ + fh.seek(0) + + name = getattr(fh, "name", "output") + url = output_file_prefix + os.path.basename(name) + + # httpx does not follow redirects by default + async with httpx.AsyncClient(follow_redirects=True) as client: + resp = await client.put(url, files={"file": fh}) + + return url + + +async def upload_file_async(fh: io.IOBase, output_file_prefix: str = None) -> str: + """ + Lifted straight from cog.files + """ + fh.seek(0) + + if output_file_prefix is not None: + url = await upload_file_to_server_async(fh, output_file_prefix) + return url + + data_url: str = to_data_url(fh) + return data_url diff --git a/replicate/prediction.py b/replicate/prediction.py index 54f5db66..fa0ba88d 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -1,3 +1,4 @@ +import asyncio import time from typing import Any, Dict, Iterator, List, Optional @@ -24,6 +25,34 @@ def wait(self): time.sleep(0.5) self.reload() + async def wait_async(self): + """Wait for prediction to finish.""" + while self.status not in ["succeeded", "failed", "canceled"]: + await asyncio.sleep(0.5) + await self.reload_async() + + async def output_iterator_async(self) -> Iterator[Any]: + # TODO: check output is list + previous_output = self.output or [] + while self.status not in ["succeeded", "failed", "canceled"]: + output = self.output or [] + new_output = output[len(previous_output) :] + for output in new_output: + yield output + previous_output = output + + await asyncio.sleep(0.5) + await self.reload_async() + + if self.status == "failed": + raise ModelError(self.error) + + output = self.output or [] + new_output = output[len(previous_output) :] + for output in new_output: + yield output + + def output_iterator(self) -> Iterator[Any]: # TODO: check output is list previous_output = self.output or [] @@ -48,6 +77,10 @@ def cancel(self): """Cancel a currently running prediction""" self._client._request("POST", f"/v1/predictions/{self.id}/cancel") + async def cancel_async(self): + """Cancel a currently running prediction""" + await self._client._request_async("POST", f"/v1/predictions/{self.id}/cancel") + class PredictionCollection(Collection): model = Prediction @@ -90,3 +123,43 @@ def list(self) -> List[Prediction]: # HACK: resolve this? make it lazy somehow? del prediction["version"] return [self.prepare_model(obj) for obj in predictions] + + async def create_async( + self, + version: Version, + input: Dict[str, Any], + webhook_completed: Optional[str] = None, + ) -> Prediction: + input = encode_json(input, upload_file=upload_file) + body = { + "version": version.id, + "input": input, + } + if webhook_completed is not None: + body["webhook_completed"] = webhook_completed + + resp = await self._client._request_async( + "POST", + "/v1/predictions", + json=body, + ) + + obj = resp.json() + obj["version"] = version + return self.prepare_model(obj) + + async def get_async(self, id: str) -> Prediction: + resp = await self._client._request_async("GET", f"/v1/predictions/{id}") + obj = resp.json() + # HACK: resolve this? make it lazy somehow? + del obj["version"] + return self.prepare_model(obj) + + async def list_async(self) -> List[Prediction]: + resp = await self._client._request_async("GET", f"/v1/predictions") + # TODO: paginate + predictions = resp.json()["results"] + for prediction in predictions: + # HACK: resolve this? make it lazy somehow? + del prediction["version"] + return [self.prepare_model(obj) for obj in predictions] diff --git a/replicate/version.py b/replicate/version.py index cc4cbd0c..bfdf9a28 100644 --- a/replicate/version.py +++ b/replicate/version.py @@ -31,6 +31,25 @@ def predict(self, **kwargs) -> Union[Any, Iterator[Any]]: raise ModelError(prediction.error) return prediction.output + + async def predict_async(self, **kwargs) -> Union[Any, Iterator[Any]]: + # TODO: support args + prediction = await self._client.predictions.create_async(version=self, input=kwargs) + # Return an iterator of the output + # FIXME: might just be a list, not an iterator. I wonder if we should differentiate? + schema = self.get_transformed_schema() + output = schema["components"]["schemas"]["Output"] + if ( + output.get("type") == "array" + and output.get("x-cog-array-type") == "iterator" + ): + return prediction.output_iterator_async() + + await prediction.wait_async() + if prediction.status == "failed": + raise ModelError(prediction.error) + return prediction.output + def get_transformed_schema(self): schema = self.openapi_schema schema = make_schema_backwards_compatible(schema, self.cog_version) @@ -44,6 +63,25 @@ def __init__(self, client, model): super().__init__(client=client) self._model = model + # doesn't exist yet + async def get_async(self, id: str) -> Version: + """ + Get a specific version. + """ + resp = await self._client._request_async( + "GET", f"/v1/models/{self._model.username}/{self._model.name}/versions/{id}" + ) + return self.prepare_model(resp.json()) + + async def list_async(self) -> List[Version]: + """ + Return a list of all versions for a model. + """ + resp = await self._client._request_async( + "GET", f"/v1/models/{self._model.username}/{self._model.name}/versions" + ) + return [self.prepare_model(obj) for obj in resp.json()["results"]] + # doesn't exist yet def get(self, id: str) -> Version: """ diff --git a/setup.py b/setup.py index 826f2483..62971aa6 100644 --- a/setup.py +++ b/setup.py @@ -22,6 +22,6 @@ license="BSD", url="https://github.com/replicate/replicate-python", python_requires=">=3.6", - install_requires=["requests", "pydantic", "packaging"], + install_requires=["requests", "pydantic", "packaging", "httpx"], classifiers=[], ) From 648e351a1f2519595bed3dcc72094ebf70aaf82f Mon Sep 17 00:00:00 2001 From: Diwank Singh Tomer Date: Sat, 25 Mar 2023 00:12:21 +0530 Subject: [PATCH 2/2] feat: Add a simple test Signed-off-by: Diwank Singh Tomer --- requirements-dev.txt | 1 + tests/test_async.py | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+) create mode 100644 tests/test_async.py diff --git a/requirements-dev.txt b/requirements-dev.txt index b3d9b4fa..a8b62281 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,3 +1,4 @@ packaging==21.3 pytest==7.1.2 +pytest-asyncio==0.21.0 responses==0.21.0 diff --git a/tests/test_async.py b/tests/test_async.py new file mode 100644 index 00000000..b70e9243 --- /dev/null +++ b/tests/test_async.py @@ -0,0 +1,20 @@ +import pytest + +import replicate + +@pytest.mark.asyncio +async def test_async_client(): + model = replicate.models.get("creatorrr/instructor-large") + version = await model.versions.get_async("bd2701dac1aea9d598bda71e6ae56b204287c0a79e2cadf96b1393127d044495") + + inputs = { + # Text to embed + 'text': "Hello world! How are you doing?", + + # Embedding instruction + 'instruction': "Represent the following text", + } + + output = await version.predict_async(**inputs) + + assert output["result"]