diff --git a/src/trio/_threads.py b/src/trio/_threads.py index 9051e2f4b5..4e44203464 100644 --- a/src/trio/_threads.py +++ b/src/trio/_threads.py @@ -2,7 +2,6 @@ import contextlib import contextvars -import functools import inspect import queue as stdlib_queue import threading @@ -83,7 +82,7 @@ class ThreadPlaceholder: class Run(Generic[RetT]): afn: Callable[..., Awaitable[RetT]] = attr.ib() args: tuple[object, ...] = attr.ib() - context: contextvars.Context = attr.ib() + context: contextvars.Context = attr.ib(init=False, factory=contextvars.copy_context) queue: stdlib_queue.SimpleQueue[outcome.Outcome[RetT]] = attr.ib( init=False, factory=stdlib_queue.SimpleQueue ) @@ -136,14 +135,14 @@ def in_trio_thread() -> None: class RunSync(Generic[RetT]): fn: Callable[..., RetT] = attr.ib() args: tuple[object, ...] = attr.ib() - context: contextvars.Context = attr.ib() + context: contextvars.Context = attr.ib(init=False, factory=contextvars.copy_context) queue: stdlib_queue.SimpleQueue[outcome.Outcome[RetT]] = attr.ib( init=False, factory=stdlib_queue.SimpleQueue ) @disable_ki_protection def unprotected_fn(self) -> RetT: - ret = self.fn(*self.args) + ret = self.context.run(self.fn, *self.args) if inspect.iscoroutine(ret): # Manually close coroutine to avoid RuntimeWarnings @@ -156,9 +155,7 @@ def unprotected_fn(self) -> RetT: return ret def run_sync(self) -> None: - # Two paramspecs + overload is a bit too hard for mypy to handle. Tell it what to infer. - runner: Callable[[Callable[[], RetT]], RetT] = self.context.run - result = outcome.capture(runner, self.unprotected_fn) + result = outcome.capture(self.unprotected_fn) self.queue.put_nowait(result) def run_in_host_task(self, token: TrioToken) -> None: @@ -303,7 +300,8 @@ async def to_thread_run_sync( # type: ignore[misc] ) abandon_on_cancel = cancellable # raise early if abandon_on_cancel.__bool__ raises - abandon_on_cancel = bool(abandon_on_cancel) + # and give a new name to ensure mypy knows it's never None + abandon_bool = bool(abandon_on_cancel) if limiter is None: limiter = current_default_thread_limiter() @@ -340,18 +338,12 @@ def do_release_then_return_result() -> RetT: thread_name = f"{getattr(sync_fn, '__name__', None)} from {trio.lowlevel.current_task().name}" def worker_fn() -> RetT: - # Trio doesn't use current_async_library_cvar, but if someone - # else set it, it would now shine through since - # snifio.thread_local isn't set in the new thread. Make sure - # the new thread sees that it's not running in async context. - current_async_library_cvar.set(None) - PARENT_TASK_DATA.token = current_trio_token - PARENT_TASK_DATA.abandon_on_cancel = abandon_on_cancel + PARENT_TASK_DATA.abandon_on_cancel = abandon_bool PARENT_TASK_DATA.cancel_register = cancel_register PARENT_TASK_DATA.task_register = task_register try: - ret = sync_fn(*args) + ret = context.run(sync_fn, *args) if inspect.iscoroutine(ret): # Manually close coroutine to avoid RuntimeWarnings @@ -369,21 +361,22 @@ def worker_fn() -> RetT: del PARENT_TASK_DATA.task_register context = contextvars.copy_context() - # Partial confuses type checkers, coerce to a callable. - contextvars_aware_worker_fn: Callable[[], RetT] = functools.partial(context.run, worker_fn) # type: ignore[assignment] + # Trio doesn't use current_async_library_cvar, but if someone + # else set it, it would now shine through since + # sniffio.thread_local isn't set in the new thread. Make sure + # the new thread sees that it's not running in async context. + context.run(current_async_library_cvar.set, None) def deliver_worker_fn_result(result: outcome.Outcome[RetT]) -> None: - # The entire run finished, so the task we're trying to contact is + # If the entire run finished, the task we're trying to contact is # certainly long gone -- it must have been cancelled and abandoned - # us. + # us. Just ignore the error in this case. with contextlib.suppress(trio.RunFinishedError): current_trio_token.run_sync_soon(report_back_in_trio_thread_fn, result) await limiter.acquire_on_behalf_of(placeholder) try: - start_thread_soon( - contextvars_aware_worker_fn, deliver_worker_fn_result, thread_name - ) + start_thread_soon(worker_fn, deliver_worker_fn_result, thread_name) except: limiter.release_on_behalf_of(placeholder) raise @@ -391,7 +384,7 @@ def deliver_worker_fn_result(result: outcome.Outcome[RetT]) -> None: def abort(raise_cancel: RaiseCancelT) -> trio.lowlevel.Abort: # fill so from_thread_check_cancelled can raise cancel_register[0] = raise_cancel - if abandon_on_cancel: + if abandon_bool: # empty so report_back_in_trio_thread_fn cannot reschedule task_register[0] = None return trio.lowlevel.Abort.SUCCEEDED @@ -446,31 +439,29 @@ def from_thread_check_cancelled() -> None: """ try: raise_cancel = PARENT_TASK_DATA.cancel_register[0] - except AttributeError as exc: + except AttributeError: raise RuntimeError( "this thread wasn't created by Trio, can't check for cancellation" - ) from exc + ) from None if raise_cancel is not None: raise_cancel() -def _check_token(trio_token: TrioToken | None) -> TrioToken: - """Raise a RuntimeError if this function is called within a trio run. - - Avoids deadlock by making sure we're not called from inside a context - that we might be waiting for and blocking it. - """ - - if trio_token is not None and not isinstance(trio_token, TrioToken): - raise RuntimeError("Passed kwarg trio_token is not of type TrioToken") +def _send_message_to_trio( + trio_token: TrioToken | None, message_to_trio: Run[RetT] | RunSync[RetT] +) -> RetT: + """Shared logic of from_thread functions""" + token_provided = trio_token is not None - if trio_token is None: + if not token_provided: try: trio_token = PARENT_TASK_DATA.token except AttributeError: raise RuntimeError( "this thread wasn't created by Trio, pass kwarg trio_token=..." ) from None + elif not isinstance(trio_token, TrioToken): + raise RuntimeError("Passed kwarg trio_token is not of type TrioToken") # Avoid deadlock by making sure we're not called from Trio thread try: @@ -480,7 +471,12 @@ def _check_token(trio_token: TrioToken | None) -> TrioToken: else: raise RuntimeError("this is a blocking function; call it from a thread") - return trio_token + if token_provided or PARENT_TASK_DATA.abandon_on_cancel: + message_to_trio.run_in_system_nursery(trio_token) + else: + message_to_trio.run_in_host_task(trio_token) + + return message_to_trio.queue.get().unwrap() def from_thread_run( @@ -523,17 +519,7 @@ def from_thread_run( maybe to avoid the cancellation context of a corresponding `trio.to_thread.run_sync` task. """ - token_provided = trio_token is not None - trio_token = _check_token(trio_token) - - message_to_trio = Run(afn, args, contextvars.copy_context()) - - if token_provided or PARENT_TASK_DATA.abandon_on_cancel: - message_to_trio.run_in_system_nursery(trio_token) - else: - message_to_trio.run_in_host_task(trio_token) - - return message_to_trio.queue.get().unwrap() + return _send_message_to_trio(trio_token, Run(afn, args)) def from_thread_run_sync( @@ -571,14 +557,4 @@ def from_thread_run_sync( maybe to avoid the cancellation context of a corresponding `trio.to_thread.run_sync` task. """ - token_provided = trio_token is not None - trio_token = _check_token(trio_token) - - message_to_trio = RunSync(fn, args, contextvars.copy_context()) - - if token_provided or PARENT_TASK_DATA.abandon_on_cancel: - message_to_trio.run_in_system_nursery(trio_token) - else: - message_to_trio.run_in_host_task(trio_token) - - return message_to_trio.queue.get().unwrap() + return _send_message_to_trio(trio_token, RunSync(fn, args))