diff --git a/pyproject.toml b/pyproject.toml index 90bc98f64f..0900f3a7d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,7 @@ module = [ "trio._ki", "trio._socket", "trio._sync", + "trio._tools.gen_exports", "trio._util", ] disallow_incomplete_defs = true diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py index a5d8529b53..bae7e4f69d 100755 --- a/trio/_tools/gen_exports.py +++ b/trio/_tools/gen_exports.py @@ -3,12 +3,19 @@ Code generation script for class methods to be exported as public API """ +from __future__ import annotations + import argparse import ast import os import sys +from collections.abc import Iterable, Iterator from pathlib import Path from textwrap import indent +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing_extensions import TypeGuard import astor @@ -36,7 +43,7 @@ """ -def is_function(node): +def is_function(node: ast.AST) -> TypeGuard[ast.FunctionDef | ast.AsyncFunctionDef]: """Check if the AST node is either a function or an async function """ @@ -45,17 +52,18 @@ def is_function(node): return False -def is_public(node): +def is_public(node: ast.AST) -> TypeGuard[ast.FunctionDef | ast.AsyncFunctionDef]: """Check if the AST node has a _public decorator""" - if not is_function(node): - return False - for decorator in node.decorator_list: - if isinstance(decorator, ast.Name) and decorator.id == "_public": - return True + if is_function(node): + for decorator in node.decorator_list: + if isinstance(decorator, ast.Name) and decorator.id == "_public": + return True return False -def get_public_methods(tree): +def get_public_methods( + tree: ast.AST, +) -> Iterator[ast.FunctionDef | ast.AsyncFunctionDef]: """Return a list of methods marked as public. The function walks the given tree and extracts all objects that are functions which are marked @@ -66,7 +74,7 @@ def get_public_methods(tree): yield node -def create_passthrough_args(funcdef): +def create_passthrough_args(funcdef: ast.FunctionDef | ast.AsyncFunctionDef) -> str: """Given a function definition, create a string that represents taking all the arguments from the function, and passing them through to another invocation of the same function. @@ -130,7 +138,7 @@ def gen_public_wrappers_source(source_path: Path, lookup_path: str) -> str: return "\n\n".join(generated) -def matches_disk_files(new_files): +def matches_disk_files(new_files: dict[str, str]) -> bool: for new_path, new_source in new_files.items(): if not os.path.exists(new_path): return False @@ -141,7 +149,7 @@ def matches_disk_files(new_files): return True -def process(sources_and_lookups, *, do_test): +def process(sources_and_lookups: Iterable[tuple[Path, str]], *, do_test: bool) -> None: new_files = {} for source_path, lookup_path in sources_and_lookups: print("Scanning:", source_path) @@ -164,7 +172,7 @@ def process(sources_and_lookups, *, do_test): # This is in fact run in CI, but only in the formatting check job, which # doesn't collect coverage. -def main(): # pragma: no cover +def main() -> None: # pragma: no cover parser = argparse.ArgumentParser( description="Generate python code for public api wrappers" )