Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
fbae13d
wip
jakkdl Aug 23, 2023
1a38b7e
wip2
jakkdl Aug 25, 2023
201da1a
hide methods in SocketType behind a TYPE_CHECKING guard, update _io_k…
jakkdl Aug 28, 2023
6cd3da7
Merge remote-tracking branch 'origin/master' into typing_sockettype
jakkdl Aug 28, 2023
c8fc73c
remove unneeded pragma: no cover
jakkdl Aug 29, 2023
e3fc6ae
fix default parameters
jakkdl Aug 29, 2023
15e8351
Merge remote-tracking branch 'origin/master' into typing_sockettype
jakkdl Sep 2, 2023
d60d1a1
raise NotImplementedError instead of methods not being defined at
jakkdl Sep 2, 2023
b213602
remove docs reference
jakkdl Sep 2, 2023
7eb099d
experimentally changing SocketType to be a generic
jakkdl Sep 2, 2023
ab66937
fix url
jakkdl Sep 5, 2023
304fd32
fix old imports
jakkdl Sep 5, 2023
c36958c
Merge remote-tracking branch 'origin/master' into typing_sockettype
jakkdl Sep 16, 2023
920d663
fix verify_types
jakkdl Sep 16, 2023
a3169cc
reverse making SocketType generic, CI fixes
jakkdl Sep 16, 2023
b9bbc06
fix verifytypes
jakkdl Sep 16, 2023
76ab994
fix suggestions from teamspen
jakkdl Sep 20, 2023
ea36fe8
Merge remote-tracking branch 'origin/master' into typing_sockettype
jakkdl Sep 20, 2023
0c294ea
Update `verify_types`
CoolCat467 Sep 23, 2023
dc09e7e
clarifying comments after a5rocks review
jakkdl Sep 26, 2023
11cfd4e
Merge remote-tracking branch 'origin/master' into typing_sockettype
jakkdl Sep 26, 2023
d934724
fix verify_types ............
jakkdl Sep 26, 2023
7335184
add @_public to exclude_lines in coveragerc
jakkdl Sep 26, 2023
618fdbf
Revert "add @_public to exclude_lines in coveragerc"
jakkdl Sep 26, 2023
a20a909
'pragma: no cover' an added type-converting line
jakkdl Sep 28, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ exclude_lines =
if t.TYPE_CHECKING:
@overload
class .*\bProtocol\b.*\):
raise NotImplementedError

partial_branches =
pragma: no branch
Expand Down
2 changes: 0 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@
("py:class", "sync function"),
# why aren't these found in stdlib?
("py:class", "types.FrameType"),
# TODO: temporary type
("py:class", "_SocketType"),
# these are not defined in https://docs.python.org/3/objects.inv
("py:class", "socket.AddressFamily"),
("py:class", "socket.SocketKind"),
Expand Down
7 changes: 0 additions & 7 deletions docs/source/reference-io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -504,13 +504,6 @@ Socket objects
* :meth:`~socket.socket.set_inheritable`
* :meth:`~socket.socket.get_inheritable`

The internal SocketType
~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: _SocketType
..
TODO: adding `:members:` here gives error due to overload+_wraps on `sendto`
TODO: rewrite ... all of the above when fixing _SocketType vs SocketType


.. currentmodule:: trio

Expand Down
4 changes: 0 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,11 @@ module = [
"trio/_tests/test_exports",
"trio/_tests/test_file_io",
"trio/_tests/test_highlevel_generic",
"trio/_tests/test_highlevel_open_tcp_listeners",
"trio/_tests/test_highlevel_open_tcp_stream",
"trio/_tests/test_highlevel_open_unix_stream",
"trio/_tests/test_highlevel_serve_listeners",
"trio/_tests/test_highlevel_socket",
"trio/_tests/test_highlevel_ssl_helpers",
"trio/_tests/test_path",
"trio/_tests/test_scheduler_determinism",
"trio/_tests/test_socket",
"trio/_tests/test_ssl",
"trio/_tests/test_subprocess",
"trio/_tests/test_sync",
Expand Down
10 changes: 5 additions & 5 deletions trio/_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing_extensions import Self

# both of these introduce circular imports if outside a TYPE_CHECKING guard
from ._socket import _SocketType
from ._socket import SocketType
from .lowlevel import Task


Expand Down Expand Up @@ -211,10 +211,10 @@ class SocketFactory(metaclass=ABCMeta):
@abstractmethod
def socket(
self,
family: socket.AddressFamily | int | None = None,
type: socket.SocketKind | int | None = None,
proto: int | None = None,
) -> _SocketType:
family: socket.AddressFamily | int = socket.AF_INET,
type: socket.SocketKind | int = socket.SOCK_STREAM,
proto: int = 0,
) -> SocketType:
"""Create and return a socket object.

Your socket object must inherit from :class:`trio.socket.SocketType`,
Comment on lines 211 to 220
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't want these types to be | None, as that affects the signature of all subclasses of SocketFactory. But this is a change in behaviour, so it might be better to just do a type: ignore

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think having them default to the values they would have became originally is a good thing

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine. I approach type hints as not bound by semver (i.e. can narrow them or expand them as long as we're not passing something with a new type in) and subclasses can still default to None. Hopefully that approach is sane...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the issue is not the type hint, but the default value being changed from None -> socket.AF_INET etc etc.
although thinking about it a bit more, I'm having trouble coming up with any code that actually changes in behaviour from this. If you're subclassing SocketFactory and defining your own SocketFactory.socket, the default values in SocketFactory.socket shouldn't modify the behaviour of MyClass.socket in any way .. I think?

Expand Down
13 changes: 8 additions & 5 deletions trio/_core/_generated_io_epoll.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions trio/_core/_generated_io_kqueue.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 5 additions & 6 deletions trio/_core/_io_epoll.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
from ._wakeup_socketpair import WakeupSocketpair

if TYPE_CHECKING:
from socket import socket

from typing_extensions import TypeAlias

from .._core import Abort, RaiseCancelT
from .._file_io import _HasFileNo


@attr.s(slots=True, eq=False)
Expand Down Expand Up @@ -290,7 +289,7 @@ def _update_registrations(self, fd: int) -> None:
if not wanted_flags:
del self._registered[fd]

async def _epoll_wait(self, fd: int | socket, attr_name: str) -> None:
async def _epoll_wait(self, fd: int | _HasFileNo, attr_name: str) -> None:
if not isinstance(fd, int):
fd = fd.fileno()
waiters = self._registered[fd]
Expand All @@ -309,15 +308,15 @@ def abort(_: RaiseCancelT) -> Abort:
await _core.wait_task_rescheduled(abort)

@_public
async def wait_readable(self, fd: int | socket) -> None:
async def wait_readable(self, fd: int | _HasFileNo) -> None:
await self._epoll_wait(fd, "read_task")

@_public
async def wait_writable(self, fd: int | socket) -> None:
async def wait_writable(self, fd: int | _HasFileNo) -> None:
await self._epoll_wait(fd, "write_task")

@_public
def notify_closing(self, fd: int | socket) -> None:
def notify_closing(self, fd: int | _HasFileNo) -> None:
if not isinstance(fd, int):
fd = fd.fileno()
wake_all(
Expand Down
11 changes: 5 additions & 6 deletions trio/_core/_io_kqueue.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@
from ._wakeup_socketpair import WakeupSocketpair

if TYPE_CHECKING:
from socket import socket

from typing_extensions import TypeAlias

from .._core import Abort, RaiseCancelT, Task, UnboundedQueue
from .._file_io import _HasFileNo

assert not TYPE_CHECKING or (sys.platform != "linux" and sys.platform != "win32")

Expand Down Expand Up @@ -149,7 +148,7 @@ def abort(raise_cancel: RaiseCancelT) -> Abort:
# wait_task_rescheduled does not have its return type typed
return await _core.wait_task_rescheduled(abort) # type: ignore[no-any-return]

async def _wait_common(self, fd: int | socket, filter: int) -> None:
async def _wait_common(self, fd: int | _HasFileNo, filter: int) -> None:
if not isinstance(fd, int):
fd = fd.fileno()
flags = select.KQ_EV_ADD | select.KQ_EV_ONESHOT
Expand Down Expand Up @@ -181,15 +180,15 @@ def abort(_: RaiseCancelT) -> Abort:
await self.wait_kevent(fd, filter, abort)

@_public
async def wait_readable(self, fd: int | socket) -> None:
async def wait_readable(self, fd: int | _HasFileNo) -> None:
await self._wait_common(fd, select.KQ_FILTER_READ)

@_public
async def wait_writable(self, fd: int | socket) -> None:
async def wait_writable(self, fd: int | _HasFileNo) -> None:
await self._wait_common(fd, select.KQ_FILTER_WRITE)

@_public
def notify_closing(self, fd: int | socket) -> None:
def notify_closing(self, fd: int | _HasFileNo) -> None:
if not isinstance(fd, int):
fd = fd.fileno()

Expand Down
26 changes: 13 additions & 13 deletions trio/_dtls.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,26 +43,26 @@
from OpenSSL.SSL import Context
from typing_extensions import Self, TypeAlias

from trio.socket import Address, _SocketType
from trio.socket import SocketType

MAX_UDP_PACKET_SIZE = 65527


def packet_header_overhead(sock: _SocketType) -> int:
def packet_header_overhead(sock: SocketType) -> int:
if sock.family == trio.socket.AF_INET:
return 28
else:
return 48


def worst_case_mtu(sock: _SocketType) -> int:
def worst_case_mtu(sock: SocketType) -> int:
if sock.family == trio.socket.AF_INET:
return 576 - packet_header_overhead(sock)
else:
return 1280 - packet_header_overhead(sock)


def best_guess_mtu(sock: _SocketType) -> int:
def best_guess_mtu(sock: SocketType) -> int:
return 1500 - packet_header_overhead(sock)


Expand Down Expand Up @@ -563,7 +563,7 @@ def _signable(*fields: bytes) -> bytes:


def _make_cookie(
key: bytes, salt: bytes, tick: int, address: Address, client_hello_bits: bytes
key: bytes, salt: bytes, tick: int, address: Any, client_hello_bits: bytes
) -> bytes:
assert len(salt) == SALT_BYTES
assert len(key) == KEY_BYTES
Expand All @@ -581,7 +581,7 @@ def _make_cookie(


def valid_cookie(
key: bytes, cookie: bytes, address: Address, client_hello_bits: bytes
key: bytes, cookie: bytes, address: Any, client_hello_bits: bytes
) -> bool:
if len(cookie) > SALT_BYTES:
salt = cookie[:SALT_BYTES]
Expand All @@ -603,7 +603,7 @@ def valid_cookie(


def challenge_for(
key: bytes, address: Address, epoch_seqno: int, client_hello_bits: bytes
key: bytes, address: Any, epoch_seqno: int, client_hello_bits: bytes
) -> bytes:
salt = os.urandom(SALT_BYTES)
tick = _current_cookie_tick()
Expand Down Expand Up @@ -664,7 +664,7 @@ def _read_loop(read_fn: Callable[[int], bytes]) -> bytes:


async def handle_client_hello_untrusted(
endpoint: DTLSEndpoint, address: Address, packet: bytes
endpoint: DTLSEndpoint, address: Any, packet: bytes
) -> None:
if endpoint._listening_context is None:
return
Expand Down Expand Up @@ -739,7 +739,7 @@ async def handle_client_hello_untrusted(


async def dtls_receive_loop(
endpoint_ref: ReferenceType[DTLSEndpoint], sock: _SocketType
endpoint_ref: ReferenceType[DTLSEndpoint], sock: SocketType
) -> None:
try:
while True:
Expand Down Expand Up @@ -829,7 +829,7 @@ class DTLSChannel(trio.abc.Channel[bytes], metaclass=NoPublicConstructor):

"""

def __init__(self, endpoint: DTLSEndpoint, peer_address: Address, ctx: Context):
def __init__(self, endpoint: DTLSEndpoint, peer_address: Any, ctx: Context):
self.endpoint = endpoint
self.peer_address = peer_address
self._packets_dropped_in_trio = 0
Expand Down Expand Up @@ -1180,7 +1180,7 @@ class DTLSEndpoint:

"""

def __init__(self, socket: _SocketType, *, incoming_packets_buffer: int = 10):
def __init__(self, socket: SocketType, *, incoming_packets_buffer: int = 10):
# We do this lazily on first construction, so only people who actually use DTLS
# have to install PyOpenSSL.
global SSL
Expand All @@ -1191,7 +1191,7 @@ def __init__(self, socket: _SocketType, *, incoming_packets_buffer: int = 10):
if socket.type != trio.socket.SOCK_DGRAM:
raise ValueError("DTLS requires a SOCK_DGRAM socket")
self._initialized = True
self.socket: _SocketType = socket
self.socket: SocketType = socket

self.incoming_packets_buffer = incoming_packets_buffer
self._token = trio.lowlevel.current_trio_token()
Expand All @@ -1200,7 +1200,7 @@ def __init__(self, socket: _SocketType, *, incoming_packets_buffer: int = 10):
# as a peer provides a valid cookie, we can immediately tear down the
# old connection.
# {remote address: DTLSChannel}
self._streams: WeakValueDictionary[Address, DTLSChannel] = WeakValueDictionary()
self._streams: WeakValueDictionary[Any, DTLSChannel] = WeakValueDictionary()
self._listening_context: Context | None = None
self._listening_key: bytes | None = None
self._incoming_connections_q = _Queue[DTLSChannel](float("inf"))
Expand Down
5 changes: 3 additions & 2 deletions trio/_highlevel_open_tcp_listeners.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,12 @@ async def serve_tcp(
connect to it to check that it's working properly, you can use something
like::

from trio import SocketListener, SocketStream
from trio.testing import open_stream_to_socket_listener

async with trio.open_nursery() as nursery:
listeners = await nursery.start(serve_tcp, handler, 0)
client_stream = await open_stream_to_socket_listener(listeners[0])
listeners: list[SocketListener] = await nursery.start(serve_tcp, handler, 0)
client_stream: SocketStream = await open_stream_to_socket_listener(listeners[0])

# Then send and receive data on 'client_stream', for example:
await client_stream.send_all(b"GET / HTTP/1.0\\r\\n\\r\\n")
Expand Down
Loading