diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..7bda9a2 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,4 @@ +{ + "python.formatting.provider": "black", + "editor.formatOnSave": true +} diff --git a/README.md b/README.md index e69de29..f3a1e1d 100644 --- a/README.md +++ b/README.md @@ -0,0 +1,73 @@ +# Replicate Python client + +This is a Python client for Replicate. It lets you run models from your Python code or Jupyter notebook, and do various other things on Replicate. + +You can run a model and get its output: + +```python +>>> import replicate + +>>> model = replicate.models.get("bfirsh/resnet") +>>> model.predict(open("mystery.jpg")) +[('n02123597', 'Siamese_cat', 0.88293666), ('n02123394', 'Persian_cat', 0.09810519), ('n02123045', 'tabby', 0.0057580653)] +``` + +You can run a model and feed the output into another model: + +```python +>>> image = replicate.models.get("afiaka87/clip-guided-diffusion".predict(prompt="avocado armchair") +>>> upscaled_image = replicate.models.get("jingyunliang/swinir").predict(image=image) +``` + +Run a model and get its output while it's running: + +```python +model = replicate.models.get("pixray/text2image") +for image in model.predict(prompt="san francisco sunset"): + display(image) +``` + +You can start a model and run it in the background: + +```python +>>> prediction = replicate.predictions.create( +... version="kvfrans/clipdraw", +... input={"prompt":"Watercolor painting of an underwater submarine"}) + +>>> prediction + + +>>> prediction.status +Prediction.STATUS_RUNNING + +>>> prediction.logs +["something happened"] + +>>> dict(prediction) +{"id": "...", "status": "running", ...} + +>>> prediction.reload() +>>> prediction.logs +["something happened", "another thing happened"] + +>>> prediction.wait() + +>>> prediction.status +Prediction.STATUS_SUCCESSFUL + +>>> prediction.output + +``` + +You can list all the predictions you've run: + +``` +>>> replicate.predictions.list() +[, ] +``` + +## Install + +```bash +pip install replicate +``` diff --git a/replicate/__init__.py b/replicate/__init__.py new file mode 100644 index 0000000..fc06fd4 --- /dev/null +++ b/replicate/__init__.py @@ -0,0 +1,6 @@ +from .client import Client + +default_client = Client() +models = default_client.models +predictions = default_client.predictions +versions = default_client.versions diff --git a/replicate/base_model.py b/replicate/base_model.py new file mode 100644 index 0000000..56e639c --- /dev/null +++ b/replicate/base_model.py @@ -0,0 +1,32 @@ +from typing import ForwardRef +import pydantic + + +Client = ForwardRef("Client") +Collection = ForwardRef("Collection") + + +class BaseModel(pydantic.BaseModel): + """ + A base class for representing a single object on the server. + """ + + _client: Client = pydantic.PrivateAttr() + _collection: Collection = pydantic.PrivateAttr() + + def __init__(self, attrs=None, client=None, collection=None): + super().__init__(**attrs) + + #: A client pointing at the server that this object is on. + self._client = client + + #: The collection that this model is part of. + self._collection = collection + + def reload(self): + """ + Load this object from the server again. + """ + new_model = self._collection.get(self.id) + for k, v in new_model.dict().items(): + setattr(self, k, v) diff --git a/replicate/client.py b/replicate/client.py new file mode 100644 index 0000000..627f471 --- /dev/null +++ b/replicate/client.py @@ -0,0 +1,46 @@ +import os +import requests + +from replicate.model import ModelCollection +from replicate.prediction import PredictionCollection +from replicate.version import VersionCollection + + +class Client: + def __init__(self, api_token=None) -> None: + super().__init__() + self.api_token = api_token + if self.api_token is None: + self.api_token = os.environ.get("REPLICATE_API_TOKEN") + + self.base_url = "https://api.replicate.com" + + # TODO: make thread safe + self.session = requests.Session() + + def _get(self, path: str, **kwargs): + if "headers" not in kwargs: + kwargs["headers"] = {} + kwargs["headers"].update(self._headers()) + return self.session.get(self.base_url + path, **kwargs) + + def _post(self, path: str, **kwargs): + if "headers" not in kwargs: + kwargs["headers"] = {} + kwargs["headers"].update(self._headers()) + return self.session.post(self.base_url + path, **kwargs) + + def _headers(self): + return {"Authorization": f"Token {self.api_token}"} + + @property + def models(self) -> ModelCollection: + return ModelCollection(client=self) + + @property + def predictions(self) -> PredictionCollection: + return PredictionCollection(client=self) + + @property + def versions(self) -> VersionCollection: + return VersionCollection(client=self) diff --git a/replicate/collection.py b/replicate/collection.py new file mode 100644 index 0000000..01a37a6 --- /dev/null +++ b/replicate/collection.py @@ -0,0 +1,35 @@ +from replicate.base_model import BaseModel + + +class Collection: + """ + A base class for representing all objects of a particular type on the + server. + """ + + model: BaseModel = None + + def __init__(self, client=None): + self._client = client + + def list(self): + raise NotImplementedError + + def get(self, key): + raise NotImplementedError + + def create(self, attrs=None): + raise NotImplementedError + + def prepare_model(self, attrs): + """ + Create a model from a set of attributes. + """ + if isinstance(attrs, BaseModel): + attrs._client = self._client + attrs._collection = self + return attrs + elif isinstance(attrs, dict): + return self.model(attrs=attrs, client=self._client, collection=self) + else: + raise Exception("Can't create %s from %s" % (self.model.__name__, attrs)) diff --git a/replicate/exceptions.py b/replicate/exceptions.py new file mode 100644 index 0000000..26253e0 --- /dev/null +++ b/replicate/exceptions.py @@ -0,0 +1,6 @@ +class ReplicateException(Exception): + pass + + +class ModelError(ReplicateException): + pass diff --git a/replicate/model.py b/replicate/model.py new file mode 100644 index 0000000..efbad66 --- /dev/null +++ b/replicate/model.py @@ -0,0 +1,30 @@ +from typing import List + +from replicate.base_model import BaseModel +from replicate.collection import Collection +from replicate.exceptions import ReplicateException +from replicate.version import Version + + +class Model(BaseModel): + username: str + name: str + + def predict(self, *args, **kwargs): + versions = self._client.versions.list(self) + if not versions: + raise ReplicateException( + "No versions found for model %s/%s" % (self.username, self.name) + ) + latest_version = versions[0] + return latest_version.predict(*args, **kwargs) + + +class ModelCollection(Collection): + model = Model + + def get(self, name: str) -> Model: + # TODO: fetch model from server + # TODO: support permanent IDs + username, name = name.split("/") + return self.prepare_model({"username": username, "name": name}) diff --git a/replicate/prediction.py b/replicate/prediction.py new file mode 100644 index 0000000..d28cf5f --- /dev/null +++ b/replicate/prediction.py @@ -0,0 +1,72 @@ +import time +from typing import Any, Dict, Iterator, List, Optional + +from replicate.base_model import BaseModel +from replicate.collection import Collection +from replicate.exceptions import ModelError, ReplicateException +from replicate.version import Version + + +class Prediction(BaseModel): + id: str + error: Optional[str] + input: Optional[Dict[str, Any]] + logs: Optional[str] + output: Optional[Any] + status: str + version: Optional[Version] + + def wait(self): + """Wait for prediction to finish.""" + while self.status not in ["succeeded", "failed"]: + time.sleep(0.1) + self.reload() + + def output_iterator(self) -> Iterator[Any]: + # TODO: check output is list + previous_output = self.output or [] + while self.status not in ["succeeded", "failed"]: + if self.status == "failed": + raise ModelError(self.error) + output = self.output or [] + new_output = output[len(previous_output) :] + for output in new_output: + yield output + previous_output = output + time.sleep(0.1) + self.reload() + output = self.output or [] + new_output = output[len(previous_output) :] + for output in new_output: + yield output + + +class PredictionCollection(Collection): + model = Prediction + + def create(self, version: Version, input: Dict[str, Any]) -> Prediction: + resp = self._client._post( + "/v1/predictions", json={"version": version.id, "input": input} + ) + resp.raise_for_status() + obj = resp.json() + obj["version"] = version + return self.prepare_model(obj) + + def get(self, id: str) -> Prediction: + resp = self._client._get(f"/v1/predictions/{id}") + resp.raise_for_status() + obj = resp.json() + # HACK: resolve this? make it lazy somehow? + del obj["version"] + return self.prepare_model(obj) + + def list(self) -> List[Prediction]: + resp = self._client._get(f"/v1/predictions") + resp.raise_for_status() + # TODO: paginate + predictions = resp.json()["results"] + for prediction in predictions: + # HACK: resolve this? make it lazy somehow? + del prediction["version"] + return [self.prepare_model(obj) for obj in predictions] diff --git a/replicate/version.py b/replicate/version.py new file mode 100644 index 0000000..ace9dd4 --- /dev/null +++ b/replicate/version.py @@ -0,0 +1,51 @@ +import datetime +from typing import Any, Iterator, List + +from replicate.base_model import BaseModel +from replicate.collection import Collection +from replicate.exceptions import ModelError + + +class Version(BaseModel): + id: str + created_at: datetime.datetime + cog_version: str + openapi_schema: Any + + def predict(self, **kwargs) -> Any | Iterator[Any]: + # TODO: support args + prediction = self._client.predictions.create(version=self, input=kwargs) + # Return an iterator of the output + # FIXME: might just be a list, not an iterator. I wonder if we should differentiate? + if ( + self.openapi_schema["components"]["schemas"]["Output"].get("type") + == "array" + ): + return prediction.output_iterator() + + prediction.wait() + if prediction.status == "failed": + raise ModelError(prediction.error) + return prediction.output + + +class VersionCollection(Collection): + model = Version + + # doesn't exist yet + def get(self, id: str) -> Version: + """ + Get a specific version. + """ + resp = self._client._get(f"/v1/versions/{id}") + resp.raise_for_status() + return self.prepare_model(resp.json()) + + # HACK: model should be a property, or something, and get attached to the version + def list(self, model) -> List[Version]: + """ + Return a list of all versions for a model. + """ + resp = self._client._get(f"/v1/models/{model.username}/{model.name}/versions") + resp.raise_for_status() + return [self.prepare_model(obj) for obj in resp.json()] diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..b12b043 --- /dev/null +++ b/setup.py @@ -0,0 +1,16 @@ +# !/usr/bin/env python + +from distutils.core import setup + +setup( + name="replicate", + packages=["replicate"], + version="0.1.0", + description="Python client for Replicate", + author="Replicate, Inc.", + license="BSD", + url="https://github.com/replicate/replicate-python", + python_requires=">=3.6", + install_requires=["requests", "pydantic"], + classifiers=[], +)