diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 72bec4a4..a2992dae 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -23,9 +23,15 @@ jobs: python-version: ${{ matrix.python-version }} cache: "pip" - name: Install dependencies - run: python -m pip install -r requirements.txt -r requirements-dev.txt . + run: | + python -m pip install -r requirements.txt -r requirements-dev.txt . + yes | python -m mypy --install-types replicate || true + - name: Lint - run: python -m ruff . + run: | + python -m mypy replicate + python -m ruff . + python -m black --check . - name: Test run: python -m pytest diff --git a/pyproject.toml b/pyproject.toml index 24400d23..46518389 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,13 @@ license = { file = "LICENSE" } authors = [{ name = "Replicate, Inc." }] requires-python = ">=3.8" dependencies = ["packaging", "pydantic>1", "requests>2"] -optional-dependencies = { dev = ["black", "pytest", "responses", "ruff"] } +optional-dependencies = { dev = [ + "black", + "mypy", + "pytest", + "responses", + "ruff", +] } [project.urls] homepage = "https://replicate.com" @@ -31,11 +37,21 @@ select = [ "BLE", # flake8-blind-except "FBT", # flake8-boolean-trap "B", # flake8-bugbear + "ANN", # flake8-annotations ] ignore = [ - "E501", # Line too long - "S113", # Probable use of requests call without timeout + "E501", # Line too long + "S113", # Probable use of requests call without timeout + "ANN001", # Missing type annotation for function argument + "ANN002", # Missing type annotation for `*args` + "ANN003", # Missing type annotation for `**kwargs` + "ANN101", # Missing type annotation for self in method + "ANN102", # Missing type annotation for cls in classmethod ] [tool.ruff.per-file-ignores] -"tests/*" = ["S101", "S106"] +"tests/*" = [ + "S101", # Use of assert + "S106", # Possible use of hard-coded password function arguments + "ANN201", # Missing return type annotation for public function +] diff --git a/replicate/base_model.py b/replicate/base_model.py index 164ca045..b3cf1d48 100644 --- a/replicate/base_model.py +++ b/replicate/base_model.py @@ -1,9 +1,10 @@ -from typing import ForwardRef +from typing import TYPE_CHECKING -import pydantic +if TYPE_CHECKING: + from replicate.client import Client + from replicate.collection import Collection -Client = ForwardRef("Client") -Collection = ForwardRef("Collection") +import pydantic class BaseModel(pydantic.BaseModel): @@ -11,10 +12,12 @@ class BaseModel(pydantic.BaseModel): A base class for representing a single object on the server. """ - _client: Client = pydantic.PrivateAttr() - _collection: Collection = pydantic.PrivateAttr() + id: str + + _client: "Client" = pydantic.PrivateAttr() + _collection: "Collection" = pydantic.PrivateAttr() - def reload(self): + def reload(self) -> None: """ Load this object from the server again. """ diff --git a/replicate/client.py b/replicate/client.py index 75409787..4196d6aa 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -1,7 +1,7 @@ import os import re from json import JSONDecodeError -from typing import Any, Iterator, Union +from typing import Any, Dict, Iterator, Optional, Union import requests from requests.adapters import HTTPAdapter, Retry @@ -14,7 +14,7 @@ class Client: - def __init__(self, api_token=None) -> None: + def __init__(self, api_token: Optional[str] = None) -> None: super().__init__() # Client is instantiated at import time, so do as little as possible. # This includes resolving environment variables -- they might be set programmatically. @@ -30,7 +30,7 @@ def __init__(self, api_token=None) -> None: total=5, backoff_factor=2, # Only retry 500s on GET so we don't unintionally mutute data - method_whitelist=["GET"], + allowed_methods=["GET"], # https://support.cloudflare.com/hc/en-us/articles/115003011431-Troubleshooting-Cloudflare-5XX-errors status_forcelist=[ 429, @@ -54,14 +54,14 @@ def __init__(self, api_token=None) -> None: write_retries = Retry( total=5, backoff_factor=2, - method_whitelist=["POST", "PUT"], + allowed_methods=["POST", "PUT"], # Only retry POST/PUT requests on rate limits, so we don't unintionally mutute data status_forcelist=[429], ) self.write_session.mount("http://", HTTPAdapter(max_retries=write_retries)) self.write_session.mount("https://", HTTPAdapter(max_retries=write_retries)) - def _request(self, method: str, path: str, **kwargs): + def _request(self, method: str, path: str, **kwargs) -> requests.Response: # from requests.Session if method in ["GET", "OPTIONS"]: kwargs.setdefault("allow_redirects", True) @@ -81,13 +81,13 @@ def _request(self, method: str, path: str, **kwargs): raise ReplicateError(f"HTTP error: {resp.status_code, resp.reason}") return resp - def _headers(self): + def _headers(self) -> Dict[str, str]: return { "Authorization": f"Token {self._api_token()}", "User-Agent": f"replicate-python@{__version__}", } - def _api_token(self): + def _api_token(self) -> str: token = self.api_token # Evaluate lazily in case environment variable is set with dotenv, or something if token is None: @@ -112,7 +112,7 @@ def predictions(self) -> PredictionCollection: def trainings(self) -> TrainingCollection: return TrainingCollection(client=self) - def run(self, model_version, **kwargs) -> Union[Any, Iterator[Any]]: + def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]: """ Run a model in the format owner/name:version. """ diff --git a/replicate/collection.py b/replicate/collection.py index 94766aae..92e7a88a 100644 --- a/replicate/collection.py +++ b/replicate/collection.py @@ -1,38 +1,54 @@ +import abc +from typing import TYPE_CHECKING, Dict, Generic, List, TypeVar, Union, cast + +if TYPE_CHECKING: + from replicate.client import Client + from replicate.base_model import BaseModel +Model = TypeVar("Model", bound=BaseModel) + -class Collection: +class Collection(abc.ABC, Generic[Model]): """ A base class for representing all objects of a particular type on the server. """ - model: BaseModel = None - - def __init__(self, client=None): + def __init__(self, client: "Client") -> None: self._client = client - def list(self): - raise NotImplementedError + @abc.abstractproperty + def model(self) -> Model: + pass + + @abc.abstractmethod + def list(self) -> List[Model]: + pass - def get(self, key): - raise NotImplementedError + @abc.abstractmethod + def get(self, key: str) -> Model: + pass - def create(self, attrs=None): - raise NotImplementedError + @abc.abstractmethod + def create(self, **kwargs) -> Model: + pass - def prepare_model(self, attrs): + def prepare_model(self, attrs: Union[Model, Dict]) -> Model: """ Create a model from a set of attributes. """ if isinstance(attrs, BaseModel): attrs._client = self._client attrs._collection = self - return attrs - elif isinstance(attrs, dict): + return cast(Model, attrs) + elif ( + isinstance(attrs, dict) and self.model is not None and callable(self.model) + ): model = self.model(**attrs) model._client = self._client model._collection = self return model else: - raise Exception(f"Can't create {self.model.__name__} from {attrs}") + name = self.model.__name__ if hasattr(self.model, "__name__") else "model" + raise Exception(f"Can't create {name} from {attrs}") diff --git a/replicate/files.py b/replicate/files.py index c93d411f..55a6612c 100644 --- a/replicate/files.py +++ b/replicate/files.py @@ -2,11 +2,12 @@ import io import mimetypes import os +from typing import Optional import requests -def upload_file(fh: io.IOBase, output_file_prefix: str = None) -> str: +def upload_file(fh: io.IOBase, output_file_prefix: Optional[str] = None) -> str: """ Lifted straight from cog.files """ diff --git a/replicate/json.py b/replicate/json.py index bab3b4d8..cd0b864e 100644 --- a/replicate/json.py +++ b/replicate/json.py @@ -11,7 +11,9 @@ has_numpy = False -def encode_json(obj: Any, upload_file: Callable[[io.IOBase], str]) -> Any: +def encode_json( + obj: Any, upload_file: Callable[[io.IOBase], str] # noqa: ANN401 +) -> Any: # noqa: ANN401 """ Returns a JSON-compatible version of the object. Effectively the same thing as cog.json.encode_json. """ diff --git a/replicate/model.py b/replicate/model.py index 5d240955..d6b32fcd 100644 --- a/replicate/model.py +++ b/replicate/model.py @@ -1,3 +1,5 @@ +from typing import Dict, List, Union + from replicate.base_model import BaseModel from replicate.collection import Collection from replicate.exceptions import ReplicateException @@ -8,21 +10,34 @@ class Model(BaseModel): username: str name: str - def predict(self, *args, **kwargs): + def predict(self, *args, **kwargs) -> None: raise ReplicateException( "The `model.predict()` method has been removed, because it's unstable: if a new version of the model you're using is pushed and its API has changed, your code may break. Use `version.predict()` instead. See https://github.com/replicate/replicate-python#readme" ) @property - def versions(self): + def versions(self) -> VersionCollection: return VersionCollection(client=self._client, model=self) class ModelCollection(Collection): model = Model + def list(self) -> List[Model]: + raise NotImplementedError() + def get(self, name: str) -> Model: # TODO: fetch model from server # TODO: support permanent IDs username, name = name.split("/") return self.prepare_model({"username": username, "name": name}) + + def create(self, **kwargs) -> Model: + raise NotImplementedError() + + def prepare_model(self, attrs: Union[Model, Dict]) -> Model: + if isinstance(attrs, BaseModel): + attrs.id = f"{attrs.username}/{attrs.name}" + elif isinstance(attrs, dict): + attrs["id"] = f"{attrs['username']}/{attrs['name']}" + return super().prepare_model(attrs) diff --git a/replicate/prediction.py b/replicate/prediction.py index 0e4d2349..a69217d8 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -21,7 +21,7 @@ class Prediction(BaseModel): created_at: Optional[str] completed_at: Optional[str] - def wait(self): + def wait(self) -> None: """Wait for prediction to finish.""" while self.status not in ["succeeded", "failed", "canceled"]: time.sleep(self._client.poll_interval) @@ -47,7 +47,7 @@ def output_iterator(self) -> Iterator[Any]: for output in new_output: yield output - def cancel(self): + def cancel(self) -> None: """Cancel a currently running prediction""" self._client._request("POST", f"/v1/predictions/{self.id}/cancel") @@ -55,13 +55,30 @@ def cancel(self): class PredictionCollection(Collection): model = Prediction - def create( + def list(self) -> List[Prediction]: + resp = self._client._request("GET", "/v1/predictions") + # TODO: paginate + predictions = resp.json()["results"] + for prediction in predictions: + # HACK: resolve this? make it lazy somehow? + del prediction["version"] + return [self.prepare_model(obj) for obj in predictions] + + def get(self, id: str) -> Prediction: + resp = self._client._request("GET", f"/v1/predictions/{id}") + obj = resp.json() + # HACK: resolve this? make it lazy somehow? + del obj["version"] + return self.prepare_model(obj) + + def create( # type: ignore self, version: Version, input: Dict[str, Any], webhook: Optional[str] = None, webhook_completed: Optional[str] = None, webhook_events_filter: Optional[List[str]] = None, + **kwargs, ) -> Prediction: input = encode_json(input, upload_file=upload_file) body = { @@ -83,19 +100,3 @@ def create( obj = resp.json() obj["version"] = version return self.prepare_model(obj) - - def get(self, id: str) -> Prediction: - resp = self._client._request("GET", f"/v1/predictions/{id}") - obj = resp.json() - # HACK: resolve this? make it lazy somehow? - del obj["version"] - return self.prepare_model(obj) - - def list(self) -> List[Prediction]: - resp = self._client._request("GET", "/v1/predictions") - # TODO: paginate - predictions = resp.json()["results"] - for prediction in predictions: - # HACK: resolve this? make it lazy somehow? - del prediction["version"] - return [self.prepare_model(obj) for obj in predictions] diff --git a/replicate/schema.py b/replicate/schema.py index 34642cc3..a48d2351 100644 --- a/replicate/schema.py +++ b/replicate/schema.py @@ -3,12 +3,15 @@ # TODO: this code is shared with replicate's backend. Maybe we should put it in the Cog Python package as the source of truth? -def version_has_no_array_type(cog_version): +def version_has_no_array_type(cog_version: str) -> bool: """Iterators have x-cog-array-type=iterator in the schema from 0.3.9 onward""" return version.parse(cog_version) < version.parse("0.3.9") -def make_schema_backwards_compatible(schema, version): +def make_schema_backwards_compatible( + schema: dict, + version: str, +) -> dict: """A place to add backwards compatibility logic for our openapi schema""" # If the top-level output is an array, assume it is an iterator in old versions which didn't have an array type if version_has_no_array_type(version): diff --git a/replicate/training.py b/replicate/training.py index bf43a465..b60c9613 100644 --- a/replicate/training.py +++ b/replicate/training.py @@ -21,7 +21,7 @@ class Training(BaseModel): status: str version: str - def cancel(self): + def cancel(self) -> None: """Cancel a running training""" self._client._request("POST", f"/v1/trainings/{self.id}/cancel") @@ -29,13 +29,25 @@ def cancel(self): class TrainingCollection(Collection): model = Training - def create( + def list(self) -> List[Training]: + raise NotImplementedError() + + def get(self, id: str) -> Training: + resp = self._client._request( + "GET", + f"/v1/trainings/{id}", + ) + obj = resp.json() + return self.prepare_model(obj) + + def create( # type: ignore self, version: str, input: Dict[str, Any], destination: str, webhook: Optional[str] = None, webhook_events_filter: Optional[List[str]] = None, + **kwargs, ) -> Training: input = encode_json(input, upload_file=upload_file) body = { @@ -66,11 +78,3 @@ def create( ) obj = resp.json() return self.prepare_model(obj) - - def get(self, id: str) -> Training: - resp = self._client._request( - "GET", - f"/v1/trainings/{id}", - ) - obj = resp.json() - return self.prepare_model(obj) diff --git a/replicate/version.py b/replicate/version.py index 02a5fa5d..d4ed9108 100644 --- a/replicate/version.py +++ b/replicate/version.py @@ -1,6 +1,11 @@ import datetime import warnings -from typing import Any, Iterator, List, Union +from typing import TYPE_CHECKING, Any, Iterator, List, Union + +if TYPE_CHECKING: + from replicate.client import Client + from replicate.model import Model + from replicate.base_model import BaseModel from replicate.collection import Collection @@ -12,7 +17,7 @@ class Version(BaseModel): id: str created_at: datetime.datetime cog_version: str - openapi_schema: Any + openapi_schema: dict def predict(self, **kwargs) -> Union[Any, Iterator[Any]]: warnings.warn( @@ -36,7 +41,7 @@ def predict(self, **kwargs) -> Union[Any, Iterator[Any]]: raise ModelError(prediction.error) return prediction.output - def get_transformed_schema(self): + def get_transformed_schema(self) -> dict: schema = self.openapi_schema schema = make_schema_backwards_compatible(schema, self.cog_version) return schema @@ -45,7 +50,7 @@ def get_transformed_schema(self): class VersionCollection(Collection): model = Version - def __init__(self, client, model): + def __init__(self, client: "Client", model: "Model") -> None: super().__init__(client=client) self._model = model @@ -59,6 +64,9 @@ def get(self, id: str) -> Version: ) return self.prepare_model(resp.json()) + def create(self, **kwargs) -> Version: + raise NotImplementedError() + def list(self) -> List[Version]: """ Return a list of all versions for a model. diff --git a/requirements-dev.txt b/requirements-dev.txt index c3865e8e..e23d31d9 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,7 +2,7 @@ # This file is autogenerated by pip-compile with Python 3.11 # by the following command: # -# pip-compile --extra=dev --output-file=requirements-dev.txt pyproject.toml +# pip-compile --extra=dev --output-file=requirements-dev.txt --resolver=backtracking pyproject.toml # attrs==22.2.0 # via pytest @@ -18,8 +18,12 @@ idna==3.4 # via requests iniconfig==2.0.0 # via pytest +mypy==1.2.0 + # via replicate (pyproject.toml) mypy-extensions==1.0.0 - # via black + # via + # black + # mypy packaging==23.0 # via # black @@ -48,7 +52,9 @@ ruff==0.0.261 types-pyyaml==6.0.12.9 # via responses typing-extensions==4.5.0 - # via pydantic + # via + # mypy + # pydantic urllib3==1.26.15 # via # requests