From abdc65e775975644c504ca03c630330ce0061661 Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Sat, 29 Jul 2023 19:35:33 -0500 Subject: [PATCH] Add typing for `_util.py` --- trio/_util.py | 45 ++++++++++++++++++++++++++++++--------------- 1 file changed, 30 insertions(+), 15 deletions(-) diff --git a/trio/_util.py b/trio/_util.py index a87f1fc02c..94d6334162 100644 --- a/trio/_util.py +++ b/trio/_util.py @@ -9,10 +9,13 @@ import typing as t from abc import ABCMeta from functools import update_wrapper -from types import TracebackType +from types import AsyncGeneratorType, TracebackType import trio +if t.TYPE_CHECKING: + from typing_extensions import Self + CallT = t.TypeVar("CallT", bound=t.Callable[..., t.Any]) @@ -61,7 +64,7 @@ signal_raise = getattr(_lib, "raise") else: - def signal_raise(signum): + def signal_raise(signum: int) -> None: signal.pthread_kill(threading.get_ident(), signum) @@ -73,7 +76,7 @@ def signal_raise(signum): # Trying to use signal out of the main thread will fail, so we can then # reliably check if this is the main thread without relying on a # potentially modified threading. -def is_main_thread(): +def is_main_thread() -> bool: """Attempt to reliably check if we are in the main thread.""" try: signal.signal(signal.SIGINT, signal.getsignal(signal.SIGINT)) @@ -86,8 +89,14 @@ def is_main_thread(): # Call the function and get the coroutine object, while giving helpful # errors for common mistakes. Returns coroutine object. ###### -def coroutine_or_error(async_fn, *args): - def _return_value_looks_like_wrong_library(value): +def coroutine_or_error( + async_fn: collections.abc.Callable[ + ..., collections.abc.Coroutine[t.Any, t.Any, t.Any] + ] + | collections.abc.Coroutine[t.Any, t.Any, t.Any], + *args: t.Any, +) -> collections.abc.Coroutine[t.Any, t.Any, t.Any]: + def _return_value_looks_like_wrong_library(value: object) -> bool: # Returned by legacy @asyncio.coroutine functions, which includes # a surprising proportion of asyncio builtins. if isinstance(value, collections.abc.Generator): @@ -103,7 +112,7 @@ def _return_value_looks_like_wrong_library(value): return False try: - coro = async_fn(*args) + coro = async_fn(*args) # type: ignore[operator] # Coroutine not callable as intended except TypeError: # Give good error for: nursery.start_soon(trio.sleep(1)) @@ -183,11 +192,13 @@ class ConflictDetector: """ - def __init__(self, msg): + __slots__ = ("_msg", "_held") + + def __init__(self, msg: str) -> None: self._msg = msg self._held = False - def __enter__(self): + def __enter__(self) -> None: if self._held: raise trio.BusyResourceError(self._msg) else: @@ -224,10 +235,10 @@ def decorator(func: CallT) -> CallT: return decorator -def fixup_module_metadata(module_name, namespace): +def fixup_module_metadata(module_name: str, namespace: dict[str, object]) -> None: seen_ids = set() - def fix_one(qualname, name, obj): + def fix_one(qualname: str, name: str, obj: object) -> None: # avoid infinite recursion (relevant when using # typing.Generic, for example) if id(obj) in seen_ids: @@ -242,7 +253,7 @@ def fix_one(qualname, name, obj): # rewriting these. if hasattr(obj, "__name__") and "." not in obj.__name__: obj.__name__ = name - obj.__qualname__ = qualname + obj.__qualname__ = qualname # type: ignore[attr-defined] # object doesn't have __qualname__ attribute if isinstance(obj, type): for attr_name, attr_value in obj.__dict__.items(): fix_one(objname + "." + attr_name, attr_name, attr_value) @@ -269,14 +280,14 @@ def open_memory_channel(max_buffer_size: int) -> Tuple[ but at least it becomes possible to write those. """ - def __init__(self, fn): + def __init__(self, fn: collections.abc.Callable[..., t.Any]) -> None: update_wrapper(self, fn) self._fn = fn - def __call__(self, *args, **kwargs): + def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Any: return self._fn(*args, **kwargs) - def __getitem__(self, _): + def __getitem__(self, _: t.Any) -> Self: return self @@ -295,6 +306,8 @@ class SomeClass(metaclass=Final): - TypeError if a subclass is created """ + __slots__ = () + def __new__( cls, name: str, bases: tuple[type, ...], cls_namespace: dict[str, object] ) -> Final: @@ -329,6 +342,8 @@ class SomeClass(metaclass=NoPublicConstructor): - TypeError if a subclass or an instance is created. """ + __slots__ = () + def __call__(cls, *args: object, **kwargs: object) -> None: raise TypeError( f"{cls.__module__}.{cls.__qualname__} has no public constructor" @@ -338,7 +353,7 @@ def _create(cls: t.Type[T], *args: object, **kwargs: object) -> T: return super().__call__(*args, **kwargs) # type: ignore -def name_asyncgen(agen): +def name_asyncgen(agen: AsyncGeneratorType[t.Any, t.Any]) -> str: """Return the fully-qualified name of the async generator function that produced the async generator iterator *agen*. """