diff --git a/docs/source/reference-core.rst b/docs/source/reference-core.rst index fc74ccd613..333cb0537b 100644 --- a/docs/source/reference-core.rst +++ b/docs/source/reference-core.rst @@ -1823,6 +1823,25 @@ to spawn a child thread, and then use a :ref:`memory channel .. literalinclude:: reference-core/from-thread-example.py +.. note:: + + The ``from_thread.run*`` functions reuse the host task that called + :func:`trio.to_thread.run_sync` to run your provided function, as long as you're + using the default ``cancellable=False`` so Trio can be sure that the task will remain + around to perform the work. If you pass ``cancellable=True`` at the outset, or if + you provide a :class:`~trio.lowlevel.TrioToken` when calling back in to Trio, your + functions will be executed in a new system task. Therefore, the + :func:`~trio.lowlevel.current_task`, :func:`current_effective_deadline`, or other + task-tree specific values may differ depending on keyword argument values. + +You can also use :func:`trio.from_thread.check_cancelled` to check for cancellation from +a thread that was spawned by :func:`trio.to_thread.run_sync`. If the call to +:func:`~trio.to_thread.run_sync` was cancelled (even if ``cancellable=False``!), then +:func:`~trio.from_thread.check_cancelled` will raise :func:`trio.Cancelled`. +It's like ``trio.from_thread.run(trio.sleep, 0)``, but much faster. + +.. autofunction:: trio.from_thread.check_cancelled + Threads and task-local storage ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/newsfragments/2392.feature.rst b/newsfragments/2392.feature.rst new file mode 100644 index 0000000000..985d3235af --- /dev/null +++ b/newsfragments/2392.feature.rst @@ -0,0 +1,5 @@ +If called from a thread spawned by `trio.to_thread.run_sync`, `trio.from_thread.run` and +`trio.from_thread.run_sync` now reuse the task and cancellation status of the host task; +this means that context variables and cancel scopes naturally propagate 'through' +threads spawned by Trio. You can also use `trio.from_thread.check_cancelled` +to efficiently check for cancellation without reentering the Trio thread. diff --git a/trio/_tests/test_threads.py b/trio/_tests/test_threads.py index 9e448a4d38..a151c03077 100644 --- a/trio/_tests/test_threads.py +++ b/trio/_tests/test_threads.py @@ -13,13 +13,12 @@ import pytest import sniffio -from trio._core import TrioToken, current_trio_token - -from .. import CapacityLimiter, Event, _core, sleep +from .. import CapacityLimiter, Event, _core, fail_after, sleep, sleep_forever from .._core._tests.test_ki import ki_self from .._core._tests.tutil import buggy_pypy_asyncgens from .._threads import ( current_default_thread_limiter, + from_thread_check_cancelled, from_thread_run, from_thread_run_sync, to_thread_run_sync, @@ -645,7 +644,7 @@ async def async_fn(): # pragma: no cover def thread_fn(): from_thread_run_sync(async_fn) - with pytest.raises(TypeError, match="expected a sync function"): + with pytest.raises(TypeError, match="expected a synchronous function"): await to_thread_run_sync(thread_fn) @@ -810,25 +809,32 @@ def test_from_thread_run_during_shutdown(): save = [] record = [] - async def agen(): + async def agen(token): try: yield finally: - with pytest.raises(_core.RunFinishedError), _core.CancelScope(shield=True): - await to_thread_run_sync(from_thread_run, sleep, 0) - record.append("ok") - - async def main(): - save.append(agen()) + with _core.CancelScope(shield=True): + try: + await to_thread_run_sync( + partial(from_thread_run, sleep, 0, trio_token=token) + ) + except _core.RunFinishedError: + record.append("finished") + else: + record.append("clean") + + async def main(use_system_task): + save.append(agen(_core.current_trio_token() if use_system_task else None)) await save[-1].asend(None) - _core.run(main) - assert record == ["ok"] + _core.run(main, True) # System nursery will be closed and raise RunFinishedError + _core.run(main, False) # host task will be rescheduled as normal + assert record == ["finished", "clean"] async def test_trio_token_weak_referenceable(): - token = current_trio_token() - assert isinstance(token, TrioToken) + token = _core.current_trio_token() + assert isinstance(token, _core.TrioToken) weak_reference = weakref.ref(token) assert token is weak_reference() @@ -842,3 +848,170 @@ def __bool__(self): with pytest.raises(NotImplementedError): await to_thread_run_sync(int, cancellable=BadBool()) + + +async def test_from_thread_reuses_task(): + task = _core.current_task() + + async def async_current_task(): + return _core.current_task() + + assert task is await to_thread_run_sync(from_thread_run_sync, _core.current_task) + assert task is await to_thread_run_sync(from_thread_run, async_current_task) + + +async def test_recursive_to_thread(): + tid = None + + def get_tid_then_reenter(): + nonlocal tid + tid = threading.get_ident() + return from_thread_run(to_thread_run_sync, threading.get_ident) + + assert tid != await to_thread_run_sync(get_tid_then_reenter) + + +async def test_from_thread_host_cancelled(): + queue = stdlib_queue.Queue() + + def sync_check(): + from_thread_run_sync(cancel_scope.cancel) + try: + from_thread_run_sync(bool) + except _core.Cancelled: # pragma: no cover + queue.put(True) # sync functions don't raise Cancelled + else: + queue.put(False) + + with _core.CancelScope() as cancel_scope: + await to_thread_run_sync(sync_check) + + assert not cancel_scope.cancelled_caught + assert not queue.get_nowait() + + with _core.CancelScope() as cancel_scope: + await to_thread_run_sync(sync_check, cancellable=True) + + assert cancel_scope.cancelled_caught + assert not await to_thread_run_sync(partial(queue.get, timeout=1)) + + async def no_checkpoint(): + return True + + def async_check(): + from_thread_run_sync(cancel_scope.cancel) + try: + assert from_thread_run(no_checkpoint) + except _core.Cancelled: # pragma: no cover + queue.put(True) # async functions raise Cancelled at checkpoints + else: + queue.put(False) + + with _core.CancelScope() as cancel_scope: + await to_thread_run_sync(async_check) + + assert not cancel_scope.cancelled_caught + assert not queue.get_nowait() + + with _core.CancelScope() as cancel_scope: + await to_thread_run_sync(async_check, cancellable=True) + + assert cancel_scope.cancelled_caught + assert not await to_thread_run_sync(partial(queue.get, timeout=1)) + + async def async_time_bomb(): + cancel_scope.cancel() + with fail_after(10): + await sleep_forever() + + with _core.CancelScope() as cancel_scope: + await to_thread_run_sync(from_thread_run, async_time_bomb) + + assert cancel_scope.cancelled_caught + + +async def test_from_thread_check_cancelled(): + q = stdlib_queue.Queue() + + async def child(cancellable, scope): + with scope: + record.append("start") + try: + return await to_thread_run_sync(f, cancellable=cancellable) + except _core.Cancelled: + record.append("cancel") + raise + finally: + record.append("exit") + + def f(): + try: + from_thread_check_cancelled() + except _core.Cancelled: # pragma: no cover, test failure path + q.put("Cancelled") + else: + q.put("Not Cancelled") + ev.wait() + return from_thread_check_cancelled() + + # Base case: nothing cancelled so we shouldn't see cancels anywhere + record = [] + ev = threading.Event() + async with _core.open_nursery() as nursery: + nursery.start_soon(child, False, _core.CancelScope()) + await wait_all_tasks_blocked() + assert record[0] == "start" + assert q.get(timeout=1) == "Not Cancelled" + ev.set() + # implicit assertion, Cancelled not raised via nursery + assert record[1] == "exit" + + # cancellable=False case: a cancel will pop out but be handled by + # the appropriate cancel scope + record = [] + ev = threading.Event() + scope = _core.CancelScope() # Nursery cancel scope gives false positives + async with _core.open_nursery() as nursery: + nursery.start_soon(child, False, scope) + await wait_all_tasks_blocked() + assert record[0] == "start" + assert q.get(timeout=1) == "Not Cancelled" + scope.cancel() + ev.set() + assert scope.cancelled_caught + assert "cancel" in record + assert record[-1] == "exit" + + # cancellable=True case: slightly different thread behavior needed + # check thread is cancelled "soon" after abandonment + def f(): # noqa: F811 + ev.wait() + try: + from_thread_check_cancelled() + except _core.Cancelled: + q.put("Cancelled") + else: # pragma: no cover, test failure path + q.put("Not Cancelled") + + record = [] + ev = threading.Event() + scope = _core.CancelScope() + async with _core.open_nursery() as nursery: + nursery.start_soon(child, True, scope) + await wait_all_tasks_blocked() + assert record[0] == "start" + scope.cancel() + ev.set() + assert scope.cancelled_caught + assert "cancel" in record + assert record[-1] == "exit" + assert q.get(timeout=1) == "Cancelled" + + +async def test_from_thread_check_cancelled_raises_in_foreign_threads(): + with pytest.raises(RuntimeError): + from_thread_check_cancelled() + q = stdlib_queue.Queue() + _core.start_thread_soon(from_thread_check_cancelled, lambda _: q.put(_)) + with pytest.raises(RuntimeError): + q.get(timeout=1).unwrap() diff --git a/trio/_tests/verify_types_darwin.json b/trio/_tests/verify_types_darwin.json index 40238f367e..b05766ca5e 100644 --- a/trio/_tests/verify_types_darwin.json +++ b/trio/_tests/verify_types_darwin.json @@ -40,7 +40,7 @@ ], "exportedSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 630, + "withKnownType": 631, "withUnknownType": 0 }, "ignoreUnknownTypesFromImports": true, diff --git a/trio/_tests/verify_types_linux.json b/trio/_tests/verify_types_linux.json index 6a8a3933d9..0f660f68b6 100644 --- a/trio/_tests/verify_types_linux.json +++ b/trio/_tests/verify_types_linux.json @@ -28,7 +28,7 @@ ], "exportedSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 627, + "withKnownType": 628, "withUnknownType": 0 }, "ignoreUnknownTypesFromImports": true, diff --git a/trio/_tests/verify_types_windows.json b/trio/_tests/verify_types_windows.json index a983537cd4..3107706623 100644 --- a/trio/_tests/verify_types_windows.json +++ b/trio/_tests/verify_types_windows.json @@ -64,7 +64,7 @@ ], "exportedSymbolCounts": { "withAmbiguousType": 0, - "withKnownType": 630, + "withKnownType": 631, "withUnknownType": 0 }, "ignoreUnknownTypesFromImports": true, diff --git a/trio/_threads.py b/trio/_threads.py index 24905cfbde..7649587751 100644 --- a/trio/_threads.py +++ b/trio/_threads.py @@ -7,7 +7,7 @@ import threading from collections.abc import Awaitable, Callable from itertools import count -from typing import TypeVar +from typing import Generic, TypeVar import attr import outcome @@ -27,16 +27,19 @@ from ._util import coroutine_or_error RetT = TypeVar("RetT") -Ret2T = TypeVar("Ret2T") -class _TokenLocal(threading.local): - """Global due to Threading API, thread local storage for trio token.""" +class _ParentTaskData(threading.local): + """Global due to Threading API, thread local storage for data related to the + parent task of native Trio threads.""" token: TrioToken + abandon_on_cancel: bool + cancel_register: list[RaiseCancelT | None] + task_register: list[trio.lowlevel.Task | None] -TOKEN_LOCAL = _TokenLocal() +PARENT_TASK_DATA = _ParentTaskData() _limiter_local: RunVar[CapacityLimiter] = RunVar("limiter") # I pulled this number out of the air; it isn't based on anything. Probably we @@ -70,6 +73,103 @@ class ThreadPlaceholder: name: str = attr.ib() +# Types for the to_thread_run_sync message loop +@attr.s(frozen=True, eq=False) +class Run(Generic[RetT]): + afn: Callable[..., Awaitable[RetT]] = attr.ib() + args: tuple[object, ...] = attr.ib() + context: contextvars.Context = attr.ib() + queue: stdlib_queue.SimpleQueue[outcome.Outcome[RetT]] = attr.ib( + init=False, factory=stdlib_queue.SimpleQueue + ) + + @disable_ki_protection + async def unprotected_afn(self) -> RetT: + coro = coroutine_or_error(self.afn, *self.args) + return await coro + + async def run(self) -> None: + # we use extra checkpoints to pick up and reset any context changes + task = trio.lowlevel.current_task() + old_context = task.context + task.context = self.context.copy() + try: + await trio.lowlevel.cancel_shielded_checkpoint() + result = await outcome.acapture(self.unprotected_afn) + self.queue.put_nowait(result) + finally: + task.context = old_context + await trio.lowlevel.cancel_shielded_checkpoint() + + async def run_system(self) -> None: + result = await outcome.acapture(self.unprotected_afn) + self.queue.put_nowait(result) + + def run_in_host_task(self, token: TrioToken) -> None: + task_register = PARENT_TASK_DATA.task_register + + def in_trio_thread() -> None: + task = task_register[0] + assert task is not None, "guaranteed by abandon_on_cancel semantics" + trio.lowlevel.reschedule(task, outcome.Value(self)) + + token.run_sync_soon(in_trio_thread) + + def run_in_system_nursery(self, token: TrioToken) -> None: + def in_trio_thread() -> None: + try: + trio.lowlevel.spawn_system_task( + self.run_system, name=self.afn, context=self.context + ) + except RuntimeError: # system nursery is closed + self.queue.put_nowait( + outcome.Error(trio.RunFinishedError("system nursery is closed")) + ) + + token.run_sync_soon(in_trio_thread) + + +@attr.s(frozen=True, eq=False) +class RunSync(Generic[RetT]): + fn: Callable[..., RetT] = attr.ib() + args: tuple[object, ...] = attr.ib() + context: contextvars.Context = attr.ib() + queue: stdlib_queue.SimpleQueue[outcome.Outcome[RetT]] = attr.ib( + init=False, factory=stdlib_queue.SimpleQueue + ) + + @disable_ki_protection + def unprotected_fn(self) -> RetT: + ret = self.fn(*self.args) + + if inspect.iscoroutine(ret): + # Manually close coroutine to avoid RuntimeWarnings + ret.close() + raise TypeError( + "Trio expected a synchronous function, but {!r} appears to be " + "asynchronous".format(getattr(self.fn, "__qualname__", self.fn)) + ) + + return ret + + def run_sync(self) -> None: + result = outcome.capture(self.context.run, self.unprotected_fn) + self.queue.put_nowait(result) + + def run_in_host_task(self, token: TrioToken) -> None: + task_register = PARENT_TASK_DATA.task_register + + def in_trio_thread() -> None: + task = task_register[0] + assert task is not None, "guaranteed by abandon_on_cancel semantics" + trio.lowlevel.reschedule(task, outcome.Value(self)) + + token.run_sync_soon(in_trio_thread) + + def run_in_system_nursery(self, token: TrioToken) -> None: + token.run_sync_soon(self.run_sync) + + @enable_ki_protection # Decorator used on function with Coroutine[Any, Any, RetT] async def to_thread_run_sync( # type: ignore[misc] sync_fn: Callable[..., RetT], @@ -162,7 +262,7 @@ async def to_thread_run_sync( # type: ignore[misc] """ await trio.lowlevel.checkpoint_if_cancelled() - cancellable = bool(cancellable) # raise early if cancellable.__bool__ raises + abandon_on_cancel = bool(cancellable) # raise early if cancellable.__bool__ raises if limiter is None: limiter = current_default_thread_limiter() @@ -170,6 +270,9 @@ async def to_thread_run_sync( # type: ignore[misc] # for the result – or None if this function was cancelled and we should # discard the result. task_register: list[trio.lowlevel.Task | None] = [trio.lowlevel.current_task()] + # Holds a reference to the raise_cancel function provided if a cancellation + # is attempted against this task - or None if no such delivery has happened. + cancel_register: list[RaiseCancelT | None] = [None] # type: ignore[assignment] name = f"trio.to_thread.run_sync-{next(_thread_counter)}" placeholder = ThreadPlaceholder(name) @@ -188,7 +291,7 @@ def do_release_then_return_result() -> RetT: result = outcome.capture(do_release_then_return_result) if task_register[0] is not None: - trio.lowlevel.reschedule(task_register[0], result) + trio.lowlevel.reschedule(task_register[0], outcome.Value(result)) current_trio_token = trio.lowlevel.current_trio_token() @@ -202,7 +305,10 @@ def worker_fn() -> RetT: # the new thread sees that it's not running in async context. current_async_library_cvar.set(None) - TOKEN_LOCAL.token = current_trio_token + PARENT_TASK_DATA.token = current_trio_token + PARENT_TASK_DATA.abandon_on_cancel = abandon_on_cancel + PARENT_TASK_DATA.cancel_register = cancel_register + PARENT_TASK_DATA.task_register = task_register try: ret = sync_fn(*args) @@ -216,7 +322,10 @@ def worker_fn() -> RetT: return ret finally: - del TOKEN_LOCAL.token + del PARENT_TASK_DATA.token + del PARENT_TASK_DATA.abandon_on_cancel + del PARENT_TASK_DATA.cancel_register + del PARENT_TASK_DATA.task_register context = contextvars.copy_context() # Partial confuses type checkers, coerce to a callable. @@ -240,37 +349,77 @@ def deliver_worker_fn_result(result: outcome.Outcome[RetT]) -> None: limiter.release_on_behalf_of(placeholder) raise - def abort(_: RaiseCancelT) -> trio.lowlevel.Abort: - if cancellable: + def abort(raise_cancel: RaiseCancelT) -> trio.lowlevel.Abort: + # fill so from_thread_check_cancelled can raise + cancel_register[0] = raise_cancel + if abandon_on_cancel: + # empty so report_back_in_trio_thread_fn cannot reschedule task_register[0] = None return trio.lowlevel.Abort.SUCCEEDED else: return trio.lowlevel.Abort.FAILED - # wait_task_rescheduled return value cannot be typed - return await trio.lowlevel.wait_task_rescheduled(abort) # type: ignore[no-any-return] + while True: + # wait_task_rescheduled return value cannot be typed + msg_from_thread: outcome.Outcome[RetT] | Run[object] | RunSync[ + object + ] = await trio.lowlevel.wait_task_rescheduled(abort) + if isinstance(msg_from_thread, outcome.Outcome): + return msg_from_thread.unwrap() # type: ignore[no-any-return] + elif isinstance(msg_from_thread, Run): + await msg_from_thread.run() + elif isinstance(msg_from_thread, RunSync): + msg_from_thread.run_sync() + else: # pragma: no cover, internal debugging guard TODO: use assert_never + raise TypeError( + "trio.to_thread.run_sync received unrecognized thread message {!r}." + "".format(msg_from_thread) + ) + del msg_from_thread -# We use two typevars here, because cb can transform from one to the other any way it likes. -def _run_fn_as_system_task( - cb: Callable[ - [ - stdlib_queue.SimpleQueue[outcome.Outcome[Ret2T]], - Callable[..., RetT], - tuple[object, ...], - ], - object, - ], - fn: Callable[..., RetT], - *args: object, - context: contextvars.Context, - trio_token: TrioToken | None = None, - # Outcome isn't typed, so Ret2T is used only in the return type. -) -> Ret2T: # type: ignore[type-var] - """Helper function for from_thread.run and from_thread.run_sync. +def from_thread_check_cancelled() -> None: + """Raise `trio.Cancelled` if the associated Trio task entered a cancelled status. - Since this internally uses TrioToken.run_sync_soon, all warnings about - raised exceptions canceling all tasks should be noted. + Only applicable to threads spawned by `trio.to_thread.run_sync`. Poll to allow + ``cancellable=False`` threads to raise :exc:`~trio.Cancelled` at a suitable + place, or to end abandoned ``cancellable=True`` threads sooner than they may + otherwise. + + Raises: + Cancelled: If the corresponding call to `trio.to_thread.run_sync` has had a + delivery of cancellation attempted against it, regardless of the value of + ``cancellable`` supplied as an argument to it. + RuntimeError: If this thread is not spawned from `trio.to_thread.run_sync`. + + .. note:: + + To be precise, :func:`~trio.from_thread.check_cancelled` checks whether the task + running :func:`trio.to_thread.run_sync` has ever been cancelled since the last + time it was running a :func:`trio.from_thread.run` or :func:`trio.from_thread.run_sync` + function. It may raise `trio.Cancelled` even if a cancellation occurred that was + later hidden by a modification to `trio.CancelScope.shield` between the cancelled + `~trio.CancelScope` and :func:`trio.to_thread.run_sync`. This differs from the + behavior of normal Trio checkpoints, which raise `~trio.Cancelled` only if the + cancellation is still active when the checkpoint executes. The distinction here is + *exceedingly* unlikely to be relevant to your application, but we mention it + for completeness. + """ + try: + raise_cancel = PARENT_TASK_DATA.cancel_register[0] + except AttributeError: + raise RuntimeError( + "this thread wasn't created by Trio, can't check for cancellation" + ) + if raise_cancel is not None: + raise_cancel() + + +def _check_token(trio_token: TrioToken | None) -> TrioToken: + """Raise a RuntimeError if this function is called within a trio run. + + Avoids deadlock by making sure we're not called from inside a context + that we might be waiting for and blocking it. """ if trio_token is not None and not isinstance(trio_token, TrioToken): @@ -278,7 +427,7 @@ def _run_fn_as_system_task( if trio_token is None: try: - trio_token = TOKEN_LOCAL.token + trio_token = PARENT_TASK_DATA.token except AttributeError: raise RuntimeError( "this thread wasn't created by Trio, pass kwarg trio_token=..." @@ -292,9 +441,7 @@ def _run_fn_as_system_task( else: raise RuntimeError("this is a blocking function; call it from a thread") - q: stdlib_queue.SimpleQueue[outcome.Outcome[Ret2T]] = stdlib_queue.SimpleQueue() - trio_token.run_sync_soon(context.run, cb, q, fn, args) - return q.get().unwrap() # type: ignore[no-any-return] # Until outcome is typed + return trio_token def from_thread_run( @@ -315,63 +462,46 @@ def from_thread_run( RunFinishedError: if the corresponding call to :func:`trio.run` has already completed, or if the run has started its final cleanup phase and can no longer spawn new system tasks. - Cancelled: if the corresponding call to :func:`trio.run` completes - while ``afn(*args)`` is running, then ``afn`` is likely to raise - :exc:`trio.Cancelled`, and this will propagate out into + Cancelled: If the original call to :func:`trio.to_thread.run_sync` is cancelled + (if *trio_token* is None) or the call to :func:`trio.run` completes + (if *trio_token* is not None) while ``afn(*args)`` is running, + then *afn* is likely to raise + completes while ``afn(*args)`` is running, then ``afn`` is likely to raise + :exc:`trio.Cancelled`. RuntimeError: if you try calling this from inside the Trio thread, - which would otherwise cause a deadlock. - AttributeError: if no ``trio_token`` was provided, and we can't infer - one from context. + which would otherwise cause a deadlock, or if no ``trio_token`` was + provided, and we can't infer one from context. TypeError: if ``afn`` is not an asynchronous function. - **Locating a Trio Token**: There are two ways to specify which + **Locating a TrioToken**: There are two ways to specify which `trio.run` loop to reenter: - Spawn this thread from `trio.to_thread.run_sync`. Trio will - automatically capture the relevant Trio token and use it when you - want to re-enter Trio. + automatically capture the relevant Trio token and use it + to re-enter the same Trio task. - Pass a keyword argument, ``trio_token`` specifying a specific `trio.run` loop to re-enter. This is useful in case you have a "foreign" thread, spawned using some other framework, and still want - to enter Trio. + to enter Trio, or if you want to use a new system task to call ``afn``, + maybe to avoid the cancellation context of a corresponding + `trio.to_thread.run_sync` task. """ + token_provided = trio_token is not None + trio_token = _check_token(trio_token) - def callback( - q: stdlib_queue.SimpleQueue[outcome.Outcome[RetT]], - afn: Callable[..., Awaitable[RetT]], - args: tuple[object, ...], - ) -> None: - @disable_ki_protection - async def unprotected_afn() -> RetT: - coro = coroutine_or_error(afn, *args) - return await coro + message_to_trio = Run(afn, args, contextvars.copy_context()) - async def await_in_trio_thread_task() -> None: - q.put_nowait(await outcome.acapture(unprotected_afn)) + if token_provided or PARENT_TASK_DATA.abandon_on_cancel: + message_to_trio.run_in_system_nursery(trio_token) + else: + message_to_trio.run_in_host_task(trio_token) - context = contextvars.copy_context() - try: - trio.lowlevel.spawn_system_task( - await_in_trio_thread_task, name=afn, context=context - ) - except RuntimeError: # system nursery is closed - q.put_nowait( - outcome.Error(trio.RunFinishedError("system nursery is closed")) - ) - - context = contextvars.copy_context() - return _run_fn_as_system_task( - callback, - afn, - *args, - context=context, - trio_token=trio_token, - ) + return message_to_trio.queue.get().unwrap() # type: ignore[no-any-return] def from_thread_run_sync( fn: Callable[..., RetT], - *args: tuple[object, ...], + *args: object, trio_token: TrioToken | None = None, ) -> RetT: """Run the given sync function in the parent Trio thread, blocking until it @@ -387,12 +517,11 @@ def from_thread_run_sync( RunFinishedError: if the corresponding call to `trio.run` has already completed. RuntimeError: if you try calling this from inside the Trio thread, - which would otherwise cause a deadlock. - AttributeError: if no ``trio_token`` was provided, and we can't infer - one from context. + which would otherwise cause a deadlock or if no ``trio_token`` was + provided, and we can't infer one from context. TypeError: if ``fn`` is an async function. - **Locating a Trio Token**: There are two ways to specify which + **Locating a TrioToken**: There are two ways to specify which `trio.run` loop to reenter: - Spawn this thread from `trio.to_thread.run_sync`. Trio will @@ -401,37 +530,18 @@ def from_thread_run_sync( - Pass a keyword argument, ``trio_token`` specifying a specific `trio.run` loop to re-enter. This is useful in case you have a "foreign" thread, spawned using some other framework, and still want - to enter Trio. + to enter Trio, or if you want to use a new system task to call ``fn``, + maybe to avoid the cancellation context of a corresponding + `trio.to_thread.run_sync` task. """ + token_provided = trio_token is not None + trio_token = _check_token(trio_token) - def callback( - q: stdlib_queue.SimpleQueue[outcome.Outcome[RetT]], - fn: Callable[..., RetT], - args: tuple[object, ...], - ) -> None: - @disable_ki_protection - def unprotected_fn() -> RetT: - ret = fn(*args) + message_to_trio = RunSync(fn, args, contextvars.copy_context()) - if inspect.iscoroutine(ret): - # Manually close coroutine to avoid RuntimeWarnings - ret.close() - raise TypeError( - "Trio expected a sync function, but {!r} appears to be " - "asynchronous".format(getattr(fn, "__qualname__", fn)) - ) - - return ret - - res = outcome.capture(unprotected_fn) - q.put_nowait(res) - - context = contextvars.copy_context() + if token_provided or PARENT_TASK_DATA.abandon_on_cancel: + message_to_trio.run_in_system_nursery(trio_token) + else: + message_to_trio.run_in_host_task(trio_token) - return _run_fn_as_system_task( - callback, - fn, - *args, - context=context, - trio_token=trio_token, - ) + return message_to_trio.queue.get().unwrap() # type: ignore[no-any-return] diff --git a/trio/from_thread.py b/trio/from_thread.py index e6f7b2495e..0de0023941 100644 --- a/trio/from_thread.py +++ b/trio/from_thread.py @@ -4,7 +4,11 @@ """ -from ._threads import from_thread_run as run, from_thread_run_sync as run_sync +from ._threads import ( + from_thread_check_cancelled as check_cancelled, + from_thread_run as run, + from_thread_run_sync as run_sync, +) # need to use __all__ for pyright --verifytypes to see re-exports when renaming them -__all__ = ["run", "run_sync"] +__all__ = ["check_cancelled", "run", "run_sync"]