From 7dcab2acac3e31a82ff071c38a2a87fba3a84c8f Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Sat, 5 Aug 2023 12:35:55 -0500 Subject: [PATCH 1/2] Add typing to `_tools/gen_exports.py` --- trio/_tools/gen_exports.py | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py index a5d8529b53..0907d3b6f8 100755 --- a/trio/_tools/gen_exports.py +++ b/trio/_tools/gen_exports.py @@ -3,12 +3,19 @@ Code generation script for class methods to be exported as public API """ +from __future__ import annotations + import argparse import ast import os import sys +from collections.abc import Generator, Iterable from pathlib import Path from textwrap import indent +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing_extensions import TypeGuard import astor @@ -36,7 +43,7 @@ """ -def is_function(node): +def is_function(node: ast.AST) -> TypeGuard[ast.FunctionDef | ast.AsyncFunctionDef]: """Check if the AST node is either a function or an async function """ @@ -45,17 +52,18 @@ def is_function(node): return False -def is_public(node): +def is_public(node: ast.AST) -> TypeGuard[ast.FunctionDef | ast.AsyncFunctionDef]: """Check if the AST node has a _public decorator""" - if not is_function(node): - return False - for decorator in node.decorator_list: - if isinstance(decorator, ast.Name) and decorator.id == "_public": - return True + if is_function(node): + for decorator in node.decorator_list: + if isinstance(decorator, ast.Name) and decorator.id == "_public": + return True return False -def get_public_methods(tree): +def get_public_methods( + tree: ast.AST, +) -> Generator[ast.FunctionDef | ast.AsyncFunctionDef, None, None]: """Return a list of methods marked as public. The function walks the given tree and extracts all objects that are functions which are marked @@ -66,7 +74,7 @@ def get_public_methods(tree): yield node -def create_passthrough_args(funcdef): +def create_passthrough_args(funcdef: ast.FunctionDef | ast.AsyncFunctionDef) -> str: """Given a function definition, create a string that represents taking all the arguments from the function, and passing them through to another invocation of the same function. @@ -130,7 +138,7 @@ def gen_public_wrappers_source(source_path: Path, lookup_path: str) -> str: return "\n\n".join(generated) -def matches_disk_files(new_files): +def matches_disk_files(new_files: dict[str, str]) -> bool: for new_path, new_source in new_files.items(): if not os.path.exists(new_path): return False @@ -141,7 +149,7 @@ def matches_disk_files(new_files): return True -def process(sources_and_lookups, *, do_test): +def process(sources_and_lookups: Iterable[tuple[Path, str]], *, do_test: bool) -> None: new_files = {} for source_path, lookup_path in sources_and_lookups: print("Scanning:", source_path) @@ -164,7 +172,7 @@ def process(sources_and_lookups, *, do_test): # This is in fact run in CI, but only in the formatting check job, which # doesn't collect coverage. -def main(): # pragma: no cover +def main() -> None: # pragma: no cover parser = argparse.ArgumentParser( description="Generate python code for public api wrappers" ) From 323dbd150af2e7147288010f3f0b227aa4f076fa Mon Sep 17 00:00:00 2001 From: jakkdl