From cd0ec9f43080f1c9f0e5a9df65e8a7a271ff15e6 Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Thu, 24 Jun 2021 09:16:43 -0700 Subject: [PATCH 01/47] First draft --- newsfragments/2010.feature.rst | 4 + trio/_dtls.py | 854 +++++++++++++++++++++++++++++++++ 2 files changed, 858 insertions(+) create mode 100644 newsfragments/2010.feature.rst create mode 100644 trio/_dtls.py diff --git a/newsfragments/2010.feature.rst b/newsfragments/2010.feature.rst new file mode 100644 index 0000000000..f99687652f --- /dev/null +++ b/newsfragments/2010.feature.rst @@ -0,0 +1,4 @@ +Added support for `Datagram TLS +`__, +for secure communication over UDP. Currently requires `PyOpenSSL +`__. diff --git a/trio/_dtls.py b/trio/_dtls.py new file mode 100644 index 0000000000..b825fea5d6 --- /dev/null +++ b/trio/_dtls.py @@ -0,0 +1,854 @@ +# https://datatracker.ietf.org/doc/html/rfc6347 + +# XX: figure out what to do about the pyopenssl dependency +# Maybe the toplevel __init__.py should use __getattr__ trickery to load all +# the DTLS code lazily? + +import struct +import hmac +import os +import io +import enum +from itertools import count +import weakref + +import attr +from OpenSSL import SSL + +import trio +from trio._util import NoPublicConstructor + +MAX_UDP_PACKET_SIZE = 65527 + +# There are a bunch of different RFCs that define these codes, so for a +# comprehensive collection look here: +# https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml +class ContentType(enum.IntEnum): + change_cipher_spec = 20 + alert = 21 + handshake = 22 + application_data = 23 + heartbeat = 24 + + +class HandshakeType(enum.IntEnum): + hello_request = 0 + client_hello = 1 + server_hello = 2 + hello_verify_request = 3 + new_session_ticket = 4 + end_of_early_data = 4 + encrypted_extensions = 8 + certificate = 11 + server_key_exchange = 12 + certificate_request = 13 + server_hello_done = 14 + certificate_verify = 15 + client_key_exchange = 16 + finished = 20 + certificate_url = 21 + certificate_status = 22 + supplemental_data = 23 + key_update = 24 + compressed_certificate = 25 + ekt_key = 26 + message_hash = 254 + + +class ProtocolVersion: + DTLS10 = bytes([254, 255]) + DTLS12 = bytes([254, 253]) + + +EPOCH_MASK = 0xffff << (6 * 8) + + +# Conventions: +# - All functions that handle network data end in _untrusted. +# - All functions end in _untrusted MUST make sure that bad data from the +# network cannot *only* cause BadPacket to be raised. No IndexError or +# struct.error or whatever. +class BadPacket(Exception): + pass + + +# This checks that the DTLS 'epoch' field is 0, which is true iff we're in the +# initial handshake. It doesn't check the ContentType, because not all +# handshake messages have ContentType==handshake -- for example, +# ChangeCipherSpec is used during the handshake but has its own ContentType. +# +# Cannot fail. +def part_of_handshake_untrusted(packet): + # If the packet is too short, then slicing will successfully return a + # short string, which will necessarily fail to match. + return packet[3:5] == b"\x00\x00" + + +# Cannot fail +def is_client_hello_untrusted(packet): + try: + return ( + packet[0] == ContentType.handshake + and packet[13] == HandshakeType.client_hello + ) + except IndexError: + # Invalid DTLS record + return False + + +# DTLS records are: +# - 1 byte content type +# - 2 bytes version +# - 8 bytes epoch+seqno +# Technically this is 2 bytes epoch then 6 bytes seqno, but we treat it as +# a single 8-byte integer, where epoch changes are represented as jumping +# forward by 2**(6*8). +# - 2 bytes payload length (unsigned big-endian) +# - payload +RECORD_HEADER = struct.Struct("!B2sQH") + + +@attr.frozen +class Record: + content_type: int + version: bytes + epoch_seqno: int + payload: bytes + + +def records_untrusted(packet): + i = 0 + while i < len(packet): + try: + ct, version, epoch_seqno, payload_len = RECORD_HEADER.unpack_from(packet, i) + except struct.error as exc: + raise BadPacket("invalid record header") from exc + i += RECORD_HEADER.size + payload = packet[i : i + payload_len] + if len(payload) != payload_len: + raise BadPacket("short record") + i += payload_len + yield Record(ct, version, epoch_seqno, payload) + + +def encode_record(record): + header = RECORD_HEADER.pack( + record.content_type, + record.version, + record.epoch_seqno, + len(record.payload), + ) + return header + record.payload + + +# Handshake messages are: +# - 1 byte message type +# - 3 bytes total message length +# - 2 bytes message sequence number +# - 3 bytes fragment offset +# - 3 bytes fragment length +HANDSHAKE_MESSAGE_HEADER = struct.Struct("!B3sH3s3s") + + +@attr.frozen +class HandshakeFragment: + msg_type: int + msg_len: int, + msg_seq: int + frag_offset: int + frag_len: int + frag: bytes + + +def decode_handshake_fragment_untrusted(payload): + # Raises BadPacket if decoding fails + try: + ( + msg_type, + msg_len_bytes, + msg_seq, + frag_offset_bytes, + frag_len_bytes, + ) = HANDSHAKE_MESSAGE_HEADER.unpack_from(payload) + except struct.error as exc: + raise BadPacket("bad handshake message header") from exc + # 'struct' doesn't have built-in support for 24-bit integers, so we + # have to do it by hand. These can't fail. + msg_len = int.from_bytes(msg_len_bytes, "big") + frag_offset = int.from_bytes(frag_offset_bytes, "big") + frag_len = int.from_bytes(frag_len_bytes, "big") + frag = payload[HANDSHAKE_MESSAGE_HEADER.size:] + if len(frag) != frag_len: + raise BadPacket("short fragment") + return HandshakeFragment( + msg_type, + msg_len, + msg_seq, + frag_offset, + frag_len, + frag, + ) + + +def encode_handshake_fragment(hsf): + hs_header = HANDSHAKE_MESSAGE_HEADER.pack( + hsf.msg_type, + hsf.msg_len.to_bytes(3, "big"), + hsf.msg_seq, + hsf.frag_offset.to_bytes(3, "big"), + hsf.frag_len.to_bytes(3, "big"), + ) + return hs_header + hsf.body + + +def decode_client_hello_untrusted(packet): + # Raises BadPacket if parsing fails + # Returns (record epoch_seqno, cookie from the packet, data that should be + # hashed into cookie) + try: + # ClientHello has to be the first record in the packet + record = next(records_untrusted(packet)) + if record.content_type != ContentType.handshake: + raise BadPacket("not a handshake record") + fragment = decode_handshake_fragment_untrusted(record.payload) + if fragment.msg_type != HandshakeType.client_hello: + raise BadPacket("not a ClientHello") + # ClientHello can't be fragmented, because reassembly requires holding + # per-connection state, and we refuse to allocate per-connection state + # until after we get a valid ClientHello. + if fragment.frag_offset != 0: + raise BadPacket("fragmented ClientHello") + if fragment.frag_len != fragment.msg_len: + raise BadPacket("fragmented ClientHello") + + # As per RFC 6347: + # + # When responding to a HelloVerifyRequest, the client MUST use the + # same parameter values (version, random, session_id, cipher_suites, + # compression_method) as it did in the original ClientHello. The + # server SHOULD use those values to generate its cookie and verify that + # they are correct upon cookie receipt. + # + # However, the record-layer framing can and will change (e.g. the + # second ClientHello will have a new record-layer sequence number). So + # we need to pull out the handshake message alone, discarding the + # record-layer stuff, and then we're going to hash all of it *except* + # the cookie. + + body = fragment.frag + # ClientHello is: + # + # - 2 bytes client_version + # - 32 bytes random + # - 1 byte session_id length + # - session_id + # - 1 byte cookie length + # - cookie + # - everything else + # + # So to find the cookie, so we need to figure out how long the + # session_id is and skip past it. + session_id_len = body[2 + 32] + cookie_len_offset = 2 + 32 + 1 + session_id_len + cookie_len = body[cookie_len_offset] + + cookie_start = cookie_len_offset + 1 + cookie_end = cookie_start + cookie_len + + before_cookie = body[:cookie_start] + cookie = body[cookie_start:cookie_end] + after_cookie = body[cookie_end:] + + if len(cookie) != cookie_len: + raise BadPacket("short cookie") + return (record.epoch_seqno, cookie, before_cookie + after_cookie) + + except (struct.error, IndexError) as exc: + raise BadPacket("bad ClientHello") from exc + + +@attr.frozen +class HandshakeMessage: + record_version: bytes + msg_type: HandshakeType + msg_seq: int + body: bytearray + + +# ChangeCipherSpec is part of the handshake, but it's not a "handshake +# message" and can't be fragmented the same way. Sigh. +@attr.frozen +class PseudoHandshakeMessage: + record_version: bytes + content_type: int + payload: bytes + + +# This takes a raw outgoing handshake volley that openssl generated, and +# reconstructs the handshake messages inside it, so that we can repack them +# into records while retransmitting. So the data ought to be well-behaved -- +# it's not coming from the network. +def decode_volley_trusted(volley): + messages = [] + messages_by_seq = {} + for record in records_untrusted(volley): + if record.content_type == ContentType.change_cipher_spec: + messages.append(PseudoHandshakeMessage(record.version, + record.content_type, record.payload)) + else: + assert record.content_type == ContentType.handshake + fragment = decode_handshake_fragment_untrusted(record.payload) + msg_type = HandshakeType(fragment.msg_type) + if fragment.msg_seq not in messages_by_seq: + msg = HandshakeMessage( + record.version, msg_type, fragment.msg_seq, bytearray(fragment.msg_len) + ) + messages.append(msg) + messages_by_seq[fragment.msg_seq] = msg + else: + msg = messages_by_seq[msg_seq] + assert msg.msg_type == fragment.msg_type + assert msg.msg_seq == fragment.msg_seq + assert len(msg.body) == fragment.msg_len + + msg.body[fragment.frag_offset : fragment.frag_offset + fragment.frag_len] = fragment.frag + + return messages + + +class RecordEncoder: + def __init__(self, first_record_seq): + self._record_seq = count(first_record_seq) + + def skip_first_record_number(self): + assert next(self._record_seq) == 0 + + def encode_volley(self, messages, mtu): + packets = [] + packet = bytearray() + for message in messages: + if isinstance(message, PseudoHandshakeMessage): + space = mtu - len(packet) - RECORD_HEADER.size - len(message.payload) + if space <= 0: + packets.append(packet) + packet = bytearray() + packet += RECORD_HEADER.pack( + message.content_type, + message.record_version, + next(self._record_seq), + len(message.payload), + ) + packet += payload + assert len(packet) <= mtu + else: + msg_len_bytes = message.msg_len.to_bytes(3, "big") + frag_offset = 0 + while frag_offset < len(message.body): + space = mtu - len(packet) - RECORD_HEADER.size - HANDSHAKE_MESSAGE_HEADER.size + if space <= 0: + packets.append(packet) + packet = bytearray() + continue + frag = message.body[frag_offset:frag_offset + space] + frag_offset_bytes = frag_offset.to_bytes(3, "big") + frag_len_bytes = len(frag).to_bytes(3, "big") + + packet += RECORD_HEADER.pack( + ContentType.handshake, + message.record_version, + next(self._record_seq), + HANDSHAKE_MESSAGE_HEADER.size + len(frag), + ) + + packet += HANDSHAKE_MESSAGE_HEADER.pack( + message.msg_type, + msg_len_bytes, + message.msg_seq, + frag_offset_bytes, + frag_len_bytes, + ) + + packet += frag + + assert len(packet) <= mtu + + if packet: + packets.append(packet) + + return packets + + +# This bit requires implementing a bona fide cryptographic protocol, so even though it's +# a simple one let's take a moment to discuss the design. +# +# Our goal is to force new incoming handshakes that claim to be coming from a +# given ip:port to prove that they can also receive packets sent to that +# ip:port. (There's nothing in UDP to stop someone from forging the return +# address, and it's often used for stuff like DoS reflection attacks, where +# an attacker tries to trick us into sending data at some innocent victim.) +# For more details, see: +# +# https://datatracker.ietf.org/doc/html/rfc6347#section-4.2.1 +# +# To do this, when we receive an initial ClientHello, we calculate a magic +# cookie, and send it back as a HelloVerifyRequest. Then the client sends us a +# second ClientHello, this time with the magic cookie included, and after we +# check that this cookie is valid we go ahead and start the handshake proper. +# +# So the magic cookie needs the following properties: +# - No-one can forge it without knowing our secret key +# - It ensures that the ip, port, and ClientHello contents from the response +# match those in the challenge +# - It expires after a short-ish period (so that if an attacker manages to steal one, it +# won't be useful for long) +# - It doesn't require storing any peer-specific state on our side +# +# To do that, we take the ip/port/ClientHello data and compute an HMAC of them, using a +# secret key we generate on startup. We also include: +# +# - The current time (using Trio's clock), rounded to the nearest 30 seconds +# - A random salt +# +# Then the cookie the salt + the HMAC digest. +# +# When verifying a cookie, we use the salt + new ip/port/ClientHello data to recompute +# the HMAC digest, for both the current time and the current time minus 30 seconds, and +# if either of them match, we consider the cookie good. +# +# Including the rounded-off time like this means that each cookie is good for at least +# 30 seconds, and possibly as much as 60 seconds. +# +# The salt is probably not necessary -- I'm pretty sure that all it does is make it hard +# for an attacker to figure out when our clock ticks over a 30 second boundary. Which is +# probably pretty harmless? But it's easier to add the salt than to convince myself that +# it's *completely* harmless, so, salt it is. + +COOKIE_REFRESH_INTERVAL = 30 # seconds +KEY = None +KEY_BYTES = 8 +COOKIE_HASH = "sha256" +SALT_BYTES = 8 + + +def _current_cookie_tick(): + return int(trio.current_time() / COOKIE_REFRESH_INTERVAL) + + +# Simple deterministic, bijective serializer -- i.e., a useful tool for hashing +# structured data. +def _signable(*fields): + out = [] + for field in fields: + out.append(struct.encode("!Q", len(field))) + out.append(field) + return b"".join(out) + + +def _make_cookie(salt, tick, address, client_hello_bits): + assert len(salt) == SALT_BYTES + + global KEY + if KEY is None: + KEY = os.urandom(KEY_BYTES) + + signable_data = _signable( + salt, + struct.encode("!Q", tick), + # address is a mix of strings and ints, and variable length, so pack + # it into a single nested field + _signable(*(str(part).encode() for part in address)), + client_hello_bits, + ) + + return salt + hmac.digest(KEY, signable_data, COOKIE_HASH) + + +def valid_cookie(cookie, address, client_hello_bits): + if len(cookie) > SALT_BYTES: + salt = cookie[:SALT_BYTES] + + cur_cookie = _make_cookie(salt, tick, address, client_hello_bits) + old_cookie = _make_cookie(salt, tick - 1, address, client_hello_bits) + + return ( + hmac.compare_digest(cookie, cur_cookie) + | hmac.compare_digest(cookie, old_cookie) + ) + else: + return False + + +def challenge_for(address, epoch_seqno, client_hello_bits): + salt = os.urandom(SALT_BYTES) + tick = _current_cookie_tick() + cookie = _make_cookie(salt, tick, address, client_hello_bits) + + # HelloVerifyRequest body is: + # - 2 bytes version + # - length-prefixed cookie + # + # The DTLS 1.2 spec says that for this message specifically we should use + # the DTLS 1.0 version. + # + # (It also says the opposite of that, but that part is a mistake: + # https://www.rfc-editor.org/errata/eid4103 + # ). + # + # And I guess we use this for both the message-level and record-level + # ProtocolVersions, since we haven't negotiated anything else yet? + body = ProtocolVersion.DTLS10 + bytes([len(cookie)]) + cookie + + # RFC says have to copy the client's record number + # Errata says it should be handshake message number + # Openssl copies back record sequence number, and always sets message seq + # number 0. So I guess we'll follow openssl. + hs = HandshakeFragment( + msg_type=HandshakeType.hello_verify_request, + msg_len=len(body), + msg_seq=0, + frag_offset=0, + frag_len=len(body), + frag_bytes=body, + ) + payload = encode_handshake_fragment(hs) + + packet = encode_record(Record(ContentType.handshake, + ProtocolVersion.DTLS10, epoch_seqno, payload)) + return packet + + +# when listening, definitely needs background tasks to handle reading+handshaking +# connect() (add_peer()?) should probably handle handshake directly +# in *theory* if only had client-mode, then could have multiple tasks calling +# add_peer() simultaneously + +# +# alternatively: add_peer/listen as sync functions to express intentions, plus +# user is expected to regularly be pumping the receive loop. Handshakes don't +# progress unless you're doing this. Each packet receive call runs through the +# whole DTLS state machine. +# +# Problem: in this model, handshake timeouts are a pain in the butt. (Need to +# track them in some kind of manual timer wheel etc.). We don't +# want to write an explicit state machine; we want to take advantage of trio's +# tools. +# +# Therefore, handshakes should have a host task. +# +# And if handshakes have a host task, then what? +# +# main thing with handshakes is that might want to read packets to handle them +# even if user isn't otherwise reading packets. Our two options are to either +# have handshakes running continuously and potentially drop data packets +# internally, or else to only process handshakes while the user is asking for +# data. +# +# I guess dropping is better? We can have an internal queue size + stats on +# what's dropped? +# +# and if we're going to do our own dropping, then it's ok to have a constant +# reader task running... + +# One option would be to have a fully connection-based API. Make a DTLS socket +# endpoint, and then call `await connect(...)` to get a DTLS socket +# connection, or `serve(task_fn)` so `task_fn` runs on each incoming +# handshake. Lifetimes are then clear. (GC shutdown still a bit of a pain, but +# whatever, can handle it.) +# +# Alternatively, lean into the single socket API. Handshakes happen in +# background, peers are identified by magic tokens returned from 'connect' or +# 'receive'. Need some synthetic event for "new client connected" to avoid the +# possibility of clients filling up connection table with immortal nonsense. +# Need way to forget specific peers. +# +# In both cases: what to do when in server mode, and a new connection +# *replaces* an existing connection? (or in mixed client/server mode +# similarly, I guess?) I guess for single-socket API you need a way to get +# these notifications anyway... for server mode maybe the old one gets marked +# closed and a new one is created? so need some way to signal EOF, which isn't +# a thing otherwise. (BrokenResourceError?) +# +# Also for future-DTLS (and QUIC etc.) there's connection migration, where the +# peer address changes. I guess the user doesn't necessarily care about +# getting notifications of this, as long as their connection remains bound to +# the right peer? It does rule out the use of connected UDP sockets though. + + +class _Queue: + def __init__(self, incoming_packets_buffer): + self._s, self._r = trio.open_memory_channel(incoming_packets_buffer) + + async def put(self, obj): + await self._s.send(obj) + + async def get(self): + return self._r.receive() + + +def _read_loop(read_fn): + chunks = [] + while True: + try: + chunk = read_fn(2 ** 14) # max TLS record size + except SSL.WantReadError: + break + if not chunk: + break + chunks.append(chunk) + return b"".join(chunks) + + +async def handle_client_hello_untrusted(dtls, address, packet): + if dtls._listening_context is None: + return + + try: + epoch_seqno, cookie, bits = decode_client_hello_untrusted(address, packet) + except BadPacket: + return + + if not valid_cookie(cookie, address, bits): + challenge_packet = challenge_for(address, epoch_seqno, bits) + try: + await dtls.sock.sendto(address, challenge_packet) + except OSError: + pass + else: + stream = DTLSStream(dtls, address, dtls._listening_context) + stream._inject_client_hello(packet) + old_stream = dlts._streams.get(address) + if old_stream is not None: + old_stream._break(RuntimeError("peer started a new DTLS connection")) + dtls._streams[address] = stream + dtls._incoming_connections_q.put_nowait(stream) + + +async def dtls_receive_loop(dtls): + sock = dtls.socket + dtls_ref = weakref.weakref(dtls) + del dtls + while True: + try: + address, packet = await sock.recvfrom() + except ClosedResourceError: + return + except OSError as exc: + dtls = dtls_ref() + if dtls is None: + return + dtls._break(exc) + return + # All of the following is sync, so we can be confident that our + # reference to dtls remains valid. + dtls = dtls_ref() + try: + if dtls is None: + return + if is_client_hello_untrusted(packet): + await handle_client_hello_untrusted(dtls, address, packet) + elif address in dtls._streams: + stream = dtls._streams[address] + if stream._did_handshake and part_of_handshake_untrusted(packet): + # The peer just sent us more handshake messages, that aren't a + # ClientHello, and we thought the handshake was done. Some of the + # packets that we sent to finish the handshake must have gotten + # lost. So re-send them. We do this directly here instead of just + # putting it into the queue, because there's no guarantee that + # anyone is reading from the queue, because we think the handshake + # is done! + await stream._resend_final_volley() + else: + try: + stream._q.put_nowait(packet) + except trio.WouldBlock: + stream.packets_dropped_in_trio += 1 + else: + # Drop packet + pass + finally: + del dtls + + +class DTLSStream(trio.abc.Channel[bytes], metaclass=NoPublicConstructor): + def __init__(self, dtls, peer_address, ctx): + self.dtls = dtls + self.peer_address = peer_address + self.packets_dropped_in_trio = 0 + self._mtu = 1472 # XX + self._did_handshake = False + self._ssl = SSL.Connection(ctx) + self._broken = False + self._closed = False + self._q = Queue(dtls.incoming_packets_buffer) + self._handshake_lock = trio.Lock() + self._record_encoder = RecordEncoder() + + def _break(self, reason: BaseException): + self._broken = True + self._broken_reason = reason + # XX wake things up + + def close(self): + if self._closed: + return + self._closed = True + if self.dtls._streams.get(self.peer_address) is self: + del self.dtls._streams[self.peer_address] + # Will wake any tasks waiting on self._q.get with a + # ClosedResourceError + self._q._r.close() + + async def aclose(self): + self.close() + await trio.lowlevel.checkpoint() + + def _inject_client_hello(self, packet): + stream._ssl.bio_write(packet) + # If we're on the server side, then we already sent record 0 as our cookie + # challenge. So we want to start the handshake proper with record 1. + self._record_encoder.skip_first_record_number() + + async def _send_volley(self, volley_messages): + packets = self._record_encoder(volley_messages, self._mtu) + for packet in packets: + await self.dtls.socket.sendto(self.peer_address, packet) + + async def _resend_final_volley(self): + await self._send_volley(self._final_volley) + + async def do_handshake(self): + async with self._handshake_lock: + if self._did_handshake: + return + # If we're a client, we send the initial volley. If we're a server, then + # the initial ClientHello has already been inserted into self._ssl's + # read BIO. So either way, we start by generating a new volley. + try: + self._ssl.do_handshake() + except SSL.WantReadError: + pass + + volley_messages = [] + def read_volley(): + volley_bytes = read_loop(self._ssl.bio_read) + new_volley_messages = decode_volley_trusted(volley_bytes) + if ( + new_volley_messages and volley_messages and + new_volley_messages[0].msg_seq == volley_messages[0].msg_seq + ): + # openssl decided to retransmit; discard because we'll handle this + # ourselves + return [] + else: + return new_volley_messages + + volley_messages = read_volley() + # If we don't have messages to send in our initial volley, then something + # has gone very wrong. (I'm not sure this can actually happen without an + # error from OpenSSL, but let's cover our bases.) + if not volley_messages: + self._break(SSL.Error("something wrong with peer's ClientHello")) + # XX raise + return + + while True: + assert volley_messages + await self._send_volley(volley_messages) + with trio.move_on_after(1) as cscope: + async for packet in self._q._r: + self._ssl.bio_write(packet) + try: + self._ssl.do_handshake() + except SSL.WantReadError: + pass + else: + # No exception -> the handshake is done, and we can + # switch into data transfer mode. + self._did_handshake = True + # Might be empty, but that's ok + self._final_volley = read_volley() + await self._send_volley(self._final_volley) + return + maybe_volley = read_volley() + if maybe_volley: + # We managed to get all of the peer's volley and generate a + # new one ourselves! break out of the 'for' loop and restart + # the timer. + volley_messages = new_volley + break + if cscope.cancelled_caught: + # timeout expired, adjust timeout/mtu + # Good guidance here: https://tlswg.org/dtls13-spec/draft-ietf-tls-dtls13.html#name-timer-values + XX + + async def send(self, data): + if not self._did_handshake: + await self.do_handshake() + self._ssl.write(data) + await self.dtls.socket.sendto(self.peer_address, read_loop(self._ssl.bio_read)) + + async def receive(self): + if not self._did_handshake: + await self.do_handshake() + packet = await self._q.get() + self._ssl.bio_write(packet) + return read_loop(self._ssl.read) + + +class DTLS: + def __init__(self, socket, *, incoming_packets_buffer=10): + if socket.type != trio.socket.SOCK_DGRAM: + raise BadPacket("DTLS requires a SOCK_DGRAM socket") + self.socket = socket + self.incoming_packets_buffer = incoming_packets_buffer + self._token = trio.lowlevel.current_trio_token() + # We don't need to track handshaking vs non-handshake connections + # separately. We only keep one connection per remote address; as soon + # as a peer provides a valid cookie, we can immediately tear down the + # old connection. + # {remote address: DTLSStream} + self._streams = weakref.WeakValueDictionary() + self._listening_context = None + self._incoming_connections_q = Queue(float("inf")) + + trio.lowlevel.spawn_system_task(dtls_receive_loop, self) + + def __del__(self): + # Close the socket in Trio context (if our Trio context still exists), so that + # the background task gets notified about the closure and can exit. + try: + self._token.run_sync_soon(self.socket.close) + except RuntimeError: + pass + + def close(self): + self.socket.close() + for stream in self._streams.values(): + stream.close() + self._incoming_connections_q._s.close() + + async def aclose(self): + self.close() + await trio.lowlevel.checkpoint() + + async def serve(self, ssl_context, async_fn, *args): + if self._listening_context is not None: + raise trio.BusyResourceError("another task is already listening") + try: + self._listening_context = ssl_context + async with trio.open_nursery() as nursery: + async for stream in self._incoming_connections_q._r: + nursery.start_soon(async_fn, stream, *args) + finally: + self._listening_context = None + + def _set_stream_for(self, address, stream): + old_stream = self._streams.get(address) + if old_stream is not None: + old_stream._break(RuntimeError("replaced by a new DTLS association")) + self._streams[address] = stream + + async def connect(self, address, ssl_context): + stream = DTLSStream(self, address, ssl_context) + self._set_stream_for(address, stream) + await stream.do_handshake() + return stream From 8c2fafa98100ad2a5ae038eb8f70092816a0474e Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Fri, 25 Jun 2021 02:55:56 -0700 Subject: [PATCH 02/47] smoke test passing! --- trio/_dtls.py | 291 ++++++++++++++++++++++------------------ trio/tests/test_dtls.py | 40 ++++++ 2 files changed, 203 insertions(+), 128 deletions(-) create mode 100644 trio/tests/test_dtls.py diff --git a/trio/_dtls.py b/trio/_dtls.py index b825fea5d6..4b738581e5 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -153,7 +153,7 @@ def encode_record(record): @attr.frozen class HandshakeFragment: msg_type: int - msg_len: int, + msg_len: int msg_seq: int frag_offset: int frag_len: int @@ -179,7 +179,7 @@ def decode_handshake_fragment_untrusted(payload): frag_len = int.from_bytes(frag_len_bytes, "big") frag = payload[HANDSHAKE_MESSAGE_HEADER.size:] if len(frag) != frag_len: - raise BadPacket("short fragment") + raise BadPacket("handshake fragment length doesn't match record length") return HandshakeFragment( msg_type, msg_len, @@ -198,7 +198,7 @@ def encode_handshake_fragment(hsf): hsf.frag_offset.to_bytes(3, "big"), hsf.frag_len.to_bytes(3, "big"), ) - return hs_header + hsf.body + return hs_header + hsf.frag def decode_client_hello_untrusted(packet): @@ -255,7 +255,7 @@ def decode_client_hello_untrusted(packet): cookie_start = cookie_len_offset + 1 cookie_end = cookie_start + cookie_len - before_cookie = body[:cookie_start] + before_cookie = body[:cookie_len_offset] cookie = body[cookie_start:cookie_end] after_cookie = body[cookie_end:] @@ -284,6 +284,16 @@ class PseudoHandshakeMessage: payload: bytes +# The final record in a handshake is Finished, which is encrypted, can't be fragmented +# (at least by us), and keeps its record number (because it's in a new epoch). So we +# just pass it through unchanged. (Fortunately, the payload is only a single hash value, +# so the largest it will ever be is 64 bytes for a 512-bit hash. Which is small enough +# that it never requires fragmenting to fit into a UDP packet. +@attr.frozen +class OpaqueHandshakeMessage: + record: Record + + # This takes a raw outgoing handshake volley that openssl generated, and # reconstructs the handshake messages inside it, so that we can repack them # into records while retransmitting. So the data ought to be well-behaved -- @@ -292,9 +302,18 @@ def decode_volley_trusted(volley): messages = [] messages_by_seq = {} for record in records_untrusted(volley): - if record.content_type == ContentType.change_cipher_spec: - messages.append(PseudoHandshakeMessage(record.version, - record.content_type, record.payload)) + # ChangeCipherSpec isn't a handshake message, so it can't be fragmented. + # Handshake messages with epoch > 0 are encrypted, so we can't fragment them + # either. Fortunately, ChangeCipherSpec has a 1 byte payload, and the only + # encrypted handshake message is Finished, whose payload is a single hash value + # -- so 32 bytes for SHA-256, 64 for SHA-512, etc. Neither is going to be so + # large that it has to be fragmented to fit into a single packet. + if record.epoch_seqno & EPOCH_MASK: + messages.append(OpaqueHandshakeMessage(record)) + elif record.content_type == ContentType.change_cipher_spec: + messages.append( + PseudoHandshakeMessage(record.version, record.content_type, record.payload) + ) else: assert record.content_type == ContentType.handshake fragment = decode_handshake_fragment_untrusted(record.payload) @@ -306,7 +325,7 @@ def decode_volley_trusted(volley): messages.append(msg) messages_by_seq[fragment.msg_seq] = msg else: - msg = messages_by_seq[msg_seq] + msg = messages_by_seq[fragment.msg_seq] assert msg.msg_type == fragment.msg_type assert msg.msg_seq == fragment.msg_seq assert len(msg.body) == fragment.msg_len @@ -317,8 +336,8 @@ def decode_volley_trusted(volley): class RecordEncoder: - def __init__(self, first_record_seq): - self._record_seq = count(first_record_seq) + def __init__(self): + self._record_seq = count() def skip_first_record_number(self): assert next(self._record_seq) == 0 @@ -327,7 +346,14 @@ def encode_volley(self, messages, mtu): packets = [] packet = bytearray() for message in messages: - if isinstance(message, PseudoHandshakeMessage): + if isinstance(message, OpaqueHandshakeMessage): + encoded = encode_record(message.record) + if mtu - len(packet) - len(encoded) <= 0: + packets.append(packet) + packet = bytearray() + packet += encoded + assert len(packet) <= mtu + elif isinstance(message, PseudoHandshakeMessage): space = mtu - len(packet) - RECORD_HEADER.size - len(message.payload) if space <= 0: packets.append(packet) @@ -338,12 +364,15 @@ def encode_volley(self, messages, mtu): next(self._record_seq), len(message.payload), ) - packet += payload + packet += message.payload assert len(packet) <= mtu else: - msg_len_bytes = message.msg_len.to_bytes(3, "big") + msg_len_bytes = len(message.body).to_bytes(3, "big") frag_offset = 0 - while frag_offset < len(message.body): + frags_encoded = 0 + # If message.body is empty, then we still want to encode it in one + # fragment, not zero. + while frag_offset < len(message.body) or not frags_encoded: space = mtu - len(packet) - RECORD_HEADER.size - HANDSHAKE_MESSAGE_HEADER.size if space <= 0: packets.append(packet) @@ -352,6 +381,7 @@ def encode_volley(self, messages, mtu): frag = message.body[frag_offset:frag_offset + space] frag_offset_bytes = frag_offset.to_bytes(3, "big") frag_len_bytes = len(frag).to_bytes(3, "big") + frag_offset += len(frag) packet += RECORD_HEADER.pack( ContentType.handshake, @@ -370,6 +400,7 @@ def encode_volley(self, messages, mtu): packet += frag + frags_encoded += 1 assert len(packet) <= mtu if packet: @@ -423,6 +454,10 @@ def encode_volley(self, messages, mtu): # probably pretty harmless? But it's easier to add the salt than to convince myself that # it's *completely* harmless, so, salt it is. +# XX maybe the cookie should also sign the *local* address, so you can't take a cookie +# from one socket and use it on another socket on the same trio? or just generate the +# key in each call to 'serve'. + COOKIE_REFRESH_INTERVAL = 30 # seconds KEY = None KEY_BYTES = 8 @@ -434,12 +469,12 @@ def _current_cookie_tick(): return int(trio.current_time() / COOKIE_REFRESH_INTERVAL) -# Simple deterministic, bijective serializer -- i.e., a useful tool for hashing -# structured data. +# Simple deterministic and invertible serializer -- i.e., a useful tool for converting +# structured data into something we can cryptographically sign. def _signable(*fields): out = [] for field in fields: - out.append(struct.encode("!Q", len(field))) + out.append(struct.pack("!Q", len(field))) out.append(field) return b"".join(out) @@ -453,7 +488,7 @@ def _make_cookie(salt, tick, address, client_hello_bits): signable_data = _signable( salt, - struct.encode("!Q", tick), + struct.pack("!Q", tick), # address is a mix of strings and ints, and variable length, so pack # it into a single nested field _signable(*(str(part).encode() for part in address)), @@ -467,9 +502,13 @@ def valid_cookie(cookie, address, client_hello_bits): if len(cookie) > SALT_BYTES: salt = cookie[:SALT_BYTES] + tick = _current_cookie_tick() + cur_cookie = _make_cookie(salt, tick, address, client_hello_bits) old_cookie = _make_cookie(salt, tick - 1, address, client_hello_bits) + # I doubt using a short-circuiting 'or' here would leak any meaningful + # information, but why risk it when '|' is just as easy. return ( hmac.compare_digest(cookie, cur_cookie) | hmac.compare_digest(cookie, old_cookie) @@ -508,7 +547,7 @@ def challenge_for(address, epoch_seqno, client_hello_bits): msg_seq=0, frag_offset=0, frag_len=len(body), - frag_bytes=body, + frag=body, ) payload = encode_handshake_fragment(hs) @@ -517,71 +556,9 @@ def challenge_for(address, epoch_seqno, client_hello_bits): return packet -# when listening, definitely needs background tasks to handle reading+handshaking -# connect() (add_peer()?) should probably handle handshake directly -# in *theory* if only had client-mode, then could have multiple tasks calling -# add_peer() simultaneously + -# -# alternatively: add_peer/listen as sync functions to express intentions, plus -# user is expected to regularly be pumping the receive loop. Handshakes don't -# progress unless you're doing this. Each packet receive call runs through the -# whole DTLS state machine. -# -# Problem: in this model, handshake timeouts are a pain in the butt. (Need to -# track them in some kind of manual timer wheel etc.). We don't -# want to write an explicit state machine; we want to take advantage of trio's -# tools. -# -# Therefore, handshakes should have a host task. -# -# And if handshakes have a host task, then what? -# -# main thing with handshakes is that might want to read packets to handle them -# even if user isn't otherwise reading packets. Our two options are to either -# have handshakes running continuously and potentially drop data packets -# internally, or else to only process handshakes while the user is asking for -# data. -# -# I guess dropping is better? We can have an internal queue size + stats on -# what's dropped? -# -# and if we're going to do our own dropping, then it's ok to have a constant -# reader task running... - -# One option would be to have a fully connection-based API. Make a DTLS socket -# endpoint, and then call `await connect(...)` to get a DTLS socket -# connection, or `serve(task_fn)` so `task_fn` runs on each incoming -# handshake. Lifetimes are then clear. (GC shutdown still a bit of a pain, but -# whatever, can handle it.) -# -# Alternatively, lean into the single socket API. Handshakes happen in -# background, peers are identified by magic tokens returned from 'connect' or -# 'receive'. Need some synthetic event for "new client connected" to avoid the -# possibility of clients filling up connection table with immortal nonsense. -# Need way to forget specific peers. -# -# In both cases: what to do when in server mode, and a new connection -# *replaces* an existing connection? (or in mixed client/server mode -# similarly, I guess?) I guess for single-socket API you need a way to get -# these notifications anyway... for server mode maybe the old one gets marked -# closed and a new one is created? so need some way to signal EOF, which isn't -# a thing otherwise. (BrokenResourceError?) -# -# Also for future-DTLS (and QUIC etc.) there's connection migration, where the -# peer address changes. I guess the user doesn't necessarily care about -# getting notifications of this, as long as their connection remains bound to -# the right peer? It does rule out the use of connected UDP sockets though. - - class _Queue: def __init__(self, incoming_packets_buffer): - self._s, self._r = trio.open_memory_channel(incoming_packets_buffer) - - async def put(self, obj): - await self._s.send(obj) - - async def get(self): - return self._r.receive() + self.s, self.r = trio.open_memory_channel(incoming_packets_buffer) def _read_loop(read_fn): @@ -602,43 +579,54 @@ async def handle_client_hello_untrusted(dtls, address, packet): return try: - epoch_seqno, cookie, bits = decode_client_hello_untrusted(address, packet) + epoch_seqno, cookie, bits = decode_client_hello_untrusted(packet) except BadPacket: return if not valid_cookie(cookie, address, bits): challenge_packet = challenge_for(address, epoch_seqno, bits) try: - await dtls.sock.sendto(address, challenge_packet) - except OSError: + async with dtls._send_lock: + await dtls.socket.sendto(challenge_packet, address) + except (OSError, trio.ClosedResourceError): pass else: - stream = DTLSStream(dtls, address, dtls._listening_context) - stream._inject_client_hello(packet) - old_stream = dlts._streams.get(address) + # We got a real, valid ClientHello! + stream = DTLSStream._create(dtls, address, dtls._listening_context) + try: + stream._inject_client_hello_untrusted(packet) + except BadPacket: + # ...or, well, OpenSSL didn't like it, so I guess we didn't. + return + old_stream = dtls._streams.get(address) if old_stream is not None: - old_stream._break(RuntimeError("peer started a new DTLS connection")) + if old_stream._client_hello == packet: + # ...but it's just a duplicate of a packet we got before, so never mind. + return + else: + # Ok, this *really is* a new handshake; the old stream should go away. + old_stream._replaced() dtls._streams[address] = stream - dtls._incoming_connections_q.put_nowait(stream) + dtls._incoming_connections_q.s.send_nowait(stream) async def dtls_receive_loop(dtls): sock = dtls.socket - dtls_ref = weakref.weakref(dtls) + dtls_ref = weakref.ref(dtls) del dtls while True: try: - address, packet = await sock.recvfrom() - except ClosedResourceError: + packet, address = await sock.recvfrom(MAX_UDP_PACKET_SIZE) + except trio.ClosedResourceError: return except OSError as exc: + # XX need to handle this better + # https://bobobobo.wordpress.com/2009/05/17/udp-an-existing-connection-was-forcibly-closed-by-the-remote-host/ dtls = dtls_ref() if dtls is None: return dtls._break(exc) return - # All of the following is sync, so we can be confident that our - # reference to dtls remains valid. dtls = dtls_ref() try: if dtls is None: @@ -652,13 +640,16 @@ async def dtls_receive_loop(dtls): # ClientHello, and we thought the handshake was done. Some of the # packets that we sent to finish the handshake must have gotten # lost. So re-send them. We do this directly here instead of just - # putting it into the queue, because there's no guarantee that - # anyone is reading from the queue, because we think the handshake - # is done! - await stream._resend_final_volley() + # putting it into the queue and letting the receiver do it, because + # there's no guarantee that anyone is reading from the queue, + # because we think the handshake is done! + try: + await stream._resend_final_volley() + except trio.ClosedResourceError: + return else: try: - stream._q.put_nowait(packet) + stream._q.s.send_nowait(packet) except trio.WouldBlock: stream.packets_dropped_in_trio += 1 else: @@ -674,18 +665,30 @@ def __init__(self, dtls, peer_address, ctx): self.peer_address = peer_address self.packets_dropped_in_trio = 0 self._mtu = 1472 # XX + self._client_hello = None self._did_handshake = False self._ssl = SSL.Connection(ctx) - self._broken = False + # Arbitrary b/c we repack messages anyway, but has to be set + self._ssl.set_ciphertext_mtu(1500) + self._replaced = False self._closed = False - self._q = Queue(dtls.incoming_packets_buffer) + self._q = _Queue(dtls.incoming_packets_buffer) self._handshake_lock = trio.Lock() self._record_encoder = RecordEncoder() - def _break(self, reason: BaseException): - self._broken = True - self._broken_reason = reason - # XX wake things up + def _replaced(self): + self._replaced = True + # Any packets we already received could maybe possibly still be processed, but + # there are no more coming. So we close this on the sender side. + self._q.s.close() + + def _check_replaced(self): + if self._replaced: + raise BrokenResourceError("peer tore down this connection to start a new one") + + # XX expose knobs and queries for MTU information + # XX on systems where we can (maybe just Linux?) take advantage of the kernel's PMTU + # estimate def close(self): if self._closed: @@ -695,22 +698,37 @@ def close(self): del self.dtls._streams[self.peer_address] # Will wake any tasks waiting on self._q.get with a # ClosedResourceError - self._q._r.close() + self._q.r.close() async def aclose(self): self.close() await trio.lowlevel.checkpoint() - def _inject_client_hello(self, packet): - stream._ssl.bio_write(packet) + def _inject_client_hello_untrusted(self, packet): + self._client_hello = packet + self._ssl.bio_write(packet) # If we're on the server side, then we already sent record 0 as our cookie # challenge. So we want to start the handshake proper with record 1. self._record_encoder.skip_first_record_number() + # We've already validated this cookie. But, we still have to call DTLSv1_listen + # so OpenSSL thinks that it's verified the cookie. The problem is that + # if you're doing cookie challenges, then the actual ClientHello has msg_seq=1 + # instead of msg_seq=0, and OpenSSL will refuse to process a ClientHello with + # msg_seq=1 unless you've called DTLSv1_listen. It also gets OpenSSL to bump the + # outgoing ServerHello's msg_seq to 1. + try: + self._ssl.DTLSv1_listen() + except SSL.Error: + raise BadPacket async def _send_volley(self, volley_messages): - packets = self._record_encoder(volley_messages, self._mtu) + packets = self._record_encoder.encode_volley(volley_messages, self._mtu) + # XX debug + # decoded = decode_volley_trusted(b"".join(packets)) + # assert decoded == volley_messages for packet in packets: - await self.dtls.socket.sendto(self.peer_address, packet) + async with self.dtls._send_lock: + await self.dtls.socket.sendto(packet, self.peer_address) async def _resend_final_volley(self): await self._send_volley(self._final_volley) @@ -729,7 +747,7 @@ async def do_handshake(self): volley_messages = [] def read_volley(): - volley_bytes = read_loop(self._ssl.bio_read) + volley_bytes = _read_loop(self._ssl.bio_read) new_volley_messages = decode_volley_trusted(volley_bytes) if ( new_volley_messages and volley_messages and @@ -746,15 +764,13 @@ def read_volley(): # has gone very wrong. (I'm not sure this can actually happen without an # error from OpenSSL, but let's cover our bases.) if not volley_messages: - self._break(SSL.Error("something wrong with peer's ClientHello")) - # XX raise - return + raise SSL.Error("something wrong with peer's ClientHello") while True: assert volley_messages await self._send_volley(volley_messages) - with trio.move_on_after(1) as cscope: - async for packet in self._q._r: + with trio.move_on_after(10) as cscope: + async for packet in self._q.r: self._ssl.bio_write(packet) try: self._ssl.do_handshake() @@ -773,25 +789,36 @@ def read_volley(): # We managed to get all of the peer's volley and generate a # new one ourselves! break out of the 'for' loop and restart # the timer. - volley_messages = new_volley + volley_messages = maybe_volley break + else: + assert self._replaced + self._check_replaced() if cscope.cancelled_caught: # timeout expired, adjust timeout/mtu # Good guidance here: https://tlswg.org/dtls13-spec/draft-ietf-tls-dtls13.html#name-timer-values XX async def send(self, data): + if self._closed: + raise trio.ClosedResourceError if not self._did_handshake: await self.do_handshake() + self._check_replaced() self._ssl.write(data) - await self.dtls.socket.sendto(self.peer_address, read_loop(self._ssl.bio_read)) + async with self.dtls._send_lock: + await self.dtls.socket.sendto(_read_loop(self._ssl.bio_read), self.peer_address) async def receive(self): if not self._did_handshake: await self.do_handshake() - packet = await self._q.get() + try: + packet = await self._q.r.receive() + except trio.EndOfChannel: + assert self._replaced + self._check_replaced() self._ssl.bio_write(packet) - return read_loop(self._ssl.read) + return _read_loop(self._ssl.read) class DTLS: @@ -808,7 +835,8 @@ def __init__(self, socket, *, incoming_packets_buffer=10): # {remote address: DTLSStream} self._streams = weakref.WeakValueDictionary() self._listening_context = None - self._incoming_connections_q = Queue(float("inf")) + self._incoming_connections_q = _Queue(float("inf")) + self._send_lock = trio.Lock() trio.lowlevel.spawn_system_task(dtls_receive_loop, self) @@ -824,19 +852,25 @@ def close(self): self.socket.close() for stream in self._streams.values(): stream.close() - self._incoming_connections_q._s.close() + self._incoming_connections_q.s.close() - async def aclose(self): + def __enter__(self): + return self + + def __exit__(self, *args): self.close() - await trio.lowlevel.checkpoint() - async def serve(self, ssl_context, async_fn, *args): + async def serve(self, ssl_context, async_fn, *args, task_status=trio.TASK_STATUS_IGNORED): if self._listening_context is not None: raise trio.BusyResourceError("another task is already listening") + # We do cookie verification ourselves, so tell OpenSSL not to worry about it. + # (See also _inject_client_hello_untrusted.) + ssl_context.set_cookie_verify_callback(lambda *_: True) try: self._listening_context = ssl_context + task_status.started() async with trio.open_nursery() as nursery: - async for stream in self._incoming_connections_q._r: + async for stream in self._incoming_connections_q.r: nursery.start_soon(async_fn, stream, *args) finally: self._listening_context = None @@ -848,7 +882,8 @@ def _set_stream_for(self, address, stream): self._streams[address] = stream async def connect(self, address, ssl_context): - stream = DTLSStream(self, address, ssl_context) + stream = DTLSStream._create(self, address, ssl_context) + stream._ssl.set_connect_state() self._set_stream_for(address, stream) await stream.do_handshake() return stream diff --git a/trio/tests/test_dtls.py b/trio/tests/test_dtls.py new file mode 100644 index 0000000000..5e7f70ad55 --- /dev/null +++ b/trio/tests/test_dtls.py @@ -0,0 +1,40 @@ +import trio +from trio._dtls import DTLS + +import trustme +from OpenSSL import SSL + +ca = trustme.CA() +server_cert = ca.issue_cert("example.com") + +server_ctx = SSL.Context(SSL.DTLS_METHOD) +server_cert.configure_cert(server_ctx) + +client_ctx = SSL.Context(SSL.DTLS_METHOD) +ca.configure_trust(client_ctx) + +# XX this should be handled in the real code +server_ctx.set_options(SSL.OP_NO_QUERY_MTU | SSL.OP_NO_RENEGOTIATION) +client_ctx.set_options(SSL.OP_NO_QUERY_MTU | SSL.OP_NO_RENEGOTIATION) + +async def test_smoke(): + server_sock = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + with server_sock: + await server_sock.bind(("127.0.0.1", 54321)) + server_dtls = DTLS(server_sock) + + async with trio.open_nursery() as nursery: + + async def handle_client(dtls_stream): + await dtls_stream.do_handshake() + assert await dtls_stream.receive() == b"hello" + await dtls_stream.send(b"goodbye") + + await nursery.start(server_dtls.serve, server_ctx, handle_client) + + client_sock = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + client_dtls = DTLS(client_sock) + client = await client_dtls.connect(server_sock.getsockname(), client_ctx) + await client.send(b"hello") + assert await client.receive() == b"goodbye" + nursery.cancel_scope.cancel() From 66595f56393c0d575600b87ac361e43649436576 Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Fri, 25 Jun 2021 02:58:27 -0700 Subject: [PATCH 03/47] Move required SSL OP_* settings into the proper place --- trio/_dtls.py | 6 ++++++ trio/tests/test_dtls.py | 4 ---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/trio/_dtls.py b/trio/_dtls.py index 4b738581e5..c37c9ba20e 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -667,6 +667,12 @@ def __init__(self, dtls, peer_address, ctx): self._mtu = 1472 # XX self._client_hello = None self._did_handshake = False + # These are mandatory for all DTLS connections. OP_NO_QUERY_MTU is required to + # stop openssl from trying to query the memory BIO's MTU and then breaking, and + # OP_NO_RENEGOTIATION disables renegotiation, which is too complex for us to + # support and isn't useful anyway -- especially for DTLS where it's equivalent + # to just performing a new handshake. + ctx.set_options(SSL.OP_NO_QUERY_MTU | SSL.OP_NO_RENEGOTIATION) self._ssl = SSL.Connection(ctx) # Arbitrary b/c we repack messages anyway, but has to be set self._ssl.set_ciphertext_mtu(1500) diff --git a/trio/tests/test_dtls.py b/trio/tests/test_dtls.py index 5e7f70ad55..5149e8d56f 100644 --- a/trio/tests/test_dtls.py +++ b/trio/tests/test_dtls.py @@ -13,10 +13,6 @@ client_ctx = SSL.Context(SSL.DTLS_METHOD) ca.configure_trust(client_ctx) -# XX this should be handled in the real code -server_ctx.set_options(SSL.OP_NO_QUERY_MTU | SSL.OP_NO_RENEGOTIATION) -client_ctx.set_options(SSL.OP_NO_QUERY_MTU | SSL.OP_NO_RENEGOTIATION) - async def test_smoke(): server_sock = trio.socket.socket(type=trio.socket.SOCK_DGRAM) with server_sock: From 641494a0e6c8308c015d57dae4aea88b8509f97d Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Fri, 25 Jun 2021 05:32:59 -0700 Subject: [PATCH 04/47] All logic implemented, I think (probably not all correct though) --- trio/_dtls.py | 104 ++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 76 insertions(+), 28 deletions(-) diff --git a/trio/_dtls.py b/trio/_dtls.py index c37c9ba20e..f9e9363863 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -20,6 +20,24 @@ MAX_UDP_PACKET_SIZE = 65527 +def packet_header_overhead(sock): + if sock.family == trio.socket.AF_INET: + return 28 + else: + return 48 + + +def worst_case_mtu(sock): + if sock.family == trio.socket.AF_INET: + return 576 - packet_header_overhead(sock) + else: + return 1280 - packet_header_overhead(sock) + + +def best_guess_mtu(sock): + return 1500 - packet_header_overhead(sock) + + # There are a bunch of different RFCs that define these codes, so for a # comprehensive collection look here: # https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml @@ -620,13 +638,16 @@ async def dtls_receive_loop(dtls): except trio.ClosedResourceError: return except OSError as exc: - # XX need to handle this better - # https://bobobobo.wordpress.com/2009/05/17/udp-an-existing-connection-was-forcibly-closed-by-the-remote-host/ - dtls = dtls_ref() - if dtls is None: + if exc.errno in (errno.EBADF, errno.ENOTSOCK): + # Socket was closed return - dtls._break(exc) - return + else: + # Some weird error, e.g. apparently some versions of Windows can do + # ECONNRESET here to report that some previous UDP packet got an ICMP + # Port Unreachable: + # https://bobobobo.wordpress.com/2009/05/17/udp-an-existing-connection-was-forcibly-closed-by-the-remote-host/ + # We'll assume that whatever it is, it's a transient problem. + continue dtls = dtls_ref() try: if dtls is None: @@ -664,7 +685,6 @@ def __init__(self, dtls, peer_address, ctx): self.dtls = dtls self.peer_address = peer_address self.packets_dropped_in_trio = 0 - self._mtu = 1472 # XX self._client_hello = None self._did_handshake = False # These are mandatory for all DTLS connections. OP_NO_QUERY_MTU is required to @@ -674,8 +694,10 @@ def __init__(self, dtls, peer_address, ctx): # to just performing a new handshake. ctx.set_options(SSL.OP_NO_QUERY_MTU | SSL.OP_NO_RENEGOTIATION) self._ssl = SSL.Connection(ctx) - # Arbitrary b/c we repack messages anyway, but has to be set - self._ssl.set_ciphertext_mtu(1500) + self._mtu = None + # This calls self._ssl.set_ciphertext_mtu, which is important, because if you + # don't call it then openssl doesn't work. + self.set_ciphertext_mtu(best_guess_mtu(self.dtls.socket)) self._replaced = False self._closed = False self._q = _Queue(dtls.incoming_packets_buffer) @@ -692,10 +714,20 @@ def _check_replaced(self): if self._replaced: raise BrokenResourceError("peer tore down this connection to start a new one") - # XX expose knobs and queries for MTU information + def set_ciphertext_mtu(self, new_mtu): + self._mtu = new_mtu + self._ssl.set_ciphertext_mtu(new_mtu) + + def get_cleartext_mtu(self): + return self._ssl.get_cleartext_mtu() + # XX on systems where we can (maybe just Linux?) take advantage of the kernel's PMTU # estimate + # XX should we send close-notify when closing? It seems particularly pointless for + # DTLS where packets are all independent and can be lost anyway. We do at least need + # to handle receiving it properly though, which might be easier if we send it... + def close(self): if self._closed: return @@ -729,9 +761,6 @@ def _inject_client_hello_untrusted(self, packet): async def _send_volley(self, volley_messages): packets = self._record_encoder.encode_volley(volley_messages, self._mtu) - # XX debug - # decoded = decode_volley_trusted(b"".join(packets)) - # assert decoded == volley_messages for packet in packets: async with self.dtls._send_lock: await self.dtls.socket.sendto(packet, self.peer_address) @@ -739,19 +768,15 @@ async def _send_volley(self, volley_messages): async def _resend_final_volley(self): await self._send_volley(self._final_volley) - async def do_handshake(self): + async def do_handshake(self, *, initial_retransmit_timeout=1.0): async with self._handshake_lock: if self._did_handshake: return - # If we're a client, we send the initial volley. If we're a server, then - # the initial ClientHello has already been inserted into self._ssl's - # read BIO. So either way, we start by generating a new volley. - try: - self._ssl.do_handshake() - except SSL.WantReadError: - pass + timeout = initial_retransmit_timeout volley_messages = [] + volley_failed_sends = 0 + def read_volley(): volley_bytes = _read_loop(self._ssl.bio_read) new_volley_messages = decode_volley_trusted(volley_bytes) @@ -759,22 +784,31 @@ def read_volley(): new_volley_messages and volley_messages and new_volley_messages[0].msg_seq == volley_messages[0].msg_seq ): - # openssl decided to retransmit; discard because we'll handle this - # ourselves + # openssl decided to retransmit; discard because we handle + # retransmits ourselves return [] else: return new_volley_messages + # If we're a client, we send the initial volley. If we're a server, then + # the initial ClientHello has already been inserted into self._ssl's + # read BIO. So either way, we start by generating a new volley. + try: + self._ssl.do_handshake() + except SSL.WantReadError: + pass volley_messages = read_volley() # If we don't have messages to send in our initial volley, then something # has gone very wrong. (I'm not sure this can actually happen without an - # error from OpenSSL, but let's cover our bases.) + # error from OpenSSL, but we check just in case.) if not volley_messages: raise SSL.Error("something wrong with peer's ClientHello") while True: + # -- at this point, we need to either send or re-send a volley -- assert volley_messages await self._send_volley(volley_messages) + # -- then this is where we wait for a reply -- with trio.move_on_after(10) as cscope: async for packet in self._q.r: self._ssl.bio_write(packet) @@ -786,7 +820,8 @@ def read_volley(): # No exception -> the handshake is done, and we can # switch into data transfer mode. self._did_handshake = True - # Might be empty, but that's ok + # Might be empty, but that's ok -- we'll just send no + # packets. self._final_volley = read_volley() await self._send_volley(self._final_volley) return @@ -796,14 +831,27 @@ def read_volley(): # new one ourselves! break out of the 'for' loop and restart # the timer. volley_messages = maybe_volley + # "Implementations SHOULD retain the current timer value + # until a transmission without loss occurs, at which time + # the value may be reset to the initial value." + if volley_failed_sends == 0: + timeout = initial_retransmit_timeout + volley_failed_sends = 0 break else: assert self._replaced self._check_replaced() if cscope.cancelled_caught: - # timeout expired, adjust timeout/mtu - # Good guidance here: https://tlswg.org/dtls13-spec/draft-ietf-tls-dtls13.html#name-timer-values - XX + # Timeout expired. Double timeout for backoff, with a limit of 60 + # seconds (this matches what openssl does, and also the + # recommendation in draft-ietf-tls-dtls13). + timeout = min(2 * timeout, 60.0) + volley_failed_sends += 1 + if volley_failed_sends == 2: + # We tried sending this twice and they both failed. Maybe our + # PMTU estimate is wrong? Let's try dropping it to the minimum + # and hope that helps. + self.set_ciphertext_mtu(min(self._mtu, worst_case_mtu(self.dtls.socket))) async def send(self, data): if self._closed: From d581e35988e498004218e9e4f3981fb4cf2bfe10 Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Sat, 26 Jun 2021 02:11:09 -0700 Subject: [PATCH 05/47] Run black --- trio/_dtls.py | 72 ++++++++++++++++++++++++----------------- trio/tests/test_dtls.py | 1 + 2 files changed, 43 insertions(+), 30 deletions(-) diff --git a/trio/_dtls.py b/trio/_dtls.py index f9e9363863..4a841a48d4 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -20,6 +20,7 @@ MAX_UDP_PACKET_SIZE = 65527 + def packet_header_overhead(sock): if sock.family == trio.socket.AF_INET: return 28 @@ -78,7 +79,7 @@ class ProtocolVersion: DTLS12 = bytes([254, 253]) -EPOCH_MASK = 0xffff << (6 * 8) +EPOCH_MASK = 0xFFFF << (6 * 8) # Conventions: @@ -151,10 +152,7 @@ def records_untrusted(packet): def encode_record(record): header = RECORD_HEADER.pack( - record.content_type, - record.version, - record.epoch_seqno, - len(record.payload), + record.content_type, record.version, record.epoch_seqno, len(record.payload), ) return header + record.payload @@ -195,17 +193,10 @@ def decode_handshake_fragment_untrusted(payload): msg_len = int.from_bytes(msg_len_bytes, "big") frag_offset = int.from_bytes(frag_offset_bytes, "big") frag_len = int.from_bytes(frag_len_bytes, "big") - frag = payload[HANDSHAKE_MESSAGE_HEADER.size:] + frag = payload[HANDSHAKE_MESSAGE_HEADER.size :] if len(frag) != frag_len: raise BadPacket("handshake fragment length doesn't match record length") - return HandshakeFragment( - msg_type, - msg_len, - msg_seq, - frag_offset, - frag_len, - frag, - ) + return HandshakeFragment(msg_type, msg_len, msg_seq, frag_offset, frag_len, frag,) def encode_handshake_fragment(hsf): @@ -330,7 +321,9 @@ def decode_volley_trusted(volley): messages.append(OpaqueHandshakeMessage(record)) elif record.content_type == ContentType.change_cipher_spec: messages.append( - PseudoHandshakeMessage(record.version, record.content_type, record.payload) + PseudoHandshakeMessage( + record.version, record.content_type, record.payload + ) ) else: assert record.content_type == ContentType.handshake @@ -338,7 +331,10 @@ def decode_volley_trusted(volley): msg_type = HandshakeType(fragment.msg_type) if fragment.msg_seq not in messages_by_seq: msg = HandshakeMessage( - record.version, msg_type, fragment.msg_seq, bytearray(fragment.msg_len) + record.version, + msg_type, + fragment.msg_seq, + bytearray(fragment.msg_len), ) messages.append(msg) messages_by_seq[fragment.msg_seq] = msg @@ -348,7 +344,9 @@ def decode_volley_trusted(volley): assert msg.msg_seq == fragment.msg_seq assert len(msg.body) == fragment.msg_len - msg.body[fragment.frag_offset : fragment.frag_offset + fragment.frag_len] = fragment.frag + msg.body[ + fragment.frag_offset : fragment.frag_offset + fragment.frag_len + ] = fragment.frag return messages @@ -391,12 +389,17 @@ def encode_volley(self, messages, mtu): # If message.body is empty, then we still want to encode it in one # fragment, not zero. while frag_offset < len(message.body) or not frags_encoded: - space = mtu - len(packet) - RECORD_HEADER.size - HANDSHAKE_MESSAGE_HEADER.size + space = ( + mtu + - len(packet) + - RECORD_HEADER.size + - HANDSHAKE_MESSAGE_HEADER.size + ) if space <= 0: packets.append(packet) packet = bytearray() continue - frag = message.body[frag_offset:frag_offset + space] + frag = message.body[frag_offset : frag_offset + space] frag_offset_bytes = frag_offset.to_bytes(3, "big") frag_len_bytes = len(frag).to_bytes(3, "big") frag_offset += len(frag) @@ -527,9 +530,8 @@ def valid_cookie(cookie, address, client_hello_bits): # I doubt using a short-circuiting 'or' here would leak any meaningful # information, but why risk it when '|' is just as easy. - return ( - hmac.compare_digest(cookie, cur_cookie) - | hmac.compare_digest(cookie, old_cookie) + return hmac.compare_digest(cookie, cur_cookie) | hmac.compare_digest( + cookie, old_cookie ) else: return False @@ -569,8 +571,9 @@ def challenge_for(address, epoch_seqno, client_hello_bits): ) payload = encode_handshake_fragment(hs) - packet = encode_record(Record(ContentType.handshake, - ProtocolVersion.DTLS10, epoch_seqno, payload)) + packet = encode_record( + Record(ContentType.handshake, ProtocolVersion.DTLS10, epoch_seqno, payload) + ) return packet @@ -712,7 +715,9 @@ def _replaced(self): def _check_replaced(self): if self._replaced: - raise BrokenResourceError("peer tore down this connection to start a new one") + raise BrokenResourceError( + "peer tore down this connection to start a new one" + ) def set_ciphertext_mtu(self, new_mtu): self._mtu = new_mtu @@ -781,8 +786,9 @@ def read_volley(): volley_bytes = _read_loop(self._ssl.bio_read) new_volley_messages = decode_volley_trusted(volley_bytes) if ( - new_volley_messages and volley_messages and - new_volley_messages[0].msg_seq == volley_messages[0].msg_seq + new_volley_messages + and volley_messages + and new_volley_messages[0].msg_seq == volley_messages[0].msg_seq ): # openssl decided to retransmit; discard because we handle # retransmits ourselves @@ -851,7 +857,9 @@ def read_volley(): # We tried sending this twice and they both failed. Maybe our # PMTU estimate is wrong? Let's try dropping it to the minimum # and hope that helps. - self.set_ciphertext_mtu(min(self._mtu, worst_case_mtu(self.dtls.socket))) + self.set_ciphertext_mtu( + min(self._mtu, worst_case_mtu(self.dtls.socket)) + ) async def send(self, data): if self._closed: @@ -861,7 +869,9 @@ async def send(self, data): self._check_replaced() self._ssl.write(data) async with self.dtls._send_lock: - await self.dtls.socket.sendto(_read_loop(self._ssl.bio_read), self.peer_address) + await self.dtls.socket.sendto( + _read_loop(self._ssl.bio_read), self.peer_address + ) async def receive(self): if not self._did_handshake: @@ -914,7 +924,9 @@ def __enter__(self): def __exit__(self, *args): self.close() - async def serve(self, ssl_context, async_fn, *args, task_status=trio.TASK_STATUS_IGNORED): + async def serve( + self, ssl_context, async_fn, *args, task_status=trio.TASK_STATUS_IGNORED + ): if self._listening_context is not None: raise trio.BusyResourceError("another task is already listening") # We do cookie verification ourselves, so tell OpenSSL not to worry about it. diff --git a/trio/tests/test_dtls.py b/trio/tests/test_dtls.py index 5149e8d56f..783a9e5e8f 100644 --- a/trio/tests/test_dtls.py +++ b/trio/tests/test_dtls.py @@ -13,6 +13,7 @@ client_ctx = SSL.Context(SSL.DTLS_METHOD) ca.configure_trust(client_ctx) + async def test_smoke(): server_sock = trio.socket.socket(type=trio.socket.SOCK_DGRAM) with server_sock: From 4903a6133e446c2b7723ccdb3ae576ed3ee36eb0 Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Wed, 30 Jun 2021 23:45:02 -0700 Subject: [PATCH 06/47] Delay importing OpenSSL.SSL until a DTLS object is constructed --- trio/__init__.py | 2 ++ trio/_dtls.py | 6 +++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/trio/__init__.py b/trio/__init__.py index a50ec33310..b4b77f7fa9 100644 --- a/trio/__init__.py +++ b/trio/__init__.py @@ -74,6 +74,8 @@ from ._ssl import SSLStream, SSLListener, NeedHandshakeError +from ._dtls import DTLS, DTLSStream + from ._highlevel_serve_listeners import serve_listeners from ._highlevel_open_tcp_stream import open_tcp_stream diff --git a/trio/_dtls.py b/trio/_dtls.py index 4a841a48d4..a3650c3598 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -13,7 +13,6 @@ import weakref import attr -from OpenSSL import SSL import trio from trio._util import NoPublicConstructor @@ -887,6 +886,11 @@ async def receive(self): class DTLS: def __init__(self, socket, *, incoming_packets_buffer=10): + # We do this lazily on first construction, so only people who actually use DTLS + # have to install PyOpenSSL. + global SSL + from OpenSSL import SSL + if socket.type != trio.socket.SOCK_DGRAM: raise BadPacket("DTLS requires a SOCK_DGRAM socket") self.socket = socket From 9304e3860a6df146f5346a228b7a1495f2d59ae8 Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Fri, 3 Sep 2021 09:33:42 -0700 Subject: [PATCH 07/47] Close test socket properly --- trio/tests/test_dtls.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/trio/tests/test_dtls.py b/trio/tests/test_dtls.py index 783a9e5e8f..8fcf24a112 100644 --- a/trio/tests/test_dtls.py +++ b/trio/tests/test_dtls.py @@ -29,9 +29,9 @@ async def handle_client(dtls_stream): await nursery.start(server_dtls.serve, server_ctx, handle_client) - client_sock = trio.socket.socket(type=trio.socket.SOCK_DGRAM) - client_dtls = DTLS(client_sock) - client = await client_dtls.connect(server_sock.getsockname(), client_ctx) - await client.send(b"hello") - assert await client.receive() == b"goodbye" - nursery.cancel_scope.cancel() + with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as client_sock: + client_dtls = DTLS(client_sock) + client = await client_dtls.connect(server_sock.getsockname(), client_ctx) + await client.send(b"hello") + assert await client.receive() == b"goodbye" + nursery.cancel_scope.cancel() From cd7bf54ec035057e95ae08cece784abe3a6d79e5 Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Fri, 3 Sep 2021 09:33:52 -0700 Subject: [PATCH 08/47] mark DTLS class as final --- trio/_dtls.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trio/_dtls.py b/trio/_dtls.py index a3650c3598..6a0ca3edb2 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -15,7 +15,7 @@ import attr import trio -from trio._util import NoPublicConstructor +from trio._util import NoPublicConstructor, Final MAX_UDP_PACKET_SIZE = 65527 @@ -884,7 +884,7 @@ async def receive(self): return _read_loop(self._ssl.read) -class DTLS: +class DTLS(metaclass=Final): def __init__(self, socket, *, incoming_packets_buffer=10): # We do this lazily on first construction, so only people who actually use DTLS # have to install PyOpenSSL. From c44f29fccd59acfdb4ce887f397dcf3638e33ce9 Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Fri, 3 Sep 2021 09:34:21 -0700 Subject: [PATCH 09/47] Refactor socket address resolution Split out the address resolution code so that fake net can reuse it more easily --- trio/_socket.py | 174 +++++++++++++++++++------------------- trio/tests/test_socket.py | 21 +++-- 2 files changed, 98 insertions(+), 97 deletions(-) diff --git a/trio/_socket.py b/trio/_socket.py index 886f5614f6..3747bfe4d5 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -349,6 +349,82 @@ async def wrapper(self, *args, **kwargs): return wrapper +# Helpers to work with the (hostname, port) language that Python uses for socket +# addresses everywhere. Split out into a standalone function so it can be reused by +# FakeNet. + +# Take an address in Python's representation, and returns a new address in +# the same representation, but with names resolved to numbers, +# etc. +# +# local=True means that the address is being used with bind() or similar +# local=False means that the address is being used with connect() or sendto() or +# similar. +# +# NOTE: this function does not always checkpoint +async def _resolve_address_nocp(type, family, proto, *, ipv6_v6only, address, local): + # Do some pre-checking (or exit early for non-IP sockets) + if family == _stdlib_socket.AF_INET: + if not isinstance(address, tuple) or not len(address) == 2: + raise ValueError("address should be a (host, port) tuple") + elif family == _stdlib_socket.AF_INET6: + if not isinstance(address, tuple) or not 2 <= len(address) <= 4: + raise ValueError( + "address should be a (host, port, [flowinfo, [scopeid]]) tuple" + ) + elif family == _stdlib_socket.AF_UNIX: + # unwrap path-likes + return os.fspath(address) + else: + return address + + # -- From here on we know we have IPv4 or IPV6 -- + host, port, *_ = address + # Fast path for the simple case: already-resolved IP address, + # already-resolved port. This is particularly important for UDP, since + # every sendto call goes through here. + if isinstance(port, int): + try: + _stdlib_socket.inet_pton(family, address[0]) + except (OSError, TypeError): + pass + else: + return address + # Special cases to match the stdlib, see gh-277 + if host == "": + host = None + if host == "": + host = "255.255.255.255" + flags = 0 + if local: + flags |= _stdlib_socket.AI_PASSIVE + # Since we always pass in an explicit family here, AI_ADDRCONFIG + # doesn't add any value -- if we have no ipv6 connectivity and are + # working with an ipv6 socket, then things will break soon enough! And + # if we do enable it, then it makes it impossible to even run tests + # for ipv6 address resolution on travis-ci, which as of 2017-03-07 has + # no ipv6. + # flags |= AI_ADDRCONFIG + if family == _stdlib_socket.AF_INET6 and not ipv6_v6only: + flags |= _stdlib_socket.AI_V4MAPPED + gai_res = await getaddrinfo(host, port, family, type, proto, flags) + # AFAICT from the spec it's not possible for getaddrinfo to return an + # empty list. + assert len(gai_res) >= 1 + # Address is the last item in the first entry + (*_, normed), *_ = gai_res + # The above ignored any flowid and scopeid in the passed-in address, + # so restore them if present: + if family == _stdlib_socket.AF_INET6: + normed = list(normed) + assert len(normed) == 4 + if len(address) >= 3: + normed[2] = address[2] + if len(address) >= 4: + normed[3] = address[3] + normed = tuple(normed) + return normed + class SocketType: def __init__(self): raise TypeError( @@ -444,7 +520,7 @@ def close(self): self._sock.close() async def bind(self, address): - address = await self._resolve_local_address_nocp(address) + address = await self._resolve_address_nocp(address, local=True) if ( hasattr(_stdlib_socket, "AF_UNIX") and self.family == _stdlib_socket.AF_UNIX @@ -480,89 +556,15 @@ def is_readable(self): async def wait_writable(self): await _core.wait_writable(self._sock) - ################################################################ - # Address handling - ################################################################ - - # Take an address in Python's representation, and returns a new address in - # the same representation, but with names resolved to numbers, - # etc. - # - # NOTE: this function does not always checkpoint - async def _resolve_address_nocp(self, address, flags): - # Do some pre-checking (or exit early for non-IP sockets) - if self._sock.family == _stdlib_socket.AF_INET: - if not isinstance(address, tuple) or not len(address) == 2: - raise ValueError("address should be a (host, port) tuple") - elif self._sock.family == _stdlib_socket.AF_INET6: - if not isinstance(address, tuple) or not 2 <= len(address) <= 4: - raise ValueError( - "address should be a (host, port, [flowinfo, [scopeid]]) tuple" - ) - elif self._sock.family == _stdlib_socket.AF_UNIX: - # unwrap path-likes - return os.fspath(address) + async def _resolve_address_nocp(self, address, *, local): + if self.family == _stdlib_socket.AF_INET6: + ipv6_v6only = self._sock.getsockopt(IPPROTO_IPV6, _stdlib_socket.IPV6_V6ONLY) else: - return address - - # -- From here on we know we have IPv4 or IPV6 -- - host, port, *_ = address - # Fast path for the simple case: already-resolved IP address, - # already-resolved port. This is particularly important for UDP, since - # every sendto call goes through here. - if isinstance(port, int): - try: - _stdlib_socket.inet_pton(self._sock.family, address[0]) - except (OSError, TypeError): - pass - else: - return address - # Special cases to match the stdlib, see gh-277 - if host == "": - host = None - if host == "": - host = "255.255.255.255" - # Since we always pass in an explicit family here, AI_ADDRCONFIG - # doesn't add any value -- if we have no ipv6 connectivity and are - # working with an ipv6 socket, then things will break soon enough! And - # if we do enable it, then it makes it impossible to even run tests - # for ipv6 address resolution on travis-ci, which as of 2017-03-07 has - # no ipv6. - # flags |= AI_ADDRCONFIG - if self._sock.family == _stdlib_socket.AF_INET6: - if not self._sock.getsockopt(IPPROTO_IPV6, _stdlib_socket.IPV6_V6ONLY): - flags |= _stdlib_socket.AI_V4MAPPED - gai_res = await getaddrinfo( - host, port, self._sock.family, self.type, self._sock.proto, flags - ) - # AFAICT from the spec it's not possible for getaddrinfo to return an - # empty list. - assert len(gai_res) >= 1 - # Address is the last item in the first entry - (*_, normed), *_ = gai_res - # The above ignored any flowid and scopeid in the passed-in address, - # so restore them if present: - if self._sock.family == _stdlib_socket.AF_INET6: - normed = list(normed) - assert len(normed) == 4 - if len(address) >= 3: - normed[2] = address[2] - if len(address) >= 4: - normed[3] = address[3] - normed = tuple(normed) - return normed - - # Returns something appropriate to pass to bind() - # - # NOTE: this function does not always checkpoint - async def _resolve_local_address_nocp(self, address): - return await self._resolve_address_nocp(address, _stdlib_socket.AI_PASSIVE) - - # Returns something appropriate to pass to connect()/sendto()/sendmsg() - # - # NOTE: this function does not always checkpoint - async def _resolve_remote_address_nocp(self, address): - return await self._resolve_address_nocp(address, 0) + ipv6_v6only = False + return await _resolve_address_nocp( + self.type, self.family, self.proto, + ipv6_v6only=ipv6_v6only, + address=address, local=local) async def _nonblocking_helper(self, fn, args, kwargs, wait_fn): # We have to reconcile two conflicting goals: @@ -617,7 +619,7 @@ async def connect(self, address): # notification. This means it isn't really cancellable... we close the # socket if cancelled, to avoid confusion. try: - address = await self._resolve_remote_address_nocp(address) + address = await self._resolve_address_nocp(address, local=False) async with _try_sync(): # An interesting puzzle: can a non-blocking connect() return EINTR # (= raise InterruptedError)? PEP 475 specifically left this as @@ -741,7 +743,7 @@ async def sendto(self, *args): # args is: data[, flags], address) # and kwargs are not accepted args = list(args) - args[-1] = await self._resolve_remote_address_nocp(args[-1]) + args[-1] = await self._resolve_address_nocp(args[-1], local=False) return await self._nonblocking_helper( _stdlib_socket.socket.sendto, args, {}, _core.wait_writable ) @@ -766,7 +768,7 @@ async def sendmsg(self, *args): # and kwargs are not accepted if len(args) == 4 and args[-1] is not None: args = list(args) - args[-1] = await self._resolve_remote_address_nocp(args[-1]) + args[-1] = await self._resolve_address_nocp(args[-1], local=False) return await self._nonblocking_helper( _stdlib_socket.socket.sendmsg, args, {}, _core.wait_writable ) diff --git a/trio/tests/test_socket.py b/trio/tests/test_socket.py index d891041ab2..b7c4839981 100644 --- a/trio/tests/test_socket.py +++ b/trio/tests/test_socket.py @@ -495,21 +495,20 @@ def assert_eq(actual, expected): with tsocket.socket(family=socket_type) as sock: # For some reason the stdlib special-cases "" to pass NULL to - # getaddrinfo They also error out on None, but whatever, None is much + # getaddrinfo. They also error out on None, but whatever, None is much # more consistent, so we accept it too. for null in [None, ""]: - got = await sock._resolve_local_address_nocp((null, 80)) + got = await sock._resolve_address_nocp((null, 80), local=True) assert_eq(got, (addrs.bind_all, 80)) - got = await sock._resolve_remote_address_nocp((null, 80)) + got = await sock._resolve_address_nocp((null, 80), local=False) assert_eq(got, (addrs.localhost, 80)) # AI_PASSIVE only affects the wildcard address, so for everything else - # _resolve_local_address_nocp and _resolve_remote_address_nocp should - # work the same: - for resolver in ["_resolve_local_address_nocp", "_resolve_remote_address_nocp"]: + # local=True/local=False should work the same: + for local in [False, True]: async def res(*args): - return await getattr(sock, resolver)(*args) + return await sock._resolve_address_nocp(*args, local=local) assert_eq(await res((addrs.arbitrary, "http")), (addrs.arbitrary, 80)) if v6: @@ -560,7 +559,7 @@ async def res(*args): except (AttributeError, OSError): pass else: - assert await getattr(netlink_sock, resolver)("asdf") == "asdf" + assert await netlink_sock._resolve_address_nocp("asdf", local=local) == "asdf" netlink_sock.close() with pytest.raises(ValueError): @@ -721,16 +720,16 @@ def connect(self, *args, **kwargs): await sock.connect(("127.0.0.1", 2)) -async def test_resolve_remote_address_exception_closes_socket(): +async def test_resolve_address_exception_in_connect_closes_socket(): # Here we are testing issue 247, any cancellation will leave the socket closed with _core.CancelScope() as cancel_scope: with tsocket.socket() as sock: - async def _resolve_remote_address_nocp(self, *args, **kwargs): + async def _resolve_address_nocp(self, *args, **kwargs): cancel_scope.cancel() await _core.checkpoint() - sock._resolve_remote_address_nocp = _resolve_remote_address_nocp + sock._resolve_address_nocp = _resolve_address_nocp with assert_checkpoints(): with pytest.raises(_core.Cancelled): await sock.connect("") From d1866f618117ebb38ac5996a5fc296636fcaed9d Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Sat, 4 Sep 2021 18:27:22 -0700 Subject: [PATCH 10/47] Add randomized dtls handshake robustness test --- trio/_dtls.py | 71 ++++--- trio/testing/__init__.py | 2 + trio/testing/_fake_net.py | 388 +++++++++++++++++++++++++++++++++++++ trio/tests/test_dtls.py | 91 ++++++++- trio/tests/test_fakenet.py | 42 ++++ 5 files changed, 566 insertions(+), 28 deletions(-) create mode 100644 trio/testing/_fake_net.py create mode 100644 trio/tests/test_fakenet.py diff --git a/trio/_dtls.py b/trio/_dtls.py index 6a0ca3edb2..07aaec64f5 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -11,6 +11,7 @@ import enum from itertools import count import weakref +import errno import attr @@ -318,7 +319,7 @@ def decode_volley_trusted(volley): # large that it has to be fragmented to fit into a single packet. if record.epoch_seqno & EPOCH_MASK: messages.append(OpaqueHandshakeMessage(record)) - elif record.content_type == ContentType.change_cipher_spec: + elif record.content_type in (ContentType.change_cipher_spec, ContentType.alert): messages.append( PseudoHandshakeMessage( record.version, record.content_type, record.payload @@ -460,7 +461,7 @@ def encode_volley(self, messages, mtu): # - The current time (using Trio's clock), rounded to the nearest 30 seconds # - A random salt # -# Then the cookie the salt + the HMAC digest. +# Then the cookie is the salt and the HMAC digest concatenated together. # # When verifying a cookie, we use the salt + new ip/port/ClientHello data to recompute # the HMAC digest, for both the current time and the current time minus 30 seconds, and @@ -525,7 +526,7 @@ def valid_cookie(cookie, address, client_hello_bits): tick = _current_cookie_tick() cur_cookie = _make_cookie(salt, tick, address, client_hello_bits) - old_cookie = _make_cookie(salt, tick - 1, address, client_hello_bits) + old_cookie = _make_cookie(salt, max(tick - 1, 0), address, client_hello_bits) # I doubt using a short-circuiting 'or' here would leak any meaningful # information, but why risk it when '|' is just as easy. @@ -620,12 +621,13 @@ async def handle_client_hello_untrusted(dtls, address, packet): return old_stream = dtls._streams.get(address) if old_stream is not None: - if old_stream._client_hello == packet: + if old_stream._client_hello == (cookie, bits): # ...but it's just a duplicate of a packet we got before, so never mind. return else: # Ok, this *really is* a new handshake; the old stream should go away. - old_stream._replaced() + old_stream._set_replaced() + stream._client_hello = (cookie, bits) dtls._streams[address] = stream dtls._incoming_connections_q.s.send_nowait(stream) @@ -706,7 +708,7 @@ def __init__(self, dtls, peer_address, ctx): self._handshake_lock = trio.Lock() self._record_encoder = RecordEncoder() - def _replaced(self): + def _set_replaced(self): self._replaced = True # Any packets we already received could maybe possibly still be processed, but # there are no more coming. So we close this on the sender side. @@ -714,7 +716,7 @@ def _replaced(self): def _check_replaced(self): if self._replaced: - raise BrokenResourceError( + raise trio.BrokenResourceError( "peer tore down this connection to start a new one" ) @@ -747,7 +749,6 @@ async def aclose(self): await trio.lowlevel.checkpoint() def _inject_client_hello_untrusted(self, packet): - self._client_hello = packet self._ssl.bio_write(packet) # If we're on the server side, then we already sent record 0 as our cookie # challenge. So we want to start the handshake proper with record 1. @@ -787,6 +788,7 @@ def read_volley(): if ( new_volley_messages and volley_messages + and isinstance(new_volley_messages[0], HandshakeMessage) and new_volley_messages[0].msg_seq == volley_messages[0].msg_seq ): # openssl decided to retransmit; discard because we handle @@ -819,7 +821,9 @@ def read_volley(): self._ssl.bio_write(packet) try: self._ssl.do_handshake() - except SSL.WantReadError: + # We ignore generic SSL.Error here, because you can get those + # from random invalid packets + except (SSL.WantReadError, SSL.Error): pass else: # No exception -> the handshake is done, and we can @@ -832,17 +836,25 @@ def read_volley(): return maybe_volley = read_volley() if maybe_volley: - # We managed to get all of the peer's volley and generate a - # new one ourselves! break out of the 'for' loop and restart - # the timer. - volley_messages = maybe_volley - # "Implementations SHOULD retain the current timer value - # until a transmission without loss occurs, at which time - # the value may be reset to the initial value." - if volley_failed_sends == 0: - timeout = initial_retransmit_timeout - volley_failed_sends = 0 - break + if (isinstance(maybe_volley[0], PseudoHandshakeMessage) + and maybe_volley[0].content_type == ContentType.alert): + # we're sending an alert (e.g. due to a corrupted + # packet). We want to send it once, but don't save it to + # retransmit -- keep the last volley as the current + # volley. + await self._send_volley(maybe_volley) + else: + # We managed to get all of the peer's volley and + # generate a new one ourselves! break out of the 'for' + # loop and restart the timer. + volley_messages = maybe_volley + # "Implementations SHOULD retain the current timer value + # until a transmission without loss occurs, at which + # time the value may be reset to the initial value." + if volley_failed_sends == 0: + timeout = initial_retransmit_timeout + volley_failed_sends = 0 + break else: assert self._replaced self._check_replaced() @@ -875,13 +887,18 @@ async def send(self, data): async def receive(self): if not self._did_handshake: await self.do_handshake() - try: - packet = await self._q.r.receive() - except trio.EndOfChannel: - assert self._replaced - self._check_replaced() - self._ssl.bio_write(packet) - return _read_loop(self._ssl.read) + while True: + try: + packet = await self._q.r.receive() + except trio.EndOfChannel: + assert self._replaced + self._check_replaced() + # Don't return spurious empty packets because of stray handshake packets + # coming in late + if part_of_handshake_untrusted(packet): + continue + self._ssl.bio_write(packet) + return _read_loop(self._ssl.read) class DTLS(metaclass=Final): diff --git a/trio/testing/__init__.py b/trio/testing/__init__.py index aa15c4743e..cf631f5716 100644 --- a/trio/testing/__init__.py +++ b/trio/testing/__init__.py @@ -24,6 +24,8 @@ from ._network import open_stream_to_socket_listener +from ._fake_net import FakeNet + ################################################################ from .._util import fixup_module_metadata diff --git a/trio/testing/_fake_net.py b/trio/testing/_fake_net.py new file mode 100644 index 0000000000..50a233ec71 --- /dev/null +++ b/trio/testing/_fake_net.py @@ -0,0 +1,388 @@ +# This should eventually be cleaned up and become public, but for right now I'm just +# implementing enough to test DTLS. + +# TODO: +# - user-defined routers +# - TCP +# - UDP broadcast + +import trio +import attr +import ipaddress +from collections import deque +import errno +import os +from typing import Union, List, Optional +import enum +from contextlib import contextmanager + +from trio._util import Final, NoPublicConstructor + +IPAddress = Union[ipaddress.IPv4Address, ipaddress.IPv6Address] + +def _family_for(ip: IPAddress) -> int: + if isinstance(ip, ipaddress.IPv4Address): + return trio.socket.AF_INET + elif isinstance(ip, ipaddress.IPv6Address): + return trio.socket.AF_INET6 + assert False # pragma: no cover + + +def _wildcard_ip_for(family: int) -> IPAddress: + if family == trio.socket.AF_INET: + return ipaddress.ip_address("0.0.0.0") + elif family == trio.socket.AF_INET6: + return ipaddress.ip_address("::") + else: + assert False + + +def _localhost_ip_for(family: int) -> IPAddress: + if family == trio.socket.AF_INET: + return ipaddress.ip_address("127.0.0.1") + elif family == trio.socket.AF_INET6: + return ipaddress.ip_address("::1") + else: + assert False + + +def _fake_err(code): + raise OSError(code, os.strerror(code)) + + +def _scatter(data, buffers): + written = 0 + for buf in buffers: + next_piece = data[written:written + len(buf)] + with memoryview(buf) as mbuf: + mbuf[:len(next_piece)] = next_piece + written += len(next_piece) + if written == len(data): + break + return written + + +@attr.frozen +class UDPEndpoint: + ip: IPAddress + port: int + + def as_python_sockaddr(self): + sockaddr = (self.ip.compressed, self.port) + if isinstance(self.ip, ipaddress.IPv6Address): + sockaddr += (0, 0) + return sockaddr + + @classmethod + def from_python_sockaddr(cls, sockaddr): + ip, port = sockaddr[:2] + return cls(ip=ipaddress.ip_address(ip), port=port) + + +@attr.frozen +class UDPBinding: + local: UDPEndpoint + + +@attr.frozen +class UDPPacket: + source: UDPEndpoint + destination: UDPEndpoint + payload: bytes + + def reply(self, payload): + return UDPPacket( + source=self.destination, destination=self.source, payload=payload + ) + + +@attr.frozen +class FakeSocketFactory(trio.abc.SocketFactory): + fake_net: "FakeNet" + + def socket(self, family: int, type: int, proto: int) -> "FakeSocket": + return FakeSocket._create(self.fake_net, family, type, proto) + + +@attr.frozen +class FakeHostnameResolver(trio.abc.HostnameResolver): + fake_net: "FakeNet" + + async def getaddrinfo( + self, host: str, port: Union[int, str], family=0, type=0, proto=0, flags=0 + ): + raise NotImplementedError("FakeNet doesn't do fake DNS yet") + + async def getnameinfo(self, sockaddr, flags: int): + raise NotImplementedError("FakeNet doesn't do fake DNS yet") + + +class FakeNet(metaclass=Final): + def __init__(self): + # When we need to pick an arbitrary unique ip address/port, use these: + self._auto_ipv4_iter = ipaddress.IPv4Network("1.0.0.0/8").hosts() + self._auto_ipv4_iter = ipaddress.IPv6Network("1::/16").hosts() + self._auto_port_iter = iter(range(50000, 65535)) + + self._bound: Dict[UDPBinding, FakeSocket] = {} + + self.route_packet = None + + def _bind(self, binding: UDPBinding, socket: "FakeSocket") -> None: + if binding in self._bound: + _fake_err(errno.EADDRINUSE) + self._bound[binding] = socket + + def enable(self) -> None: + trio.socket.set_custom_socket_factory(FakeSocketFactory(self)) + trio.socket.set_custom_hostname_resolver(FakeHostnameResolver(self)) + + def send_packet(self, packet) -> None: + if self.route_packet is None: + self.deliver_packet(packet) + else: + self.route_packet(packet) + + def deliver_packet(self, packet) -> None: + binding = UDPBinding(local=packet.destination) + if binding in self._bound: + self._bound[binding]._deliver_packet(packet) + else: + # No valid destination, so drop it + pass + + +class FakeSocket(trio.socket.SocketType, metaclass=NoPublicConstructor): + def __init__(self, fake_net: FakeNet, family: int, type: int, proto: int): + self._fake_net = fake_net + + if not family: + family = trio.socket.AF_INET + if not type: + type = trio.socket.SOCK_STREAM + + if family not in (trio.socket.AF_INET, trio.socket.AF_INET6): + raise NotImplementedError(f"FakeNet doesn't (yet) support family={family}") + if type != trio.socket.SOCK_DGRAM: + raise NotImplementedError(f"FakeNet doesn't (yet) support type={type}") + + self.family = family + self.type = type + self.proto = proto + + self._closed = False + + self._packet_sender, self._packet_receiver = trio.open_memory_channel(float("inf")) + + # This is the source-of-truth for what port etc. this socket is bound to + self._binding: Optional[UDPBinding] = None + + def _check_closed(self): + if self._closed: + _fake_err(errno.EBADF) + + def close(self): + #breakpoint() + if self._closed: + return + self._closed = True + if self._binding is not None: + del self._fake_net._bound[self._binding] + self._packet_receiver.close() + + async def _resolve_address_nocp(self, address, *, local): + return await trio._socket._resolve_address_nocp(self.type, self.family, + self.proto, address=address, + ipv6_v6only=False, local=local) + + def _deliver_packet(self, packet: UDPPacket): + try: + self._packet_sender.send_nowait(packet) + except trio.BrokenResourceError: + # sending to a closed socket -- UDP packets get dropped + pass + + ################################################################ + # Actual IO operation implementations + ################################################################ + + async def bind(self, addr): + self._check_closed() + if self._binding is not None: + _fake_error(errno.EINVAL) + await trio.lowlevel.checkpoint() + ip_str, port = await self._resolve_address_nocp(addr, local=True) + ip = ipaddress.ip_address(ip_str) + assert _family_for(ip) == self.family + # We convert binds to INET_ANY into binds to localhost + if ip == ipaddress.ip_address("0.0.0.0"): + ip = ipaddress.ip_address("127.0.0.1") + elif ip == ipaddress.ip_address("::"): + ip = ipaddress.ip_address("::1") + if port == 0: + port = next(self._fake_net._auto_port_iter) + binding = UDPBinding(local=UDPEndpoint(ip, port)) + self._fake_net._bind(binding, self) + self._binding = binding + + async def connect(self, peer): + raise NotImplementedError("FakeNet does not (yet) support connected sockets") + + async def sendmsg(self, *args): + self._check_closed() + ancdata = [] + flags = 0 + address = None + if len(args) == 1: + (buffers,) = args + elif len(args) == 2: + buffers, address = args + elif len(args) == 3: + buffers, flags, address = args + elif len(args) == 4: + buffers, ancdata, flags, address = args + else: + raise TypeError("wrong number of arguments") + + await trio.lowlevel.checkpoint() + + if address is not None: + address = await self._resolve_address_nocp(address, local=False) + if ancdata: + raise NotImplementedError("FakeNet doesn't support ancillary data") + if flags: + raise NotImplementedError(f"FakeNet send flags must be 0, not {flags}") + + if address is None: + _fake_err(errno.ENOTCONN) + + destination = UDPEndpoint.from_python_sockaddr(address) + + if self._binding is None: + await self.bind((_wildcard_ip_for(self.family).compressed, 0)) + + payload = b"".join(buffers) + + packet = UDPPacket( + source=self._binding.local, + destination=destination, + payload=payload, + ) + + self._fake_net.send_packet(packet) + + return len(payload) + + async def recvmsg_into(self, buffers, ancbufsize=0, flags=0): + if ancbufsize != 0: + raise NotImplementedError("FakeNet doesn't support ancillary data") + if flags != 0: + raise NotImplementedError("FakeNet doesn't support any recv flags") + + self._check_closed() + + ancdata = [] + msg_flags = 0 + + packet = await self._packet_receiver.receive() + address = packet.source.as_python_sockaddr() + written = _scatter(packet.payload, buffers) + if written < len(packet.payload): + msg_flags |= trio.socket.MSG_TRUNC + return written, ancdata, msg_flags, address + + ################################################################ + # Simple state query stuff + ################################################################ + + def getsockname(self): + self._check_closed() + if self._binding is not None: + return self._binding.local.as_python_sockaddr() + elif self.family == trio.socket.AF_INET: + return ("0.0.0.0", 0) + else: + assert self.family == trio.socket.AF_INET6 + return ("::", 0) + + def getpeername(self): + self._check_closed() + if self._binding is not None: + if self._binding.remote is not None: + return self._binding.remote.as_python_sockaddr() + _fake_err(errno.ENOTCONN) + + def getsockopt(self, level, item): + self._check_closed() + raise OSError(f"FakeNet doesn't implement getsockopt({level}, {item})") + + def setsockopt(self, level, item, value): + self._check_closed() + + if (level, item) == (trio.socket.IPPROTO_IPV6, trio.socket.IPV6_V6ONLY): + if not value: + raise NotImplementedError("FakeNet always has IPV6_V6ONLY=True") + + raise OSError(f"FakeNet doesn't implement setsockopt({level}, {item}, ...)") + + ################################################################ + # Various boilerplate and trivial stubs + ################################################################ + + def __enter__(self): + return self + + def __exit__(self, *exc_info): + self.close() + + async def send(self, data, flags=0): + return await self.sendto(data, flags, None) + + async def sendto(self, *args): + if len(args) == 2: + data, address = args + flags = 0 + elif len(args) == 3: + data, flags, address = args + else: + raise TypeError("wrong number of arguments") + return await self.sendmsg([data], [], flags, address) + + async def recv(self, bufsize, flags=0): + data, address = await self.recvfrom(bufsize, flags) + return data + + async def recv_into(self, buf, nbytes=0, flags=0): + got_bytes, address = await self.recvfrom_into(buf, nbytes, flags) + return got_bytes + + async def recvfrom(self, bufsize, flags=0): + data, ancdata, msg_flags, address = await self.recvmsg(bufsize, flags) + return data, address + + async def recvfrom_into(self, buf, nbytes=0, flags=0): + if nbytes != 0 and nbytes != len(buf): + raise NotImplementedError("partial recvfrom_into") + got_nbytes, ancdata, msg_flags, address = await self.recvmsg_into([buf], 0, flags) + return got_nbytes, address + + async def recvmsg(self, bufsize, ancbufsize=0, flags=0): + buf = bytearray(bufsize) + got_nbytes, ancdata, msg_flags, address = await self.recvmsg_into([buf], ancbufsize, flags) + return (bytes(buf[:got_nbytes]), ancdata, msg_flags, address) + + def fileno(self): + raise NotImplementedError("can't get fileno() for FakeNet sockets") + + def detach(self): + raise NotImplementedError("can't detach() a FakeNet socket") + + def get_inheritable(self): + return False + + def set_inheritable(self, inheritable): + if inheritable: + raise NotImplementedError("FakeNet can't make inheritable sockets") + + def share(self, process_id): + raise NotImplementedError("FakeNet can't share sockets") diff --git a/trio/tests/test_dtls.py b/trio/tests/test_dtls.py index 8fcf24a112..5527c63f1e 100644 --- a/trio/tests/test_dtls.py +++ b/trio/tests/test_dtls.py @@ -1,5 +1,7 @@ import trio from trio._dtls import DTLS +import random +import attr import trustme from OpenSSL import SSL @@ -17,7 +19,7 @@ async def test_smoke(): server_sock = trio.socket.socket(type=trio.socket.SOCK_DGRAM) with server_sock: - await server_sock.bind(("127.0.0.1", 54321)) + await server_sock.bind(("127.0.0.1", 0)) server_dtls = DTLS(server_sock) async with trio.open_nursery() as nursery: @@ -35,3 +37,90 @@ async def handle_client(dtls_stream): await client.send(b"hello") assert await client.receive() == b"goodbye" nursery.cancel_scope.cancel() + + +async def test_handshake_over_terrible_network(autojump_clock): + HANDSHAKES = 1000 + r = random.Random(0) + fn = trio.testing.FakeNet() + fn.enable() + with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as server_sock: + async with trio.open_nursery() as nursery: + async def route_packet(packet): + while True: + op = r.choices(["deliver", "drop", "dupe", "delay"], + weights=[0.7, 0.1, 0.1, 0.1])[0] + print(f"{packet.source} -> {packet.destination}: {op}") + if op == "drop": + return + elif op == "dupe": + fn.send_packet(packet) + elif op == "delay": + await trio.sleep(r.random() * 3) + else: + assert op == "deliver" + print(f"{packet.source} -> {packet.destination}: delivered {packet.payload.hex()}") + fn.deliver_packet(packet) + break + + def route_packet_wrapper(packet): + try: + nursery.start_soon(route_packet, packet) + except RuntimeError: + # We're exiting the nursery, so any remaining packets can just get + # dropped + pass + + fn.route_packet = route_packet_wrapper + + await server_sock.bind(("1.1.1.1", 54321)) + server_dtls = DTLS(server_sock) + + next_client_idx = 0 + next_client_msg_recvd = trio.Event() + + async def handle_client(dtls_stream): + print("handling new client") + try: + await dtls_stream.do_handshake() + while True: + data = await dtls_stream.receive() + print(f"server received plaintext: {data}") + if not data: + continue + assert int(data.decode()) == next_client_idx + next_client_msg_recvd.set() + break + except trio.BrokenResourceError: + # client might have timed out on handshake and started a new one + # so we'll let this task die and let the new task do the check + print("new handshake restarting") + pass + except: + print("server handler saw") + import traceback + traceback.print_exc() + raise + + await nursery.start(server_dtls.serve, server_ctx, handle_client) + + for _ in range(HANDSHAKES): + print("#" * 80) + print("#" * 80) + print("#" * 80) + with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as client_sock: + client_dtls = DTLS(client_sock) + client = await client_dtls.connect(server_sock.getsockname(), client_ctx) + while True: + data = str(next_client_idx).encode() + print(f"client sending plaintext: {data}") + await client.send(data) + with trio.move_on_after(10) as cscope: + await next_client_msg_recvd.wait() + if not cscope.cancelled_caught: + break + + next_client_idx += 1 + next_client_msg_recvd = trio.Event() + + nursery.cancel_scope.cancel() diff --git a/trio/tests/test_fakenet.py b/trio/tests/test_fakenet.py new file mode 100644 index 0000000000..623976f2be --- /dev/null +++ b/trio/tests/test_fakenet.py @@ -0,0 +1,42 @@ +import pytest + +import trio +from trio.testing import FakeNet + +def fn(): + fn = FakeNet() + fn.enable() + return fn + +async def test_basic_udp(): + fn() + s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + + await s1.bind(("127.0.0.1", 0)) + ip, port = s1.getsockname() + assert ip == "127.0.0.1" + assert port != 0 + await s2.sendto(b"xyz", s1.getsockname()) + data, addr = await s1.recvfrom(10) + assert data == b"xyz" + assert addr == s2.getsockname() + await s1.sendto(b"abc", s2.getsockname()) + data, addr = await s2.recvfrom(10) + assert data == b"abc" + assert addr == s1.getsockname() + + +async def test_msg_trunc(): + fn() + s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + await s1.bind(("127.0.0.1", 0)) + await s2.sendto(b"xyz", s1.getsockname()) + data, addr = await s1.recvfrom(10) + + +async def test_basic_tcp(): + fn() + with pytest.raises(NotImplementedError): + trio.socket.socket() From 7378ce7f6fd6ca90a9045c4454afcc7132aeb278 Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Sat, 4 Sep 2021 18:27:59 -0700 Subject: [PATCH 11/47] run black --- trio/_dtls.py | 20 ++++++++++++++++---- trio/_socket.py | 13 ++++++++++--- trio/testing/_fake_net.py | 32 ++++++++++++++++++++++---------- trio/tests/test_dtls.py | 20 +++++++++++++++----- trio/tests/test_fakenet.py | 2 ++ trio/tests/test_socket.py | 5 ++++- 6 files changed, 69 insertions(+), 23 deletions(-) diff --git a/trio/_dtls.py b/trio/_dtls.py index 07aaec64f5..c0ecb771dd 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -152,7 +152,10 @@ def records_untrusted(packet): def encode_record(record): header = RECORD_HEADER.pack( - record.content_type, record.version, record.epoch_seqno, len(record.payload), + record.content_type, + record.version, + record.epoch_seqno, + len(record.payload), ) return header + record.payload @@ -196,7 +199,14 @@ def decode_handshake_fragment_untrusted(payload): frag = payload[HANDSHAKE_MESSAGE_HEADER.size :] if len(frag) != frag_len: raise BadPacket("handshake fragment length doesn't match record length") - return HandshakeFragment(msg_type, msg_len, msg_seq, frag_offset, frag_len, frag,) + return HandshakeFragment( + msg_type, + msg_len, + msg_seq, + frag_offset, + frag_len, + frag, + ) def encode_handshake_fragment(hsf): @@ -836,8 +846,10 @@ def read_volley(): return maybe_volley = read_volley() if maybe_volley: - if (isinstance(maybe_volley[0], PseudoHandshakeMessage) - and maybe_volley[0].content_type == ContentType.alert): + if ( + isinstance(maybe_volley[0], PseudoHandshakeMessage) + and maybe_volley[0].content_type == ContentType.alert + ): # we're sending an alert (e.g. due to a corrupted # packet). We want to send it once, but don't save it to # retransmit -- keep the last volley as the current diff --git a/trio/_socket.py b/trio/_socket.py index 3747bfe4d5..bcff1ee9e7 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -425,6 +425,7 @@ async def _resolve_address_nocp(type, family, proto, *, ipv6_v6only, address, lo normed = tuple(normed) return normed + class SocketType: def __init__(self): raise TypeError( @@ -558,13 +559,19 @@ async def wait_writable(self): async def _resolve_address_nocp(self, address, *, local): if self.family == _stdlib_socket.AF_INET6: - ipv6_v6only = self._sock.getsockopt(IPPROTO_IPV6, _stdlib_socket.IPV6_V6ONLY) + ipv6_v6only = self._sock.getsockopt( + IPPROTO_IPV6, _stdlib_socket.IPV6_V6ONLY + ) else: ipv6_v6only = False return await _resolve_address_nocp( - self.type, self.family, self.proto, + self.type, + self.family, + self.proto, ipv6_v6only=ipv6_v6only, - address=address, local=local) + address=address, + local=local, + ) async def _nonblocking_helper(self, fn, args, kwargs, wait_fn): # We have to reconcile two conflicting goals: diff --git a/trio/testing/_fake_net.py b/trio/testing/_fake_net.py index 50a233ec71..df853cb87b 100644 --- a/trio/testing/_fake_net.py +++ b/trio/testing/_fake_net.py @@ -20,6 +20,7 @@ IPAddress = Union[ipaddress.IPv4Address, ipaddress.IPv6Address] + def _family_for(ip: IPAddress) -> int: if isinstance(ip, ipaddress.IPv4Address): return trio.socket.AF_INET @@ -53,9 +54,9 @@ def _fake_err(code): def _scatter(data, buffers): written = 0 for buf in buffers: - next_piece = data[written:written + len(buf)] + next_piece = data[written : written + len(buf)] with memoryview(buf) as mbuf: - mbuf[:len(next_piece)] = next_piece + mbuf[: len(next_piece)] = next_piece written += len(next_piece) if written == len(data): break @@ -109,7 +110,7 @@ class FakeHostnameResolver(trio.abc.HostnameResolver): fake_net: "FakeNet" async def getaddrinfo( - self, host: str, port: Union[int, str], family=0, type=0, proto=0, flags=0 + self, host: str, port: Union[int, str], family=0, type=0, proto=0, flags=0 ): raise NotImplementedError("FakeNet doesn't do fake DNS yet") @@ -172,7 +173,9 @@ def __init__(self, fake_net: FakeNet, family: int, type: int, proto: int): self._closed = False - self._packet_sender, self._packet_receiver = trio.open_memory_channel(float("inf")) + self._packet_sender, self._packet_receiver = trio.open_memory_channel( + float("inf") + ) # This is the source-of-truth for what port etc. this socket is bound to self._binding: Optional[UDPBinding] = None @@ -182,7 +185,7 @@ def _check_closed(self): _fake_err(errno.EBADF) def close(self): - #breakpoint() + # breakpoint() if self._closed: return self._closed = True @@ -191,9 +194,14 @@ def close(self): self._packet_receiver.close() async def _resolve_address_nocp(self, address, *, local): - return await trio._socket._resolve_address_nocp(self.type, self.family, - self.proto, address=address, - ipv6_v6only=False, local=local) + return await trio._socket._resolve_address_nocp( + self.type, + self.family, + self.proto, + address=address, + ipv6_v6only=False, + local=local, + ) def _deliver_packet(self, packet: UDPPacket): try: @@ -363,12 +371,16 @@ async def recvfrom(self, bufsize, flags=0): async def recvfrom_into(self, buf, nbytes=0, flags=0): if nbytes != 0 and nbytes != len(buf): raise NotImplementedError("partial recvfrom_into") - got_nbytes, ancdata, msg_flags, address = await self.recvmsg_into([buf], 0, flags) + got_nbytes, ancdata, msg_flags, address = await self.recvmsg_into( + [buf], 0, flags + ) return got_nbytes, address async def recvmsg(self, bufsize, ancbufsize=0, flags=0): buf = bytearray(bufsize) - got_nbytes, ancdata, msg_flags, address = await self.recvmsg_into([buf], ancbufsize, flags) + got_nbytes, ancdata, msg_flags, address = await self.recvmsg_into( + [buf], ancbufsize, flags + ) return (bytes(buf[:got_nbytes]), ancdata, msg_flags, address) def fileno(self): diff --git a/trio/tests/test_dtls.py b/trio/tests/test_dtls.py index 5527c63f1e..61e6fc2f97 100644 --- a/trio/tests/test_dtls.py +++ b/trio/tests/test_dtls.py @@ -33,7 +33,9 @@ async def handle_client(dtls_stream): with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as client_sock: client_dtls = DTLS(client_sock) - client = await client_dtls.connect(server_sock.getsockname(), client_ctx) + client = await client_dtls.connect( + server_sock.getsockname(), client_ctx + ) await client.send(b"hello") assert await client.receive() == b"goodbye" nursery.cancel_scope.cancel() @@ -46,10 +48,13 @@ async def test_handshake_over_terrible_network(autojump_clock): fn.enable() with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as server_sock: async with trio.open_nursery() as nursery: + async def route_packet(packet): while True: - op = r.choices(["deliver", "drop", "dupe", "delay"], - weights=[0.7, 0.1, 0.1, 0.1])[0] + op = r.choices( + ["deliver", "drop", "dupe", "delay"], + weights=[0.7, 0.1, 0.1, 0.1], + )[0] print(f"{packet.source} -> {packet.destination}: {op}") if op == "drop": return @@ -59,7 +64,9 @@ async def route_packet(packet): await trio.sleep(r.random() * 3) else: assert op == "deliver" - print(f"{packet.source} -> {packet.destination}: delivered {packet.payload.hex()}") + print( + f"{packet.source} -> {packet.destination}: delivered {packet.payload.hex()}" + ) fn.deliver_packet(packet) break @@ -99,6 +106,7 @@ async def handle_client(dtls_stream): except: print("server handler saw") import traceback + traceback.print_exc() raise @@ -110,7 +118,9 @@ async def handle_client(dtls_stream): print("#" * 80) with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as client_sock: client_dtls = DTLS(client_sock) - client = await client_dtls.connect(server_sock.getsockname(), client_ctx) + client = await client_dtls.connect( + server_sock.getsockname(), client_ctx + ) while True: data = str(next_client_idx).encode() print(f"client sending plaintext: {data}") diff --git a/trio/tests/test_fakenet.py b/trio/tests/test_fakenet.py index 623976f2be..4e1c45b45f 100644 --- a/trio/tests/test_fakenet.py +++ b/trio/tests/test_fakenet.py @@ -3,11 +3,13 @@ import trio from trio.testing import FakeNet + def fn(): fn = FakeNet() fn.enable() return fn + async def test_basic_udp(): fn() s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) diff --git a/trio/tests/test_socket.py b/trio/tests/test_socket.py index b7c4839981..1fa3721f91 100644 --- a/trio/tests/test_socket.py +++ b/trio/tests/test_socket.py @@ -559,7 +559,10 @@ async def res(*args): except (AttributeError, OSError): pass else: - assert await netlink_sock._resolve_address_nocp("asdf", local=local) == "asdf" + assert ( + await netlink_sock._resolve_address_nocp("asdf", local=local) + == "asdf" + ) netlink_sock.close() with pytest.raises(ValueError): From dcc5eae5de83a66f8b8d6c826e676dc73390416a Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Sat, 4 Sep 2021 19:22:34 -0700 Subject: [PATCH 12/47] Mark randomized test as slow --- trio/tests/test_dtls.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/trio/tests/test_dtls.py b/trio/tests/test_dtls.py index 61e6fc2f97..24693ef1ed 100644 --- a/trio/tests/test_dtls.py +++ b/trio/tests/test_dtls.py @@ -1,3 +1,4 @@ +import pytest import trio from trio._dtls import DTLS import random @@ -6,6 +7,8 @@ import trustme from OpenSSL import SSL +from .._core.tests.tutil import slow + ca = trustme.CA() server_cert = ca.issue_cert("example.com") @@ -41,6 +44,7 @@ async def handle_client(dtls_stream): nursery.cancel_scope.cancel() +@slow async def test_handshake_over_terrible_network(autojump_clock): HANDSHAKES = 1000 r = random.Random(0) From 4a0bf1d029a92ab8658dc641be2628cf9af50e4f Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Sat, 4 Sep 2021 19:25:46 -0700 Subject: [PATCH 13/47] Make FakeNet private for now --- trio/testing/__init__.py | 2 -- trio/tests/test_dtls.py | 3 ++- trio/tests/test_fakenet.py | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/trio/testing/__init__.py b/trio/testing/__init__.py index cf631f5716..aa15c4743e 100644 --- a/trio/testing/__init__.py +++ b/trio/testing/__init__.py @@ -24,8 +24,6 @@ from ._network import open_stream_to_socket_listener -from ._fake_net import FakeNet - ################################################################ from .._util import fixup_module_metadata diff --git a/trio/tests/test_dtls.py b/trio/tests/test_dtls.py index 24693ef1ed..b2ba10442c 100644 --- a/trio/tests/test_dtls.py +++ b/trio/tests/test_dtls.py @@ -7,6 +7,7 @@ import trustme from OpenSSL import SSL +from trio.testing._fake_net import FakeNet from .._core.tests.tutil import slow ca = trustme.CA() @@ -48,7 +49,7 @@ async def handle_client(dtls_stream): async def test_handshake_over_terrible_network(autojump_clock): HANDSHAKES = 1000 r = random.Random(0) - fn = trio.testing.FakeNet() + fn = FakeNet() fn.enable() with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as server_sock: async with trio.open_nursery() as nursery: diff --git a/trio/tests/test_fakenet.py b/trio/tests/test_fakenet.py index 4e1c45b45f..bc691c9db5 100644 --- a/trio/tests/test_fakenet.py +++ b/trio/tests/test_fakenet.py @@ -1,7 +1,7 @@ import pytest import trio -from trio.testing import FakeNet +from trio.testing._fake_net import FakeNet def fn(): From c2cfb1c1729c4fcb77ddeafa823f7be9c75eb23e Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Sat, 4 Sep 2021 20:37:15 -0700 Subject: [PATCH 14/47] make list of tests to write --- trio/tests/test_dtls.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/trio/tests/test_dtls.py b/trio/tests/test_dtls.py index b2ba10442c..1bdf4ffe8d 100644 --- a/trio/tests/test_dtls.py +++ b/trio/tests/test_dtls.py @@ -139,3 +139,24 @@ async def handle_client(dtls_stream): next_client_msg_recvd = trio.Event() nursery.cancel_scope.cancel() + + +# send all kinds of garbage at a server socket +# send hello at a client-only socket +# socket closed at terrible times +# cancelling and restarting a client handshake +# garbage collecting DTLS object without closing it +# incoming packets buffer overflow +# set/get mtu +# closing a DTLSStream +# two simultaneous calls to .do_handshake() +# openssl retransmit +# receive a piece of garbage from the correct source during a handshake (corrupted +# packet, someone being a jerk) -- though can't necessarily tolerate someone sending a +# fake HelloRetryRequest +# implicit handshake on send/receive +# send/receive after closing +# DTLS close +# DTLS on SOCK_STREAM socket +# calling serve twice +# connect() that replaces an existing association (currently totally broken!) From 5a224c1b5d2f08ad118abc31ed669b3667ef1ef3 Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Sat, 4 Sep 2021 20:42:34 -0700 Subject: [PATCH 15/47] Rename DTLSStream -> DTLSChannel --- trio/__init__.py | 2 +- trio/_dtls.py | 8 ++++---- trio/tests/test_dtls.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/trio/__init__.py b/trio/__init__.py index b4b77f7fa9..782220231e 100644 --- a/trio/__init__.py +++ b/trio/__init__.py @@ -74,7 +74,7 @@ from ._ssl import SSLStream, SSLListener, NeedHandshakeError -from ._dtls import DTLS, DTLSStream +from ._dtls import DTLS, DTLSChannel from ._highlevel_serve_listeners import serve_listeners diff --git a/trio/_dtls.py b/trio/_dtls.py index c0ecb771dd..585b565351 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -623,7 +623,7 @@ async def handle_client_hello_untrusted(dtls, address, packet): pass else: # We got a real, valid ClientHello! - stream = DTLSStream._create(dtls, address, dtls._listening_context) + stream = DTLSChannel._create(dtls, address, dtls._listening_context) try: stream._inject_client_hello_untrusted(packet) except BadPacket: @@ -694,7 +694,7 @@ async def dtls_receive_loop(dtls): del dtls -class DTLSStream(trio.abc.Channel[bytes], metaclass=NoPublicConstructor): +class DTLSChannel(trio.abc.Channel[bytes], metaclass=NoPublicConstructor): def __init__(self, dtls, peer_address, ctx): self.dtls = dtls self.peer_address = peer_address @@ -929,7 +929,7 @@ def __init__(self, socket, *, incoming_packets_buffer=10): # separately. We only keep one connection per remote address; as soon # as a peer provides a valid cookie, we can immediately tear down the # old connection. - # {remote address: DTLSStream} + # {remote address: DTLSChannel} self._streams = weakref.WeakValueDictionary() self._listening_context = None self._incoming_connections_q = _Queue(float("inf")) @@ -981,7 +981,7 @@ def _set_stream_for(self, address, stream): self._streams[address] = stream async def connect(self, address, ssl_context): - stream = DTLSStream._create(self, address, ssl_context) + stream = DTLSChannel._create(self, address, ssl_context) stream._ssl.set_connect_state() self._set_stream_for(address, stream) await stream.do_handshake() diff --git a/trio/tests/test_dtls.py b/trio/tests/test_dtls.py index 1bdf4ffe8d..2fc124a7b3 100644 --- a/trio/tests/test_dtls.py +++ b/trio/tests/test_dtls.py @@ -148,7 +148,7 @@ async def handle_client(dtls_stream): # garbage collecting DTLS object without closing it # incoming packets buffer overflow # set/get mtu -# closing a DTLSStream +# closing a DTLSChannel # two simultaneous calls to .do_handshake() # openssl retransmit # receive a piece of garbage from the correct source during a handshake (corrupted From d94713a3081d5ebda56409dc424c342b6952b5ea Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Sat, 4 Sep 2021 21:06:26 -0700 Subject: [PATCH 16/47] more tests --- trio/tests/test_dtls.py | 41 +++++++++++++++++++++++++++++------------ 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/trio/tests/test_dtls.py b/trio/tests/test_dtls.py index 2fc124a7b3..f50986daf1 100644 --- a/trio/tests/test_dtls.py +++ b/trio/tests/test_dtls.py @@ -8,7 +8,7 @@ from OpenSSL import SSL from trio.testing._fake_net import FakeNet -from .._core.tests.tutil import slow +from .._core.tests.tutil import slow, can_bind_ipv6 ca = trustme.CA() server_cert = ca.issue_cert("example.com") @@ -20,10 +20,19 @@ ca.configure_trust(client_ctx) -async def test_smoke(): - server_sock = trio.socket.socket(type=trio.socket.SOCK_DGRAM) +families = [trio.socket.AF_INET] +if can_bind_ipv6: + families.append(trio.socket.AF_INET6) + +@pytest.mark.parametrize("family", families) +async def test_smoke(family): + if family == trio.socket.AF_INET: + localhost = "127.0.0.1" + else: + localhost = "::1" + server_sock = trio.socket.socket(type=trio.socket.SOCK_DGRAM, family=family) with server_sock: - await server_sock.bind(("127.0.0.1", 0)) + await server_sock.bind((localhost, 0)) server_dtls = DTLS(server_sock) async with trio.open_nursery() as nursery: @@ -35,13 +44,21 @@ async def handle_client(dtls_stream): await nursery.start(server_dtls.serve, server_ctx, handle_client) - with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as client_sock: + with trio.socket.socket(type=trio.socket.SOCK_DGRAM, family=family) as client_sock: client_dtls = DTLS(client_sock) client = await client_dtls.connect( server_sock.getsockname(), client_ctx ) await client.send(b"hello") assert await client.receive() == b"goodbye" + + client.set_ciphertext_mtu(1234) + cleartext_mtu_1234 = client.get_cleartext_mtu() + client.set_ciphertext_mtu(4321) + assert client.get_cleartext_mtu() > cleartext_mtu_1234 + client.set_ciphertext_mtu(1234) + assert client.get_cleartext_mtu() == cleartext_mtu_1234 + nursery.cancel_scope.cancel() @@ -141,22 +158,22 @@ async def handle_client(dtls_stream): nursery.cancel_scope.cancel() +# implicit handshake on send/receive +# send/receive after closing +# DTLS close +# DTLS on SOCK_STREAM socket +# incoming packets buffer overflow + # send all kinds of garbage at a server socket # send hello at a client-only socket # socket closed at terrible times -# cancelling and restarting a client handshake +# cancelling a client handshake and then starting a new one # garbage collecting DTLS object without closing it -# incoming packets buffer overflow -# set/get mtu # closing a DTLSChannel # two simultaneous calls to .do_handshake() # openssl retransmit # receive a piece of garbage from the correct source during a handshake (corrupted # packet, someone being a jerk) -- though can't necessarily tolerate someone sending a # fake HelloRetryRequest -# implicit handshake on send/receive -# send/receive after closing -# DTLS close -# DTLS on SOCK_STREAM socket # calling serve twice # connect() that replaces an existing association (currently totally broken!) From def73ede4e82aa0e1fa3d1f4080876e55f8d8caa Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Sat, 4 Sep 2021 21:48:29 -0700 Subject: [PATCH 17/47] more tests --- trio/_dtls.py | 16 +++++- trio/tests/test_dtls.py | 108 ++++++++++++++++++++++++++++++++++++---- 2 files changed, 111 insertions(+), 13 deletions(-) diff --git a/trio/_dtls.py b/trio/_dtls.py index 585b565351..a72a12cd8e 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -920,8 +920,10 @@ def __init__(self, socket, *, incoming_packets_buffer=10): global SSL from OpenSSL import SSL + self.socket = None # for __del__ if socket.type != trio.socket.SOCK_DGRAM: - raise BadPacket("DTLS requires a SOCK_DGRAM socket") + raise ValueError("DTLS requires a SOCK_DGRAM socket") + self.socket = socket self.incoming_packets_buffer = incoming_packets_buffer self._token = trio.lowlevel.current_trio_token() @@ -934,20 +936,24 @@ def __init__(self, socket, *, incoming_packets_buffer=10): self._listening_context = None self._incoming_connections_q = _Queue(float("inf")) self._send_lock = trio.Lock() + self._closed = False trio.lowlevel.spawn_system_task(dtls_receive_loop, self) def __del__(self): # Close the socket in Trio context (if our Trio context still exists), so that # the background task gets notified about the closure and can exit. + if self.socket is None: + return try: self._token.run_sync_soon(self.socket.close) except RuntimeError: pass def close(self): + self._closed = True self.socket.close() - for stream in self._streams.values(): + for stream in list(self._streams.values()): stream.close() self._incoming_connections_q.s.close() @@ -957,9 +963,14 @@ def __enter__(self): def __exit__(self, *args): self.close() + def _check_closed(self): + if self._closed: + raise trio.ClosedResourceError + async def serve( self, ssl_context, async_fn, *args, task_status=trio.TASK_STATUS_IGNORED ): + self._check_closed() if self._listening_context is not None: raise trio.BusyResourceError("another task is already listening") # We do cookie verification ourselves, so tell OpenSSL not to worry about it. @@ -981,6 +992,7 @@ def _set_stream_for(self, address, stream): self._streams[address] = stream async def connect(self, address, ssl_context): + self._check_closed() stream = DTLSChannel._create(self, address, ssl_context) stream._ssl.set_connect_state() self._set_stream_for(address, stream) diff --git a/trio/tests/test_dtls.py b/trio/tests/test_dtls.py index f50986daf1..d350a539b7 100644 --- a/trio/tests/test_dtls.py +++ b/trio/tests/test_dtls.py @@ -3,6 +3,7 @@ from trio._dtls import DTLS import random import attr +from contextlib import asynccontextmanager import trustme from OpenSSL import SSL @@ -37,10 +38,10 @@ async def test_smoke(family): async with trio.open_nursery() as nursery: - async def handle_client(dtls_stream): - await dtls_stream.do_handshake() - assert await dtls_stream.receive() == b"hello" - await dtls_stream.send(b"goodbye") + async def handle_client(dtls_channel): + await dtls_channel.do_handshake() + assert await dtls_channel.receive() == b"hello" + await dtls_channel.send(b"goodbye") await nursery.start(server_dtls.serve, server_ctx, handle_client) @@ -108,12 +109,12 @@ def route_packet_wrapper(packet): next_client_idx = 0 next_client_msg_recvd = trio.Event() - async def handle_client(dtls_stream): + async def handle_client(dtls_channel): print("handling new client") try: - await dtls_stream.do_handshake() + await dtls_channel.do_handshake() while True: - data = await dtls_stream.receive() + data = await dtls_channel.receive() print(f"server received plaintext: {data}") if not data: continue @@ -158,10 +159,95 @@ async def handle_client(dtls_stream): nursery.cancel_scope.cancel() -# implicit handshake on send/receive -# send/receive after closing -# DTLS close -# DTLS on SOCK_STREAM socket +async def test_implicit_handshake(): + with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as server_sock: + await server_sock.bind(("127.0.0.1", 0)) + server_dtls = DTLS(server_sock) + + +def dtls(): + sock = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + return DTLS(sock) + + +@asynccontextmanager +async def dtls_echo_server(*, autocancel=True): + with dtls() as server: + await server.socket.bind(("127.0.0.1", 0)) + async with trio.open_nursery() as nursery: + async def echo_handler(dtls_channel): + async for packet in dtls_channel: + await dtls_channel.send(packet) + + await nursery.start(server.serve, server_ctx, echo_handler) + + yield server, server.socket.getsockname() + + if autocancel: + nursery.cancel_scope.cancel() + + +async def test_implicit_handshake(): + async with dtls_echo_server() as (_, address): + with dtls() as client_endpoint: + client = await client_endpoint.connect(address, client_ctx) + + # Implicit handshake + await client.send(b"xyz") + assert await client.receive() == b"xyz" + + +async def test_channel_closing(): + async with dtls_echo_server() as (_, address): + with dtls() as client_endpoint: + client = await client_endpoint.connect(address, client_ctx) + client.close() + + with pytest.raises(trio.ClosedResourceError): + await client.send(b"abc") + with pytest.raises(trio.ClosedResourceError): + await client.receive() + + +async def test_serve_exits_cleanly_on_close(): + async with dtls_echo_server(autocancel=False) as (server_endpoint, address): + server_endpoint.close() + # Testing that the nursery exits even without being cancelled + + +async def test_client_multiplex(): + async with dtls_echo_server() as (_, address1), dtls_echo_server() as (_, address2): + with dtls() as client_endpoint: + client1 = await client_endpoint.connect(address1, client_ctx) + client2 = await client_endpoint.connect(address2, client_ctx) + + await client1.send(b"abc") + await client2.send(b"xyz") + assert await client2.receive() == b"xyz" + assert await client1.receive() == b"abc" + + client_endpoint.close() + + with pytest.raises(trio.ClosedResourceError): + await client1.send("xxx") + with pytest.raises(trio.ClosedResourceError): + await client2.receive() + with pytest.raises(trio.ClosedResourceError): + await client_endpoint.connect(address1, client_ctx) + + async with trio.open_nursery() as nursery: + with pytest.raises(trio.ClosedResourceError): + async def null_handler(_): # pragma: no cover + pass + await nursery.start(client_endpoint.serve, server_ctx, null_handler) + + +async def test_dtls_over_dgram_only(): + with trio.socket.socket() as s: + with pytest.raises(ValueError): + DTLS(s) + + # incoming packets buffer overflow # send all kinds of garbage at a server socket From f2600525a103dc0ce8fce81fa0377cc59f53bdad Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Sat, 4 Sep 2021 21:51:54 -0700 Subject: [PATCH 18/47] rename DTLS -> DTLSEndpoint --- trio/__init__.py | 2 +- trio/_dtls.py | 2 +- trio/tests/test_dtls.py | 16 ++++++++-------- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/trio/__init__.py b/trio/__init__.py index 782220231e..b35fa076b3 100644 --- a/trio/__init__.py +++ b/trio/__init__.py @@ -74,7 +74,7 @@ from ._ssl import SSLStream, SSLListener, NeedHandshakeError -from ._dtls import DTLS, DTLSChannel +from ._dtls import DTLSEndpoint, DTLSChannel from ._highlevel_serve_listeners import serve_listeners diff --git a/trio/_dtls.py b/trio/_dtls.py index a72a12cd8e..7d5f598ec0 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -913,7 +913,7 @@ async def receive(self): return _read_loop(self._ssl.read) -class DTLS(metaclass=Final): +class DTLSEndpoint(metaclass=Final): def __init__(self, socket, *, incoming_packets_buffer=10): # We do this lazily on first construction, so only people who actually use DTLS # have to install PyOpenSSL. diff --git a/trio/tests/test_dtls.py b/trio/tests/test_dtls.py index d350a539b7..763affa00e 100644 --- a/trio/tests/test_dtls.py +++ b/trio/tests/test_dtls.py @@ -1,6 +1,6 @@ import pytest import trio -from trio._dtls import DTLS +from trio import DTLSEndpoint import random import attr from contextlib import asynccontextmanager @@ -34,7 +34,7 @@ async def test_smoke(family): server_sock = trio.socket.socket(type=trio.socket.SOCK_DGRAM, family=family) with server_sock: await server_sock.bind((localhost, 0)) - server_dtls = DTLS(server_sock) + server_dtls = DTLSEndpoint(server_sock) async with trio.open_nursery() as nursery: @@ -46,7 +46,7 @@ async def handle_client(dtls_channel): await nursery.start(server_dtls.serve, server_ctx, handle_client) with trio.socket.socket(type=trio.socket.SOCK_DGRAM, family=family) as client_sock: - client_dtls = DTLS(client_sock) + client_dtls = DTLSEndpoint(client_sock) client = await client_dtls.connect( server_sock.getsockname(), client_ctx ) @@ -104,7 +104,7 @@ def route_packet_wrapper(packet): fn.route_packet = route_packet_wrapper await server_sock.bind(("1.1.1.1", 54321)) - server_dtls = DTLS(server_sock) + server_dtls = DTLSEndpoint(server_sock) next_client_idx = 0 next_client_msg_recvd = trio.Event() @@ -140,7 +140,7 @@ async def handle_client(dtls_channel): print("#" * 80) print("#" * 80) with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as client_sock: - client_dtls = DTLS(client_sock) + client_dtls = DTLSEndpoint(client_sock) client = await client_dtls.connect( server_sock.getsockname(), client_ctx ) @@ -162,12 +162,12 @@ async def handle_client(dtls_channel): async def test_implicit_handshake(): with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as server_sock: await server_sock.bind(("127.0.0.1", 0)) - server_dtls = DTLS(server_sock) + server_dtls = DTLSEndpoint(server_sock) def dtls(): sock = trio.socket.socket(type=trio.socket.SOCK_DGRAM) - return DTLS(sock) + return DTLSEndpoint(sock) @asynccontextmanager @@ -245,7 +245,7 @@ async def null_handler(_): # pragma: no cover async def test_dtls_over_dgram_only(): with trio.socket.socket() as s: with pytest.raises(ValueError): - DTLS(s) + DTLSEndpoint(s) # incoming packets buffer overflow From e694e8ee142d78280e690eabedff9dcbc0016813 Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Sun, 5 Sep 2021 10:53:22 -0700 Subject: [PATCH 19/47] full duplex test (+ racing do_handshake) --- trio/tests/test_dtls.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/trio/tests/test_dtls.py b/trio/tests/test_dtls.py index 763affa00e..000d46cdf1 100644 --- a/trio/tests/test_dtls.py +++ b/trio/tests/test_dtls.py @@ -197,6 +197,24 @@ async def test_implicit_handshake(): assert await client.receive() == b"xyz" +async def test_full_duplex(): + with dtls() as server_endpoint, dtls() as client_endpoint: + await server_endpoint.socket.bind(("127.0.0.1", 0)) + async with trio.open_nursery() as server_nursery: + async def handler(channel): + async with trio.open_nursery() as nursery: + nursery.start_soon(channel.send, b"from server") + nursery.start_soon(channel.receive) + + await server_nursery.start(server_endpoint.serve, server_ctx, handler) + + client = await client_endpoint.connect(server_endpoint.socket.getsockname(), client_ctx) + async with trio.open_nursery() as nursery: + nursery.start_soon(client.send, b"from client") + nursery.start_soon(client.receive) + + server_nursery.cancel_scope.cancel() + async def test_channel_closing(): async with dtls_echo_server() as (_, address): with dtls() as client_endpoint: @@ -255,8 +273,6 @@ async def test_dtls_over_dgram_only(): # socket closed at terrible times # cancelling a client handshake and then starting a new one # garbage collecting DTLS object without closing it -# closing a DTLSChannel -# two simultaneous calls to .do_handshake() # openssl retransmit # receive a piece of garbage from the correct source during a handshake (corrupted # packet, someone being a jerk) -- though can't necessarily tolerate someone sending a From 17b5f902b6b1f1a14364500023d70c7a7d1a58bf Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Sun, 5 Sep 2021 11:06:54 -0700 Subject: [PATCH 20/47] testing testing --- trio/_dtls.py | 25 ++++++++++++++----------- trio/tests/test_dtls.py | 21 ++++++++++++++++++++- 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/trio/_dtls.py b/trio/_dtls.py index 7d5f598ec0..6544cb0d1f 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -695,8 +695,8 @@ async def dtls_receive_loop(dtls): class DTLSChannel(trio.abc.Channel[bytes], metaclass=NoPublicConstructor): - def __init__(self, dtls, peer_address, ctx): - self.dtls = dtls + def __init__(self, endpoint, peer_address, ctx): + self.endpoint = endpoint self.peer_address = peer_address self.packets_dropped_in_trio = 0 self._client_hello = None @@ -711,10 +711,10 @@ def __init__(self, dtls, peer_address, ctx): self._mtu = None # This calls self._ssl.set_ciphertext_mtu, which is important, because if you # don't call it then openssl doesn't work. - self.set_ciphertext_mtu(best_guess_mtu(self.dtls.socket)) + self.set_ciphertext_mtu(best_guess_mtu(self.endpoint.socket)) self._replaced = False self._closed = False - self._q = _Queue(dtls.incoming_packets_buffer) + self._q = _Queue(endpoint.incoming_packets_buffer) self._handshake_lock = trio.Lock() self._record_encoder = RecordEncoder() @@ -748,8 +748,8 @@ def close(self): if self._closed: return self._closed = True - if self.dtls._streams.get(self.peer_address) is self: - del self.dtls._streams[self.peer_address] + if self.endpoint._streams.get(self.peer_address) is self: + del self.endpoint._streams[self.peer_address] # Will wake any tasks waiting on self._q.get with a # ClosedResourceError self._q.r.close() @@ -777,8 +777,8 @@ def _inject_client_hello_untrusted(self, packet): async def _send_volley(self, volley_messages): packets = self._record_encoder.encode_volley(volley_messages, self._mtu) for packet in packets: - async with self.dtls._send_lock: - await self.dtls.socket.sendto(packet, self.peer_address) + async with self.endpoint._send_lock: + await self.endpoint.socket.sendto(packet, self.peer_address) async def _resend_final_volley(self): await self._send_volley(self._final_volley) @@ -881,7 +881,7 @@ def read_volley(): # PMTU estimate is wrong? Let's try dropping it to the minimum # and hope that helps. self.set_ciphertext_mtu( - min(self._mtu, worst_case_mtu(self.dtls.socket)) + min(self._mtu, worst_case_mtu(self.endpoint.socket)) ) async def send(self, data): @@ -891,8 +891,8 @@ async def send(self, data): await self.do_handshake() self._check_replaced() self._ssl.write(data) - async with self.dtls._send_lock: - await self.dtls.socket.sendto( + async with self.endpoint._send_lock: + await self.endpoint.socket.sendto( _read_loop(self._ssl.bio_read), self.peer_address ) @@ -992,6 +992,9 @@ def _set_stream_for(self, address, stream): self._streams[address] = stream async def connect(self, address, ssl_context): + # it would be nice if we could detect when 'address' is our own endpoint (a + # loopback connection), because that can't work + # but I don't see how to do it reliably self._check_closed() stream = DTLSChannel._create(self, address, ssl_context) stream._ssl.set_connect_state() diff --git a/trio/tests/test_dtls.py b/trio/tests/test_dtls.py index 000d46cdf1..c8ac750a99 100644 --- a/trio/tests/test_dtls.py +++ b/trio/tests/test_dtls.py @@ -176,6 +176,9 @@ async def dtls_echo_server(*, autocancel=True): await server.socket.bind(("127.0.0.1", 0)) async with trio.open_nursery() as nursery: async def echo_handler(dtls_channel): + print(f"echo handler started: " + f"server {dtls_channel.endpoint.socket.getsockname()} " + f"client {dtls_channel.peer_address}") async for packet in dtls_channel: await dtls_channel.send(packet) @@ -266,6 +269,23 @@ async def test_dtls_over_dgram_only(): DTLSEndpoint(s) +async def test_double_serve(): + async def null_handler(_): # pragma: no cover + pass + + with dtls() as endpoint: + async with trio.open_nursery() as nursery: + await nursery.start(endpoint.serve, server_ctx, null_handler) + with pytest.raises(trio.BusyResourceError): + await nursery.start(endpoint.serve, server_ctx, null_handler) + + nursery.cancel_scope.cancel() + + async with trio.open_nursery() as nursery: + await nursery.start(endpoint.serve, server_ctx, null_handler) + nursery.cancel_scope.cancel() + + # incoming packets buffer overflow # send all kinds of garbage at a server socket @@ -277,5 +297,4 @@ async def test_dtls_over_dgram_only(): # receive a piece of garbage from the correct source during a handshake (corrupted # packet, someone being a jerk) -- though can't necessarily tolerate someone sending a # fake HelloRetryRequest -# calling serve twice # connect() that replaces an existing association (currently totally broken!) From 8a1e2cbb36ce1214311ae1659f88c195d29acbeb Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Sun, 5 Sep 2021 12:59:34 -0700 Subject: [PATCH 21/47] Switch DTLSChannel to follow '.statistics()' convention --- trio/_dtls.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/trio/_dtls.py b/trio/_dtls.py index 6544cb0d1f..00e8c05ace 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -686,7 +686,7 @@ async def dtls_receive_loop(dtls): try: stream._q.s.send_nowait(packet) except trio.WouldBlock: - stream.packets_dropped_in_trio += 1 + stream._packets_dropped_in_trio += 1 else: # Drop packet pass @@ -694,11 +694,16 @@ async def dtls_receive_loop(dtls): del dtls +@attr.frozen +class DTLSChannelStatistics: + incoming_packets_dropped_in_trio: int + + class DTLSChannel(trio.abc.Channel[bytes], metaclass=NoPublicConstructor): def __init__(self, endpoint, peer_address, ctx): self.endpoint = endpoint self.peer_address = peer_address - self.packets_dropped_in_trio = 0 + self._packets_dropped_in_trio = 0 self._client_hello = None self._did_handshake = False # These are mandatory for all DTLS connections. OP_NO_QUERY_MTU is required to @@ -718,6 +723,9 @@ def __init__(self, endpoint, peer_address, ctx): self._handshake_lock = trio.Lock() self._record_encoder = RecordEncoder() + def statistics(self) -> DTLSChannelStatistics: + return DTLSChannelStatistics(self._packets_dropped_in_trio) + def _set_replaced(self): self._replaced = True # Any packets we already received could maybe possibly still be processed, but From db4b54947f3a606a709a822114cd2041b1e43b1b Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Sun, 5 Sep 2021 12:59:49 -0700 Subject: [PATCH 22/47] moar tests --- trio/tests/test_dtls.py | 163 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 158 insertions(+), 5 deletions(-) diff --git a/trio/tests/test_dtls.py b/trio/tests/test_dtls.py index c8ac750a99..f6ac32822d 100644 --- a/trio/tests/test_dtls.py +++ b/trio/tests/test_dtls.py @@ -4,6 +4,7 @@ import random import attr from contextlib import asynccontextmanager +from itertools import count import trustme from OpenSSL import SSL @@ -165,9 +166,9 @@ async def test_implicit_handshake(): server_dtls = DTLSEndpoint(server_sock) -def dtls(): +def dtls(**kwargs): sock = trio.socket.socket(type=trio.socket.SOCK_DGRAM) - return DTLSEndpoint(sock) + return DTLSEndpoint(sock, **kwargs) @asynccontextmanager @@ -180,6 +181,7 @@ async def echo_handler(dtls_channel): f"server {dtls_channel.endpoint.socket.getsockname()} " f"client {dtls_channel.peer_address}") async for packet in dtls_channel: + print(f"echoing {packet} -> {dtls_channel.peer_address}") await dtls_channel.send(packet) await nursery.start(server.serve, server_ctx, echo_handler) @@ -286,13 +288,164 @@ async def null_handler(_): # pragma: no cover nursery.cancel_scope.cancel() -# incoming packets buffer overflow +async def test_connect_to_non_server(autojump_clock): + fn = FakeNet() + fn.enable() + with dtls() as client1, dtls() as client2: + await client1.socket.bind(("127.0.0.1", 0)) + # This should just time out + with trio.move_on_after(100) as cscope: + await client2.connect(client1.socket.getsockname(), client_ctx) + assert cscope.cancelled_caught + + +async def test_incoming_buffer_overflow(autojump_clock): + fn = FakeNet() + fn.enable() + for buffer_size in [10, 20]: + async with dtls_echo_server() as (_, address): + with dtls(incoming_packets_buffer=buffer_size) as client_endpoint: + assert client_endpoint.incoming_packets_buffer == buffer_size + client = await client_endpoint.connect(address, client_ctx) + for i in range(buffer_size + 15): + await client.send(str(i).encode()) + await trio.sleep(1) + stats = client.statistics() + assert stats.incoming_packets_dropped_in_trio == 15 + for i in range(buffer_size): + assert await client.receive() == str(i).encode() + await client.send(b"buffer clear now") + assert await client.receive() == b"buffer clear now" + + +async def test_server_socket_doesnt_crash_on_garbage(autojump_clock): + fn = FakeNet() + fn.enable() + + from trio._dtls import ( + Record, encode_record, HandshakeFragment, encode_handshake_fragment, + ContentType, HandshakeType, ProtocolVersion, + ) + + client_hello = encode_record( + Record( + content_type=ContentType.handshake, + version=ProtocolVersion.DTLS10, + epoch_seqno=0, + payload=encode_handshake_fragment( + HandshakeFragment( + msg_type=HandshakeType.client_hello, + msg_len=10, + msg_seq=0, + frag_offset=0, + frag_len=10, + frag=bytes(10), + ) + ), + ) + ) + + client_hello_extended = client_hello + b"\x00" + client_hello_short = client_hello[:-1] + # cuts off in middle of handshake message header + client_hello_really_short = client_hello[:14] + client_hello_corrupt_record_len = bytearray(client_hello) + client_hello_corrupt_record_len[11] = 0xff + + client_hello_fragmented = encode_record( + Record( + content_type=ContentType.handshake, + version=ProtocolVersion.DTLS10, + epoch_seqno=0, + payload=encode_handshake_fragment( + HandshakeFragment( + msg_type=HandshakeType.client_hello, + msg_len=20, + msg_seq=0, + frag_offset=0, + frag_len=10, + frag=bytes(10), + ) + ), + ) + ) + + client_hello_trailing_data_in_record = encode_record( + Record( + content_type=ContentType.handshake, + version=ProtocolVersion.DTLS10, + epoch_seqno=0, + payload=encode_handshake_fragment( + HandshakeFragment( + msg_type=HandshakeType.client_hello, + msg_len=20, + msg_seq=0, + frag_offset=0, + frag_len=10, + frag=bytes(10), + ) + ) + b"\x00", + ) + ) + async with dtls_echo_server() as (_, address): + with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as sock: + for bad_packet in [ + b"", + b"xyz", + client_hello_extended, + client_hello_short, + client_hello_really_short, + client_hello_corrupt_record_len, + client_hello_fragmented, + client_hello_trailing_data_in_record, + ]: + await sock.sendto(bad_packet, address) + await trio.sleep(1) + + +async def test_invalid_cookie_rejected(autojump_clock): + fn = FakeNet() + fn.enable() + + from trio._dtls import decode_client_hello_untrusted, BadPacket + + offset_to_corrupt = count() + def route_packet(packet): + try: + _, cookie, _ = decode_client_hello_untrusted(packet.payload) + except BadPacket: + pass + else: + if len(cookie) != 0: + # this is a challenge response packet + # let's corrupt the next offset so the handshake should fail + payload = bytearray(packet.payload) + offset = next(offset_to_corrupt) + if offset >= len(payload): + # We've tried all offsets + # clamp offset to the end of the payload, and tell the client to stop + # trying to connect + offset = len(payload) - 1 + cscope.cancel() + payload[offset] ^= 0x01 + packet = attr.evolve(packet, payload=payload) + + fn.deliver_packet(packet) + + fn.route_packet = route_packet + + async with dtls_echo_server() as (_, address): + with trio.CancelScope() as cscope: + while True: + with dtls() as client: + await client.connect(address, client_ctx) + assert cscope.cancelled_caught -# send all kinds of garbage at a server socket -# send hello at a client-only socket # socket closed at terrible times # cancelling a client handshake and then starting a new one # garbage collecting DTLS object without closing it + # use fakenet, send a packet to the server, then immediately drop the dtls object and + # run gc before `sock.recvfrom()` can return # openssl retransmit # receive a piece of garbage from the correct source during a handshake (corrupted # packet, someone being a jerk) -- though can't necessarily tolerate someone sending a From 57d75addbe713e4cad9124cc61365d789919e377 Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Sun, 5 Sep 2021 23:41:22 -0700 Subject: [PATCH 23/47] Fixes + tests - use correct initial seqno on server in case of packet loss - calculate timeouts correctly - moar tests --- trio/_dtls.py | 106 ++++++++------- trio/testing/_fake_net.py | 2 +- trio/tests/test_dtls.py | 277 +++++++++++++++++++++++++++++++------- 3 files changed, 288 insertions(+), 97 deletions(-) diff --git a/trio/_dtls.py b/trio/_dtls.py index 00e8c05ace..b9eb913482 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -127,12 +127,15 @@ def is_client_hello_untrusted(packet): RECORD_HEADER = struct.Struct("!B2sQH") +hex_repr = attr.ib(repr=lambda data: data.hex()) + + @attr.frozen class Record: content_type: int - version: bytes + version: bytes = hex_repr epoch_seqno: int - payload: bytes + payload: bytes = hex_repr def records_untrusted(packet): @@ -176,7 +179,7 @@ class HandshakeFragment: msg_seq: int frag_offset: int frag_len: int - frag: bytes + frag: bytes = attr.ib(repr=lambda f: f.hex()) def decode_handshake_fragment_untrusted(payload): @@ -288,19 +291,19 @@ def decode_client_hello_untrusted(packet): @attr.frozen class HandshakeMessage: - record_version: bytes + record_version: bytes = hex_repr msg_type: HandshakeType msg_seq: int - body: bytearray + body: bytearray = hex_repr # ChangeCipherSpec is part of the handshake, but it's not a "handshake # message" and can't be fragmented the same way. Sigh. @attr.frozen class PseudoHandshakeMessage: - record_version: bytes + record_version: bytes = hex_repr content_type: int - payload: bytes + payload: bytes = hex_repr # The final record in a handshake is Finished, which is encrypted, can't be fragmented @@ -365,8 +368,8 @@ class RecordEncoder: def __init__(self): self._record_seq = count() - def skip_first_record_number(self): - assert next(self._record_seq) == 0 + def set_first_record_number(self, n): + self._record_seq = count(n) def encode_volley(self, messages, mtu): packets = [] @@ -624,15 +627,28 @@ async def handle_client_hello_untrusted(dtls, address, packet): else: # We got a real, valid ClientHello! stream = DTLSChannel._create(dtls, address, dtls._listening_context) + # Our HelloRetryRequest had some sequence number. We need our future sequence + # numbers to be larger than it, so our peer knows that our future records aren't + # stale/duplicates. But, we don't know what this sequence number was. What we do + # know is: + # - the HelloRetryRequest seqno was copied it from the initial ClientHello + # - the new ClientHello has a higher seqno than the initial ClientHello + # So, if we copy the new ClientHello's seqno into our first real handshake + # record and increment from there, that should work. + stream._record_encoder.set_first_record_number(epoch_seqno) + # Process the ClientHello try: - stream._inject_client_hello_untrusted(packet) - except BadPacket: - # ...or, well, OpenSSL didn't like it, so I guess we didn't. + stream._ssl.bio_write(packet) + stream._ssl.DTLSv1_listen() + except SSL.Error: + # ...OpenSSL didn't like it, so I guess we didn't have a valid ClientHello + # after all. return + # Check if we have an existing association old_stream = dtls._streams.get(address) if old_stream is not None: if old_stream._client_hello == (cookie, bits): - # ...but it's just a duplicate of a packet we got before, so never mind. + # ...This was just a duplicate of the last ClientHello, so never mind. return else: # Ok, this *really is* a new handshake; the old stream should go away. @@ -762,26 +778,16 @@ def close(self): # ClosedResourceError self._q.r.close() + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + async def aclose(self): self.close() await trio.lowlevel.checkpoint() - def _inject_client_hello_untrusted(self, packet): - self._ssl.bio_write(packet) - # If we're on the server side, then we already sent record 0 as our cookie - # challenge. So we want to start the handshake proper with record 1. - self._record_encoder.skip_first_record_number() - # We've already validated this cookie. But, we still have to call DTLSv1_listen - # so OpenSSL thinks that it's verified the cookie. The problem is that - # if you're doing cookie challenges, then the actual ClientHello has msg_seq=1 - # instead of msg_seq=0, and OpenSSL will refuse to process a ClientHello with - # msg_seq=1 unless you've called DTLSv1_listen. It also gets OpenSSL to bump the - # outgoing ServerHello's msg_seq to 1. - try: - self._ssl.DTLSv1_listen() - except SSL.Error: - raise BadPacket - async def _send_volley(self, volley_messages): packets = self._record_encoder.encode_volley(volley_messages, self._mtu) for packet in packets: @@ -826,7 +832,7 @@ def read_volley(): # If we don't have messages to send in our initial volley, then something # has gone very wrong. (I'm not sure this can actually happen without an # error from OpenSSL, but we check just in case.) - if not volley_messages: + if not volley_messages: # pragma: no cover raise SSL.Error("something wrong with peer's ClientHello") while True: @@ -834,14 +840,14 @@ def read_volley(): assert volley_messages await self._send_volley(volley_messages) # -- then this is where we wait for a reply -- - with trio.move_on_after(10) as cscope: + with trio.move_on_after(timeout) as cscope: async for packet in self._q.r: self._ssl.bio_write(packet) try: self._ssl.do_handshake() # We ignore generic SSL.Error here, because you can get those # from random invalid packets - except (SSL.WantReadError, SSL.Error): + except (SSL.WantReadError, SSL.Error) as exc: pass else: # No exception -> the handshake is done, and we can @@ -987,25 +993,33 @@ async def serve( try: self._listening_context = ssl_context task_status.started() + + async def handler_wrapper(stream): + with stream: + await async_fn(stream, *args) + async with trio.open_nursery() as nursery: - async for stream in self._incoming_connections_q.r: - nursery.start_soon(async_fn, stream, *args) + async for stream in self._incoming_connections_q.r: # pragma: no branch + nursery.start_soon(handler_wrapper, stream) finally: self._listening_context = None - def _set_stream_for(self, address, stream): - old_stream = self._streams.get(address) - if old_stream is not None: - old_stream._break(RuntimeError("replaced by a new DTLS association")) - self._streams[address] = stream - - async def connect(self, address, ssl_context): + async def connect(self, address, ssl_context, *, initial_retransmit_timeout=1.0): # it would be nice if we could detect when 'address' is our own endpoint (a # loopback connection), because that can't work # but I don't see how to do it reliably self._check_closed() - stream = DTLSChannel._create(self, address, ssl_context) - stream._ssl.set_connect_state() - self._set_stream_for(address, stream) - await stream.do_handshake() - return stream + channel = DTLSChannel._create(self, address, ssl_context) + channel._ssl.set_connect_state() + old_channel = self._streams.get(address) + if old_channel is not None: + old_channel._set_replaced() + self._streams[address] = channel + try: + await channel.do_handshake( + initial_retransmit_timeout=initial_retransmit_timeout + ) + except: + channel.close() + raise + return channel diff --git a/trio/testing/_fake_net.py b/trio/testing/_fake_net.py index df853cb87b..f0ea927734 100644 --- a/trio/testing/_fake_net.py +++ b/trio/testing/_fake_net.py @@ -89,7 +89,7 @@ class UDPBinding: class UDPPacket: source: UDPEndpoint destination: UDPEndpoint - payload: bytes + payload: bytes = attr.ib(repr=lambda p: p.hex()) def reply(self, payload): return UDPPacket( diff --git a/trio/tests/test_dtls.py b/trio/tests/test_dtls.py index f6ac32822d..bd1842fb25 100644 --- a/trio/tests/test_dtls.py +++ b/trio/tests/test_dtls.py @@ -26,6 +26,7 @@ if can_bind_ipv6: families.append(trio.socket.AF_INET6) + @pytest.mark.parametrize("family", families) async def test_smoke(family): if family == trio.socket.AF_INET: @@ -46,7 +47,9 @@ async def handle_client(dtls_channel): await nursery.start(server_dtls.serve, server_ctx, handle_client) - with trio.socket.socket(type=trio.socket.SOCK_DGRAM, family=family) as client_sock: + with trio.socket.socket( + type=trio.socket.SOCK_DGRAM, family=family + ) as client_sock: client_dtls = DTLSEndpoint(client_sock) client = await client_dtls.connect( server_sock.getsockname(), client_ctx @@ -176,10 +179,13 @@ async def dtls_echo_server(*, autocancel=True): with dtls() as server: await server.socket.bind(("127.0.0.1", 0)) async with trio.open_nursery() as nursery: + async def echo_handler(dtls_channel): - print(f"echo handler started: " - f"server {dtls_channel.endpoint.socket.getsockname()} " - f"client {dtls_channel.peer_address}") + print( + f"echo handler started: " + f"server {dtls_channel.endpoint.socket.getsockname()} " + f"client {dtls_channel.peer_address}" + ) async for packet in dtls_channel: print(f"echoing {packet} -> {dtls_channel.peer_address}") await dtls_channel.send(packet) @@ -206,6 +212,7 @@ async def test_full_duplex(): with dtls() as server_endpoint, dtls() as client_endpoint: await server_endpoint.socket.bind(("127.0.0.1", 0)) async with trio.open_nursery() as server_nursery: + async def handler(channel): async with trio.open_nursery() as nursery: nursery.start_soon(channel.send, b"from server") @@ -213,13 +220,16 @@ async def handler(channel): await server_nursery.start(server_endpoint.serve, server_ctx, handler) - client = await client_endpoint.connect(server_endpoint.socket.getsockname(), client_ctx) + client = await client_endpoint.connect( + server_endpoint.socket.getsockname(), client_ctx + ) async with trio.open_nursery() as nursery: nursery.start_soon(client.send, b"from client") nursery.start_soon(client.receive) server_nursery.cancel_scope.cancel() + async def test_channel_closing(): async with dtls_echo_server() as (_, address): with dtls() as client_endpoint: @@ -231,11 +241,18 @@ async def test_channel_closing(): with pytest.raises(trio.ClosedResourceError): await client.receive() + # close is idempotent + client.close() + # can also aclose + await client.aclose() + async def test_serve_exits_cleanly_on_close(): async with dtls_echo_server(autocancel=False) as (server_endpoint, address): server_endpoint.close() # Testing that the nursery exits even without being cancelled + # close is idempotent + server_endpoint.close() async def test_client_multiplex(): @@ -260,8 +277,10 @@ async def test_client_multiplex(): async with trio.open_nursery() as nursery: with pytest.raises(trio.ClosedResourceError): + async def null_handler(_): # pragma: no cover pass + await nursery.start(client_endpoint.serve, server_ctx, null_handler) @@ -323,8 +342,13 @@ async def test_server_socket_doesnt_crash_on_garbage(autojump_clock): fn.enable() from trio._dtls import ( - Record, encode_record, HandshakeFragment, encode_handshake_fragment, - ContentType, HandshakeType, ProtocolVersion, + Record, + encode_record, + HandshakeFragment, + encode_handshake_fragment, + ContentType, + HandshakeType, + ProtocolVersion, ) client_hello = encode_record( @@ -350,7 +374,7 @@ async def test_server_socket_doesnt_crash_on_garbage(autojump_clock): # cuts off in middle of handshake message header client_hello_really_short = client_hello[:14] client_hello_corrupt_record_len = bytearray(client_hello) - client_hello_corrupt_record_len[11] = 0xff + client_hello_corrupt_record_len[11] = 0xFF client_hello_fragmented = encode_record( Record( @@ -384,20 +408,21 @@ async def test_server_socket_doesnt_crash_on_garbage(autojump_clock): frag_len=10, frag=bytes(10), ) - ) + b"\x00", + ) + + b"\x00", ) ) async with dtls_echo_server() as (_, address): with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as sock: for bad_packet in [ - b"", - b"xyz", - client_hello_extended, - client_hello_short, - client_hello_really_short, - client_hello_corrupt_record_len, - client_hello_fragmented, - client_hello_trailing_data_in_record, + b"", + b"xyz", + client_hello_extended, + client_hello_short, + client_hello_really_short, + client_hello_corrupt_record_len, + client_hello_fragmented, + client_hello_trailing_data_in_record, ]: await sock.sendto(bad_packet, address) await trio.sleep(1) @@ -409,45 +434,197 @@ async def test_invalid_cookie_rejected(autojump_clock): from trio._dtls import decode_client_hello_untrusted, BadPacket - offset_to_corrupt = count() - def route_packet(packet): - try: - _, cookie, _ = decode_client_hello_untrusted(packet.payload) - except BadPacket: - pass - else: - if len(cookie) != 0: - # this is a challenge response packet - # let's corrupt the next offset so the handshake should fail - payload = bytearray(packet.payload) - offset = next(offset_to_corrupt) - if offset >= len(payload): - # We've tried all offsets - # clamp offset to the end of the payload, and tell the client to stop - # trying to connect - offset = len(payload) - 1 - cscope.cancel() - payload[offset] ^= 0x01 - packet = attr.evolve(packet, payload=payload) + with trio.CancelScope() as cscope: + + offset_to_corrupt = count() + + def route_packet(packet): + try: + _, cookie, _ = decode_client_hello_untrusted(packet.payload) + except BadPacket: + pass + else: + if len(cookie) != 0: + # this is a challenge response packet + # let's corrupt the next offset so the handshake should fail + payload = bytearray(packet.payload) + offset = next(offset_to_corrupt) + if offset >= len(payload): + # We've tried all offsets. Clamp offset to the end of the + # payload, and terminate the test. + offset = len(payload) - 1 + cscope.cancel() + payload[offset] ^= 0x01 + packet = attr.evolve(packet, payload=payload) + + fn.deliver_packet(packet) + + fn.route_packet = route_packet + + async with dtls_echo_server() as (_, address): + while True: + with dtls() as client: + await client.connect(address, client_ctx) + assert cscope.cancelled_caught + + +async def test_client_cancels_handshake_and_starts_new_one(autojump_clock): + # if a client disappears during the handshake, and then starts a new handshake from + # scratch, then the first handler's channel should fail, and a new handler get + # started + fn = FakeNet() + fn.enable() + + with dtls() as server, dtls() as client: + await server.socket.bind(("127.0.0.1", 0)) + async with trio.open_nursery() as nursery: + first_time = True + + async def handler(channel): + nonlocal first_time + if first_time: + first_time = False + print("handler: first time, cancelling connect") + connect_cscope.cancel() + await trio.sleep(0.5) + print("handler: handshake should fail now") + with pytest.raises(trio.BrokenResourceError): + await channel.do_handshake() + else: + print("handler: not first time, sending hello") + await channel.send(b"hello") + + await nursery.start(server.serve, server_ctx, handler) + + print("client: starting first connect") + with trio.CancelScope() as connect_cscope: + await client.connect(server.socket.getsockname(), client_ctx) + assert connect_cscope.cancelled_caught + + print("client: starting second connect") + channel = await client.connect(server.socket.getsockname(), client_ctx) + assert await channel.receive() == b"hello" + + nursery.cancel_scope.cancel() + + +async def test_swap_client_server(): + with dtls() as a, dtls() as b: + await a.socket.bind(("127.0.0.1", 0)) + await b.socket.bind(("127.0.0.1", 0)) + + async def echo_handler(channel): + async for packet in channel: + await channel.send(packet) + + async def crashing_echo_handler(channel): + with pytest.raises(trio.BrokenResourceError): + await echo_handler(channel) + + async with trio.open_nursery() as nursery: + await nursery.start(a.serve, server_ctx, crashing_echo_handler) + await nursery.start(b.serve, server_ctx, echo_handler) + + b_to_a = await b.connect(a.socket.getsockname(), client_ctx) + await b_to_a.send(b"b as client") + assert await b_to_a.receive() == b"b as client" + a_to_b = await a.connect(b.socket.getsockname(), client_ctx) + with pytest.raises(trio.BrokenResourceError): + await b_to_a.send(b"association broken") + await a_to_b.send(b"a as client") + assert await a_to_b.receive() == b"a as client" + + nursery.cancel_scope.cancel() + + +@slow +async def test_openssl_retransmit_doesnt_break_stuff(): + # can't use autojump_clock here, because the point of the test is to wait for + # openssl's built-in retransmit timer to expire, which is hard-coded to use + # wall-clock time. + fn = FakeNet() + fn.enable() + + blackholed = True + + def route_packet(packet): + if blackholed: + print("dropped packet", packet) + return + print("delivered packet", packet) + # packets.append( + # scapy.all.IP( + # src=packet.source.ip.compressed, dst=packet.destination.ip.compressed + # ) + # / scapy.all.UDP(sport=packet.source.port, dport=packet.destination.port) + # / packet.payload + # ) fn.deliver_packet(packet) fn.route_packet = route_packet + async with dtls_echo_server() as (server_endpoint, address): + with dtls() as client_endpoint: + async with trio.open_nursery() as nursery: + + async def connecter(): + client = await client_endpoint.connect( + address, client_ctx, initial_retransmit_timeout=1.5 + ) + await client.send(b"hi") + assert await client.receive() == b"hi" + + nursery.start_soon(connecter) + + # openssl's default timeout is 1 second, so this ensures that it thinks + # the timeout has expired + await trio.sleep(1.1) + # disable blackholing and send a garbage packet to wake up openssl so it + # notices the timeout has expired + blackholed = False + await server_endpoint.socket.sendto( + b"xxx", client_endpoint.socket.getsockname() + ) + # now the client task should finish connecting and exit cleanly + + # scapy.all.wrpcap("/tmp/trace.pcap", packets) + + +async def test_initial_retransmit_timeout(autojump_clock): + fn = FakeNet() + fn.enable() + + blackholed = True + + def route_packet(packet): + nonlocal blackholed + if blackholed: + blackholed = False + else: + fn.deliver_packet(packet) + + fn.route_packet = route_packet + async with dtls_echo_server() as (_, address): - with trio.CancelScope() as cscope: - while True: - with dtls() as client: - await client.connect(address, client_ctx) - assert cscope.cancelled_caught + for t in [1, 2, 4]: + with dtls() as client: + before = trio.current_time() + blackholed = True + await client.connect(address, client_ctx, initial_retransmit_timeout=t) + after = trio.current_time() + assert after - before == t + + +async def test_tiny_mtu(): + async with dtls_echo_server() as (server, address): + with dtls() as client: + pass + # socket closed at terrible times -# cancelling a client handshake and then starting a new one # garbage collecting DTLS object without closing it - # use fakenet, send a packet to the server, then immediately drop the dtls object and - # run gc before `sock.recvfrom()` can return -# openssl retransmit -# receive a piece of garbage from the correct source during a handshake (corrupted -# packet, someone being a jerk) -- though can't necessarily tolerate someone sending a -# fake HelloRetryRequest -# connect() that replaces an existing association (currently totally broken!) +# use fakenet, send a packet to the server, then immediately drop the dtls object and +# run gc before `sock.recvfrom()` can return + +# ...connect() probably shouldn't do the handshake. makes it impossible to set MTU! From f7250c336fcb2bef9147827630e48bab15327822 Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Mon, 6 Sep 2021 14:28:15 -0700 Subject: [PATCH 24/47] Cleanup pass on names and cookie crypto --- trio/_dtls.py | 86 ++++++++++++++++++++++------------------- trio/tests/test_dtls.py | 47 ++++++++++++---------- 2 files changed, 73 insertions(+), 60 deletions(-) diff --git a/trio/_dtls.py b/trio/_dtls.py index b9eb913482..c1b084519f 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -488,15 +488,19 @@ def encode_volley(self, messages, mtu): # probably pretty harmless? But it's easier to add the salt than to convince myself that # it's *completely* harmless, so, salt it is. -# XX maybe the cookie should also sign the *local* address, so you can't take a cookie -# from one socket and use it on another socket on the same trio? or just generate the -# key in each call to 'serve'. - COOKIE_REFRESH_INTERVAL = 30 # seconds -KEY = None -KEY_BYTES = 8 +KEY_BYTES = 32 COOKIE_HASH = "sha256" SALT_BYTES = 8 +# 32 bytes was the maximum cookie length in DTLS 1.0. DTLS 1.2 raised it to 255. I doubt +# there are any DTLS 1.0 implementations still in the wild, but really 32 bytes is +# plenty, and it also gets rid of a confusing warning in Wireshark output. +# +# We truncate the cookie to 32 bytes, of which 8 bytes is salt, so that leaves 24 bytes +# of truncated HMAC = 192 bit security, which is still massive overkill. (TCP uses 32 +# *bits* for this.) HMAC truncation is explicitly noted as safe in RFC 2104: +# https://datatracker.ietf.org/doc/html/rfc2104#section-5 +COOKIE_LENGTH = 32 def _current_cookie_tick(): @@ -513,12 +517,9 @@ def _signable(*fields): return b"".join(out) -def _make_cookie(salt, tick, address, client_hello_bits): +def _make_cookie(key, salt, tick, address, client_hello_bits): assert len(salt) == SALT_BYTES - - global KEY - if KEY is None: - KEY = os.urandom(KEY_BYTES) + assert len(key) == KEY_BYTES signable_data = _signable( salt, @@ -529,17 +530,17 @@ def _make_cookie(salt, tick, address, client_hello_bits): client_hello_bits, ) - return salt + hmac.digest(KEY, signable_data, COOKIE_HASH) + return (salt + hmac.digest(key, signable_data, COOKIE_HASH))[:COOKIE_LENGTH] -def valid_cookie(cookie, address, client_hello_bits): +def valid_cookie(key, cookie, address, client_hello_bits): if len(cookie) > SALT_BYTES: salt = cookie[:SALT_BYTES] tick = _current_cookie_tick() - cur_cookie = _make_cookie(salt, tick, address, client_hello_bits) - old_cookie = _make_cookie(salt, max(tick - 1, 0), address, client_hello_bits) + cur_cookie = _make_cookie(key, salt, tick, address, client_hello_bits) + old_cookie = _make_cookie(key, salt, max(tick - 1, 0), address, client_hello_bits) # I doubt using a short-circuiting 'or' here would leak any meaningful # information, but why risk it when '|' is just as easy. @@ -550,10 +551,10 @@ def valid_cookie(cookie, address, client_hello_bits): return False -def challenge_for(address, epoch_seqno, client_hello_bits): +def challenge_for(key, address, epoch_seqno, client_hello_bits): salt = os.urandom(SALT_BYTES) tick = _current_cookie_tick() - cookie = _make_cookie(salt, tick, address, client_hello_bits) + cookie = _make_cookie(key, salt, tick, address, client_hello_bits) # HelloVerifyRequest body is: # - 2 bytes version @@ -608,8 +609,8 @@ def _read_loop(read_fn): return b"".join(chunks) -async def handle_client_hello_untrusted(dtls, address, packet): - if dtls._listening_context is None: +async def handle_client_hello_untrusted(endpoint, address, packet): + if endpoint._listening_context is None: return try: @@ -617,16 +618,19 @@ async def handle_client_hello_untrusted(dtls, address, packet): except BadPacket: return - if not valid_cookie(cookie, address, bits): - challenge_packet = challenge_for(address, epoch_seqno, bits) + if endpoint._listening_key is None: + endpoint._listening_key = os.urandom(KEY_BYTES) + + if not valid_cookie(endpoint._listening_key, cookie, address, bits): + challenge_packet = challenge_for(endpoint._listening_key, address, epoch_seqno, bits) try: - async with dtls._send_lock: - await dtls.socket.sendto(challenge_packet, address) + async with endpoint._send_lock: + await endpoint.socket.sendto(challenge_packet, address) except (OSError, trio.ClosedResourceError): pass else: # We got a real, valid ClientHello! - stream = DTLSChannel._create(dtls, address, dtls._listening_context) + stream = DTLSChannel._create(endpoint, address, endpoint._listening_context) # Our HelloRetryRequest had some sequence number. We need our future sequence # numbers to be larger than it, so our peer knows that our future records aren't # stale/duplicates. But, we don't know what this sequence number was. What we do @@ -645,7 +649,7 @@ async def handle_client_hello_untrusted(dtls, address, packet): # after all. return # Check if we have an existing association - old_stream = dtls._streams.get(address) + old_stream = endpoint._streams.get(address) if old_stream is not None: if old_stream._client_hello == (cookie, bits): # ...This was just a duplicate of the last ClientHello, so never mind. @@ -654,14 +658,14 @@ async def handle_client_hello_untrusted(dtls, address, packet): # Ok, this *really is* a new handshake; the old stream should go away. old_stream._set_replaced() stream._client_hello = (cookie, bits) - dtls._streams[address] = stream - dtls._incoming_connections_q.s.send_nowait(stream) + endpoint._streams[address] = stream + endpoint._incoming_connections_q.s.send_nowait(stream) -async def dtls_receive_loop(dtls): - sock = dtls.socket - dtls_ref = weakref.ref(dtls) - del dtls +async def dtls_receive_loop(endpoint): + sock = endpoint.socket + endpoint_ref = weakref.ref(endpoint) + del endpoint while True: try: packet, address = await sock.recvfrom(MAX_UDP_PACKET_SIZE) @@ -678,14 +682,14 @@ async def dtls_receive_loop(dtls): # https://bobobobo.wordpress.com/2009/05/17/udp-an-existing-connection-was-forcibly-closed-by-the-remote-host/ # We'll assume that whatever it is, it's a transient problem. continue - dtls = dtls_ref() + endpoint = endpoint_ref() try: - if dtls is None: + if endpoint is None: return if is_client_hello_untrusted(packet): - await handle_client_hello_untrusted(dtls, address, packet) - elif address in dtls._streams: - stream = dtls._streams[address] + await handle_client_hello_untrusted(endpoint, address, packet) + elif address in endpoint._streams: + stream = endpoint._streams[address] if stream._did_handshake and part_of_handshake_untrusted(packet): # The peer just sent us more handshake messages, that aren't a # ClientHello, and we thought the handshake was done. Some of the @@ -707,7 +711,7 @@ async def dtls_receive_loop(dtls): # Drop packet pass finally: - del dtls + del endpoint @attr.frozen @@ -934,7 +938,7 @@ def __init__(self, socket, *, incoming_packets_buffer=10): global SSL from OpenSSL import SSL - self.socket = None # for __del__ + self.socket = None # for __del__, in case the next line raises if socket.type != trio.socket.SOCK_DGRAM: raise ValueError("DTLS requires a SOCK_DGRAM socket") @@ -948,6 +952,7 @@ def __init__(self, socket, *, incoming_packets_buffer=10): # {remote address: DTLSChannel} self._streams = weakref.WeakValueDictionary() self._listening_context = None + self._listening_key = None self._incoming_connections_q = _Queue(float("inf")) self._send_lock = trio.Lock() self._closed = False @@ -955,10 +960,11 @@ def __init__(self, socket, *, incoming_packets_buffer=10): trio.lowlevel.spawn_system_task(dtls_receive_loop, self) def __del__(self): - # Close the socket in Trio context (if our Trio context still exists), so that - # the background task gets notified about the closure and can exit. + # Do nothing if this object was never fully constructed if self.socket is None: return + # Close the socket in Trio context (if our Trio context still exists), so that + # the background task gets notified about the closure and can exit. try: self._token.run_sync_soon(self.socket.close) except RuntimeError: diff --git a/trio/tests/test_dtls.py b/trio/tests/test_dtls.py index bd1842fb25..7c8a623783 100644 --- a/trio/tests/test_dtls.py +++ b/trio/tests/test_dtls.py @@ -169,14 +169,14 @@ async def test_implicit_handshake(): server_dtls = DTLSEndpoint(server_sock) -def dtls(**kwargs): +def endpoint(**kwargs): sock = trio.socket.socket(type=trio.socket.SOCK_DGRAM) return DTLSEndpoint(sock, **kwargs) @asynccontextmanager async def dtls_echo_server(*, autocancel=True): - with dtls() as server: + with endpoint() as server: await server.socket.bind(("127.0.0.1", 0)) async with trio.open_nursery() as nursery: @@ -200,7 +200,7 @@ async def echo_handler(dtls_channel): async def test_implicit_handshake(): async with dtls_echo_server() as (_, address): - with dtls() as client_endpoint: + with endpoint() as client_endpoint: client = await client_endpoint.connect(address, client_ctx) # Implicit handshake @@ -209,7 +209,7 @@ async def test_implicit_handshake(): async def test_full_duplex(): - with dtls() as server_endpoint, dtls() as client_endpoint: + with endpoint() as server_endpoint, endpoint() as client_endpoint: await server_endpoint.socket.bind(("127.0.0.1", 0)) async with trio.open_nursery() as server_nursery: @@ -232,7 +232,7 @@ async def handler(channel): async def test_channel_closing(): async with dtls_echo_server() as (_, address): - with dtls() as client_endpoint: + with endpoint() as client_endpoint: client = await client_endpoint.connect(address, client_ctx) client.close() @@ -257,7 +257,7 @@ async def test_serve_exits_cleanly_on_close(): async def test_client_multiplex(): async with dtls_echo_server() as (_, address1), dtls_echo_server() as (_, address2): - with dtls() as client_endpoint: + with endpoint() as client_endpoint: client1 = await client_endpoint.connect(address1, client_ctx) client2 = await client_endpoint.connect(address2, client_ctx) @@ -294,23 +294,23 @@ async def test_double_serve(): async def null_handler(_): # pragma: no cover pass - with dtls() as endpoint: + with endpoint() as server_endpoint: async with trio.open_nursery() as nursery: - await nursery.start(endpoint.serve, server_ctx, null_handler) + await nursery.start(server_endpoint.serve, server_ctx, null_handler) with pytest.raises(trio.BusyResourceError): - await nursery.start(endpoint.serve, server_ctx, null_handler) + await nursery.start(server_endpoint.serve, server_ctx, null_handler) nursery.cancel_scope.cancel() async with trio.open_nursery() as nursery: - await nursery.start(endpoint.serve, server_ctx, null_handler) + await nursery.start(server_endpoint.serve, server_ctx, null_handler) nursery.cancel_scope.cancel() async def test_connect_to_non_server(autojump_clock): fn = FakeNet() fn.enable() - with dtls() as client1, dtls() as client2: + with endpoint() as client1, endpoint() as client2: await client1.socket.bind(("127.0.0.1", 0)) # This should just time out with trio.move_on_after(100) as cscope: @@ -323,7 +323,7 @@ async def test_incoming_buffer_overflow(autojump_clock): fn.enable() for buffer_size in [10, 20]: async with dtls_echo_server() as (_, address): - with dtls(incoming_packets_buffer=buffer_size) as client_endpoint: + with endpoint(incoming_packets_buffer=buffer_size) as client_endpoint: assert client_endpoint.incoming_packets_buffer == buffer_size client = await client_endpoint.connect(address, client_ctx) for i in range(buffer_size + 15): @@ -463,7 +463,7 @@ def route_packet(packet): async with dtls_echo_server() as (_, address): while True: - with dtls() as client: + with endpoint() as client: await client.connect(address, client_ctx) assert cscope.cancelled_caught @@ -475,7 +475,7 @@ async def test_client_cancels_handshake_and_starts_new_one(autojump_clock): fn = FakeNet() fn.enable() - with dtls() as server, dtls() as client: + with endpoint() as server, endpoint() as client: await server.socket.bind(("127.0.0.1", 0)) async with trio.open_nursery() as nursery: first_time = True @@ -509,7 +509,7 @@ async def handler(channel): async def test_swap_client_server(): - with dtls() as a, dtls() as b: + with endpoint() as a, endpoint() as b: await a.socket.bind(("127.0.0.1", 0)) await b.socket.bind(("127.0.0.1", 0)) @@ -565,7 +565,7 @@ def route_packet(packet): fn.route_packet = route_packet async with dtls_echo_server() as (server_endpoint, address): - with dtls() as client_endpoint: + with endpoint() as client_endpoint: async with trio.open_nursery() as nursery: async def connecter(): @@ -608,7 +608,7 @@ def route_packet(packet): async with dtls_echo_server() as (_, address): for t in [1, 2, 4]: - with dtls() as client: + with endpoint() as client: before = trio.current_time() blackholed = True await client.connect(address, client_ctx, initial_retransmit_timeout=t) @@ -618,13 +618,20 @@ def route_packet(packet): async def test_tiny_mtu(): async with dtls_echo_server() as (server, address): - with dtls() as client: + with endpoint() as client: pass # socket closed at terrible times + # garbage collecting DTLS object without closing it -# use fakenet, send a packet to the server, then immediately drop the dtls object and -# run gc before `sock.recvfrom()` can return +# (use fakenet, send a packet to the server, then immediately drop the dtls object and +# run gc before `sock.recvfrom()` can return) # ...connect() probably shouldn't do the handshake. makes it impossible to set MTU! + +# handshake with fakenet enforcing the minimum mtu (both ipv4 and ipv6) + +# maybe handshake failure should only set _mtu, not the openssl-level mtu? +# (and rename _mtu to "_effective_handshake_mtu" or something to be clear about its +# purpose) From ca2c652ae43feeb88e54ac25655f079282f5d5c4 Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Mon, 6 Sep 2021 20:55:18 -0700 Subject: [PATCH 25/47] Take handshake out of connect() and make it sync --- trio/_dtls.py | 24 +++-- trio/tests/test_dtls.py | 230 +++++++++++++++++++++++----------------- 2 files changed, 147 insertions(+), 107 deletions(-) diff --git a/trio/_dtls.py b/trio/_dtls.py index c1b084519f..3c22b99c90 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -540,7 +540,9 @@ def valid_cookie(key, cookie, address, client_hello_bits): tick = _current_cookie_tick() cur_cookie = _make_cookie(key, salt, tick, address, client_hello_bits) - old_cookie = _make_cookie(key, salt, max(tick - 1, 0), address, client_hello_bits) + old_cookie = _make_cookie( + key, salt, max(tick - 1, 0), address, client_hello_bits + ) # I doubt using a short-circuiting 'or' here would leak any meaningful # information, but why risk it when '|' is just as easy. @@ -622,7 +624,9 @@ async def handle_client_hello_untrusted(endpoint, address, packet): endpoint._listening_key = os.urandom(KEY_BYTES) if not valid_cookie(endpoint._listening_key, cookie, address, bits): - challenge_packet = challenge_for(endpoint._listening_key, address, epoch_seqno, bits) + challenge_packet = challenge_for( + endpoint._listening_key, address, epoch_seqno, bits + ) try: async with endpoint._send_lock: await endpoint.socket.sendto(challenge_packet, address) @@ -1010,7 +1014,7 @@ async def handler_wrapper(stream): finally: self._listening_context = None - async def connect(self, address, ssl_context, *, initial_retransmit_timeout=1.0): + def connect(self, address, ssl_context): # it would be nice if we could detect when 'address' is our own endpoint (a # loopback connection), because that can't work # but I don't see how to do it reliably @@ -1021,11 +1025,11 @@ async def connect(self, address, ssl_context, *, initial_retransmit_timeout=1.0) if old_channel is not None: old_channel._set_replaced() self._streams[address] = channel - try: - await channel.do_handshake( - initial_retransmit_timeout=initial_retransmit_timeout - ) - except: - channel.close() - raise + # try: + # await channel.do_handshake( + # initial_retransmit_timeout=initial_retransmit_timeout + # ) + # except: + # channel.close() + # raise return channel diff --git a/trio/tests/test_dtls.py b/trio/tests/test_dtls.py index 7c8a623783..4f93d3160a 100644 --- a/trio/tests/test_dtls.py +++ b/trio/tests/test_dtls.py @@ -5,12 +5,13 @@ import attr from contextlib import asynccontextmanager from itertools import count +import ipaddress import trustme from OpenSSL import SSL from trio.testing._fake_net import FakeNet -from .._core.tests.tutil import slow, can_bind_ipv6 +from .._core.tests.tutil import slow, binds_ipv6 ca = trustme.CA() server_cert = ca.issue_cert("example.com") @@ -22,51 +23,70 @@ ca.configure_trust(client_ctx) -families = [trio.socket.AF_INET] -if can_bind_ipv6: - families.append(trio.socket.AF_INET6) +parametrize_ipv6 = pytest.mark.parametrize( + "ipv6", [False, pytest.param(True, marks=binds_ipv6)], ids=["ipv4", "ipv6"] +) -@pytest.mark.parametrize("family", families) -async def test_smoke(family): - if family == trio.socket.AF_INET: - localhost = "127.0.0.1" +def endpoint(**kwargs): + ipv6 = kwargs.pop("ipv6", False) + if ipv6: + family = trio.socket.AF_INET6 else: - localhost = "::1" - server_sock = trio.socket.socket(type=trio.socket.SOCK_DGRAM, family=family) - with server_sock: - await server_sock.bind((localhost, 0)) - server_dtls = DTLSEndpoint(server_sock) - - async with trio.open_nursery() as nursery: + family = trio.socket.AF_INET + sock = trio.socket.socket(type=trio.socket.SOCK_DGRAM, family=family) + return DTLSEndpoint(sock, **kwargs) - async def handle_client(dtls_channel): - await dtls_channel.do_handshake() - assert await dtls_channel.receive() == b"hello" - await dtls_channel.send(b"goodbye") - await nursery.start(server_dtls.serve, server_ctx, handle_client) +@asynccontextmanager +async def dtls_echo_server(*, autocancel=True, mtu=None, ipv6=False): + with endpoint(ipv6=ipv6) as server: + if ipv6: + localhost = "::1" + else: + localhost = "127.0.0.1" + await server.socket.bind((localhost, 0)) + async with trio.open_nursery() as nursery: - with trio.socket.socket( - type=trio.socket.SOCK_DGRAM, family=family - ) as client_sock: - client_dtls = DTLSEndpoint(client_sock) - client = await client_dtls.connect( - server_sock.getsockname(), client_ctx + async def echo_handler(dtls_channel): + print( + f"echo handler started: " + f"server {dtls_channel.endpoint.socket.getsockname()} " + f"client {dtls_channel.peer_address}" ) - await client.send(b"hello") - assert await client.receive() == b"goodbye" + if mtu is not None: + dtls_channel.set_ciphertext_mtu(mtu) + async for packet in dtls_channel: + print(f"echoing {packet} -> {dtls_channel.peer_address}") + await dtls_channel.send(packet) - client.set_ciphertext_mtu(1234) - cleartext_mtu_1234 = client.get_cleartext_mtu() - client.set_ciphertext_mtu(4321) - assert client.get_cleartext_mtu() > cleartext_mtu_1234 - client.set_ciphertext_mtu(1234) - assert client.get_cleartext_mtu() == cleartext_mtu_1234 + await nursery.start(server.serve, server_ctx, echo_handler) + + yield server, server.socket.getsockname() + if autocancel: nursery.cancel_scope.cancel() +@parametrize_ipv6 +async def test_smoke(ipv6): + async with dtls_echo_server(ipv6=ipv6) as (server_endpoint, address): + with endpoint(ipv6=ipv6) as client_endpoint: + client_channel = client_endpoint.connect(address, client_ctx) + await client_channel.do_handshake() + await client_channel.send(b"hello") + assert await client_channel.receive() == b"hello" + await client_channel.send(b"goodbye") + assert await client_channel.receive() == b"goodbye" + + client_channel.set_ciphertext_mtu(1234) + cleartext_mtu_1234 = client_channel.get_cleartext_mtu() + client_channel.set_ciphertext_mtu(4321) + assert client_channel.get_cleartext_mtu() > cleartext_mtu_1234 + client_channel.set_ciphertext_mtu(1234) + assert client_channel.get_cleartext_mtu() == cleartext_mtu_1234 + + @slow async def test_handshake_over_terrible_network(autojump_clock): HANDSHAKES = 1000 @@ -145,9 +165,8 @@ async def handle_client(dtls_channel): print("#" * 80) with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as client_sock: client_dtls = DTLSEndpoint(client_sock) - client = await client_dtls.connect( - server_sock.getsockname(), client_ctx - ) + client = client_dtls.connect(server_sock.getsockname(), client_ctx) + await client.do_handshake() while True: data = str(next_client_idx).encode() print(f"client sending plaintext: {data}") @@ -163,45 +182,10 @@ async def handle_client(dtls_channel): nursery.cancel_scope.cancel() -async def test_implicit_handshake(): - with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as server_sock: - await server_sock.bind(("127.0.0.1", 0)) - server_dtls = DTLSEndpoint(server_sock) - - -def endpoint(**kwargs): - sock = trio.socket.socket(type=trio.socket.SOCK_DGRAM) - return DTLSEndpoint(sock, **kwargs) - - -@asynccontextmanager -async def dtls_echo_server(*, autocancel=True): - with endpoint() as server: - await server.socket.bind(("127.0.0.1", 0)) - async with trio.open_nursery() as nursery: - - async def echo_handler(dtls_channel): - print( - f"echo handler started: " - f"server {dtls_channel.endpoint.socket.getsockname()} " - f"client {dtls_channel.peer_address}" - ) - async for packet in dtls_channel: - print(f"echoing {packet} -> {dtls_channel.peer_address}") - await dtls_channel.send(packet) - - await nursery.start(server.serve, server_ctx, echo_handler) - - yield server, server.socket.getsockname() - - if autocancel: - nursery.cancel_scope.cancel() - - async def test_implicit_handshake(): async with dtls_echo_server() as (_, address): with endpoint() as client_endpoint: - client = await client_endpoint.connect(address, client_ctx) + client = client_endpoint.connect(address, client_ctx) # Implicit handshake await client.send(b"xyz") @@ -209,6 +193,8 @@ async def test_implicit_handshake(): async def test_full_duplex(): + # Tests simultaneous send/receive, and also multiple methods implicitly invoking + # do_handshake simultaneously. with endpoint() as server_endpoint, endpoint() as client_endpoint: await server_endpoint.socket.bind(("127.0.0.1", 0)) async with trio.open_nursery() as server_nursery: @@ -220,7 +206,7 @@ async def handler(channel): await server_nursery.start(server_endpoint.serve, server_ctx, handler) - client = await client_endpoint.connect( + client = client_endpoint.connect( server_endpoint.socket.getsockname(), client_ctx ) async with trio.open_nursery() as nursery: @@ -233,7 +219,8 @@ async def handler(channel): async def test_channel_closing(): async with dtls_echo_server() as (_, address): with endpoint() as client_endpoint: - client = await client_endpoint.connect(address, client_ctx) + client = client_endpoint.connect(address, client_ctx) + await client.do_handshake() client.close() with pytest.raises(trio.ClosedResourceError): @@ -258,8 +245,8 @@ async def test_serve_exits_cleanly_on_close(): async def test_client_multiplex(): async with dtls_echo_server() as (_, address1), dtls_echo_server() as (_, address2): with endpoint() as client_endpoint: - client1 = await client_endpoint.connect(address1, client_ctx) - client2 = await client_endpoint.connect(address2, client_ctx) + client1 = client_endpoint.connect(address1, client_ctx) + client2 = client_endpoint.connect(address2, client_ctx) await client1.send(b"abc") await client2.send(b"xyz") @@ -273,7 +260,7 @@ async def test_client_multiplex(): with pytest.raises(trio.ClosedResourceError): await client2.receive() with pytest.raises(trio.ClosedResourceError): - await client_endpoint.connect(address1, client_ctx) + client_endpoint.connect(address1, client_ctx) async with trio.open_nursery() as nursery: with pytest.raises(trio.ClosedResourceError): @@ -314,7 +301,8 @@ async def test_connect_to_non_server(autojump_clock): await client1.socket.bind(("127.0.0.1", 0)) # This should just time out with trio.move_on_after(100) as cscope: - await client2.connect(client1.socket.getsockname(), client_ctx) + channel = client2.connect(client1.socket.getsockname(), client_ctx) + await channel.do_handshake() assert cscope.cancelled_caught @@ -325,7 +313,7 @@ async def test_incoming_buffer_overflow(autojump_clock): async with dtls_echo_server() as (_, address): with endpoint(incoming_packets_buffer=buffer_size) as client_endpoint: assert client_endpoint.incoming_packets_buffer == buffer_size - client = await client_endpoint.connect(address, client_ctx) + client = client_endpoint.connect(address, client_ctx) for i in range(buffer_size + 15): await client.send(str(i).encode()) await trio.sleep(1) @@ -464,7 +452,8 @@ def route_packet(packet): async with dtls_echo_server() as (_, address): while True: with endpoint() as client: - await client.connect(address, client_ctx) + channel = client.connect(address, client_ctx) + await channel.do_handshake() assert cscope.cancelled_caught @@ -498,11 +487,12 @@ async def handler(channel): print("client: starting first connect") with trio.CancelScope() as connect_cscope: - await client.connect(server.socket.getsockname(), client_ctx) + channel = client.connect(server.socket.getsockname(), client_ctx) + await channel.do_handshake() assert connect_cscope.cancelled_caught print("client: starting second connect") - channel = await client.connect(server.socket.getsockname(), client_ctx) + channel = client.connect(server.socket.getsockname(), client_ctx) assert await channel.receive() == b"hello" nursery.cancel_scope.cancel() @@ -525,11 +515,12 @@ async def crashing_echo_handler(channel): await nursery.start(a.serve, server_ctx, crashing_echo_handler) await nursery.start(b.serve, server_ctx, echo_handler) - b_to_a = await b.connect(a.socket.getsockname(), client_ctx) + b_to_a = b.connect(a.socket.getsockname(), client_ctx) await b_to_a.send(b"b as client") assert await b_to_a.receive() == b"b as client" - a_to_b = await a.connect(b.socket.getsockname(), client_ctx) + a_to_b = a.connect(b.socket.getsockname(), client_ctx) + await a_to_b.do_handshake() with pytest.raises(trio.BrokenResourceError): await b_to_a.send(b"association broken") await a_to_b.send(b"a as client") @@ -569,9 +560,8 @@ def route_packet(packet): async with trio.open_nursery() as nursery: async def connecter(): - client = await client_endpoint.connect( - address, client_ctx, initial_retransmit_timeout=1.5 - ) + client = client_endpoint.connect(address, client_ctx) + await client.do_handshake(initial_retransmit_timeout=1.5) await client.send(b"hi") assert await client.receive() == b"hi" @@ -591,7 +581,7 @@ async def connecter(): # scapy.all.wrpcap("/tmp/trace.pcap", packets) -async def test_initial_retransmit_timeout(autojump_clock): +async def test_initial_retransmit_timeout_configuration(autojump_clock): fn = FakeNet() fn.enable() @@ -611,15 +601,65 @@ def route_packet(packet): with endpoint() as client: before = trio.current_time() blackholed = True - await client.connect(address, client_ctx, initial_retransmit_timeout=t) + channel = client.connect(address, client_ctx) + await channel.do_handshake(initial_retransmit_timeout=t) after = trio.current_time() assert after - before == t -async def test_tiny_mtu(): - async with dtls_echo_server() as (server, address): +async def test_setting_tiny_mtu(): + # ClientHello is ~240 bytes, and it can't be fragmented, so our mtu has to + # be larger than that. (300 is still smaller than any real network though.) + MTU = 300 + + fn = FakeNet() + fn.enable() + + def route_packet(packet): + print(f"delivering {packet}") + print(f"payload size: {len(packet.payload)}") + assert len(packet.payload) <= MTU + fn.deliver_packet(packet) + + fn.route_packet = route_packet + + async with dtls_echo_server(mtu=MTU) as (server, address): with endpoint() as client: - pass + channel = client.connect(address, client_ctx) + channel.set_ciphertext_mtu(MTU) + await channel.do_handshake() + await channel.send(b"hi") + assert await channel.receive() == b"hi" + + +@parametrize_ipv6 +async def test_tiny_network_mtu(ipv6, autojump_clock): + # Fake network that has the minimum allowable MTU for whatever protocol we're using. + fn = FakeNet() + fn.enable() + + if ipv6: + mtu = 1280 - 48 + else: + mtu = 576 - 28 + + def route_packet(packet): + if len(packet.payload) > mtu: + print(f"dropping {packet}") + else: + print(f"delivering {packet}") + fn.deliver_packet(packet) + + fn.route_packet = route_packet + + # See if we can successfully do a handshake -- some of the volleys will get dropped, + # and the retransmit logic should detect this and back off the MTU to something + # smaller until it succeeds. + async with dtls_echo_server(ipv6=ipv6) as (_, address): + with endpoint(ipv6=ipv6) as client_endpoint: + client = client_endpoint.connect(address, client_ctx) + await client.send(b"xyz") + assert await client.receive() == b"xyz" # socket closed at terrible times @@ -628,10 +668,6 @@ async def test_tiny_mtu(): # (use fakenet, send a packet to the server, then immediately drop the dtls object and # run gc before `sock.recvfrom()` can return) -# ...connect() probably shouldn't do the handshake. makes it impossible to set MTU! - -# handshake with fakenet enforcing the minimum mtu (both ipv4 and ipv6) - # maybe handshake failure should only set _mtu, not the openssl-level mtu? # (and rename _mtu to "_effective_handshake_mtu" or something to be clear about its # purpose) From 05c3f88d21acdf90fc83eca110f3e45fcc0e9d84 Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Mon, 6 Sep 2021 22:50:35 -0700 Subject: [PATCH 26/47] Cleanups and a few more tests --- trio/_dtls.py | 60 ++++++++--------- trio/tests/test_dtls.py | 139 ++++++++++++++++++++++++---------------- 2 files changed, 115 insertions(+), 84 deletions(-) diff --git a/trio/_dtls.py b/trio/_dtls.py index 3c22b99c90..2964a629ab 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -1,9 +1,5 @@ # https://datatracker.ietf.org/doc/html/rfc6347 -# XX: figure out what to do about the pyopenssl dependency -# Maybe the toplevel __init__.py should use __getattr__ trickery to load all -# the DTLS code lazily? - import struct import hmac import os @@ -12,6 +8,7 @@ from itertools import count import weakref import errno +import warnings import attr @@ -127,7 +124,7 @@ def is_client_hello_untrusted(packet): RECORD_HEADER = struct.Struct("!B2sQH") -hex_repr = attr.ib(repr=lambda data: data.hex()) +hex_repr = attr.ib(repr=lambda data: data.hex()) # pragma: no cover @attr.frozen @@ -143,7 +140,10 @@ def records_untrusted(packet): while i < len(packet): try: ct, version, epoch_seqno, payload_len = RECORD_HEADER.unpack_from(packet, i) - except struct.error as exc: + # Marked as no-cover because at time of writing, this code is unreachable + # (records_untrusted only gets called on packets that are either trusted or that + # have passed is_client_hello_untrusted, which filters out short packets) + except struct.error as exc: # pragma: no cover raise BadPacket("invalid record header") from exc i += RECORD_HEADER.size payload = packet[i : i + payload_len] @@ -179,7 +179,7 @@ class HandshakeFragment: msg_seq: int frag_offset: int frag_len: int - frag: bytes = attr.ib(repr=lambda f: f.hex()) + frag: bytes = hex_repr def decode_handshake_fragment_untrusted(payload): @@ -605,8 +605,6 @@ def _read_loop(read_fn): chunk = read_fn(2 ** 14) # max TLS record size except SSL.WantReadError: break - if not chunk: - break chunks.append(chunk) return b"".join(chunks) @@ -666,10 +664,11 @@ async def handle_client_hello_untrusted(endpoint, address, packet): endpoint._incoming_connections_q.s.send_nowait(stream) -async def dtls_receive_loop(endpoint): - sock = endpoint.socket - endpoint_ref = weakref.ref(endpoint) - del endpoint +async def dtls_receive_loop(endpoint_ref): + try: + sock = endpoint_ref().socket + except AttributeError: + return while True: try: packet, address = await sock.recvfrom(MAX_UDP_PACKET_SIZE) @@ -909,6 +908,8 @@ def read_volley(): async def send(self, data): if self._closed: raise trio.ClosedResourceError + if not data: + raise ValueError("openssl doesn't support sending empty DTLS packets") if not self._did_handshake: await self.do_handshake() self._check_replaced() @@ -921,18 +922,19 @@ async def send(self, data): async def receive(self): if not self._did_handshake: await self.do_handshake() + # If the packet isn't really valid, then openssl can decode it to the empty + # string (e.g. b/c it's a late-arriving handshake packet, or a duplicate copy of + # a data packet). Skip over these instead of returning them. while True: try: packet = await self._q.r.receive() except trio.EndOfChannel: assert self._replaced self._check_replaced() - # Don't return spurious empty packets because of stray handshake packets - # coming in late - if part_of_handshake_untrusted(packet): - continue self._ssl.bio_write(packet) - return _read_loop(self._ssl.read) + cleartext = _read_loop(self._ssl.read) + if cleartext: + return cleartext class DTLSEndpoint(metaclass=Final): @@ -961,7 +963,7 @@ def __init__(self, socket, *, incoming_packets_buffer=10): self._send_lock = trio.Lock() self._closed = False - trio.lowlevel.spawn_system_task(dtls_receive_loop, self) + trio.lowlevel.spawn_system_task(dtls_receive_loop, weakref.ref(self)) def __del__(self): # Do nothing if this object was never fully constructed @@ -969,10 +971,15 @@ def __del__(self): return # Close the socket in Trio context (if our Trio context still exists), so that # the background task gets notified about the closure and can exit. - try: - self._token.run_sync_soon(self.socket.close) - except RuntimeError: - pass + if not self._closed: + try: + self._token.run_sync_soon(self.close) + except RuntimeError: + pass + # Do this last, because it might raise an exception + warnings.warn( + f"unclosed DTLS endpoint {self!r}", ResourceWarning, source=self + ) def close(self): self._closed = True @@ -1025,11 +1032,4 @@ def connect(self, address, ssl_context): if old_channel is not None: old_channel._set_replaced() self._streams[address] = channel - # try: - # await channel.do_handshake( - # initial_retransmit_timeout=initial_retransmit_timeout - # ) - # except: - # channel.close() - # raise return channel diff --git a/trio/tests/test_dtls.py b/trio/tests/test_dtls.py index 4f93d3160a..ad8fb82e7b 100644 --- a/trio/tests/test_dtls.py +++ b/trio/tests/test_dtls.py @@ -6,12 +6,13 @@ from contextlib import asynccontextmanager from itertools import count import ipaddress +import warnings import trustme from OpenSSL import SSL from trio.testing._fake_net import FakeNet -from .._core.tests.tutil import slow, binds_ipv6 +from .._core.tests.tutil import slow, binds_ipv6, gc_collect_harder ca = trustme.CA() server_cert = ca.issue_cert("example.com") @@ -56,9 +57,15 @@ async def echo_handler(dtls_channel): ) if mtu is not None: dtls_channel.set_ciphertext_mtu(mtu) - async for packet in dtls_channel: - print(f"echoing {packet} -> {dtls_channel.peer_address}") - await dtls_channel.send(packet) + print("server starting do_handshake") + await dtls_channel.do_handshake() + print("server finished do_handshake") + try: + async for packet in dtls_channel: + print(f"echoing {packet} -> {dtls_channel.peer_address}") + await dtls_channel.send(packet) + except trio.BrokenResourceError: + pass await nursery.start(server.serve, server_ctx, echo_handler) @@ -79,6 +86,9 @@ async def test_smoke(ipv6): await client_channel.send(b"goodbye") assert await client_channel.receive() == b"goodbye" + with pytest.raises(ValueError): + await client_channel.send(b"") + client_channel.set_ciphertext_mtu(1234) cleartext_mtu_1234 = client_channel.get_cleartext_mtu() client_channel.set_ciphertext_mtu(4321) @@ -93,7 +103,8 @@ async def test_handshake_over_terrible_network(autojump_clock): r = random.Random(0) fn = FakeNet() fn.enable() - with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as server_sock: + + async with dtls_echo_server() as (_, address): async with trio.open_nursery() as nursery: async def route_packet(packet): @@ -127,60 +138,25 @@ def route_packet_wrapper(packet): fn.route_packet = route_packet_wrapper - await server_sock.bind(("1.1.1.1", 54321)) - server_dtls = DTLSEndpoint(server_sock) - - next_client_idx = 0 - next_client_msg_recvd = trio.Event() - - async def handle_client(dtls_channel): - print("handling new client") - try: - await dtls_channel.do_handshake() - while True: - data = await dtls_channel.receive() - print(f"server received plaintext: {data}") - if not data: - continue - assert int(data.decode()) == next_client_idx - next_client_msg_recvd.set() - break - except trio.BrokenResourceError: - # client might have timed out on handshake and started a new one - # so we'll let this task die and let the new task do the check - print("new handshake restarting") - pass - except: - print("server handler saw") - import traceback - - traceback.print_exc() - raise - - await nursery.start(server_dtls.serve, server_ctx, handle_client) - - for _ in range(HANDSHAKES): + for i in range(HANDSHAKES): print("#" * 80) print("#" * 80) print("#" * 80) - with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as client_sock: - client_dtls = DTLSEndpoint(client_sock) - client = client_dtls.connect(server_sock.getsockname(), client_ctx) + with endpoint() as client_endpoint: + client = client_endpoint.connect(address, client_ctx) + print("client starting do_handshake") await client.do_handshake() + print("client finished do_handshake") + msg = str(i).encode() + # Make multiple attempts to send data, because the network might + # drop it while True: - data = str(next_client_idx).encode() - print(f"client sending plaintext: {data}") - await client.send(data) with trio.move_on_after(10) as cscope: - await next_client_msg_recvd.wait() + await client.send(msg) + assert await client.receive() == msg if not cscope.cancelled_caught: break - next_client_idx += 1 - next_client_msg_recvd = trio.Event() - - nursery.cancel_scope.cancel() - async def test_implicit_handshake(): async with dtls_echo_server() as (_, address): @@ -662,11 +638,66 @@ def route_packet(packet): assert await client.receive() == b"xyz" -# socket closed at terrible times +@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning") +async def test_system_task_cleaned_up_on_gc(): + before_tasks = trio.lowlevel.current_statistics().tasks_living + + e = endpoint() + # Give system task a chance to start up + await trio.testing.wait_all_tasks_blocked() + + during_tasks = trio.lowlevel.current_statistics().tasks_living -# garbage collecting DTLS object without closing it -# (use fakenet, send a packet to the server, then immediately drop the dtls object and -# run gc before `sock.recvfrom()` can return) + with pytest.warns(ResourceWarning): + del e + gc_collect_harder() + + await trio.testing.wait_all_tasks_blocked() + + after_tasks = trio.lowlevel.current_statistics().tasks_living + assert before_tasks < during_tasks + assert before_tasks == after_tasks + + +@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning") +async def test_gc_before_system_task_starts(): + e = endpoint() + + with pytest.warns(ResourceWarning): + del e + gc_collect_harder() + + await trio.testing.wait_all_tasks_blocked() + + +async def test_already_closed_socket_doesnt_crash(): + with endpoint() as e: + # We close the socket before checkpointing, so the socket will already be closed + # when the system task starts up + e.socket.close() + # Now give it a chance to start up, and hopefully not crash + await trio.testing.wait_all_tasks_blocked() + + +async def test_socket_closed_while_processing_clienthello(autojump_clock): + fn = FakeNet() + fn.enable() + + # Check what happens if the socket is discovered to be closed when sending a + # HelloVerifyRequest, since that has its own sending logic + async with dtls_echo_server() as (server, address): + def route_packet(packet): + fn.deliver_packet(packet) + server.socket.close() + + fn.route_packet = route_packet + + with endpoint() as client_endpoint: + with trio.move_on_after(10): + client = client_endpoint.connect(address, client_ctx) + await client.do_handshake() + +# socket closed at terrible times # maybe handshake failure should only set _mtu, not the openssl-level mtu? # (and rename _mtu to "_effective_handshake_mtu" or something to be clear about its From 2fb2d0df2b612c101db5f35bdfe20c1164919a52 Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Mon, 6 Sep 2021 23:12:34 -0700 Subject: [PATCH 27/47] Don't let handshake implicitly overwrite user-specified mtu --- trio/_dtls.py | 14 ++++++++------ trio/tests/test_dtls.py | 16 ++++++++-------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/trio/_dtls.py b/trio/_dtls.py index 2964a629ab..3cadbd5a1b 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -736,7 +736,7 @@ def __init__(self, endpoint, peer_address, ctx): # to just performing a new handshake. ctx.set_options(SSL.OP_NO_QUERY_MTU | SSL.OP_NO_RENEGOTIATION) self._ssl = SSL.Connection(ctx) - self._mtu = None + self._handshake_mtu = None # This calls self._ssl.set_ciphertext_mtu, which is important, because if you # don't call it then openssl doesn't work. self.set_ciphertext_mtu(best_guess_mtu(self.endpoint.socket)) @@ -762,7 +762,7 @@ def _check_replaced(self): ) def set_ciphertext_mtu(self, new_mtu): - self._mtu = new_mtu + self._handshake_mtu = new_mtu self._ssl.set_ciphertext_mtu(new_mtu) def get_cleartext_mtu(self): @@ -796,7 +796,9 @@ async def aclose(self): await trio.lowlevel.checkpoint() async def _send_volley(self, volley_messages): - packets = self._record_encoder.encode_volley(volley_messages, self._mtu) + packets = self._record_encoder.encode_volley( + volley_messages, self._handshake_mtu + ) for packet in packets: async with self.endpoint._send_lock: await self.endpoint.socket.sendto(packet, self.peer_address) @@ -901,8 +903,8 @@ def read_volley(): # We tried sending this twice and they both failed. Maybe our # PMTU estimate is wrong? Let's try dropping it to the minimum # and hope that helps. - self.set_ciphertext_mtu( - min(self._mtu, worst_case_mtu(self.endpoint.socket)) + self._handshake_mtu = min( + self._handshake_mtu, worst_case_mtu(self.endpoint.socket) ) async def send(self, data): @@ -947,8 +949,8 @@ def __init__(self, socket, *, incoming_packets_buffer=10): self.socket = None # for __del__, in case the next line raises if socket.type != trio.socket.SOCK_DGRAM: raise ValueError("DTLS requires a SOCK_DGRAM socket") - self.socket = socket + self.incoming_packets_buffer = incoming_packets_buffer self._token = trio.lowlevel.current_trio_token() # We don't need to track handshaking vs non-handshake connections diff --git a/trio/tests/test_dtls.py b/trio/tests/test_dtls.py index ad8fb82e7b..dce1c92954 100644 --- a/trio/tests/test_dtls.py +++ b/trio/tests/test_dtls.py @@ -583,7 +583,7 @@ def route_packet(packet): assert after - before == t -async def test_setting_tiny_mtu(): +async def test_explicit_tiny_mtu_is_respected(): # ClientHello is ~240 bytes, and it can't be fragmented, so our mtu has to # be larger than that. (300 is still smaller than any real network though.) MTU = 300 @@ -609,7 +609,7 @@ def route_packet(packet): @parametrize_ipv6 -async def test_tiny_network_mtu(ipv6, autojump_clock): +async def test_handshake_handles_minimum_network_mtu(ipv6, autojump_clock): # Fake network that has the minimum allowable MTU for whatever protocol we're using. fn = FakeNet() fn.enable() @@ -634,8 +634,13 @@ def route_packet(packet): async with dtls_echo_server(ipv6=ipv6) as (_, address): with endpoint(ipv6=ipv6) as client_endpoint: client = client_endpoint.connect(address, client_ctx) + # the handshake mtu backoff shouldn't affect the return value from + # get_cleartext_mtu, b/c that's under the user's control via + # set_ciphertext_mtu + client.set_ciphertext_mtu(9999) await client.send(b"xyz") assert await client.receive() == b"xyz" + assert client.get_cleartext_mtu() > 9000 # as vegeta said @pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning") @@ -686,6 +691,7 @@ async def test_socket_closed_while_processing_clienthello(autojump_clock): # Check what happens if the socket is discovered to be closed when sending a # HelloVerifyRequest, since that has its own sending logic async with dtls_echo_server() as (server, address): + def route_packet(packet): fn.deliver_packet(packet) server.socket.close() @@ -696,9 +702,3 @@ def route_packet(packet): with trio.move_on_after(10): client = client_endpoint.connect(address, client_ctx) await client.do_handshake() - -# socket closed at terrible times - -# maybe handshake failure should only set _mtu, not the openssl-level mtu? -# (and rename _mtu to "_effective_handshake_mtu" or something to be clear about its -# purpose) From 78ecb675f2c9ead68e495db22e7eef7805375380 Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Tue, 7 Sep 2021 19:21:21 -0700 Subject: [PATCH 28/47] More cleanup and test coverage --- trio/_dtls.py | 100 +++++++++++++++++--------------- trio/tests/test_dtls.py | 125 ++++++++++++++++++++++++++++++++++++++-- 2 files changed, 176 insertions(+), 49 deletions(-) diff --git a/trio/_dtls.py b/trio/_dtls.py index 3cadbd5a1b..3c990e1af2 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -230,7 +230,10 @@ def decode_client_hello_untrusted(packet): try: # ClientHello has to be the first record in the packet record = next(records_untrusted(packet)) - if record.content_type != ContentType.handshake: + # no-cover because at time of writing, this is unreachable: + # decode_client_hello_untrusted is only called on packets that have passed + # is_client_hello_untrusted, which confirms the content type. + if record.content_type != ContentType.handshake: # pragma: no cover raise BadPacket("not a handshake record") fragment = decode_handshake_fragment_untrusted(record.payload) if fragment.msg_type != HandshakeType.client_hello: @@ -669,52 +672,58 @@ async def dtls_receive_loop(endpoint_ref): sock = endpoint_ref().socket except AttributeError: return - while True: - try: - packet, address = await sock.recvfrom(MAX_UDP_PACKET_SIZE) - except trio.ClosedResourceError: - return - except OSError as exc: - if exc.errno in (errno.EBADF, errno.ENOTSOCK): - # Socket was closed - return - else: - # Some weird error, e.g. apparently some versions of Windows can do - # ECONNRESET here to report that some previous UDP packet got an ICMP - # Port Unreachable: - # https://bobobobo.wordpress.com/2009/05/17/udp-an-existing-connection-was-forcibly-closed-by-the-remote-host/ - # We'll assume that whatever it is, it's a transient problem. - continue - endpoint = endpoint_ref() - try: - if endpoint is None: - return - if is_client_hello_untrusted(packet): - await handle_client_hello_untrusted(endpoint, address, packet) - elif address in endpoint._streams: - stream = endpoint._streams[address] - if stream._did_handshake and part_of_handshake_untrusted(packet): - # The peer just sent us more handshake messages, that aren't a - # ClientHello, and we thought the handshake was done. Some of the - # packets that we sent to finish the handshake must have gotten - # lost. So re-send them. We do this directly here instead of just - # putting it into the queue and letting the receiver do it, because - # there's no guarantee that anyone is reading from the queue, - # because we think the handshake is done! - try: + try: + while True: + try: + packet, address = await sock.recvfrom(MAX_UDP_PACKET_SIZE) + except OSError as exc: + if exc.errno == errno.ECONNRESET: + # Windows only: "On a UDP-datagram socket [ECONNRESET] + # indicates a previous send operation resulted in an ICMP Port + # Unreachable message" -- https://docs.microsoft.com/en-us/windows/win32/api/winsock/nf-winsock-recvfrom + # + # This is totally useless -- there's nothing we can do with this + # information. So we just ignore it and retry the recv. + continue + else: + raise + endpoint = endpoint_ref() + try: + if endpoint is None: + return + if is_client_hello_untrusted(packet): + await handle_client_hello_untrusted(endpoint, address, packet) + elif address in endpoint._streams: + stream = endpoint._streams[address] + if stream._did_handshake and part_of_handshake_untrusted(packet): + # The peer just sent us more handshake messages, that aren't a + # ClientHello, and we thought the handshake was done. Some of + # the packets that we sent to finish the handshake must have + # gotten lost. So re-send them. We do this directly here instead + # of just putting it into the queue and letting the receiver do + # it, because there's no guarantee that anyone is reading from + # the queue, because we think the handshake is done! await stream._resend_final_volley() - except trio.ClosedResourceError: - return + else: + try: + stream._q.s.send_nowait(packet) + except trio.WouldBlock: + stream._packets_dropped_in_trio += 1 else: - try: - stream._q.s.send_nowait(packet) - except trio.WouldBlock: - stream._packets_dropped_in_trio += 1 - else: - # Drop packet - pass - finally: - del endpoint + # Drop packet + pass + finally: + del endpoint + except trio.ClosedResourceError: + # socket was closed + return + except OSError as exc: + if exc.errno in (errno.EBADF, errno.ENOTSOCK): + # socket was closed + return + else: # pragma: no cover + # ??? shouldn't happen + raise @attr.frozen @@ -847,6 +856,7 @@ def read_volley(): while True: # -- at this point, we need to either send or re-send a volley -- assert volley_messages + self._check_replaced() await self._send_volley(volley_messages) # -- then this is where we wait for a reply -- with trio.move_on_after(timeout) as cscope: diff --git a/trio/tests/test_dtls.py b/trio/tests/test_dtls.py index dce1c92954..2d09e469c8 100644 --- a/trio/tests/test_dtls.py +++ b/trio/tests/test_dtls.py @@ -64,8 +64,8 @@ async def echo_handler(dtls_channel): async for packet in dtls_channel: print(f"echoing {packet} -> {dtls_channel.peer_address}") await dtls_channel.send(packet) - except trio.BrokenResourceError: - pass + except trio.BrokenResourceError: # pragma: no cover + print("echo handler channel broken") await nursery.start(server.serve, server_ctx, echo_handler) @@ -111,7 +111,7 @@ async def route_packet(packet): while True: op = r.choices( ["deliver", "drop", "dupe", "delay"], - weights=[0.7, 0.1, 0.1, 0.1], + weights=[0.6, 0.1, 0.1, 0.1], )[0] print(f"{packet.source} -> {packet.destination}: {op}") if op == "drop": @@ -120,6 +120,23 @@ async def route_packet(packet): fn.send_packet(packet) elif op == "delay": await trio.sleep(r.random() * 3) + # I wanted to test random packet corruption too, but it turns out + # openssl has a bug in the following scenario: + # - client sends ClientHello + # - server sends HelloVerifyRequest with cookie -- but cookie is + # invalid b/c either the ClientHello or HelloVerifyRequest was + # corrupted + # - client re-sends ClientHello with invalid cookie + # - server replies with new HelloVerifyRequest and correct cookie + # + # At this point, the client *should* switch to the new, valid + # cookie. But OpenSSL doesn't; it stubbornly insists on re-sending + # the original, invalid cookie over and over. + # + # elif op == "distort": + # payload = bytearray(packet.payload) + # payload[r.randrange(len(payload))] ^= 1 << r.randrange(8) + # packet = attr.evolve(packet, payload=payload) else: assert op == "deliver" print( @@ -131,7 +148,7 @@ async def route_packet(packet): def route_packet_wrapper(packet): try: nursery.start_soon(route_packet, packet) - except RuntimeError: + except RuntimeError: # pragma: no cover # We're exiting the nursery, so any remaining packets can just get # dropped pass @@ -376,6 +393,25 @@ async def test_server_socket_doesnt_crash_on_garbage(autojump_clock): + b"\x00", ) ) + + handshake_empty = encode_record( + Record( + content_type=ContentType.handshake, + version=ProtocolVersion.DTLS10, + epoch_seqno=0, + payload=b"", + ) + ) + + client_hello_truncated_in_cookie = encode_record( + Record( + content_type=ContentType.handshake, + version=ProtocolVersion.DTLS10, + epoch_seqno=0, + payload=bytes(2 + 32 + 1) + b"\xff", + ) + ) + async with dtls_echo_server() as (_, address): with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as sock: for bad_packet in [ @@ -387,6 +423,8 @@ async def test_server_socket_doesnt_crash_on_garbage(autojump_clock): client_hello_corrupt_record_len, client_hello_fragmented, client_hello_trailing_data_in_record, + handshake_empty, + client_hello_truncated_in_cookie, ]: await sock.sendto(bad_packet, address) await trio.sleep(1) @@ -471,6 +509,8 @@ async def handler(channel): channel = client.connect(server.socket.getsockname(), client_ctx) assert await channel.receive() == b"hello" + # Give handlers a chance to finish + await trio.sleep(10) nursery.cancel_scope.cancel() @@ -675,6 +715,17 @@ async def test_gc_before_system_task_starts(): await trio.testing.wait_all_tasks_blocked() +@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning") +def test_gc_after_trio_exits(): + async def main(): + return endpoint() + + e = trio.run(main) + with pytest.warns(ResourceWarning): + del e + gc_collect_harder() + + async def test_already_closed_socket_doesnt_crash(): with endpoint() as e: # We close the socket before checkpointing, so the socket will already be closed @@ -702,3 +753,69 @@ def route_packet(packet): with trio.move_on_after(10): client = client_endpoint.connect(address, client_ctx) await client.do_handshake() + + +async def test_association_replaced_while_handshake_running(autojump_clock): + fn = FakeNet() + fn.enable() + + blackholed = True + + def route_packet(packet): + if blackholed: + return + fn.deliver_packet(packet) + + fn.route_packet = route_packet + + async with dtls_echo_server() as (_, address): + with endpoint() as client_endpoint: + c1 = client_endpoint.connect(address, client_ctx) + async with trio.open_nursery() as nursery: + async def doomed_handshake(): + with pytest.raises(trio.BrokenResourceError): + await c1.do_handshake() + + nursery.start_soon(doomed_handshake) + + await trio.sleep(10) + + c2 = client_endpoint.connect(address, client_ctx) + + +async def test_association_replaced_before_handshake_starts(): + fn = FakeNet() + fn.enable() + + # This test shouldn't send any packets + def route_packet(packet): # pragma: no cover + assert False + + fn.route_packet = route_packet + + async with dtls_echo_server() as (_, address): + with endpoint() as client_endpoint: + c1 = client_endpoint.connect(address, client_ctx) + c2 = client_endpoint.connect(address, client_ctx) + with pytest.raises(trio.BrokenResourceError): + await c1.do_handshake() + + +async def test_send_to_closed_local_port(): + # On Windows, sending a UDP packet to a closed local port can cause a weird + # ECONNRESET error later, inside the receive task. Make sure we're handling it + # properly. + async with dtls_echo_server() as (_, address): + with endpoint() as client_endpoint: + async with trio.open_nursery() as nursery: + for i in range(1, 10): + channel = client_endpoint.connect(("127.0.0.1", i), client_ctx) + nursery.start_soon(channel.do_handshake) + channel = client_endpoint.connect(address, client_ctx) + await channel.send(b"xxx") + assert await channel.receive() == b"xxx" + nursery.cancel_scope.cancel() + +# can we work around the openssl bug with invalid cookies by rebooting the connection +# when we see a second HelloVerifyRequest? (and then enable packet corruption in the +# torture test?) From 00e5cafa3999e97c73f9253e0cbcc0322d07dc57 Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Tue, 7 Sep 2021 20:57:52 -0700 Subject: [PATCH 29/47] comment --- trio/_dtls.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/trio/_dtls.py b/trio/_dtls.py index 3c990e1af2..e531ced810 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -1,4 +1,10 @@ +# Implementation of DTLS 1.2, using pyopenssl # https://datatracker.ietf.org/doc/html/rfc6347 +# +# OpenSSL's APIs for DTLS are extremely awkward and limited, which forces us to jump +# through a *lot* of hoops and implement important chunks of the protocol ourselves. +# Hopefully they fix this before implementing DTLS 1.3, because it's a very different +# protocol, and it's probably impossible to pull tricks like we do here. import struct import hmac From 0a2b055b2862f82d2416fc88d9d5b48c3f386baa Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Tue, 7 Sep 2021 23:39:45 -0700 Subject: [PATCH 30/47] A few more cleanups + add docs --- docs/source/conf.py | 1 + docs/source/reference-io.rst | 46 ++++++++ trio/_dtls.py | 223 ++++++++++++++++++++++++++++++++--- trio/tests/test_dtls.py | 49 +++++--- 4 files changed, 289 insertions(+), 30 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 6045ffd828..52872d0cb4 100755 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -87,6 +87,7 @@ def setup(app): intersphinx_mapping = { "python": ('https://docs.python.org/3', None), "outcome": ('https://outcome.readthedocs.io/en/latest/', None), + "pyopenssl": ('https://www.pyopenssl.org/en/stable/', None), } autodoc_member_order = "bysource" diff --git a/docs/source/reference-io.rst b/docs/source/reference-io.rst index b18983e272..a3291ef2ae 100644 --- a/docs/source/reference-io.rst +++ b/docs/source/reference-io.rst @@ -258,6 +258,52 @@ you call them before the handshake completes: .. autoexception:: NeedHandshakeError +Datagram TLS support +~~~~~~~~~~~~~~~~~~~~ + +Trio also has support for Datagram TLS (DTLS), which is like TLS but +for unreliable UDP connections. This can be useful for applications +where TCP's reliable in-order delivery is problematic, like +teleconferencing, latency-sensitive games, and VPNs. + +Currently, using DTLS with Trio requires PyOpenSSL. We hope to +eventually allow the use of the stdlib `ssl` module as well, but +unfortunately that's not yet possible. + +.. warning:: Note that PyOpenSSL is in many ways lower-level than the + `ssl` module – in particular, it currently **HAS NO BUILT-IN + MECHANISM TO VALIDATE CERTIFICATES**. We *strongly* recommend that + you use the `service-identity + `__ library to validate + hostnames and certificates. + +.. autoclass:: DTLSEndpoint + + .. automethod:: connect + + .. automethod:: serve + + .. automethod:: close + +.. autoclass:: DTLSChannel + :show-inheritance: + + .. automethod:: do_handshake + + .. automethod:: send + + .. automethod:: receive + + .. automethod:: close + + .. automethod:: aclose + + .. automethod:: set_ciphertext_mtu + + .. automethod:: get_cleartext_mtu + + .. automethod:: statistics + .. module:: trio.socket Low-level networking with :mod:`trio.socket` diff --git a/trio/_dtls.py b/trio/_dtls.py index e531ced810..8289821ad0 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -673,11 +673,7 @@ async def handle_client_hello_untrusted(endpoint, address, packet): endpoint._incoming_connections_q.s.send_nowait(stream) -async def dtls_receive_loop(endpoint_ref): - try: - sock = endpoint_ref().socket - except AttributeError: - return +async def dtls_receive_loop(endpoint_ref, sock): try: while True: try: @@ -738,6 +734,20 @@ class DTLSChannelStatistics: class DTLSChannel(trio.abc.Channel[bytes], metaclass=NoPublicConstructor): + """A DTLS connection. + + This class has no public constructor – you get instances by calling + `DTLSEndpoint.serve` or `~DTLSEndpoint.connect`. + + .. attribute:: endpoint + + The `DTLSEndpoint` that this connection is using. + + .. attribute:: peer_address + + The IP/port of the remote peer that this connection is associated with. + + """ def __init__(self, endpoint, peer_address, ctx): self.endpoint = endpoint self.peer_address = peer_address @@ -761,9 +771,6 @@ def __init__(self, endpoint, peer_address, ctx): self._handshake_lock = trio.Lock() self._record_encoder = RecordEncoder() - def statistics(self) -> DTLSChannelStatistics: - return DTLSChannelStatistics(self._packets_dropped_in_trio) - def _set_replaced(self): self._replaced = True # Any packets we already received could maybe possibly still be processed, but @@ -776,13 +783,6 @@ def _check_replaced(self): "peer tore down this connection to start a new one" ) - def set_ciphertext_mtu(self, new_mtu): - self._handshake_mtu = new_mtu - self._ssl.set_ciphertext_mtu(new_mtu) - - def get_cleartext_mtu(self): - return self._ssl.get_cleartext_mtu() - # XX on systems where we can (maybe just Linux?) take advantage of the kernel's PMTU # estimate @@ -791,6 +791,17 @@ def get_cleartext_mtu(self): # to handle receiving it properly though, which might be easier if we send it... def close(self): + """Close this connection. + + `DTLSChannel`\s don't actually own any OS-level resources – the + socket is owned by the `DTLSEndpoint`, not the individual connections. So + you don't really *have* to call this. But it will interrupt any other tasks + calling `receive` with a `ClosedResourceError`, and cause future attempts to use + this connection to fail. + + You can also use this object as a synchronous or asynchronous context manager. + + """ if self._closed: return self._closed = True @@ -807,6 +818,12 @@ def __exit__(self, *args): self.close() async def aclose(self): + """Close this connection, but asynchronously. + + This is included to satisfy the `trio.abc.Channel` contract. It's + identical to `close`, but async. + + """ self.close() await trio.lowlevel.checkpoint() @@ -822,6 +839,33 @@ async def _resend_final_volley(self): await self._send_volley(self._final_volley) async def do_handshake(self, *, initial_retransmit_timeout=1.0): + """Perform the handshake. + + Calling this is optional – if you don't, then it will be automatically called + the first time you call `send` or `receive`. But calling it explicitly can be + useful in case you want to control the retransmit timeout, use a cancel scope to + place an overall timeout on the handshake, or catch errors from the handshake + specifically. + + It's safe to call this multiple times, or call it simultaneously from multiple + tasks – the first call will perform the handshake, and the rest will be no-ops. + + Args: + + initial_retransmit_timeout (float): Since UDP is an unreliable protocol, it's + possible that some of the packets we send during the handshake will get + lost. To handle this, DTLS uses a timer to automatically retransmit + handshake packets that don't receive a response. This lets you set the + timeout we use to detect packet loss. Ideally, it should be set to ~1.5 + times the round-trip time to your peer, but 1 second is a reasonable + default. There's `some useful guidance here + `__. + + This is the *initial* timeout, because if packets keep being lost then Trio + will automatically back off to longer values, to avoid overloading the + network. + + """ async with self._handshake_lock: if self._did_handshake: return @@ -924,6 +968,10 @@ def read_volley(): ) async def send(self, data): + """Send a packet of data, securely. + + """ + if self._closed: raise trio.ClosedResourceError if not data: @@ -938,6 +986,15 @@ async def send(self, data): ) async def receive(self): + """Fetch the next packet of data from this connection's peer, waiting if + necessary. + + This is safe to call from multiple tasks simultaneously, in case you have some + reason to do that. And more importantly, it's cancellation-safe, meaning that + cancelling a call to `receive` will never cause a packet to be lost or corrupt + the underlying connection. + + """ if not self._did_handshake: await self.do_handshake() # If the packet isn't really valid, then openssl can decode it to the empty @@ -954,8 +1011,92 @@ async def receive(self): if cleartext: return cleartext + def set_ciphertext_mtu(self, new_mtu): + """Tells Trio the `largest amount of data that can be sent in a single packet to + this peer `__. + + Trio doesn't actually enforce this limit – if you pass a huge packet to `send`, + then we'll dutifully encrypt it and attempt to send it. But calling this method + does have two useful effects: + + - If called before the handshake is performed, then Trio will automatically + fragment handshake messages to fit within the given MTU. It also might + fragment them even smaller, if it detects signs of packet loss, so setting + this should never be necessary to make a successful connection. But, the + packet loss detection only happens after multiple timeouts have expired, so if + you have reason to believe that a smaller MTU is required, then you can set + this to skip those timeouts and establish the connection more quickly. + + - It changes the value returned from `get_cleartext_mtu`. So if you have some + kind of estimate of the network-level MTU, then you can use this to figure out + how much overhead DTLS will need for hashes/padding/etc., and how much space + you have left for your application data. + + The MTU here is measuring the largest UDP *payload* you think can be sent, the + amount of encrypted data that can be handed to the operating system in a single + call to `send`. It should *not* include IP/UDP headers. Note that OS estimates + of the MTU often are link-layer MTUs, so you have to subtract off 28 bytes on + IPv4 and 48 bytes on IPv6 to get the ciphertext MTU. + + By default, Trio assumes an MTU of 1472 bytes on IPv4, and 1452 bytes on IPv6, + which correspond to the common Ethernet MTU of 1500 bytes after accounting for + IP/UDP overhead. + + """ + self._handshake_mtu = new_mtu + self._ssl.set_ciphertext_mtu(new_mtu) + + def get_cleartext_mtu(self): + """Returns the largest number of bytes that you can pass in a single call to + `send` while still fitting within the network-level MTU. + + See `set_ciphertext_mtu` for more details. + + """ + if not self._did_handshake: + raise trio.NeedHandshakeError + return self._ssl.get_cleartext_mtu() + + def statistics(self): + """Returns an object with statistics about this connection. + + Currently this has only one attribute: + + - ``incoming_packets_dropped_in_trio`` (``int``): Gives a count of the number of + incoming packets from this peer that Trio successfully received from the + network, but then got dropped because the internal channel buffer was full. If + this is non-zero, then you might want to call ``receive`` more often, or use a + larger ``incoming_packets_buffer``, or just not worry about it because your + UDP-based protocol should be able to handle the occasional lost packet, right? + + """ + return DTLSChannelStatistics(self._packets_dropped_in_trio) + class DTLSEndpoint(metaclass=Final): + """A DTLS endpoint. + + A single UDP socket can handle arbitrarily many DTLS connections simultaneously, + acting as a client or server as needed. A `DTLSEndpoint` object holds a UDP socket + and manages these connections, which are represented as `DTLSChannel` objects. + + Args: + socket: (trio.socket.SocketType): A ``SOCK_DGRAM`` socket. If you want to accept + incoming connections in server mode, then you should probably bind the socket to + some known port. + incoming_packets_buffer (int): Each `DTLSChannel` using this socket has its own + buffer that holds incoming packets until you call `~DTLSChannel.receive` to read + them. This lets you adjust the size of this buffer. `~DTLSChannel.statistics` + lets you check if the buffer has overflowed. + + .. attribute:: socket + incoming_packets_buffer + + Both constructor arguments are also exposed as attributes, in case you need to + access them later. + + """ + def __init__(self, socket, *, incoming_packets_buffer=10): # We do this lazily on first construction, so only people who actually use DTLS # have to install PyOpenSSL. @@ -981,7 +1122,7 @@ def __init__(self, socket, *, incoming_packets_buffer=10): self._send_lock = trio.Lock() self._closed = False - trio.lowlevel.spawn_system_task(dtls_receive_loop, weakref.ref(self)) + trio.lowlevel.spawn_system_task(dtls_receive_loop, weakref.ref(self), self.socket) def __del__(self): # Do nothing if this object was never fully constructed @@ -1000,6 +1141,11 @@ def __del__(self): ) def close(self): + """Close this socket, and all associated DTLS connections. + + This object can also be used as a context manager. + + """ self._closed = True self.socket.close() for stream in list(self._streams.values()): @@ -1019,6 +1165,34 @@ def _check_closed(self): async def serve( self, ssl_context, async_fn, *args, task_status=trio.TASK_STATUS_IGNORED ): + """Listen for incoming connections, and spawn a handler for each using an + internal nursery. + + Similar to `~trio.serve_tcp`, this function never returns until cancelled, or + the `DTLSEndpoint` is closed and all handlers have exited. + + Usage commonly looks like:: + + async def handler(dtls_channel): + ... + + async with trio.open_nursery() as nursery: + await nursery.start(dtls_endpoint.serve, ssl_context, handler) + # ... do other things here ... + + The ``dtls_channel`` passed into the handler function has already performed the + "cookie exchange" part of the DTLS handshake, so the peer address is + trustworthy. But the actual cryptographic handshake doesn't happen until you + start using it, giving you a chance for any last minute configuration, and the + option to catch and handle handshake errors. + + Args: + ssl_context (OpenSSL.SSL.Context): The PyOpenSSL context object to use for + incoming connections. + async_fn: The handler function that will be invoked for each incoming + connection. + + """ self._check_closed() if self._listening_context is not None: raise trio.BusyResourceError("another task is already listening") @@ -1040,6 +1214,23 @@ async def handler_wrapper(stream): self._listening_context = None def connect(self, address, ssl_context): + """Initiate an outgoing DTLS connection. + + Notice that this is a synchronous method. That's because it doesn't actually + initiate any I/O – it just sets up a `DTLSChannel` object. The actual handshake + doesn't occur until you start using the `DTLSChannel`. This gives you a chance + to do further configuration first, like setting MTU etc. + + Args: + address: The address to connect to. Usually a (host, port) tuple, like + ``("127.0.0.1", 12345)``. + ssl_context (OpenSSL.SSL.Context): The PyOpenSSL context object to use for + this connection. + + Returns: + DTLSChannel + + """ # it would be nice if we could detect when 'address' is our own endpoint (a # loopback connection), because that can't work # but I don't see how to do it reliably diff --git a/trio/tests/test_dtls.py b/trio/tests/test_dtls.py index 2d09e469c8..c889d353aa 100644 --- a/trio/tests/test_dtls.py +++ b/trio/tests/test_dtls.py @@ -57,10 +57,10 @@ async def echo_handler(dtls_channel): ) if mtu is not None: dtls_channel.set_ciphertext_mtu(mtu) - print("server starting do_handshake") - await dtls_channel.do_handshake() - print("server finished do_handshake") try: + print("server starting do_handshake") + await dtls_channel.do_handshake() + print("server finished do_handshake") async for packet in dtls_channel: print(f"echoing {packet} -> {dtls_channel.peer_address}") await dtls_channel.send(packet) @@ -80,6 +80,9 @@ async def test_smoke(ipv6): async with dtls_echo_server(ipv6=ipv6) as (server_endpoint, address): with endpoint(ipv6=ipv6) as client_endpoint: client_channel = client_endpoint.connect(address, client_ctx) + with pytest.raises(trio.NeedHandshakeError): + client_channel.get_cleartext_mtu() + await client_channel.do_handshake() await client_channel.send(b"hello") assert await client_channel.receive() == b"hello" @@ -111,7 +114,7 @@ async def route_packet(packet): while True: op = r.choices( ["deliver", "drop", "dupe", "delay"], - weights=[0.6, 0.1, 0.1, 0.1], + weights=[0.7, 0.1, 0.1, 0.1], )[0] print(f"{packet.source} -> {packet.destination}: {op}") if op == "drop": @@ -122,6 +125,7 @@ async def route_packet(packet): await trio.sleep(r.random() * 3) # I wanted to test random packet corruption too, but it turns out # openssl has a bug in the following scenario: + # # - client sends ClientHello # - server sends HelloVerifyRequest with cookie -- but cookie is # invalid b/c either the ClientHello or HelloVerifyRequest was @@ -131,7 +135,12 @@ async def route_packet(packet): # # At this point, the client *should* switch to the new, valid # cookie. But OpenSSL doesn't; it stubbornly insists on re-sending - # the original, invalid cookie over and over. + # the original, invalid cookie over and over. In theory we could + # work around this by detecting cookie changes and starting over + # with a whole new SSL object, but (a) it doesn't seem worth it, (b) + # when I tried then I ran into another issue where OpenSSL got stuck + # in an infinite loop sending alerts over and over, which I didn't + # dig into because see (a). # # elif op == "distort": # payload = bytearray(packet.payload) @@ -715,6 +724,26 @@ async def test_gc_before_system_task_starts(): await trio.testing.wait_all_tasks_blocked() +@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning") +async def test_gc_as_packet_received(): + fn = FakeNet() + fn.enable() + + e = endpoint() + await e.socket.bind(("127.0.0.1", 0)) + + await trio.testing.wait_all_tasks_blocked() + + with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as s: + await s.sendto(b"xxx", e.socket.getsockname()) + # At this point, the endpoint's receive loop has been marked runnable because it + # just received a packet; closing the endpoint socket won't interrupt that. But by + # the time it wakes up to process the packet, the endpoint will be gone. + with pytest.warns(ResourceWarning): + del e + gc_collect_harder() + + @pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning") def test_gc_after_trio_exits(): async def main(): @@ -759,12 +788,8 @@ async def test_association_replaced_while_handshake_running(autojump_clock): fn = FakeNet() fn.enable() - blackholed = True - def route_packet(packet): - if blackholed: - return - fn.deliver_packet(packet) + pass fn.route_packet = route_packet @@ -815,7 +840,3 @@ async def test_send_to_closed_local_port(): await channel.send(b"xxx") assert await channel.receive() == b"xxx" nursery.cancel_scope.cancel() - -# can we work around the openssl bug with invalid cookies by rebooting the connection -# when we see a second HelloVerifyRequest? (and then enable packet corruption in the -# torture test?) From 419c962c21e9aba95d038e0430c4af0748f154ea Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Thu, 30 Sep 2021 08:05:26 -0700 Subject: [PATCH 31/47] Quote literal backslash in string --- trio/_dtls.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trio/_dtls.py b/trio/_dtls.py index 8289821ad0..58b24abaa8 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -793,7 +793,7 @@ def _check_replaced(self): def close(self): """Close this connection. - `DTLSChannel`\s don't actually own any OS-level resources – the + `DTLSChannel`\\s don't actually own any OS-level resources – the socket is owned by the `DTLSEndpoint`, not the individual connections. So you don't really *have* to call this. But it will interrupt any other tasks calling `receive` with a `ClosedResourceError`, and cause future attempts to use From 717e46f7a90c52873203a1ec133fe1a98506f885 Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Thu, 30 Sep 2021 08:15:01 -0700 Subject: [PATCH 32/47] Defer starting the DTLS receive task until we actually need to receive This works around a problem on windows where recvfrom on an unbound socket errors out. --- trio/_dtls.py | 15 ++++++++++++++- trio/tests/test_dtls.py | 17 +++++++++++++++-- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/trio/_dtls.py b/trio/_dtls.py index 58b24abaa8..766a3268e1 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -909,6 +909,7 @@ def read_volley(): self._check_replaced() await self._send_volley(volley_messages) # -- then this is where we wait for a reply -- + self.endpoint._ensure_receive_loop() with trio.move_on_after(timeout) as cscope: async for packet in self._q.r: self._ssl.bio_write(packet) @@ -1121,8 +1122,15 @@ def __init__(self, socket, *, incoming_packets_buffer=10): self._incoming_connections_q = _Queue(float("inf")) self._send_lock = trio.Lock() self._closed = False + self._receive_loop_spawned = False - trio.lowlevel.spawn_system_task(dtls_receive_loop, weakref.ref(self), self.socket) + def _ensure_receive_loop(self): + # We have to spawn this lazily, because on Windows it will immediately error out + # if the socket isn't already bound -- which for clients might not happen until + # after we send our first packet. + if not self._receive_loop_spawned: + trio.lowlevel.spawn_system_task(dtls_receive_loop, weakref.ref(self), self.socket) + self._receive_loop_spawned = True def __del__(self): # Do nothing if this object was never fully constructed @@ -1196,6 +1204,11 @@ async def handler(dtls_channel): self._check_closed() if self._listening_context is not None: raise trio.BusyResourceError("another task is already listening") + try: + self.socket.getsockname() + except OSError: + raise RuntimeError("DTLS socket must be bound before it can serve") + self._ensure_receive_loop() # We do cookie verification ourselves, so tell OpenSSL not to worry about it. # (See also _inject_client_hello_untrusted.) ssl_context.set_cookie_verify_callback(lambda *_: True) diff --git a/trio/tests/test_dtls.py b/trio/tests/test_dtls.py index c889d353aa..44f967ca06 100644 --- a/trio/tests/test_dtls.py +++ b/trio/tests/test_dtls.py @@ -284,6 +284,7 @@ async def null_handler(_): # pragma: no cover pass with endpoint() as server_endpoint: + await server_endpoint.socket.bind(("127.0.0.1", 0)) async with trio.open_nursery() as nursery: await nursery.start(server_endpoint.serve, server_ctx, null_handler) with pytest.raises(trio.BusyResourceError): @@ -696,9 +697,21 @@ def route_packet(packet): async def test_system_task_cleaned_up_on_gc(): before_tasks = trio.lowlevel.current_statistics().tasks_living + e = endpoint() - # Give system task a chance to start up - await trio.testing.wait_all_tasks_blocked() + + async def force_receive_loop_to_start(): + # This connection/handshake attempt can't succeed. The only purpose is to force + # the endpoint to set up a receive loop. + with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as s: + await s.bind(("127.0.0.1", 0)) + c = e.connect(s.getsockname(), client_ctx) + async with trio.open_nursery() as nursery: + nursery.start_soon(c.do_handshake) + await trio.testing.wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + + await force_receive_loop_to_start() during_tasks = trio.lowlevel.current_statistics().tasks_living From 08686f67da8e6750fa9279c140087ad787d0d7fb Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Tue, 2 Nov 2021 21:20:33 -0700 Subject: [PATCH 33/47] Work around bug in Ubuntu 18.04's OpenSSL --- trio/_dtls.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/trio/_dtls.py b/trio/_dtls.py index 766a3268e1..57b2437b78 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -659,6 +659,24 @@ async def handle_client_hello_untrusted(endpoint, address, packet): # ...OpenSSL didn't like it, so I guess we didn't have a valid ClientHello # after all. return + + # Some old versions of OpenSSL have a bug with memory BIOs, where DTLSv1_listen + # consumes the ClientHello out of the BIO, but then do_handshake expects the + # ClientHello to still be in there (but not the one that ships with Ubuntu + # 20.04). In particular, this is known to affect the OpenSSL v1.1.1 that ships + # with Ubuntu 18.04. To work around this, we deliver a second copy of the + # ClientHello after DTLSv1_listen has completed. This is safe to do + # unconditionally, because on newer versions of OpenSSL, the second ClientHello + # is treated as a duplicate packet, which is a normal thing that can happen over + # UDP. For more details, see: + # + # https://github.com/pyca/pyopenssl/blob/e84e7b57d1838de70ab7a27089fbee78ce0d2106/tests/test_ssl.py#L4226-L4293 + # + # This was fixed in v1.1.1a, and all later versions. So maybe in 2024 or so we + # can delete this. The fix landed in OpenSSL master as 079ef6bd534d2, and then + # was backported to the 1.1.1 branch as d1bfd8076e28. + stream._ssl.bio_write(packet) + # Check if we have an existing association old_stream = endpoint._streams.get(address) if old_stream is not None: From e5a4d0d37ced86869e09f74f8bc868940efb3f74 Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Tue, 2 Nov 2021 21:21:13 -0700 Subject: [PATCH 34/47] Clean up test Noticed while testing on Ubuntu 18.04's weird old OpenSSL that this test was getting stuck in an infinite loop, even though it passed with more recent OpenSSL. I'm not sure why it was different, exactly, but on closer examination, writing the test this way makes more sense anyway, and now the tests pass on Ubuntu 18.04 too. --- trio/tests/test_dtls.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/trio/tests/test_dtls.py b/trio/tests/test_dtls.py index 44f967ca06..305c4c2a19 100644 --- a/trio/tests/test_dtls.py +++ b/trio/tests/test_dtls.py @@ -448,7 +448,9 @@ async def test_invalid_cookie_rejected(autojump_clock): with trio.CancelScope() as cscope: - offset_to_corrupt = count() + # the first 11 bytes of ClientHello aren't protected by the cookie, so only test + # corrupting bytes after that. + offset_to_corrupt = count(11) def route_packet(packet): try: From 2ffd8927c0d5ff1f3ef8b87a7739a9bf3520ad49 Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Tue, 2 Nov 2021 21:25:43 -0700 Subject: [PATCH 35/47] Temporarily switch branch to pull dev version of openssl, to let CI run --- test-requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test-requirements.txt b/test-requirements.txt index aa32adab96..a2f050a033 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -97,7 +97,7 @@ pygments==2.10.0 # via ipython pylint==2.12.2 # via -r test-requirements.in -pyopenssl==21.0.0 +pyopenssl @ git+https://github.com/pyca/pyopenssl # via -r test-requirements.in pyparsing==3.0.7 # via packaging From f335dce06e4e8543526055626800af540d5288ed Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Tue, 2 Nov 2021 21:49:17 -0700 Subject: [PATCH 36/47] blacken --- trio/_dtls.py | 9 +++++---- trio/tests/test_dtls.py | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/trio/_dtls.py b/trio/_dtls.py index 57b2437b78..dc653f19ff 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -766,6 +766,7 @@ class DTLSChannel(trio.abc.Channel[bytes], metaclass=NoPublicConstructor): The IP/port of the remote peer that this connection is associated with. """ + def __init__(self, endpoint, peer_address, ctx): self.endpoint = endpoint self.peer_address = peer_address @@ -987,9 +988,7 @@ def read_volley(): ) async def send(self, data): - """Send a packet of data, securely. - - """ + """Send a packet of data, securely.""" if self._closed: raise trio.ClosedResourceError @@ -1147,7 +1146,9 @@ def _ensure_receive_loop(self): # if the socket isn't already bound -- which for clients might not happen until # after we send our first packet. if not self._receive_loop_spawned: - trio.lowlevel.spawn_system_task(dtls_receive_loop, weakref.ref(self), self.socket) + trio.lowlevel.spawn_system_task( + dtls_receive_loop, weakref.ref(self), self.socket + ) self._receive_loop_spawned = True def __del__(self): diff --git a/trio/tests/test_dtls.py b/trio/tests/test_dtls.py index 305c4c2a19..eb18d7a1a9 100644 --- a/trio/tests/test_dtls.py +++ b/trio/tests/test_dtls.py @@ -699,7 +699,6 @@ def route_packet(packet): async def test_system_task_cleaned_up_on_gc(): before_tasks = trio.lowlevel.current_statistics().tasks_living - e = endpoint() async def force_receive_loop_to_start(): @@ -812,6 +811,7 @@ def route_packet(packet): with endpoint() as client_endpoint: c1 = client_endpoint.connect(address, client_ctx) async with trio.open_nursery() as nursery: + async def doomed_handshake(): with pytest.raises(trio.BrokenResourceError): await c1.do_handshake() From acc0eacdd84c57f8f0b0a14cb50aec27153bbe70 Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Wed, 3 Nov 2021 21:28:22 -0700 Subject: [PATCH 37/47] Restore py36 compatibility --- trio/tests/test_dtls.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trio/tests/test_dtls.py b/trio/tests/test_dtls.py index eb18d7a1a9..0cadad8148 100644 --- a/trio/tests/test_dtls.py +++ b/trio/tests/test_dtls.py @@ -3,7 +3,7 @@ from trio import DTLSEndpoint import random import attr -from contextlib import asynccontextmanager +from async_generator import asynccontextmanager from itertools import count import ipaddress import warnings From e3fb2d823f84210a3ab5adb5ede577ad2c1f2a15 Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Wed, 3 Nov 2021 21:29:01 -0700 Subject: [PATCH 38/47] Maybe this will work? --- test-requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test-requirements.txt b/test-requirements.txt index a2f050a033..1747ebe512 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -97,7 +97,7 @@ pygments==2.10.0 # via ipython pylint==2.12.2 # via -r test-requirements.in -pyopenssl @ git+https://github.com/pyca/pyopenssl +pyopenssl @ https://github.com/pyca/pyopenssl/archive/refs/heads/main.zip # via -r test-requirements.in pyparsing==3.0.7 # via packaging From 78634986cabb747af99a3579a715cc648b9a7c01 Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Wed, 3 Nov 2021 21:38:26 -0700 Subject: [PATCH 39/47] Pacify flake8 --- trio/_dtls.py | 2 +- trio/tests/test_dtls.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/trio/_dtls.py b/trio/_dtls.py index dc653f19ff..5bd0cc7872 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -936,7 +936,7 @@ def read_volley(): self._ssl.do_handshake() # We ignore generic SSL.Error here, because you can get those # from random invalid packets - except (SSL.WantReadError, SSL.Error) as exc: + except (SSL.WantReadError, SSL.Error): pass else: # No exception -> the handshake is done, and we can diff --git a/trio/tests/test_dtls.py b/trio/tests/test_dtls.py index 0cadad8148..68107fe86d 100644 --- a/trio/tests/test_dtls.py +++ b/trio/tests/test_dtls.py @@ -820,7 +820,7 @@ async def doomed_handshake(): await trio.sleep(10) - c2 = client_endpoint.connect(address, client_ctx) + client_endpoint.connect(address, client_ctx) async def test_association_replaced_before_handshake_starts(): @@ -836,7 +836,7 @@ def route_packet(packet): # pragma: no cover async with dtls_echo_server() as (_, address): with endpoint() as client_endpoint: c1 = client_endpoint.connect(address, client_ctx) - c2 = client_endpoint.connect(address, client_ctx) + client_endpoint.connect(address, client_ctx) with pytest.raises(trio.BrokenResourceError): await c1.do_handshake() From 6e3aca08fb0e290b19080502ac2aa8b36e669f64 Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Wed, 3 Nov 2021 21:47:32 -0700 Subject: [PATCH 40/47] make mypy happy --- test-requirements.in | 1 + test-requirements.txt | 13 ++++--------- trio/_dtls.py | 17 +++++++++-------- 3 files changed, 14 insertions(+), 17 deletions(-) diff --git a/test-requirements.in b/test-requirements.in index 4d15de8e4f..2c75770524 100644 --- a/test-requirements.in +++ b/test-requirements.in @@ -12,6 +12,7 @@ cryptography>=36.0.0 # 35.0.0 is transitive but fails # Tools black; implementation_name == "cpython" mypy; implementation_name == "cpython" +types-pyOpenSSL; implementation_name == "cpython" flake8 astor # code generation diff --git a/test-requirements.txt b/test-requirements.txt index 1747ebe512..e932052c31 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,8 +1,8 @@ # -# This file is autogenerated by pip-compile with python 3.7 +# This file is autogenerated by pip-compile # To update, run: # -# pip-compile --output-file=test-requirements.txt test-requirements.in +# pip-compile --output-file test-requirements.txt test-requirements.in # astor==0.8.1 # via -r test-requirements.in @@ -25,9 +25,8 @@ click==8.0.3 # via black coverage[toml]==6.0.2 # via pytest-cov -cryptography==36.0.1 +cryptography==35.0.0 # via - # -r test-requirements.in # pyopenssl # trustme decorator==5.1.1 @@ -127,8 +126,6 @@ traitlets==5.1.1 # ipython # matplotlib-inline trustme==0.9.0 - # via -r test-requirements.in -typing-extensions==4.0.1 ; implementation_name == "cpython" # via # -r test-requirements.in # black @@ -137,6 +134,4 @@ wcwidth==0.2.5 # via prompt-toolkit wrapt==1.13.3 # via astroid - -# The following packages are considered to be unsafe in a requirements file: -# setuptools +types-pyOpenSSL diff --git a/trio/_dtls.py b/trio/_dtls.py index 5bd0cc7872..fc3d111931 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -130,15 +130,16 @@ def is_client_hello_untrusted(packet): RECORD_HEADER = struct.Struct("!B2sQH") -hex_repr = attr.ib(repr=lambda data: data.hex()) # pragma: no cover +def to_hex(data: bytes) -> str: # pragma: no cover + return data.hex() @attr.frozen class Record: content_type: int - version: bytes = hex_repr + version: bytes = attr.ib(repr=to_hex) epoch_seqno: int - payload: bytes = hex_repr + payload: bytes = attr.ib(repr=to_hex) def records_untrusted(packet): @@ -185,7 +186,7 @@ class HandshakeFragment: msg_seq: int frag_offset: int frag_len: int - frag: bytes = hex_repr + frag: bytes = attr.ib(repr=to_hex) def decode_handshake_fragment_untrusted(payload): @@ -300,19 +301,19 @@ def decode_client_hello_untrusted(packet): @attr.frozen class HandshakeMessage: - record_version: bytes = hex_repr + record_version: bytes = attr.ib(repr=to_hex) msg_type: HandshakeType msg_seq: int - body: bytearray = hex_repr + body: bytearray = attr.ib(repr=to_hex) # ChangeCipherSpec is part of the handshake, but it's not a "handshake # message" and can't be fragmented the same way. Sigh. @attr.frozen class PseudoHandshakeMessage: - record_version: bytes = hex_repr + record_version: bytes = attr.ib(repr=to_hex) content_type: int - payload: bytes = hex_repr + payload: bytes = attr.ib(repr=to_hex) # The final record in a handshake is Finished, which is encrypted, can't be fragmented From 1559b56407050a268b1df5b5dddab3700fe73a3c Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Wed, 3 Nov 2021 22:10:09 -0700 Subject: [PATCH 41/47] more py36 --- trio/_dtls.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/trio/_dtls.py b/trio/_dtls.py index fc3d111931..d5c5e66d94 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -42,6 +42,14 @@ def best_guess_mtu(sock): return 1500 - packet_header_overhead(sock) +try: + from hmac import digest +except ImportError: + # python 3.6 + def digest(key, msg, algorithm): + return hmac.new(key, msg, algorithm).digest() + + # There are a bunch of different RFCs that define these codes, so for a # comprehensive collection look here: # https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml @@ -540,7 +548,7 @@ def _make_cookie(key, salt, tick, address, client_hello_bits): client_hello_bits, ) - return (salt + hmac.digest(key, signable_data, COOKIE_HASH))[:COOKIE_LENGTH] + return (salt + digest(key, signable_data, COOKIE_HASH))[:COOKIE_LENGTH] def valid_cookie(key, cookie, address, client_hello_bits): From 86ab14d45b653e6ad0397218a8204716216015d5 Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Wed, 3 Nov 2021 22:26:26 -0700 Subject: [PATCH 42/47] shut up mypy --- trio/_dtls.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trio/_dtls.py b/trio/_dtls.py index d5c5e66d94..09dd4bf4b2 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -46,7 +46,7 @@ def best_guess_mtu(sock): from hmac import digest except ImportError: # python 3.6 - def digest(key, msg, algorithm): + def digest(key, msg, algorithm): # type: ignore return hmac.new(key, msg, algorithm).digest() From 1714e73c5b711f3ec6e3b3b9e9f2da85c826a319 Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Fri, 28 Jan 2022 17:25:33 -0800 Subject: [PATCH 43/47] remove unneeded import --- trio/_dtls.py | 1 - 1 file changed, 1 deletion(-) diff --git a/trio/_dtls.py b/trio/_dtls.py index 09dd4bf4b2..66e19a0367 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -9,7 +9,6 @@ import struct import hmac import os -import io import enum from itertools import count import weakref From 1fc7847a8992d396e29a132e4707e01ac3255218 Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Sat, 29 Jan 2022 12:29:39 -0800 Subject: [PATCH 44/47] pyopenssl has released! --- test-requirements.in | 4 ++-- test-requirements.txt | 27 ++++++++++++++++++++------- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/test-requirements.in b/test-requirements.in index 2c75770524..6d12a658e2 100644 --- a/test-requirements.in +++ b/test-requirements.in @@ -3,8 +3,8 @@ pytest >= 5.0 # for faulthandler in core pytest-cov >= 2.6.0 # ipython 7.x is the last major version supporting Python 3.7 ipython ~= 7.31 # for the IPython traceback integration tests -pyOpenSSL # for the ssl tests -trustme # for the ssl tests +pyOpenSSL >= 22.0.0 # for the ssl + DTLS tests +trustme # for the ssl + DTLS tests pylint # for pylint finding all symbols tests jedi # for jedi code completion tests cryptography>=36.0.0 # 35.0.0 is transitive but fails diff --git a/test-requirements.txt b/test-requirements.txt index e932052c31..54d4add08b 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,8 +1,8 @@ # -# This file is autogenerated by pip-compile +# This file is autogenerated by pip-compile with python 3.9 # To update, run: # -# pip-compile --output-file test-requirements.txt test-requirements.in +# pip-compile test-requirements.in # astor==0.8.1 # via -r test-requirements.in @@ -25,8 +25,9 @@ click==8.0.3 # via black coverage[toml]==6.0.2 # via pytest-cov -cryptography==35.0.0 +cryptography==36.0.1 # via + # -r test-requirements.in # pyopenssl # trustme decorator==5.1.1 @@ -96,7 +97,7 @@ pygments==2.10.0 # via ipython pylint==2.12.2 # via -r test-requirements.in -pyopenssl @ https://github.com/pyca/pyopenssl/archive/refs/heads/main.zip +pyopenssl==22.0.0 # via -r test-requirements.in pyparsing==3.0.7 # via packaging @@ -106,8 +107,6 @@ pytest==6.2.5 # pytest-cov pytest-cov==3.0.0 # via -r test-requirements.in -six==1.16.0 - # via pyopenssl sniffio==1.2.0 # via -r test-requirements.in sortedcontainers==2.4.0 @@ -126,12 +125,26 @@ traitlets==5.1.1 # ipython # matplotlib-inline trustme==0.9.0 + # via -r test-requirements.in +types-cryptography==3.3.14 + # via types-pyopenssl +types-enum34==1.1.8 + # via types-cryptography +types-ipaddress==1.0.7 + # via types-cryptography +types-pyopenssl==21.0.3 ; implementation_name == "cpython" + # via -r test-requirements.in +typing-extensions==4.0.1 ; implementation_name == "cpython" # via # -r test-requirements.in + # astroid # black # mypy + # pylint wcwidth==0.2.5 # via prompt-toolkit wrapt==1.13.3 # via astroid -types-pyOpenSSL + +# The following packages are considered to be unsafe in a requirements file: +# setuptools From 10f450645d94e27a39513d6da2171e2da7e6a6e7 Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Sat, 29 Jan 2022 12:30:13 -0800 Subject: [PATCH 45/47] Work around pypy gc quirks --- trio/tests/test_dtls.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/trio/tests/test_dtls.py b/trio/tests/test_dtls.py index 68107fe86d..8968d9a601 100644 --- a/trio/tests/test_dtls.py +++ b/trio/tests/test_dtls.py @@ -1,12 +1,11 @@ import pytest import trio +import trio.testing from trio import DTLSEndpoint import random import attr from async_generator import asynccontextmanager from itertools import count -import ipaddress -import warnings import trustme from OpenSSL import SSL @@ -699,9 +698,13 @@ def route_packet(packet): async def test_system_task_cleaned_up_on_gc(): before_tasks = trio.lowlevel.current_statistics().tasks_living - e = endpoint() + # We put this into a sub-function so that everything automatically becomes garbage + # when the frame exits. For some reason just doing 'del e' wasn't enough on pypy + # with coverage enabled -- I think we were hitting this bug: + # https://foss.heptapod.net/pypy/pypy/-/issues/3656 + async def start_and_forget_endpoint(): + e = endpoint() - async def force_receive_loop_to_start(): # This connection/handshake attempt can't succeed. The only purpose is to force # the endpoint to set up a receive loop. with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as s: @@ -712,12 +715,12 @@ async def force_receive_loop_to_start(): await trio.testing.wait_all_tasks_blocked() nursery.cancel_scope.cancel() - await force_receive_loop_to_start() - - during_tasks = trio.lowlevel.current_statistics().tasks_living + during_tasks = trio.lowlevel.current_statistics().tasks_living + return during_tasks with pytest.warns(ResourceWarning): - del e + during_tasks = await start_and_forget_endpoint() + await trio.testing.wait_all_tasks_blocked() gc_collect_harder() await trio.testing.wait_all_tasks_blocked() @@ -745,6 +748,7 @@ async def test_gc_as_packet_received(): e = endpoint() await e.socket.bind(("127.0.0.1", 0)) + e._ensure_receive_loop() await trio.testing.wait_all_tasks_blocked() @@ -761,6 +765,12 @@ async def test_gc_as_packet_received(): @pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning") def test_gc_after_trio_exits(): async def main(): + # We use fakenet just to make sure no real sockets can leak out of the test + # case - on pypy somehow the socket was outliving the gc_collect_harder call + # below. Since the test is just making sure DTLSEndpoint.__del__ doesn't explode + # when called after trio exits, it doesn't need a real socket. + fn = FakeNet() + fn.enable() return endpoint() e = trio.run(main) From 284cf2d9be44f5647c5df3a1662493fbacd6df84 Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Sat, 29 Jan 2022 12:45:24 -0800 Subject: [PATCH 46/47] Drop py36 support --- trio/_dtls.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/trio/_dtls.py b/trio/_dtls.py index 66e19a0367..4cd4392fbc 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -41,14 +41,6 @@ def best_guess_mtu(sock): return 1500 - packet_header_overhead(sock) -try: - from hmac import digest -except ImportError: - # python 3.6 - def digest(key, msg, algorithm): # type: ignore - return hmac.new(key, msg, algorithm).digest() - - # There are a bunch of different RFCs that define these codes, so for a # comprehensive collection look here: # https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml @@ -547,7 +539,7 @@ def _make_cookie(key, salt, tick, address, client_hello_bits): client_hello_bits, ) - return (salt + digest(key, signable_data, COOKIE_HASH))[:COOKIE_LENGTH] + return (salt + hmac.digest(key, signable_data, COOKIE_HASH))[:COOKIE_LENGTH] def valid_cookie(key, cookie, address, client_hello_bits): From 93851cfe5378e44244c345a6ab41a8962a8e693a Mon Sep 17 00:00:00 2001 From: Peter Gessler Date: Wed, 27 Jul 2022 09:01:18 -0500 Subject: [PATCH 47/47] Update _dtls.py --- trio/_dtls.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trio/_dtls.py b/trio/_dtls.py index 4cd4392fbc..910637455a 100644 --- a/trio/_dtls.py +++ b/trio/_dtls.py @@ -611,7 +611,7 @@ def _read_loop(read_fn): chunks = [] while True: try: - chunk = read_fn(2 ** 14) # max TLS record size + chunk = read_fn(2**14) # max TLS record size except SSL.WantReadError: break chunks.append(chunk)