diff --git a/MANIFEST.in b/MANIFEST.in index e2fd4c157f..0a78367389 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -2,6 +2,7 @@ include LICENSE LICENSE.MIT LICENSE.APACHE2 include README.rst include CODE_OF_CONDUCT.md CONTRIBUTING.md include test-requirements.txt +include trio/py.typed recursive-include trio/tests/test_ssl_certs *.pem recursive-include docs * prune docs/build diff --git a/trio/_core/_run.py b/trio/_core/_run.py index fcb36c32c5..836ce3c434 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -18,7 +18,12 @@ from contextvars import copy_context from math import inf from time import perf_counter -from typing import Callable, TYPE_CHECKING +from typing import Any, Awaitable, Callable, TYPE_CHECKING, TypeVar + +if sys.version_info >= (3, 10): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec from sniffio import current_async_library_cvar @@ -801,7 +806,7 @@ class NurseryManager: """ @enable_ki_protection - async def __aenter__(self): + async def __aenter__(self) -> "Nursery": self._scope = CancelScope() self._scope.__enter__() self._nursery = Nursery._create(current_task(), self._scope) @@ -840,7 +845,7 @@ def __exit__(self): # pragma: no cover assert False, """Never called, but should be defined""" -def open_nursery(): +def open_nursery() -> NurseryManager: """Returns an async context manager which must be used to create a new `Nursery`. @@ -851,6 +856,10 @@ def open_nursery(): return NurseryManager() +T_Retval = TypeVar("T_Retval") +T_ParamSpec = ParamSpec("T_ParamSpec") + + class Nursery(metaclass=NoPublicConstructor): """A context which may be used to spawn (or cancel) child tasks. @@ -957,6 +966,17 @@ def aborted(raise_cancel): # (see test_nursery_cancel_doesnt_create_cyclic_garbage) del self._pending_excs + def soonify( + self, + async_fn: Callable[T_ParamSpec, Awaitable[Any]], + name: str = None, + ) -> Callable[T_ParamSpec, None]: + def wrapper(*args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs) -> None: + partial_f = functools.partial(async_fn, *args, **kwargs) + self.start_soon(partial_f, name=name) + + return wrapper + def start_soon(self, async_fn, *args, name=None): """Creates a child task, scheduling ``await async_fn(*args)``. diff --git a/trio/_threads.py b/trio/_threads.py index 356eb0d2d1..98d538b910 100644 --- a/trio/_threads.py +++ b/trio/_threads.py @@ -1,8 +1,11 @@ # coding: utf-8 +import functools import threading import queue as stdlib_queue from itertools import count +from typing import Awaitable, TypeVar, Callable, Optional +import sys import attr import inspect @@ -10,6 +13,11 @@ import trio +if sys.version_info >= (3, 10): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec + from ._sync import CapacityLimiter from ._core import ( enable_ki_protection, @@ -55,6 +63,27 @@ class ThreadPlaceholder: name = attr.ib() +T_Retval = TypeVar("T_Retval") +T_ParamSpec = ParamSpec("T_ParamSpec") + + +def to_thread_asyncify( + sync_fn: Callable[T_ParamSpec, T_Retval], + *, + cancellable: bool = False, + limiter: Optional[CapacityLimiter] = None, +) -> Callable[T_ParamSpec, Awaitable[T_Retval]]: + async def wrapper( + *args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs + ) -> T_Retval: + partial_f = functools.partial(sync_fn, *args, **kwargs) + return await to_thread_run_sync( + partial_f, cancellable=cancellable, limiter=limiter + ) + + return wrapper + + @enable_ki_protection async def to_thread_run_sync(sync_fn, *args, cancellable=False, limiter=None): """Convert a blocking operation into an async operation using a thread. @@ -238,6 +267,16 @@ def _run_fn_as_system_task(cb, fn, *args, trio_token=None): return q.get().unwrap() +def from_thread_syncify( + afn: Callable[T_ParamSpec, Awaitable[T_Retval]], +) -> Callable[T_ParamSpec, T_Retval]: + def wrapper(*args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs) -> T_Retval: + partial_f = functools.partial(afn, *args, **kwargs) + return from_thread_run(partial_f) + + return wrapper + + def from_thread_run(afn, *args, trio_token=None): """Run the given async function in the parent Trio thread, blocking until it is complete. diff --git a/trio/from_thread.py b/trio/from_thread.py index 296a5a89ea..42859c3ba9 100644 --- a/trio/from_thread.py +++ b/trio/from_thread.py @@ -4,4 +4,5 @@ """ from ._threads import from_thread_run as run +from ._threads import from_thread_syncify as syncify from ._threads import from_thread_run_sync as run_sync diff --git a/trio/py.typed b/trio/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/trio/to_thread.py b/trio/to_thread.py index 6eec7b36c7..548ee7fad0 100644 --- a/trio/to_thread.py +++ b/trio/to_thread.py @@ -1,2 +1,3 @@ from ._threads import to_thread_run_sync as run_sync +from ._threads import to_thread_asyncify as asyncify from ._threads import current_default_thread_limiter