diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 7f814e7e96..f726b2f4dd 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -37,7 +37,6 @@ jobs: - name: Install Python dependencies run: | python -m pip install '.[dev]' - yes | python -m mypy --install-types replicate || true - name: Build run: make cog - name: Test @@ -61,12 +60,29 @@ jobs: - name: Install Python dependencies run: | python -m pip install '.[dev]' - yes | python -m mypy --install-types replicate || true - name: Test run: make test-python env: HYPOTHESIS_PROFILE: ci + typecheck-python: + name: "Typecheck Python" + runs-on: ubuntu-latest + defaults: + run: + shell: bash + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v4 + with: + python-version: "3.11" + - name: Install Python dependencies + run: | + python -m pip install '.[dev]' + - name: Run typechecking + run: | + python -m pyright + # cannot run this on mac due to licensing issues: https://github.com/actions/virtual-environments/issues/2150 test-integration: name: "Test integration" @@ -82,7 +98,6 @@ jobs: - name: Install Python dependencies run: | python -m pip install '.[dev]' - yes | python -m mypy --install-types replicate || true - name: Test run: make test-integration diff --git a/Makefile b/Makefile index a4378cb36f..8edf9857d2 100644 --- a/Makefile +++ b/Makefile @@ -13,7 +13,7 @@ GOARCH := $(shell $(GO) env GOARCH) PYTHON := python PYTEST := $(PYTHON) -m pytest -MYPY := $(PYTHON) -m mypy +PYRIGHT := $(PYTHON) -m pyright RUFF := $(PYTHON) -m ruff default: all @@ -94,7 +94,7 @@ lint-go: .PHONY: lint-python lint-python: $(RUFF) python/cog - $(MYPY) python/cog + $(PYRIGHT) .PHONY: lint lint: lint-go lint-python diff --git a/pyproject.toml b/pyproject.toml index edad0944d4..07ff5161d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,10 +30,10 @@ optional-dependencies = { "dev" = [ "httpx", 'hypothesis<6.80.0; python_version < "3.8"', 'hypothesis; python_version >= "3.8"', - "mypy", 'numpy<1.22.0; python_version < "3.8"', 'numpy; python_version >= "3.8"', "pillow", + "pyright", "pytest", "pytest-asyncio", "pytest-httpserver", @@ -48,12 +48,20 @@ dynamic = ["version"] [tool.setuptools_scm] write_to = "python/cog/_version.py" -[tool.mypy] -plugins = "pydantic.mypy" -disallow_untyped_defs = true -# TODO: remove this and bring the codebase inline with the current mypy default -no_implicit_optional = false -exclude = ["python/tests/"] +[tool.pyright] +# TODO: remove this and bring the codebase inline with the current default +strictParameterNoneValue = false +# legacy behavior, fixed in PEP688 +disableBytesTypePromotions = true +include = ["python"] +exclude = ["python/tests"] +reportMissingParameterType = "error" +reportUnknownLambdaType = "error" +reportUnnecessaryIsInstance = "warning" +reportUnnecessaryComparison = "warning" +reportUnneesssaryContains = "warning" +reportMissingTypeArgument = "error" +reportUnusedExpression = "warning" [tool.setuptools] package-dir = { "" = "python" } diff --git a/python/cog/command/ast_openapi_schema.py b/python/cog/command/ast_openapi_schema.py index 7e886122ea..dc1ceee752 100644 --- a/python/cog/command/ast_openapi_schema.py +++ b/python/cog/command/ast_openapi_schema.py @@ -1,6 +1,8 @@ import ast import json import sys +import types +import typing from pathlib import Path try: @@ -307,8 +309,25 @@ def find(obj: ast.AST, name: str) -> ast.AST: """Find a particular named node in a tree""" return next(node for node in ast.walk(obj) if getattr(node, "name", "") == name) +if typing.TYPE_CHECKING: + AstVal: "typing.TypeAlias" = "int | float | complex | str | list[AstVal] | bytes | None" + AstValNoBytes: "typing.TypeAlias" = "int | float | str | list[AstValNoBytes]" + JSONObject: "typing.TypeAlias" = "int | float | str | list[JSONObject] | JSONDict | None" + JSONDict: "typing.TypeAlias" = "dict[str, JSONObject]" + + +def to_serializable(val: "AstVal") -> "JSONObject": + if isinstance(val, bytes): + return val.decode("utf-8") + elif isinstance(val, list): + return [to_serializable(x) for x in val] + elif isinstance(val, complex): + msg = "complex inputs are not supported" + raise ValueError(msg) + else: + return val -def get_value(node: ast.AST) -> "int | float | complex | str | list": +def get_value(node: ast.AST) -> "AstVal": """Return the value of constant or list of constants""" if isinstance(node, ast.Constant): return node.value @@ -320,7 +339,7 @@ def get_value(node: ast.AST) -> "int | float | complex | str | list": if isinstance(node, (ast.List, ast.Tuple)): return [get_value(e) for e in node.elts] if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub): - return -get_value(node.operand) + return -typing.cast(typing.Union[int, float, complex], get_value(node.operand)) raise ValueError("Unexpected node type", type(node)) @@ -344,7 +363,7 @@ def get_call_name(call: ast.Call) -> str: raise ValueError("Unexpected node type", type(call), ast.unparse(call)) -def parse_args(tree: ast.AST) -> "list[tuple[ast.arg, ast.expr | Ellipsis]]": +def parse_args(tree: ast.AST) -> "list[tuple[ast.arg, ast.expr | types.EllipsisType]]": """Parse argument, default pairs from a file with a predict function""" predict = find(tree, "predict") assert isinstance(predict, ast.FunctionDef) @@ -353,12 +372,16 @@ def parse_args(tree: ast.AST) -> "list[tuple[ast.arg, ast.expr | Ellipsis]]": defaults = [...] * (len(args) - len(predict.args.defaults)) + predict.args.defaults return list(zip(args, defaults)) - -def parse_assignment(assignment: ast.AST) -> "tuple[str | None, dict | None]": +def parse_assignment(assignment: ast.AST) -> "None | tuple[str, JSONObject]": """Parse an assignment into an OpenAPI object property""" if isinstance(assignment, ast.AnnAssign): assert isinstance(assignment.target, ast.Name) # shouldn't be an Attribute - default = {"default": get_value(assignment.value)} if assignment.value else {} + default = {} + if assignment.value: + try: + default = {"default": to_serializable(get_value(assignment.value))} + except UnicodeDecodeError: + pass return assignment.target.id, { "title": assignment.target.id.replace("_", " ").title(), "type": OPENAPI_TYPES[get_annotation(assignment.annotation)], @@ -366,21 +389,21 @@ def parse_assignment(assignment: ast.AST) -> "tuple[str | None, dict | None]": } if isinstance(assignment, ast.Assign): if len(assignment.targets) == 1 and isinstance(assignment.targets[0], ast.Name): - value = get_value(assignment.value) + value = to_serializable(get_value(assignment.value)) return assignment.targets[0].id, { "title": assignment.targets[0].id.replace("_", " ").title(), "type": OPENAPI_TYPES[type(value).__name__], "default": value, } raise ValueError("Unexpected assignment", assignment) - return None, None + return None -def parse_class(classdef: ast.AST) -> dict: +def parse_class(classdef: ast.AST) -> "JSONDict": """Parse a class definition into an OpenAPI object""" assert isinstance(classdef, ast.ClassDef) properties = { - key: property for key, property in map(parse_assignment, classdef.body) if key + assignment[0]: assignment[1] for assignment in map(parse_assignment, classdef.body) if assignment } return { "title": classdef.name, @@ -404,8 +427,8 @@ def resolve_name(node: ast.expr) -> str: if isinstance(node, ast.Name): return node.id if isinstance(node, ast.Index): - # depricated, but needed for py3.8 - return resolve_name(node.value) + # deprecated, but needed for py3.8 + return resolve_name(node.value) # type: ignore if isinstance(node, ast.Attribute): return node.attr if isinstance(node, ast.Subscript): @@ -413,7 +436,7 @@ def resolve_name(node: ast.expr) -> str: raise ValueError("Unexpected node type", type(node), ast.unparse(node)) -def parse_return_annotation(tree: ast.AST, fn: str = "predict") -> "tuple[dict, dict]": +def parse_return_annotation(tree: ast.AST, fn: str = "predict") -> "tuple[JSONDict, JSONDict]": predict = find(tree, fn) if not isinstance(predict, ast.FunctionDef): raise ValueError("Could not find predict function") @@ -459,7 +482,7 @@ def predict( format = {"format": "uri"} if name in ("Path", "File") else {} return {}, {"title": "Output", "type": OPENAPI_TYPES.get(name, name), **format} # it must be a custom object - schema = {name: parse_class(find(tree, name))} + schema: "JSONDict" = {name: parse_class(find(tree, name))} return schema, { "title": "Output", "$ref": f"#/components/schemas/{name}", @@ -469,24 +492,30 @@ def predict( KEPT_ATTRS = ("description", "default", "ge", "le", "max_length", "min_length", "regex") -def extract_info(code: str) -> dict: +def extract_info(code: str) -> "JSONDict": """Parse the schemas from a file with a predict function""" tree = ast.parse(code) - inputs = {"title": "Input", "type": "object", "properties": {}} + properties: "JSONDict" = {} + inputs: "JSONDict" = {"title": "Input", "type": "object", "properties": properties} required: "list[str]" = [] - schemas: "dict[str, dict]" = {} + schemas: "JSONDict" = {} for arg, default in parse_args(tree): if arg.arg == "self": continue if isinstance(default, ast.Call) and get_call_name(default) == "Input": - kws = {kw.arg: get_value(kw.value) for kw in default.keywords} + kws = {} + for kw in default.keywords: + if kw.arg is None: + msg = "unknown argument for Input" + raise ValueError(msg) + kws[kw.arg] = to_serializable(get_value(kw.value)) elif isinstance(default, (ast.Constant, ast.List, ast.Tuple, ast.Str, ast.Num)): - kws = {"default": get_value(default)} # could be None + kws = {"default": to_serializable(get_value(default))} # could be None elif default == ...: # no default kws = {} else: raise ValueError("Unexpected default value", default) - input: dict = {"x-order": len(inputs["properties"])} + input: "JSONDict" = {"x-order": len(properties)} # need to handle other types? arg_type = OPENAPI_TYPES.get(get_annotation(arg.annotation), "string") if get_annotation(arg.annotation) in ("Path", "File"): @@ -508,23 +537,25 @@ def extract_info(code: str) -> dict: else: input["title"] = arg.arg.replace("_", " ").title() input["type"] = arg_type - inputs["properties"][arg.arg] = input # type: ignore + properties[arg.arg] = input if required: - inputs["required"] = required + inputs["required"] = list(required) # List[Path], list[Path], str, Iterator[str], MyOutput, Output return_schema, output = parse_return_annotation(tree, "predict") - schema = json.loads(BASE_SCHEMA) - components = { + schema: "JSONDict" = json.loads(BASE_SCHEMA) + components: "JSONDict" = { "Input": inputs, "Output": output, **schemas, **return_schema, } - schema["components"]["schemas"].update(components) + # trust me, typechecker, I know BASE_SCHEMA + x: "JSONDict" = schema["components"]["schemas"] # type: ignore + x.update(components) return schema -def extract_file(fname: "str | Path") -> dict: +def extract_file(fname: "str | Path") -> "JSONObject": return extract_info(open(fname, encoding="utf-8").read()) diff --git a/python/cog/files.py b/python/cog/files.py index 4477f22fdc..2ca8cd383a 100644 --- a/python/cog/files.py +++ b/python/cog/files.py @@ -23,7 +23,8 @@ def upload_file(fh: io.IOBase, output_file_prefix: str = None) -> str: b = b.encode("utf-8") encoded_body = base64.b64encode(b) if getattr(fh, "name", None): - # despite doing a getattr check here, mypy complains that io.IOBase has no attribute name + # despite doing a getattr check here, pyright complains that io.IOBase has no attribute name + # TODO: switch to typing.IO[]? mime_type = mimetypes.guess_type(fh.name)[0] # type: ignore else: mime_type = "application/octet-stream" diff --git a/python/cog/json.py b/python/cog/json.py index 0e51ba8597..843572eea0 100644 --- a/python/cog/json.py +++ b/python/cog/json.py @@ -29,11 +29,9 @@ def make_encodeable(obj: Any) -> Any: return obj.isoformat() try: import numpy as np # type: ignore - - has_numpy = True except ImportError: - has_numpy = False - if has_numpy: + pass + else: if isinstance(obj, np.integer): return int(obj) if isinstance(obj, np.floating): diff --git a/python/cog/predictor.py b/python/cog/predictor.py index 5af98281ae..3e041f839b 100644 --- a/python/cog/predictor.py +++ b/python/cog/predictor.py @@ -4,6 +4,7 @@ import io import os.path import sys +import types from abc import ABC, abstractmethod from collections.abc import Iterator from pathlib import Path @@ -15,13 +16,14 @@ Optional, Type, Union, + cast, ) from unittest.mock import patch try: from typing import get_args, get_origin except ImportError: # Python < 3.8 - from typing_compat import get_args, get_origin + from typing_compat import get_args, get_origin # type: ignore import yaml from pydantic import BaseModel, Field, create_model @@ -42,7 +44,7 @@ Path as CogPath, ) -ALLOWED_INPUT_TYPES = [str, int, float, bool, CogFile, CogPath] +ALLOWED_INPUT_TYPES: List[Type[Any]] = [str, int, float, bool, CogFile, CogPath] class BasePredictor(ABC): @@ -76,18 +78,20 @@ def run_setup(predictor: BasePredictor) -> None: # TODO: Cog{File,Path}.validate(...) methods accept either "real" # paths/files or URLs to those things. In future we can probably tidy this # up a little bit. + # TODO: CogFile/CogPath should have subclasses for each of the subtypes if weights_url: if weights_type == CogFile: - weights = CogFile.validate(weights_url) + weights = cast(CogFile, CogFile.validate(weights_url)) elif weights_type == CogPath: - weights = CogPath.validate(weights_url) + # TODO: So this can be a url. evil! + weights = cast(CogPath, CogPath.validate(weights_url)) else: raise ValueError( f"Predictor.setup() has an argument 'weights' of type {weights_type}, but only File and Path are supported" ) elif os.path.exists(weights_path): if weights_type == CogFile: - weights = open(weights_path, "rb") + weights = cast(CogFile, open(weights_path, "rb")) elif weights_type == CogPath: weights = CogPath(weights_path) else: @@ -100,7 +104,7 @@ def run_setup(predictor: BasePredictor) -> None: predictor.setup(weights=weights) -def get_weights_type(setup_function: Callable) -> Optional[Any]: +def get_weights_type(setup_function: Callable[[Any], None]) -> Optional[Any]: signature = inspect.signature(setup_function) if "weights" not in signature.parameters: return None @@ -114,7 +118,7 @@ def get_weights_type(setup_function: Callable) -> Optional[Any]: def run_prediction( - predictor: BasePredictor, inputs: Dict[Any, Any], cleanup_functions: List[Callable] + predictor: BasePredictor, inputs: Dict[Any, Any], cleanup_functions: List[Callable[[], None]], ) -> Any: """ Run the predictor on the inputs, and append resulting paths @@ -209,18 +213,18 @@ def cleanup(self) -> None: pass -def get_predict(predictor: Any) -> Callable: +def get_predict(predictor: Any) -> Callable[..., Any]: if hasattr(predictor, "predict"): return predictor.predict return predictor -def validate_input_type(type: Type, name: str) -> None: +def validate_input_type(type: Type[Any], name: str) -> None: if type is inspect.Signature.empty: raise TypeError( f"No input type provided for parameter `{name}`. Supported input types are: {readable_types_list(ALLOWED_INPUT_TYPES)}, or a Union or List of those types." ) elif type not in ALLOWED_INPUT_TYPES: - if hasattr(type, "__origin__") and (type.__origin__ is Union or type.__origin__ is list): + if get_origin(type) in (Union, List, list) or (hasattr(types, "UnionType") and get_origin(type) is types.UnionType): # noqa: E721 for t in get_args(type): validate_input_type(t, name) else: @@ -305,6 +309,7 @@ def get_output_type(predictor: BasePredictor) -> Type[BaseModel]: """ predict = get_predict(predictor) signature = inspect.signature(predict) + OutputType: Type[BaseModel] if signature.return_annotation is inspect.Signature.empty: raise TypeError( """You must set an output type. If your model can return multiple output types, you can explicitly set `Any` as the output type. @@ -327,7 +332,7 @@ def predict( if get_origin(OutputType) is Iterator: # Annotated allows us to attach Field annotations to the list, which we use to mark that this is an iterator # https://pydantic-docs.helpmanual.io/usage/schema/#typingannotated-fields - OutputType = Annotated[List[get_args(OutputType)[0]], Field(**{"x-cog-array-type": "iterator"})] # type: ignore + OutputType: Type[BaseModel] = Annotated[List[get_args(OutputType)[0]], Field(**{"x-cog-array-type": "iterator"})] # type: ignore if not hasattr(OutputType, "__name__") or OutputType.__name__ != "Output": # Wrap the type in a model called "Output" so it is a consistent name in the OpenAPI schema @@ -339,7 +344,7 @@ class Output(BaseModel): return OutputType -def human_readable_type_name(t: Type) -> str: +def human_readable_type_name(t: Type[Any]) -> str: """ Generates a useful-for-humans label for a type. For builtin types, it's just the class name (eg "str" or "int"). For other types, it includes the module (eg "pathlib.Path" or "cog.File"). @@ -357,5 +362,5 @@ def human_readable_type_name(t: Type) -> str: return str(t) -def readable_types_list(type_list: List[Type]) -> str: +def readable_types_list(type_list: List[Type[Any]]) -> str: return ", ".join(human_readable_type_name(t) for t in type_list) diff --git a/python/cog/schema.py b/python/cog/schema.py index 22948bd6a8..c11fa77df7 100644 --- a/python/cog/schema.py +++ b/python/cog/schema.py @@ -48,7 +48,7 @@ class PredictionRequest(PredictionBaseModel): ] = WebhookEvent.default_events() @classmethod - def with_types(cls, input_type: t.Type) -> t.Any: + def with_types(cls, input_type: t.Type[t.Any]) -> t.Any: # [compat] Input is implicitly optional -- previous versions of the # Cog HTTP API allowed input to be omitted (e.g. for models that don't # have any inputs). We should consider changing this in future. @@ -74,7 +74,7 @@ class PredictionResponse(PredictionBaseModel): metrics: t.Optional[t.Dict[str, t.Any]] @classmethod - def with_types(cls, input_type: t.Type, output_type: t.Type) -> t.Any: + def with_types(cls, input_type: t.Type[t.Any], output_type: t.Type[t.Any]) -> t.Any: # [compat] Input is implicitly optional -- previous versions of the # Cog HTTP API allowed input to be omitted (e.g. for models that don't # have any inputs). We should consider changing this in future. diff --git a/python/cog/server/http.py b/python/cog/server/http.py index 2fe8e2f29a..4b003ebc50 100644 --- a/python/cog/server/http.py +++ b/python/cog/server/http.py @@ -9,7 +9,19 @@ import textwrap import threading from enum import Enum, auto, unique -from typing import Any, Callable, Dict, Optional, Union +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + Optional, + TypeVar, + Union, +) + +if TYPE_CHECKING: + from typing import ParamSpec import structlog import uvicorn @@ -44,6 +56,13 @@ class Health(Enum): BUSY = auto() SETUP_FAILED = auto() +class State: + health: Health + setup_result: "Optional[asyncio.Task[schema.PredictionResponse]]" + setup_result_payload: Optional[schema.PredictionResponse] + +class MyFastAPI(FastAPI): + state: State def create_app( config: Dict[str, Any], @@ -51,8 +70,8 @@ def create_app( threads: int = 1, upload_url: Optional[str] = None, mode: str = "predict", -) -> FastAPI: - app = FastAPI( +) -> MyFastAPI: + app = MyFastAPI( title="Cog", # TODO: mention model name? # version=None # TODO ) @@ -74,16 +93,20 @@ def create_app( InputType = get_input_type(predictor) OutputType = get_output_type(predictor) - PredictionRequest = schema.PredictionRequest.with_types(input_type=InputType) + class PredictionRequest(schema.PredictionRequest.with_types(input_type=InputType)): + pass PredictionResponse = schema.PredictionResponse.with_types( input_type=InputType, output_type=OutputType ) http_semaphore = asyncio.Semaphore(threads) - def limited(f: Callable) -> Callable: + if TYPE_CHECKING: + P = ParamSpec("P") + T = TypeVar("T") + def limited(f: "Callable[P, Awaitable[T]]") -> "Callable[P, Awaitable[T]]": @functools.wraps(f) - async def wrapped(*args: Any, **kwargs: Any) -> Any: + async def wrapped(*args: "P.args", **kwargs: "P.kwargs") -> "T": async with http_semaphore: return await f(*args, **kwargs) @@ -175,7 +198,7 @@ async def predict_idempotent( return await _predict(request=request, respond_async=respond_async) async def _predict( - *, request: PredictionRequest, respond_async: bool = False + *, request: Optional[PredictionRequest], respond_async: bool = False ) -> Response: # [compat] If no body is supplied, assume that this model can be run # with empty input. This will throw a ValidationError if that's not @@ -250,7 +273,7 @@ async def _check_setup_result() -> Any: # this can raise CancelledError result = app.state.setup_result.result() - if result["status"] == schema.Status.SUCCEEDED: + if result.status == schema.Status.SUCCEEDED: app.state.health = Health.READY else: app.state.health = Health.SETUP_FAILED @@ -312,7 +335,7 @@ def signal_ignore(signum: Any, frame: Any) -> None: log.warn("Got a signal to exit, ignoring it...", signal=signal.Signals(signum).name) -def signal_set_event(event: threading.Event) -> Callable: +def signal_set_event(event: threading.Event) -> Callable[[Any, Any], None]: def _signal_set_event(signum: Any, frame: Any) -> None: event.set() @@ -361,12 +384,12 @@ def _signal_set_event(signum: Any, frame: Any) -> None: config = load_config() - threads = args.threads + threads: Optional[int] = args.threads if threads is None: if config.get("build", {}).get("gpu", False): threads = 1 else: - threads = os.cpu_count() + threads = max(1, len(os.sched_getaffinity(0))) shutdown_event = threading.Event() app = create_app( diff --git a/python/cog/server/response_throttler.py b/python/cog/server/response_throttler.py index 8cbeda21d2..41e2ed0312 100644 --- a/python/cog/server/response_throttler.py +++ b/python/cog/server/response_throttler.py @@ -1,4 +1,5 @@ import time +from typing import Any, Dict from ..schema import Status @@ -8,7 +9,7 @@ def __init__(self, response_interval: float) -> None: self.last_sent_response_time = 0.0 self.response_interval = response_interval - def should_send_response(self, response: dict) -> bool: + def should_send_response(self, response: Dict[str, Any]) -> bool: if Status.is_terminal(response["status"]): return True diff --git a/python/cog/server/runner.py b/python/cog/server/runner.py index 35c27cfb85..19b3fc0c61 100644 --- a/python/cog/server/runner.py +++ b/python/cog/server/runner.py @@ -1,9 +1,11 @@ import asyncio import io +import threading import traceback +import typing # TypeAlias, py3.10 from asyncio import Task from datetime import datetime, timezone -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Optional, Tuple, cast import requests import structlog @@ -33,17 +35,18 @@ class RunnerBusyError(Exception): class UnknownPredictionError(Exception): pass +PredictionTask: "typing.TypeAlias" = "Task[schema.PredictionResponse]" class PredictionRunner: def __init__( self, *, predictor_ref: str, - shutdown_event: Optional[asyncio.Event], + shutdown_event: Optional[threading.Event], upload_url: Optional[str] = None, ) -> None: self._response: Optional[schema.PredictionResponse] = None - self._result: Optional[Task] = None + self._result: "Optional[PredictionTask]" = None self._worker = Worker(predictor_ref=predictor_ref) self._should_cancel = asyncio.Event() @@ -51,8 +54,8 @@ def __init__( self._shutdown_event = shutdown_event self._upload_url = upload_url - def make_error_handler(self, activity: str) -> Callable: - def handle_error(task: Task) -> None: + def make_error_handler(self, activity: str) -> Callable[[PredictionTask], None]: + def handle_error(task: PredictionTask) -> None: exc = task.exception() if not exc: return @@ -68,7 +71,7 @@ def handle_error(task: Task) -> None: return handle_error - def setup(self) -> "Task[dict[str, Any]]": + def setup(self) -> "Task[schema.PredictionResponse]": if self.is_busy(): raise RunnerBusyError() self._result = asyncio.create_task(setup(worker=self._worker)) @@ -100,9 +103,10 @@ def predict( upload_url = self._upload_url if upload else None event_handler = create_event_handler(prediction, upload_url=upload_url) - def handle_cleanup(_: Task) -> None: - if hasattr(prediction.input, "cleanup"): - prediction.input.cleanup() + def handle_cleanup(_: PredictionTask) -> None: + input = cast(Any, prediction.input) + if hasattr(input, "cleanup"): + input.cleanup() self._response = event_handler.response coro = predict( @@ -154,7 +158,7 @@ def create_event_handler( webhook_sender = None if webhook is not None: - webhook_sender = webhook_caller_filtered(webhook, events_filter) + webhook_sender = webhook_caller_filtered(webhook, set(events_filter)) file_uploader = None if upload_url is not None: @@ -167,7 +171,7 @@ def create_event_handler( return event_handler -def generate_file_uploader(upload_url: str) -> Callable: +def generate_file_uploader(upload_url: str) -> Callable[[Any], Any]: client = _make_file_upload_http_client() def file_uploader(output: Any) -> Any: @@ -183,8 +187,8 @@ class PredictionEventHandler: def __init__( self, p: schema.PredictionResponse, - webhook_sender: Optional[Callable] = None, - file_uploader: Optional[Callable] = None, + webhook_sender: Optional[Callable[[Any, schema.WebhookEvent], None]] = None, + file_uploader: Optional[Callable[[Any], Any]] = None, ) -> None: log.info("starting prediction") self.p = p @@ -273,7 +277,7 @@ def _upload_files(self, output: Any) -> Any: raise FileUploadError("Got error trying to upload output files") from error -async def setup(*, worker: Worker) -> Dict[str, Any]: +async def setup(*, worker: Worker) -> schema.PredictionResponse: logs = [] status = None started_at = datetime.now(tz=timezone.utc) @@ -303,12 +307,19 @@ async def setup(*, worker: Worker) -> Dict[str, Any]: probes = ProbeHelper() probes.ready() - return { - "logs": "".join(logs), - "status": status, - "started_at": started_at, - "completed_at": completed_at, - } + return schema.PredictionResponse( + input={}, + output=None, + id=None, + version=None, + created_at=None, + started_at=started_at, + completed_at=completed_at, + logs="".join(logs), + error=None, + metrics=None, + status=status + ) async def predict( @@ -394,7 +405,7 @@ async def _predict( else: event_handler.set_output(event.payload) - elif isinstance(event, Done): + elif isinstance(event, Done): # pyright: ignore reportUnnecessaryIsinstance if event.canceled: event_handler.canceled() elif event.error: @@ -402,7 +413,7 @@ async def _predict( else: event_handler.succeeded() - else: + else: # shouldn't happen, exhausted the type log.warn("received unexpected event from worker", data=event) return event_handler.response diff --git a/python/cog/server/webhook.py b/python/cog/server/webhook.py index e55f6c910a..9cd4ea59c8 100644 --- a/python/cog/server/webhook.py +++ b/python/cog/server/webhook.py @@ -13,14 +13,12 @@ def _get_version() -> str: - use_importlib = True try: - from importlib.metadata import version - except ImportError: - use_importlib = False - - try: - if use_importlib: + try: + from importlib.metadata import version + except ImportError: + pass + else: return version("cog") import pkg_resources @@ -39,8 +37,8 @@ def _get_version() -> str: def webhook_caller_filtered( - webhook: str, webhook_events_filter: Set[WebhookEvent] -) -> Callable: + webhook: str, webhook_events_filter: Set[WebhookEvent], +) -> Callable[[Any, WebhookEvent], None]: upstream_caller = webhook_caller(webhook) def caller(response: Any, event: WebhookEvent) -> None: @@ -50,7 +48,7 @@ def caller(response: Any, event: WebhookEvent) -> None: return caller -def webhook_caller(webhook: str) -> Callable: +def webhook_caller(webhook: str) -> Callable[[Any], None]: # TODO: we probably don't need to create new sessions and new throttlers # for every prediction. throttler = ResponseThrottler(response_interval=_response_interval) diff --git a/python/cog/server/worker.py b/python/cog/server/worker.py index c51d3c62d6..790eea361d 100644 --- a/python/cog/server/worker.py +++ b/python/cog/server/worker.py @@ -88,7 +88,7 @@ def terminate(self) -> None: self._child.join() def cancel(self) -> None: - if self._allow_cancel and self._child.is_alive(): + if self._allow_cancel and self._child.is_alive() and self._child.pid is not None: os.kill(self._child.pid, signal.SIGUSR1) self._allow_cancel = False diff --git a/python/cog/types.py b/python/cog/types.py index aa61275b3b..d4a95bdfa7 100644 --- a/python/cog/types.py +++ b/python/cog/types.py @@ -4,7 +4,8 @@ import pathlib import shutil import tempfile -import urllib +import urllib.parse +import urllib.request from typing import Any, Dict, Iterator, List, Optional, TypeVar, Union import requests @@ -256,15 +257,15 @@ def __get_validators__(cls) -> Iterator[Any]: yield cls.validate @classmethod - def validate(cls, value: Any) -> Iterator: + def validate(cls, value: Iterator[Any]) -> Iterator[Any]: return value -def _len_bytes(s, encoding="utf-8") -> int: +def _len_bytes(s: str, encoding: str="utf-8") -> int: return len(s.encode(encoding)) -def _truncate_filename_bytes(s, length, encoding="utf-8") -> str: +def _truncate_filename_bytes(s: str, length: int, encoding: str="utf-8") -> str: """ Truncate a filename to at most `length` bytes, preserving file extension and avoiding text encoding corruption from truncation. diff --git a/python/tests/server/test_runner.py b/python/tests/server/test_runner.py index 2a2ae9a045..b0151b9f54 100644 --- a/python/tests/server/test_runner.py +++ b/python/tests/server/test_runner.py @@ -48,10 +48,10 @@ async def test_prediction_runner_setup(): try: result = await runner.setup() - assert result["status"] == Status.SUCCEEDED - assert result["logs"] == "" - assert isinstance(result["started_at"], datetime) - assert isinstance(result["completed_at"], datetime) + assert result.status == Status.SUCCEEDED + assert result.logs == "" + assert isinstance(result.started_at, datetime) + assert isinstance(result.completed_at, datetime) finally: runner.shutdown()