From 6097f6dba7882b254fbf11c0b12916f9ec85ef90 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 10 Apr 2023 06:15:17 -0700 Subject: [PATCH 01/15] Enable flake8-annotations check Signed-off-by: Mattt Zmuda --- pyproject.toml | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 24400d23..54bc27e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,11 +31,20 @@ select = [ "BLE", # flake8-blind-except "FBT", # flake8-boolean-trap "B", # flake8-bugbear + "ANN", # flake8-annotations ] ignore = [ - "E501", # Line too long - "S113", # Probable use of requests call without timeout + "E501", # Line too long + "S113", # Probable use of requests call without timeout + "ANN002", # Missing type annotation for `*args` + "ANN003", # Missing type annotation for `**kwargs` + "ANN101", # Missing type annotation for self in method + "ANN102", # Missing type annotation for cls in classmethod ] [tool.ruff.per-file-ignores] -"tests/*" = ["S101", "S106"] +"tests/*" = [ + "S101", # Use of assert + "S106", # Possible use of hard-coded password function arguments + "ANN201", # Missing return type annotation for public function +] From 5a272db095d1ea00db62a59abcfdfb2b5c3a590c Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 10 Apr 2023 06:15:36 -0700 Subject: [PATCH 02/15] Fix annotations and type checking errors Signed-off-by: Mattt Zmuda --- replicate/base_model.py | 15 ++++++++------- replicate/client.py | 12 ++++++------ replicate/collection.py | 35 ++++++++++++++++++++++++----------- replicate/files.py | 3 ++- replicate/model.py | 10 ++++++++-- replicate/prediction.py | 37 +++++++++++++++++++------------------ replicate/training.py | 22 +++++++++++++--------- replicate/version.py | 14 +++++++++++--- 8 files changed, 91 insertions(+), 57 deletions(-) diff --git a/replicate/base_model.py b/replicate/base_model.py index 164ca045..44e0a20c 100644 --- a/replicate/base_model.py +++ b/replicate/base_model.py @@ -1,9 +1,10 @@ -from typing import ForwardRef +from typing import TYPE_CHECKING -import pydantic +if TYPE_CHECKING: + from replicate.client import Client + from replicate.collection import Collection -Client = ForwardRef("Client") -Collection = ForwardRef("Collection") +import pydantic class BaseModel(pydantic.BaseModel): @@ -11,10 +12,10 @@ class BaseModel(pydantic.BaseModel): A base class for representing a single object on the server. """ - _client: Client = pydantic.PrivateAttr() - _collection: Collection = pydantic.PrivateAttr() + _client: "Client" = pydantic.PrivateAttr() + _collection: "Collection" = pydantic.PrivateAttr() - def reload(self): + def reload(self) -> None: """ Load this object from the server again. """ diff --git a/replicate/client.py b/replicate/client.py index 75409787..385f4a9a 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -1,7 +1,7 @@ import os import re from json import JSONDecodeError -from typing import Any, Iterator, Union +from typing import Any, Iterator, Optional, Union import requests from requests.adapters import HTTPAdapter, Retry @@ -14,7 +14,7 @@ class Client: - def __init__(self, api_token=None) -> None: + def __init__(self, api_token: Optional[str] = None) -> None: super().__init__() # Client is instantiated at import time, so do as little as possible. # This includes resolving environment variables -- they might be set programmatically. @@ -61,7 +61,7 @@ def __init__(self, api_token=None) -> None: self.write_session.mount("http://", HTTPAdapter(max_retries=write_retries)) self.write_session.mount("https://", HTTPAdapter(max_retries=write_retries)) - def _request(self, method: str, path: str, **kwargs): + def _request(self, method: str, path: str, **kwargs) -> requests.Response: # from requests.Session if method in ["GET", "OPTIONS"]: kwargs.setdefault("allow_redirects", True) @@ -81,13 +81,13 @@ def _request(self, method: str, path: str, **kwargs): raise ReplicateError(f"HTTP error: {resp.status_code, resp.reason}") return resp - def _headers(self): + def _headers(self) -> dict[str, str]: return { "Authorization": f"Token {self._api_token()}", "User-Agent": f"replicate-python@{__version__}", } - def _api_token(self): + def _api_token(self) -> str: token = self.api_token # Evaluate lazily in case environment variable is set with dotenv, or something if token is None: @@ -112,7 +112,7 @@ def predictions(self) -> PredictionCollection: def trainings(self) -> TrainingCollection: return TrainingCollection(client=self) - def run(self, model_version, **kwargs) -> Union[Any, Iterator[Any]]: + def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]: """ Run a model in the format owner/name:version. """ diff --git a/replicate/collection.py b/replicate/collection.py index 94766aae..587aaedd 100644 --- a/replicate/collection.py +++ b/replicate/collection.py @@ -1,27 +1,40 @@ +import abc +from typing import TYPE_CHECKING, Any, Generic, List, Optional, TypeVar + +if TYPE_CHECKING: + from replicate.client import Client + from replicate.base_model import BaseModel +Model = TypeVar("Model", BaseModel, Any) + -class Collection: +class Collection(abc.ABC, Generic[Model]): """ A base class for representing all objects of a particular type on the server. """ - model: BaseModel = None - - def __init__(self, client=None): + def __init__(self, client: "Client") -> None: self._client = client - def list(self): - raise NotImplementedError + @abc.abstractproperty + def model(self) -> Model: + pass + + @abc.abstractmethod + def list(self) -> List[Model]: + pass - def get(self, key): - raise NotImplementedError + @abc.abstractmethod + def get(self, key: str) -> Model: + pass - def create(self, attrs=None): - raise NotImplementedError + @abc.abstractmethod + def create(self, **kwargs) -> Model: + pass - def prepare_model(self, attrs): + def prepare_model(self, attrs: Model | dict) -> Model: """ Create a model from a set of attributes. """ diff --git a/replicate/files.py b/replicate/files.py index c93d411f..55a6612c 100644 --- a/replicate/files.py +++ b/replicate/files.py @@ -2,11 +2,12 @@ import io import mimetypes import os +from typing import Optional import requests -def upload_file(fh: io.IOBase, output_file_prefix: str = None) -> str: +def upload_file(fh: io.IOBase, output_file_prefix: Optional[str] = None) -> str: """ Lifted straight from cog.files """ diff --git a/replicate/model.py b/replicate/model.py index 5d240955..6ddbd037 100644 --- a/replicate/model.py +++ b/replicate/model.py @@ -8,21 +8,27 @@ class Model(BaseModel): username: str name: str - def predict(self, *args, **kwargs): + def predict(self, *args, **kwargs) -> None: raise ReplicateException( "The `model.predict()` method has been removed, because it's unstable: if a new version of the model you're using is pushed and its API has changed, your code may break. Use `version.predict()` instead. See https://github.com/replicate/replicate-python#readme" ) @property - def versions(self): + def versions(self) -> VersionCollection: return VersionCollection(client=self._client, model=self) class ModelCollection(Collection): model = Model + def list(self) -> list[Model]: + raise NotImplementedError() + def get(self, name: str) -> Model: # TODO: fetch model from server # TODO: support permanent IDs username, name = name.split("/") return self.prepare_model({"username": username, "name": name}) + + def create(self, **kwargs) -> Model: + raise NotImplementedError() diff --git a/replicate/prediction.py b/replicate/prediction.py index 0e4d2349..a36778e7 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -21,7 +21,7 @@ class Prediction(BaseModel): created_at: Optional[str] completed_at: Optional[str] - def wait(self): + def wait(self) -> None: """Wait for prediction to finish.""" while self.status not in ["succeeded", "failed", "canceled"]: time.sleep(self._client.poll_interval) @@ -47,7 +47,7 @@ def output_iterator(self) -> Iterator[Any]: for output in new_output: yield output - def cancel(self): + def cancel(self) -> None: """Cancel a currently running prediction""" self._client._request("POST", f"/v1/predictions/{self.id}/cancel") @@ -55,6 +55,22 @@ def cancel(self): class PredictionCollection(Collection): model = Prediction + def list(self) -> List[Prediction]: + resp = self._client._request("GET", "/v1/predictions") + # TODO: paginate + predictions = resp.json()["results"] + for prediction in predictions: + # HACK: resolve this? make it lazy somehow? + del prediction["version"] + return [self.prepare_model(obj) for obj in predictions] + + def get(self, id: str) -> Prediction: + resp = self._client._request("GET", f"/v1/predictions/{id}") + obj = resp.json() + # HACK: resolve this? make it lazy somehow? + del obj["version"] + return self.prepare_model(obj) + def create( self, version: Version, @@ -62,6 +78,7 @@ def create( webhook: Optional[str] = None, webhook_completed: Optional[str] = None, webhook_events_filter: Optional[List[str]] = None, + **kwargs, ) -> Prediction: input = encode_json(input, upload_file=upload_file) body = { @@ -83,19 +100,3 @@ def create( obj = resp.json() obj["version"] = version return self.prepare_model(obj) - - def get(self, id: str) -> Prediction: - resp = self._client._request("GET", f"/v1/predictions/{id}") - obj = resp.json() - # HACK: resolve this? make it lazy somehow? - del obj["version"] - return self.prepare_model(obj) - - def list(self) -> List[Prediction]: - resp = self._client._request("GET", "/v1/predictions") - # TODO: paginate - predictions = resp.json()["results"] - for prediction in predictions: - # HACK: resolve this? make it lazy somehow? - del prediction["version"] - return [self.prepare_model(obj) for obj in predictions] diff --git a/replicate/training.py b/replicate/training.py index bf43a465..4d310f96 100644 --- a/replicate/training.py +++ b/replicate/training.py @@ -21,7 +21,7 @@ class Training(BaseModel): status: str version: str - def cancel(self): + def cancel(self) -> None: """Cancel a running training""" self._client._request("POST", f"/v1/trainings/{self.id}/cancel") @@ -29,6 +29,17 @@ def cancel(self): class TrainingCollection(Collection): model = Training + def list(self) -> List[Training]: + raise NotImplementedError() + + def get(self, id: str) -> Training: + resp = self._client._request( + "GET", + f"/v1/trainings/{id}", + ) + obj = resp.json() + return self.prepare_model(obj) + def create( self, version: str, @@ -36,6 +47,7 @@ def create( destination: str, webhook: Optional[str] = None, webhook_events_filter: Optional[List[str]] = None, + **kwargs, ) -> Training: input = encode_json(input, upload_file=upload_file) body = { @@ -66,11 +78,3 @@ def create( ) obj = resp.json() return self.prepare_model(obj) - - def get(self, id: str) -> Training: - resp = self._client._request( - "GET", - f"/v1/trainings/{id}", - ) - obj = resp.json() - return self.prepare_model(obj) diff --git a/replicate/version.py b/replicate/version.py index 02a5fa5d..3fa27a92 100644 --- a/replicate/version.py +++ b/replicate/version.py @@ -1,6 +1,11 @@ import datetime import warnings -from typing import Any, Iterator, List, Union +from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union + +if TYPE_CHECKING: + from replicate.client import Client + from replicate.model import Model + from replicate.base_model import BaseModel from replicate.collection import Collection @@ -36,7 +41,7 @@ def predict(self, **kwargs) -> Union[Any, Iterator[Any]]: raise ModelError(prediction.error) return prediction.output - def get_transformed_schema(self): + def get_transformed_schema(self) -> Any: schema = self.openapi_schema schema = make_schema_backwards_compatible(schema, self.cog_version) return schema @@ -45,7 +50,7 @@ def get_transformed_schema(self): class VersionCollection(Collection): model = Version - def __init__(self, client, model): + def __init__(self, client: Optional["Client"], model: "Model") -> None: super().__init__(client=client) self._model = model @@ -59,6 +64,9 @@ def get(self, id: str) -> Version: ) return self.prepare_model(resp.json()) + def create(self, **kwargs) -> Version: + raise NotImplementedError() + def list(self) -> List[Version]: """ Return a list of all versions for a model. From fc04ef26a2221cfbc0cf24e5d2b554ec944139a9 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 10 Apr 2023 06:16:07 -0700 Subject: [PATCH 03/15] Ignore ANN001 in tests Signed-off-by: Mattt Zmuda --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 54bc27e5..0628d0c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ select = [ ignore = [ "E501", # Line too long "S113", # Probable use of requests call without timeout + "ANN001", # Missing type annotation for function argument "ANN002", # Missing type annotation for `*args` "ANN003", # Missing type annotation for `**kwargs` "ANN101", # Missing type annotation for self in method From 6338bdc037e675c050db9a5aaf860a2a38dab700 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 10 Apr 2023 06:21:09 -0700 Subject: [PATCH 04/15] Fix remaining annotation warnings Signed-off-by: Mattt Zmuda --- replicate/collection.py | 2 +- replicate/json.py | 4 +++- replicate/schema.py | 7 +++++-- replicate/version.py | 8 ++++---- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/replicate/collection.py b/replicate/collection.py index 587aaedd..b38e2a05 100644 --- a/replicate/collection.py +++ b/replicate/collection.py @@ -1,5 +1,5 @@ import abc -from typing import TYPE_CHECKING, Any, Generic, List, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Generic, List, TypeVar if TYPE_CHECKING: from replicate.client import Client diff --git a/replicate/json.py b/replicate/json.py index bab3b4d8..cd0b864e 100644 --- a/replicate/json.py +++ b/replicate/json.py @@ -11,7 +11,9 @@ has_numpy = False -def encode_json(obj: Any, upload_file: Callable[[io.IOBase], str]) -> Any: +def encode_json( + obj: Any, upload_file: Callable[[io.IOBase], str] # noqa: ANN401 +) -> Any: # noqa: ANN401 """ Returns a JSON-compatible version of the object. Effectively the same thing as cog.json.encode_json. """ diff --git a/replicate/schema.py b/replicate/schema.py index 34642cc3..a48d2351 100644 --- a/replicate/schema.py +++ b/replicate/schema.py @@ -3,12 +3,15 @@ # TODO: this code is shared with replicate's backend. Maybe we should put it in the Cog Python package as the source of truth? -def version_has_no_array_type(cog_version): +def version_has_no_array_type(cog_version: str) -> bool: """Iterators have x-cog-array-type=iterator in the schema from 0.3.9 onward""" return version.parse(cog_version) < version.parse("0.3.9") -def make_schema_backwards_compatible(schema, version): +def make_schema_backwards_compatible( + schema: dict, + version: str, +) -> dict: """A place to add backwards compatibility logic for our openapi schema""" # If the top-level output is an array, assume it is an iterator in old versions which didn't have an array type if version_has_no_array_type(version): diff --git a/replicate/version.py b/replicate/version.py index 3fa27a92..d4ed9108 100644 --- a/replicate/version.py +++ b/replicate/version.py @@ -1,6 +1,6 @@ import datetime import warnings -from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union +from typing import TYPE_CHECKING, Any, Iterator, List, Union if TYPE_CHECKING: from replicate.client import Client @@ -17,7 +17,7 @@ class Version(BaseModel): id: str created_at: datetime.datetime cog_version: str - openapi_schema: Any + openapi_schema: dict def predict(self, **kwargs) -> Union[Any, Iterator[Any]]: warnings.warn( @@ -41,7 +41,7 @@ def predict(self, **kwargs) -> Union[Any, Iterator[Any]]: raise ModelError(prediction.error) return prediction.output - def get_transformed_schema(self) -> Any: + def get_transformed_schema(self) -> dict: schema = self.openapi_schema schema = make_schema_backwards_compatible(schema, self.cog_version) return schema @@ -50,7 +50,7 @@ def get_transformed_schema(self) -> Any: class VersionCollection(Collection): model = Version - def __init__(self, client: Optional["Client"], model: "Model") -> None: + def __init__(self, client: "Client", model: "Model") -> None: super().__init__(client=client) self._model = model From 3fd91837bde0b8dc6a8b5031389601d7b9ee75e7 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 10 Apr 2023 06:39:04 -0700 Subject: [PATCH 05/15] Add missing id property to BaseModel Signed-off-by: Mattt Zmuda --- replicate/base_model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/replicate/base_model.py b/replicate/base_model.py index 44e0a20c..b3cf1d48 100644 --- a/replicate/base_model.py +++ b/replicate/base_model.py @@ -12,6 +12,8 @@ class BaseModel(pydantic.BaseModel): A base class for representing a single object on the server. """ + id: str + _client: "Client" = pydantic.PrivateAttr() _collection: "Collection" = pydantic.PrivateAttr() From 3d11d07c7c3cf447eb3e0bde013d5ec6fc9820f3 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Tue, 11 Apr 2023 10:26:41 -0700 Subject: [PATCH 06/15] Fix type checking for collection model Signed-off-by: Mattt Zmuda --- replicate/collection.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/replicate/collection.py b/replicate/collection.py index b38e2a05..3e5290f8 100644 --- a/replicate/collection.py +++ b/replicate/collection.py @@ -1,12 +1,12 @@ import abc -from typing import TYPE_CHECKING, Any, Generic, List, TypeVar +from typing import TYPE_CHECKING, Generic, List, TypeVar, cast if TYPE_CHECKING: from replicate.client import Client from replicate.base_model import BaseModel -Model = TypeVar("Model", BaseModel, Any) +Model = TypeVar("Model", bound=BaseModel) class Collection(abc.ABC, Generic[Model]): @@ -41,11 +41,14 @@ def prepare_model(self, attrs: Model | dict) -> Model: if isinstance(attrs, BaseModel): attrs._client = self._client attrs._collection = self - return attrs - elif isinstance(attrs, dict): + return cast(Model, attrs) + elif ( + isinstance(attrs, dict) and self.model is not None and callable(self.model) + ): model = self.model(**attrs) model._client = self._client model._collection = self return model else: - raise Exception(f"Can't create {self.model.__name__} from {attrs}") + name = self.model.__name__ if hasattr(self.model, "__name__") else "model" + raise Exception(f"Can't create {name} from {attrs}") From 8a4681506bd0268514ad71505ba91885891318fe Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Tue, 11 Apr 2023 10:37:56 -0700 Subject: [PATCH 07/15] Add mypy as development dependency Signed-off-by: Mattt Zmuda --- pyproject.toml | 8 +++++++- requirements-dev.txt | 12 +++++++++--- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0628d0c8..46518389 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,13 @@ license = { file = "LICENSE" } authors = [{ name = "Replicate, Inc." }] requires-python = ">=3.8" dependencies = ["packaging", "pydantic>1", "requests>2"] -optional-dependencies = { dev = ["black", "pytest", "responses", "ruff"] } +optional-dependencies = { dev = [ + "black", + "mypy", + "pytest", + "responses", + "ruff", +] } [project.urls] homepage = "https://replicate.com" diff --git a/requirements-dev.txt b/requirements-dev.txt index c3865e8e..e23d31d9 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,7 +2,7 @@ # This file is autogenerated by pip-compile with Python 3.11 # by the following command: # -# pip-compile --extra=dev --output-file=requirements-dev.txt pyproject.toml +# pip-compile --extra=dev --output-file=requirements-dev.txt --resolver=backtracking pyproject.toml # attrs==22.2.0 # via pytest @@ -18,8 +18,12 @@ idna==3.4 # via requests iniconfig==2.0.0 # via pytest +mypy==1.2.0 + # via replicate (pyproject.toml) mypy-extensions==1.0.0 - # via black + # via + # black + # mypy packaging==23.0 # via # black @@ -48,7 +52,9 @@ ruff==0.0.261 types-pyyaml==6.0.12.9 # via responses typing-extensions==4.5.0 - # via pydantic + # via + # mypy + # pydantic urllib3==1.26.15 # via # requests From d4492eef2731840a2ca7ce366928a5e4e6503ae5 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Tue, 11 Apr 2023 10:38:27 -0700 Subject: [PATCH 08/15] Ignore type checking for overloaded create methods Signed-off-by: Mattt Zmuda --- replicate/prediction.py | 2 +- replicate/training.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/replicate/prediction.py b/replicate/prediction.py index a36778e7..a69217d8 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -71,7 +71,7 @@ def get(self, id: str) -> Prediction: del obj["version"] return self.prepare_model(obj) - def create( + def create( # type: ignore self, version: Version, input: Dict[str, Any], diff --git a/replicate/training.py b/replicate/training.py index 4d310f96..b60c9613 100644 --- a/replicate/training.py +++ b/replicate/training.py @@ -40,7 +40,7 @@ def get(self, id: str) -> Training: obj = resp.json() return self.prepare_model(obj) - def create( + def create( # type: ignore self, version: str, input: Dict[str, Any], From 71a1a46fd1f23923060c11b214c84d3a8d060239 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Tue, 11 Apr 2023 10:43:53 -0700 Subject: [PATCH 09/15] Run black and mypy in lint step Signed-off-by: Mattt Zmuda --- .github/workflows/ci.yaml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 72bec4a4..7037436b 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -25,7 +25,10 @@ jobs: - name: Install dependencies run: python -m pip install -r requirements.txt -r requirements-dev.txt . - name: Lint - run: python -m ruff . + run: | + python -m ruff . + python -m black --check . + python -m mypy replicate - name: Test run: python -m pytest From 85ae6a4143400b8566c16a4623b7e6cb1943b7fe Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Tue, 11 Apr 2023 10:50:49 -0700 Subject: [PATCH 10/15] Install mypy types in CI workflow Signed-off-by: Mattt Zmuda --- .github/workflows/ci.yaml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 7037436b..937a2df5 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -23,12 +23,14 @@ jobs: python-version: ${{ matrix.python-version }} cache: "pip" - name: Install dependencies - run: python -m pip install -r requirements.txt -r requirements-dev.txt . + run: | + python -m pip install -r requirements.txt -r requirements-dev.txt . + python -m mypy --install-types - name: Lint run: | + python -m mypy replicate python -m ruff . python -m black --check . - python -m mypy replicate - name: Test run: python -m pytest From 7c50b892a3fd047206840b0416126228a6374f01 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Tue, 11 Apr 2023 10:51:18 -0700 Subject: [PATCH 11/15] Use List and Union from typing Signed-off-by: Mattt Zmuda --- replicate/collection.py | 4 ++-- replicate/model.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/replicate/collection.py b/replicate/collection.py index 3e5290f8..92e7a88a 100644 --- a/replicate/collection.py +++ b/replicate/collection.py @@ -1,5 +1,5 @@ import abc -from typing import TYPE_CHECKING, Generic, List, TypeVar, cast +from typing import TYPE_CHECKING, Dict, Generic, List, TypeVar, Union, cast if TYPE_CHECKING: from replicate.client import Client @@ -34,7 +34,7 @@ def get(self, key: str) -> Model: def create(self, **kwargs) -> Model: pass - def prepare_model(self, attrs: Model | dict) -> Model: + def prepare_model(self, attrs: Union[Model, Dict]) -> Model: """ Create a model from a set of attributes. """ diff --git a/replicate/model.py b/replicate/model.py index 6ddbd037..85b4328f 100644 --- a/replicate/model.py +++ b/replicate/model.py @@ -1,3 +1,5 @@ +from typing import List + from replicate.base_model import BaseModel from replicate.collection import Collection from replicate.exceptions import ReplicateException @@ -21,7 +23,7 @@ def versions(self) -> VersionCollection: class ModelCollection(Collection): model = Model - def list(self) -> list[Model]: + def list(self) -> List[Model]: raise NotImplementedError() def get(self, name: str) -> Model: From 7a9fb4f256e63f5e34f1394f547369f5aa1adbb4 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Tue, 11 Apr 2023 10:56:02 -0700 Subject: [PATCH 12/15] Fix CI setup for mypy Signed-off-by: Mattt Zmuda --- .github/workflows/ci.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 937a2df5..a2992dae 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -25,7 +25,8 @@ jobs: - name: Install dependencies run: | python -m pip install -r requirements.txt -r requirements-dev.txt . - python -m mypy --install-types + yes | python -m mypy --install-types replicate || true + - name: Lint run: | python -m mypy replicate From ccd9004dc9ed5d41b650946f5e25bed512b094e4 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Tue, 11 Apr 2023 13:38:27 -0700 Subject: [PATCH 13/15] Fix pydantic validation error due to missing id field in model Signed-off-by: Mattt Zmuda --- replicate/model.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/replicate/model.py b/replicate/model.py index 85b4328f..cd8728fe 100644 --- a/replicate/model.py +++ b/replicate/model.py @@ -1,4 +1,4 @@ -from typing import List +from typing import Dict, List, Union from replicate.base_model import BaseModel from replicate.collection import Collection @@ -34,3 +34,7 @@ def get(self, name: str) -> Model: def create(self, **kwargs) -> Model: raise NotImplementedError() + + def prepare_model(self, attrs: Union[Model, Dict]) -> Model: + attrs["id"] = f"{attrs['username']}/{attrs['name']}" + return super().prepare_model(attrs) From a2d541b1bdb89d4841e2a231cb02863aa3461caf Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Tue, 11 Apr 2023 13:45:50 -0700 Subject: [PATCH 14/15] Fix urllib3 DeprecationWarning about using 'method_whitelist' with Retry Signed-off-by: Mattt Zmuda --- replicate/client.py | 4 ++-- replicate/model.py | 5 ++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/replicate/client.py b/replicate/client.py index 385f4a9a..0b44f8b0 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -30,7 +30,7 @@ def __init__(self, api_token: Optional[str] = None) -> None: total=5, backoff_factor=2, # Only retry 500s on GET so we don't unintionally mutute data - method_whitelist=["GET"], + allowed_methods=["GET"], # https://support.cloudflare.com/hc/en-us/articles/115003011431-Troubleshooting-Cloudflare-5XX-errors status_forcelist=[ 429, @@ -54,7 +54,7 @@ def __init__(self, api_token: Optional[str] = None) -> None: write_retries = Retry( total=5, backoff_factor=2, - method_whitelist=["POST", "PUT"], + allowed_methods=["POST", "PUT"], # Only retry POST/PUT requests on rate limits, so we don't unintionally mutute data status_forcelist=[429], ) diff --git a/replicate/model.py b/replicate/model.py index cd8728fe..d6b32fcd 100644 --- a/replicate/model.py +++ b/replicate/model.py @@ -36,5 +36,8 @@ def create(self, **kwargs) -> Model: raise NotImplementedError() def prepare_model(self, attrs: Union[Model, Dict]) -> Model: - attrs["id"] = f"{attrs['username']}/{attrs['name']}" + if isinstance(attrs, BaseModel): + attrs.id = f"{attrs.username}/{attrs.name}" + elif isinstance(attrs, dict): + attrs["id"] = f"{attrs['username']}/{attrs['name']}" return super().prepare_model(attrs) From 01cedbe34f76a14e3f9db5839b05139926a2b69a Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Tue, 11 Apr 2023 13:59:15 -0700 Subject: [PATCH 15/15] error: "dict" is not subscriptable, use "typing.Dict" instead Signed-off-by: Mattt Zmuda --- replicate/client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/replicate/client.py b/replicate/client.py index 0b44f8b0..4196d6aa 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -1,7 +1,7 @@ import os import re from json import JSONDecodeError -from typing import Any, Iterator, Optional, Union +from typing import Any, Dict, Iterator, Optional, Union import requests from requests.adapters import HTTPAdapter, Retry @@ -81,7 +81,7 @@ def _request(self, method: str, path: str, **kwargs) -> requests.Response: raise ReplicateError(f"HTTP error: {resp.status_code, resp.reason}") return resp - def _headers(self) -> dict[str, str]: + def _headers(self) -> Dict[str, str]: return { "Authorization": f"Token {self._api_token()}", "User-Agent": f"replicate-python@{__version__}",