diff --git a/docs/source/conf.py b/docs/source/conf.py
index 6045ffd828..52872d0cb4 100755
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -87,6 +87,7 @@ def setup(app):
intersphinx_mapping = {
"python": ('https://docs.python.org/3', None),
"outcome": ('https://outcome.readthedocs.io/en/latest/', None),
+ "pyopenssl": ('https://www.pyopenssl.org/en/stable/', None),
}
autodoc_member_order = "bysource"
diff --git a/docs/source/reference-io.rst b/docs/source/reference-io.rst
index b18983e272..a3291ef2ae 100644
--- a/docs/source/reference-io.rst
+++ b/docs/source/reference-io.rst
@@ -258,6 +258,52 @@ you call them before the handshake completes:
.. autoexception:: NeedHandshakeError
+Datagram TLS support
+~~~~~~~~~~~~~~~~~~~~
+
+Trio also has support for Datagram TLS (DTLS), which is like TLS but
+for unreliable UDP connections. This can be useful for applications
+where TCP's reliable in-order delivery is problematic, like
+teleconferencing, latency-sensitive games, and VPNs.
+
+Currently, using DTLS with Trio requires PyOpenSSL. We hope to
+eventually allow the use of the stdlib `ssl` module as well, but
+unfortunately that's not yet possible.
+
+.. warning:: Note that PyOpenSSL is in many ways lower-level than the
+ `ssl` module – in particular, it currently **HAS NO BUILT-IN
+ MECHANISM TO VALIDATE CERTIFICATES**. We *strongly* recommend that
+ you use the `service-identity
+ `__ library to validate
+ hostnames and certificates.
+
+.. autoclass:: DTLSEndpoint
+
+ .. automethod:: connect
+
+ .. automethod:: serve
+
+ .. automethod:: close
+
+.. autoclass:: DTLSChannel
+ :show-inheritance:
+
+ .. automethod:: do_handshake
+
+ .. automethod:: send
+
+ .. automethod:: receive
+
+ .. automethod:: close
+
+ .. automethod:: aclose
+
+ .. automethod:: set_ciphertext_mtu
+
+ .. automethod:: get_cleartext_mtu
+
+ .. automethod:: statistics
+
.. module:: trio.socket
Low-level networking with :mod:`trio.socket`
diff --git a/newsfragments/2010.feature.rst b/newsfragments/2010.feature.rst
new file mode 100644
index 0000000000..f99687652f
--- /dev/null
+++ b/newsfragments/2010.feature.rst
@@ -0,0 +1,4 @@
+Added support for `Datagram TLS
+`__,
+for secure communication over UDP. Currently requires `PyOpenSSL
+`__.
diff --git a/test-requirements.in b/test-requirements.in
index 3478604218..186b13392e 100644
--- a/test-requirements.in
+++ b/test-requirements.in
@@ -3,8 +3,8 @@ pytest >= 5.0 # for faulthandler in core
pytest-cov >= 2.6.0
# ipython 7.x is the last major version supporting Python 3.7
ipython < 7.32 # for the IPython traceback integration tests
-pyOpenSSL # for the ssl tests
-trustme # for the ssl tests
+pyOpenSSL >= 22.0.0 # for the ssl + DTLS tests
+trustme # for the ssl + DTLS tests
pylint # for pylint finding all symbols tests
jedi # for jedi code completion tests
cryptography>=36.0.0 # 35.0.0 is transitive but fails
@@ -12,6 +12,7 @@ cryptography>=36.0.0 # 35.0.0 is transitive but fails
# Tools
black; implementation_name == "cpython"
mypy; implementation_name == "cpython"
+types-pyOpenSSL; implementation_name == "cpython"
flake8
astor # code generation
diff --git a/test-requirements.txt b/test-requirements.txt
index 17e6697e83..5aabe871de 100644
--- a/test-requirements.txt
+++ b/test-requirements.txt
@@ -134,10 +134,21 @@ traitlets==5.3.0
# matplotlib-inline
trustme==0.9.0
# via -r test-requirements.in
+types-cryptography==3.3.14
+ # via types-pyopenssl
+types-enum34==1.1.8
+ # via types-cryptography
+types-ipaddress==1.0.7
+ # via types-cryptography
+types-pyopenssl==21.0.3 ; implementation_name == "cpython"
+ # via -r test-requirements.in
typing-extensions==4.3.0 ; implementation_name == "cpython"
# via
# -r test-requirements.in
+ # astroid
+ # black
# mypy
+ # pylint
wcwidth==0.2.5
# via prompt-toolkit
wrapt==1.14.1
diff --git a/trio/__init__.py b/trio/__init__.py
index a50ec33310..b35fa076b3 100644
--- a/trio/__init__.py
+++ b/trio/__init__.py
@@ -74,6 +74,8 @@
from ._ssl import SSLStream, SSLListener, NeedHandshakeError
+from ._dtls import DTLSEndpoint, DTLSChannel
+
from ._highlevel_serve_listeners import serve_listeners
from ._highlevel_open_tcp_stream import open_tcp_stream
diff --git a/trio/_dtls.py b/trio/_dtls.py
new file mode 100644
index 0000000000..910637455a
--- /dev/null
+++ b/trio/_dtls.py
@@ -0,0 +1,1276 @@
+# Implementation of DTLS 1.2, using pyopenssl
+# https://datatracker.ietf.org/doc/html/rfc6347
+#
+# OpenSSL's APIs for DTLS are extremely awkward and limited, which forces us to jump
+# through a *lot* of hoops and implement important chunks of the protocol ourselves.
+# Hopefully they fix this before implementing DTLS 1.3, because it's a very different
+# protocol, and it's probably impossible to pull tricks like we do here.
+
+import struct
+import hmac
+import os
+import enum
+from itertools import count
+import weakref
+import errno
+import warnings
+
+import attr
+
+import trio
+from trio._util import NoPublicConstructor, Final
+
+MAX_UDP_PACKET_SIZE = 65527
+
+
+def packet_header_overhead(sock):
+ if sock.family == trio.socket.AF_INET:
+ return 28
+ else:
+ return 48
+
+
+def worst_case_mtu(sock):
+ if sock.family == trio.socket.AF_INET:
+ return 576 - packet_header_overhead(sock)
+ else:
+ return 1280 - packet_header_overhead(sock)
+
+
+def best_guess_mtu(sock):
+ return 1500 - packet_header_overhead(sock)
+
+
+# There are a bunch of different RFCs that define these codes, so for a
+# comprehensive collection look here:
+# https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml
+class ContentType(enum.IntEnum):
+ change_cipher_spec = 20
+ alert = 21
+ handshake = 22
+ application_data = 23
+ heartbeat = 24
+
+
+class HandshakeType(enum.IntEnum):
+ hello_request = 0
+ client_hello = 1
+ server_hello = 2
+ hello_verify_request = 3
+ new_session_ticket = 4
+ end_of_early_data = 4
+ encrypted_extensions = 8
+ certificate = 11
+ server_key_exchange = 12
+ certificate_request = 13
+ server_hello_done = 14
+ certificate_verify = 15
+ client_key_exchange = 16
+ finished = 20
+ certificate_url = 21
+ certificate_status = 22
+ supplemental_data = 23
+ key_update = 24
+ compressed_certificate = 25
+ ekt_key = 26
+ message_hash = 254
+
+
+class ProtocolVersion:
+ DTLS10 = bytes([254, 255])
+ DTLS12 = bytes([254, 253])
+
+
+EPOCH_MASK = 0xFFFF << (6 * 8)
+
+
+# Conventions:
+# - All functions that handle network data end in _untrusted.
+# - All functions end in _untrusted MUST make sure that bad data from the
+# network cannot *only* cause BadPacket to be raised. No IndexError or
+# struct.error or whatever.
+class BadPacket(Exception):
+ pass
+
+
+# This checks that the DTLS 'epoch' field is 0, which is true iff we're in the
+# initial handshake. It doesn't check the ContentType, because not all
+# handshake messages have ContentType==handshake -- for example,
+# ChangeCipherSpec is used during the handshake but has its own ContentType.
+#
+# Cannot fail.
+def part_of_handshake_untrusted(packet):
+ # If the packet is too short, then slicing will successfully return a
+ # short string, which will necessarily fail to match.
+ return packet[3:5] == b"\x00\x00"
+
+
+# Cannot fail
+def is_client_hello_untrusted(packet):
+ try:
+ return (
+ packet[0] == ContentType.handshake
+ and packet[13] == HandshakeType.client_hello
+ )
+ except IndexError:
+ # Invalid DTLS record
+ return False
+
+
+# DTLS records are:
+# - 1 byte content type
+# - 2 bytes version
+# - 8 bytes epoch+seqno
+# Technically this is 2 bytes epoch then 6 bytes seqno, but we treat it as
+# a single 8-byte integer, where epoch changes are represented as jumping
+# forward by 2**(6*8).
+# - 2 bytes payload length (unsigned big-endian)
+# - payload
+RECORD_HEADER = struct.Struct("!B2sQH")
+
+
+def to_hex(data: bytes) -> str: # pragma: no cover
+ return data.hex()
+
+
+@attr.frozen
+class Record:
+ content_type: int
+ version: bytes = attr.ib(repr=to_hex)
+ epoch_seqno: int
+ payload: bytes = attr.ib(repr=to_hex)
+
+
+def records_untrusted(packet):
+ i = 0
+ while i < len(packet):
+ try:
+ ct, version, epoch_seqno, payload_len = RECORD_HEADER.unpack_from(packet, i)
+ # Marked as no-cover because at time of writing, this code is unreachable
+ # (records_untrusted only gets called on packets that are either trusted or that
+ # have passed is_client_hello_untrusted, which filters out short packets)
+ except struct.error as exc: # pragma: no cover
+ raise BadPacket("invalid record header") from exc
+ i += RECORD_HEADER.size
+ payload = packet[i : i + payload_len]
+ if len(payload) != payload_len:
+ raise BadPacket("short record")
+ i += payload_len
+ yield Record(ct, version, epoch_seqno, payload)
+
+
+def encode_record(record):
+ header = RECORD_HEADER.pack(
+ record.content_type,
+ record.version,
+ record.epoch_seqno,
+ len(record.payload),
+ )
+ return header + record.payload
+
+
+# Handshake messages are:
+# - 1 byte message type
+# - 3 bytes total message length
+# - 2 bytes message sequence number
+# - 3 bytes fragment offset
+# - 3 bytes fragment length
+HANDSHAKE_MESSAGE_HEADER = struct.Struct("!B3sH3s3s")
+
+
+@attr.frozen
+class HandshakeFragment:
+ msg_type: int
+ msg_len: int
+ msg_seq: int
+ frag_offset: int
+ frag_len: int
+ frag: bytes = attr.ib(repr=to_hex)
+
+
+def decode_handshake_fragment_untrusted(payload):
+ # Raises BadPacket if decoding fails
+ try:
+ (
+ msg_type,
+ msg_len_bytes,
+ msg_seq,
+ frag_offset_bytes,
+ frag_len_bytes,
+ ) = HANDSHAKE_MESSAGE_HEADER.unpack_from(payload)
+ except struct.error as exc:
+ raise BadPacket("bad handshake message header") from exc
+ # 'struct' doesn't have built-in support for 24-bit integers, so we
+ # have to do it by hand. These can't fail.
+ msg_len = int.from_bytes(msg_len_bytes, "big")
+ frag_offset = int.from_bytes(frag_offset_bytes, "big")
+ frag_len = int.from_bytes(frag_len_bytes, "big")
+ frag = payload[HANDSHAKE_MESSAGE_HEADER.size :]
+ if len(frag) != frag_len:
+ raise BadPacket("handshake fragment length doesn't match record length")
+ return HandshakeFragment(
+ msg_type,
+ msg_len,
+ msg_seq,
+ frag_offset,
+ frag_len,
+ frag,
+ )
+
+
+def encode_handshake_fragment(hsf):
+ hs_header = HANDSHAKE_MESSAGE_HEADER.pack(
+ hsf.msg_type,
+ hsf.msg_len.to_bytes(3, "big"),
+ hsf.msg_seq,
+ hsf.frag_offset.to_bytes(3, "big"),
+ hsf.frag_len.to_bytes(3, "big"),
+ )
+ return hs_header + hsf.frag
+
+
+def decode_client_hello_untrusted(packet):
+ # Raises BadPacket if parsing fails
+ # Returns (record epoch_seqno, cookie from the packet, data that should be
+ # hashed into cookie)
+ try:
+ # ClientHello has to be the first record in the packet
+ record = next(records_untrusted(packet))
+ # no-cover because at time of writing, this is unreachable:
+ # decode_client_hello_untrusted is only called on packets that have passed
+ # is_client_hello_untrusted, which confirms the content type.
+ if record.content_type != ContentType.handshake: # pragma: no cover
+ raise BadPacket("not a handshake record")
+ fragment = decode_handshake_fragment_untrusted(record.payload)
+ if fragment.msg_type != HandshakeType.client_hello:
+ raise BadPacket("not a ClientHello")
+ # ClientHello can't be fragmented, because reassembly requires holding
+ # per-connection state, and we refuse to allocate per-connection state
+ # until after we get a valid ClientHello.
+ if fragment.frag_offset != 0:
+ raise BadPacket("fragmented ClientHello")
+ if fragment.frag_len != fragment.msg_len:
+ raise BadPacket("fragmented ClientHello")
+
+ # As per RFC 6347:
+ #
+ # When responding to a HelloVerifyRequest, the client MUST use the
+ # same parameter values (version, random, session_id, cipher_suites,
+ # compression_method) as it did in the original ClientHello. The
+ # server SHOULD use those values to generate its cookie and verify that
+ # they are correct upon cookie receipt.
+ #
+ # However, the record-layer framing can and will change (e.g. the
+ # second ClientHello will have a new record-layer sequence number). So
+ # we need to pull out the handshake message alone, discarding the
+ # record-layer stuff, and then we're going to hash all of it *except*
+ # the cookie.
+
+ body = fragment.frag
+ # ClientHello is:
+ #
+ # - 2 bytes client_version
+ # - 32 bytes random
+ # - 1 byte session_id length
+ # - session_id
+ # - 1 byte cookie length
+ # - cookie
+ # - everything else
+ #
+ # So to find the cookie, so we need to figure out how long the
+ # session_id is and skip past it.
+ session_id_len = body[2 + 32]
+ cookie_len_offset = 2 + 32 + 1 + session_id_len
+ cookie_len = body[cookie_len_offset]
+
+ cookie_start = cookie_len_offset + 1
+ cookie_end = cookie_start + cookie_len
+
+ before_cookie = body[:cookie_len_offset]
+ cookie = body[cookie_start:cookie_end]
+ after_cookie = body[cookie_end:]
+
+ if len(cookie) != cookie_len:
+ raise BadPacket("short cookie")
+ return (record.epoch_seqno, cookie, before_cookie + after_cookie)
+
+ except (struct.error, IndexError) as exc:
+ raise BadPacket("bad ClientHello") from exc
+
+
+@attr.frozen
+class HandshakeMessage:
+ record_version: bytes = attr.ib(repr=to_hex)
+ msg_type: HandshakeType
+ msg_seq: int
+ body: bytearray = attr.ib(repr=to_hex)
+
+
+# ChangeCipherSpec is part of the handshake, but it's not a "handshake
+# message" and can't be fragmented the same way. Sigh.
+@attr.frozen
+class PseudoHandshakeMessage:
+ record_version: bytes = attr.ib(repr=to_hex)
+ content_type: int
+ payload: bytes = attr.ib(repr=to_hex)
+
+
+# The final record in a handshake is Finished, which is encrypted, can't be fragmented
+# (at least by us), and keeps its record number (because it's in a new epoch). So we
+# just pass it through unchanged. (Fortunately, the payload is only a single hash value,
+# so the largest it will ever be is 64 bytes for a 512-bit hash. Which is small enough
+# that it never requires fragmenting to fit into a UDP packet.
+@attr.frozen
+class OpaqueHandshakeMessage:
+ record: Record
+
+
+# This takes a raw outgoing handshake volley that openssl generated, and
+# reconstructs the handshake messages inside it, so that we can repack them
+# into records while retransmitting. So the data ought to be well-behaved --
+# it's not coming from the network.
+def decode_volley_trusted(volley):
+ messages = []
+ messages_by_seq = {}
+ for record in records_untrusted(volley):
+ # ChangeCipherSpec isn't a handshake message, so it can't be fragmented.
+ # Handshake messages with epoch > 0 are encrypted, so we can't fragment them
+ # either. Fortunately, ChangeCipherSpec has a 1 byte payload, and the only
+ # encrypted handshake message is Finished, whose payload is a single hash value
+ # -- so 32 bytes for SHA-256, 64 for SHA-512, etc. Neither is going to be so
+ # large that it has to be fragmented to fit into a single packet.
+ if record.epoch_seqno & EPOCH_MASK:
+ messages.append(OpaqueHandshakeMessage(record))
+ elif record.content_type in (ContentType.change_cipher_spec, ContentType.alert):
+ messages.append(
+ PseudoHandshakeMessage(
+ record.version, record.content_type, record.payload
+ )
+ )
+ else:
+ assert record.content_type == ContentType.handshake
+ fragment = decode_handshake_fragment_untrusted(record.payload)
+ msg_type = HandshakeType(fragment.msg_type)
+ if fragment.msg_seq not in messages_by_seq:
+ msg = HandshakeMessage(
+ record.version,
+ msg_type,
+ fragment.msg_seq,
+ bytearray(fragment.msg_len),
+ )
+ messages.append(msg)
+ messages_by_seq[fragment.msg_seq] = msg
+ else:
+ msg = messages_by_seq[fragment.msg_seq]
+ assert msg.msg_type == fragment.msg_type
+ assert msg.msg_seq == fragment.msg_seq
+ assert len(msg.body) == fragment.msg_len
+
+ msg.body[
+ fragment.frag_offset : fragment.frag_offset + fragment.frag_len
+ ] = fragment.frag
+
+ return messages
+
+
+class RecordEncoder:
+ def __init__(self):
+ self._record_seq = count()
+
+ def set_first_record_number(self, n):
+ self._record_seq = count(n)
+
+ def encode_volley(self, messages, mtu):
+ packets = []
+ packet = bytearray()
+ for message in messages:
+ if isinstance(message, OpaqueHandshakeMessage):
+ encoded = encode_record(message.record)
+ if mtu - len(packet) - len(encoded) <= 0:
+ packets.append(packet)
+ packet = bytearray()
+ packet += encoded
+ assert len(packet) <= mtu
+ elif isinstance(message, PseudoHandshakeMessage):
+ space = mtu - len(packet) - RECORD_HEADER.size - len(message.payload)
+ if space <= 0:
+ packets.append(packet)
+ packet = bytearray()
+ packet += RECORD_HEADER.pack(
+ message.content_type,
+ message.record_version,
+ next(self._record_seq),
+ len(message.payload),
+ )
+ packet += message.payload
+ assert len(packet) <= mtu
+ else:
+ msg_len_bytes = len(message.body).to_bytes(3, "big")
+ frag_offset = 0
+ frags_encoded = 0
+ # If message.body is empty, then we still want to encode it in one
+ # fragment, not zero.
+ while frag_offset < len(message.body) or not frags_encoded:
+ space = (
+ mtu
+ - len(packet)
+ - RECORD_HEADER.size
+ - HANDSHAKE_MESSAGE_HEADER.size
+ )
+ if space <= 0:
+ packets.append(packet)
+ packet = bytearray()
+ continue
+ frag = message.body[frag_offset : frag_offset + space]
+ frag_offset_bytes = frag_offset.to_bytes(3, "big")
+ frag_len_bytes = len(frag).to_bytes(3, "big")
+ frag_offset += len(frag)
+
+ packet += RECORD_HEADER.pack(
+ ContentType.handshake,
+ message.record_version,
+ next(self._record_seq),
+ HANDSHAKE_MESSAGE_HEADER.size + len(frag),
+ )
+
+ packet += HANDSHAKE_MESSAGE_HEADER.pack(
+ message.msg_type,
+ msg_len_bytes,
+ message.msg_seq,
+ frag_offset_bytes,
+ frag_len_bytes,
+ )
+
+ packet += frag
+
+ frags_encoded += 1
+ assert len(packet) <= mtu
+
+ if packet:
+ packets.append(packet)
+
+ return packets
+
+
+# This bit requires implementing a bona fide cryptographic protocol, so even though it's
+# a simple one let's take a moment to discuss the design.
+#
+# Our goal is to force new incoming handshakes that claim to be coming from a
+# given ip:port to prove that they can also receive packets sent to that
+# ip:port. (There's nothing in UDP to stop someone from forging the return
+# address, and it's often used for stuff like DoS reflection attacks, where
+# an attacker tries to trick us into sending data at some innocent victim.)
+# For more details, see:
+#
+# https://datatracker.ietf.org/doc/html/rfc6347#section-4.2.1
+#
+# To do this, when we receive an initial ClientHello, we calculate a magic
+# cookie, and send it back as a HelloVerifyRequest. Then the client sends us a
+# second ClientHello, this time with the magic cookie included, and after we
+# check that this cookie is valid we go ahead and start the handshake proper.
+#
+# So the magic cookie needs the following properties:
+# - No-one can forge it without knowing our secret key
+# - It ensures that the ip, port, and ClientHello contents from the response
+# match those in the challenge
+# - It expires after a short-ish period (so that if an attacker manages to steal one, it
+# won't be useful for long)
+# - It doesn't require storing any peer-specific state on our side
+#
+# To do that, we take the ip/port/ClientHello data and compute an HMAC of them, using a
+# secret key we generate on startup. We also include:
+#
+# - The current time (using Trio's clock), rounded to the nearest 30 seconds
+# - A random salt
+#
+# Then the cookie is the salt and the HMAC digest concatenated together.
+#
+# When verifying a cookie, we use the salt + new ip/port/ClientHello data to recompute
+# the HMAC digest, for both the current time and the current time minus 30 seconds, and
+# if either of them match, we consider the cookie good.
+#
+# Including the rounded-off time like this means that each cookie is good for at least
+# 30 seconds, and possibly as much as 60 seconds.
+#
+# The salt is probably not necessary -- I'm pretty sure that all it does is make it hard
+# for an attacker to figure out when our clock ticks over a 30 second boundary. Which is
+# probably pretty harmless? But it's easier to add the salt than to convince myself that
+# it's *completely* harmless, so, salt it is.
+
+COOKIE_REFRESH_INTERVAL = 30 # seconds
+KEY_BYTES = 32
+COOKIE_HASH = "sha256"
+SALT_BYTES = 8
+# 32 bytes was the maximum cookie length in DTLS 1.0. DTLS 1.2 raised it to 255. I doubt
+# there are any DTLS 1.0 implementations still in the wild, but really 32 bytes is
+# plenty, and it also gets rid of a confusing warning in Wireshark output.
+#
+# We truncate the cookie to 32 bytes, of which 8 bytes is salt, so that leaves 24 bytes
+# of truncated HMAC = 192 bit security, which is still massive overkill. (TCP uses 32
+# *bits* for this.) HMAC truncation is explicitly noted as safe in RFC 2104:
+# https://datatracker.ietf.org/doc/html/rfc2104#section-5
+COOKIE_LENGTH = 32
+
+
+def _current_cookie_tick():
+ return int(trio.current_time() / COOKIE_REFRESH_INTERVAL)
+
+
+# Simple deterministic and invertible serializer -- i.e., a useful tool for converting
+# structured data into something we can cryptographically sign.
+def _signable(*fields):
+ out = []
+ for field in fields:
+ out.append(struct.pack("!Q", len(field)))
+ out.append(field)
+ return b"".join(out)
+
+
+def _make_cookie(key, salt, tick, address, client_hello_bits):
+ assert len(salt) == SALT_BYTES
+ assert len(key) == KEY_BYTES
+
+ signable_data = _signable(
+ salt,
+ struct.pack("!Q", tick),
+ # address is a mix of strings and ints, and variable length, so pack
+ # it into a single nested field
+ _signable(*(str(part).encode() for part in address)),
+ client_hello_bits,
+ )
+
+ return (salt + hmac.digest(key, signable_data, COOKIE_HASH))[:COOKIE_LENGTH]
+
+
+def valid_cookie(key, cookie, address, client_hello_bits):
+ if len(cookie) > SALT_BYTES:
+ salt = cookie[:SALT_BYTES]
+
+ tick = _current_cookie_tick()
+
+ cur_cookie = _make_cookie(key, salt, tick, address, client_hello_bits)
+ old_cookie = _make_cookie(
+ key, salt, max(tick - 1, 0), address, client_hello_bits
+ )
+
+ # I doubt using a short-circuiting 'or' here would leak any meaningful
+ # information, but why risk it when '|' is just as easy.
+ return hmac.compare_digest(cookie, cur_cookie) | hmac.compare_digest(
+ cookie, old_cookie
+ )
+ else:
+ return False
+
+
+def challenge_for(key, address, epoch_seqno, client_hello_bits):
+ salt = os.urandom(SALT_BYTES)
+ tick = _current_cookie_tick()
+ cookie = _make_cookie(key, salt, tick, address, client_hello_bits)
+
+ # HelloVerifyRequest body is:
+ # - 2 bytes version
+ # - length-prefixed cookie
+ #
+ # The DTLS 1.2 spec says that for this message specifically we should use
+ # the DTLS 1.0 version.
+ #
+ # (It also says the opposite of that, but that part is a mistake:
+ # https://www.rfc-editor.org/errata/eid4103
+ # ).
+ #
+ # And I guess we use this for both the message-level and record-level
+ # ProtocolVersions, since we haven't negotiated anything else yet?
+ body = ProtocolVersion.DTLS10 + bytes([len(cookie)]) + cookie
+
+ # RFC says have to copy the client's record number
+ # Errata says it should be handshake message number
+ # Openssl copies back record sequence number, and always sets message seq
+ # number 0. So I guess we'll follow openssl.
+ hs = HandshakeFragment(
+ msg_type=HandshakeType.hello_verify_request,
+ msg_len=len(body),
+ msg_seq=0,
+ frag_offset=0,
+ frag_len=len(body),
+ frag=body,
+ )
+ payload = encode_handshake_fragment(hs)
+
+ packet = encode_record(
+ Record(ContentType.handshake, ProtocolVersion.DTLS10, epoch_seqno, payload)
+ )
+ return packet
+
+
+class _Queue:
+ def __init__(self, incoming_packets_buffer):
+ self.s, self.r = trio.open_memory_channel(incoming_packets_buffer)
+
+
+def _read_loop(read_fn):
+ chunks = []
+ while True:
+ try:
+ chunk = read_fn(2**14) # max TLS record size
+ except SSL.WantReadError:
+ break
+ chunks.append(chunk)
+ return b"".join(chunks)
+
+
+async def handle_client_hello_untrusted(endpoint, address, packet):
+ if endpoint._listening_context is None:
+ return
+
+ try:
+ epoch_seqno, cookie, bits = decode_client_hello_untrusted(packet)
+ except BadPacket:
+ return
+
+ if endpoint._listening_key is None:
+ endpoint._listening_key = os.urandom(KEY_BYTES)
+
+ if not valid_cookie(endpoint._listening_key, cookie, address, bits):
+ challenge_packet = challenge_for(
+ endpoint._listening_key, address, epoch_seqno, bits
+ )
+ try:
+ async with endpoint._send_lock:
+ await endpoint.socket.sendto(challenge_packet, address)
+ except (OSError, trio.ClosedResourceError):
+ pass
+ else:
+ # We got a real, valid ClientHello!
+ stream = DTLSChannel._create(endpoint, address, endpoint._listening_context)
+ # Our HelloRetryRequest had some sequence number. We need our future sequence
+ # numbers to be larger than it, so our peer knows that our future records aren't
+ # stale/duplicates. But, we don't know what this sequence number was. What we do
+ # know is:
+ # - the HelloRetryRequest seqno was copied it from the initial ClientHello
+ # - the new ClientHello has a higher seqno than the initial ClientHello
+ # So, if we copy the new ClientHello's seqno into our first real handshake
+ # record and increment from there, that should work.
+ stream._record_encoder.set_first_record_number(epoch_seqno)
+ # Process the ClientHello
+ try:
+ stream._ssl.bio_write(packet)
+ stream._ssl.DTLSv1_listen()
+ except SSL.Error:
+ # ...OpenSSL didn't like it, so I guess we didn't have a valid ClientHello
+ # after all.
+ return
+
+ # Some old versions of OpenSSL have a bug with memory BIOs, where DTLSv1_listen
+ # consumes the ClientHello out of the BIO, but then do_handshake expects the
+ # ClientHello to still be in there (but not the one that ships with Ubuntu
+ # 20.04). In particular, this is known to affect the OpenSSL v1.1.1 that ships
+ # with Ubuntu 18.04. To work around this, we deliver a second copy of the
+ # ClientHello after DTLSv1_listen has completed. This is safe to do
+ # unconditionally, because on newer versions of OpenSSL, the second ClientHello
+ # is treated as a duplicate packet, which is a normal thing that can happen over
+ # UDP. For more details, see:
+ #
+ # https://github.com/pyca/pyopenssl/blob/e84e7b57d1838de70ab7a27089fbee78ce0d2106/tests/test_ssl.py#L4226-L4293
+ #
+ # This was fixed in v1.1.1a, and all later versions. So maybe in 2024 or so we
+ # can delete this. The fix landed in OpenSSL master as 079ef6bd534d2, and then
+ # was backported to the 1.1.1 branch as d1bfd8076e28.
+ stream._ssl.bio_write(packet)
+
+ # Check if we have an existing association
+ old_stream = endpoint._streams.get(address)
+ if old_stream is not None:
+ if old_stream._client_hello == (cookie, bits):
+ # ...This was just a duplicate of the last ClientHello, so never mind.
+ return
+ else:
+ # Ok, this *really is* a new handshake; the old stream should go away.
+ old_stream._set_replaced()
+ stream._client_hello = (cookie, bits)
+ endpoint._streams[address] = stream
+ endpoint._incoming_connections_q.s.send_nowait(stream)
+
+
+async def dtls_receive_loop(endpoint_ref, sock):
+ try:
+ while True:
+ try:
+ packet, address = await sock.recvfrom(MAX_UDP_PACKET_SIZE)
+ except OSError as exc:
+ if exc.errno == errno.ECONNRESET:
+ # Windows only: "On a UDP-datagram socket [ECONNRESET]
+ # indicates a previous send operation resulted in an ICMP Port
+ # Unreachable message" -- https://docs.microsoft.com/en-us/windows/win32/api/winsock/nf-winsock-recvfrom
+ #
+ # This is totally useless -- there's nothing we can do with this
+ # information. So we just ignore it and retry the recv.
+ continue
+ else:
+ raise
+ endpoint = endpoint_ref()
+ try:
+ if endpoint is None:
+ return
+ if is_client_hello_untrusted(packet):
+ await handle_client_hello_untrusted(endpoint, address, packet)
+ elif address in endpoint._streams:
+ stream = endpoint._streams[address]
+ if stream._did_handshake and part_of_handshake_untrusted(packet):
+ # The peer just sent us more handshake messages, that aren't a
+ # ClientHello, and we thought the handshake was done. Some of
+ # the packets that we sent to finish the handshake must have
+ # gotten lost. So re-send them. We do this directly here instead
+ # of just putting it into the queue and letting the receiver do
+ # it, because there's no guarantee that anyone is reading from
+ # the queue, because we think the handshake is done!
+ await stream._resend_final_volley()
+ else:
+ try:
+ stream._q.s.send_nowait(packet)
+ except trio.WouldBlock:
+ stream._packets_dropped_in_trio += 1
+ else:
+ # Drop packet
+ pass
+ finally:
+ del endpoint
+ except trio.ClosedResourceError:
+ # socket was closed
+ return
+ except OSError as exc:
+ if exc.errno in (errno.EBADF, errno.ENOTSOCK):
+ # socket was closed
+ return
+ else: # pragma: no cover
+ # ??? shouldn't happen
+ raise
+
+
+@attr.frozen
+class DTLSChannelStatistics:
+ incoming_packets_dropped_in_trio: int
+
+
+class DTLSChannel(trio.abc.Channel[bytes], metaclass=NoPublicConstructor):
+ """A DTLS connection.
+
+ This class has no public constructor – you get instances by calling
+ `DTLSEndpoint.serve` or `~DTLSEndpoint.connect`.
+
+ .. attribute:: endpoint
+
+ The `DTLSEndpoint` that this connection is using.
+
+ .. attribute:: peer_address
+
+ The IP/port of the remote peer that this connection is associated with.
+
+ """
+
+ def __init__(self, endpoint, peer_address, ctx):
+ self.endpoint = endpoint
+ self.peer_address = peer_address
+ self._packets_dropped_in_trio = 0
+ self._client_hello = None
+ self._did_handshake = False
+ # These are mandatory for all DTLS connections. OP_NO_QUERY_MTU is required to
+ # stop openssl from trying to query the memory BIO's MTU and then breaking, and
+ # OP_NO_RENEGOTIATION disables renegotiation, which is too complex for us to
+ # support and isn't useful anyway -- especially for DTLS where it's equivalent
+ # to just performing a new handshake.
+ ctx.set_options(SSL.OP_NO_QUERY_MTU | SSL.OP_NO_RENEGOTIATION)
+ self._ssl = SSL.Connection(ctx)
+ self._handshake_mtu = None
+ # This calls self._ssl.set_ciphertext_mtu, which is important, because if you
+ # don't call it then openssl doesn't work.
+ self.set_ciphertext_mtu(best_guess_mtu(self.endpoint.socket))
+ self._replaced = False
+ self._closed = False
+ self._q = _Queue(endpoint.incoming_packets_buffer)
+ self._handshake_lock = trio.Lock()
+ self._record_encoder = RecordEncoder()
+
+ def _set_replaced(self):
+ self._replaced = True
+ # Any packets we already received could maybe possibly still be processed, but
+ # there are no more coming. So we close this on the sender side.
+ self._q.s.close()
+
+ def _check_replaced(self):
+ if self._replaced:
+ raise trio.BrokenResourceError(
+ "peer tore down this connection to start a new one"
+ )
+
+ # XX on systems where we can (maybe just Linux?) take advantage of the kernel's PMTU
+ # estimate
+
+ # XX should we send close-notify when closing? It seems particularly pointless for
+ # DTLS where packets are all independent and can be lost anyway. We do at least need
+ # to handle receiving it properly though, which might be easier if we send it...
+
+ def close(self):
+ """Close this connection.
+
+ `DTLSChannel`\\s don't actually own any OS-level resources – the
+ socket is owned by the `DTLSEndpoint`, not the individual connections. So
+ you don't really *have* to call this. But it will interrupt any other tasks
+ calling `receive` with a `ClosedResourceError`, and cause future attempts to use
+ this connection to fail.
+
+ You can also use this object as a synchronous or asynchronous context manager.
+
+ """
+ if self._closed:
+ return
+ self._closed = True
+ if self.endpoint._streams.get(self.peer_address) is self:
+ del self.endpoint._streams[self.peer_address]
+ # Will wake any tasks waiting on self._q.get with a
+ # ClosedResourceError
+ self._q.r.close()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *args):
+ self.close()
+
+ async def aclose(self):
+ """Close this connection, but asynchronously.
+
+ This is included to satisfy the `trio.abc.Channel` contract. It's
+ identical to `close`, but async.
+
+ """
+ self.close()
+ await trio.lowlevel.checkpoint()
+
+ async def _send_volley(self, volley_messages):
+ packets = self._record_encoder.encode_volley(
+ volley_messages, self._handshake_mtu
+ )
+ for packet in packets:
+ async with self.endpoint._send_lock:
+ await self.endpoint.socket.sendto(packet, self.peer_address)
+
+ async def _resend_final_volley(self):
+ await self._send_volley(self._final_volley)
+
+ async def do_handshake(self, *, initial_retransmit_timeout=1.0):
+ """Perform the handshake.
+
+ Calling this is optional – if you don't, then it will be automatically called
+ the first time you call `send` or `receive`. But calling it explicitly can be
+ useful in case you want to control the retransmit timeout, use a cancel scope to
+ place an overall timeout on the handshake, or catch errors from the handshake
+ specifically.
+
+ It's safe to call this multiple times, or call it simultaneously from multiple
+ tasks – the first call will perform the handshake, and the rest will be no-ops.
+
+ Args:
+
+ initial_retransmit_timeout (float): Since UDP is an unreliable protocol, it's
+ possible that some of the packets we send during the handshake will get
+ lost. To handle this, DTLS uses a timer to automatically retransmit
+ handshake packets that don't receive a response. This lets you set the
+ timeout we use to detect packet loss. Ideally, it should be set to ~1.5
+ times the round-trip time to your peer, but 1 second is a reasonable
+ default. There's `some useful guidance here
+ `__.
+
+ This is the *initial* timeout, because if packets keep being lost then Trio
+ will automatically back off to longer values, to avoid overloading the
+ network.
+
+ """
+ async with self._handshake_lock:
+ if self._did_handshake:
+ return
+
+ timeout = initial_retransmit_timeout
+ volley_messages = []
+ volley_failed_sends = 0
+
+ def read_volley():
+ volley_bytes = _read_loop(self._ssl.bio_read)
+ new_volley_messages = decode_volley_trusted(volley_bytes)
+ if (
+ new_volley_messages
+ and volley_messages
+ and isinstance(new_volley_messages[0], HandshakeMessage)
+ and new_volley_messages[0].msg_seq == volley_messages[0].msg_seq
+ ):
+ # openssl decided to retransmit; discard because we handle
+ # retransmits ourselves
+ return []
+ else:
+ return new_volley_messages
+
+ # If we're a client, we send the initial volley. If we're a server, then
+ # the initial ClientHello has already been inserted into self._ssl's
+ # read BIO. So either way, we start by generating a new volley.
+ try:
+ self._ssl.do_handshake()
+ except SSL.WantReadError:
+ pass
+ volley_messages = read_volley()
+ # If we don't have messages to send in our initial volley, then something
+ # has gone very wrong. (I'm not sure this can actually happen without an
+ # error from OpenSSL, but we check just in case.)
+ if not volley_messages: # pragma: no cover
+ raise SSL.Error("something wrong with peer's ClientHello")
+
+ while True:
+ # -- at this point, we need to either send or re-send a volley --
+ assert volley_messages
+ self._check_replaced()
+ await self._send_volley(volley_messages)
+ # -- then this is where we wait for a reply --
+ self.endpoint._ensure_receive_loop()
+ with trio.move_on_after(timeout) as cscope:
+ async for packet in self._q.r:
+ self._ssl.bio_write(packet)
+ try:
+ self._ssl.do_handshake()
+ # We ignore generic SSL.Error here, because you can get those
+ # from random invalid packets
+ except (SSL.WantReadError, SSL.Error):
+ pass
+ else:
+ # No exception -> the handshake is done, and we can
+ # switch into data transfer mode.
+ self._did_handshake = True
+ # Might be empty, but that's ok -- we'll just send no
+ # packets.
+ self._final_volley = read_volley()
+ await self._send_volley(self._final_volley)
+ return
+ maybe_volley = read_volley()
+ if maybe_volley:
+ if (
+ isinstance(maybe_volley[0], PseudoHandshakeMessage)
+ and maybe_volley[0].content_type == ContentType.alert
+ ):
+ # we're sending an alert (e.g. due to a corrupted
+ # packet). We want to send it once, but don't save it to
+ # retransmit -- keep the last volley as the current
+ # volley.
+ await self._send_volley(maybe_volley)
+ else:
+ # We managed to get all of the peer's volley and
+ # generate a new one ourselves! break out of the 'for'
+ # loop and restart the timer.
+ volley_messages = maybe_volley
+ # "Implementations SHOULD retain the current timer value
+ # until a transmission without loss occurs, at which
+ # time the value may be reset to the initial value."
+ if volley_failed_sends == 0:
+ timeout = initial_retransmit_timeout
+ volley_failed_sends = 0
+ break
+ else:
+ assert self._replaced
+ self._check_replaced()
+ if cscope.cancelled_caught:
+ # Timeout expired. Double timeout for backoff, with a limit of 60
+ # seconds (this matches what openssl does, and also the
+ # recommendation in draft-ietf-tls-dtls13).
+ timeout = min(2 * timeout, 60.0)
+ volley_failed_sends += 1
+ if volley_failed_sends == 2:
+ # We tried sending this twice and they both failed. Maybe our
+ # PMTU estimate is wrong? Let's try dropping it to the minimum
+ # and hope that helps.
+ self._handshake_mtu = min(
+ self._handshake_mtu, worst_case_mtu(self.endpoint.socket)
+ )
+
+ async def send(self, data):
+ """Send a packet of data, securely."""
+
+ if self._closed:
+ raise trio.ClosedResourceError
+ if not data:
+ raise ValueError("openssl doesn't support sending empty DTLS packets")
+ if not self._did_handshake:
+ await self.do_handshake()
+ self._check_replaced()
+ self._ssl.write(data)
+ async with self.endpoint._send_lock:
+ await self.endpoint.socket.sendto(
+ _read_loop(self._ssl.bio_read), self.peer_address
+ )
+
+ async def receive(self):
+ """Fetch the next packet of data from this connection's peer, waiting if
+ necessary.
+
+ This is safe to call from multiple tasks simultaneously, in case you have some
+ reason to do that. And more importantly, it's cancellation-safe, meaning that
+ cancelling a call to `receive` will never cause a packet to be lost or corrupt
+ the underlying connection.
+
+ """
+ if not self._did_handshake:
+ await self.do_handshake()
+ # If the packet isn't really valid, then openssl can decode it to the empty
+ # string (e.g. b/c it's a late-arriving handshake packet, or a duplicate copy of
+ # a data packet). Skip over these instead of returning them.
+ while True:
+ try:
+ packet = await self._q.r.receive()
+ except trio.EndOfChannel:
+ assert self._replaced
+ self._check_replaced()
+ self._ssl.bio_write(packet)
+ cleartext = _read_loop(self._ssl.read)
+ if cleartext:
+ return cleartext
+
+ def set_ciphertext_mtu(self, new_mtu):
+ """Tells Trio the `largest amount of data that can be sent in a single packet to
+ this peer `__.
+
+ Trio doesn't actually enforce this limit – if you pass a huge packet to `send`,
+ then we'll dutifully encrypt it and attempt to send it. But calling this method
+ does have two useful effects:
+
+ - If called before the handshake is performed, then Trio will automatically
+ fragment handshake messages to fit within the given MTU. It also might
+ fragment them even smaller, if it detects signs of packet loss, so setting
+ this should never be necessary to make a successful connection. But, the
+ packet loss detection only happens after multiple timeouts have expired, so if
+ you have reason to believe that a smaller MTU is required, then you can set
+ this to skip those timeouts and establish the connection more quickly.
+
+ - It changes the value returned from `get_cleartext_mtu`. So if you have some
+ kind of estimate of the network-level MTU, then you can use this to figure out
+ how much overhead DTLS will need for hashes/padding/etc., and how much space
+ you have left for your application data.
+
+ The MTU here is measuring the largest UDP *payload* you think can be sent, the
+ amount of encrypted data that can be handed to the operating system in a single
+ call to `send`. It should *not* include IP/UDP headers. Note that OS estimates
+ of the MTU often are link-layer MTUs, so you have to subtract off 28 bytes on
+ IPv4 and 48 bytes on IPv6 to get the ciphertext MTU.
+
+ By default, Trio assumes an MTU of 1472 bytes on IPv4, and 1452 bytes on IPv6,
+ which correspond to the common Ethernet MTU of 1500 bytes after accounting for
+ IP/UDP overhead.
+
+ """
+ self._handshake_mtu = new_mtu
+ self._ssl.set_ciphertext_mtu(new_mtu)
+
+ def get_cleartext_mtu(self):
+ """Returns the largest number of bytes that you can pass in a single call to
+ `send` while still fitting within the network-level MTU.
+
+ See `set_ciphertext_mtu` for more details.
+
+ """
+ if not self._did_handshake:
+ raise trio.NeedHandshakeError
+ return self._ssl.get_cleartext_mtu()
+
+ def statistics(self):
+ """Returns an object with statistics about this connection.
+
+ Currently this has only one attribute:
+
+ - ``incoming_packets_dropped_in_trio`` (``int``): Gives a count of the number of
+ incoming packets from this peer that Trio successfully received from the
+ network, but then got dropped because the internal channel buffer was full. If
+ this is non-zero, then you might want to call ``receive`` more often, or use a
+ larger ``incoming_packets_buffer``, or just not worry about it because your
+ UDP-based protocol should be able to handle the occasional lost packet, right?
+
+ """
+ return DTLSChannelStatistics(self._packets_dropped_in_trio)
+
+
+class DTLSEndpoint(metaclass=Final):
+ """A DTLS endpoint.
+
+ A single UDP socket can handle arbitrarily many DTLS connections simultaneously,
+ acting as a client or server as needed. A `DTLSEndpoint` object holds a UDP socket
+ and manages these connections, which are represented as `DTLSChannel` objects.
+
+ Args:
+ socket: (trio.socket.SocketType): A ``SOCK_DGRAM`` socket. If you want to accept
+ incoming connections in server mode, then you should probably bind the socket to
+ some known port.
+ incoming_packets_buffer (int): Each `DTLSChannel` using this socket has its own
+ buffer that holds incoming packets until you call `~DTLSChannel.receive` to read
+ them. This lets you adjust the size of this buffer. `~DTLSChannel.statistics`
+ lets you check if the buffer has overflowed.
+
+ .. attribute:: socket
+ incoming_packets_buffer
+
+ Both constructor arguments are also exposed as attributes, in case you need to
+ access them later.
+
+ """
+
+ def __init__(self, socket, *, incoming_packets_buffer=10):
+ # We do this lazily on first construction, so only people who actually use DTLS
+ # have to install PyOpenSSL.
+ global SSL
+ from OpenSSL import SSL
+
+ self.socket = None # for __del__, in case the next line raises
+ if socket.type != trio.socket.SOCK_DGRAM:
+ raise ValueError("DTLS requires a SOCK_DGRAM socket")
+ self.socket = socket
+
+ self.incoming_packets_buffer = incoming_packets_buffer
+ self._token = trio.lowlevel.current_trio_token()
+ # We don't need to track handshaking vs non-handshake connections
+ # separately. We only keep one connection per remote address; as soon
+ # as a peer provides a valid cookie, we can immediately tear down the
+ # old connection.
+ # {remote address: DTLSChannel}
+ self._streams = weakref.WeakValueDictionary()
+ self._listening_context = None
+ self._listening_key = None
+ self._incoming_connections_q = _Queue(float("inf"))
+ self._send_lock = trio.Lock()
+ self._closed = False
+ self._receive_loop_spawned = False
+
+ def _ensure_receive_loop(self):
+ # We have to spawn this lazily, because on Windows it will immediately error out
+ # if the socket isn't already bound -- which for clients might not happen until
+ # after we send our first packet.
+ if not self._receive_loop_spawned:
+ trio.lowlevel.spawn_system_task(
+ dtls_receive_loop, weakref.ref(self), self.socket
+ )
+ self._receive_loop_spawned = True
+
+ def __del__(self):
+ # Do nothing if this object was never fully constructed
+ if self.socket is None:
+ return
+ # Close the socket in Trio context (if our Trio context still exists), so that
+ # the background task gets notified about the closure and can exit.
+ if not self._closed:
+ try:
+ self._token.run_sync_soon(self.close)
+ except RuntimeError:
+ pass
+ # Do this last, because it might raise an exception
+ warnings.warn(
+ f"unclosed DTLS endpoint {self!r}", ResourceWarning, source=self
+ )
+
+ def close(self):
+ """Close this socket, and all associated DTLS connections.
+
+ This object can also be used as a context manager.
+
+ """
+ self._closed = True
+ self.socket.close()
+ for stream in list(self._streams.values()):
+ stream.close()
+ self._incoming_connections_q.s.close()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *args):
+ self.close()
+
+ def _check_closed(self):
+ if self._closed:
+ raise trio.ClosedResourceError
+
+ async def serve(
+ self, ssl_context, async_fn, *args, task_status=trio.TASK_STATUS_IGNORED
+ ):
+ """Listen for incoming connections, and spawn a handler for each using an
+ internal nursery.
+
+ Similar to `~trio.serve_tcp`, this function never returns until cancelled, or
+ the `DTLSEndpoint` is closed and all handlers have exited.
+
+ Usage commonly looks like::
+
+ async def handler(dtls_channel):
+ ...
+
+ async with trio.open_nursery() as nursery:
+ await nursery.start(dtls_endpoint.serve, ssl_context, handler)
+ # ... do other things here ...
+
+ The ``dtls_channel`` passed into the handler function has already performed the
+ "cookie exchange" part of the DTLS handshake, so the peer address is
+ trustworthy. But the actual cryptographic handshake doesn't happen until you
+ start using it, giving you a chance for any last minute configuration, and the
+ option to catch and handle handshake errors.
+
+ Args:
+ ssl_context (OpenSSL.SSL.Context): The PyOpenSSL context object to use for
+ incoming connections.
+ async_fn: The handler function that will be invoked for each incoming
+ connection.
+
+ """
+ self._check_closed()
+ if self._listening_context is not None:
+ raise trio.BusyResourceError("another task is already listening")
+ try:
+ self.socket.getsockname()
+ except OSError:
+ raise RuntimeError("DTLS socket must be bound before it can serve")
+ self._ensure_receive_loop()
+ # We do cookie verification ourselves, so tell OpenSSL not to worry about it.
+ # (See also _inject_client_hello_untrusted.)
+ ssl_context.set_cookie_verify_callback(lambda *_: True)
+ try:
+ self._listening_context = ssl_context
+ task_status.started()
+
+ async def handler_wrapper(stream):
+ with stream:
+ await async_fn(stream, *args)
+
+ async with trio.open_nursery() as nursery:
+ async for stream in self._incoming_connections_q.r: # pragma: no branch
+ nursery.start_soon(handler_wrapper, stream)
+ finally:
+ self._listening_context = None
+
+ def connect(self, address, ssl_context):
+ """Initiate an outgoing DTLS connection.
+
+ Notice that this is a synchronous method. That's because it doesn't actually
+ initiate any I/O – it just sets up a `DTLSChannel` object. The actual handshake
+ doesn't occur until you start using the `DTLSChannel`. This gives you a chance
+ to do further configuration first, like setting MTU etc.
+
+ Args:
+ address: The address to connect to. Usually a (host, port) tuple, like
+ ``("127.0.0.1", 12345)``.
+ ssl_context (OpenSSL.SSL.Context): The PyOpenSSL context object to use for
+ this connection.
+
+ Returns:
+ DTLSChannel
+
+ """
+ # it would be nice if we could detect when 'address' is our own endpoint (a
+ # loopback connection), because that can't work
+ # but I don't see how to do it reliably
+ self._check_closed()
+ channel = DTLSChannel._create(self, address, ssl_context)
+ channel._ssl.set_connect_state()
+ old_channel = self._streams.get(address)
+ if old_channel is not None:
+ old_channel._set_replaced()
+ self._streams[address] = channel
+ return channel
diff --git a/trio/_socket.py b/trio/_socket.py
index 886f5614f6..bcff1ee9e7 100644
--- a/trio/_socket.py
+++ b/trio/_socket.py
@@ -349,6 +349,83 @@ async def wrapper(self, *args, **kwargs):
return wrapper
+# Helpers to work with the (hostname, port) language that Python uses for socket
+# addresses everywhere. Split out into a standalone function so it can be reused by
+# FakeNet.
+
+# Take an address in Python's representation, and returns a new address in
+# the same representation, but with names resolved to numbers,
+# etc.
+#
+# local=True means that the address is being used with bind() or similar
+# local=False means that the address is being used with connect() or sendto() or
+# similar.
+#
+# NOTE: this function does not always checkpoint
+async def _resolve_address_nocp(type, family, proto, *, ipv6_v6only, address, local):
+ # Do some pre-checking (or exit early for non-IP sockets)
+ if family == _stdlib_socket.AF_INET:
+ if not isinstance(address, tuple) or not len(address) == 2:
+ raise ValueError("address should be a (host, port) tuple")
+ elif family == _stdlib_socket.AF_INET6:
+ if not isinstance(address, tuple) or not 2 <= len(address) <= 4:
+ raise ValueError(
+ "address should be a (host, port, [flowinfo, [scopeid]]) tuple"
+ )
+ elif family == _stdlib_socket.AF_UNIX:
+ # unwrap path-likes
+ return os.fspath(address)
+ else:
+ return address
+
+ # -- From here on we know we have IPv4 or IPV6 --
+ host, port, *_ = address
+ # Fast path for the simple case: already-resolved IP address,
+ # already-resolved port. This is particularly important for UDP, since
+ # every sendto call goes through here.
+ if isinstance(port, int):
+ try:
+ _stdlib_socket.inet_pton(family, address[0])
+ except (OSError, TypeError):
+ pass
+ else:
+ return address
+ # Special cases to match the stdlib, see gh-277
+ if host == "":
+ host = None
+ if host == "":
+ host = "255.255.255.255"
+ flags = 0
+ if local:
+ flags |= _stdlib_socket.AI_PASSIVE
+ # Since we always pass in an explicit family here, AI_ADDRCONFIG
+ # doesn't add any value -- if we have no ipv6 connectivity and are
+ # working with an ipv6 socket, then things will break soon enough! And
+ # if we do enable it, then it makes it impossible to even run tests
+ # for ipv6 address resolution on travis-ci, which as of 2017-03-07 has
+ # no ipv6.
+ # flags |= AI_ADDRCONFIG
+ if family == _stdlib_socket.AF_INET6 and not ipv6_v6only:
+ flags |= _stdlib_socket.AI_V4MAPPED
+ gai_res = await getaddrinfo(host, port, family, type, proto, flags)
+ # AFAICT from the spec it's not possible for getaddrinfo to return an
+ # empty list.
+ assert len(gai_res) >= 1
+ # Address is the last item in the first entry
+ (*_, normed), *_ = gai_res
+ # The above ignored any flowid and scopeid in the passed-in address,
+ # so restore them if present:
+ if family == _stdlib_socket.AF_INET6:
+ normed = list(normed)
+ assert len(normed) == 4
+ if len(address) >= 3:
+ normed[2] = address[2]
+ if len(address) >= 4:
+ normed[3] = address[3]
+ normed = tuple(normed)
+ return normed
+
+
class SocketType:
def __init__(self):
raise TypeError(
@@ -444,7 +521,7 @@ def close(self):
self._sock.close()
async def bind(self, address):
- address = await self._resolve_local_address_nocp(address)
+ address = await self._resolve_address_nocp(address, local=True)
if (
hasattr(_stdlib_socket, "AF_UNIX")
and self.family == _stdlib_socket.AF_UNIX
@@ -480,89 +557,21 @@ def is_readable(self):
async def wait_writable(self):
await _core.wait_writable(self._sock)
- ################################################################
- # Address handling
- ################################################################
-
- # Take an address in Python's representation, and returns a new address in
- # the same representation, but with names resolved to numbers,
- # etc.
- #
- # NOTE: this function does not always checkpoint
- async def _resolve_address_nocp(self, address, flags):
- # Do some pre-checking (or exit early for non-IP sockets)
- if self._sock.family == _stdlib_socket.AF_INET:
- if not isinstance(address, tuple) or not len(address) == 2:
- raise ValueError("address should be a (host, port) tuple")
- elif self._sock.family == _stdlib_socket.AF_INET6:
- if not isinstance(address, tuple) or not 2 <= len(address) <= 4:
- raise ValueError(
- "address should be a (host, port, [flowinfo, [scopeid]]) tuple"
- )
- elif self._sock.family == _stdlib_socket.AF_UNIX:
- # unwrap path-likes
- return os.fspath(address)
+ async def _resolve_address_nocp(self, address, *, local):
+ if self.family == _stdlib_socket.AF_INET6:
+ ipv6_v6only = self._sock.getsockopt(
+ IPPROTO_IPV6, _stdlib_socket.IPV6_V6ONLY
+ )
else:
- return address
-
- # -- From here on we know we have IPv4 or IPV6 --
- host, port, *_ = address
- # Fast path for the simple case: already-resolved IP address,
- # already-resolved port. This is particularly important for UDP, since
- # every sendto call goes through here.
- if isinstance(port, int):
- try:
- _stdlib_socket.inet_pton(self._sock.family, address[0])
- except (OSError, TypeError):
- pass
- else:
- return address
- # Special cases to match the stdlib, see gh-277
- if host == "":
- host = None
- if host == "":
- host = "255.255.255.255"
- # Since we always pass in an explicit family here, AI_ADDRCONFIG
- # doesn't add any value -- if we have no ipv6 connectivity and are
- # working with an ipv6 socket, then things will break soon enough! And
- # if we do enable it, then it makes it impossible to even run tests
- # for ipv6 address resolution on travis-ci, which as of 2017-03-07 has
- # no ipv6.
- # flags |= AI_ADDRCONFIG
- if self._sock.family == _stdlib_socket.AF_INET6:
- if not self._sock.getsockopt(IPPROTO_IPV6, _stdlib_socket.IPV6_V6ONLY):
- flags |= _stdlib_socket.AI_V4MAPPED
- gai_res = await getaddrinfo(
- host, port, self._sock.family, self.type, self._sock.proto, flags
+ ipv6_v6only = False
+ return await _resolve_address_nocp(
+ self.type,
+ self.family,
+ self.proto,
+ ipv6_v6only=ipv6_v6only,
+ address=address,
+ local=local,
)
- # AFAICT from the spec it's not possible for getaddrinfo to return an
- # empty list.
- assert len(gai_res) >= 1
- # Address is the last item in the first entry
- (*_, normed), *_ = gai_res
- # The above ignored any flowid and scopeid in the passed-in address,
- # so restore them if present:
- if self._sock.family == _stdlib_socket.AF_INET6:
- normed = list(normed)
- assert len(normed) == 4
- if len(address) >= 3:
- normed[2] = address[2]
- if len(address) >= 4:
- normed[3] = address[3]
- normed = tuple(normed)
- return normed
-
- # Returns something appropriate to pass to bind()
- #
- # NOTE: this function does not always checkpoint
- async def _resolve_local_address_nocp(self, address):
- return await self._resolve_address_nocp(address, _stdlib_socket.AI_PASSIVE)
-
- # Returns something appropriate to pass to connect()/sendto()/sendmsg()
- #
- # NOTE: this function does not always checkpoint
- async def _resolve_remote_address_nocp(self, address):
- return await self._resolve_address_nocp(address, 0)
async def _nonblocking_helper(self, fn, args, kwargs, wait_fn):
# We have to reconcile two conflicting goals:
@@ -617,7 +626,7 @@ async def connect(self, address):
# notification. This means it isn't really cancellable... we close the
# socket if cancelled, to avoid confusion.
try:
- address = await self._resolve_remote_address_nocp(address)
+ address = await self._resolve_address_nocp(address, local=False)
async with _try_sync():
# An interesting puzzle: can a non-blocking connect() return EINTR
# (= raise InterruptedError)? PEP 475 specifically left this as
@@ -741,7 +750,7 @@ async def sendto(self, *args):
# args is: data[, flags], address)
# and kwargs are not accepted
args = list(args)
- args[-1] = await self._resolve_remote_address_nocp(args[-1])
+ args[-1] = await self._resolve_address_nocp(args[-1], local=False)
return await self._nonblocking_helper(
_stdlib_socket.socket.sendto, args, {}, _core.wait_writable
)
@@ -766,7 +775,7 @@ async def sendmsg(self, *args):
# and kwargs are not accepted
if len(args) == 4 and args[-1] is not None:
args = list(args)
- args[-1] = await self._resolve_remote_address_nocp(args[-1])
+ args[-1] = await self._resolve_address_nocp(args[-1], local=False)
return await self._nonblocking_helper(
_stdlib_socket.socket.sendmsg, args, {}, _core.wait_writable
)
diff --git a/trio/testing/_fake_net.py b/trio/testing/_fake_net.py
new file mode 100644
index 0000000000..f0ea927734
--- /dev/null
+++ b/trio/testing/_fake_net.py
@@ -0,0 +1,400 @@
+# This should eventually be cleaned up and become public, but for right now I'm just
+# implementing enough to test DTLS.
+
+# TODO:
+# - user-defined routers
+# - TCP
+# - UDP broadcast
+
+import trio
+import attr
+import ipaddress
+from collections import deque
+import errno
+import os
+from typing import Union, List, Optional
+import enum
+from contextlib import contextmanager
+
+from trio._util import Final, NoPublicConstructor
+
+IPAddress = Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
+
+
+def _family_for(ip: IPAddress) -> int:
+ if isinstance(ip, ipaddress.IPv4Address):
+ return trio.socket.AF_INET
+ elif isinstance(ip, ipaddress.IPv6Address):
+ return trio.socket.AF_INET6
+ assert False # pragma: no cover
+
+
+def _wildcard_ip_for(family: int) -> IPAddress:
+ if family == trio.socket.AF_INET:
+ return ipaddress.ip_address("0.0.0.0")
+ elif family == trio.socket.AF_INET6:
+ return ipaddress.ip_address("::")
+ else:
+ assert False
+
+
+def _localhost_ip_for(family: int) -> IPAddress:
+ if family == trio.socket.AF_INET:
+ return ipaddress.ip_address("127.0.0.1")
+ elif family == trio.socket.AF_INET6:
+ return ipaddress.ip_address("::1")
+ else:
+ assert False
+
+
+def _fake_err(code):
+ raise OSError(code, os.strerror(code))
+
+
+def _scatter(data, buffers):
+ written = 0
+ for buf in buffers:
+ next_piece = data[written : written + len(buf)]
+ with memoryview(buf) as mbuf:
+ mbuf[: len(next_piece)] = next_piece
+ written += len(next_piece)
+ if written == len(data):
+ break
+ return written
+
+
+@attr.frozen
+class UDPEndpoint:
+ ip: IPAddress
+ port: int
+
+ def as_python_sockaddr(self):
+ sockaddr = (self.ip.compressed, self.port)
+ if isinstance(self.ip, ipaddress.IPv6Address):
+ sockaddr += (0, 0)
+ return sockaddr
+
+ @classmethod
+ def from_python_sockaddr(cls, sockaddr):
+ ip, port = sockaddr[:2]
+ return cls(ip=ipaddress.ip_address(ip), port=port)
+
+
+@attr.frozen
+class UDPBinding:
+ local: UDPEndpoint
+
+
+@attr.frozen
+class UDPPacket:
+ source: UDPEndpoint
+ destination: UDPEndpoint
+ payload: bytes = attr.ib(repr=lambda p: p.hex())
+
+ def reply(self, payload):
+ return UDPPacket(
+ source=self.destination, destination=self.source, payload=payload
+ )
+
+
+@attr.frozen
+class FakeSocketFactory(trio.abc.SocketFactory):
+ fake_net: "FakeNet"
+
+ def socket(self, family: int, type: int, proto: int) -> "FakeSocket":
+ return FakeSocket._create(self.fake_net, family, type, proto)
+
+
+@attr.frozen
+class FakeHostnameResolver(trio.abc.HostnameResolver):
+ fake_net: "FakeNet"
+
+ async def getaddrinfo(
+ self, host: str, port: Union[int, str], family=0, type=0, proto=0, flags=0
+ ):
+ raise NotImplementedError("FakeNet doesn't do fake DNS yet")
+
+ async def getnameinfo(self, sockaddr, flags: int):
+ raise NotImplementedError("FakeNet doesn't do fake DNS yet")
+
+
+class FakeNet(metaclass=Final):
+ def __init__(self):
+ # When we need to pick an arbitrary unique ip address/port, use these:
+ self._auto_ipv4_iter = ipaddress.IPv4Network("1.0.0.0/8").hosts()
+ self._auto_ipv4_iter = ipaddress.IPv6Network("1::/16").hosts()
+ self._auto_port_iter = iter(range(50000, 65535))
+
+ self._bound: Dict[UDPBinding, FakeSocket] = {}
+
+ self.route_packet = None
+
+ def _bind(self, binding: UDPBinding, socket: "FakeSocket") -> None:
+ if binding in self._bound:
+ _fake_err(errno.EADDRINUSE)
+ self._bound[binding] = socket
+
+ def enable(self) -> None:
+ trio.socket.set_custom_socket_factory(FakeSocketFactory(self))
+ trio.socket.set_custom_hostname_resolver(FakeHostnameResolver(self))
+
+ def send_packet(self, packet) -> None:
+ if self.route_packet is None:
+ self.deliver_packet(packet)
+ else:
+ self.route_packet(packet)
+
+ def deliver_packet(self, packet) -> None:
+ binding = UDPBinding(local=packet.destination)
+ if binding in self._bound:
+ self._bound[binding]._deliver_packet(packet)
+ else:
+ # No valid destination, so drop it
+ pass
+
+
+class FakeSocket(trio.socket.SocketType, metaclass=NoPublicConstructor):
+ def __init__(self, fake_net: FakeNet, family: int, type: int, proto: int):
+ self._fake_net = fake_net
+
+ if not family:
+ family = trio.socket.AF_INET
+ if not type:
+ type = trio.socket.SOCK_STREAM
+
+ if family not in (trio.socket.AF_INET, trio.socket.AF_INET6):
+ raise NotImplementedError(f"FakeNet doesn't (yet) support family={family}")
+ if type != trio.socket.SOCK_DGRAM:
+ raise NotImplementedError(f"FakeNet doesn't (yet) support type={type}")
+
+ self.family = family
+ self.type = type
+ self.proto = proto
+
+ self._closed = False
+
+ self._packet_sender, self._packet_receiver = trio.open_memory_channel(
+ float("inf")
+ )
+
+ # This is the source-of-truth for what port etc. this socket is bound to
+ self._binding: Optional[UDPBinding] = None
+
+ def _check_closed(self):
+ if self._closed:
+ _fake_err(errno.EBADF)
+
+ def close(self):
+ # breakpoint()
+ if self._closed:
+ return
+ self._closed = True
+ if self._binding is not None:
+ del self._fake_net._bound[self._binding]
+ self._packet_receiver.close()
+
+ async def _resolve_address_nocp(self, address, *, local):
+ return await trio._socket._resolve_address_nocp(
+ self.type,
+ self.family,
+ self.proto,
+ address=address,
+ ipv6_v6only=False,
+ local=local,
+ )
+
+ def _deliver_packet(self, packet: UDPPacket):
+ try:
+ self._packet_sender.send_nowait(packet)
+ except trio.BrokenResourceError:
+ # sending to a closed socket -- UDP packets get dropped
+ pass
+
+ ################################################################
+ # Actual IO operation implementations
+ ################################################################
+
+ async def bind(self, addr):
+ self._check_closed()
+ if self._binding is not None:
+ _fake_error(errno.EINVAL)
+ await trio.lowlevel.checkpoint()
+ ip_str, port = await self._resolve_address_nocp(addr, local=True)
+ ip = ipaddress.ip_address(ip_str)
+ assert _family_for(ip) == self.family
+ # We convert binds to INET_ANY into binds to localhost
+ if ip == ipaddress.ip_address("0.0.0.0"):
+ ip = ipaddress.ip_address("127.0.0.1")
+ elif ip == ipaddress.ip_address("::"):
+ ip = ipaddress.ip_address("::1")
+ if port == 0:
+ port = next(self._fake_net._auto_port_iter)
+ binding = UDPBinding(local=UDPEndpoint(ip, port))
+ self._fake_net._bind(binding, self)
+ self._binding = binding
+
+ async def connect(self, peer):
+ raise NotImplementedError("FakeNet does not (yet) support connected sockets")
+
+ async def sendmsg(self, *args):
+ self._check_closed()
+ ancdata = []
+ flags = 0
+ address = None
+ if len(args) == 1:
+ (buffers,) = args
+ elif len(args) == 2:
+ buffers, address = args
+ elif len(args) == 3:
+ buffers, flags, address = args
+ elif len(args) == 4:
+ buffers, ancdata, flags, address = args
+ else:
+ raise TypeError("wrong number of arguments")
+
+ await trio.lowlevel.checkpoint()
+
+ if address is not None:
+ address = await self._resolve_address_nocp(address, local=False)
+ if ancdata:
+ raise NotImplementedError("FakeNet doesn't support ancillary data")
+ if flags:
+ raise NotImplementedError(f"FakeNet send flags must be 0, not {flags}")
+
+ if address is None:
+ _fake_err(errno.ENOTCONN)
+
+ destination = UDPEndpoint.from_python_sockaddr(address)
+
+ if self._binding is None:
+ await self.bind((_wildcard_ip_for(self.family).compressed, 0))
+
+ payload = b"".join(buffers)
+
+ packet = UDPPacket(
+ source=self._binding.local,
+ destination=destination,
+ payload=payload,
+ )
+
+ self._fake_net.send_packet(packet)
+
+ return len(payload)
+
+ async def recvmsg_into(self, buffers, ancbufsize=0, flags=0):
+ if ancbufsize != 0:
+ raise NotImplementedError("FakeNet doesn't support ancillary data")
+ if flags != 0:
+ raise NotImplementedError("FakeNet doesn't support any recv flags")
+
+ self._check_closed()
+
+ ancdata = []
+ msg_flags = 0
+
+ packet = await self._packet_receiver.receive()
+ address = packet.source.as_python_sockaddr()
+ written = _scatter(packet.payload, buffers)
+ if written < len(packet.payload):
+ msg_flags |= trio.socket.MSG_TRUNC
+ return written, ancdata, msg_flags, address
+
+ ################################################################
+ # Simple state query stuff
+ ################################################################
+
+ def getsockname(self):
+ self._check_closed()
+ if self._binding is not None:
+ return self._binding.local.as_python_sockaddr()
+ elif self.family == trio.socket.AF_INET:
+ return ("0.0.0.0", 0)
+ else:
+ assert self.family == trio.socket.AF_INET6
+ return ("::", 0)
+
+ def getpeername(self):
+ self._check_closed()
+ if self._binding is not None:
+ if self._binding.remote is not None:
+ return self._binding.remote.as_python_sockaddr()
+ _fake_err(errno.ENOTCONN)
+
+ def getsockopt(self, level, item):
+ self._check_closed()
+ raise OSError(f"FakeNet doesn't implement getsockopt({level}, {item})")
+
+ def setsockopt(self, level, item, value):
+ self._check_closed()
+
+ if (level, item) == (trio.socket.IPPROTO_IPV6, trio.socket.IPV6_V6ONLY):
+ if not value:
+ raise NotImplementedError("FakeNet always has IPV6_V6ONLY=True")
+
+ raise OSError(f"FakeNet doesn't implement setsockopt({level}, {item}, ...)")
+
+ ################################################################
+ # Various boilerplate and trivial stubs
+ ################################################################
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *exc_info):
+ self.close()
+
+ async def send(self, data, flags=0):
+ return await self.sendto(data, flags, None)
+
+ async def sendto(self, *args):
+ if len(args) == 2:
+ data, address = args
+ flags = 0
+ elif len(args) == 3:
+ data, flags, address = args
+ else:
+ raise TypeError("wrong number of arguments")
+ return await self.sendmsg([data], [], flags, address)
+
+ async def recv(self, bufsize, flags=0):
+ data, address = await self.recvfrom(bufsize, flags)
+ return data
+
+ async def recv_into(self, buf, nbytes=0, flags=0):
+ got_bytes, address = await self.recvfrom_into(buf, nbytes, flags)
+ return got_bytes
+
+ async def recvfrom(self, bufsize, flags=0):
+ data, ancdata, msg_flags, address = await self.recvmsg(bufsize, flags)
+ return data, address
+
+ async def recvfrom_into(self, buf, nbytes=0, flags=0):
+ if nbytes != 0 and nbytes != len(buf):
+ raise NotImplementedError("partial recvfrom_into")
+ got_nbytes, ancdata, msg_flags, address = await self.recvmsg_into(
+ [buf], 0, flags
+ )
+ return got_nbytes, address
+
+ async def recvmsg(self, bufsize, ancbufsize=0, flags=0):
+ buf = bytearray(bufsize)
+ got_nbytes, ancdata, msg_flags, address = await self.recvmsg_into(
+ [buf], ancbufsize, flags
+ )
+ return (bytes(buf[:got_nbytes]), ancdata, msg_flags, address)
+
+ def fileno(self):
+ raise NotImplementedError("can't get fileno() for FakeNet sockets")
+
+ def detach(self):
+ raise NotImplementedError("can't detach() a FakeNet socket")
+
+ def get_inheritable(self):
+ return False
+
+ def set_inheritable(self, inheritable):
+ if inheritable:
+ raise NotImplementedError("FakeNet can't make inheritable sockets")
+
+ def share(self, process_id):
+ raise NotImplementedError("FakeNet can't share sockets")
diff --git a/trio/tests/test_dtls.py b/trio/tests/test_dtls.py
new file mode 100644
index 0000000000..8968d9a601
--- /dev/null
+++ b/trio/tests/test_dtls.py
@@ -0,0 +1,867 @@
+import pytest
+import trio
+import trio.testing
+from trio import DTLSEndpoint
+import random
+import attr
+from async_generator import asynccontextmanager
+from itertools import count
+
+import trustme
+from OpenSSL import SSL
+
+from trio.testing._fake_net import FakeNet
+from .._core.tests.tutil import slow, binds_ipv6, gc_collect_harder
+
+ca = trustme.CA()
+server_cert = ca.issue_cert("example.com")
+
+server_ctx = SSL.Context(SSL.DTLS_METHOD)
+server_cert.configure_cert(server_ctx)
+
+client_ctx = SSL.Context(SSL.DTLS_METHOD)
+ca.configure_trust(client_ctx)
+
+
+parametrize_ipv6 = pytest.mark.parametrize(
+ "ipv6", [False, pytest.param(True, marks=binds_ipv6)], ids=["ipv4", "ipv6"]
+)
+
+
+def endpoint(**kwargs):
+ ipv6 = kwargs.pop("ipv6", False)
+ if ipv6:
+ family = trio.socket.AF_INET6
+ else:
+ family = trio.socket.AF_INET
+ sock = trio.socket.socket(type=trio.socket.SOCK_DGRAM, family=family)
+ return DTLSEndpoint(sock, **kwargs)
+
+
+@asynccontextmanager
+async def dtls_echo_server(*, autocancel=True, mtu=None, ipv6=False):
+ with endpoint(ipv6=ipv6) as server:
+ if ipv6:
+ localhost = "::1"
+ else:
+ localhost = "127.0.0.1"
+ await server.socket.bind((localhost, 0))
+ async with trio.open_nursery() as nursery:
+
+ async def echo_handler(dtls_channel):
+ print(
+ f"echo handler started: "
+ f"server {dtls_channel.endpoint.socket.getsockname()} "
+ f"client {dtls_channel.peer_address}"
+ )
+ if mtu is not None:
+ dtls_channel.set_ciphertext_mtu(mtu)
+ try:
+ print("server starting do_handshake")
+ await dtls_channel.do_handshake()
+ print("server finished do_handshake")
+ async for packet in dtls_channel:
+ print(f"echoing {packet} -> {dtls_channel.peer_address}")
+ await dtls_channel.send(packet)
+ except trio.BrokenResourceError: # pragma: no cover
+ print("echo handler channel broken")
+
+ await nursery.start(server.serve, server_ctx, echo_handler)
+
+ yield server, server.socket.getsockname()
+
+ if autocancel:
+ nursery.cancel_scope.cancel()
+
+
+@parametrize_ipv6
+async def test_smoke(ipv6):
+ async with dtls_echo_server(ipv6=ipv6) as (server_endpoint, address):
+ with endpoint(ipv6=ipv6) as client_endpoint:
+ client_channel = client_endpoint.connect(address, client_ctx)
+ with pytest.raises(trio.NeedHandshakeError):
+ client_channel.get_cleartext_mtu()
+
+ await client_channel.do_handshake()
+ await client_channel.send(b"hello")
+ assert await client_channel.receive() == b"hello"
+ await client_channel.send(b"goodbye")
+ assert await client_channel.receive() == b"goodbye"
+
+ with pytest.raises(ValueError):
+ await client_channel.send(b"")
+
+ client_channel.set_ciphertext_mtu(1234)
+ cleartext_mtu_1234 = client_channel.get_cleartext_mtu()
+ client_channel.set_ciphertext_mtu(4321)
+ assert client_channel.get_cleartext_mtu() > cleartext_mtu_1234
+ client_channel.set_ciphertext_mtu(1234)
+ assert client_channel.get_cleartext_mtu() == cleartext_mtu_1234
+
+
+@slow
+async def test_handshake_over_terrible_network(autojump_clock):
+ HANDSHAKES = 1000
+ r = random.Random(0)
+ fn = FakeNet()
+ fn.enable()
+
+ async with dtls_echo_server() as (_, address):
+ async with trio.open_nursery() as nursery:
+
+ async def route_packet(packet):
+ while True:
+ op = r.choices(
+ ["deliver", "drop", "dupe", "delay"],
+ weights=[0.7, 0.1, 0.1, 0.1],
+ )[0]
+ print(f"{packet.source} -> {packet.destination}: {op}")
+ if op == "drop":
+ return
+ elif op == "dupe":
+ fn.send_packet(packet)
+ elif op == "delay":
+ await trio.sleep(r.random() * 3)
+ # I wanted to test random packet corruption too, but it turns out
+ # openssl has a bug in the following scenario:
+ #
+ # - client sends ClientHello
+ # - server sends HelloVerifyRequest with cookie -- but cookie is
+ # invalid b/c either the ClientHello or HelloVerifyRequest was
+ # corrupted
+ # - client re-sends ClientHello with invalid cookie
+ # - server replies with new HelloVerifyRequest and correct cookie
+ #
+ # At this point, the client *should* switch to the new, valid
+ # cookie. But OpenSSL doesn't; it stubbornly insists on re-sending
+ # the original, invalid cookie over and over. In theory we could
+ # work around this by detecting cookie changes and starting over
+ # with a whole new SSL object, but (a) it doesn't seem worth it, (b)
+ # when I tried then I ran into another issue where OpenSSL got stuck
+ # in an infinite loop sending alerts over and over, which I didn't
+ # dig into because see (a).
+ #
+ # elif op == "distort":
+ # payload = bytearray(packet.payload)
+ # payload[r.randrange(len(payload))] ^= 1 << r.randrange(8)
+ # packet = attr.evolve(packet, payload=payload)
+ else:
+ assert op == "deliver"
+ print(
+ f"{packet.source} -> {packet.destination}: delivered {packet.payload.hex()}"
+ )
+ fn.deliver_packet(packet)
+ break
+
+ def route_packet_wrapper(packet):
+ try:
+ nursery.start_soon(route_packet, packet)
+ except RuntimeError: # pragma: no cover
+ # We're exiting the nursery, so any remaining packets can just get
+ # dropped
+ pass
+
+ fn.route_packet = route_packet_wrapper
+
+ for i in range(HANDSHAKES):
+ print("#" * 80)
+ print("#" * 80)
+ print("#" * 80)
+ with endpoint() as client_endpoint:
+ client = client_endpoint.connect(address, client_ctx)
+ print("client starting do_handshake")
+ await client.do_handshake()
+ print("client finished do_handshake")
+ msg = str(i).encode()
+ # Make multiple attempts to send data, because the network might
+ # drop it
+ while True:
+ with trio.move_on_after(10) as cscope:
+ await client.send(msg)
+ assert await client.receive() == msg
+ if not cscope.cancelled_caught:
+ break
+
+
+async def test_implicit_handshake():
+ async with dtls_echo_server() as (_, address):
+ with endpoint() as client_endpoint:
+ client = client_endpoint.connect(address, client_ctx)
+
+ # Implicit handshake
+ await client.send(b"xyz")
+ assert await client.receive() == b"xyz"
+
+
+async def test_full_duplex():
+ # Tests simultaneous send/receive, and also multiple methods implicitly invoking
+ # do_handshake simultaneously.
+ with endpoint() as server_endpoint, endpoint() as client_endpoint:
+ await server_endpoint.socket.bind(("127.0.0.1", 0))
+ async with trio.open_nursery() as server_nursery:
+
+ async def handler(channel):
+ async with trio.open_nursery() as nursery:
+ nursery.start_soon(channel.send, b"from server")
+ nursery.start_soon(channel.receive)
+
+ await server_nursery.start(server_endpoint.serve, server_ctx, handler)
+
+ client = client_endpoint.connect(
+ server_endpoint.socket.getsockname(), client_ctx
+ )
+ async with trio.open_nursery() as nursery:
+ nursery.start_soon(client.send, b"from client")
+ nursery.start_soon(client.receive)
+
+ server_nursery.cancel_scope.cancel()
+
+
+async def test_channel_closing():
+ async with dtls_echo_server() as (_, address):
+ with endpoint() as client_endpoint:
+ client = client_endpoint.connect(address, client_ctx)
+ await client.do_handshake()
+ client.close()
+
+ with pytest.raises(trio.ClosedResourceError):
+ await client.send(b"abc")
+ with pytest.raises(trio.ClosedResourceError):
+ await client.receive()
+
+ # close is idempotent
+ client.close()
+ # can also aclose
+ await client.aclose()
+
+
+async def test_serve_exits_cleanly_on_close():
+ async with dtls_echo_server(autocancel=False) as (server_endpoint, address):
+ server_endpoint.close()
+ # Testing that the nursery exits even without being cancelled
+ # close is idempotent
+ server_endpoint.close()
+
+
+async def test_client_multiplex():
+ async with dtls_echo_server() as (_, address1), dtls_echo_server() as (_, address2):
+ with endpoint() as client_endpoint:
+ client1 = client_endpoint.connect(address1, client_ctx)
+ client2 = client_endpoint.connect(address2, client_ctx)
+
+ await client1.send(b"abc")
+ await client2.send(b"xyz")
+ assert await client2.receive() == b"xyz"
+ assert await client1.receive() == b"abc"
+
+ client_endpoint.close()
+
+ with pytest.raises(trio.ClosedResourceError):
+ await client1.send("xxx")
+ with pytest.raises(trio.ClosedResourceError):
+ await client2.receive()
+ with pytest.raises(trio.ClosedResourceError):
+ client_endpoint.connect(address1, client_ctx)
+
+ async with trio.open_nursery() as nursery:
+ with pytest.raises(trio.ClosedResourceError):
+
+ async def null_handler(_): # pragma: no cover
+ pass
+
+ await nursery.start(client_endpoint.serve, server_ctx, null_handler)
+
+
+async def test_dtls_over_dgram_only():
+ with trio.socket.socket() as s:
+ with pytest.raises(ValueError):
+ DTLSEndpoint(s)
+
+
+async def test_double_serve():
+ async def null_handler(_): # pragma: no cover
+ pass
+
+ with endpoint() as server_endpoint:
+ await server_endpoint.socket.bind(("127.0.0.1", 0))
+ async with trio.open_nursery() as nursery:
+ await nursery.start(server_endpoint.serve, server_ctx, null_handler)
+ with pytest.raises(trio.BusyResourceError):
+ await nursery.start(server_endpoint.serve, server_ctx, null_handler)
+
+ nursery.cancel_scope.cancel()
+
+ async with trio.open_nursery() as nursery:
+ await nursery.start(server_endpoint.serve, server_ctx, null_handler)
+ nursery.cancel_scope.cancel()
+
+
+async def test_connect_to_non_server(autojump_clock):
+ fn = FakeNet()
+ fn.enable()
+ with endpoint() as client1, endpoint() as client2:
+ await client1.socket.bind(("127.0.0.1", 0))
+ # This should just time out
+ with trio.move_on_after(100) as cscope:
+ channel = client2.connect(client1.socket.getsockname(), client_ctx)
+ await channel.do_handshake()
+ assert cscope.cancelled_caught
+
+
+async def test_incoming_buffer_overflow(autojump_clock):
+ fn = FakeNet()
+ fn.enable()
+ for buffer_size in [10, 20]:
+ async with dtls_echo_server() as (_, address):
+ with endpoint(incoming_packets_buffer=buffer_size) as client_endpoint:
+ assert client_endpoint.incoming_packets_buffer == buffer_size
+ client = client_endpoint.connect(address, client_ctx)
+ for i in range(buffer_size + 15):
+ await client.send(str(i).encode())
+ await trio.sleep(1)
+ stats = client.statistics()
+ assert stats.incoming_packets_dropped_in_trio == 15
+ for i in range(buffer_size):
+ assert await client.receive() == str(i).encode()
+ await client.send(b"buffer clear now")
+ assert await client.receive() == b"buffer clear now"
+
+
+async def test_server_socket_doesnt_crash_on_garbage(autojump_clock):
+ fn = FakeNet()
+ fn.enable()
+
+ from trio._dtls import (
+ Record,
+ encode_record,
+ HandshakeFragment,
+ encode_handshake_fragment,
+ ContentType,
+ HandshakeType,
+ ProtocolVersion,
+ )
+
+ client_hello = encode_record(
+ Record(
+ content_type=ContentType.handshake,
+ version=ProtocolVersion.DTLS10,
+ epoch_seqno=0,
+ payload=encode_handshake_fragment(
+ HandshakeFragment(
+ msg_type=HandshakeType.client_hello,
+ msg_len=10,
+ msg_seq=0,
+ frag_offset=0,
+ frag_len=10,
+ frag=bytes(10),
+ )
+ ),
+ )
+ )
+
+ client_hello_extended = client_hello + b"\x00"
+ client_hello_short = client_hello[:-1]
+ # cuts off in middle of handshake message header
+ client_hello_really_short = client_hello[:14]
+ client_hello_corrupt_record_len = bytearray(client_hello)
+ client_hello_corrupt_record_len[11] = 0xFF
+
+ client_hello_fragmented = encode_record(
+ Record(
+ content_type=ContentType.handshake,
+ version=ProtocolVersion.DTLS10,
+ epoch_seqno=0,
+ payload=encode_handshake_fragment(
+ HandshakeFragment(
+ msg_type=HandshakeType.client_hello,
+ msg_len=20,
+ msg_seq=0,
+ frag_offset=0,
+ frag_len=10,
+ frag=bytes(10),
+ )
+ ),
+ )
+ )
+
+ client_hello_trailing_data_in_record = encode_record(
+ Record(
+ content_type=ContentType.handshake,
+ version=ProtocolVersion.DTLS10,
+ epoch_seqno=0,
+ payload=encode_handshake_fragment(
+ HandshakeFragment(
+ msg_type=HandshakeType.client_hello,
+ msg_len=20,
+ msg_seq=0,
+ frag_offset=0,
+ frag_len=10,
+ frag=bytes(10),
+ )
+ )
+ + b"\x00",
+ )
+ )
+
+ handshake_empty = encode_record(
+ Record(
+ content_type=ContentType.handshake,
+ version=ProtocolVersion.DTLS10,
+ epoch_seqno=0,
+ payload=b"",
+ )
+ )
+
+ client_hello_truncated_in_cookie = encode_record(
+ Record(
+ content_type=ContentType.handshake,
+ version=ProtocolVersion.DTLS10,
+ epoch_seqno=0,
+ payload=bytes(2 + 32 + 1) + b"\xff",
+ )
+ )
+
+ async with dtls_echo_server() as (_, address):
+ with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as sock:
+ for bad_packet in [
+ b"",
+ b"xyz",
+ client_hello_extended,
+ client_hello_short,
+ client_hello_really_short,
+ client_hello_corrupt_record_len,
+ client_hello_fragmented,
+ client_hello_trailing_data_in_record,
+ handshake_empty,
+ client_hello_truncated_in_cookie,
+ ]:
+ await sock.sendto(bad_packet, address)
+ await trio.sleep(1)
+
+
+async def test_invalid_cookie_rejected(autojump_clock):
+ fn = FakeNet()
+ fn.enable()
+
+ from trio._dtls import decode_client_hello_untrusted, BadPacket
+
+ with trio.CancelScope() as cscope:
+
+ # the first 11 bytes of ClientHello aren't protected by the cookie, so only test
+ # corrupting bytes after that.
+ offset_to_corrupt = count(11)
+
+ def route_packet(packet):
+ try:
+ _, cookie, _ = decode_client_hello_untrusted(packet.payload)
+ except BadPacket:
+ pass
+ else:
+ if len(cookie) != 0:
+ # this is a challenge response packet
+ # let's corrupt the next offset so the handshake should fail
+ payload = bytearray(packet.payload)
+ offset = next(offset_to_corrupt)
+ if offset >= len(payload):
+ # We've tried all offsets. Clamp offset to the end of the
+ # payload, and terminate the test.
+ offset = len(payload) - 1
+ cscope.cancel()
+ payload[offset] ^= 0x01
+ packet = attr.evolve(packet, payload=payload)
+
+ fn.deliver_packet(packet)
+
+ fn.route_packet = route_packet
+
+ async with dtls_echo_server() as (_, address):
+ while True:
+ with endpoint() as client:
+ channel = client.connect(address, client_ctx)
+ await channel.do_handshake()
+ assert cscope.cancelled_caught
+
+
+async def test_client_cancels_handshake_and_starts_new_one(autojump_clock):
+ # if a client disappears during the handshake, and then starts a new handshake from
+ # scratch, then the first handler's channel should fail, and a new handler get
+ # started
+ fn = FakeNet()
+ fn.enable()
+
+ with endpoint() as server, endpoint() as client:
+ await server.socket.bind(("127.0.0.1", 0))
+ async with trio.open_nursery() as nursery:
+ first_time = True
+
+ async def handler(channel):
+ nonlocal first_time
+ if first_time:
+ first_time = False
+ print("handler: first time, cancelling connect")
+ connect_cscope.cancel()
+ await trio.sleep(0.5)
+ print("handler: handshake should fail now")
+ with pytest.raises(trio.BrokenResourceError):
+ await channel.do_handshake()
+ else:
+ print("handler: not first time, sending hello")
+ await channel.send(b"hello")
+
+ await nursery.start(server.serve, server_ctx, handler)
+
+ print("client: starting first connect")
+ with trio.CancelScope() as connect_cscope:
+ channel = client.connect(server.socket.getsockname(), client_ctx)
+ await channel.do_handshake()
+ assert connect_cscope.cancelled_caught
+
+ print("client: starting second connect")
+ channel = client.connect(server.socket.getsockname(), client_ctx)
+ assert await channel.receive() == b"hello"
+
+ # Give handlers a chance to finish
+ await trio.sleep(10)
+ nursery.cancel_scope.cancel()
+
+
+async def test_swap_client_server():
+ with endpoint() as a, endpoint() as b:
+ await a.socket.bind(("127.0.0.1", 0))
+ await b.socket.bind(("127.0.0.1", 0))
+
+ async def echo_handler(channel):
+ async for packet in channel:
+ await channel.send(packet)
+
+ async def crashing_echo_handler(channel):
+ with pytest.raises(trio.BrokenResourceError):
+ await echo_handler(channel)
+
+ async with trio.open_nursery() as nursery:
+ await nursery.start(a.serve, server_ctx, crashing_echo_handler)
+ await nursery.start(b.serve, server_ctx, echo_handler)
+
+ b_to_a = b.connect(a.socket.getsockname(), client_ctx)
+ await b_to_a.send(b"b as client")
+ assert await b_to_a.receive() == b"b as client"
+
+ a_to_b = a.connect(b.socket.getsockname(), client_ctx)
+ await a_to_b.do_handshake()
+ with pytest.raises(trio.BrokenResourceError):
+ await b_to_a.send(b"association broken")
+ await a_to_b.send(b"a as client")
+ assert await a_to_b.receive() == b"a as client"
+
+ nursery.cancel_scope.cancel()
+
+
+@slow
+async def test_openssl_retransmit_doesnt_break_stuff():
+ # can't use autojump_clock here, because the point of the test is to wait for
+ # openssl's built-in retransmit timer to expire, which is hard-coded to use
+ # wall-clock time.
+ fn = FakeNet()
+ fn.enable()
+
+ blackholed = True
+
+ def route_packet(packet):
+ if blackholed:
+ print("dropped packet", packet)
+ return
+ print("delivered packet", packet)
+ # packets.append(
+ # scapy.all.IP(
+ # src=packet.source.ip.compressed, dst=packet.destination.ip.compressed
+ # )
+ # / scapy.all.UDP(sport=packet.source.port, dport=packet.destination.port)
+ # / packet.payload
+ # )
+ fn.deliver_packet(packet)
+
+ fn.route_packet = route_packet
+
+ async with dtls_echo_server() as (server_endpoint, address):
+ with endpoint() as client_endpoint:
+ async with trio.open_nursery() as nursery:
+
+ async def connecter():
+ client = client_endpoint.connect(address, client_ctx)
+ await client.do_handshake(initial_retransmit_timeout=1.5)
+ await client.send(b"hi")
+ assert await client.receive() == b"hi"
+
+ nursery.start_soon(connecter)
+
+ # openssl's default timeout is 1 second, so this ensures that it thinks
+ # the timeout has expired
+ await trio.sleep(1.1)
+ # disable blackholing and send a garbage packet to wake up openssl so it
+ # notices the timeout has expired
+ blackholed = False
+ await server_endpoint.socket.sendto(
+ b"xxx", client_endpoint.socket.getsockname()
+ )
+ # now the client task should finish connecting and exit cleanly
+
+ # scapy.all.wrpcap("/tmp/trace.pcap", packets)
+
+
+async def test_initial_retransmit_timeout_configuration(autojump_clock):
+ fn = FakeNet()
+ fn.enable()
+
+ blackholed = True
+
+ def route_packet(packet):
+ nonlocal blackholed
+ if blackholed:
+ blackholed = False
+ else:
+ fn.deliver_packet(packet)
+
+ fn.route_packet = route_packet
+
+ async with dtls_echo_server() as (_, address):
+ for t in [1, 2, 4]:
+ with endpoint() as client:
+ before = trio.current_time()
+ blackholed = True
+ channel = client.connect(address, client_ctx)
+ await channel.do_handshake(initial_retransmit_timeout=t)
+ after = trio.current_time()
+ assert after - before == t
+
+
+async def test_explicit_tiny_mtu_is_respected():
+ # ClientHello is ~240 bytes, and it can't be fragmented, so our mtu has to
+ # be larger than that. (300 is still smaller than any real network though.)
+ MTU = 300
+
+ fn = FakeNet()
+ fn.enable()
+
+ def route_packet(packet):
+ print(f"delivering {packet}")
+ print(f"payload size: {len(packet.payload)}")
+ assert len(packet.payload) <= MTU
+ fn.deliver_packet(packet)
+
+ fn.route_packet = route_packet
+
+ async with dtls_echo_server(mtu=MTU) as (server, address):
+ with endpoint() as client:
+ channel = client.connect(address, client_ctx)
+ channel.set_ciphertext_mtu(MTU)
+ await channel.do_handshake()
+ await channel.send(b"hi")
+ assert await channel.receive() == b"hi"
+
+
+@parametrize_ipv6
+async def test_handshake_handles_minimum_network_mtu(ipv6, autojump_clock):
+ # Fake network that has the minimum allowable MTU for whatever protocol we're using.
+ fn = FakeNet()
+ fn.enable()
+
+ if ipv6:
+ mtu = 1280 - 48
+ else:
+ mtu = 576 - 28
+
+ def route_packet(packet):
+ if len(packet.payload) > mtu:
+ print(f"dropping {packet}")
+ else:
+ print(f"delivering {packet}")
+ fn.deliver_packet(packet)
+
+ fn.route_packet = route_packet
+
+ # See if we can successfully do a handshake -- some of the volleys will get dropped,
+ # and the retransmit logic should detect this and back off the MTU to something
+ # smaller until it succeeds.
+ async with dtls_echo_server(ipv6=ipv6) as (_, address):
+ with endpoint(ipv6=ipv6) as client_endpoint:
+ client = client_endpoint.connect(address, client_ctx)
+ # the handshake mtu backoff shouldn't affect the return value from
+ # get_cleartext_mtu, b/c that's under the user's control via
+ # set_ciphertext_mtu
+ client.set_ciphertext_mtu(9999)
+ await client.send(b"xyz")
+ assert await client.receive() == b"xyz"
+ assert client.get_cleartext_mtu() > 9000 # as vegeta said
+
+
+@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning")
+async def test_system_task_cleaned_up_on_gc():
+ before_tasks = trio.lowlevel.current_statistics().tasks_living
+
+ # We put this into a sub-function so that everything automatically becomes garbage
+ # when the frame exits. For some reason just doing 'del e' wasn't enough on pypy
+ # with coverage enabled -- I think we were hitting this bug:
+ # https://foss.heptapod.net/pypy/pypy/-/issues/3656
+ async def start_and_forget_endpoint():
+ e = endpoint()
+
+ # This connection/handshake attempt can't succeed. The only purpose is to force
+ # the endpoint to set up a receive loop.
+ with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as s:
+ await s.bind(("127.0.0.1", 0))
+ c = e.connect(s.getsockname(), client_ctx)
+ async with trio.open_nursery() as nursery:
+ nursery.start_soon(c.do_handshake)
+ await trio.testing.wait_all_tasks_blocked()
+ nursery.cancel_scope.cancel()
+
+ during_tasks = trio.lowlevel.current_statistics().tasks_living
+ return during_tasks
+
+ with pytest.warns(ResourceWarning):
+ during_tasks = await start_and_forget_endpoint()
+ await trio.testing.wait_all_tasks_blocked()
+ gc_collect_harder()
+
+ await trio.testing.wait_all_tasks_blocked()
+
+ after_tasks = trio.lowlevel.current_statistics().tasks_living
+ assert before_tasks < during_tasks
+ assert before_tasks == after_tasks
+
+
+@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning")
+async def test_gc_before_system_task_starts():
+ e = endpoint()
+
+ with pytest.warns(ResourceWarning):
+ del e
+ gc_collect_harder()
+
+ await trio.testing.wait_all_tasks_blocked()
+
+
+@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning")
+async def test_gc_as_packet_received():
+ fn = FakeNet()
+ fn.enable()
+
+ e = endpoint()
+ await e.socket.bind(("127.0.0.1", 0))
+ e._ensure_receive_loop()
+
+ await trio.testing.wait_all_tasks_blocked()
+
+ with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as s:
+ await s.sendto(b"xxx", e.socket.getsockname())
+ # At this point, the endpoint's receive loop has been marked runnable because it
+ # just received a packet; closing the endpoint socket won't interrupt that. But by
+ # the time it wakes up to process the packet, the endpoint will be gone.
+ with pytest.warns(ResourceWarning):
+ del e
+ gc_collect_harder()
+
+
+@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning")
+def test_gc_after_trio_exits():
+ async def main():
+ # We use fakenet just to make sure no real sockets can leak out of the test
+ # case - on pypy somehow the socket was outliving the gc_collect_harder call
+ # below. Since the test is just making sure DTLSEndpoint.__del__ doesn't explode
+ # when called after trio exits, it doesn't need a real socket.
+ fn = FakeNet()
+ fn.enable()
+ return endpoint()
+
+ e = trio.run(main)
+ with pytest.warns(ResourceWarning):
+ del e
+ gc_collect_harder()
+
+
+async def test_already_closed_socket_doesnt_crash():
+ with endpoint() as e:
+ # We close the socket before checkpointing, so the socket will already be closed
+ # when the system task starts up
+ e.socket.close()
+ # Now give it a chance to start up, and hopefully not crash
+ await trio.testing.wait_all_tasks_blocked()
+
+
+async def test_socket_closed_while_processing_clienthello(autojump_clock):
+ fn = FakeNet()
+ fn.enable()
+
+ # Check what happens if the socket is discovered to be closed when sending a
+ # HelloVerifyRequest, since that has its own sending logic
+ async with dtls_echo_server() as (server, address):
+
+ def route_packet(packet):
+ fn.deliver_packet(packet)
+ server.socket.close()
+
+ fn.route_packet = route_packet
+
+ with endpoint() as client_endpoint:
+ with trio.move_on_after(10):
+ client = client_endpoint.connect(address, client_ctx)
+ await client.do_handshake()
+
+
+async def test_association_replaced_while_handshake_running(autojump_clock):
+ fn = FakeNet()
+ fn.enable()
+
+ def route_packet(packet):
+ pass
+
+ fn.route_packet = route_packet
+
+ async with dtls_echo_server() as (_, address):
+ with endpoint() as client_endpoint:
+ c1 = client_endpoint.connect(address, client_ctx)
+ async with trio.open_nursery() as nursery:
+
+ async def doomed_handshake():
+ with pytest.raises(trio.BrokenResourceError):
+ await c1.do_handshake()
+
+ nursery.start_soon(doomed_handshake)
+
+ await trio.sleep(10)
+
+ client_endpoint.connect(address, client_ctx)
+
+
+async def test_association_replaced_before_handshake_starts():
+ fn = FakeNet()
+ fn.enable()
+
+ # This test shouldn't send any packets
+ def route_packet(packet): # pragma: no cover
+ assert False
+
+ fn.route_packet = route_packet
+
+ async with dtls_echo_server() as (_, address):
+ with endpoint() as client_endpoint:
+ c1 = client_endpoint.connect(address, client_ctx)
+ client_endpoint.connect(address, client_ctx)
+ with pytest.raises(trio.BrokenResourceError):
+ await c1.do_handshake()
+
+
+async def test_send_to_closed_local_port():
+ # On Windows, sending a UDP packet to a closed local port can cause a weird
+ # ECONNRESET error later, inside the receive task. Make sure we're handling it
+ # properly.
+ async with dtls_echo_server() as (_, address):
+ with endpoint() as client_endpoint:
+ async with trio.open_nursery() as nursery:
+ for i in range(1, 10):
+ channel = client_endpoint.connect(("127.0.0.1", i), client_ctx)
+ nursery.start_soon(channel.do_handshake)
+ channel = client_endpoint.connect(address, client_ctx)
+ await channel.send(b"xxx")
+ assert await channel.receive() == b"xxx"
+ nursery.cancel_scope.cancel()
diff --git a/trio/tests/test_fakenet.py b/trio/tests/test_fakenet.py
new file mode 100644
index 0000000000..bc691c9db5
--- /dev/null
+++ b/trio/tests/test_fakenet.py
@@ -0,0 +1,44 @@
+import pytest
+
+import trio
+from trio.testing._fake_net import FakeNet
+
+
+def fn():
+ fn = FakeNet()
+ fn.enable()
+ return fn
+
+
+async def test_basic_udp():
+ fn()
+ s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
+ s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
+
+ await s1.bind(("127.0.0.1", 0))
+ ip, port = s1.getsockname()
+ assert ip == "127.0.0.1"
+ assert port != 0
+ await s2.sendto(b"xyz", s1.getsockname())
+ data, addr = await s1.recvfrom(10)
+ assert data == b"xyz"
+ assert addr == s2.getsockname()
+ await s1.sendto(b"abc", s2.getsockname())
+ data, addr = await s2.recvfrom(10)
+ assert data == b"abc"
+ assert addr == s1.getsockname()
+
+
+async def test_msg_trunc():
+ fn()
+ s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
+ s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
+ await s1.bind(("127.0.0.1", 0))
+ await s2.sendto(b"xyz", s1.getsockname())
+ data, addr = await s1.recvfrom(10)
+
+
+async def test_basic_tcp():
+ fn()
+ with pytest.raises(NotImplementedError):
+ trio.socket.socket()
diff --git a/trio/tests/test_socket.py b/trio/tests/test_socket.py
index d891041ab2..1fa3721f91 100644
--- a/trio/tests/test_socket.py
+++ b/trio/tests/test_socket.py
@@ -495,21 +495,20 @@ def assert_eq(actual, expected):
with tsocket.socket(family=socket_type) as sock:
# For some reason the stdlib special-cases "" to pass NULL to
- # getaddrinfo They also error out on None, but whatever, None is much
+ # getaddrinfo. They also error out on None, but whatever, None is much
# more consistent, so we accept it too.
for null in [None, ""]:
- got = await sock._resolve_local_address_nocp((null, 80))
+ got = await sock._resolve_address_nocp((null, 80), local=True)
assert_eq(got, (addrs.bind_all, 80))
- got = await sock._resolve_remote_address_nocp((null, 80))
+ got = await sock._resolve_address_nocp((null, 80), local=False)
assert_eq(got, (addrs.localhost, 80))
# AI_PASSIVE only affects the wildcard address, so for everything else
- # _resolve_local_address_nocp and _resolve_remote_address_nocp should
- # work the same:
- for resolver in ["_resolve_local_address_nocp", "_resolve_remote_address_nocp"]:
+ # local=True/local=False should work the same:
+ for local in [False, True]:
async def res(*args):
- return await getattr(sock, resolver)(*args)
+ return await sock._resolve_address_nocp(*args, local=local)
assert_eq(await res((addrs.arbitrary, "http")), (addrs.arbitrary, 80))
if v6:
@@ -560,7 +559,10 @@ async def res(*args):
except (AttributeError, OSError):
pass
else:
- assert await getattr(netlink_sock, resolver)("asdf") == "asdf"
+ assert (
+ await netlink_sock._resolve_address_nocp("asdf", local=local)
+ == "asdf"
+ )
netlink_sock.close()
with pytest.raises(ValueError):
@@ -721,16 +723,16 @@ def connect(self, *args, **kwargs):
await sock.connect(("127.0.0.1", 2))
-async def test_resolve_remote_address_exception_closes_socket():
+async def test_resolve_address_exception_in_connect_closes_socket():
# Here we are testing issue 247, any cancellation will leave the socket closed
with _core.CancelScope() as cancel_scope:
with tsocket.socket() as sock:
- async def _resolve_remote_address_nocp(self, *args, **kwargs):
+ async def _resolve_address_nocp(self, *args, **kwargs):
cancel_scope.cancel()
await _core.checkpoint()
- sock._resolve_remote_address_nocp = _resolve_remote_address_nocp
+ sock._resolve_address_nocp = _resolve_address_nocp
with assert_checkpoints():
with pytest.raises(_core.Cancelled):
await sock.connect("")