Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"python.formatting.provider": "black",
"editor.formatOnSave": true
}
73 changes: 73 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Replicate Python client

This is a Python client for Replicate. It lets you run models from your Python code or Jupyter notebook, and do various other things on Replicate.

You can run a model and get its output:

```python
>>> import replicate

>>> model = replicate.models.get("bfirsh/resnet")
>>> model.predict(open("mystery.jpg"))
[('n02123597', 'Siamese_cat', 0.88293666), ('n02123394', 'Persian_cat', 0.09810519), ('n02123045', 'tabby', 0.0057580653)]
```

You can run a model and feed the output into another model:

```python
>>> image = replicate.models.get("afiaka87/clip-guided-diffusion".predict(prompt="avocado armchair")
>>> upscaled_image = replicate.models.get("jingyunliang/swinir").predict(image=image)
```

Run a model and get its output while it's running:

```python
model = replicate.models.get("pixray/text2image")
for image in model.predict(prompt="san francisco sunset"):
display(image)
```

You can start a model and run it in the background:

```python
>>> prediction = replicate.predictions.create(
... version="kvfrans/clipdraw",
... input={"prompt":"Watercolor painting of an underwater submarine"})

>>> prediction
<Prediction 38a73e57ddb9 on kvfrans/clipdraw:8b0ba5ab4d85>

>>> prediction.status
Prediction.STATUS_RUNNING

>>> prediction.logs
["something happened"]

>>> dict(prediction)
{"id": "...", "status": "running", ...}

>>> prediction.reload()
>>> prediction.logs
["something happened", "another thing happened"]

>>> prediction.wait()

>>> prediction.status
Prediction.STATUS_SUCCESSFUL

>>> prediction.output
<file: output.png>
```

You can list all the predictions you've run:

```
>>> replicate.predictions.list()
[<Prediction: 8b0ba5ab4d85>, <Prediction: 494900564e8c>]
```

## Install

```bash
pip install replicate
```
6 changes: 6 additions & 0 deletions replicate/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .client import Client

default_client = Client()
models = default_client.models
predictions = default_client.predictions
versions = default_client.versions
32 changes: 32 additions & 0 deletions replicate/base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import ForwardRef
import pydantic


Client = ForwardRef("Client")
Collection = ForwardRef("Collection")


class BaseModel(pydantic.BaseModel):
"""
A base class for representing a single object on the server.
"""

_client: Client = pydantic.PrivateAttr()
_collection: Collection = pydantic.PrivateAttr()

def __init__(self, attrs=None, client=None, collection=None):
super().__init__(**attrs)

#: A client pointing at the server that this object is on.
self._client = client

#: The collection that this model is part of.
self._collection = collection

def reload(self):
"""
Load this object from the server again.
"""
new_model = self._collection.get(self.id)
for k, v in new_model.dict().items():
setattr(self, k, v)
46 changes: 46 additions & 0 deletions replicate/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import os
import requests

from replicate.model import ModelCollection
from replicate.prediction import PredictionCollection
from replicate.version import VersionCollection


class Client:
def __init__(self, api_token=None) -> None:
super().__init__()
self.api_token = api_token
if self.api_token is None:
self.api_token = os.environ.get("REPLICATE_API_TOKEN")

self.base_url = "https://api.replicate.com"

# TODO: make thread safe
self.session = requests.Session()

def _get(self, path: str, **kwargs):
if "headers" not in kwargs:
kwargs["headers"] = {}
kwargs["headers"].update(self._headers())
return self.session.get(self.base_url + path, **kwargs)

def _post(self, path: str, **kwargs):
if "headers" not in kwargs:
kwargs["headers"] = {}
kwargs["headers"].update(self._headers())
return self.session.post(self.base_url + path, **kwargs)

def _headers(self):
return {"Authorization": f"Token {self.api_token}"}

@property
def models(self) -> ModelCollection:
return ModelCollection(client=self)

@property
def predictions(self) -> PredictionCollection:
return PredictionCollection(client=self)

@property
def versions(self) -> VersionCollection:
return VersionCollection(client=self)
35 changes: 35 additions & 0 deletions replicate/collection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from replicate.base_model import BaseModel


class Collection:
"""
A base class for representing all objects of a particular type on the
server.
"""

model: BaseModel = None

def __init__(self, client=None):
self._client = client

def list(self):
raise NotImplementedError

def get(self, key):
raise NotImplementedError

def create(self, attrs=None):
raise NotImplementedError

def prepare_model(self, attrs):
"""
Create a model from a set of attributes.
"""
if isinstance(attrs, BaseModel):
attrs._client = self._client
attrs._collection = self
return attrs
elif isinstance(attrs, dict):
return self.model(attrs=attrs, client=self._client, collection=self)
else:
raise Exception("Can't create %s from %s" % (self.model.__name__, attrs))
6 changes: 6 additions & 0 deletions replicate/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class ReplicateException(Exception):
pass


class ModelError(ReplicateException):
pass
30 changes: 30 additions & 0 deletions replicate/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import List

from replicate.base_model import BaseModel
from replicate.collection import Collection
from replicate.exceptions import ReplicateException
from replicate.version import Version


class Model(BaseModel):
username: str
name: str

def predict(self, *args, **kwargs):
versions = self._client.versions.list(self)
if not versions:
raise ReplicateException(
"No versions found for model %s/%s" % (self.username, self.name)
)
latest_version = versions[0]
return latest_version.predict(*args, **kwargs)


class ModelCollection(Collection):
model = Model

def get(self, name: str) -> Model:
# TODO: fetch model from server
# TODO: support permanent IDs
username, name = name.split("/")
return self.prepare_model({"username": username, "name": name})
72 changes: 72 additions & 0 deletions replicate/prediction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import time
from typing import Any, Dict, Iterator, List, Optional

from replicate.base_model import BaseModel
from replicate.collection import Collection
from replicate.exceptions import ModelError, ReplicateException
from replicate.version import Version


class Prediction(BaseModel):
id: str
error: Optional[str]
input: Optional[Dict[str, Any]]
logs: Optional[str]
output: Optional[Any]
status: str
version: Optional[Version]

def wait(self):
"""Wait for prediction to finish."""
while self.status not in ["succeeded", "failed"]:
time.sleep(0.1)
self.reload()

def output_iterator(self) -> Iterator[Any]:
# TODO: check output is list
previous_output = self.output or []
while self.status not in ["succeeded", "failed"]:
if self.status == "failed":
raise ModelError(self.error)
output = self.output or []
new_output = output[len(previous_output) :]
for output in new_output:
yield output
previous_output = output
time.sleep(0.1)
self.reload()
output = self.output or []
new_output = output[len(previous_output) :]
for output in new_output:
yield output


class PredictionCollection(Collection):
model = Prediction

def create(self, version: Version, input: Dict[str, Any]) -> Prediction:
resp = self._client._post(
"/v1/predictions", json={"version": version.id, "input": input}
)
resp.raise_for_status()
obj = resp.json()
obj["version"] = version
return self.prepare_model(obj)

def get(self, id: str) -> Prediction:
resp = self._client._get(f"/v1/predictions/{id}")
resp.raise_for_status()
obj = resp.json()
# HACK: resolve this? make it lazy somehow?
del obj["version"]
return self.prepare_model(obj)

def list(self) -> List[Prediction]:
resp = self._client._get(f"/v1/predictions")
resp.raise_for_status()
# TODO: paginate
predictions = resp.json()["results"]
for prediction in predictions:
# HACK: resolve this? make it lazy somehow?
del prediction["version"]
return [self.prepare_model(obj) for obj in predictions]
51 changes: 51 additions & 0 deletions replicate/version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import datetime
from typing import Any, Iterator, List

from replicate.base_model import BaseModel
from replicate.collection import Collection
from replicate.exceptions import ModelError


class Version(BaseModel):
id: str
created_at: datetime.datetime
cog_version: str
openapi_schema: Any

def predict(self, **kwargs) -> Any | Iterator[Any]:
# TODO: support args
prediction = self._client.predictions.create(version=self, input=kwargs)
# Return an iterator of the output
# FIXME: might just be a list, not an iterator. I wonder if we should differentiate?
if (
self.openapi_schema["components"]["schemas"]["Output"].get("type")
== "array"
):
return prediction.output_iterator()

prediction.wait()
if prediction.status == "failed":
raise ModelError(prediction.error)
return prediction.output


class VersionCollection(Collection):
model = Version

# doesn't exist yet
def get(self, id: str) -> Version:
"""
Get a specific version.
"""
resp = self._client._get(f"/v1/versions/{id}")
resp.raise_for_status()
return self.prepare_model(resp.json())

# HACK: model should be a property, or something, and get attached to the version
def list(self, model) -> List[Version]:
"""
Return a list of all versions for a model.
"""
resp = self._client._get(f"/v1/models/{model.username}/{model.name}/versions")
resp.raise_for_status()
return [self.prepare_model(obj) for obj in resp.json()]
16 changes: 16 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# !/usr/bin/env python

from distutils.core import setup

setup(
name="replicate",
packages=["replicate"],
version="0.1.0",
description="Python client for Replicate",
author="Replicate, Inc.",
license="BSD",
url="https://github.com/replicate/replicate-python",
python_requires=">=3.6",
install_requires=["requests", "pydantic"],
classifiers=[],
)