diff --git a/docs/source/conf.py b/docs/source/conf.py index 6045ffd828..52872d0cb4 100755 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -87,6 +87,7 @@ def setup(app): intersphinx_mapping = { "python": ('https://docs.python.org/3', None), "outcome": ('https://outcome.readthedocs.io/en/latest/', None), + "pyopenssl": ('https://www.pyopenssl.org/en/stable/', None), } autodoc_member_order = "bysource" diff --git a/docs/source/reference-io.rst b/docs/source/reference-io.rst index b18983e272..a3291ef2ae 100644 --- a/docs/source/reference-io.rst +++ b/docs/source/reference-io.rst @@ -258,6 +258,52 @@ you call them before the handshake completes: .. autoexception:: NeedHandshakeError +Datagram TLS support +~~~~~~~~~~~~~~~~~~~~ + +Trio also has support for Datagram TLS (DTLS), which is like TLS but +for unreliable UDP connections. This can be useful for applications +where TCP's reliable in-order delivery is problematic, like +teleconferencing, latency-sensitive games, and VPNs. + +Currently, using DTLS with Trio requires PyOpenSSL. We hope to +eventually allow the use of the stdlib `ssl` module as well, but +unfortunately that's not yet possible. + +.. warning:: Note that PyOpenSSL is in many ways lower-level than the + `ssl` module – in particular, it currently **HAS NO BUILT-IN + MECHANISM TO VALIDATE CERTIFICATES**. We *strongly* recommend that + you use the `service-identity + `__ library to validate + hostnames and certificates. + +.. autoclass:: DTLSEndpoint + + .. automethod:: connect + + .. automethod:: serve + + .. automethod:: close + +.. autoclass:: DTLSChannel + :show-inheritance: + + .. automethod:: do_handshake + + .. automethod:: send + + .. automethod:: receive + + .. automethod:: close + + .. automethod:: aclose + + .. automethod:: set_ciphertext_mtu + + .. automethod:: get_cleartext_mtu + + .. automethod:: statistics + .. module:: trio.socket Low-level networking with :mod:`trio.socket` diff --git a/newsfragments/2010.feature.rst b/newsfragments/2010.feature.rst new file mode 100644 index 0000000000..f99687652f --- /dev/null +++ b/newsfragments/2010.feature.rst @@ -0,0 +1,4 @@ +Added support for `Datagram TLS +`__, +for secure communication over UDP. Currently requires `PyOpenSSL +`__. diff --git a/test-requirements.in b/test-requirements.in index 3478604218..186b13392e 100644 --- a/test-requirements.in +++ b/test-requirements.in @@ -3,8 +3,8 @@ pytest >= 5.0 # for faulthandler in core pytest-cov >= 2.6.0 # ipython 7.x is the last major version supporting Python 3.7 ipython < 7.32 # for the IPython traceback integration tests -pyOpenSSL # for the ssl tests -trustme # for the ssl tests +pyOpenSSL >= 22.0.0 # for the ssl + DTLS tests +trustme # for the ssl + DTLS tests pylint # for pylint finding all symbols tests jedi # for jedi code completion tests cryptography>=36.0.0 # 35.0.0 is transitive but fails @@ -12,6 +12,7 @@ cryptography>=36.0.0 # 35.0.0 is transitive but fails # Tools black; implementation_name == "cpython" mypy; implementation_name == "cpython" +types-pyOpenSSL; implementation_name == "cpython" flake8 astor # code generation diff --git a/test-requirements.txt b/test-requirements.txt index 17e6697e83..5aabe871de 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -134,10 +134,21 @@ traitlets==5.3.0 # matplotlib-inline trustme==0.9.0 # via -r test-requirements.in +types-cryptography==3.3.14 + # via types-pyopenssl +types-enum34==1.1.8 + # via types-cryptography +types-ipaddress==1.0.7 + # via types-cryptography +types-pyopenssl==21.0.3 ; implementation_name == "cpython" + # via -r test-requirements.in typing-extensions==4.3.0 ; implementation_name == "cpython" # via # -r test-requirements.in + # astroid + # black # mypy + # pylint wcwidth==0.2.5 # via prompt-toolkit wrapt==1.14.1 diff --git a/trio/__init__.py b/trio/__init__.py index a50ec33310..b35fa076b3 100644 --- a/trio/__init__.py +++ b/trio/__init__.py @@ -74,6 +74,8 @@ from ._ssl import SSLStream, SSLListener, NeedHandshakeError +from ._dtls import DTLSEndpoint, DTLSChannel + from ._highlevel_serve_listeners import serve_listeners from ._highlevel_open_tcp_stream import open_tcp_stream diff --git a/trio/_dtls.py b/trio/_dtls.py new file mode 100644 index 0000000000..910637455a --- /dev/null +++ b/trio/_dtls.py @@ -0,0 +1,1276 @@ +# Implementation of DTLS 1.2, using pyopenssl +# https://datatracker.ietf.org/doc/html/rfc6347 +# +# OpenSSL's APIs for DTLS are extremely awkward and limited, which forces us to jump +# through a *lot* of hoops and implement important chunks of the protocol ourselves. +# Hopefully they fix this before implementing DTLS 1.3, because it's a very different +# protocol, and it's probably impossible to pull tricks like we do here. + +import struct +import hmac +import os +import enum +from itertools import count +import weakref +import errno +import warnings + +import attr + +import trio +from trio._util import NoPublicConstructor, Final + +MAX_UDP_PACKET_SIZE = 65527 + + +def packet_header_overhead(sock): + if sock.family == trio.socket.AF_INET: + return 28 + else: + return 48 + + +def worst_case_mtu(sock): + if sock.family == trio.socket.AF_INET: + return 576 - packet_header_overhead(sock) + else: + return 1280 - packet_header_overhead(sock) + + +def best_guess_mtu(sock): + return 1500 - packet_header_overhead(sock) + + +# There are a bunch of different RFCs that define these codes, so for a +# comprehensive collection look here: +# https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml +class ContentType(enum.IntEnum): + change_cipher_spec = 20 + alert = 21 + handshake = 22 + application_data = 23 + heartbeat = 24 + + +class HandshakeType(enum.IntEnum): + hello_request = 0 + client_hello = 1 + server_hello = 2 + hello_verify_request = 3 + new_session_ticket = 4 + end_of_early_data = 4 + encrypted_extensions = 8 + certificate = 11 + server_key_exchange = 12 + certificate_request = 13 + server_hello_done = 14 + certificate_verify = 15 + client_key_exchange = 16 + finished = 20 + certificate_url = 21 + certificate_status = 22 + supplemental_data = 23 + key_update = 24 + compressed_certificate = 25 + ekt_key = 26 + message_hash = 254 + + +class ProtocolVersion: + DTLS10 = bytes([254, 255]) + DTLS12 = bytes([254, 253]) + + +EPOCH_MASK = 0xFFFF << (6 * 8) + + +# Conventions: +# - All functions that handle network data end in _untrusted. +# - All functions end in _untrusted MUST make sure that bad data from the +# network cannot *only* cause BadPacket to be raised. No IndexError or +# struct.error or whatever. +class BadPacket(Exception): + pass + + +# This checks that the DTLS 'epoch' field is 0, which is true iff we're in the +# initial handshake. It doesn't check the ContentType, because not all +# handshake messages have ContentType==handshake -- for example, +# ChangeCipherSpec is used during the handshake but has its own ContentType. +# +# Cannot fail. +def part_of_handshake_untrusted(packet): + # If the packet is too short, then slicing will successfully return a + # short string, which will necessarily fail to match. + return packet[3:5] == b"\x00\x00" + + +# Cannot fail +def is_client_hello_untrusted(packet): + try: + return ( + packet[0] == ContentType.handshake + and packet[13] == HandshakeType.client_hello + ) + except IndexError: + # Invalid DTLS record + return False + + +# DTLS records are: +# - 1 byte content type +# - 2 bytes version +# - 8 bytes epoch+seqno +# Technically this is 2 bytes epoch then 6 bytes seqno, but we treat it as +# a single 8-byte integer, where epoch changes are represented as jumping +# forward by 2**(6*8). +# - 2 bytes payload length (unsigned big-endian) +# - payload +RECORD_HEADER = struct.Struct("!B2sQH") + + +def to_hex(data: bytes) -> str: # pragma: no cover + return data.hex() + + +@attr.frozen +class Record: + content_type: int + version: bytes = attr.ib(repr=to_hex) + epoch_seqno: int + payload: bytes = attr.ib(repr=to_hex) + + +def records_untrusted(packet): + i = 0 + while i < len(packet): + try: + ct, version, epoch_seqno, payload_len = RECORD_HEADER.unpack_from(packet, i) + # Marked as no-cover because at time of writing, this code is unreachable + # (records_untrusted only gets called on packets that are either trusted or that + # have passed is_client_hello_untrusted, which filters out short packets) + except struct.error as exc: # pragma: no cover + raise BadPacket("invalid record header") from exc + i += RECORD_HEADER.size + payload = packet[i : i + payload_len] + if len(payload) != payload_len: + raise BadPacket("short record") + i += payload_len + yield Record(ct, version, epoch_seqno, payload) + + +def encode_record(record): + header = RECORD_HEADER.pack( + record.content_type, + record.version, + record.epoch_seqno, + len(record.payload), + ) + return header + record.payload + + +# Handshake messages are: +# - 1 byte message type +# - 3 bytes total message length +# - 2 bytes message sequence number +# - 3 bytes fragment offset +# - 3 bytes fragment length +HANDSHAKE_MESSAGE_HEADER = struct.Struct("!B3sH3s3s") + + +@attr.frozen +class HandshakeFragment: + msg_type: int + msg_len: int + msg_seq: int + frag_offset: int + frag_len: int + frag: bytes = attr.ib(repr=to_hex) + + +def decode_handshake_fragment_untrusted(payload): + # Raises BadPacket if decoding fails + try: + ( + msg_type, + msg_len_bytes, + msg_seq, + frag_offset_bytes, + frag_len_bytes, + ) = HANDSHAKE_MESSAGE_HEADER.unpack_from(payload) + except struct.error as exc: + raise BadPacket("bad handshake message header") from exc + # 'struct' doesn't have built-in support for 24-bit integers, so we + # have to do it by hand. These can't fail. + msg_len = int.from_bytes(msg_len_bytes, "big") + frag_offset = int.from_bytes(frag_offset_bytes, "big") + frag_len = int.from_bytes(frag_len_bytes, "big") + frag = payload[HANDSHAKE_MESSAGE_HEADER.size :] + if len(frag) != frag_len: + raise BadPacket("handshake fragment length doesn't match record length") + return HandshakeFragment( + msg_type, + msg_len, + msg_seq, + frag_offset, + frag_len, + frag, + ) + + +def encode_handshake_fragment(hsf): + hs_header = HANDSHAKE_MESSAGE_HEADER.pack( + hsf.msg_type, + hsf.msg_len.to_bytes(3, "big"), + hsf.msg_seq, + hsf.frag_offset.to_bytes(3, "big"), + hsf.frag_len.to_bytes(3, "big"), + ) + return hs_header + hsf.frag + + +def decode_client_hello_untrusted(packet): + # Raises BadPacket if parsing fails + # Returns (record epoch_seqno, cookie from the packet, data that should be + # hashed into cookie) + try: + # ClientHello has to be the first record in the packet + record = next(records_untrusted(packet)) + # no-cover because at time of writing, this is unreachable: + # decode_client_hello_untrusted is only called on packets that have passed + # is_client_hello_untrusted, which confirms the content type. + if record.content_type != ContentType.handshake: # pragma: no cover + raise BadPacket("not a handshake record") + fragment = decode_handshake_fragment_untrusted(record.payload) + if fragment.msg_type != HandshakeType.client_hello: + raise BadPacket("not a ClientHello") + # ClientHello can't be fragmented, because reassembly requires holding + # per-connection state, and we refuse to allocate per-connection state + # until after we get a valid ClientHello. + if fragment.frag_offset != 0: + raise BadPacket("fragmented ClientHello") + if fragment.frag_len != fragment.msg_len: + raise BadPacket("fragmented ClientHello") + + # As per RFC 6347: + # + # When responding to a HelloVerifyRequest, the client MUST use the + # same parameter values (version, random, session_id, cipher_suites, + # compression_method) as it did in the original ClientHello. The + # server SHOULD use those values to generate its cookie and verify that + # they are correct upon cookie receipt. + # + # However, the record-layer framing can and will change (e.g. the + # second ClientHello will have a new record-layer sequence number). So + # we need to pull out the handshake message alone, discarding the + # record-layer stuff, and then we're going to hash all of it *except* + # the cookie. + + body = fragment.frag + # ClientHello is: + # + # - 2 bytes client_version + # - 32 bytes random + # - 1 byte session_id length + # - session_id + # - 1 byte cookie length + # - cookie + # - everything else + # + # So to find the cookie, so we need to figure out how long the + # session_id is and skip past it. + session_id_len = body[2 + 32] + cookie_len_offset = 2 + 32 + 1 + session_id_len + cookie_len = body[cookie_len_offset] + + cookie_start = cookie_len_offset + 1 + cookie_end = cookie_start + cookie_len + + before_cookie = body[:cookie_len_offset] + cookie = body[cookie_start:cookie_end] + after_cookie = body[cookie_end:] + + if len(cookie) != cookie_len: + raise BadPacket("short cookie") + return (record.epoch_seqno, cookie, before_cookie + after_cookie) + + except (struct.error, IndexError) as exc: + raise BadPacket("bad ClientHello") from exc + + +@attr.frozen +class HandshakeMessage: + record_version: bytes = attr.ib(repr=to_hex) + msg_type: HandshakeType + msg_seq: int + body: bytearray = attr.ib(repr=to_hex) + + +# ChangeCipherSpec is part of the handshake, but it's not a "handshake +# message" and can't be fragmented the same way. Sigh. +@attr.frozen +class PseudoHandshakeMessage: + record_version: bytes = attr.ib(repr=to_hex) + content_type: int + payload: bytes = attr.ib(repr=to_hex) + + +# The final record in a handshake is Finished, which is encrypted, can't be fragmented +# (at least by us), and keeps its record number (because it's in a new epoch). So we +# just pass it through unchanged. (Fortunately, the payload is only a single hash value, +# so the largest it will ever be is 64 bytes for a 512-bit hash. Which is small enough +# that it never requires fragmenting to fit into a UDP packet. +@attr.frozen +class OpaqueHandshakeMessage: + record: Record + + +# This takes a raw outgoing handshake volley that openssl generated, and +# reconstructs the handshake messages inside it, so that we can repack them +# into records while retransmitting. So the data ought to be well-behaved -- +# it's not coming from the network. +def decode_volley_trusted(volley): + messages = [] + messages_by_seq = {} + for record in records_untrusted(volley): + # ChangeCipherSpec isn't a handshake message, so it can't be fragmented. + # Handshake messages with epoch > 0 are encrypted, so we can't fragment them + # either. Fortunately, ChangeCipherSpec has a 1 byte payload, and the only + # encrypted handshake message is Finished, whose payload is a single hash value + # -- so 32 bytes for SHA-256, 64 for SHA-512, etc. Neither is going to be so + # large that it has to be fragmented to fit into a single packet. + if record.epoch_seqno & EPOCH_MASK: + messages.append(OpaqueHandshakeMessage(record)) + elif record.content_type in (ContentType.change_cipher_spec, ContentType.alert): + messages.append( + PseudoHandshakeMessage( + record.version, record.content_type, record.payload + ) + ) + else: + assert record.content_type == ContentType.handshake + fragment = decode_handshake_fragment_untrusted(record.payload) + msg_type = HandshakeType(fragment.msg_type) + if fragment.msg_seq not in messages_by_seq: + msg = HandshakeMessage( + record.version, + msg_type, + fragment.msg_seq, + bytearray(fragment.msg_len), + ) + messages.append(msg) + messages_by_seq[fragment.msg_seq] = msg + else: + msg = messages_by_seq[fragment.msg_seq] + assert msg.msg_type == fragment.msg_type + assert msg.msg_seq == fragment.msg_seq + assert len(msg.body) == fragment.msg_len + + msg.body[ + fragment.frag_offset : fragment.frag_offset + fragment.frag_len + ] = fragment.frag + + return messages + + +class RecordEncoder: + def __init__(self): + self._record_seq = count() + + def set_first_record_number(self, n): + self._record_seq = count(n) + + def encode_volley(self, messages, mtu): + packets = [] + packet = bytearray() + for message in messages: + if isinstance(message, OpaqueHandshakeMessage): + encoded = encode_record(message.record) + if mtu - len(packet) - len(encoded) <= 0: + packets.append(packet) + packet = bytearray() + packet += encoded + assert len(packet) <= mtu + elif isinstance(message, PseudoHandshakeMessage): + space = mtu - len(packet) - RECORD_HEADER.size - len(message.payload) + if space <= 0: + packets.append(packet) + packet = bytearray() + packet += RECORD_HEADER.pack( + message.content_type, + message.record_version, + next(self._record_seq), + len(message.payload), + ) + packet += message.payload + assert len(packet) <= mtu + else: + msg_len_bytes = len(message.body).to_bytes(3, "big") + frag_offset = 0 + frags_encoded = 0 + # If message.body is empty, then we still want to encode it in one + # fragment, not zero. + while frag_offset < len(message.body) or not frags_encoded: + space = ( + mtu + - len(packet) + - RECORD_HEADER.size + - HANDSHAKE_MESSAGE_HEADER.size + ) + if space <= 0: + packets.append(packet) + packet = bytearray() + continue + frag = message.body[frag_offset : frag_offset + space] + frag_offset_bytes = frag_offset.to_bytes(3, "big") + frag_len_bytes = len(frag).to_bytes(3, "big") + frag_offset += len(frag) + + packet += RECORD_HEADER.pack( + ContentType.handshake, + message.record_version, + next(self._record_seq), + HANDSHAKE_MESSAGE_HEADER.size + len(frag), + ) + + packet += HANDSHAKE_MESSAGE_HEADER.pack( + message.msg_type, + msg_len_bytes, + message.msg_seq, + frag_offset_bytes, + frag_len_bytes, + ) + + packet += frag + + frags_encoded += 1 + assert len(packet) <= mtu + + if packet: + packets.append(packet) + + return packets + + +# This bit requires implementing a bona fide cryptographic protocol, so even though it's +# a simple one let's take a moment to discuss the design. +# +# Our goal is to force new incoming handshakes that claim to be coming from a +# given ip:port to prove that they can also receive packets sent to that +# ip:port. (There's nothing in UDP to stop someone from forging the return +# address, and it's often used for stuff like DoS reflection attacks, where +# an attacker tries to trick us into sending data at some innocent victim.) +# For more details, see: +# +# https://datatracker.ietf.org/doc/html/rfc6347#section-4.2.1 +# +# To do this, when we receive an initial ClientHello, we calculate a magic +# cookie, and send it back as a HelloVerifyRequest. Then the client sends us a +# second ClientHello, this time with the magic cookie included, and after we +# check that this cookie is valid we go ahead and start the handshake proper. +# +# So the magic cookie needs the following properties: +# - No-one can forge it without knowing our secret key +# - It ensures that the ip, port, and ClientHello contents from the response +# match those in the challenge +# - It expires after a short-ish period (so that if an attacker manages to steal one, it +# won't be useful for long) +# - It doesn't require storing any peer-specific state on our side +# +# To do that, we take the ip/port/ClientHello data and compute an HMAC of them, using a +# secret key we generate on startup. We also include: +# +# - The current time (using Trio's clock), rounded to the nearest 30 seconds +# - A random salt +# +# Then the cookie is the salt and the HMAC digest concatenated together. +# +# When verifying a cookie, we use the salt + new ip/port/ClientHello data to recompute +# the HMAC digest, for both the current time and the current time minus 30 seconds, and +# if either of them match, we consider the cookie good. +# +# Including the rounded-off time like this means that each cookie is good for at least +# 30 seconds, and possibly as much as 60 seconds. +# +# The salt is probably not necessary -- I'm pretty sure that all it does is make it hard +# for an attacker to figure out when our clock ticks over a 30 second boundary. Which is +# probably pretty harmless? But it's easier to add the salt than to convince myself that +# it's *completely* harmless, so, salt it is. + +COOKIE_REFRESH_INTERVAL = 30 # seconds +KEY_BYTES = 32 +COOKIE_HASH = "sha256" +SALT_BYTES = 8 +# 32 bytes was the maximum cookie length in DTLS 1.0. DTLS 1.2 raised it to 255. I doubt +# there are any DTLS 1.0 implementations still in the wild, but really 32 bytes is +# plenty, and it also gets rid of a confusing warning in Wireshark output. +# +# We truncate the cookie to 32 bytes, of which 8 bytes is salt, so that leaves 24 bytes +# of truncated HMAC = 192 bit security, which is still massive overkill. (TCP uses 32 +# *bits* for this.) HMAC truncation is explicitly noted as safe in RFC 2104: +# https://datatracker.ietf.org/doc/html/rfc2104#section-5 +COOKIE_LENGTH = 32 + + +def _current_cookie_tick(): + return int(trio.current_time() / COOKIE_REFRESH_INTERVAL) + + +# Simple deterministic and invertible serializer -- i.e., a useful tool for converting +# structured data into something we can cryptographically sign. +def _signable(*fields): + out = [] + for field in fields: + out.append(struct.pack("!Q", len(field))) + out.append(field) + return b"".join(out) + + +def _make_cookie(key, salt, tick, address, client_hello_bits): + assert len(salt) == SALT_BYTES + assert len(key) == KEY_BYTES + + signable_data = _signable( + salt, + struct.pack("!Q", tick), + # address is a mix of strings and ints, and variable length, so pack + # it into a single nested field + _signable(*(str(part).encode() for part in address)), + client_hello_bits, + ) + + return (salt + hmac.digest(key, signable_data, COOKIE_HASH))[:COOKIE_LENGTH] + + +def valid_cookie(key, cookie, address, client_hello_bits): + if len(cookie) > SALT_BYTES: + salt = cookie[:SALT_BYTES] + + tick = _current_cookie_tick() + + cur_cookie = _make_cookie(key, salt, tick, address, client_hello_bits) + old_cookie = _make_cookie( + key, salt, max(tick - 1, 0), address, client_hello_bits + ) + + # I doubt using a short-circuiting 'or' here would leak any meaningful + # information, but why risk it when '|' is just as easy. + return hmac.compare_digest(cookie, cur_cookie) | hmac.compare_digest( + cookie, old_cookie + ) + else: + return False + + +def challenge_for(key, address, epoch_seqno, client_hello_bits): + salt = os.urandom(SALT_BYTES) + tick = _current_cookie_tick() + cookie = _make_cookie(key, salt, tick, address, client_hello_bits) + + # HelloVerifyRequest body is: + # - 2 bytes version + # - length-prefixed cookie + # + # The DTLS 1.2 spec says that for this message specifically we should use + # the DTLS 1.0 version. + # + # (It also says the opposite of that, but that part is a mistake: + # https://www.rfc-editor.org/errata/eid4103 + # ). + # + # And I guess we use this for both the message-level and record-level + # ProtocolVersions, since we haven't negotiated anything else yet? + body = ProtocolVersion.DTLS10 + bytes([len(cookie)]) + cookie + + # RFC says have to copy the client's record number + # Errata says it should be handshake message number + # Openssl copies back record sequence number, and always sets message seq + # number 0. So I guess we'll follow openssl. + hs = HandshakeFragment( + msg_type=HandshakeType.hello_verify_request, + msg_len=len(body), + msg_seq=0, + frag_offset=0, + frag_len=len(body), + frag=body, + ) + payload = encode_handshake_fragment(hs) + + packet = encode_record( + Record(ContentType.handshake, ProtocolVersion.DTLS10, epoch_seqno, payload) + ) + return packet + + +class _Queue: + def __init__(self, incoming_packets_buffer): + self.s, self.r = trio.open_memory_channel(incoming_packets_buffer) + + +def _read_loop(read_fn): + chunks = [] + while True: + try: + chunk = read_fn(2**14) # max TLS record size + except SSL.WantReadError: + break + chunks.append(chunk) + return b"".join(chunks) + + +async def handle_client_hello_untrusted(endpoint, address, packet): + if endpoint._listening_context is None: + return + + try: + epoch_seqno, cookie, bits = decode_client_hello_untrusted(packet) + except BadPacket: + return + + if endpoint._listening_key is None: + endpoint._listening_key = os.urandom(KEY_BYTES) + + if not valid_cookie(endpoint._listening_key, cookie, address, bits): + challenge_packet = challenge_for( + endpoint._listening_key, address, epoch_seqno, bits + ) + try: + async with endpoint._send_lock: + await endpoint.socket.sendto(challenge_packet, address) + except (OSError, trio.ClosedResourceError): + pass + else: + # We got a real, valid ClientHello! + stream = DTLSChannel._create(endpoint, address, endpoint._listening_context) + # Our HelloRetryRequest had some sequence number. We need our future sequence + # numbers to be larger than it, so our peer knows that our future records aren't + # stale/duplicates. But, we don't know what this sequence number was. What we do + # know is: + # - the HelloRetryRequest seqno was copied it from the initial ClientHello + # - the new ClientHello has a higher seqno than the initial ClientHello + # So, if we copy the new ClientHello's seqno into our first real handshake + # record and increment from there, that should work. + stream._record_encoder.set_first_record_number(epoch_seqno) + # Process the ClientHello + try: + stream._ssl.bio_write(packet) + stream._ssl.DTLSv1_listen() + except SSL.Error: + # ...OpenSSL didn't like it, so I guess we didn't have a valid ClientHello + # after all. + return + + # Some old versions of OpenSSL have a bug with memory BIOs, where DTLSv1_listen + # consumes the ClientHello out of the BIO, but then do_handshake expects the + # ClientHello to still be in there (but not the one that ships with Ubuntu + # 20.04). In particular, this is known to affect the OpenSSL v1.1.1 that ships + # with Ubuntu 18.04. To work around this, we deliver a second copy of the + # ClientHello after DTLSv1_listen has completed. This is safe to do + # unconditionally, because on newer versions of OpenSSL, the second ClientHello + # is treated as a duplicate packet, which is a normal thing that can happen over + # UDP. For more details, see: + # + # https://github.com/pyca/pyopenssl/blob/e84e7b57d1838de70ab7a27089fbee78ce0d2106/tests/test_ssl.py#L4226-L4293 + # + # This was fixed in v1.1.1a, and all later versions. So maybe in 2024 or so we + # can delete this. The fix landed in OpenSSL master as 079ef6bd534d2, and then + # was backported to the 1.1.1 branch as d1bfd8076e28. + stream._ssl.bio_write(packet) + + # Check if we have an existing association + old_stream = endpoint._streams.get(address) + if old_stream is not None: + if old_stream._client_hello == (cookie, bits): + # ...This was just a duplicate of the last ClientHello, so never mind. + return + else: + # Ok, this *really is* a new handshake; the old stream should go away. + old_stream._set_replaced() + stream._client_hello = (cookie, bits) + endpoint._streams[address] = stream + endpoint._incoming_connections_q.s.send_nowait(stream) + + +async def dtls_receive_loop(endpoint_ref, sock): + try: + while True: + try: + packet, address = await sock.recvfrom(MAX_UDP_PACKET_SIZE) + except OSError as exc: + if exc.errno == errno.ECONNRESET: + # Windows only: "On a UDP-datagram socket [ECONNRESET] + # indicates a previous send operation resulted in an ICMP Port + # Unreachable message" -- https://docs.microsoft.com/en-us/windows/win32/api/winsock/nf-winsock-recvfrom + # + # This is totally useless -- there's nothing we can do with this + # information. So we just ignore it and retry the recv. + continue + else: + raise + endpoint = endpoint_ref() + try: + if endpoint is None: + return + if is_client_hello_untrusted(packet): + await handle_client_hello_untrusted(endpoint, address, packet) + elif address in endpoint._streams: + stream = endpoint._streams[address] + if stream._did_handshake and part_of_handshake_untrusted(packet): + # The peer just sent us more handshake messages, that aren't a + # ClientHello, and we thought the handshake was done. Some of + # the packets that we sent to finish the handshake must have + # gotten lost. So re-send them. We do this directly here instead + # of just putting it into the queue and letting the receiver do + # it, because there's no guarantee that anyone is reading from + # the queue, because we think the handshake is done! + await stream._resend_final_volley() + else: + try: + stream._q.s.send_nowait(packet) + except trio.WouldBlock: + stream._packets_dropped_in_trio += 1 + else: + # Drop packet + pass + finally: + del endpoint + except trio.ClosedResourceError: + # socket was closed + return + except OSError as exc: + if exc.errno in (errno.EBADF, errno.ENOTSOCK): + # socket was closed + return + else: # pragma: no cover + # ??? shouldn't happen + raise + + +@attr.frozen +class DTLSChannelStatistics: + incoming_packets_dropped_in_trio: int + + +class DTLSChannel(trio.abc.Channel[bytes], metaclass=NoPublicConstructor): + """A DTLS connection. + + This class has no public constructor – you get instances by calling + `DTLSEndpoint.serve` or `~DTLSEndpoint.connect`. + + .. attribute:: endpoint + + The `DTLSEndpoint` that this connection is using. + + .. attribute:: peer_address + + The IP/port of the remote peer that this connection is associated with. + + """ + + def __init__(self, endpoint, peer_address, ctx): + self.endpoint = endpoint + self.peer_address = peer_address + self._packets_dropped_in_trio = 0 + self._client_hello = None + self._did_handshake = False + # These are mandatory for all DTLS connections. OP_NO_QUERY_MTU is required to + # stop openssl from trying to query the memory BIO's MTU and then breaking, and + # OP_NO_RENEGOTIATION disables renegotiation, which is too complex for us to + # support and isn't useful anyway -- especially for DTLS where it's equivalent + # to just performing a new handshake. + ctx.set_options(SSL.OP_NO_QUERY_MTU | SSL.OP_NO_RENEGOTIATION) + self._ssl = SSL.Connection(ctx) + self._handshake_mtu = None + # This calls self._ssl.set_ciphertext_mtu, which is important, because if you + # don't call it then openssl doesn't work. + self.set_ciphertext_mtu(best_guess_mtu(self.endpoint.socket)) + self._replaced = False + self._closed = False + self._q = _Queue(endpoint.incoming_packets_buffer) + self._handshake_lock = trio.Lock() + self._record_encoder = RecordEncoder() + + def _set_replaced(self): + self._replaced = True + # Any packets we already received could maybe possibly still be processed, but + # there are no more coming. So we close this on the sender side. + self._q.s.close() + + def _check_replaced(self): + if self._replaced: + raise trio.BrokenResourceError( + "peer tore down this connection to start a new one" + ) + + # XX on systems where we can (maybe just Linux?) take advantage of the kernel's PMTU + # estimate + + # XX should we send close-notify when closing? It seems particularly pointless for + # DTLS where packets are all independent and can be lost anyway. We do at least need + # to handle receiving it properly though, which might be easier if we send it... + + def close(self): + """Close this connection. + + `DTLSChannel`\\s don't actually own any OS-level resources – the + socket is owned by the `DTLSEndpoint`, not the individual connections. So + you don't really *have* to call this. But it will interrupt any other tasks + calling `receive` with a `ClosedResourceError`, and cause future attempts to use + this connection to fail. + + You can also use this object as a synchronous or asynchronous context manager. + + """ + if self._closed: + return + self._closed = True + if self.endpoint._streams.get(self.peer_address) is self: + del self.endpoint._streams[self.peer_address] + # Will wake any tasks waiting on self._q.get with a + # ClosedResourceError + self._q.r.close() + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + async def aclose(self): + """Close this connection, but asynchronously. + + This is included to satisfy the `trio.abc.Channel` contract. It's + identical to `close`, but async. + + """ + self.close() + await trio.lowlevel.checkpoint() + + async def _send_volley(self, volley_messages): + packets = self._record_encoder.encode_volley( + volley_messages, self._handshake_mtu + ) + for packet in packets: + async with self.endpoint._send_lock: + await self.endpoint.socket.sendto(packet, self.peer_address) + + async def _resend_final_volley(self): + await self._send_volley(self._final_volley) + + async def do_handshake(self, *, initial_retransmit_timeout=1.0): + """Perform the handshake. + + Calling this is optional – if you don't, then it will be automatically called + the first time you call `send` or `receive`. But calling it explicitly can be + useful in case you want to control the retransmit timeout, use a cancel scope to + place an overall timeout on the handshake, or catch errors from the handshake + specifically. + + It's safe to call this multiple times, or call it simultaneously from multiple + tasks – the first call will perform the handshake, and the rest will be no-ops. + + Args: + + initial_retransmit_timeout (float): Since UDP is an unreliable protocol, it's + possible that some of the packets we send during the handshake will get + lost. To handle this, DTLS uses a timer to automatically retransmit + handshake packets that don't receive a response. This lets you set the + timeout we use to detect packet loss. Ideally, it should be set to ~1.5 + times the round-trip time to your peer, but 1 second is a reasonable + default. There's `some useful guidance here + `__. + + This is the *initial* timeout, because if packets keep being lost then Trio + will automatically back off to longer values, to avoid overloading the + network. + + """ + async with self._handshake_lock: + if self._did_handshake: + return + + timeout = initial_retransmit_timeout + volley_messages = [] + volley_failed_sends = 0 + + def read_volley(): + volley_bytes = _read_loop(self._ssl.bio_read) + new_volley_messages = decode_volley_trusted(volley_bytes) + if ( + new_volley_messages + and volley_messages + and isinstance(new_volley_messages[0], HandshakeMessage) + and new_volley_messages[0].msg_seq == volley_messages[0].msg_seq + ): + # openssl decided to retransmit; discard because we handle + # retransmits ourselves + return [] + else: + return new_volley_messages + + # If we're a client, we send the initial volley. If we're a server, then + # the initial ClientHello has already been inserted into self._ssl's + # read BIO. So either way, we start by generating a new volley. + try: + self._ssl.do_handshake() + except SSL.WantReadError: + pass + volley_messages = read_volley() + # If we don't have messages to send in our initial volley, then something + # has gone very wrong. (I'm not sure this can actually happen without an + # error from OpenSSL, but we check just in case.) + if not volley_messages: # pragma: no cover + raise SSL.Error("something wrong with peer's ClientHello") + + while True: + # -- at this point, we need to either send or re-send a volley -- + assert volley_messages + self._check_replaced() + await self._send_volley(volley_messages) + # -- then this is where we wait for a reply -- + self.endpoint._ensure_receive_loop() + with trio.move_on_after(timeout) as cscope: + async for packet in self._q.r: + self._ssl.bio_write(packet) + try: + self._ssl.do_handshake() + # We ignore generic SSL.Error here, because you can get those + # from random invalid packets + except (SSL.WantReadError, SSL.Error): + pass + else: + # No exception -> the handshake is done, and we can + # switch into data transfer mode. + self._did_handshake = True + # Might be empty, but that's ok -- we'll just send no + # packets. + self._final_volley = read_volley() + await self._send_volley(self._final_volley) + return + maybe_volley = read_volley() + if maybe_volley: + if ( + isinstance(maybe_volley[0], PseudoHandshakeMessage) + and maybe_volley[0].content_type == ContentType.alert + ): + # we're sending an alert (e.g. due to a corrupted + # packet). We want to send it once, but don't save it to + # retransmit -- keep the last volley as the current + # volley. + await self._send_volley(maybe_volley) + else: + # We managed to get all of the peer's volley and + # generate a new one ourselves! break out of the 'for' + # loop and restart the timer. + volley_messages = maybe_volley + # "Implementations SHOULD retain the current timer value + # until a transmission without loss occurs, at which + # time the value may be reset to the initial value." + if volley_failed_sends == 0: + timeout = initial_retransmit_timeout + volley_failed_sends = 0 + break + else: + assert self._replaced + self._check_replaced() + if cscope.cancelled_caught: + # Timeout expired. Double timeout for backoff, with a limit of 60 + # seconds (this matches what openssl does, and also the + # recommendation in draft-ietf-tls-dtls13). + timeout = min(2 * timeout, 60.0) + volley_failed_sends += 1 + if volley_failed_sends == 2: + # We tried sending this twice and they both failed. Maybe our + # PMTU estimate is wrong? Let's try dropping it to the minimum + # and hope that helps. + self._handshake_mtu = min( + self._handshake_mtu, worst_case_mtu(self.endpoint.socket) + ) + + async def send(self, data): + """Send a packet of data, securely.""" + + if self._closed: + raise trio.ClosedResourceError + if not data: + raise ValueError("openssl doesn't support sending empty DTLS packets") + if not self._did_handshake: + await self.do_handshake() + self._check_replaced() + self._ssl.write(data) + async with self.endpoint._send_lock: + await self.endpoint.socket.sendto( + _read_loop(self._ssl.bio_read), self.peer_address + ) + + async def receive(self): + """Fetch the next packet of data from this connection's peer, waiting if + necessary. + + This is safe to call from multiple tasks simultaneously, in case you have some + reason to do that. And more importantly, it's cancellation-safe, meaning that + cancelling a call to `receive` will never cause a packet to be lost or corrupt + the underlying connection. + + """ + if not self._did_handshake: + await self.do_handshake() + # If the packet isn't really valid, then openssl can decode it to the empty + # string (e.g. b/c it's a late-arriving handshake packet, or a duplicate copy of + # a data packet). Skip over these instead of returning them. + while True: + try: + packet = await self._q.r.receive() + except trio.EndOfChannel: + assert self._replaced + self._check_replaced() + self._ssl.bio_write(packet) + cleartext = _read_loop(self._ssl.read) + if cleartext: + return cleartext + + def set_ciphertext_mtu(self, new_mtu): + """Tells Trio the `largest amount of data that can be sent in a single packet to + this peer `__. + + Trio doesn't actually enforce this limit – if you pass a huge packet to `send`, + then we'll dutifully encrypt it and attempt to send it. But calling this method + does have two useful effects: + + - If called before the handshake is performed, then Trio will automatically + fragment handshake messages to fit within the given MTU. It also might + fragment them even smaller, if it detects signs of packet loss, so setting + this should never be necessary to make a successful connection. But, the + packet loss detection only happens after multiple timeouts have expired, so if + you have reason to believe that a smaller MTU is required, then you can set + this to skip those timeouts and establish the connection more quickly. + + - It changes the value returned from `get_cleartext_mtu`. So if you have some + kind of estimate of the network-level MTU, then you can use this to figure out + how much overhead DTLS will need for hashes/padding/etc., and how much space + you have left for your application data. + + The MTU here is measuring the largest UDP *payload* you think can be sent, the + amount of encrypted data that can be handed to the operating system in a single + call to `send`. It should *not* include IP/UDP headers. Note that OS estimates + of the MTU often are link-layer MTUs, so you have to subtract off 28 bytes on + IPv4 and 48 bytes on IPv6 to get the ciphertext MTU. + + By default, Trio assumes an MTU of 1472 bytes on IPv4, and 1452 bytes on IPv6, + which correspond to the common Ethernet MTU of 1500 bytes after accounting for + IP/UDP overhead. + + """ + self._handshake_mtu = new_mtu + self._ssl.set_ciphertext_mtu(new_mtu) + + def get_cleartext_mtu(self): + """Returns the largest number of bytes that you can pass in a single call to + `send` while still fitting within the network-level MTU. + + See `set_ciphertext_mtu` for more details. + + """ + if not self._did_handshake: + raise trio.NeedHandshakeError + return self._ssl.get_cleartext_mtu() + + def statistics(self): + """Returns an object with statistics about this connection. + + Currently this has only one attribute: + + - ``incoming_packets_dropped_in_trio`` (``int``): Gives a count of the number of + incoming packets from this peer that Trio successfully received from the + network, but then got dropped because the internal channel buffer was full. If + this is non-zero, then you might want to call ``receive`` more often, or use a + larger ``incoming_packets_buffer``, or just not worry about it because your + UDP-based protocol should be able to handle the occasional lost packet, right? + + """ + return DTLSChannelStatistics(self._packets_dropped_in_trio) + + +class DTLSEndpoint(metaclass=Final): + """A DTLS endpoint. + + A single UDP socket can handle arbitrarily many DTLS connections simultaneously, + acting as a client or server as needed. A `DTLSEndpoint` object holds a UDP socket + and manages these connections, which are represented as `DTLSChannel` objects. + + Args: + socket: (trio.socket.SocketType): A ``SOCK_DGRAM`` socket. If you want to accept + incoming connections in server mode, then you should probably bind the socket to + some known port. + incoming_packets_buffer (int): Each `DTLSChannel` using this socket has its own + buffer that holds incoming packets until you call `~DTLSChannel.receive` to read + them. This lets you adjust the size of this buffer. `~DTLSChannel.statistics` + lets you check if the buffer has overflowed. + + .. attribute:: socket + incoming_packets_buffer + + Both constructor arguments are also exposed as attributes, in case you need to + access them later. + + """ + + def __init__(self, socket, *, incoming_packets_buffer=10): + # We do this lazily on first construction, so only people who actually use DTLS + # have to install PyOpenSSL. + global SSL + from OpenSSL import SSL + + self.socket = None # for __del__, in case the next line raises + if socket.type != trio.socket.SOCK_DGRAM: + raise ValueError("DTLS requires a SOCK_DGRAM socket") + self.socket = socket + + self.incoming_packets_buffer = incoming_packets_buffer + self._token = trio.lowlevel.current_trio_token() + # We don't need to track handshaking vs non-handshake connections + # separately. We only keep one connection per remote address; as soon + # as a peer provides a valid cookie, we can immediately tear down the + # old connection. + # {remote address: DTLSChannel} + self._streams = weakref.WeakValueDictionary() + self._listening_context = None + self._listening_key = None + self._incoming_connections_q = _Queue(float("inf")) + self._send_lock = trio.Lock() + self._closed = False + self._receive_loop_spawned = False + + def _ensure_receive_loop(self): + # We have to spawn this lazily, because on Windows it will immediately error out + # if the socket isn't already bound -- which for clients might not happen until + # after we send our first packet. + if not self._receive_loop_spawned: + trio.lowlevel.spawn_system_task( + dtls_receive_loop, weakref.ref(self), self.socket + ) + self._receive_loop_spawned = True + + def __del__(self): + # Do nothing if this object was never fully constructed + if self.socket is None: + return + # Close the socket in Trio context (if our Trio context still exists), so that + # the background task gets notified about the closure and can exit. + if not self._closed: + try: + self._token.run_sync_soon(self.close) + except RuntimeError: + pass + # Do this last, because it might raise an exception + warnings.warn( + f"unclosed DTLS endpoint {self!r}", ResourceWarning, source=self + ) + + def close(self): + """Close this socket, and all associated DTLS connections. + + This object can also be used as a context manager. + + """ + self._closed = True + self.socket.close() + for stream in list(self._streams.values()): + stream.close() + self._incoming_connections_q.s.close() + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + def _check_closed(self): + if self._closed: + raise trio.ClosedResourceError + + async def serve( + self, ssl_context, async_fn, *args, task_status=trio.TASK_STATUS_IGNORED + ): + """Listen for incoming connections, and spawn a handler for each using an + internal nursery. + + Similar to `~trio.serve_tcp`, this function never returns until cancelled, or + the `DTLSEndpoint` is closed and all handlers have exited. + + Usage commonly looks like:: + + async def handler(dtls_channel): + ... + + async with trio.open_nursery() as nursery: + await nursery.start(dtls_endpoint.serve, ssl_context, handler) + # ... do other things here ... + + The ``dtls_channel`` passed into the handler function has already performed the + "cookie exchange" part of the DTLS handshake, so the peer address is + trustworthy. But the actual cryptographic handshake doesn't happen until you + start using it, giving you a chance for any last minute configuration, and the + option to catch and handle handshake errors. + + Args: + ssl_context (OpenSSL.SSL.Context): The PyOpenSSL context object to use for + incoming connections. + async_fn: The handler function that will be invoked for each incoming + connection. + + """ + self._check_closed() + if self._listening_context is not None: + raise trio.BusyResourceError("another task is already listening") + try: + self.socket.getsockname() + except OSError: + raise RuntimeError("DTLS socket must be bound before it can serve") + self._ensure_receive_loop() + # We do cookie verification ourselves, so tell OpenSSL not to worry about it. + # (See also _inject_client_hello_untrusted.) + ssl_context.set_cookie_verify_callback(lambda *_: True) + try: + self._listening_context = ssl_context + task_status.started() + + async def handler_wrapper(stream): + with stream: + await async_fn(stream, *args) + + async with trio.open_nursery() as nursery: + async for stream in self._incoming_connections_q.r: # pragma: no branch + nursery.start_soon(handler_wrapper, stream) + finally: + self._listening_context = None + + def connect(self, address, ssl_context): + """Initiate an outgoing DTLS connection. + + Notice that this is a synchronous method. That's because it doesn't actually + initiate any I/O – it just sets up a `DTLSChannel` object. The actual handshake + doesn't occur until you start using the `DTLSChannel`. This gives you a chance + to do further configuration first, like setting MTU etc. + + Args: + address: The address to connect to. Usually a (host, port) tuple, like + ``("127.0.0.1", 12345)``. + ssl_context (OpenSSL.SSL.Context): The PyOpenSSL context object to use for + this connection. + + Returns: + DTLSChannel + + """ + # it would be nice if we could detect when 'address' is our own endpoint (a + # loopback connection), because that can't work + # but I don't see how to do it reliably + self._check_closed() + channel = DTLSChannel._create(self, address, ssl_context) + channel._ssl.set_connect_state() + old_channel = self._streams.get(address) + if old_channel is not None: + old_channel._set_replaced() + self._streams[address] = channel + return channel diff --git a/trio/_socket.py b/trio/_socket.py index 886f5614f6..bcff1ee9e7 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -349,6 +349,83 @@ async def wrapper(self, *args, **kwargs): return wrapper +# Helpers to work with the (hostname, port) language that Python uses for socket +# addresses everywhere. Split out into a standalone function so it can be reused by +# FakeNet. + +# Take an address in Python's representation, and returns a new address in +# the same representation, but with names resolved to numbers, +# etc. +# +# local=True means that the address is being used with bind() or similar +# local=False means that the address is being used with connect() or sendto() or +# similar. +# +# NOTE: this function does not always checkpoint +async def _resolve_address_nocp(type, family, proto, *, ipv6_v6only, address, local): + # Do some pre-checking (or exit early for non-IP sockets) + if family == _stdlib_socket.AF_INET: + if not isinstance(address, tuple) or not len(address) == 2: + raise ValueError("address should be a (host, port) tuple") + elif family == _stdlib_socket.AF_INET6: + if not isinstance(address, tuple) or not 2 <= len(address) <= 4: + raise ValueError( + "address should be a (host, port, [flowinfo, [scopeid]]) tuple" + ) + elif family == _stdlib_socket.AF_UNIX: + # unwrap path-likes + return os.fspath(address) + else: + return address + + # -- From here on we know we have IPv4 or IPV6 -- + host, port, *_ = address + # Fast path for the simple case: already-resolved IP address, + # already-resolved port. This is particularly important for UDP, since + # every sendto call goes through here. + if isinstance(port, int): + try: + _stdlib_socket.inet_pton(family, address[0]) + except (OSError, TypeError): + pass + else: + return address + # Special cases to match the stdlib, see gh-277 + if host == "": + host = None + if host == "": + host = "255.255.255.255" + flags = 0 + if local: + flags |= _stdlib_socket.AI_PASSIVE + # Since we always pass in an explicit family here, AI_ADDRCONFIG + # doesn't add any value -- if we have no ipv6 connectivity and are + # working with an ipv6 socket, then things will break soon enough! And + # if we do enable it, then it makes it impossible to even run tests + # for ipv6 address resolution on travis-ci, which as of 2017-03-07 has + # no ipv6. + # flags |= AI_ADDRCONFIG + if family == _stdlib_socket.AF_INET6 and not ipv6_v6only: + flags |= _stdlib_socket.AI_V4MAPPED + gai_res = await getaddrinfo(host, port, family, type, proto, flags) + # AFAICT from the spec it's not possible for getaddrinfo to return an + # empty list. + assert len(gai_res) >= 1 + # Address is the last item in the first entry + (*_, normed), *_ = gai_res + # The above ignored any flowid and scopeid in the passed-in address, + # so restore them if present: + if family == _stdlib_socket.AF_INET6: + normed = list(normed) + assert len(normed) == 4 + if len(address) >= 3: + normed[2] = address[2] + if len(address) >= 4: + normed[3] = address[3] + normed = tuple(normed) + return normed + + class SocketType: def __init__(self): raise TypeError( @@ -444,7 +521,7 @@ def close(self): self._sock.close() async def bind(self, address): - address = await self._resolve_local_address_nocp(address) + address = await self._resolve_address_nocp(address, local=True) if ( hasattr(_stdlib_socket, "AF_UNIX") and self.family == _stdlib_socket.AF_UNIX @@ -480,89 +557,21 @@ def is_readable(self): async def wait_writable(self): await _core.wait_writable(self._sock) - ################################################################ - # Address handling - ################################################################ - - # Take an address in Python's representation, and returns a new address in - # the same representation, but with names resolved to numbers, - # etc. - # - # NOTE: this function does not always checkpoint - async def _resolve_address_nocp(self, address, flags): - # Do some pre-checking (or exit early for non-IP sockets) - if self._sock.family == _stdlib_socket.AF_INET: - if not isinstance(address, tuple) or not len(address) == 2: - raise ValueError("address should be a (host, port) tuple") - elif self._sock.family == _stdlib_socket.AF_INET6: - if not isinstance(address, tuple) or not 2 <= len(address) <= 4: - raise ValueError( - "address should be a (host, port, [flowinfo, [scopeid]]) tuple" - ) - elif self._sock.family == _stdlib_socket.AF_UNIX: - # unwrap path-likes - return os.fspath(address) + async def _resolve_address_nocp(self, address, *, local): + if self.family == _stdlib_socket.AF_INET6: + ipv6_v6only = self._sock.getsockopt( + IPPROTO_IPV6, _stdlib_socket.IPV6_V6ONLY + ) else: - return address - - # -- From here on we know we have IPv4 or IPV6 -- - host, port, *_ = address - # Fast path for the simple case: already-resolved IP address, - # already-resolved port. This is particularly important for UDP, since - # every sendto call goes through here. - if isinstance(port, int): - try: - _stdlib_socket.inet_pton(self._sock.family, address[0]) - except (OSError, TypeError): - pass - else: - return address - # Special cases to match the stdlib, see gh-277 - if host == "": - host = None - if host == "": - host = "255.255.255.255" - # Since we always pass in an explicit family here, AI_ADDRCONFIG - # doesn't add any value -- if we have no ipv6 connectivity and are - # working with an ipv6 socket, then things will break soon enough! And - # if we do enable it, then it makes it impossible to even run tests - # for ipv6 address resolution on travis-ci, which as of 2017-03-07 has - # no ipv6. - # flags |= AI_ADDRCONFIG - if self._sock.family == _stdlib_socket.AF_INET6: - if not self._sock.getsockopt(IPPROTO_IPV6, _stdlib_socket.IPV6_V6ONLY): - flags |= _stdlib_socket.AI_V4MAPPED - gai_res = await getaddrinfo( - host, port, self._sock.family, self.type, self._sock.proto, flags + ipv6_v6only = False + return await _resolve_address_nocp( + self.type, + self.family, + self.proto, + ipv6_v6only=ipv6_v6only, + address=address, + local=local, ) - # AFAICT from the spec it's not possible for getaddrinfo to return an - # empty list. - assert len(gai_res) >= 1 - # Address is the last item in the first entry - (*_, normed), *_ = gai_res - # The above ignored any flowid and scopeid in the passed-in address, - # so restore them if present: - if self._sock.family == _stdlib_socket.AF_INET6: - normed = list(normed) - assert len(normed) == 4 - if len(address) >= 3: - normed[2] = address[2] - if len(address) >= 4: - normed[3] = address[3] - normed = tuple(normed) - return normed - - # Returns something appropriate to pass to bind() - # - # NOTE: this function does not always checkpoint - async def _resolve_local_address_nocp(self, address): - return await self._resolve_address_nocp(address, _stdlib_socket.AI_PASSIVE) - - # Returns something appropriate to pass to connect()/sendto()/sendmsg() - # - # NOTE: this function does not always checkpoint - async def _resolve_remote_address_nocp(self, address): - return await self._resolve_address_nocp(address, 0) async def _nonblocking_helper(self, fn, args, kwargs, wait_fn): # We have to reconcile two conflicting goals: @@ -617,7 +626,7 @@ async def connect(self, address): # notification. This means it isn't really cancellable... we close the # socket if cancelled, to avoid confusion. try: - address = await self._resolve_remote_address_nocp(address) + address = await self._resolve_address_nocp(address, local=False) async with _try_sync(): # An interesting puzzle: can a non-blocking connect() return EINTR # (= raise InterruptedError)? PEP 475 specifically left this as @@ -741,7 +750,7 @@ async def sendto(self, *args): # args is: data[, flags], address) # and kwargs are not accepted args = list(args) - args[-1] = await self._resolve_remote_address_nocp(args[-1]) + args[-1] = await self._resolve_address_nocp(args[-1], local=False) return await self._nonblocking_helper( _stdlib_socket.socket.sendto, args, {}, _core.wait_writable ) @@ -766,7 +775,7 @@ async def sendmsg(self, *args): # and kwargs are not accepted if len(args) == 4 and args[-1] is not None: args = list(args) - args[-1] = await self._resolve_remote_address_nocp(args[-1]) + args[-1] = await self._resolve_address_nocp(args[-1], local=False) return await self._nonblocking_helper( _stdlib_socket.socket.sendmsg, args, {}, _core.wait_writable ) diff --git a/trio/testing/_fake_net.py b/trio/testing/_fake_net.py new file mode 100644 index 0000000000..f0ea927734 --- /dev/null +++ b/trio/testing/_fake_net.py @@ -0,0 +1,400 @@ +# This should eventually be cleaned up and become public, but for right now I'm just +# implementing enough to test DTLS. + +# TODO: +# - user-defined routers +# - TCP +# - UDP broadcast + +import trio +import attr +import ipaddress +from collections import deque +import errno +import os +from typing import Union, List, Optional +import enum +from contextlib import contextmanager + +from trio._util import Final, NoPublicConstructor + +IPAddress = Union[ipaddress.IPv4Address, ipaddress.IPv6Address] + + +def _family_for(ip: IPAddress) -> int: + if isinstance(ip, ipaddress.IPv4Address): + return trio.socket.AF_INET + elif isinstance(ip, ipaddress.IPv6Address): + return trio.socket.AF_INET6 + assert False # pragma: no cover + + +def _wildcard_ip_for(family: int) -> IPAddress: + if family == trio.socket.AF_INET: + return ipaddress.ip_address("0.0.0.0") + elif family == trio.socket.AF_INET6: + return ipaddress.ip_address("::") + else: + assert False + + +def _localhost_ip_for(family: int) -> IPAddress: + if family == trio.socket.AF_INET: + return ipaddress.ip_address("127.0.0.1") + elif family == trio.socket.AF_INET6: + return ipaddress.ip_address("::1") + else: + assert False + + +def _fake_err(code): + raise OSError(code, os.strerror(code)) + + +def _scatter(data, buffers): + written = 0 + for buf in buffers: + next_piece = data[written : written + len(buf)] + with memoryview(buf) as mbuf: + mbuf[: len(next_piece)] = next_piece + written += len(next_piece) + if written == len(data): + break + return written + + +@attr.frozen +class UDPEndpoint: + ip: IPAddress + port: int + + def as_python_sockaddr(self): + sockaddr = (self.ip.compressed, self.port) + if isinstance(self.ip, ipaddress.IPv6Address): + sockaddr += (0, 0) + return sockaddr + + @classmethod + def from_python_sockaddr(cls, sockaddr): + ip, port = sockaddr[:2] + return cls(ip=ipaddress.ip_address(ip), port=port) + + +@attr.frozen +class UDPBinding: + local: UDPEndpoint + + +@attr.frozen +class UDPPacket: + source: UDPEndpoint + destination: UDPEndpoint + payload: bytes = attr.ib(repr=lambda p: p.hex()) + + def reply(self, payload): + return UDPPacket( + source=self.destination, destination=self.source, payload=payload + ) + + +@attr.frozen +class FakeSocketFactory(trio.abc.SocketFactory): + fake_net: "FakeNet" + + def socket(self, family: int, type: int, proto: int) -> "FakeSocket": + return FakeSocket._create(self.fake_net, family, type, proto) + + +@attr.frozen +class FakeHostnameResolver(trio.abc.HostnameResolver): + fake_net: "FakeNet" + + async def getaddrinfo( + self, host: str, port: Union[int, str], family=0, type=0, proto=0, flags=0 + ): + raise NotImplementedError("FakeNet doesn't do fake DNS yet") + + async def getnameinfo(self, sockaddr, flags: int): + raise NotImplementedError("FakeNet doesn't do fake DNS yet") + + +class FakeNet(metaclass=Final): + def __init__(self): + # When we need to pick an arbitrary unique ip address/port, use these: + self._auto_ipv4_iter = ipaddress.IPv4Network("1.0.0.0/8").hosts() + self._auto_ipv4_iter = ipaddress.IPv6Network("1::/16").hosts() + self._auto_port_iter = iter(range(50000, 65535)) + + self._bound: Dict[UDPBinding, FakeSocket] = {} + + self.route_packet = None + + def _bind(self, binding: UDPBinding, socket: "FakeSocket") -> None: + if binding in self._bound: + _fake_err(errno.EADDRINUSE) + self._bound[binding] = socket + + def enable(self) -> None: + trio.socket.set_custom_socket_factory(FakeSocketFactory(self)) + trio.socket.set_custom_hostname_resolver(FakeHostnameResolver(self)) + + def send_packet(self, packet) -> None: + if self.route_packet is None: + self.deliver_packet(packet) + else: + self.route_packet(packet) + + def deliver_packet(self, packet) -> None: + binding = UDPBinding(local=packet.destination) + if binding in self._bound: + self._bound[binding]._deliver_packet(packet) + else: + # No valid destination, so drop it + pass + + +class FakeSocket(trio.socket.SocketType, metaclass=NoPublicConstructor): + def __init__(self, fake_net: FakeNet, family: int, type: int, proto: int): + self._fake_net = fake_net + + if not family: + family = trio.socket.AF_INET + if not type: + type = trio.socket.SOCK_STREAM + + if family not in (trio.socket.AF_INET, trio.socket.AF_INET6): + raise NotImplementedError(f"FakeNet doesn't (yet) support family={family}") + if type != trio.socket.SOCK_DGRAM: + raise NotImplementedError(f"FakeNet doesn't (yet) support type={type}") + + self.family = family + self.type = type + self.proto = proto + + self._closed = False + + self._packet_sender, self._packet_receiver = trio.open_memory_channel( + float("inf") + ) + + # This is the source-of-truth for what port etc. this socket is bound to + self._binding: Optional[UDPBinding] = None + + def _check_closed(self): + if self._closed: + _fake_err(errno.EBADF) + + def close(self): + # breakpoint() + if self._closed: + return + self._closed = True + if self._binding is not None: + del self._fake_net._bound[self._binding] + self._packet_receiver.close() + + async def _resolve_address_nocp(self, address, *, local): + return await trio._socket._resolve_address_nocp( + self.type, + self.family, + self.proto, + address=address, + ipv6_v6only=False, + local=local, + ) + + def _deliver_packet(self, packet: UDPPacket): + try: + self._packet_sender.send_nowait(packet) + except trio.BrokenResourceError: + # sending to a closed socket -- UDP packets get dropped + pass + + ################################################################ + # Actual IO operation implementations + ################################################################ + + async def bind(self, addr): + self._check_closed() + if self._binding is not None: + _fake_error(errno.EINVAL) + await trio.lowlevel.checkpoint() + ip_str, port = await self._resolve_address_nocp(addr, local=True) + ip = ipaddress.ip_address(ip_str) + assert _family_for(ip) == self.family + # We convert binds to INET_ANY into binds to localhost + if ip == ipaddress.ip_address("0.0.0.0"): + ip = ipaddress.ip_address("127.0.0.1") + elif ip == ipaddress.ip_address("::"): + ip = ipaddress.ip_address("::1") + if port == 0: + port = next(self._fake_net._auto_port_iter) + binding = UDPBinding(local=UDPEndpoint(ip, port)) + self._fake_net._bind(binding, self) + self._binding = binding + + async def connect(self, peer): + raise NotImplementedError("FakeNet does not (yet) support connected sockets") + + async def sendmsg(self, *args): + self._check_closed() + ancdata = [] + flags = 0 + address = None + if len(args) == 1: + (buffers,) = args + elif len(args) == 2: + buffers, address = args + elif len(args) == 3: + buffers, flags, address = args + elif len(args) == 4: + buffers, ancdata, flags, address = args + else: + raise TypeError("wrong number of arguments") + + await trio.lowlevel.checkpoint() + + if address is not None: + address = await self._resolve_address_nocp(address, local=False) + if ancdata: + raise NotImplementedError("FakeNet doesn't support ancillary data") + if flags: + raise NotImplementedError(f"FakeNet send flags must be 0, not {flags}") + + if address is None: + _fake_err(errno.ENOTCONN) + + destination = UDPEndpoint.from_python_sockaddr(address) + + if self._binding is None: + await self.bind((_wildcard_ip_for(self.family).compressed, 0)) + + payload = b"".join(buffers) + + packet = UDPPacket( + source=self._binding.local, + destination=destination, + payload=payload, + ) + + self._fake_net.send_packet(packet) + + return len(payload) + + async def recvmsg_into(self, buffers, ancbufsize=0, flags=0): + if ancbufsize != 0: + raise NotImplementedError("FakeNet doesn't support ancillary data") + if flags != 0: + raise NotImplementedError("FakeNet doesn't support any recv flags") + + self._check_closed() + + ancdata = [] + msg_flags = 0 + + packet = await self._packet_receiver.receive() + address = packet.source.as_python_sockaddr() + written = _scatter(packet.payload, buffers) + if written < len(packet.payload): + msg_flags |= trio.socket.MSG_TRUNC + return written, ancdata, msg_flags, address + + ################################################################ + # Simple state query stuff + ################################################################ + + def getsockname(self): + self._check_closed() + if self._binding is not None: + return self._binding.local.as_python_sockaddr() + elif self.family == trio.socket.AF_INET: + return ("0.0.0.0", 0) + else: + assert self.family == trio.socket.AF_INET6 + return ("::", 0) + + def getpeername(self): + self._check_closed() + if self._binding is not None: + if self._binding.remote is not None: + return self._binding.remote.as_python_sockaddr() + _fake_err(errno.ENOTCONN) + + def getsockopt(self, level, item): + self._check_closed() + raise OSError(f"FakeNet doesn't implement getsockopt({level}, {item})") + + def setsockopt(self, level, item, value): + self._check_closed() + + if (level, item) == (trio.socket.IPPROTO_IPV6, trio.socket.IPV6_V6ONLY): + if not value: + raise NotImplementedError("FakeNet always has IPV6_V6ONLY=True") + + raise OSError(f"FakeNet doesn't implement setsockopt({level}, {item}, ...)") + + ################################################################ + # Various boilerplate and trivial stubs + ################################################################ + + def __enter__(self): + return self + + def __exit__(self, *exc_info): + self.close() + + async def send(self, data, flags=0): + return await self.sendto(data, flags, None) + + async def sendto(self, *args): + if len(args) == 2: + data, address = args + flags = 0 + elif len(args) == 3: + data, flags, address = args + else: + raise TypeError("wrong number of arguments") + return await self.sendmsg([data], [], flags, address) + + async def recv(self, bufsize, flags=0): + data, address = await self.recvfrom(bufsize, flags) + return data + + async def recv_into(self, buf, nbytes=0, flags=0): + got_bytes, address = await self.recvfrom_into(buf, nbytes, flags) + return got_bytes + + async def recvfrom(self, bufsize, flags=0): + data, ancdata, msg_flags, address = await self.recvmsg(bufsize, flags) + return data, address + + async def recvfrom_into(self, buf, nbytes=0, flags=0): + if nbytes != 0 and nbytes != len(buf): + raise NotImplementedError("partial recvfrom_into") + got_nbytes, ancdata, msg_flags, address = await self.recvmsg_into( + [buf], 0, flags + ) + return got_nbytes, address + + async def recvmsg(self, bufsize, ancbufsize=0, flags=0): + buf = bytearray(bufsize) + got_nbytes, ancdata, msg_flags, address = await self.recvmsg_into( + [buf], ancbufsize, flags + ) + return (bytes(buf[:got_nbytes]), ancdata, msg_flags, address) + + def fileno(self): + raise NotImplementedError("can't get fileno() for FakeNet sockets") + + def detach(self): + raise NotImplementedError("can't detach() a FakeNet socket") + + def get_inheritable(self): + return False + + def set_inheritable(self, inheritable): + if inheritable: + raise NotImplementedError("FakeNet can't make inheritable sockets") + + def share(self, process_id): + raise NotImplementedError("FakeNet can't share sockets") diff --git a/trio/tests/test_dtls.py b/trio/tests/test_dtls.py new file mode 100644 index 0000000000..8968d9a601 --- /dev/null +++ b/trio/tests/test_dtls.py @@ -0,0 +1,867 @@ +import pytest +import trio +import trio.testing +from trio import DTLSEndpoint +import random +import attr +from async_generator import asynccontextmanager +from itertools import count + +import trustme +from OpenSSL import SSL + +from trio.testing._fake_net import FakeNet +from .._core.tests.tutil import slow, binds_ipv6, gc_collect_harder + +ca = trustme.CA() +server_cert = ca.issue_cert("example.com") + +server_ctx = SSL.Context(SSL.DTLS_METHOD) +server_cert.configure_cert(server_ctx) + +client_ctx = SSL.Context(SSL.DTLS_METHOD) +ca.configure_trust(client_ctx) + + +parametrize_ipv6 = pytest.mark.parametrize( + "ipv6", [False, pytest.param(True, marks=binds_ipv6)], ids=["ipv4", "ipv6"] +) + + +def endpoint(**kwargs): + ipv6 = kwargs.pop("ipv6", False) + if ipv6: + family = trio.socket.AF_INET6 + else: + family = trio.socket.AF_INET + sock = trio.socket.socket(type=trio.socket.SOCK_DGRAM, family=family) + return DTLSEndpoint(sock, **kwargs) + + +@asynccontextmanager +async def dtls_echo_server(*, autocancel=True, mtu=None, ipv6=False): + with endpoint(ipv6=ipv6) as server: + if ipv6: + localhost = "::1" + else: + localhost = "127.0.0.1" + await server.socket.bind((localhost, 0)) + async with trio.open_nursery() as nursery: + + async def echo_handler(dtls_channel): + print( + f"echo handler started: " + f"server {dtls_channel.endpoint.socket.getsockname()} " + f"client {dtls_channel.peer_address}" + ) + if mtu is not None: + dtls_channel.set_ciphertext_mtu(mtu) + try: + print("server starting do_handshake") + await dtls_channel.do_handshake() + print("server finished do_handshake") + async for packet in dtls_channel: + print(f"echoing {packet} -> {dtls_channel.peer_address}") + await dtls_channel.send(packet) + except trio.BrokenResourceError: # pragma: no cover + print("echo handler channel broken") + + await nursery.start(server.serve, server_ctx, echo_handler) + + yield server, server.socket.getsockname() + + if autocancel: + nursery.cancel_scope.cancel() + + +@parametrize_ipv6 +async def test_smoke(ipv6): + async with dtls_echo_server(ipv6=ipv6) as (server_endpoint, address): + with endpoint(ipv6=ipv6) as client_endpoint: + client_channel = client_endpoint.connect(address, client_ctx) + with pytest.raises(trio.NeedHandshakeError): + client_channel.get_cleartext_mtu() + + await client_channel.do_handshake() + await client_channel.send(b"hello") + assert await client_channel.receive() == b"hello" + await client_channel.send(b"goodbye") + assert await client_channel.receive() == b"goodbye" + + with pytest.raises(ValueError): + await client_channel.send(b"") + + client_channel.set_ciphertext_mtu(1234) + cleartext_mtu_1234 = client_channel.get_cleartext_mtu() + client_channel.set_ciphertext_mtu(4321) + assert client_channel.get_cleartext_mtu() > cleartext_mtu_1234 + client_channel.set_ciphertext_mtu(1234) + assert client_channel.get_cleartext_mtu() == cleartext_mtu_1234 + + +@slow +async def test_handshake_over_terrible_network(autojump_clock): + HANDSHAKES = 1000 + r = random.Random(0) + fn = FakeNet() + fn.enable() + + async with dtls_echo_server() as (_, address): + async with trio.open_nursery() as nursery: + + async def route_packet(packet): + while True: + op = r.choices( + ["deliver", "drop", "dupe", "delay"], + weights=[0.7, 0.1, 0.1, 0.1], + )[0] + print(f"{packet.source} -> {packet.destination}: {op}") + if op == "drop": + return + elif op == "dupe": + fn.send_packet(packet) + elif op == "delay": + await trio.sleep(r.random() * 3) + # I wanted to test random packet corruption too, but it turns out + # openssl has a bug in the following scenario: + # + # - client sends ClientHello + # - server sends HelloVerifyRequest with cookie -- but cookie is + # invalid b/c either the ClientHello or HelloVerifyRequest was + # corrupted + # - client re-sends ClientHello with invalid cookie + # - server replies with new HelloVerifyRequest and correct cookie + # + # At this point, the client *should* switch to the new, valid + # cookie. But OpenSSL doesn't; it stubbornly insists on re-sending + # the original, invalid cookie over and over. In theory we could + # work around this by detecting cookie changes and starting over + # with a whole new SSL object, but (a) it doesn't seem worth it, (b) + # when I tried then I ran into another issue where OpenSSL got stuck + # in an infinite loop sending alerts over and over, which I didn't + # dig into because see (a). + # + # elif op == "distort": + # payload = bytearray(packet.payload) + # payload[r.randrange(len(payload))] ^= 1 << r.randrange(8) + # packet = attr.evolve(packet, payload=payload) + else: + assert op == "deliver" + print( + f"{packet.source} -> {packet.destination}: delivered {packet.payload.hex()}" + ) + fn.deliver_packet(packet) + break + + def route_packet_wrapper(packet): + try: + nursery.start_soon(route_packet, packet) + except RuntimeError: # pragma: no cover + # We're exiting the nursery, so any remaining packets can just get + # dropped + pass + + fn.route_packet = route_packet_wrapper + + for i in range(HANDSHAKES): + print("#" * 80) + print("#" * 80) + print("#" * 80) + with endpoint() as client_endpoint: + client = client_endpoint.connect(address, client_ctx) + print("client starting do_handshake") + await client.do_handshake() + print("client finished do_handshake") + msg = str(i).encode() + # Make multiple attempts to send data, because the network might + # drop it + while True: + with trio.move_on_after(10) as cscope: + await client.send(msg) + assert await client.receive() == msg + if not cscope.cancelled_caught: + break + + +async def test_implicit_handshake(): + async with dtls_echo_server() as (_, address): + with endpoint() as client_endpoint: + client = client_endpoint.connect(address, client_ctx) + + # Implicit handshake + await client.send(b"xyz") + assert await client.receive() == b"xyz" + + +async def test_full_duplex(): + # Tests simultaneous send/receive, and also multiple methods implicitly invoking + # do_handshake simultaneously. + with endpoint() as server_endpoint, endpoint() as client_endpoint: + await server_endpoint.socket.bind(("127.0.0.1", 0)) + async with trio.open_nursery() as server_nursery: + + async def handler(channel): + async with trio.open_nursery() as nursery: + nursery.start_soon(channel.send, b"from server") + nursery.start_soon(channel.receive) + + await server_nursery.start(server_endpoint.serve, server_ctx, handler) + + client = client_endpoint.connect( + server_endpoint.socket.getsockname(), client_ctx + ) + async with trio.open_nursery() as nursery: + nursery.start_soon(client.send, b"from client") + nursery.start_soon(client.receive) + + server_nursery.cancel_scope.cancel() + + +async def test_channel_closing(): + async with dtls_echo_server() as (_, address): + with endpoint() as client_endpoint: + client = client_endpoint.connect(address, client_ctx) + await client.do_handshake() + client.close() + + with pytest.raises(trio.ClosedResourceError): + await client.send(b"abc") + with pytest.raises(trio.ClosedResourceError): + await client.receive() + + # close is idempotent + client.close() + # can also aclose + await client.aclose() + + +async def test_serve_exits_cleanly_on_close(): + async with dtls_echo_server(autocancel=False) as (server_endpoint, address): + server_endpoint.close() + # Testing that the nursery exits even without being cancelled + # close is idempotent + server_endpoint.close() + + +async def test_client_multiplex(): + async with dtls_echo_server() as (_, address1), dtls_echo_server() as (_, address2): + with endpoint() as client_endpoint: + client1 = client_endpoint.connect(address1, client_ctx) + client2 = client_endpoint.connect(address2, client_ctx) + + await client1.send(b"abc") + await client2.send(b"xyz") + assert await client2.receive() == b"xyz" + assert await client1.receive() == b"abc" + + client_endpoint.close() + + with pytest.raises(trio.ClosedResourceError): + await client1.send("xxx") + with pytest.raises(trio.ClosedResourceError): + await client2.receive() + with pytest.raises(trio.ClosedResourceError): + client_endpoint.connect(address1, client_ctx) + + async with trio.open_nursery() as nursery: + with pytest.raises(trio.ClosedResourceError): + + async def null_handler(_): # pragma: no cover + pass + + await nursery.start(client_endpoint.serve, server_ctx, null_handler) + + +async def test_dtls_over_dgram_only(): + with trio.socket.socket() as s: + with pytest.raises(ValueError): + DTLSEndpoint(s) + + +async def test_double_serve(): + async def null_handler(_): # pragma: no cover + pass + + with endpoint() as server_endpoint: + await server_endpoint.socket.bind(("127.0.0.1", 0)) + async with trio.open_nursery() as nursery: + await nursery.start(server_endpoint.serve, server_ctx, null_handler) + with pytest.raises(trio.BusyResourceError): + await nursery.start(server_endpoint.serve, server_ctx, null_handler) + + nursery.cancel_scope.cancel() + + async with trio.open_nursery() as nursery: + await nursery.start(server_endpoint.serve, server_ctx, null_handler) + nursery.cancel_scope.cancel() + + +async def test_connect_to_non_server(autojump_clock): + fn = FakeNet() + fn.enable() + with endpoint() as client1, endpoint() as client2: + await client1.socket.bind(("127.0.0.1", 0)) + # This should just time out + with trio.move_on_after(100) as cscope: + channel = client2.connect(client1.socket.getsockname(), client_ctx) + await channel.do_handshake() + assert cscope.cancelled_caught + + +async def test_incoming_buffer_overflow(autojump_clock): + fn = FakeNet() + fn.enable() + for buffer_size in [10, 20]: + async with dtls_echo_server() as (_, address): + with endpoint(incoming_packets_buffer=buffer_size) as client_endpoint: + assert client_endpoint.incoming_packets_buffer == buffer_size + client = client_endpoint.connect(address, client_ctx) + for i in range(buffer_size + 15): + await client.send(str(i).encode()) + await trio.sleep(1) + stats = client.statistics() + assert stats.incoming_packets_dropped_in_trio == 15 + for i in range(buffer_size): + assert await client.receive() == str(i).encode() + await client.send(b"buffer clear now") + assert await client.receive() == b"buffer clear now" + + +async def test_server_socket_doesnt_crash_on_garbage(autojump_clock): + fn = FakeNet() + fn.enable() + + from trio._dtls import ( + Record, + encode_record, + HandshakeFragment, + encode_handshake_fragment, + ContentType, + HandshakeType, + ProtocolVersion, + ) + + client_hello = encode_record( + Record( + content_type=ContentType.handshake, + version=ProtocolVersion.DTLS10, + epoch_seqno=0, + payload=encode_handshake_fragment( + HandshakeFragment( + msg_type=HandshakeType.client_hello, + msg_len=10, + msg_seq=0, + frag_offset=0, + frag_len=10, + frag=bytes(10), + ) + ), + ) + ) + + client_hello_extended = client_hello + b"\x00" + client_hello_short = client_hello[:-1] + # cuts off in middle of handshake message header + client_hello_really_short = client_hello[:14] + client_hello_corrupt_record_len = bytearray(client_hello) + client_hello_corrupt_record_len[11] = 0xFF + + client_hello_fragmented = encode_record( + Record( + content_type=ContentType.handshake, + version=ProtocolVersion.DTLS10, + epoch_seqno=0, + payload=encode_handshake_fragment( + HandshakeFragment( + msg_type=HandshakeType.client_hello, + msg_len=20, + msg_seq=0, + frag_offset=0, + frag_len=10, + frag=bytes(10), + ) + ), + ) + ) + + client_hello_trailing_data_in_record = encode_record( + Record( + content_type=ContentType.handshake, + version=ProtocolVersion.DTLS10, + epoch_seqno=0, + payload=encode_handshake_fragment( + HandshakeFragment( + msg_type=HandshakeType.client_hello, + msg_len=20, + msg_seq=0, + frag_offset=0, + frag_len=10, + frag=bytes(10), + ) + ) + + b"\x00", + ) + ) + + handshake_empty = encode_record( + Record( + content_type=ContentType.handshake, + version=ProtocolVersion.DTLS10, + epoch_seqno=0, + payload=b"", + ) + ) + + client_hello_truncated_in_cookie = encode_record( + Record( + content_type=ContentType.handshake, + version=ProtocolVersion.DTLS10, + epoch_seqno=0, + payload=bytes(2 + 32 + 1) + b"\xff", + ) + ) + + async with dtls_echo_server() as (_, address): + with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as sock: + for bad_packet in [ + b"", + b"xyz", + client_hello_extended, + client_hello_short, + client_hello_really_short, + client_hello_corrupt_record_len, + client_hello_fragmented, + client_hello_trailing_data_in_record, + handshake_empty, + client_hello_truncated_in_cookie, + ]: + await sock.sendto(bad_packet, address) + await trio.sleep(1) + + +async def test_invalid_cookie_rejected(autojump_clock): + fn = FakeNet() + fn.enable() + + from trio._dtls import decode_client_hello_untrusted, BadPacket + + with trio.CancelScope() as cscope: + + # the first 11 bytes of ClientHello aren't protected by the cookie, so only test + # corrupting bytes after that. + offset_to_corrupt = count(11) + + def route_packet(packet): + try: + _, cookie, _ = decode_client_hello_untrusted(packet.payload) + except BadPacket: + pass + else: + if len(cookie) != 0: + # this is a challenge response packet + # let's corrupt the next offset so the handshake should fail + payload = bytearray(packet.payload) + offset = next(offset_to_corrupt) + if offset >= len(payload): + # We've tried all offsets. Clamp offset to the end of the + # payload, and terminate the test. + offset = len(payload) - 1 + cscope.cancel() + payload[offset] ^= 0x01 + packet = attr.evolve(packet, payload=payload) + + fn.deliver_packet(packet) + + fn.route_packet = route_packet + + async with dtls_echo_server() as (_, address): + while True: + with endpoint() as client: + channel = client.connect(address, client_ctx) + await channel.do_handshake() + assert cscope.cancelled_caught + + +async def test_client_cancels_handshake_and_starts_new_one(autojump_clock): + # if a client disappears during the handshake, and then starts a new handshake from + # scratch, then the first handler's channel should fail, and a new handler get + # started + fn = FakeNet() + fn.enable() + + with endpoint() as server, endpoint() as client: + await server.socket.bind(("127.0.0.1", 0)) + async with trio.open_nursery() as nursery: + first_time = True + + async def handler(channel): + nonlocal first_time + if first_time: + first_time = False + print("handler: first time, cancelling connect") + connect_cscope.cancel() + await trio.sleep(0.5) + print("handler: handshake should fail now") + with pytest.raises(trio.BrokenResourceError): + await channel.do_handshake() + else: + print("handler: not first time, sending hello") + await channel.send(b"hello") + + await nursery.start(server.serve, server_ctx, handler) + + print("client: starting first connect") + with trio.CancelScope() as connect_cscope: + channel = client.connect(server.socket.getsockname(), client_ctx) + await channel.do_handshake() + assert connect_cscope.cancelled_caught + + print("client: starting second connect") + channel = client.connect(server.socket.getsockname(), client_ctx) + assert await channel.receive() == b"hello" + + # Give handlers a chance to finish + await trio.sleep(10) + nursery.cancel_scope.cancel() + + +async def test_swap_client_server(): + with endpoint() as a, endpoint() as b: + await a.socket.bind(("127.0.0.1", 0)) + await b.socket.bind(("127.0.0.1", 0)) + + async def echo_handler(channel): + async for packet in channel: + await channel.send(packet) + + async def crashing_echo_handler(channel): + with pytest.raises(trio.BrokenResourceError): + await echo_handler(channel) + + async with trio.open_nursery() as nursery: + await nursery.start(a.serve, server_ctx, crashing_echo_handler) + await nursery.start(b.serve, server_ctx, echo_handler) + + b_to_a = b.connect(a.socket.getsockname(), client_ctx) + await b_to_a.send(b"b as client") + assert await b_to_a.receive() == b"b as client" + + a_to_b = a.connect(b.socket.getsockname(), client_ctx) + await a_to_b.do_handshake() + with pytest.raises(trio.BrokenResourceError): + await b_to_a.send(b"association broken") + await a_to_b.send(b"a as client") + assert await a_to_b.receive() == b"a as client" + + nursery.cancel_scope.cancel() + + +@slow +async def test_openssl_retransmit_doesnt_break_stuff(): + # can't use autojump_clock here, because the point of the test is to wait for + # openssl's built-in retransmit timer to expire, which is hard-coded to use + # wall-clock time. + fn = FakeNet() + fn.enable() + + blackholed = True + + def route_packet(packet): + if blackholed: + print("dropped packet", packet) + return + print("delivered packet", packet) + # packets.append( + # scapy.all.IP( + # src=packet.source.ip.compressed, dst=packet.destination.ip.compressed + # ) + # / scapy.all.UDP(sport=packet.source.port, dport=packet.destination.port) + # / packet.payload + # ) + fn.deliver_packet(packet) + + fn.route_packet = route_packet + + async with dtls_echo_server() as (server_endpoint, address): + with endpoint() as client_endpoint: + async with trio.open_nursery() as nursery: + + async def connecter(): + client = client_endpoint.connect(address, client_ctx) + await client.do_handshake(initial_retransmit_timeout=1.5) + await client.send(b"hi") + assert await client.receive() == b"hi" + + nursery.start_soon(connecter) + + # openssl's default timeout is 1 second, so this ensures that it thinks + # the timeout has expired + await trio.sleep(1.1) + # disable blackholing and send a garbage packet to wake up openssl so it + # notices the timeout has expired + blackholed = False + await server_endpoint.socket.sendto( + b"xxx", client_endpoint.socket.getsockname() + ) + # now the client task should finish connecting and exit cleanly + + # scapy.all.wrpcap("/tmp/trace.pcap", packets) + + +async def test_initial_retransmit_timeout_configuration(autojump_clock): + fn = FakeNet() + fn.enable() + + blackholed = True + + def route_packet(packet): + nonlocal blackholed + if blackholed: + blackholed = False + else: + fn.deliver_packet(packet) + + fn.route_packet = route_packet + + async with dtls_echo_server() as (_, address): + for t in [1, 2, 4]: + with endpoint() as client: + before = trio.current_time() + blackholed = True + channel = client.connect(address, client_ctx) + await channel.do_handshake(initial_retransmit_timeout=t) + after = trio.current_time() + assert after - before == t + + +async def test_explicit_tiny_mtu_is_respected(): + # ClientHello is ~240 bytes, and it can't be fragmented, so our mtu has to + # be larger than that. (300 is still smaller than any real network though.) + MTU = 300 + + fn = FakeNet() + fn.enable() + + def route_packet(packet): + print(f"delivering {packet}") + print(f"payload size: {len(packet.payload)}") + assert len(packet.payload) <= MTU + fn.deliver_packet(packet) + + fn.route_packet = route_packet + + async with dtls_echo_server(mtu=MTU) as (server, address): + with endpoint() as client: + channel = client.connect(address, client_ctx) + channel.set_ciphertext_mtu(MTU) + await channel.do_handshake() + await channel.send(b"hi") + assert await channel.receive() == b"hi" + + +@parametrize_ipv6 +async def test_handshake_handles_minimum_network_mtu(ipv6, autojump_clock): + # Fake network that has the minimum allowable MTU for whatever protocol we're using. + fn = FakeNet() + fn.enable() + + if ipv6: + mtu = 1280 - 48 + else: + mtu = 576 - 28 + + def route_packet(packet): + if len(packet.payload) > mtu: + print(f"dropping {packet}") + else: + print(f"delivering {packet}") + fn.deliver_packet(packet) + + fn.route_packet = route_packet + + # See if we can successfully do a handshake -- some of the volleys will get dropped, + # and the retransmit logic should detect this and back off the MTU to something + # smaller until it succeeds. + async with dtls_echo_server(ipv6=ipv6) as (_, address): + with endpoint(ipv6=ipv6) as client_endpoint: + client = client_endpoint.connect(address, client_ctx) + # the handshake mtu backoff shouldn't affect the return value from + # get_cleartext_mtu, b/c that's under the user's control via + # set_ciphertext_mtu + client.set_ciphertext_mtu(9999) + await client.send(b"xyz") + assert await client.receive() == b"xyz" + assert client.get_cleartext_mtu() > 9000 # as vegeta said + + +@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning") +async def test_system_task_cleaned_up_on_gc(): + before_tasks = trio.lowlevel.current_statistics().tasks_living + + # We put this into a sub-function so that everything automatically becomes garbage + # when the frame exits. For some reason just doing 'del e' wasn't enough on pypy + # with coverage enabled -- I think we were hitting this bug: + # https://foss.heptapod.net/pypy/pypy/-/issues/3656 + async def start_and_forget_endpoint(): + e = endpoint() + + # This connection/handshake attempt can't succeed. The only purpose is to force + # the endpoint to set up a receive loop. + with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as s: + await s.bind(("127.0.0.1", 0)) + c = e.connect(s.getsockname(), client_ctx) + async with trio.open_nursery() as nursery: + nursery.start_soon(c.do_handshake) + await trio.testing.wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + + during_tasks = trio.lowlevel.current_statistics().tasks_living + return during_tasks + + with pytest.warns(ResourceWarning): + during_tasks = await start_and_forget_endpoint() + await trio.testing.wait_all_tasks_blocked() + gc_collect_harder() + + await trio.testing.wait_all_tasks_blocked() + + after_tasks = trio.lowlevel.current_statistics().tasks_living + assert before_tasks < during_tasks + assert before_tasks == after_tasks + + +@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning") +async def test_gc_before_system_task_starts(): + e = endpoint() + + with pytest.warns(ResourceWarning): + del e + gc_collect_harder() + + await trio.testing.wait_all_tasks_blocked() + + +@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning") +async def test_gc_as_packet_received(): + fn = FakeNet() + fn.enable() + + e = endpoint() + await e.socket.bind(("127.0.0.1", 0)) + e._ensure_receive_loop() + + await trio.testing.wait_all_tasks_blocked() + + with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as s: + await s.sendto(b"xxx", e.socket.getsockname()) + # At this point, the endpoint's receive loop has been marked runnable because it + # just received a packet; closing the endpoint socket won't interrupt that. But by + # the time it wakes up to process the packet, the endpoint will be gone. + with pytest.warns(ResourceWarning): + del e + gc_collect_harder() + + +@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning") +def test_gc_after_trio_exits(): + async def main(): + # We use fakenet just to make sure no real sockets can leak out of the test + # case - on pypy somehow the socket was outliving the gc_collect_harder call + # below. Since the test is just making sure DTLSEndpoint.__del__ doesn't explode + # when called after trio exits, it doesn't need a real socket. + fn = FakeNet() + fn.enable() + return endpoint() + + e = trio.run(main) + with pytest.warns(ResourceWarning): + del e + gc_collect_harder() + + +async def test_already_closed_socket_doesnt_crash(): + with endpoint() as e: + # We close the socket before checkpointing, so the socket will already be closed + # when the system task starts up + e.socket.close() + # Now give it a chance to start up, and hopefully not crash + await trio.testing.wait_all_tasks_blocked() + + +async def test_socket_closed_while_processing_clienthello(autojump_clock): + fn = FakeNet() + fn.enable() + + # Check what happens if the socket is discovered to be closed when sending a + # HelloVerifyRequest, since that has its own sending logic + async with dtls_echo_server() as (server, address): + + def route_packet(packet): + fn.deliver_packet(packet) + server.socket.close() + + fn.route_packet = route_packet + + with endpoint() as client_endpoint: + with trio.move_on_after(10): + client = client_endpoint.connect(address, client_ctx) + await client.do_handshake() + + +async def test_association_replaced_while_handshake_running(autojump_clock): + fn = FakeNet() + fn.enable() + + def route_packet(packet): + pass + + fn.route_packet = route_packet + + async with dtls_echo_server() as (_, address): + with endpoint() as client_endpoint: + c1 = client_endpoint.connect(address, client_ctx) + async with trio.open_nursery() as nursery: + + async def doomed_handshake(): + with pytest.raises(trio.BrokenResourceError): + await c1.do_handshake() + + nursery.start_soon(doomed_handshake) + + await trio.sleep(10) + + client_endpoint.connect(address, client_ctx) + + +async def test_association_replaced_before_handshake_starts(): + fn = FakeNet() + fn.enable() + + # This test shouldn't send any packets + def route_packet(packet): # pragma: no cover + assert False + + fn.route_packet = route_packet + + async with dtls_echo_server() as (_, address): + with endpoint() as client_endpoint: + c1 = client_endpoint.connect(address, client_ctx) + client_endpoint.connect(address, client_ctx) + with pytest.raises(trio.BrokenResourceError): + await c1.do_handshake() + + +async def test_send_to_closed_local_port(): + # On Windows, sending a UDP packet to a closed local port can cause a weird + # ECONNRESET error later, inside the receive task. Make sure we're handling it + # properly. + async with dtls_echo_server() as (_, address): + with endpoint() as client_endpoint: + async with trio.open_nursery() as nursery: + for i in range(1, 10): + channel = client_endpoint.connect(("127.0.0.1", i), client_ctx) + nursery.start_soon(channel.do_handshake) + channel = client_endpoint.connect(address, client_ctx) + await channel.send(b"xxx") + assert await channel.receive() == b"xxx" + nursery.cancel_scope.cancel() diff --git a/trio/tests/test_fakenet.py b/trio/tests/test_fakenet.py new file mode 100644 index 0000000000..bc691c9db5 --- /dev/null +++ b/trio/tests/test_fakenet.py @@ -0,0 +1,44 @@ +import pytest + +import trio +from trio.testing._fake_net import FakeNet + + +def fn(): + fn = FakeNet() + fn.enable() + return fn + + +async def test_basic_udp(): + fn() + s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + + await s1.bind(("127.0.0.1", 0)) + ip, port = s1.getsockname() + assert ip == "127.0.0.1" + assert port != 0 + await s2.sendto(b"xyz", s1.getsockname()) + data, addr = await s1.recvfrom(10) + assert data == b"xyz" + assert addr == s2.getsockname() + await s1.sendto(b"abc", s2.getsockname()) + data, addr = await s2.recvfrom(10) + assert data == b"abc" + assert addr == s1.getsockname() + + +async def test_msg_trunc(): + fn() + s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + await s1.bind(("127.0.0.1", 0)) + await s2.sendto(b"xyz", s1.getsockname()) + data, addr = await s1.recvfrom(10) + + +async def test_basic_tcp(): + fn() + with pytest.raises(NotImplementedError): + trio.socket.socket() diff --git a/trio/tests/test_socket.py b/trio/tests/test_socket.py index d891041ab2..1fa3721f91 100644 --- a/trio/tests/test_socket.py +++ b/trio/tests/test_socket.py @@ -495,21 +495,20 @@ def assert_eq(actual, expected): with tsocket.socket(family=socket_type) as sock: # For some reason the stdlib special-cases "" to pass NULL to - # getaddrinfo They also error out on None, but whatever, None is much + # getaddrinfo. They also error out on None, but whatever, None is much # more consistent, so we accept it too. for null in [None, ""]: - got = await sock._resolve_local_address_nocp((null, 80)) + got = await sock._resolve_address_nocp((null, 80), local=True) assert_eq(got, (addrs.bind_all, 80)) - got = await sock._resolve_remote_address_nocp((null, 80)) + got = await sock._resolve_address_nocp((null, 80), local=False) assert_eq(got, (addrs.localhost, 80)) # AI_PASSIVE only affects the wildcard address, so for everything else - # _resolve_local_address_nocp and _resolve_remote_address_nocp should - # work the same: - for resolver in ["_resolve_local_address_nocp", "_resolve_remote_address_nocp"]: + # local=True/local=False should work the same: + for local in [False, True]: async def res(*args): - return await getattr(sock, resolver)(*args) + return await sock._resolve_address_nocp(*args, local=local) assert_eq(await res((addrs.arbitrary, "http")), (addrs.arbitrary, 80)) if v6: @@ -560,7 +559,10 @@ async def res(*args): except (AttributeError, OSError): pass else: - assert await getattr(netlink_sock, resolver)("asdf") == "asdf" + assert ( + await netlink_sock._resolve_address_nocp("asdf", local=local) + == "asdf" + ) netlink_sock.close() with pytest.raises(ValueError): @@ -721,16 +723,16 @@ def connect(self, *args, **kwargs): await sock.connect(("127.0.0.1", 2)) -async def test_resolve_remote_address_exception_closes_socket(): +async def test_resolve_address_exception_in_connect_closes_socket(): # Here we are testing issue 247, any cancellation will leave the socket closed with _core.CancelScope() as cancel_scope: with tsocket.socket() as sock: - async def _resolve_remote_address_nocp(self, *args, **kwargs): + async def _resolve_address_nocp(self, *args, **kwargs): cancel_scope.cancel() await _core.checkpoint() - sock._resolve_remote_address_nocp = _resolve_remote_address_nocp + sock._resolve_address_nocp = _resolve_address_nocp with assert_checkpoints(): with pytest.raises(_core.Cancelled): await sock.connect("")