Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 30 additions & 15 deletions trio/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@
import typing as t
from abc import ABCMeta
from functools import update_wrapper
from types import TracebackType
from types import AsyncGeneratorType, TracebackType

import trio

if t.TYPE_CHECKING:
from typing_extensions import Self

CallT = t.TypeVar("CallT", bound=t.Callable[..., t.Any])


Expand Down Expand Up @@ -61,7 +64,7 @@
signal_raise = getattr(_lib, "raise")
else:

def signal_raise(signum):
def signal_raise(signum: int) -> None:
signal.pthread_kill(threading.get_ident(), signum)


Expand All @@ -73,7 +76,7 @@ def signal_raise(signum):
# Trying to use signal out of the main thread will fail, so we can then
# reliably check if this is the main thread without relying on a
# potentially modified threading.
def is_main_thread():
def is_main_thread() -> bool:
"""Attempt to reliably check if we are in the main thread."""
try:
signal.signal(signal.SIGINT, signal.getsignal(signal.SIGINT))
Expand All @@ -86,8 +89,14 @@ def is_main_thread():
# Call the function and get the coroutine object, while giving helpful
# errors for common mistakes. Returns coroutine object.
######
def coroutine_or_error(async_fn, *args):
def _return_value_looks_like_wrong_library(value):
def coroutine_or_error(
async_fn: collections.abc.Callable[
..., collections.abc.Coroutine[t.Any, t.Any, t.Any]
]
| collections.abc.Coroutine[t.Any, t.Any, t.Any],
*args: t.Any,
) -> collections.abc.Coroutine[t.Any, t.Any, t.Any]:
def _return_value_looks_like_wrong_library(value: object) -> bool:
# Returned by legacy @asyncio.coroutine functions, which includes
# a surprising proportion of asyncio builtins.
if isinstance(value, collections.abc.Generator):
Expand All @@ -103,7 +112,7 @@ def _return_value_looks_like_wrong_library(value):
return False

try:
coro = async_fn(*args)
coro = async_fn(*args) # type: ignore[operator] # Coroutine not callable as intended

except TypeError:
# Give good error for: nursery.start_soon(trio.sleep(1))
Expand Down Expand Up @@ -183,11 +192,13 @@ class ConflictDetector:

"""

def __init__(self, msg):
__slots__ = ("_msg", "_held")

def __init__(self, msg: str) -> None:
self._msg = msg
self._held = False

def __enter__(self):
def __enter__(self) -> None:
if self._held:
raise trio.BusyResourceError(self._msg)
else:
Expand Down Expand Up @@ -224,10 +235,10 @@ def decorator(func: CallT) -> CallT:
return decorator


def fixup_module_metadata(module_name, namespace):
def fixup_module_metadata(module_name: str, namespace: dict[str, object]) -> None:
seen_ids = set()

def fix_one(qualname, name, obj):
def fix_one(qualname: str, name: str, obj: object) -> None:
# avoid infinite recursion (relevant when using
# typing.Generic, for example)
if id(obj) in seen_ids:
Expand All @@ -242,7 +253,7 @@ def fix_one(qualname, name, obj):
# rewriting these.
if hasattr(obj, "__name__") and "." not in obj.__name__:
obj.__name__ = name
obj.__qualname__ = qualname
obj.__qualname__ = qualname # type: ignore[attr-defined] # object doesn't have __qualname__ attribute
if isinstance(obj, type):
for attr_name, attr_value in obj.__dict__.items():
fix_one(objname + "." + attr_name, attr_name, attr_value)
Expand All @@ -269,14 +280,14 @@ def open_memory_channel(max_buffer_size: int) -> Tuple[
but at least it becomes possible to write those.
"""

def __init__(self, fn):
def __init__(self, fn: collections.abc.Callable[..., t.Any]) -> None:
update_wrapper(self, fn)
self._fn = fn

def __call__(self, *args, **kwargs):
def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
return self._fn(*args, **kwargs)

def __getitem__(self, _):
def __getitem__(self, _: t.Any) -> Self:
return self


Expand All @@ -295,6 +306,8 @@ class SomeClass(metaclass=Final):
- TypeError if a subclass is created
"""

__slots__ = ()

def __new__(
cls, name: str, bases: tuple[type, ...], cls_namespace: dict[str, object]
) -> Final:
Expand Down Expand Up @@ -329,6 +342,8 @@ class SomeClass(metaclass=NoPublicConstructor):
- TypeError if a subclass or an instance is created.
"""

__slots__ = ()

def __call__(cls, *args: object, **kwargs: object) -> None:
raise TypeError(
f"{cls.__module__}.{cls.__qualname__} has no public constructor"
Expand All @@ -338,7 +353,7 @@ def _create(cls: t.Type[T], *args: object, **kwargs: object) -> T:
return super().__call__(*args, **kwargs) # type: ignore


def name_asyncgen(agen):
def name_asyncgen(agen: AsyncGeneratorType[t.Any, t.Any]) -> str:
"""Return the fully-qualified name of the async generator function
that produced the async generator iterator *agen*.
"""
Expand Down