From bf6b23c3911b7569a5652c00e6b5e776477adb2e Mon Sep 17 00:00:00 2001 From: Yorick van Pelt Date: Tue, 21 Nov 2023 15:42:19 +0100 Subject: [PATCH 01/22] python: add trivial type casts to make typechecker happy Signed-off-by: Yorick van Pelt --- python/cog/command/ast_openapi_schema.py | 10 ++++++---- python/cog/json.py | 6 ++---- python/cog/predictor.py | 16 ++++++++++------ python/cog/server/runner.py | 7 ++++--- python/cog/server/webhook.py | 12 +++++------- python/cog/types.py | 3 ++- 6 files changed, 29 insertions(+), 25 deletions(-) diff --git a/python/cog/command/ast_openapi_schema.py b/python/cog/command/ast_openapi_schema.py index 7e886122ea..6b3cc5a139 100644 --- a/python/cog/command/ast_openapi_schema.py +++ b/python/cog/command/ast_openapi_schema.py @@ -1,6 +1,8 @@ import ast import json import sys +import types +import typing from pathlib import Path try: @@ -320,7 +322,7 @@ def get_value(node: ast.AST) -> "int | float | complex | str | list": if isinstance(node, (ast.List, ast.Tuple)): return [get_value(e) for e in node.elts] if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub): - return -get_value(node.operand) + return -typing.cast(int | float | complex, get_value(node.operand)) raise ValueError("Unexpected node type", type(node)) @@ -344,7 +346,7 @@ def get_call_name(call: ast.Call) -> str: raise ValueError("Unexpected node type", type(call), ast.unparse(call)) -def parse_args(tree: ast.AST) -> "list[tuple[ast.arg, ast.expr | Ellipsis]]": +def parse_args(tree: ast.AST) -> "list[tuple[ast.arg, ast.expr | types.EllipsisType]]": """Parse argument, default pairs from a file with a predict function""" predict = find(tree, "predict") assert isinstance(predict, ast.FunctionDef) @@ -404,8 +406,8 @@ def resolve_name(node: ast.expr) -> str: if isinstance(node, ast.Name): return node.id if isinstance(node, ast.Index): - # depricated, but needed for py3.8 - return resolve_name(node.value) + # deprecated, but needed for py3.8 + return resolve_name(node.value) # type: ignore if isinstance(node, ast.Attribute): return node.attr if isinstance(node, ast.Subscript): diff --git a/python/cog/json.py b/python/cog/json.py index 0e51ba8597..843572eea0 100644 --- a/python/cog/json.py +++ b/python/cog/json.py @@ -29,11 +29,9 @@ def make_encodeable(obj: Any) -> Any: return obj.isoformat() try: import numpy as np # type: ignore - - has_numpy = True except ImportError: - has_numpy = False - if has_numpy: + pass + else: if isinstance(obj, np.integer): return int(obj) if isinstance(obj, np.floating): diff --git a/python/cog/predictor.py b/python/cog/predictor.py index 5af98281ae..da9233edd8 100644 --- a/python/cog/predictor.py +++ b/python/cog/predictor.py @@ -15,13 +15,14 @@ Optional, Type, Union, + cast, ) from unittest.mock import patch try: from typing import get_args, get_origin except ImportError: # Python < 3.8 - from typing_compat import get_args, get_origin + from typing_compat import get_args, get_origin # type: ignore import yaml from pydantic import BaseModel, Field, create_model @@ -76,18 +77,20 @@ def run_setup(predictor: BasePredictor) -> None: # TODO: Cog{File,Path}.validate(...) methods accept either "real" # paths/files or URLs to those things. In future we can probably tidy this # up a little bit. + # TODO: CogFile/CogPath should have subclasses for each of the subtypes if weights_url: if weights_type == CogFile: - weights = CogFile.validate(weights_url) + weights = cast(CogFile, CogFile.validate(weights_url)) elif weights_type == CogPath: - weights = CogPath.validate(weights_url) + # TODO: So this can be a url. evil! + weights = cast(CogPath, CogPath.validate(weights_url)) else: raise ValueError( f"Predictor.setup() has an argument 'weights' of type {weights_type}, but only File and Path are supported" ) elif os.path.exists(weights_path): if weights_type == CogFile: - weights = open(weights_path, "rb") + weights = cast(CogFile, open(weights_path, "rb")) elif weights_type == CogPath: weights = CogPath(weights_path) else: @@ -220,7 +223,7 @@ def validate_input_type(type: Type, name: str) -> None: f"No input type provided for parameter `{name}`. Supported input types are: {readable_types_list(ALLOWED_INPUT_TYPES)}, or a Union or List of those types." ) elif type not in ALLOWED_INPUT_TYPES: - if hasattr(type, "__origin__") and (type.__origin__ is Union or type.__origin__ is list): + if get_origin(type) in (Union, List): for t in get_args(type): validate_input_type(t, name) else: @@ -305,6 +308,7 @@ def get_output_type(predictor: BasePredictor) -> Type[BaseModel]: """ predict = get_predict(predictor) signature = inspect.signature(predict) + OutputType: Type[BaseModel] if signature.return_annotation is inspect.Signature.empty: raise TypeError( """You must set an output type. If your model can return multiple output types, you can explicitly set `Any` as the output type. @@ -327,7 +331,7 @@ def predict( if get_origin(OutputType) is Iterator: # Annotated allows us to attach Field annotations to the list, which we use to mark that this is an iterator # https://pydantic-docs.helpmanual.io/usage/schema/#typingannotated-fields - OutputType = Annotated[List[get_args(OutputType)[0]], Field(**{"x-cog-array-type": "iterator"})] # type: ignore + OutputType: Type[BaseModel] = Annotated[List[get_args(OutputType)[0]], Field(**{"x-cog-array-type": "iterator"})] # type: ignore if not hasattr(OutputType, "__name__") or OutputType.__name__ != "Output": # Wrap the type in a model called "Output" so it is a consistent name in the OpenAPI schema diff --git a/python/cog/server/runner.py b/python/cog/server/runner.py index 35c27cfb85..43f3274a1b 100644 --- a/python/cog/server/runner.py +++ b/python/cog/server/runner.py @@ -3,7 +3,7 @@ import traceback from asyncio import Task from datetime import datetime, timezone -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Dict, Optional, Tuple, cast import requests import structlog @@ -101,8 +101,9 @@ def predict( event_handler = create_event_handler(prediction, upload_url=upload_url) def handle_cleanup(_: Task) -> None: - if hasattr(prediction.input, "cleanup"): - prediction.input.cleanup() + input = cast(Any, prediction.input) + if hasattr(input, "cleanup"): + input.cleanup() self._response = event_handler.response coro = predict( diff --git a/python/cog/server/webhook.py b/python/cog/server/webhook.py index e55f6c910a..55a51dca61 100644 --- a/python/cog/server/webhook.py +++ b/python/cog/server/webhook.py @@ -13,14 +13,12 @@ def _get_version() -> str: - use_importlib = True try: - from importlib.metadata import version - except ImportError: - use_importlib = False - - try: - if use_importlib: + try: + from importlib.metadata import version + except ImportError: + pass + else: return version("cog") import pkg_resources diff --git a/python/cog/types.py b/python/cog/types.py index aa61275b3b..c8f20749fe 100644 --- a/python/cog/types.py +++ b/python/cog/types.py @@ -4,7 +4,8 @@ import pathlib import shutil import tempfile -import urllib +import urllib.parse +import urllib.request from typing import Any, Dict, Iterator, List, Optional, TypeVar, Union import requests From 38313b97ed2dac562595f3d8c65034724061a03b Mon Sep 17 00:00:00 2001 From: Yorick van Pelt Date: Tue, 21 Nov 2023 15:42:52 +0100 Subject: [PATCH 02/22] worker.cancel: don't crash on missing child Signed-off-by: Yorick van Pelt --- python/cog/server/worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/cog/server/worker.py b/python/cog/server/worker.py index c51d3c62d6..790eea361d 100644 --- a/python/cog/server/worker.py +++ b/python/cog/server/worker.py @@ -88,7 +88,7 @@ def terminate(self) -> None: self._child.join() def cancel(self) -> None: - if self._allow_cancel and self._child.is_alive(): + if self._allow_cancel and self._child.is_alive() and self._child.pid is not None: os.kill(self._child.pid, signal.SIGUSR1) self._allow_cancel = False From 109a90027bd48fa656067415ca3aefceda279b4f Mon Sep 17 00:00:00 2001 From: Yorick van Pelt Date: Tue, 21 Nov 2023 15:43:37 +0100 Subject: [PATCH 03/22] webhook_caller_filtered: expects as filter set() Signed-off-by: Yorick van Pelt --- python/cog/server/runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/cog/server/runner.py b/python/cog/server/runner.py index 43f3274a1b..35b8475b96 100644 --- a/python/cog/server/runner.py +++ b/python/cog/server/runner.py @@ -155,7 +155,7 @@ def create_event_handler( webhook_sender = None if webhook is not None: - webhook_sender = webhook_caller_filtered(webhook, events_filter) + webhook_sender = webhook_caller_filtered(webhook, set(events_filter)) file_uploader = None if upload_url is not None: From bd0bffd3b8aacf9eaad7abfeadf28620eed4380e Mon Sep 17 00:00:00 2001 From: Yorick van Pelt Date: Tue, 21 Nov 2023 15:44:03 +0100 Subject: [PATCH 04/22] cog.server.http: don't crash when cpu_count couldn't be determined Signed-off-by: Yorick van Pelt --- python/cog/server/http.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/cog/server/http.py b/python/cog/server/http.py index 2fe8e2f29a..8e055b087b 100644 --- a/python/cog/server/http.py +++ b/python/cog/server/http.py @@ -361,12 +361,15 @@ def _signal_set_event(signum: Any, frame: Any) -> None: config = load_config() - threads = args.threads + threads: Optional[int] = args.threads if threads is None: if config.get("build", {}).get("gpu", False): threads = 1 else: threads = os.cpu_count() + if threads is None: + log.warn("Unable to determine cpu count, defaulting to 1 thread") + threads = 1 shutdown_event = threading.Event() app = create_app( From b96e171af0edb5ed23e0ec7f8973e4d75311aca3 Mon Sep 17 00:00:00 2001 From: Yorick van Pelt Date: Tue, 21 Nov 2023 15:45:09 +0100 Subject: [PATCH 05/22] ast_openapi_schema: deal with bytes() defaults Signed-off-by: Yorick van Pelt --- python/cog/command/ast_openapi_schema.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/python/cog/command/ast_openapi_schema.py b/python/cog/command/ast_openapi_schema.py index 6b3cc5a139..b93fd1047c 100644 --- a/python/cog/command/ast_openapi_schema.py +++ b/python/cog/command/ast_openapi_schema.py @@ -310,7 +310,7 @@ def find(obj: ast.AST, name: str) -> ast.AST: return next(node for node in ast.walk(obj) if getattr(node, "name", "") == name) -def get_value(node: ast.AST) -> "int | float | complex | str | list": +def get_value(node: ast.AST) -> "int | float | complex | str | list | bytes": """Return the value of constant or list of constants""" if isinstance(node, ast.Constant): return node.value @@ -355,12 +355,22 @@ def parse_args(tree: ast.AST) -> "list[tuple[ast.arg, ast.expr | types.EllipsisT defaults = [...] * (len(args) - len(predict.args.defaults)) + predict.args.defaults return list(zip(args, defaults)) +def _decode_val(val: "int | float | complex | str | list | bytes") -> "int | float | complex | str | list": + if isinstance(val, bytes): + return val.decode("utf-8") + else: + return val def parse_assignment(assignment: ast.AST) -> "tuple[str | None, dict | None]": """Parse an assignment into an OpenAPI object property""" if isinstance(assignment, ast.AnnAssign): assert isinstance(assignment.target, ast.Name) # shouldn't be an Attribute - default = {"default": get_value(assignment.value)} if assignment.value else {} + default = {} + if assignment.value: + try: + default = {"default": _decode_val(get_value(assignment.value))} + except UnicodeDecodeError: + pass return assignment.target.id, { "title": assignment.target.id.replace("_", " ").title(), "type": OPENAPI_TYPES[get_annotation(assignment.annotation)], From e7c273f1c1da66e3f60a18adde4ca7b4b6293012 Mon Sep 17 00:00:00 2001 From: Yorick van Pelt Date: Tue, 21 Nov 2023 15:45:56 +0100 Subject: [PATCH 06/22] cog.server.http: change PredictionRequest to make pyright happy Signed-off-by: Yorick van Pelt --- python/cog/server/http.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/cog/server/http.py b/python/cog/server/http.py index 8e055b087b..2a12183c12 100644 --- a/python/cog/server/http.py +++ b/python/cog/server/http.py @@ -74,7 +74,8 @@ def create_app( InputType = get_input_type(predictor) OutputType = get_output_type(predictor) - PredictionRequest = schema.PredictionRequest.with_types(input_type=InputType) + class PredictionRequest(schema.PredictionRequest.with_types(input_type=InputType)): + pass PredictionResponse = schema.PredictionResponse.with_types( input_type=InputType, output_type=OutputType ) From 2bf8de9845b22fda8b42d805cbc73c4d26e28d39 Mon Sep 17 00:00:00 2001 From: Yorick van Pelt Date: Tue, 21 Nov 2023 16:12:57 +0100 Subject: [PATCH 07/22] cog.predictor: more possible types Signed-off-by: Yorick van Pelt --- python/cog/predictor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/cog/predictor.py b/python/cog/predictor.py index da9233edd8..b4e93aa17c 100644 --- a/python/cog/predictor.py +++ b/python/cog/predictor.py @@ -7,6 +7,7 @@ from abc import ABC, abstractmethod from collections.abc import Iterator from pathlib import Path +from types import UnionType from typing import ( Any, Callable, @@ -223,7 +224,7 @@ def validate_input_type(type: Type, name: str) -> None: f"No input type provided for parameter `{name}`. Supported input types are: {readable_types_list(ALLOWED_INPUT_TYPES)}, or a Union or List of those types." ) elif type not in ALLOWED_INPUT_TYPES: - if get_origin(type) in (Union, List): + if get_origin(type) in (Union, List, UnionType, list): for t in get_args(type): validate_input_type(t, name) else: From 9946ffc79c0b91df1bcac33799e2be4cfcf31f20 Mon Sep 17 00:00:00 2001 From: Yorick van Pelt Date: Tue, 21 Nov 2023 16:26:04 +0100 Subject: [PATCH 08/22] Fix: UnionType is new in 3.10 Signed-off-by: Yorick van Pelt --- python/cog/predictor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/cog/predictor.py b/python/cog/predictor.py index b4e93aa17c..24eb26e45e 100644 --- a/python/cog/predictor.py +++ b/python/cog/predictor.py @@ -7,7 +7,7 @@ from abc import ABC, abstractmethod from collections.abc import Iterator from pathlib import Path -from types import UnionType +import types from typing import ( Any, Callable, @@ -224,7 +224,7 @@ def validate_input_type(type: Type, name: str) -> None: f"No input type provided for parameter `{name}`. Supported input types are: {readable_types_list(ALLOWED_INPUT_TYPES)}, or a Union or List of those types." ) elif type not in ALLOWED_INPUT_TYPES: - if get_origin(type) in (Union, List, UnionType, list): + if get_origin(type) in (Union, List, list) or (hasattr(types, "UnionType") and get_origin(type) is types.UnionType): for t in get_args(type): validate_input_type(t, name) else: From 0b2340531fb79c570f327833f113b4773870d1fc Mon Sep 17 00:00:00 2001 From: Yorick van Pelt Date: Tue, 21 Nov 2023 16:42:44 +0100 Subject: [PATCH 09/22] Fix: | syntax is only available in 3.10 Signed-off-by: Yorick van Pelt --- python/cog/command/ast_openapi_schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/cog/command/ast_openapi_schema.py b/python/cog/command/ast_openapi_schema.py index b93fd1047c..b147e32c3f 100644 --- a/python/cog/command/ast_openapi_schema.py +++ b/python/cog/command/ast_openapi_schema.py @@ -322,7 +322,7 @@ def get_value(node: ast.AST) -> "int | float | complex | str | list | bytes": if isinstance(node, (ast.List, ast.Tuple)): return [get_value(e) for e in node.elts] if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub): - return -typing.cast(int | float | complex, get_value(node.operand)) + return -typing.cast(typing.Union[int, float, complex], get_value(node.operand)) raise ValueError("Unexpected node type", type(node)) From 0ba9855298350edd1c6526d51dfbf61397d00377 Mon Sep 17 00:00:00 2001 From: Yorick van Pelt Date: Wed, 22 Nov 2023 12:10:33 +0100 Subject: [PATCH 10/22] cog.types: fix untyped definition Signed-off-by: Yorick van Pelt --- python/cog/types.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/cog/types.py b/python/cog/types.py index c8f20749fe..f503b7d78b 100644 --- a/python/cog/types.py +++ b/python/cog/types.py @@ -261,11 +261,11 @@ def validate(cls, value: Any) -> Iterator: return value -def _len_bytes(s, encoding="utf-8") -> int: +def _len_bytes(s: str, encoding: str="utf-8") -> int: return len(s.encode(encoding)) -def _truncate_filename_bytes(s, length, encoding="utf-8") -> str: +def _truncate_filename_bytes(s: str, length: int, encoding: str="utf-8") -> str: """ Truncate a filename to at most `length` bytes, preserving file extension and avoiding text encoding corruption from truncation. From ab66b2a002ebf76b14198a6888ed1fcb261a4979 Mon Sep 17 00:00:00 2001 From: Yorick van Pelt Date: Wed, 22 Nov 2023 12:10:44 +0100 Subject: [PATCH 11/22] Add tool.pyright config to pyproject.toml Signed-off-by: Yorick van Pelt --- pyproject.toml | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index edad0944d4..fe9ed1491f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,20 @@ disallow_untyped_defs = true no_implicit_optional = false exclude = ["python/tests/"] +[tool.pyright] +# TODO: remove this and bring the codebase inline with the current default +strictParameterNoneValue = false +# legacy behavior, fixed in PEP688 +disableBytesTypePromotions = true +include = ["python"] +exclude = ["python/tests"] +reportMissingParameterType = "error" +reportUnknownLambdaType = "error" +reportUnneccesaryIsInstance = "warning" +reportUnneccesaryComparison = "warning" +reportUnnecessaryContains = "warning" +reportUnusedExpression = "warning" + [tool.setuptools] package-dir = { "" = "python" } From 6264eaf1c92bc88c79e071f3887a4f6306a558aa Mon Sep 17 00:00:00 2001 From: Yorick van Pelt Date: Wed, 22 Nov 2023 14:44:04 +0100 Subject: [PATCH 12/22] python: error on missing type arguments, add them everywhere Signed-off-by: Yorick van Pelt --- pyproject.toml | 1 + python/cog/command/ast_openapi_schema.py | 64 +++++++++++++++--------- python/cog/predictor.py | 18 +++---- python/cog/schema.py | 4 +- python/cog/server/http.py | 23 +++++++-- python/cog/server/response_throttler.py | 3 +- python/cog/server/runner.py | 6 +-- python/cog/server/webhook.py | 6 +-- python/cog/types.py | 2 +- 9 files changed, 81 insertions(+), 46 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fe9ed1491f..4f72aa79ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,6 +67,7 @@ reportUnknownLambdaType = "error" reportUnneccesaryIsInstance = "warning" reportUnneccesaryComparison = "warning" reportUnnecessaryContains = "warning" +reportMissingTypeArgument = "error" reportUnusedExpression = "warning" [tool.setuptools] diff --git a/python/cog/command/ast_openapi_schema.py b/python/cog/command/ast_openapi_schema.py index b147e32c3f..daa4e56bf8 100644 --- a/python/cog/command/ast_openapi_schema.py +++ b/python/cog/command/ast_openapi_schema.py @@ -309,8 +309,25 @@ def find(obj: ast.AST, name: str) -> ast.AST: """Find a particular named node in a tree""" return next(node for node in ast.walk(obj) if getattr(node, "name", "") == name) +if typing.TYPE_CHECKING: + AstVal: "typing.TypeAlias" = int | float | complex | str | list["AstVal"] | bytes | None + AstValNoBytes: "typing.TypeAlias" = int | float | str | list["AstValNoBytes"] + JSONObject: "typing.TypeAlias" = int | float | str | list["JSONObject"] | "JSONDict" | None + JSONDict: "typing.TypeAlias" = dict[str, "JSONObject"] -def get_value(node: ast.AST) -> "int | float | complex | str | list | bytes": + +def toSerializable(val: "AstVal") -> "JSONObject": + if isinstance(val, bytes): + return val.decode("utf-8") + elif isinstance(val, list): + return [toSerializable(x) for x in val] + elif isinstance(val, complex): + msg = "complex inputs are not supported" + raise ValueError(msg) + else: + return val + +def get_value(node: ast.AST) -> "AstVal": """Return the value of constant or list of constants""" if isinstance(node, ast.Constant): return node.value @@ -355,20 +372,14 @@ def parse_args(tree: ast.AST) -> "list[tuple[ast.arg, ast.expr | types.EllipsisT defaults = [...] * (len(args) - len(predict.args.defaults)) + predict.args.defaults return list(zip(args, defaults)) -def _decode_val(val: "int | float | complex | str | list | bytes") -> "int | float | complex | str | list": - if isinstance(val, bytes): - return val.decode("utf-8") - else: - return val - -def parse_assignment(assignment: ast.AST) -> "tuple[str | None, dict | None]": +def parse_assignment(assignment: ast.AST) -> "None | tuple[str, JSONObject]": """Parse an assignment into an OpenAPI object property""" if isinstance(assignment, ast.AnnAssign): assert isinstance(assignment.target, ast.Name) # shouldn't be an Attribute default = {} if assignment.value: try: - default = {"default": _decode_val(get_value(assignment.value))} + default = {"default": toSerializable(get_value(assignment.value))} except UnicodeDecodeError: pass return assignment.target.id, { @@ -378,21 +389,21 @@ def parse_assignment(assignment: ast.AST) -> "tuple[str | None, dict | None]": } if isinstance(assignment, ast.Assign): if len(assignment.targets) == 1 and isinstance(assignment.targets[0], ast.Name): - value = get_value(assignment.value) + value = toSerializable(get_value(assignment.value)) return assignment.targets[0].id, { "title": assignment.targets[0].id.replace("_", " ").title(), "type": OPENAPI_TYPES[type(value).__name__], "default": value, } raise ValueError("Unexpected assignment", assignment) - return None, None + return None -def parse_class(classdef: ast.AST) -> dict: +def parse_class(classdef: ast.AST) -> "JSONDict": """Parse a class definition into an OpenAPI object""" assert isinstance(classdef, ast.ClassDef) properties = { - key: property for key, property in map(parse_assignment, classdef.body) if key + assignment[0]: assignment[1] for assignment in map(parse_assignment, classdef.body) if assignment } return { "title": classdef.name, @@ -425,7 +436,7 @@ def resolve_name(node: ast.expr) -> str: raise ValueError("Unexpected node type", type(node), ast.unparse(node)) -def parse_return_annotation(tree: ast.AST, fn: str = "predict") -> "tuple[dict, dict]": +def parse_return_annotation(tree: ast.AST, fn: str = "predict") -> "tuple[JSONDict, JSONDict]": predict = find(tree, fn) if not isinstance(predict, ast.FunctionDef): raise ValueError("Could not find predict function") @@ -471,7 +482,7 @@ def predict( format = {"format": "uri"} if name in ("Path", "File") else {} return {}, {"title": "Output", "type": OPENAPI_TYPES.get(name, name), **format} # it must be a custom object - schema = {name: parse_class(find(tree, name))} + schema: "JSONDict" = {name: parse_class(find(tree, name))} return schema, { "title": "Output", "$ref": f"#/components/schemas/{name}", @@ -481,24 +492,29 @@ def predict( KEPT_ATTRS = ("description", "default", "ge", "le", "max_length", "min_length", "regex") -def extract_info(code: str) -> dict: +def extract_info(code: str) -> "JSONDict": """Parse the schemas from a file with a predict function""" tree = ast.parse(code) inputs = {"title": "Input", "type": "object", "properties": {}} required: "list[str]" = [] - schemas: "dict[str, dict]" = {} + schemas: "JSONDict" = {} for arg, default in parse_args(tree): if arg.arg == "self": continue if isinstance(default, ast.Call) and get_call_name(default) == "Input": - kws = {kw.arg: get_value(kw.value) for kw in default.keywords} + kws = {} + for kw in default.keywords: + if kw.arg is None: + msg = "unknown argument for Input" + raise ValueError(msg) + kws[kw.arg] = toSerializable(get_value(kw.value)) elif isinstance(default, (ast.Constant, ast.List, ast.Tuple, ast.Str, ast.Num)): - kws = {"default": get_value(default)} # could be None + kws = {"default": toSerializable(get_value(default))} # could be None elif default == ...: # no default kws = {} else: raise ValueError("Unexpected default value", default) - input: dict = {"x-order": len(inputs["properties"])} + input: "JSONDict" = {"x-order": len(inputs["properties"])} # need to handle other types? arg_type = OPENAPI_TYPES.get(get_annotation(arg.annotation), "string") if get_annotation(arg.annotation) in ("Path", "File"): @@ -525,18 +541,20 @@ def extract_info(code: str) -> dict: inputs["required"] = required # List[Path], list[Path], str, Iterator[str], MyOutput, Output return_schema, output = parse_return_annotation(tree, "predict") - schema = json.loads(BASE_SCHEMA) + schema: "JSONDict" = json.loads(BASE_SCHEMA) components = { "Input": inputs, "Output": output, **schemas, **return_schema, } - schema["components"]["schemas"].update(components) + # trust me, typechecker, I know BASE_SCHEMA + x: "JSONDict" = schema["components"]["schemas"] # type: ignore + x.update(components) return schema -def extract_file(fname: "str | Path") -> dict: +def extract_file(fname: "str | Path") -> "JSONObject": return extract_info(open(fname, encoding="utf-8").read()) diff --git a/python/cog/predictor.py b/python/cog/predictor.py index 24eb26e45e..3e041f839b 100644 --- a/python/cog/predictor.py +++ b/python/cog/predictor.py @@ -4,10 +4,10 @@ import io import os.path import sys +import types from abc import ABC, abstractmethod from collections.abc import Iterator from pathlib import Path -import types from typing import ( Any, Callable, @@ -44,7 +44,7 @@ Path as CogPath, ) -ALLOWED_INPUT_TYPES = [str, int, float, bool, CogFile, CogPath] +ALLOWED_INPUT_TYPES: List[Type[Any]] = [str, int, float, bool, CogFile, CogPath] class BasePredictor(ABC): @@ -104,7 +104,7 @@ def run_setup(predictor: BasePredictor) -> None: predictor.setup(weights=weights) -def get_weights_type(setup_function: Callable) -> Optional[Any]: +def get_weights_type(setup_function: Callable[[Any], None]) -> Optional[Any]: signature = inspect.signature(setup_function) if "weights" not in signature.parameters: return None @@ -118,7 +118,7 @@ def get_weights_type(setup_function: Callable) -> Optional[Any]: def run_prediction( - predictor: BasePredictor, inputs: Dict[Any, Any], cleanup_functions: List[Callable] + predictor: BasePredictor, inputs: Dict[Any, Any], cleanup_functions: List[Callable[[], None]], ) -> Any: """ Run the predictor on the inputs, and append resulting paths @@ -213,18 +213,18 @@ def cleanup(self) -> None: pass -def get_predict(predictor: Any) -> Callable: +def get_predict(predictor: Any) -> Callable[..., Any]: if hasattr(predictor, "predict"): return predictor.predict return predictor -def validate_input_type(type: Type, name: str) -> None: +def validate_input_type(type: Type[Any], name: str) -> None: if type is inspect.Signature.empty: raise TypeError( f"No input type provided for parameter `{name}`. Supported input types are: {readable_types_list(ALLOWED_INPUT_TYPES)}, or a Union or List of those types." ) elif type not in ALLOWED_INPUT_TYPES: - if get_origin(type) in (Union, List, list) or (hasattr(types, "UnionType") and get_origin(type) is types.UnionType): + if get_origin(type) in (Union, List, list) or (hasattr(types, "UnionType") and get_origin(type) is types.UnionType): # noqa: E721 for t in get_args(type): validate_input_type(t, name) else: @@ -344,7 +344,7 @@ class Output(BaseModel): return OutputType -def human_readable_type_name(t: Type) -> str: +def human_readable_type_name(t: Type[Any]) -> str: """ Generates a useful-for-humans label for a type. For builtin types, it's just the class name (eg "str" or "int"). For other types, it includes the module (eg "pathlib.Path" or "cog.File"). @@ -362,5 +362,5 @@ def human_readable_type_name(t: Type) -> str: return str(t) -def readable_types_list(type_list: List[Type]) -> str: +def readable_types_list(type_list: List[Type[Any]]) -> str: return ", ".join(human_readable_type_name(t) for t in type_list) diff --git a/python/cog/schema.py b/python/cog/schema.py index 22948bd6a8..c11fa77df7 100644 --- a/python/cog/schema.py +++ b/python/cog/schema.py @@ -48,7 +48,7 @@ class PredictionRequest(PredictionBaseModel): ] = WebhookEvent.default_events() @classmethod - def with_types(cls, input_type: t.Type) -> t.Any: + def with_types(cls, input_type: t.Type[t.Any]) -> t.Any: # [compat] Input is implicitly optional -- previous versions of the # Cog HTTP API allowed input to be omitted (e.g. for models that don't # have any inputs). We should consider changing this in future. @@ -74,7 +74,7 @@ class PredictionResponse(PredictionBaseModel): metrics: t.Optional[t.Dict[str, t.Any]] @classmethod - def with_types(cls, input_type: t.Type, output_type: t.Type) -> t.Any: + def with_types(cls, input_type: t.Type[t.Any], output_type: t.Type[t.Any]) -> t.Any: # [compat] Input is implicitly optional -- previous versions of the # Cog HTTP API allowed input to be omitted (e.g. for models that don't # have any inputs). We should consider changing this in future. diff --git a/python/cog/server/http.py b/python/cog/server/http.py index 2a12183c12..fae5aa9975 100644 --- a/python/cog/server/http.py +++ b/python/cog/server/http.py @@ -9,7 +9,19 @@ import textwrap import threading from enum import Enum, auto, unique -from typing import Any, Callable, Dict, Optional, Union +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + Optional, + TypeVar, + Union, +) + +if TYPE_CHECKING: + from typing import ParamSpec import structlog import uvicorn @@ -82,9 +94,12 @@ class PredictionRequest(schema.PredictionRequest.with_types(input_type=InputType http_semaphore = asyncio.Semaphore(threads) - def limited(f: Callable) -> Callable: + if TYPE_CHECKING: + P = ParamSpec("P") + T = TypeVar("T") + def limited(f: "Callable[P, Awaitable[T]]") -> "Callable[P, Awaitable[T]]": @functools.wraps(f) - async def wrapped(*args: Any, **kwargs: Any) -> Any: + async def wrapped(*args: P.args, **kwargs: P.kwargs) -> T: async with http_semaphore: return await f(*args, **kwargs) @@ -313,7 +328,7 @@ def signal_ignore(signum: Any, frame: Any) -> None: log.warn("Got a signal to exit, ignoring it...", signal=signal.Signals(signum).name) -def signal_set_event(event: threading.Event) -> Callable: +def signal_set_event(event: threading.Event) -> Callable[[Any, Any], None]: def _signal_set_event(signum: Any, frame: Any) -> None: event.set() diff --git a/python/cog/server/response_throttler.py b/python/cog/server/response_throttler.py index 8cbeda21d2..41e2ed0312 100644 --- a/python/cog/server/response_throttler.py +++ b/python/cog/server/response_throttler.py @@ -1,4 +1,5 @@ import time +from typing import Any, Dict from ..schema import Status @@ -8,7 +9,7 @@ def __init__(self, response_interval: float) -> None: self.last_sent_response_time = 0.0 self.response_interval = response_interval - def should_send_response(self, response: dict) -> bool: + def should_send_response(self, response: Dict[str, Any]) -> bool: if Status.is_terminal(response["status"]): return True diff --git a/python/cog/server/runner.py b/python/cog/server/runner.py index 35b8475b96..38d4679d15 100644 --- a/python/cog/server/runner.py +++ b/python/cog/server/runner.py @@ -168,7 +168,7 @@ def create_event_handler( return event_handler -def generate_file_uploader(upload_url: str) -> Callable: +def generate_file_uploader(upload_url: str) -> Callable[[Any], Any]: client = _make_file_upload_http_client() def file_uploader(output: Any) -> Any: @@ -184,8 +184,8 @@ class PredictionEventHandler: def __init__( self, p: schema.PredictionResponse, - webhook_sender: Optional[Callable] = None, - file_uploader: Optional[Callable] = None, + webhook_sender: Optional[Callable[[Any, schema.WebhookEvent], None]] = None, + file_uploader: Optional[Callable[[Any], Any]] = None, ) -> None: log.info("starting prediction") self.p = p diff --git a/python/cog/server/webhook.py b/python/cog/server/webhook.py index 55a51dca61..9cd4ea59c8 100644 --- a/python/cog/server/webhook.py +++ b/python/cog/server/webhook.py @@ -37,8 +37,8 @@ def _get_version() -> str: def webhook_caller_filtered( - webhook: str, webhook_events_filter: Set[WebhookEvent] -) -> Callable: + webhook: str, webhook_events_filter: Set[WebhookEvent], +) -> Callable[[Any, WebhookEvent], None]: upstream_caller = webhook_caller(webhook) def caller(response: Any, event: WebhookEvent) -> None: @@ -48,7 +48,7 @@ def caller(response: Any, event: WebhookEvent) -> None: return caller -def webhook_caller(webhook: str) -> Callable: +def webhook_caller(webhook: str) -> Callable[[Any], None]: # TODO: we probably don't need to create new sessions and new throttlers # for every prediction. throttler = ResponseThrottler(response_interval=_response_interval) diff --git a/python/cog/types.py b/python/cog/types.py index f503b7d78b..d4a95bdfa7 100644 --- a/python/cog/types.py +++ b/python/cog/types.py @@ -257,7 +257,7 @@ def __get_validators__(cls) -> Iterator[Any]: yield cls.validate @classmethod - def validate(cls, value: Any) -> Iterator: + def validate(cls, value: Iterator[Any]) -> Iterator[Any]: return value From 8dfe4e6bd68f4276e88dc52cbd4fccb0c991dfc8 Mon Sep 17 00:00:00 2001 From: Yorick van Pelt Date: Wed, 22 Nov 2023 14:51:40 +0100 Subject: [PATCH 13/22] ast_openapi_schema.extract_info: make type checker actually happy Signed-off-by: Yorick van Pelt --- python/cog/command/ast_openapi_schema.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/python/cog/command/ast_openapi_schema.py b/python/cog/command/ast_openapi_schema.py index daa4e56bf8..bbfd4584a9 100644 --- a/python/cog/command/ast_openapi_schema.py +++ b/python/cog/command/ast_openapi_schema.py @@ -495,7 +495,8 @@ def predict( def extract_info(code: str) -> "JSONDict": """Parse the schemas from a file with a predict function""" tree = ast.parse(code) - inputs = {"title": "Input", "type": "object", "properties": {}} + properties: "JSONDict" = {} + inputs: "JSONDict" = {"title": "Input", "type": "object", "properties": properties} required: "list[str]" = [] schemas: "JSONDict" = {} for arg, default in parse_args(tree): @@ -514,7 +515,7 @@ def extract_info(code: str) -> "JSONDict": kws = {} else: raise ValueError("Unexpected default value", default) - input: "JSONDict" = {"x-order": len(inputs["properties"])} + input: "JSONDict" = {"x-order": len(properties)} # need to handle other types? arg_type = OPENAPI_TYPES.get(get_annotation(arg.annotation), "string") if get_annotation(arg.annotation) in ("Path", "File"): @@ -536,13 +537,13 @@ def extract_info(code: str) -> "JSONDict": else: input["title"] = arg.arg.replace("_", " ").title() input["type"] = arg_type - inputs["properties"][arg.arg] = input # type: ignore + properties[arg.arg] = input if required: - inputs["required"] = required + inputs["required"] = list(required) # List[Path], list[Path], str, Iterator[str], MyOutput, Output return_schema, output = parse_return_annotation(tree, "predict") schema: "JSONDict" = json.loads(BASE_SCHEMA) - components = { + components: "JSONDict" = { "Input": inputs, "Output": output, **schemas, From ddc93333e5c2e678a50589c18a1720c10c96c510 Mon Sep 17 00:00:00 2001 From: Yorick van Pelt Date: Wed, 22 Nov 2023 15:01:09 +0100 Subject: [PATCH 14/22] Add typecheck python github action Signed-off-by: Yorick van Pelt --- .github/workflows/ci.yaml | 18 ++++++++++++++++++ pyproject.toml | 1 + 2 files changed, 19 insertions(+) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 7f814e7e96..d55aee4bd7 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -67,6 +67,24 @@ jobs: env: HYPOTHESIS_PROFILE: ci + typecheck-python: + name: "Typecheck Python" + runs-on: ubuntu-latest + defaults: + run: + shell: bash + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v4 + with: + python-version: "3.11" + - name: Install Python dependencies + run: | + python -m pip install '.[dev]' + - name: Run typechecking + run: | + python -m pyright + # cannot run this on mac due to licensing issues: https://github.com/actions/virtual-environments/issues/2150 test-integration: name: "Test integration" diff --git a/pyproject.toml b/pyproject.toml index 4f72aa79ed..28d7c965f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ optional-dependencies = { "dev" = [ 'numpy<1.22.0; python_version < "3.8"', 'numpy; python_version >= "3.8"', "pillow", + "pyright", "pytest", "pytest-asyncio", "pytest-httpserver", From 5dfd4f3e6ed86e701b240726c430e3d6038d5307 Mon Sep 17 00:00:00 2001 From: Yorick van Pelt Date: Wed, 22 Nov 2023 15:03:10 +0100 Subject: [PATCH 15/22] cog.server.http: quote typenames Signed-off-by: Yorick van Pelt --- python/cog/server/http.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/cog/server/http.py b/python/cog/server/http.py index fae5aa9975..5fe4ad9250 100644 --- a/python/cog/server/http.py +++ b/python/cog/server/http.py @@ -99,7 +99,7 @@ class PredictionRequest(schema.PredictionRequest.with_types(input_type=InputType T = TypeVar("T") def limited(f: "Callable[P, Awaitable[T]]") -> "Callable[P, Awaitable[T]]": @functools.wraps(f) - async def wrapped(*args: P.args, **kwargs: P.kwargs) -> T: + async def wrapped(*args: "P.args", **kwargs: "P.kwargs") -> "T": async with http_semaphore: return await f(*args, **kwargs) From 38f022259bd9f869493fac5d54ab6924a468d72b Mon Sep 17 00:00:00 2001 From: Yorick van Pelt Date: Thu, 23 Nov 2023 17:14:06 +0100 Subject: [PATCH 16/22] python: process some review comments Signed-off-by: Yorick van Pelt --- python/cog/command/ast_openapi_schema.py | 20 ++++++++++---------- python/cog/server/http.py | 5 +---- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/python/cog/command/ast_openapi_schema.py b/python/cog/command/ast_openapi_schema.py index bbfd4584a9..dc1ceee752 100644 --- a/python/cog/command/ast_openapi_schema.py +++ b/python/cog/command/ast_openapi_schema.py @@ -310,17 +310,17 @@ def find(obj: ast.AST, name: str) -> ast.AST: return next(node for node in ast.walk(obj) if getattr(node, "name", "") == name) if typing.TYPE_CHECKING: - AstVal: "typing.TypeAlias" = int | float | complex | str | list["AstVal"] | bytes | None - AstValNoBytes: "typing.TypeAlias" = int | float | str | list["AstValNoBytes"] - JSONObject: "typing.TypeAlias" = int | float | str | list["JSONObject"] | "JSONDict" | None - JSONDict: "typing.TypeAlias" = dict[str, "JSONObject"] + AstVal: "typing.TypeAlias" = "int | float | complex | str | list[AstVal] | bytes | None" + AstValNoBytes: "typing.TypeAlias" = "int | float | str | list[AstValNoBytes]" + JSONObject: "typing.TypeAlias" = "int | float | str | list[JSONObject] | JSONDict | None" + JSONDict: "typing.TypeAlias" = "dict[str, JSONObject]" -def toSerializable(val: "AstVal") -> "JSONObject": +def to_serializable(val: "AstVal") -> "JSONObject": if isinstance(val, bytes): return val.decode("utf-8") elif isinstance(val, list): - return [toSerializable(x) for x in val] + return [to_serializable(x) for x in val] elif isinstance(val, complex): msg = "complex inputs are not supported" raise ValueError(msg) @@ -379,7 +379,7 @@ def parse_assignment(assignment: ast.AST) -> "None | tuple[str, JSONObject]": default = {} if assignment.value: try: - default = {"default": toSerializable(get_value(assignment.value))} + default = {"default": to_serializable(get_value(assignment.value))} except UnicodeDecodeError: pass return assignment.target.id, { @@ -389,7 +389,7 @@ def parse_assignment(assignment: ast.AST) -> "None | tuple[str, JSONObject]": } if isinstance(assignment, ast.Assign): if len(assignment.targets) == 1 and isinstance(assignment.targets[0], ast.Name): - value = toSerializable(get_value(assignment.value)) + value = to_serializable(get_value(assignment.value)) return assignment.targets[0].id, { "title": assignment.targets[0].id.replace("_", " ").title(), "type": OPENAPI_TYPES[type(value).__name__], @@ -508,9 +508,9 @@ def extract_info(code: str) -> "JSONDict": if kw.arg is None: msg = "unknown argument for Input" raise ValueError(msg) - kws[kw.arg] = toSerializable(get_value(kw.value)) + kws[kw.arg] = to_serializable(get_value(kw.value)) elif isinstance(default, (ast.Constant, ast.List, ast.Tuple, ast.Str, ast.Num)): - kws = {"default": toSerializable(get_value(default))} # could be None + kws = {"default": to_serializable(get_value(default))} # could be None elif default == ...: # no default kws = {} else: diff --git a/python/cog/server/http.py b/python/cog/server/http.py index 5fe4ad9250..99b61934be 100644 --- a/python/cog/server/http.py +++ b/python/cog/server/http.py @@ -382,10 +382,7 @@ def _signal_set_event(signum: Any, frame: Any) -> None: if config.get("build", {}).get("gpu", False): threads = 1 else: - threads = os.cpu_count() - if threads is None: - log.warn("Unable to determine cpu count, defaulting to 1 thread") - threads = 1 + threads = max(1, len(os.sched_getaffinity(0))) shutdown_event = threading.Event() app = create_app( From 2c8a65c76103a54f65e969ecff2b73ff94de73e2 Mon Sep 17 00:00:00 2001 From: Yorick van Pelt Date: Thu, 23 Nov 2023 17:49:44 +0100 Subject: [PATCH 17/22] async runner/http: fix types Signed-off-by: Yorick van Pelt --- python/cog/server/http.py | 8 ++++---- python/cog/server/runner.py | 20 +++++++++++--------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/python/cog/server/http.py b/python/cog/server/http.py index 99b61934be..11fc1e9e70 100644 --- a/python/cog/server/http.py +++ b/python/cog/server/http.py @@ -59,7 +59,7 @@ class Health(Enum): def create_app( config: Dict[str, Any], - shutdown_event: Optional[threading.Event], + shutdown_event: Optional[asyncio.Event], threads: int = 1, upload_url: Optional[str] = None, mode: str = "predict", @@ -328,7 +328,7 @@ def signal_ignore(signum: Any, frame: Any) -> None: log.warn("Got a signal to exit, ignoring it...", signal=signal.Signals(signum).name) -def signal_set_event(event: threading.Event) -> Callable[[Any, Any], None]: +def signal_set_event(event: threading.Event | asyncio.Event) -> Callable[[Any, Any], None]: def _signal_set_event(signum: Any, frame: Any) -> None: event.set() @@ -384,7 +384,7 @@ def _signal_set_event(signum: Any, frame: Any) -> None: else: threads = max(1, len(os.sched_getaffinity(0))) - shutdown_event = threading.Event() + shutdown_event = asyncio.Event() app = create_app( config=config, shutdown_event=shutdown_event, @@ -416,7 +416,7 @@ def _signal_set_event(signum: Any, frame: Any) -> None: s.start() try: - shutdown_event.wait() + asyncio.run(shutdown_event.wait()) except KeyboardInterrupt: pass diff --git a/python/cog/server/runner.py b/python/cog/server/runner.py index 38d4679d15..9b827d34a1 100644 --- a/python/cog/server/runner.py +++ b/python/cog/server/runner.py @@ -3,7 +3,8 @@ import traceback from asyncio import Task from datetime import datetime, timezone -from typing import Any, Callable, Dict, Optional, Tuple, cast +from typing import Any, Callable, Optional, Tuple, cast +import typing # TypeAlias, py3.10 import requests import structlog @@ -33,6 +34,7 @@ class RunnerBusyError(Exception): class UnknownPredictionError(Exception): pass +PredictionTask: "typing.TypeAlias" = Task[schema.PredictionResponse] class PredictionRunner: def __init__( @@ -43,7 +45,7 @@ def __init__( upload_url: Optional[str] = None, ) -> None: self._response: Optional[schema.PredictionResponse] = None - self._result: Optional[Task] = None + self._result: "Optional[PredictionTask]" = None self._worker = Worker(predictor_ref=predictor_ref) self._should_cancel = asyncio.Event() @@ -51,8 +53,8 @@ def __init__( self._shutdown_event = shutdown_event self._upload_url = upload_url - def make_error_handler(self, activity: str) -> Callable: - def handle_error(task: Task) -> None: + def make_error_handler(self, activity: str) -> Callable[[PredictionTask], None]: + def handle_error(task: PredictionTask) -> None: exc = task.exception() if not exc: return @@ -68,7 +70,7 @@ def handle_error(task: Task) -> None: return handle_error - def setup(self) -> "Task[dict[str, Any]]": + def setup(self) -> "Task[schema.PredictionResponse]": if self.is_busy(): raise RunnerBusyError() self._result = asyncio.create_task(setup(worker=self._worker)) @@ -100,7 +102,7 @@ def predict( upload_url = self._upload_url if upload else None event_handler = create_event_handler(prediction, upload_url=upload_url) - def handle_cleanup(_: Task) -> None: + def handle_cleanup(_: PredictionTask) -> None: input = cast(Any, prediction.input) if hasattr(input, "cleanup"): input.cleanup() @@ -274,7 +276,7 @@ def _upload_files(self, output: Any) -> Any: raise FileUploadError("Got error trying to upload output files") from error -async def setup(*, worker: Worker) -> Dict[str, Any]: +async def setup(*, worker: Worker) -> schema.PredictionResponse: logs = [] status = None started_at = datetime.now(tz=timezone.utc) @@ -304,12 +306,12 @@ async def setup(*, worker: Worker) -> Dict[str, Any]: probes = ProbeHelper() probes.ready() - return { + return schema.PredictionResponse(**{ "logs": "".join(logs), "status": status, "started_at": started_at, "completed_at": completed_at, - } + }) async def predict( From 4c9b8d9c3776d72a837ffaaa8f807fd81988edc0 Mon Sep 17 00:00:00 2001 From: Yorick van Pelt Date: Fri, 24 Nov 2023 11:50:21 +0100 Subject: [PATCH 18/22] python: mypy -> pyright everywhere Signed-off-by: Yorick van Pelt --- .github/workflows/ci.yaml | 3 --- Makefile | 4 ++-- pyproject.toml | 8 -------- python/cog/files.py | 3 ++- 4 files changed, 4 insertions(+), 14 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index d55aee4bd7..f726b2f4dd 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -37,7 +37,6 @@ jobs: - name: Install Python dependencies run: | python -m pip install '.[dev]' - yes | python -m mypy --install-types replicate || true - name: Build run: make cog - name: Test @@ -61,7 +60,6 @@ jobs: - name: Install Python dependencies run: | python -m pip install '.[dev]' - yes | python -m mypy --install-types replicate || true - name: Test run: make test-python env: @@ -100,7 +98,6 @@ jobs: - name: Install Python dependencies run: | python -m pip install '.[dev]' - yes | python -m mypy --install-types replicate || true - name: Test run: make test-integration diff --git a/Makefile b/Makefile index a4378cb36f..8edf9857d2 100644 --- a/Makefile +++ b/Makefile @@ -13,7 +13,7 @@ GOARCH := $(shell $(GO) env GOARCH) PYTHON := python PYTEST := $(PYTHON) -m pytest -MYPY := $(PYTHON) -m mypy +PYRIGHT := $(PYTHON) -m pyright RUFF := $(PYTHON) -m ruff default: all @@ -94,7 +94,7 @@ lint-go: .PHONY: lint-python lint-python: $(RUFF) python/cog - $(MYPY) python/cog + $(PYRIGHT) .PHONY: lint lint: lint-go lint-python diff --git a/pyproject.toml b/pyproject.toml index 28d7c965f2..43eefd3425 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,6 @@ optional-dependencies = { "dev" = [ "httpx", 'hypothesis<6.80.0; python_version < "3.8"', 'hypothesis; python_version >= "3.8"', - "mypy", 'numpy<1.22.0; python_version < "3.8"', 'numpy; python_version >= "3.8"', "pillow", @@ -49,13 +48,6 @@ dynamic = ["version"] [tool.setuptools_scm] write_to = "python/cog/_version.py" -[tool.mypy] -plugins = "pydantic.mypy" -disallow_untyped_defs = true -# TODO: remove this and bring the codebase inline with the current mypy default -no_implicit_optional = false -exclude = ["python/tests/"] - [tool.pyright] # TODO: remove this and bring the codebase inline with the current default strictParameterNoneValue = false diff --git a/python/cog/files.py b/python/cog/files.py index 4477f22fdc..2ca8cd383a 100644 --- a/python/cog/files.py +++ b/python/cog/files.py @@ -23,7 +23,8 @@ def upload_file(fh: io.IOBase, output_file_prefix: str = None) -> str: b = b.encode("utf-8") encoded_body = base64.b64encode(b) if getattr(fh, "name", None): - # despite doing a getattr check here, mypy complains that io.IOBase has no attribute name + # despite doing a getattr check here, pyright complains that io.IOBase has no attribute name + # TODO: switch to typing.IO[]? mime_type = mimetypes.guess_type(fh.name)[0] # type: ignore else: mime_type = "application/octet-stream" From d18c755268fbdda0b955b77f80f0031d33d594ce Mon Sep 17 00:00:00 2001 From: Yorick van Pelt Date: Fri, 24 Nov 2023 13:21:28 +0100 Subject: [PATCH 19/22] pyproject: unneccesary -> unnecessary Signed-off-by: Yorick van Pelt --- pyproject.toml | 6 +++--- python/cog/server/http.py | 2 +- python/cog/server/runner.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 43eefd3425..07ff5161d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,9 +57,9 @@ include = ["python"] exclude = ["python/tests"] reportMissingParameterType = "error" reportUnknownLambdaType = "error" -reportUnneccesaryIsInstance = "warning" -reportUnneccesaryComparison = "warning" -reportUnnecessaryContains = "warning" +reportUnnecessaryIsInstance = "warning" +reportUnnecessaryComparison = "warning" +reportUnneesssaryContains = "warning" reportMissingTypeArgument = "error" reportUnusedExpression = "warning" diff --git a/python/cog/server/http.py b/python/cog/server/http.py index 11fc1e9e70..e7ce5db990 100644 --- a/python/cog/server/http.py +++ b/python/cog/server/http.py @@ -191,7 +191,7 @@ async def predict_idempotent( return await _predict(request=request, respond_async=respond_async) async def _predict( - *, request: PredictionRequest, respond_async: bool = False + *, request: Optional[PredictionRequest], respond_async: bool = False ) -> Response: # [compat] If no body is supplied, assume that this model can be run # with empty input. This will throw a ValidationError if that's not diff --git a/python/cog/server/runner.py b/python/cog/server/runner.py index 9b827d34a1..3ebe339e60 100644 --- a/python/cog/server/runner.py +++ b/python/cog/server/runner.py @@ -397,7 +397,7 @@ async def _predict( else: event_handler.set_output(event.payload) - elif isinstance(event, Done): + elif isinstance(event, Done): # pyright: ignore reportUnnecessaryIsinstance if event.canceled: event_handler.canceled() elif event.error: @@ -405,7 +405,7 @@ async def _predict( else: event_handler.succeeded() - else: + else: # shouldn't happen, exhausted the type log.warn("received unexpected event from worker", data=event) return event_handler.response From c8f36366bc716191120e7b50d0e55958a887ce98 Mon Sep 17 00:00:00 2001 From: Yorick van Pelt Date: Fri, 24 Nov 2023 13:21:56 +0100 Subject: [PATCH 20/22] resolve threading|asyncio.Event difference the other way Signed-off-by: Yorick van Pelt --- python/cog/server/http.py | 8 ++++---- python/cog/server/runner.py | 5 +++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/python/cog/server/http.py b/python/cog/server/http.py index e7ce5db990..c087c1bc72 100644 --- a/python/cog/server/http.py +++ b/python/cog/server/http.py @@ -59,7 +59,7 @@ class Health(Enum): def create_app( config: Dict[str, Any], - shutdown_event: Optional[asyncio.Event], + shutdown_event: Optional[threading.Event], threads: int = 1, upload_url: Optional[str] = None, mode: str = "predict", @@ -328,7 +328,7 @@ def signal_ignore(signum: Any, frame: Any) -> None: log.warn("Got a signal to exit, ignoring it...", signal=signal.Signals(signum).name) -def signal_set_event(event: threading.Event | asyncio.Event) -> Callable[[Any, Any], None]: +def signal_set_event(event: threading.Event) -> Callable[[Any, Any], None]: def _signal_set_event(signum: Any, frame: Any) -> None: event.set() @@ -384,7 +384,7 @@ def _signal_set_event(signum: Any, frame: Any) -> None: else: threads = max(1, len(os.sched_getaffinity(0))) - shutdown_event = asyncio.Event() + shutdown_event = threading.Event() app = create_app( config=config, shutdown_event=shutdown_event, @@ -416,7 +416,7 @@ def _signal_set_event(signum: Any, frame: Any) -> None: s.start() try: - asyncio.run(shutdown_event.wait()) + shutdown_event.wait() except KeyboardInterrupt: pass diff --git a/python/cog/server/runner.py b/python/cog/server/runner.py index 3ebe339e60..bc960ef2ce 100644 --- a/python/cog/server/runner.py +++ b/python/cog/server/runner.py @@ -1,10 +1,11 @@ import asyncio import io +import threading import traceback +import typing # TypeAlias, py3.10 from asyncio import Task from datetime import datetime, timezone from typing import Any, Callable, Optional, Tuple, cast -import typing # TypeAlias, py3.10 import requests import structlog @@ -41,7 +42,7 @@ def __init__( self, *, predictor_ref: str, - shutdown_event: Optional[asyncio.Event], + shutdown_event: Optional[threading.Event], upload_url: Optional[str] = None, ) -> None: self._response: Optional[schema.PredictionResponse] = None From 9aca73b704761d339365c7ca519317e574651590 Mon Sep 17 00:00:00 2001 From: Yorick van Pelt Date: Fri, 24 Nov 2023 13:42:17 +0100 Subject: [PATCH 21/22] python: make tests pass for PredictionResponse Signed-off-by: Yorick van Pelt --- python/cog/server/http.py | 13 ++++++++++--- python/cog/server/runner.py | 19 +++++++++++++------ python/tests/server/test_runner.py | 8 ++++---- 3 files changed, 27 insertions(+), 13 deletions(-) diff --git a/python/cog/server/http.py b/python/cog/server/http.py index c087c1bc72..4b003ebc50 100644 --- a/python/cog/server/http.py +++ b/python/cog/server/http.py @@ -56,6 +56,13 @@ class Health(Enum): BUSY = auto() SETUP_FAILED = auto() +class State: + health: Health + setup_result: "Optional[asyncio.Task[schema.PredictionResponse]]" + setup_result_payload: Optional[schema.PredictionResponse] + +class MyFastAPI(FastAPI): + state: State def create_app( config: Dict[str, Any], @@ -63,8 +70,8 @@ def create_app( threads: int = 1, upload_url: Optional[str] = None, mode: str = "predict", -) -> FastAPI: - app = FastAPI( +) -> MyFastAPI: + app = MyFastAPI( title="Cog", # TODO: mention model name? # version=None # TODO ) @@ -266,7 +273,7 @@ async def _check_setup_result() -> Any: # this can raise CancelledError result = app.state.setup_result.result() - if result["status"] == schema.Status.SUCCEEDED: + if result.status == schema.Status.SUCCEEDED: app.state.health = Health.READY else: app.state.health = Health.SETUP_FAILED diff --git a/python/cog/server/runner.py b/python/cog/server/runner.py index bc960ef2ce..c7664779d6 100644 --- a/python/cog/server/runner.py +++ b/python/cog/server/runner.py @@ -307,12 +307,19 @@ async def setup(*, worker: Worker) -> schema.PredictionResponse: probes = ProbeHelper() probes.ready() - return schema.PredictionResponse(**{ - "logs": "".join(logs), - "status": status, - "started_at": started_at, - "completed_at": completed_at, - }) + return schema.PredictionResponse( + input={}, + output=None, + id=None, + version=None, + created_at=None, + started_at=started_at, + completed_at=completed_at, + logs="".join(logs), + error=None, + metrics=None, + status=status + ) async def predict( diff --git a/python/tests/server/test_runner.py b/python/tests/server/test_runner.py index 2a2ae9a045..b0151b9f54 100644 --- a/python/tests/server/test_runner.py +++ b/python/tests/server/test_runner.py @@ -48,10 +48,10 @@ async def test_prediction_runner_setup(): try: result = await runner.setup() - assert result["status"] == Status.SUCCEEDED - assert result["logs"] == "" - assert isinstance(result["started_at"], datetime) - assert isinstance(result["completed_at"], datetime) + assert result.status == Status.SUCCEEDED + assert result.logs == "" + assert isinstance(result.started_at, datetime) + assert isinstance(result.completed_at, datetime) finally: runner.shutdown() From 068f6a7e1c4f89f32342583763622242d16814c4 Mon Sep 17 00:00:00 2001 From: Yorick van Pelt Date: Fri, 24 Nov 2023 13:56:41 +0100 Subject: [PATCH 22/22] Fix test in python 3.7 Signed-off-by: Yorick van Pelt --- python/cog/server/runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/cog/server/runner.py b/python/cog/server/runner.py index c7664779d6..19b3fc0c61 100644 --- a/python/cog/server/runner.py +++ b/python/cog/server/runner.py @@ -35,7 +35,7 @@ class RunnerBusyError(Exception): class UnknownPredictionError(Exception): pass -PredictionTask: "typing.TypeAlias" = Task[schema.PredictionResponse] +PredictionTask: "typing.TypeAlias" = "Task[schema.PredictionResponse]" class PredictionRunner: def __init__(