From ca7400819d899d8f87eff242675d7cc43124218a Mon Sep 17 00:00:00 2001 From: richardsheridan Date: Sun, 5 Nov 2023 10:33:17 -0500 Subject: [PATCH 1/7] raise from None when transmuting errors --- trio/_threads.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trio/_threads.py b/trio/_threads.py index a1859a928a..d227147f9a 100644 --- a/trio/_threads.py +++ b/trio/_threads.py @@ -443,10 +443,10 @@ def from_thread_check_cancelled() -> None: """ try: raise_cancel = PARENT_TASK_DATA.cancel_register[0] - except AttributeError as exc: + except AttributeError: raise RuntimeError( "this thread wasn't created by Trio, can't check for cancellation" - ) from exc + ) from None if raise_cancel is not None: raise_cancel() From 67551969d409b45d93902898d24d3e66115c6d96 Mon Sep 17 00:00:00 2001 From: richardsheridan Date: Sun, 5 Nov 2023 10:34:37 -0500 Subject: [PATCH 2/7] make from_thread funcs thinner and DRYer --- trio/_threads.py | 53 ++++++++++++++++-------------------------------- 1 file changed, 18 insertions(+), 35 deletions(-) diff --git a/trio/_threads.py b/trio/_threads.py index d227147f9a..72ba7be1f4 100644 --- a/trio/_threads.py +++ b/trio/_threads.py @@ -80,7 +80,7 @@ class ThreadPlaceholder: class Run(Generic[RetT]): afn: Callable[..., Awaitable[RetT]] = attr.ib() args: tuple[object, ...] = attr.ib() - context: contextvars.Context = attr.ib() + context: contextvars.Context = attr.ib(init=False, factory=contextvars.copy_context) queue: stdlib_queue.SimpleQueue[outcome.Outcome[RetT]] = attr.ib( init=False, factory=stdlib_queue.SimpleQueue ) @@ -133,7 +133,7 @@ def in_trio_thread() -> None: class RunSync(Generic[RetT]): fn: Callable[..., RetT] = attr.ib() args: tuple[object, ...] = attr.ib() - context: contextvars.Context = attr.ib() + context: contextvars.Context = attr.ib(init=False, factory=contextvars.copy_context) queue: stdlib_queue.SimpleQueue[outcome.Outcome[RetT]] = attr.ib( init=False, factory=stdlib_queue.SimpleQueue ) @@ -451,23 +451,21 @@ def from_thread_check_cancelled() -> None: raise_cancel() -def _check_token(trio_token: TrioToken | None) -> TrioToken: - """Raise a RuntimeError if this function is called within a trio run. - - Avoids deadlock by making sure we're not called from inside a context - that we might be waiting for and blocking it. - """ - - if trio_token is not None and not isinstance(trio_token, TrioToken): - raise RuntimeError("Passed kwarg trio_token is not of type TrioToken") +def _send_message_to_trio( + trio_token: TrioToken | None, message_to_trio: Run[RetT] | RunSync[RetT] +) -> RetT: + """Shared logic of from_thread functions""" + token_provided = trio_token is not None - if trio_token is None: + if not token_provided: try: trio_token = PARENT_TASK_DATA.token except AttributeError: raise RuntimeError( "this thread wasn't created by Trio, pass kwarg trio_token=..." ) from None + elif not isinstance(trio_token, TrioToken): + raise RuntimeError("Passed kwarg trio_token is not of type TrioToken") # Avoid deadlock by making sure we're not called from Trio thread try: @@ -477,7 +475,12 @@ def _check_token(trio_token: TrioToken | None) -> TrioToken: else: raise RuntimeError("this is a blocking function; call it from a thread") - return trio_token + if token_provided or PARENT_TASK_DATA.abandon_on_cancel: + message_to_trio.run_in_system_nursery(trio_token) + else: + message_to_trio.run_in_host_task(trio_token) + + return message_to_trio.queue.get().unwrap() def from_thread_run( @@ -520,17 +523,7 @@ def from_thread_run( maybe to avoid the cancellation context of a corresponding `trio.to_thread.run_sync` task. """ - token_provided = trio_token is not None - trio_token = _check_token(trio_token) - - message_to_trio = Run(afn, args, contextvars.copy_context()) - - if token_provided or PARENT_TASK_DATA.abandon_on_cancel: - message_to_trio.run_in_system_nursery(trio_token) - else: - message_to_trio.run_in_host_task(trio_token) - - return message_to_trio.queue.get().unwrap() + return _send_message_to_trio(trio_token, Run(afn, args)) def from_thread_run_sync( @@ -568,14 +561,4 @@ def from_thread_run_sync( maybe to avoid the cancellation context of a corresponding `trio.to_thread.run_sync` task. """ - token_provided = trio_token is not None - trio_token = _check_token(trio_token) - - message_to_trio = RunSync(fn, args, contextvars.copy_context()) - - if token_provided or PARENT_TASK_DATA.abandon_on_cancel: - message_to_trio.run_in_system_nursery(trio_token) - else: - message_to_trio.run_in_host_task(trio_token) - - return message_to_trio.queue.get().unwrap() + return _send_message_to_trio(trio_token, RunSync(fn, args)) From e461911f3534362138a5fe9b1906dfb50545194f Mon Sep 17 00:00:00 2001 From: richardsheridan Date: Sun, 5 Nov 2023 10:37:21 -0500 Subject: [PATCH 3/7] update deliver comment --- trio/_threads.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trio/_threads.py b/trio/_threads.py index 72ba7be1f4..e6b54f9de7 100644 --- a/trio/_threads.py +++ b/trio/_threads.py @@ -370,9 +370,9 @@ def worker_fn() -> RetT: contextvars_aware_worker_fn: Callable[[], RetT] = functools.partial(context.run, worker_fn) # type: ignore[assignment] def deliver_worker_fn_result(result: outcome.Outcome[RetT]) -> None: - # The entire run finished, so the task we're trying to contact is + # If the entire run finished, the task we're trying to contact is # certainly long gone -- it must have been cancelled and abandoned - # us. + # us. Just ignore the error in this case. with contextlib.suppress(trio.RunFinishedError): current_trio_token.run_sync_soon(report_back_in_trio_thread_fn, result) From 5422f61aae54fd179c82da0dcb34babf7c5bc95d Mon Sep 17 00:00:00 2001 From: richardsheridan Date: Sun, 5 Nov 2023 10:37:40 -0500 Subject: [PATCH 4/7] apply context without type-confusing wrappers --- trio/_threads.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/trio/_threads.py b/trio/_threads.py index e6b54f9de7..9e0ab5caa4 100644 --- a/trio/_threads.py +++ b/trio/_threads.py @@ -2,7 +2,6 @@ import contextlib import contextvars -import functools import inspect import queue as stdlib_queue import threading @@ -337,18 +336,12 @@ def do_release_then_return_result() -> RetT: thread_name = f"{getattr(sync_fn, '__name__', None)} from {trio.lowlevel.current_task().name}" def worker_fn() -> RetT: - # Trio doesn't use current_async_library_cvar, but if someone - # else set it, it would now shine through since - # snifio.thread_local isn't set in the new thread. Make sure - # the new thread sees that it's not running in async context. - current_async_library_cvar.set(None) - PARENT_TASK_DATA.token = current_trio_token PARENT_TASK_DATA.abandon_on_cancel = abandon_on_cancel PARENT_TASK_DATA.cancel_register = cancel_register PARENT_TASK_DATA.task_register = task_register try: - ret = sync_fn(*args) + ret = context.run(sync_fn, *args) if inspect.iscoroutine(ret): # Manually close coroutine to avoid RuntimeWarnings @@ -366,8 +359,11 @@ def worker_fn() -> RetT: del PARENT_TASK_DATA.task_register context = contextvars.copy_context() - # Partial confuses type checkers, coerce to a callable. - contextvars_aware_worker_fn: Callable[[], RetT] = functools.partial(context.run, worker_fn) # type: ignore[assignment] + # Trio doesn't use current_async_library_cvar, but if someone + # else set it, it would now shine through since + # sniffio.thread_local isn't set in the new thread. Make sure + # the new thread sees that it's not running in async context. + context.run(current_async_library_cvar.set, None) def deliver_worker_fn_result(result: outcome.Outcome[RetT]) -> None: # If the entire run finished, the task we're trying to contact is @@ -378,9 +374,7 @@ def deliver_worker_fn_result(result: outcome.Outcome[RetT]) -> None: await limiter.acquire_on_behalf_of(placeholder) try: - start_thread_soon( - contextvars_aware_worker_fn, deliver_worker_fn_result, thread_name - ) + start_thread_soon(worker_fn, deliver_worker_fn_result, thread_name) except: limiter.release_on_behalf_of(placeholder) raise From 7b21b4c1d4ff532d420273fe437c05a389b28eb3 Mon Sep 17 00:00:00 2001 From: richardsheridan Date: Sun, 5 Nov 2023 10:38:32 -0500 Subject: [PATCH 5/7] ignore slight type mismatch after 5422f61a, mypy can't tell that abandon_on_cancel cannot be None here --- trio/_threads.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trio/_threads.py b/trio/_threads.py index 9e0ab5caa4..a1e2aa2b09 100644 --- a/trio/_threads.py +++ b/trio/_threads.py @@ -337,7 +337,7 @@ def do_release_then_return_result() -> RetT: def worker_fn() -> RetT: PARENT_TASK_DATA.token = current_trio_token - PARENT_TASK_DATA.abandon_on_cancel = abandon_on_cancel + PARENT_TASK_DATA.abandon_on_cancel = abandon_on_cancel # type: ignore[assignment] PARENT_TASK_DATA.cancel_register = cancel_register PARENT_TASK_DATA.task_register = task_register try: From 8d8a4752cdbbe11b803baf1d0b0fab2a83446eb6 Mon Sep 17 00:00:00 2001 From: richardsheridan Date: Sun, 5 Nov 2023 11:53:01 -0500 Subject: [PATCH 6/7] apply context without chaining paramspecs --- trio/_threads.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/trio/_threads.py b/trio/_threads.py index a1e2aa2b09..3a047968ba 100644 --- a/trio/_threads.py +++ b/trio/_threads.py @@ -139,7 +139,7 @@ class RunSync(Generic[RetT]): @disable_ki_protection def unprotected_fn(self) -> RetT: - ret = self.fn(*self.args) + ret = self.context.run(self.fn, *self.args) if inspect.iscoroutine(ret): # Manually close coroutine to avoid RuntimeWarnings @@ -152,9 +152,7 @@ def unprotected_fn(self) -> RetT: return ret def run_sync(self) -> None: - # Two paramspecs + overload is a bit too hard for mypy to handle. Tell it what to infer. - runner: Callable[[Callable[[], RetT]], RetT] = self.context.run - result = outcome.capture(runner, self.unprotected_fn) + result = outcome.capture(self.unprotected_fn) self.queue.put_nowait(result) def run_in_host_task(self, token: TrioToken) -> None: From 9ec7b25aebc62698bfe60b8bfcfd50d8e61058c6 Mon Sep 17 00:00:00 2001 From: richardsheridan Date: Sun, 5 Nov 2023 20:53:00 -0500 Subject: [PATCH 7/7] assign abandon_on_cancel to a new name for type narrowing --- trio/_threads.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/trio/_threads.py b/trio/_threads.py index 3a047968ba..1e84595cc1 100644 --- a/trio/_threads.py +++ b/trio/_threads.py @@ -297,7 +297,8 @@ async def to_thread_run_sync( # type: ignore[misc] ) abandon_on_cancel = cancellable # raise early if abandon_on_cancel.__bool__ raises - abandon_on_cancel = bool(abandon_on_cancel) + # and give a new name to ensure mypy knows it's never None + abandon_bool = bool(abandon_on_cancel) if limiter is None: limiter = current_default_thread_limiter() @@ -335,7 +336,7 @@ def do_release_then_return_result() -> RetT: def worker_fn() -> RetT: PARENT_TASK_DATA.token = current_trio_token - PARENT_TASK_DATA.abandon_on_cancel = abandon_on_cancel # type: ignore[assignment] + PARENT_TASK_DATA.abandon_on_cancel = abandon_bool PARENT_TASK_DATA.cancel_register = cancel_register PARENT_TASK_DATA.task_register = task_register try: @@ -380,7 +381,7 @@ def deliver_worker_fn_result(result: outcome.Outcome[RetT]) -> None: def abort(raise_cancel: RaiseCancelT) -> trio.lowlevel.Abort: # fill so from_thread_check_cancelled can raise cancel_register[0] = raise_cancel - if abandon_on_cancel: + if abandon_bool: # empty so report_back_in_trio_thread_fn cannot reschedule task_register[0] = None return trio.lowlevel.Abort.SUCCEEDED