From 7506f4e3b3d483fe4b2cd90c9aec322b162376f5 Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Mon, 21 Aug 2017 16:56:28 -0700 Subject: [PATCH 1/3] [wip] Sketch of how shared tasks might work See gh-266 This is surprisingly interesting and tricky. --- trio/_shared_task.py | 118 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 118 insertions(+) create mode 100644 trio/_shared_task.py diff --git a/trio/_shared_task.py b/trio/_shared_task.py new file mode 100644 index 0000000000..c16bdab559 --- /dev/null +++ b/trio/_shared_task.py @@ -0,0 +1,118 @@ +__all__ = ["SharedTaskRegistry"] + + +# Here's some cleverness to normalize out functools.partial usage, important +# b/c otherwise there's no way to pass kwargs without having to specify a key= +# manually. +# +# XX should we also do signature cleverness to normalize stuff like +# def f(x): ... +# and treat these the same: +# (f, (1,), {}) +# (f, (), {"x": 1}) +# ? This is less important b/c we can document that if you want magic key +# generation then you should be careful to make your matching calls obviously +# matching. +def _unpack_call(fn, args, kwargs): + if isinstance(fn, functools.partial): + inner_fn, inner_args, inner_kwargs = _call_to_key( + fn.func, fn.args, fn.kwargs + ) + fn = inner_fn + args = (*inner_args, *args) + kwargs = {**inner_kwargs, **kwargs} + return fn, args, kwargs + + +def call_to_hashable_key(fn, args): + fn, args, kwargs = _unpack_call(fn, args, {}) + return (fn, args, tuple(sorted(kwargs.items()))) + + +@attr.s +class SharedTask: + registry = attr.ib() + key = attr.ib() + cancel_scope = attr.ib(default=None) + # Needed to work around a race condition, where we realize we want to + # cancel the child before it's even created the cancel scope + cancelled_early = attr.ib(default=False) + # Reference count + waiter_count = attr.ib(default=0) + # Reporting back + finished = attr.ib(default=attr.Factory(trio.Event)) + result = attr.ib(default=None) + + # This runs in system task context, so it has KI protection enabled and + # any exceptions will crash the whole program. + async def run(self, async_fn, args): + + async def cancellable_runner(): + with trio.open_cancel_scope() as cancel_scope: + self.cancel_scope = cancel_scope + if self.cancelled_early: + self.cancel_scope.cancel() + return await ki_unprotected_runner() + + @trio.hazmat.disable_ki_protection + async def ki_unprotected_runner(): + return await async_fn(*args) + + self.result = await Result.acapture(cancellable_runner) + self.finished.set() + if self.registry._tasks.get(self.key) is self: + del self.registry._tasks[self.key] + + +@attr.s(slots=True, frozen=True, hash=False, cmp=False, repr=False) +class SharedTaskRegistry: + _tasks = attr.ib(default=attr.Factory(dict)) + + @trio.hazmat.enable_ki_protection + async def run(self, async_fn, *args, key=None): + if key is None: + key = call_to_hashable_key(async_fn, args) + + if key not in self._tasks: + shared_task = SharedTask(self, key) + self._tasks[key] = shared_task + trio.hazmat.spawn_system_task(shared_task.run, async_fn, args) + + shared_task = self._tasks[key] + shared_task.waiter_count += 1 + + try: + await shared_task.finished.wait() + except: + # Cancelled, or some bug + shared_task.waiter_count -= 1 + if shared_task.waiter_count == 0: + # Make sure any incoming calls to run() start a new task + del self._tasks[key] + + # Cancel the child, while working around the race condition + if shared_task.cancel_scope is None: + shared_task.cancelled_early = True + else: + shared_task.cancel_scope.cancel() + + with trio.open_cancel_scope(shield=True) as cancel_scope: + await shared_task.finished() + # Some possibilities: + # - they raised Cancelled. The cancellation we injected is + # absorbed internally, though, so this can only happen + # if a cancellation came from outside. The only way a + # system task can see this is if the whole system is + # going down, so it's OK to re-raise that -- any scope + # that includes a system task includes all the code in + # trio, including us. + # - they raise some other error: we should propagate + # - they return nothing (most common, b/c cancelled was + # raised and then + if not shared_task.cancel_scope.cancelled_caught: + return shared_task.result.unwrap() + else: + shared_task.result.unwrap() + raise + + return shared_task.result.unwrap() From 40211e0a2e41f7a2955337024b8cad72eebb6492 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 14 Jan 2025 21:16:20 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- trio/_shared_task.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/trio/_shared_task.py b/trio/_shared_task.py index c16bdab559..d1602290e3 100644 --- a/trio/_shared_task.py +++ b/trio/_shared_task.py @@ -15,9 +15,7 @@ # matching. def _unpack_call(fn, args, kwargs): if isinstance(fn, functools.partial): - inner_fn, inner_args, inner_kwargs = _call_to_key( - fn.func, fn.args, fn.kwargs - ) + inner_fn, inner_args, inner_kwargs = _call_to_key(fn.func, fn.args, fn.kwargs) fn = inner_fn args = (*inner_args, *args) kwargs = {**inner_kwargs, **kwargs} From d5327ca9cf79ae3cdc1f453a884d5681b1728ddc Mon Sep 17 00:00:00 2001 From: CoolCat467 <52022020+CoolCat467@users.noreply.github.com> Date: Tue, 14 Jan 2025 16:56:08 -0600 Subject: [PATCH 3/3] Move to `src` and add type annotations --- {trio => src/trio}/_shared_task.py | 124 ++++++++++++++++++++++------- 1 file changed, 95 insertions(+), 29 deletions(-) rename {trio => src/trio}/_shared_task.py (51%) diff --git a/trio/_shared_task.py b/src/trio/_shared_task.py similarity index 51% rename from trio/_shared_task.py rename to src/trio/_shared_task.py index d1602290e3..1d6bffe08e 100644 --- a/trio/_shared_task.py +++ b/src/trio/_shared_task.py @@ -1,3 +1,26 @@ +from __future__ import annotations + +import functools +from typing import TYPE_CHECKING, Generic, TypeVar, cast + +import attr +import outcome + +from trio._core._ki import disable_ki_protection, enable_ki_protection +from trio._core._run import CancelScope, spawn_system_task +from trio._sync import Event + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + + from typing_extensions import ParamSpec, TypeVarTuple, Unpack + + PS = ParamSpec("PS") + PosArgT = TypeVarTuple("PosArgT") + +RetT = TypeVar("RetT") + + __all__ = ["SharedTaskRegistry"] @@ -13,70 +36,112 @@ # ? This is less important b/c we can document that if you want magic key # generation then you should be careful to make your matching calls obviously # matching. -def _unpack_call(fn, args, kwargs): +def _unpack_call( + fn: Callable[PS, RetT], + args: PS.args, + kwargs: PS.kwargs | dict[str, object], +) -> tuple[Callable[PS, RetT], PS.args, PS.kwargs | dict[str, object]]: if isinstance(fn, functools.partial): - inner_fn, inner_args, inner_kwargs = _call_to_key(fn.func, fn.args, fn.kwargs) + inner_fn, inner_args, inner_kwargs = _unpack_call(fn.func, fn.args, {}) fn = inner_fn args = (*inner_args, *args) kwargs = {**inner_kwargs, **kwargs} return fn, args, kwargs -def call_to_hashable_key(fn, args): +def call_to_hashable_key( + fn: Callable[[Unpack[PosArgT]], RetT], + args: tuple[Unpack[PosArgT]], +) -> tuple[ + Callable[[Unpack[PosArgT]], RetT], + tuple[Unpack[PosArgT]], + tuple[tuple[str, object], ...], +]: fn, args, kwargs = _unpack_call(fn, args, {}) return (fn, args, tuple(sorted(kwargs.items()))) +class BaseSharedTask: + __slots__ = () + + @attr.s -class SharedTask: - registry = attr.ib() - key = attr.ib() - cancel_scope = attr.ib(default=None) +class SharedTask(BaseSharedTask, Generic["Unpack[PosArgT]", RetT]): + registry: SharedTaskRegistry = attr.ib() + key: tuple[ + Callable[[Unpack[PosArgT]], Awaitable[RetT]], + tuple[Unpack[PosArgT]], + tuple[tuple[str, object], ...], + ] = attr.ib() + cancel_scope: CancelScope | None = attr.ib(default=None) # Needed to work around a race condition, where we realize we want to # cancel the child before it's even created the cancel scope - cancelled_early = attr.ib(default=False) + cancelled_early: bool = attr.ib(default=False) # Reference count - waiter_count = attr.ib(default=0) + waiter_count: int = attr.ib(default=0) # Reporting back - finished = attr.ib(default=attr.Factory(trio.Event)) - result = attr.ib(default=None) + finished: Event = attr.ib(default=attr.Factory(Event)) + result: outcome.Value[RetT] | outcome.Error = attr.ib(default=None) # This runs in system task context, so it has KI protection enabled and # any exceptions will crash the whole program. - async def run(self, async_fn, args): + async def run( + self, + async_fn: Callable[[Unpack[PosArgT]], Awaitable[RetT]], + args: tuple[Unpack[PosArgT]], + ) -> None: + @disable_ki_protection + async def ki_unprotected_runner() -> RetT: + return await async_fn(*args) - async def cancellable_runner(): - with trio.open_cancel_scope() as cancel_scope: + async def cancellable_runner() -> RetT: + with CancelScope() as cancel_scope: self.cancel_scope = cancel_scope if self.cancelled_early: self.cancel_scope.cancel() return await ki_unprotected_runner() + raise RuntimeError("Should be unreachable.") - @trio.hazmat.disable_ki_protection - async def ki_unprotected_runner(): - return await async_fn(*args) - - self.result = await Result.acapture(cancellable_runner) + self.result = await outcome.acapture(cancellable_runner) self.finished.set() if self.registry._tasks.get(self.key) is self: del self.registry._tasks[self.key] @attr.s(slots=True, frozen=True, hash=False, cmp=False, repr=False) -class SharedTaskRegistry: - _tasks = attr.ib(default=attr.Factory(dict)) - - @trio.hazmat.enable_ki_protection - async def run(self, async_fn, *args, key=None): +class SharedTaskRegistry: # type: ignore[misc] + _tasks: dict[ # type: ignore[misc] + tuple[ + Callable[..., Awaitable[object]], + tuple[object, ...], + tuple[tuple[str, object], ...], + ], + BaseSharedTask, + ] = attr.ib(default=attr.Factory(dict)) + + @enable_ki_protection + async def run( + self, + async_fn: Callable[[Unpack[PosArgT]], Awaitable[RetT]], + *args: Unpack[PosArgT], + key: ( + tuple[ + Callable[[Unpack[PosArgT]], Awaitable[RetT]], + tuple[Unpack[PosArgT]], + tuple[tuple[str, object], ...], + ] + | None + ) = None, + ) -> RetT: if key is None: key = call_to_hashable_key(async_fn, args) if key not in self._tasks: - shared_task = SharedTask(self, key) + shared_task = SharedTask["Unpack[PosArgT]", RetT](self, key) self._tasks[key] = shared_task - trio.hazmat.spawn_system_task(shared_task.run, async_fn, args) + spawn_system_task(shared_task.run, async_fn, args) - shared_task = self._tasks[key] + shared_task = cast("SharedTask[Unpack[PosArgT], RetT]", self._tasks[key]) shared_task.waiter_count += 1 try: @@ -94,8 +159,8 @@ async def run(self, async_fn, *args, key=None): else: shared_task.cancel_scope.cancel() - with trio.open_cancel_scope(shield=True) as cancel_scope: - await shared_task.finished() + with CancelScope(shield=True): + await shared_task.finished.wait() # Some possibilities: # - they raised Cancelled. The cancellation we injected is # absorbed internally, though, so this can only happen @@ -107,6 +172,7 @@ async def run(self, async_fn, *args, key=None): # - they raise some other error: we should propagate # - they return nothing (most common, b/c cancelled was # raised and then + assert shared_task.cancel_scope is not None if not shared_task.cancel_scope.cancelled_caught: return shared_task.result.unwrap() else: