From b7e91ff54cc2a3cbde0964a169844e6cf233cd3d Mon Sep 17 00:00:00 2001 From: Spencer Brown Date: Wed, 6 Sep 2023 15:27:19 +1000 Subject: [PATCH 1/4] Add types to _signals and test_signals --- trio/_signals.py | 33 +++++++++++++++++++++++---------- trio/_tests/test_signals.py | 16 ++++++++-------- 2 files changed, 31 insertions(+), 18 deletions(-) diff --git a/trio/_signals.py b/trio/_signals.py index fe2bde946e..326cf95a5e 100644 --- a/trio/_signals.py +++ b/trio/_signals.py @@ -1,11 +1,19 @@ +from __future__ import annotations + import signal from collections import OrderedDict +from collections.abc import Callable, Generator, Iterable from contextlib import contextmanager +from types import FrameType +from typing import TYPE_CHECKING import trio from ._util import ConflictDetector, is_main_thread, signal_raise +if TYPE_CHECKING: + from typing_extensions import Self + # Discussion of signal handling strategies: # # - On Windows signals barely exist. There are no options; signal handlers are @@ -43,7 +51,10 @@ @contextmanager -def _signal_handler(signals, handler): +def _signal_handler( + signals: Iterable[int], + handler: Callable[[int, FrameType | None], object] | int | signal.Handlers | None, +) -> Generator[None, None, None]: original_handlers = {} try: for signum in set(signals): @@ -55,23 +66,23 @@ def _signal_handler(signals, handler): class SignalReceiver: - def __init__(self): + def __init__(self) -> None: # {signal num: None} - self._pending = OrderedDict() + self._pending: OrderedDict[int, None] = OrderedDict() self._lot = trio.lowlevel.ParkingLot() self._conflict_detector = ConflictDetector( "only one task can iterate on a signal receiver at a time" ) self._closed = False - def _add(self, signum): + def _add(self, signum: int) -> None: if self._closed: signal_raise(signum) else: self._pending[signum] = None self._lot.unpark() - def _redeliver_remaining(self): + def _redeliver_remaining(self) -> None: # First make sure that any signals still in the delivery pipeline will # get redelivered self._closed = True @@ -90,13 +101,13 @@ def deliver_next(): deliver_next() # Helper for tests, not public or otherwise used - def _pending_signal_count(self): + def _pending_signal_count(self) -> int: return len(self._pending) - def __aiter__(self): + def __aiter__(self) -> Self: return self - async def __anext__(self): + async def __anext__(self) -> int: if self._closed: raise RuntimeError("open_signal_receiver block already exited") # In principle it would be possible to support multiple concurrent @@ -112,7 +123,9 @@ async def __anext__(self): @contextmanager -def open_signal_receiver(*signals): +def open_signal_receiver( + *signals: signal.Signals | int, +) -> Generator[SignalReceiver, None, None]: """A context manager for catching signals. Entering this context manager starts listening for the given signals and @@ -158,7 +171,7 @@ def open_signal_receiver(*signals): token = trio.lowlevel.current_trio_token() queue = SignalReceiver() - def handler(signum, _): + def handler(signum: int, frame: FrameType | None) -> None: token.run_sync_soon(queue._add, signum, idempotent=True) try: diff --git a/trio/_tests/test_signals.py b/trio/_tests/test_signals.py index 313cce259f..2d81cd9d9d 100644 --- a/trio/_tests/test_signals.py +++ b/trio/_tests/test_signals.py @@ -9,7 +9,7 @@ from .._util import signal_raise -async def test_open_signal_receiver(): +async def test_open_signal_receiver() -> None: orig = signal.getsignal(signal.SIGILL) with open_signal_receiver(signal.SIGILL) as receiver: # Raise it a few times, to exercise signal coalescing, both at the @@ -33,7 +33,7 @@ async def test_open_signal_receiver(): assert signal.getsignal(signal.SIGILL) is orig -async def test_open_signal_receiver_restore_handler_after_one_bad_signal(): +async def test_open_signal_receiver_restore_handler_after_one_bad_signal() -> None: orig = signal.getsignal(signal.SIGILL) with pytest.raises(ValueError): with open_signal_receiver(signal.SIGILL, 1234567): @@ -42,13 +42,13 @@ async def test_open_signal_receiver_restore_handler_after_one_bad_signal(): assert signal.getsignal(signal.SIGILL) is orig -async def test_open_signal_receiver_empty_fail(): +async def test_open_signal_receiver_empty_fail() -> None: with pytest.raises(TypeError, match="No signals were provided"): with open_signal_receiver(): pass -async def test_open_signal_receiver_restore_handler_after_duplicate_signal(): +async def test_open_signal_receiver_restore_handler_after_duplicate_signal() -> None: orig = signal.getsignal(signal.SIGILL) with open_signal_receiver(signal.SIGILL, signal.SIGILL): pass @@ -56,7 +56,7 @@ async def test_open_signal_receiver_restore_handler_after_duplicate_signal(): assert signal.getsignal(signal.SIGILL) is orig -async def test_catch_signals_wrong_thread(): +async def test_catch_signals_wrong_thread() -> None: async def naughty(): with open_signal_receiver(signal.SIGINT): pass # pragma: no cover @@ -65,7 +65,7 @@ async def naughty(): await trio.to_thread.run_sync(trio.run, naughty) -async def test_open_signal_receiver_conflict(): +async def test_open_signal_receiver_conflict() -> None: with pytest.raises(trio.BusyResourceError): with open_signal_receiver(signal.SIGILL) as receiver: async with trio.open_nursery() as nursery: @@ -75,14 +75,14 @@ async def test_open_signal_receiver_conflict(): # Blocks until all previous calls to run_sync_soon(idempotent=True) have been # processed. -async def wait_run_sync_soon_idempotent_queue_barrier(): +async def wait_run_sync_soon_idempotent_queue_barrier() -> None: ev = trio.Event() token = _core.current_trio_token() token.run_sync_soon(ev.set, idempotent=True) await ev.wait() -async def test_open_signal_receiver_no_starvation(): +async def test_open_signal_receiver_no_starvation() -> None: # Set up a situation where there are always 2 pending signals available to # report, and make sure that instead of getting the same signal reported # over and over, it alternates between reporting both of them. From 4b8190cfbc1462a1d78a0dffc41048a775d13ad6 Mon Sep 17 00:00:00 2001 From: Spencer Brown Date: Wed, 6 Sep 2023 15:23:46 +1000 Subject: [PATCH 2/4] Hide SignalReceiver from the API. Convert this helper to a function to allow that. --- trio/_signals.py | 15 +++++++++------ trio/_tests/test_signals.py | 16 ++++++++-------- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/trio/_signals.py b/trio/_signals.py index 326cf95a5e..ddf6297d05 100644 --- a/trio/_signals.py +++ b/trio/_signals.py @@ -2,7 +2,7 @@ import signal from collections import OrderedDict -from collections.abc import Callable, Generator, Iterable +from collections.abc import AsyncIterator, Callable, Generator, Iterable from contextlib import contextmanager from types import FrameType from typing import TYPE_CHECKING @@ -100,10 +100,6 @@ def deliver_next(): deliver_next() - # Helper for tests, not public or otherwise used - def _pending_signal_count(self) -> int: - return len(self._pending) - def __aiter__(self) -> Self: return self @@ -122,10 +118,17 @@ async def __anext__(self) -> int: return signum +def get_pending_signal_count(rec: AsyncIterator[int]) -> int: + """Helper for tests, not public or otherwise used.""" + # open_signal_receiver() always produces SignalReceiver, so cast. + assert isinstance(rec, SignalReceiver) + return len(rec._pending) + + @contextmanager def open_signal_receiver( *signals: signal.Signals | int, -) -> Generator[SignalReceiver, None, None]: +) -> Generator[AsyncIterator[int], None, None]: """A context manager for catching signals. Entering this context manager starts listening for the given signals and diff --git a/trio/_tests/test_signals.py b/trio/_tests/test_signals.py index 2d81cd9d9d..ed158d8935 100644 --- a/trio/_tests/test_signals.py +++ b/trio/_tests/test_signals.py @@ -5,7 +5,7 @@ import trio from .. import _core -from .._signals import _signal_handler, open_signal_receiver +from .._signals import _signal_handler, get_pending_signal_count, open_signal_receiver from .._util import signal_raise @@ -22,12 +22,12 @@ async def test_open_signal_receiver() -> None: async for signum in receiver: # pragma: no branch assert signum == signal.SIGILL break - assert receiver._pending_signal_count() == 0 + assert get_pending_signal_count(receiver) == 0 signal_raise(signal.SIGILL) async for signum in receiver: # pragma: no branch assert signum == signal.SIGILL break - assert receiver._pending_signal_count() == 0 + assert get_pending_signal_count(receiver) == 0 with pytest.raises(RuntimeError): await receiver.__anext__() assert signal.getsignal(signal.SIGILL) is orig @@ -101,8 +101,8 @@ async def test_open_signal_receiver_no_starvation() -> None: assert got in [signal.SIGILL, signal.SIGFPE] assert got != previous previous = got - # Clear out the last signal so it doesn't get redelivered - while receiver._pending_signal_count() != 0: + # Clear out the last signal so that it doesn't get redelivered + while get_pending_signal_count(receiver) != 0: await receiver.__anext__() except: # pragma: no cover # If there's an unhandled exception above, then exiting the @@ -138,7 +138,7 @@ def direct_handler(signo, frame): signal_raise(signal.SIGILL) signal_raise(signal.SIGFPE) await wait_run_sync_soon_idempotent_queue_barrier() - assert receiver._pending_signal_count() == 2 + assert get_pending_signal_count(receiver) == 2 assert delivered_directly == {signal.SIGILL, signal.SIGFPE} delivered_directly.clear() @@ -156,7 +156,7 @@ def direct_handler(signo, frame): with open_signal_receiver(signal.SIGILL) as receiver: signal_raise(signal.SIGILL) await wait_run_sync_soon_idempotent_queue_barrier() - assert receiver._pending_signal_count() == 1 + assert get_pending_signal_count(receiver) == 1 # test passes if the process reaches this point without dying # Check exception chaining if there are multiple exception-raising @@ -170,7 +170,7 @@ def raise_handler(signum, _): signal_raise(signal.SIGILL) signal_raise(signal.SIGFPE) await wait_run_sync_soon_idempotent_queue_barrier() - assert receiver._pending_signal_count() == 2 + assert get_pending_signal_count(receiver) == 2 exc = excinfo.value signums = {exc.args[0]} assert isinstance(exc.__context__, RuntimeError) From cc5ba31aeb5b4d388d9f3465723f770a25e79d53 Mon Sep 17 00:00:00 2001 From: Spencer Brown Date: Thu, 7 Sep 2023 12:07:14 +1000 Subject: [PATCH 3/4] Enable strict type checking for _signals and test_signals --- pyproject.toml | 3 --- trio/_signals.py | 2 +- trio/_tests/test_signals.py | 14 +++++++++----- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 17dd2aa1b7..267c81557d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,8 +64,6 @@ module = [ "trio/_core/_generated_io_windows", "trio/_core/_io_windows", -"trio/_signals", - # internal "trio/_windows_pipes", @@ -102,7 +100,6 @@ module = [ "trio/_tests/test_highlevel_ssl_helpers", "trio/_tests/test_path", "trio/_tests/test_scheduler_determinism", -"trio/_tests/test_signals", "trio/_tests/test_socket", "trio/_tests/test_ssl", "trio/_tests/test_subprocess", diff --git a/trio/_signals.py b/trio/_signals.py index ddf6297d05..451ed00aab 100644 --- a/trio/_signals.py +++ b/trio/_signals.py @@ -90,7 +90,7 @@ def _redeliver_remaining(self) -> None: # And then redeliver any that are sitting in pending. This is done # using a weird recursive construct to make sure we process everything # even if some of the handlers raise exceptions. - def deliver_next(): + def deliver_next() -> None: if self._pending: signum, _ = self._pending.popitem(last=False) try: diff --git a/trio/_tests/test_signals.py b/trio/_tests/test_signals.py index ed158d8935..1e42239e35 100644 --- a/trio/_tests/test_signals.py +++ b/trio/_tests/test_signals.py @@ -1,4 +1,8 @@ +from __future__ import annotations + import signal +from types import FrameType +from typing import NoReturn import pytest @@ -57,7 +61,7 @@ async def test_open_signal_receiver_restore_handler_after_duplicate_signal() -> async def test_catch_signals_wrong_thread() -> None: - async def naughty(): + async def naughty() -> None: with open_signal_receiver(signal.SIGINT): pass # pragma: no cover @@ -113,10 +117,10 @@ async def test_open_signal_receiver_no_starvation() -> None: traceback.print_exc() -async def test_catch_signals_race_condition_on_exit(): - delivered_directly = set() +async def test_catch_signals_race_condition_on_exit() -> None: + delivered_directly: set[int] = set() - def direct_handler(signo, frame): + def direct_handler(signo: int, frame: FrameType | None) -> None: delivered_directly.add(signo) print(1) @@ -161,7 +165,7 @@ def direct_handler(signo, frame): # Check exception chaining if there are multiple exception-raising # handlers - def raise_handler(signum, _): + def raise_handler(signum: int, frame: FrameType | None) -> NoReturn: raise RuntimeError(signum) with _signal_handler({signal.SIGILL, signal.SIGFPE}, raise_handler): From 7f27398d205a7dc6c3327f2d7811ac040dc18347 Mon Sep 17 00:00:00 2001 From: Spencer Brown Date: Tue, 12 Sep 2023 18:28:17 +1000 Subject: [PATCH 4/4] Tweak comment wording --- trio/_signals.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trio/_signals.py b/trio/_signals.py index 451ed00aab..283c3a44a8 100644 --- a/trio/_signals.py +++ b/trio/_signals.py @@ -120,7 +120,7 @@ async def __anext__(self) -> int: def get_pending_signal_count(rec: AsyncIterator[int]) -> int: """Helper for tests, not public or otherwise used.""" - # open_signal_receiver() always produces SignalReceiver, so cast. + # open_signal_receiver() always produces SignalReceiver, this should not fail. assert isinstance(rec, SignalReceiver) return len(rec._pending)