Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 35 additions & 59 deletions src/trio/_threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import contextlib
import contextvars
import functools
import inspect
import queue as stdlib_queue
import threading
Expand Down Expand Up @@ -83,7 +82,7 @@ class ThreadPlaceholder:
class Run(Generic[RetT]):
afn: Callable[..., Awaitable[RetT]] = attr.ib()
args: tuple[object, ...] = attr.ib()
context: contextvars.Context = attr.ib()
context: contextvars.Context = attr.ib(init=False, factory=contextvars.copy_context)
queue: stdlib_queue.SimpleQueue[outcome.Outcome[RetT]] = attr.ib(
init=False, factory=stdlib_queue.SimpleQueue
)
Expand Down Expand Up @@ -136,14 +135,14 @@ def in_trio_thread() -> None:
class RunSync(Generic[RetT]):
fn: Callable[..., RetT] = attr.ib()
args: tuple[object, ...] = attr.ib()
context: contextvars.Context = attr.ib()
context: contextvars.Context = attr.ib(init=False, factory=contextvars.copy_context)
queue: stdlib_queue.SimpleQueue[outcome.Outcome[RetT]] = attr.ib(
init=False, factory=stdlib_queue.SimpleQueue
)

@disable_ki_protection
def unprotected_fn(self) -> RetT:
ret = self.fn(*self.args)
ret = self.context.run(self.fn, *self.args)

if inspect.iscoroutine(ret):
# Manually close coroutine to avoid RuntimeWarnings
Expand All @@ -156,9 +155,7 @@ def unprotected_fn(self) -> RetT:
return ret

def run_sync(self) -> None:
# Two paramspecs + overload is a bit too hard for mypy to handle. Tell it what to infer.
runner: Callable[[Callable[[], RetT]], RetT] = self.context.run
result = outcome.capture(runner, self.unprotected_fn)
result = outcome.capture(self.unprotected_fn)
self.queue.put_nowait(result)

def run_in_host_task(self, token: TrioToken) -> None:
Expand Down Expand Up @@ -303,7 +300,8 @@ async def to_thread_run_sync( # type: ignore[misc]
)
abandon_on_cancel = cancellable
# raise early if abandon_on_cancel.__bool__ raises
abandon_on_cancel = bool(abandon_on_cancel)
# and give a new name to ensure mypy knows it's never None
abandon_bool = bool(abandon_on_cancel)
if limiter is None:
limiter = current_default_thread_limiter()

Expand Down Expand Up @@ -340,18 +338,12 @@ def do_release_then_return_result() -> RetT:
thread_name = f"{getattr(sync_fn, '__name__', None)} from {trio.lowlevel.current_task().name}"

def worker_fn() -> RetT:
# Trio doesn't use current_async_library_cvar, but if someone
# else set it, it would now shine through since
# snifio.thread_local isn't set in the new thread. Make sure
# the new thread sees that it's not running in async context.
current_async_library_cvar.set(None)

PARENT_TASK_DATA.token = current_trio_token
PARENT_TASK_DATA.abandon_on_cancel = abandon_on_cancel
PARENT_TASK_DATA.abandon_on_cancel = abandon_bool
PARENT_TASK_DATA.cancel_register = cancel_register
PARENT_TASK_DATA.task_register = task_register
try:
ret = sync_fn(*args)
ret = context.run(sync_fn, *args)

if inspect.iscoroutine(ret):
# Manually close coroutine to avoid RuntimeWarnings
Expand All @@ -369,29 +361,30 @@ def worker_fn() -> RetT:
del PARENT_TASK_DATA.task_register

context = contextvars.copy_context()
# Partial confuses type checkers, coerce to a callable.
contextvars_aware_worker_fn: Callable[[], RetT] = functools.partial(context.run, worker_fn) # type: ignore[assignment]
# Trio doesn't use current_async_library_cvar, but if someone
# else set it, it would now shine through since
# sniffio.thread_local isn't set in the new thread. Make sure
# the new thread sees that it's not running in async context.
context.run(current_async_library_cvar.set, None)

def deliver_worker_fn_result(result: outcome.Outcome[RetT]) -> None:
# The entire run finished, so the task we're trying to contact is
# If the entire run finished, the task we're trying to contact is
# certainly long gone -- it must have been cancelled and abandoned
# us.
# us. Just ignore the error in this case.
with contextlib.suppress(trio.RunFinishedError):
current_trio_token.run_sync_soon(report_back_in_trio_thread_fn, result)

await limiter.acquire_on_behalf_of(placeholder)
try:
start_thread_soon(
contextvars_aware_worker_fn, deliver_worker_fn_result, thread_name
)
start_thread_soon(worker_fn, deliver_worker_fn_result, thread_name)
except:
limiter.release_on_behalf_of(placeholder)
raise

def abort(raise_cancel: RaiseCancelT) -> trio.lowlevel.Abort:
# fill so from_thread_check_cancelled can raise
cancel_register[0] = raise_cancel
if abandon_on_cancel:
if abandon_bool:
# empty so report_back_in_trio_thread_fn cannot reschedule
task_register[0] = None
return trio.lowlevel.Abort.SUCCEEDED
Expand Down Expand Up @@ -446,31 +439,29 @@ def from_thread_check_cancelled() -> None:
"""
try:
raise_cancel = PARENT_TASK_DATA.cancel_register[0]
except AttributeError as exc:
except AttributeError:
raise RuntimeError(
"this thread wasn't created by Trio, can't check for cancellation"
) from exc
) from None
if raise_cancel is not None:
raise_cancel()


def _check_token(trio_token: TrioToken | None) -> TrioToken:
"""Raise a RuntimeError if this function is called within a trio run.

Avoids deadlock by making sure we're not called from inside a context
that we might be waiting for and blocking it.
"""

if trio_token is not None and not isinstance(trio_token, TrioToken):
raise RuntimeError("Passed kwarg trio_token is not of type TrioToken")
def _send_message_to_trio(
trio_token: TrioToken | None, message_to_trio: Run[RetT] | RunSync[RetT]
) -> RetT:
"""Shared logic of from_thread functions"""
token_provided = trio_token is not None

if trio_token is None:
if not token_provided:
try:
trio_token = PARENT_TASK_DATA.token
except AttributeError:
raise RuntimeError(
"this thread wasn't created by Trio, pass kwarg trio_token=..."
) from None
elif not isinstance(trio_token, TrioToken):
raise RuntimeError("Passed kwarg trio_token is not of type TrioToken")

# Avoid deadlock by making sure we're not called from Trio thread
try:
Expand All @@ -480,7 +471,12 @@ def _check_token(trio_token: TrioToken | None) -> TrioToken:
else:
raise RuntimeError("this is a blocking function; call it from a thread")

return trio_token
if token_provided or PARENT_TASK_DATA.abandon_on_cancel:
message_to_trio.run_in_system_nursery(trio_token)
else:
message_to_trio.run_in_host_task(trio_token)

return message_to_trio.queue.get().unwrap()


def from_thread_run(
Expand Down Expand Up @@ -523,17 +519,7 @@ def from_thread_run(
maybe to avoid the cancellation context of a corresponding
`trio.to_thread.run_sync` task.
"""
token_provided = trio_token is not None
trio_token = _check_token(trio_token)

message_to_trio = Run(afn, args, contextvars.copy_context())

if token_provided or PARENT_TASK_DATA.abandon_on_cancel:
message_to_trio.run_in_system_nursery(trio_token)
else:
message_to_trio.run_in_host_task(trio_token)

return message_to_trio.queue.get().unwrap()
return _send_message_to_trio(trio_token, Run(afn, args))


def from_thread_run_sync(
Expand Down Expand Up @@ -571,14 +557,4 @@ def from_thread_run_sync(
maybe to avoid the cancellation context of a corresponding
`trio.to_thread.run_sync` task.
"""
token_provided = trio_token is not None
trio_token = _check_token(trio_token)

message_to_trio = RunSync(fn, args, contextvars.copy_context())

if token_provided or PARENT_TASK_DATA.abandon_on_cancel:
message_to_trio.run_in_system_nursery(trio_token)
else:
message_to_trio.run_in_host_task(trio_token)

return message_to_trio.queue.get().unwrap()
return _send_message_to_trio(trio_token, RunSync(fn, args))