diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 77e78619..72bec4a4 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -24,6 +24,8 @@ jobs: cache: "pip" - name: Install dependencies run: python -m pip install -r requirements.txt -r requirements-dev.txt . + - name: Lint + run: python -m ruff . - name: Test run: python -m pytest diff --git a/pyproject.toml b/pyproject.toml index 3376263a..24400d23 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ license = { file = "LICENSE" } authors = [{ name = "Replicate, Inc." }] requires-python = ">=3.8" dependencies = ["packaging", "pydantic>1", "requests>2"] -optional-dependencies = { dev = ["black", "pytest", "responses"] } +optional-dependencies = { dev = ["black", "pytest", "responses", "ruff"] } [project.urls] homepage = "https://replicate.com" @@ -19,3 +19,23 @@ repository = "https://github.com/replicate/replicate-python" [tool.pytest.ini_options] testpaths = "tests/" + +[tool.ruff] +select = [ + "E", # pycodestyle error + "F", # Pyflakes + "I", # isort + "W", # pycodestyle warning + "UP", # pyupgrade + "S", # flake8-bandit + "BLE", # flake8-blind-except + "FBT", # flake8-boolean-trap + "B", # flake8-bugbear +] +ignore = [ + "E501", # Line too long + "S113", # Probable use of requests call without timeout +] + +[tool.ruff.per-file-ignores] +"tests/*" = ["S101", "S106"] diff --git a/replicate/base_model.py b/replicate/base_model.py index 51fadf21..164ca045 100644 --- a/replicate/base_model.py +++ b/replicate/base_model.py @@ -1,6 +1,6 @@ from typing import ForwardRef -import pydantic +import pydantic Client = ForwardRef("Client") Collection = ForwardRef("Collection") diff --git a/replicate/collection.py b/replicate/collection.py index 1b9c1368..94766aae 100644 --- a/replicate/collection.py +++ b/replicate/collection.py @@ -35,4 +35,4 @@ def prepare_model(self, attrs): model._collection = self return model else: - raise Exception("Can't create %s from %s" % (self.model.__name__, attrs)) + raise Exception(f"Can't create {self.model.__name__} from {attrs}") diff --git a/replicate/files.py b/replicate/files.py index 82f70c89..c93d411f 100644 --- a/replicate/files.py +++ b/replicate/files.py @@ -15,7 +15,7 @@ def upload_file(fh: io.IOBase, output_file_prefix: str = None) -> str: if output_file_prefix is not None: name = getattr(fh, "name", "output") url = output_file_prefix + os.path.basename(name) - resp = requests.put(url, files={"file": fh}) + resp = requests.put(url, files={"file": fh}, timeout=None) resp.raise_for_status() return url diff --git a/replicate/json.py b/replicate/json.py index c23f40cc..bab3b4d8 100644 --- a/replicate/json.py +++ b/replicate/json.py @@ -3,7 +3,6 @@ from types import GeneratorType from typing import Any, Callable - try: import numpy as np # type: ignore diff --git a/replicate/model.py b/replicate/model.py index e56204b2..5d240955 100644 --- a/replicate/model.py +++ b/replicate/model.py @@ -1,5 +1,3 @@ -from typing import List - from replicate.base_model import BaseModel from replicate.collection import Collection from replicate.exceptions import ReplicateException @@ -12,7 +10,7 @@ class Model(BaseModel): def predict(self, *args, **kwargs): raise ReplicateException( - f"The `model.predict()` method has been removed, because it's unstable: if a new version of the model you're using is pushed and its API has changed, your code may break. Use `version.predict()` instead. See https://github.com/replicate/replicate-python#readme" + "The `model.predict()` method has been removed, because it's unstable: if a new version of the model you're using is pushed and its API has changed, your code may break. Use `version.predict()` instead. See https://github.com/replicate/replicate-python#readme" ) @property diff --git a/replicate/prediction.py b/replicate/prediction.py index 1d5df1a6..0e4d2349 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -3,7 +3,7 @@ from replicate.base_model import BaseModel from replicate.collection import Collection -from replicate.exceptions import ModelError, ReplicateException +from replicate.exceptions import ModelError from replicate.files import upload_file from replicate.json import encode_json from replicate.version import Version @@ -92,7 +92,7 @@ def get(self, id: str) -> Prediction: return self.prepare_model(obj) def list(self) -> List[Prediction]: - resp = self._client._request("GET", f"/v1/predictions") + resp = self._client._request("GET", "/v1/predictions") # TODO: paginate predictions = resp.json()["results"] for prediction in predictions: diff --git a/replicate/training.py b/replicate/training.py index 3cbe920f..bf43a465 100644 --- a/replicate/training.py +++ b/replicate/training.py @@ -1,13 +1,11 @@ import re -import time -from typing import Any, Dict, Iterator, List, Optional +from typing import Any, Dict, List, Optional from replicate.base_model import BaseModel from replicate.collection import Collection -from replicate.exceptions import ModelError, ReplicateException +from replicate.exceptions import ReplicateException from replicate.files import upload_file from replicate.json import encode_json -from replicate.version import Version class Training(BaseModel): @@ -55,7 +53,7 @@ def create( ) if not match: raise ReplicateException( - f"version must be in format username/model_name:version_id" + "version must be in format username/model_name:version_id" ) username = match.group("username") model_name = match.group("model_name") diff --git a/replicate/version.py b/replicate/version.py index cf3376a8..02a5fa5d 100644 --- a/replicate/version.py +++ b/replicate/version.py @@ -18,6 +18,7 @@ def predict(self, **kwargs) -> Union[Any, Iterator[Any]]: warnings.warn( "version.predict() is deprecated. Use replicate.run() instead. It will be removed before version 1.0.", DeprecationWarning, + stacklevel=1, ) prediction = self._client.predictions.create(version=self, input=kwargs) diff --git a/requirements-dev.txt b/requirements-dev.txt index f10b82a5..c3865e8e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -43,6 +43,8 @@ requests==2.28.2 # responses responses==0.23.1 # via replicate (pyproject.toml) +ruff==0.0.261 + # via replicate (pyproject.toml) types-pyyaml==6.0.12.9 # via responses typing-extensions==4.5.0 diff --git a/tests/test_prediction.py b/tests/test_prediction.py index bbfecb9d..a0d08ae9 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -1,8 +1,6 @@ import responses from responses import matchers -import replicate - from .factories import create_client, create_version @@ -41,7 +39,7 @@ def test_create_works_with_webhooks(): }, ) - prediction = client.predictions.create( + client.predictions.create( version=version, input={"text": "world"}, webhook="https://example.com/webhook", @@ -156,8 +154,8 @@ def test_async_timings(): ) assert prediction.created_at == "2022-04-26T20:00:40.658234Z" - assert prediction.completed_at == None - assert prediction.output == None + assert prediction.completed_at is None + assert prediction.output is None prediction.wait() assert prediction.created_at == "2022-04-26T20:00:40.658234Z" assert prediction.completed_at == "2022-04-26T20:02:27.648305Z" diff --git a/tests/test_version.py b/tests/test_version.py index 59984224..fb08ec50 100644 --- a/tests/test_version.py +++ b/tests/test_version.py @@ -2,9 +2,10 @@ import pytest import responses -from replicate.exceptions import ModelError from responses import matchers +from replicate.exceptions import ModelError + from .factories import ( create_version, create_version_with_iterator_output,