Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ module = [
"trio/_core/_generated_io_windows",
"trio/_core/_io_windows",

"trio/_signals",

# internal
"trio/_windows_pipes",

Expand Down Expand Up @@ -93,7 +91,6 @@ module = [
"trio/_tests/test_highlevel_ssl_helpers",
"trio/_tests/test_path",
"trio/_tests/test_scheduler_determinism",
"trio/_tests/test_signals",
"trio/_tests/test_socket",
"trio/_tests/test_ssl",
"trio/_tests/test_subprocess",
Expand Down
44 changes: 30 additions & 14 deletions trio/_signals.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
from __future__ import annotations

import signal
from collections import OrderedDict
from collections.abc import AsyncIterator, Callable, Generator, Iterable
from contextlib import contextmanager
from types import FrameType
from typing import TYPE_CHECKING

import trio

from ._util import ConflictDetector, is_main_thread, signal_raise

if TYPE_CHECKING:
from typing_extensions import Self

# Discussion of signal handling strategies:
#
# - On Windows signals barely exist. There are no options; signal handlers are
Expand Down Expand Up @@ -43,7 +51,10 @@


@contextmanager
def _signal_handler(signals, handler):
def _signal_handler(
signals: Iterable[int],
handler: Callable[[int, FrameType | None], object] | int | signal.Handlers | None,
) -> Generator[None, None, None]:
original_handlers = {}
try:
for signum in set(signals):
Expand All @@ -55,31 +66,31 @@ def _signal_handler(signals, handler):


class SignalReceiver:
def __init__(self):
def __init__(self) -> None:
# {signal num: None}
self._pending = OrderedDict()
self._pending: OrderedDict[int, None] = OrderedDict()
self._lot = trio.lowlevel.ParkingLot()
self._conflict_detector = ConflictDetector(
"only one task can iterate on a signal receiver at a time"
)
self._closed = False

def _add(self, signum):
def _add(self, signum: int) -> None:
if self._closed:
signal_raise(signum)
else:
self._pending[signum] = None
self._lot.unpark()

def _redeliver_remaining(self):
def _redeliver_remaining(self) -> None:
# First make sure that any signals still in the delivery pipeline will
# get redelivered
self._closed = True

# And then redeliver any that are sitting in pending. This is done
# using a weird recursive construct to make sure we process everything
# even if some of the handlers raise exceptions.
def deliver_next():
def deliver_next() -> None:
if self._pending:
signum, _ = self._pending.popitem(last=False)
try:
Expand All @@ -89,14 +100,10 @@ def deliver_next():

deliver_next()

# Helper for tests, not public or otherwise used
def _pending_signal_count(self):
return len(self._pending)

def __aiter__(self):
def __aiter__(self) -> Self:
return self

async def __anext__(self):
async def __anext__(self) -> int:
if self._closed:
raise RuntimeError("open_signal_receiver block already exited")
# In principle it would be possible to support multiple concurrent
Expand All @@ -111,8 +118,17 @@ async def __anext__(self):
return signum


def get_pending_signal_count(rec: AsyncIterator[int]) -> int:
"""Helper for tests, not public or otherwise used."""
# open_signal_receiver() always produces SignalReceiver, this should not fail.
assert isinstance(rec, SignalReceiver)
return len(rec._pending)


@contextmanager
def open_signal_receiver(*signals):
def open_signal_receiver(
*signals: signal.Signals | int,
) -> Generator[AsyncIterator[int], None, None]:
"""A context manager for catching signals.

Entering this context manager starts listening for the given signals and
Expand Down Expand Up @@ -158,7 +174,7 @@ def open_signal_receiver(*signals):
token = trio.lowlevel.current_trio_token()
queue = SignalReceiver()

def handler(signum, _):
def handler(signum: int, frame: FrameType | None) -> None:
token.run_sync_soon(queue._add, signum, idempotent=True)

try:
Expand Down
46 changes: 25 additions & 21 deletions trio/_tests/test_signals.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
from __future__ import annotations

import signal
from types import FrameType
from typing import NoReturn

import pytest

import trio

from .. import _core
from .._signals import _signal_handler, open_signal_receiver
from .._signals import _signal_handler, get_pending_signal_count, open_signal_receiver
from .._util import signal_raise


async def test_open_signal_receiver():
async def test_open_signal_receiver() -> None:
orig = signal.getsignal(signal.SIGILL)
with open_signal_receiver(signal.SIGILL) as receiver:
# Raise it a few times, to exercise signal coalescing, both at the
Expand All @@ -22,18 +26,18 @@ async def test_open_signal_receiver():
async for signum in receiver: # pragma: no branch
assert signum == signal.SIGILL
break
assert receiver._pending_signal_count() == 0
assert get_pending_signal_count(receiver) == 0
signal_raise(signal.SIGILL)
async for signum in receiver: # pragma: no branch
assert signum == signal.SIGILL
break
assert receiver._pending_signal_count() == 0
assert get_pending_signal_count(receiver) == 0
with pytest.raises(RuntimeError):
await receiver.__anext__()
assert signal.getsignal(signal.SIGILL) is orig


async def test_open_signal_receiver_restore_handler_after_one_bad_signal():
async def test_open_signal_receiver_restore_handler_after_one_bad_signal() -> None:
orig = signal.getsignal(signal.SIGILL)
with pytest.raises(ValueError):
with open_signal_receiver(signal.SIGILL, 1234567):
Expand All @@ -42,30 +46,30 @@ async def test_open_signal_receiver_restore_handler_after_one_bad_signal():
assert signal.getsignal(signal.SIGILL) is orig


async def test_open_signal_receiver_empty_fail():
async def test_open_signal_receiver_empty_fail() -> None:
with pytest.raises(TypeError, match="No signals were provided"):
with open_signal_receiver():
pass


async def test_open_signal_receiver_restore_handler_after_duplicate_signal():
async def test_open_signal_receiver_restore_handler_after_duplicate_signal() -> None:
orig = signal.getsignal(signal.SIGILL)
with open_signal_receiver(signal.SIGILL, signal.SIGILL):
pass
# Still restored correctly
assert signal.getsignal(signal.SIGILL) is orig


async def test_catch_signals_wrong_thread():
async def naughty():
async def test_catch_signals_wrong_thread() -> None:
async def naughty() -> None:
with open_signal_receiver(signal.SIGINT):
pass # pragma: no cover

with pytest.raises(RuntimeError):
await trio.to_thread.run_sync(trio.run, naughty)


async def test_open_signal_receiver_conflict():
async def test_open_signal_receiver_conflict() -> None:
with pytest.raises(trio.BusyResourceError):
with open_signal_receiver(signal.SIGILL) as receiver:
async with trio.open_nursery() as nursery:
Expand All @@ -75,14 +79,14 @@ async def test_open_signal_receiver_conflict():

# Blocks until all previous calls to run_sync_soon(idempotent=True) have been
# processed.
async def wait_run_sync_soon_idempotent_queue_barrier():
async def wait_run_sync_soon_idempotent_queue_barrier() -> None:
ev = trio.Event()
token = _core.current_trio_token()
token.run_sync_soon(ev.set, idempotent=True)
await ev.wait()


async def test_open_signal_receiver_no_starvation():
async def test_open_signal_receiver_no_starvation() -> None:
# Set up a situation where there are always 2 pending signals available to
# report, and make sure that instead of getting the same signal reported
# over and over, it alternates between reporting both of them.
Expand All @@ -101,8 +105,8 @@ async def test_open_signal_receiver_no_starvation():
assert got in [signal.SIGILL, signal.SIGFPE]
assert got != previous
previous = got
# Clear out the last signal so it doesn't get redelivered
while receiver._pending_signal_count() != 0:
# Clear out the last signal so that it doesn't get redelivered
while get_pending_signal_count(receiver) != 0:
await receiver.__anext__()
except: # pragma: no cover
# If there's an unhandled exception above, then exiting the
Expand All @@ -113,10 +117,10 @@ async def test_open_signal_receiver_no_starvation():
traceback.print_exc()


async def test_catch_signals_race_condition_on_exit():
delivered_directly = set()
async def test_catch_signals_race_condition_on_exit() -> None:
delivered_directly: set[int] = set()

def direct_handler(signo, frame):
def direct_handler(signo: int, frame: FrameType | None) -> None:
delivered_directly.add(signo)

print(1)
Expand All @@ -138,7 +142,7 @@ def direct_handler(signo, frame):
signal_raise(signal.SIGILL)
signal_raise(signal.SIGFPE)
await wait_run_sync_soon_idempotent_queue_barrier()
assert receiver._pending_signal_count() == 2
assert get_pending_signal_count(receiver) == 2
assert delivered_directly == {signal.SIGILL, signal.SIGFPE}
delivered_directly.clear()

Expand All @@ -156,12 +160,12 @@ def direct_handler(signo, frame):
with open_signal_receiver(signal.SIGILL) as receiver:
signal_raise(signal.SIGILL)
await wait_run_sync_soon_idempotent_queue_barrier()
assert receiver._pending_signal_count() == 1
assert get_pending_signal_count(receiver) == 1
# test passes if the process reaches this point without dying

# Check exception chaining if there are multiple exception-raising
# handlers
def raise_handler(signum, _):
def raise_handler(signum: int, frame: FrameType | None) -> NoReturn:
raise RuntimeError(signum)

with _signal_handler({signal.SIGILL, signal.SIGFPE}, raise_handler):
Expand All @@ -170,7 +174,7 @@ def raise_handler(signum, _):
signal_raise(signal.SIGILL)
signal_raise(signal.SIGFPE)
await wait_run_sync_soon_idempotent_queue_barrier()
assert receiver._pending_signal_count() == 2
assert get_pending_signal_count(receiver) == 2
exc = excinfo.value
signums = {exc.args[0]}
assert isinstance(exc.__context__, RuntimeError)
Expand Down