Skip to content
Merged
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ extend-ignore = ['D', 'E', 'W', 'F403', 'F405', 'F821', 'F822']
per-file-ignores = [
'trio/__init__.py: F401',
'trio/_core/__init__.py: F401',
'trio/_core/_generated*.py: F401',
'trio/_core/_tests/test_multierror_scripts/*: F401',
'trio/abc.py: F401',
'trio/lowlevel.py: F401',
Expand Down Expand Up @@ -51,6 +50,10 @@ disallow_untyped_defs = false
module = [
"trio._abc",
"trio._core._entry_queue",
"trio._core._generated_io_epoll",
"trio._core._generated_io_kqueue",
"trio._core._io_epoll",
"trio._core._io_kqueue",
"trio._core._local",
"trio._core._unbounded_queue",
"trio._core._thread_cache",
Expand Down
12 changes: 7 additions & 5 deletions trio/_core/_generated_instrumentation.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 13 additions & 7 deletions trio/_core/_generated_io_epoll.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 26 additions & 10 deletions trio/_core/_generated_io_kqueue.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 9 additions & 4 deletions trio/_core/_generated_io_windows.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 7 additions & 5 deletions trio/_core/_generated_run.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

66 changes: 37 additions & 29 deletions trio/_core/_io_epoll.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,38 @@
from __future__ import annotations

import select
import sys
from collections import defaultdict
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING, DefaultDict, Literal

import attr

from .. import _core
from ._io_common import wake_all
from ._run import _public
from ._run import Task, _public
from ._wakeup_socketpair import WakeupSocketpair

if TYPE_CHECKING:
from socket import socket

from .._core import Abort, RaiseCancelT


@attr.s(slots=True, eq=False)
class EpollWaiters:
read_task: Task | None = attr.ib(default=None)
write_task: Task | None = attr.ib(default=None)
current_flags: int = attr.ib(default=0)


assert not TYPE_CHECKING or sys.platform == "linux"


@attr.s(slots=True, eq=False, frozen=True)
class _EpollStatistics:
tasks_waiting_read = attr.ib()
tasks_waiting_write = attr.ib()
backend = attr.ib(default="epoll")
tasks_waiting_read: int = attr.ib()
tasks_waiting_write: int = attr.ib()
backend: Literal["epoll"] = attr.ib(init=False, default="epoll")


# Some facts about epoll
Expand Down Expand Up @@ -178,28 +193,21 @@ class _EpollStatistics:
# wanted to about how epoll works.


@attr.s(slots=True, eq=False)
class EpollWaiters:
read_task = attr.ib(default=None)
write_task = attr.ib(default=None)
current_flags = attr.ib(default=0)


@attr.s(slots=True, eq=False, hash=False)
class EpollIOManager:
_epoll = attr.ib(factory=select.epoll)
_epoll: select.epoll = attr.ib(factory=select.epoll)
# {fd: EpollWaiters}
_registered = attr.ib(
factory=lambda: defaultdict(EpollWaiters), type=Dict[int, EpollWaiters]
_registered: DefaultDict[int, EpollWaiters] = attr.ib(
factory=lambda: defaultdict(EpollWaiters)
)
_force_wakeup = attr.ib(factory=WakeupSocketpair)
_force_wakeup_fd = attr.ib(default=None)
_force_wakeup: WakeupSocketpair = attr.ib(factory=WakeupSocketpair)
_force_wakeup_fd: int | None = attr.ib(default=None)

def __attrs_post_init__(self):
def __attrs_post_init__(self) -> None:
self._epoll.register(self._force_wakeup.wakeup_sock, select.EPOLLIN)
self._force_wakeup_fd = self._force_wakeup.wakeup_sock.fileno()

def statistics(self):
def statistics(self) -> _EpollStatistics:
tasks_waiting_read = 0
tasks_waiting_write = 0
for waiter in self._registered.values():
Expand All @@ -212,24 +220,24 @@ def statistics(self):
tasks_waiting_write=tasks_waiting_write,
)

def close(self):
def close(self) -> None:
self._epoll.close()
self._force_wakeup.close()

def force_wakeup(self):
def force_wakeup(self) -> None:
self._force_wakeup.wakeup_thread_and_signal_safe()

# Return value must be False-y IFF the timeout expired, NOT if any I/O
# happened or force_wakeup was called. Otherwise it can be anything; gets
# passed straight through to process_events.
def get_events(self, timeout):
def get_events(self, timeout: float) -> list[tuple[int, int]]:
# max_events must be > 0 or epoll gets cranky
# accessing self._registered from a thread looks dangerous, but it's
# OK because it doesn't matter if our value is a little bit off.
max_events = max(1, len(self._registered))
return self._epoll.poll(timeout, max_events)

def process_events(self, events):
def process_events(self, events: list[tuple[int, int]]) -> None:
for fd, flags in events:
if fd == self._force_wakeup_fd:
self._force_wakeup.drain()
Expand All @@ -248,7 +256,7 @@ def process_events(self, events):
waiters.read_task = None
self._update_registrations(fd)

def _update_registrations(self, fd):
def _update_registrations(self, fd: int) -> None:
waiters = self._registered[fd]
wanted_flags = 0
if waiters.read_task is not None:
Expand Down Expand Up @@ -277,7 +285,7 @@ def _update_registrations(self, fd):
if not wanted_flags:
del self._registered[fd]

async def _epoll_wait(self, fd, attr_name):
async def _epoll_wait(self, fd: int | socket, attr_name: str) -> None:
if not isinstance(fd, int):
fd = fd.fileno()
waiters = self._registered[fd]
Expand All @@ -288,23 +296,23 @@ async def _epoll_wait(self, fd, attr_name):
setattr(waiters, attr_name, _core.current_task())
self._update_registrations(fd)

def abort(_):
def abort(_: RaiseCancelT) -> Abort:
setattr(waiters, attr_name, None)
self._update_registrations(fd)
return _core.Abort.SUCCEEDED

await _core.wait_task_rescheduled(abort)

@_public
async def wait_readable(self, fd):
async def wait_readable(self, fd: int | socket) -> None:
await self._epoll_wait(fd, "read_task")

@_public
async def wait_writable(self, fd):
async def wait_writable(self, fd: int | socket) -> None:
await self._epoll_wait(fd, "write_task")

@_public
def notify_closing(self, fd):
def notify_closing(self, fd: int | socket) -> None:
if not isinstance(fd, int):
fd = fd.fileno()
wake_all(
Expand Down
Loading