Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ module = [
"trio._ki",
"trio._socket",
"trio._sync",
"trio._tools.gen_exports",
"trio._util",
]
disallow_incomplete_defs = true
Expand Down
32 changes: 20 additions & 12 deletions trio/_tools/gen_exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,19 @@
Code generation script for class methods
to be exported as public API
"""
from __future__ import annotations

import argparse
import ast
import os
import sys
from collections.abc import Iterable, Iterator
from pathlib import Path
from textwrap import indent
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing_extensions import TypeGuard

import astor

Expand Down Expand Up @@ -36,7 +43,7 @@
"""


def is_function(node):
def is_function(node: ast.AST) -> TypeGuard[ast.FunctionDef | ast.AsyncFunctionDef]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Since there's a bunch of ast.FunctionDef | ast.AsyncFunctionDef I created a TypeAlias for it, and while it does make the signatures a bit easier to parse I'm not even sure myself it's worth it.

"""Check if the AST node is either a function
or an async function
"""
Expand All @@ -45,17 +52,18 @@ def is_function(node):
return False


def is_public(node):
def is_public(node: ast.AST) -> TypeGuard[ast.FunctionDef | ast.AsyncFunctionDef]:
"""Check if the AST node has a _public decorator"""
if not is_function(node):
return False
for decorator in node.decorator_list:
if isinstance(decorator, ast.Name) and decorator.id == "_public":
return True
if is_function(node):
for decorator in node.decorator_list:
if isinstance(decorator, ast.Name) and decorator.id == "_public":
return True
return False


def get_public_methods(tree):
def get_public_methods(
tree: ast.AST,
) -> Iterator[ast.FunctionDef | ast.AsyncFunctionDef]:
"""Return a list of methods marked as public.
The function walks the given tree and extracts
all objects that are functions which are marked
Expand All @@ -66,7 +74,7 @@ def get_public_methods(tree):
yield node


def create_passthrough_args(funcdef):
def create_passthrough_args(funcdef: ast.FunctionDef | ast.AsyncFunctionDef) -> str:
"""Given a function definition, create a string that represents taking all
the arguments from the function, and passing them through to another
invocation of the same function.
Expand Down Expand Up @@ -130,7 +138,7 @@ def gen_public_wrappers_source(source_path: Path, lookup_path: str) -> str:
return "\n\n".join(generated)


def matches_disk_files(new_files):
def matches_disk_files(new_files: dict[str, str]) -> bool:
for new_path, new_source in new_files.items():
if not os.path.exists(new_path):
return False
Expand All @@ -141,7 +149,7 @@ def matches_disk_files(new_files):
return True


def process(sources_and_lookups, *, do_test):
def process(sources_and_lookups: Iterable[tuple[Path, str]], *, do_test: bool) -> None:
new_files = {}
for source_path, lookup_path in sources_and_lookups:
print("Scanning:", source_path)
Expand All @@ -164,7 +172,7 @@ def process(sources_and_lookups, *, do_test):

# This is in fact run in CI, but only in the formatting check job, which
# doesn't collect coverage.
def main(): # pragma: no cover
def main() -> None: # pragma: no cover
parser = argparse.ArgumentParser(
description="Generate python code for public api wrappers"
)
Expand Down