From a7bf339953bd3c7aada50744ac672a43f2058fe0 Mon Sep 17 00:00:00 2001 From: Kumar Aditya Date: Mon, 3 Nov 2025 21:41:37 +0530 Subject: [PATCH 1/2] properly propagate contextvars for server protocols --- Lib/asyncio/base_events.py | 9 ++++-- Lib/asyncio/selector_events.py | 50 +++++++++++++++++----------------- 2 files changed, 31 insertions(+), 28 deletions(-) diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index 8cbb71f708537f..a4afd6e7a5e1f6 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -14,6 +14,7 @@ """ import collections +import contextvars import collections.abc import concurrent.futures import errno @@ -289,6 +290,7 @@ def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog, self._ssl_shutdown_timeout = ssl_shutdown_timeout self._serving = False self._serving_forever_fut = None + self._context = contextvars.copy_context() def __repr__(self): return f'<{self.__class__.__name__} sockets={self.sockets!r}>' @@ -318,7 +320,7 @@ def _start_serving(self): self._loop._start_serving( self._protocol_factory, sock, self._ssl_context, self, self._backlog, self._ssl_handshake_timeout, - self._ssl_shutdown_timeout) + self._ssl_shutdown_timeout, context=self._context) def get_loop(self): return self._loop @@ -1211,9 +1213,10 @@ async def _create_connection_transport( self, sock, protocol_factory, ssl, server_hostname, server_side=False, ssl_handshake_timeout=None, - ssl_shutdown_timeout=None): + ssl_shutdown_timeout=None, context=None): sock.setblocking(False) + context = context if context is not None else contextvars.copy_context() protocol = protocol_factory() waiter = self.create_future() @@ -1225,7 +1228,7 @@ async def _create_connection_transport( ssl_handshake_timeout=ssl_handshake_timeout, ssl_shutdown_timeout=ssl_shutdown_timeout) else: - transport = self._make_socket_transport(sock, protocol, waiter) + transport = self._make_socket_transport(sock, protocol, waiter, context=context) try: await waiter diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py index ff7e16df3c6273..85bbf4c7f360c1 100644 --- a/Lib/asyncio/selector_events.py +++ b/Lib/asyncio/selector_events.py @@ -67,10 +67,10 @@ def __init__(self, selector=None): self._transports = weakref.WeakValueDictionary() def _make_socket_transport(self, sock, protocol, waiter=None, *, - extra=None, server=None): + extra=None, server=None, context=None): self._ensure_fd_no_transport(sock) return _SelectorSocketTransport(self, sock, protocol, waiter, - extra, server) + extra, server, context=context) def _make_ssl_transport( self, rawsock, protocol, sslcontext, waiter=None, @@ -159,16 +159,16 @@ def _write_to_self(self): def _start_serving(self, protocol_factory, sock, sslcontext=None, server=None, backlog=100, ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT, - ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT): + ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT, context=None): self._add_reader(sock.fileno(), self._accept_connection, protocol_factory, sock, sslcontext, server, backlog, - ssl_handshake_timeout, ssl_shutdown_timeout) + ssl_handshake_timeout, ssl_shutdown_timeout, context) def _accept_connection( self, protocol_factory, sock, sslcontext=None, server=None, backlog=100, ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT, - ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT): + ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT, context=None): # This method is only called once for each event loop tick where the # listening socket has triggered an EVENT_READ. There may be multiple # connections waiting for an .accept() so it is called in a loop. @@ -204,21 +204,21 @@ def _accept_connection( self._start_serving, protocol_factory, sock, sslcontext, server, backlog, ssl_handshake_timeout, - ssl_shutdown_timeout) + ssl_shutdown_timeout, context) else: raise # The event loop will catch, log and ignore it. else: extra = {'peername': addr} accept = self._accept_connection2( protocol_factory, conn, extra, sslcontext, server, - ssl_handshake_timeout, ssl_shutdown_timeout) - self.create_task(accept) + ssl_handshake_timeout, ssl_shutdown_timeout, context=context) + self.create_task(accept, context=context) async def _accept_connection2( self, protocol_factory, conn, extra, sslcontext=None, server=None, ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT, - ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT): + ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT, context=None): protocol = None transport = None try: @@ -233,7 +233,7 @@ async def _accept_connection2( else: transport = self._make_socket_transport( conn, protocol, waiter=waiter, extra=extra, - server=server) + server=server, context=context) try: await waiter @@ -275,9 +275,9 @@ def _ensure_fd_no_transport(self, fd): f'File descriptor {fd!r} is used by transport ' f'{transport!r}') - def _add_reader(self, fd, callback, *args): + def _add_reader(self, fd, callback, *args, context=None): self._check_closed() - handle = events.Handle(callback, args, self, None) + handle = events.Handle(callback, args, self, context=context) key = self._selector.get_map().get(fd) if key is None: self._selector.register(fd, selectors.EVENT_READ, @@ -770,7 +770,7 @@ class _SelectorTransport(transports._FlowControlMixin, # exception) _sock = None - def __init__(self, loop, sock, protocol, extra=None, server=None): + def __init__(self, loop, sock, protocol, extra=None, server=None, context=None): super().__init__(extra, loop) self._extra['socket'] = trsock.TransportSocket(sock) try: @@ -784,7 +784,7 @@ def __init__(self, loop, sock, protocol, extra=None, server=None): self._extra['peername'] = None self._sock = sock self._sock_fd = sock.fileno() - + self._context = context self._protocol_connected = False self.set_protocol(protocol) @@ -866,7 +866,7 @@ def close(self): if not self._buffer: self._conn_lost += 1 self._loop._remove_writer(self._sock_fd) - self._loop.call_soon(self._call_connection_lost, None) + self._loop.call_soon(self._call_connection_lost, None, context=self._context) def __del__(self, _warn=warnings.warn): if self._sock is not None: @@ -899,7 +899,7 @@ def _force_close(self, exc): self._closing = True self._loop._remove_reader(self._sock_fd) self._conn_lost += 1 - self._loop.call_soon(self._call_connection_lost, exc) + self._loop.call_soon(self._call_connection_lost, exc, context=self._context) def _call_connection_lost(self, exc): try: @@ -921,7 +921,7 @@ def get_write_buffer_size(self): def _add_reader(self, fd, callback, *args): if not self.is_reading(): return - self._loop._add_reader(fd, callback, *args) + self._loop._add_reader(fd, callback, *args, context=self._context) class _SelectorSocketTransport(_SelectorTransport): @@ -930,10 +930,10 @@ class _SelectorSocketTransport(_SelectorTransport): _sendfile_compatible = constants._SendfileMode.TRY_NATIVE def __init__(self, loop, sock, protocol, waiter=None, - extra=None, server=None): - + extra=None, server=None, context=None): + assert context is not None self._read_ready_cb = None - super().__init__(loop, sock, protocol, extra, server) + super().__init__(loop, sock, protocol, extra, server, context) self._eof = False self._empty_waiter = None if _HAS_SENDMSG: @@ -945,14 +945,14 @@ def __init__(self, loop, sock, protocol, waiter=None, # decreases the latency (in some cases significantly.) base_events._set_nodelay(self._sock) - self._loop.call_soon(self._protocol.connection_made, self) + self._loop.call_soon(self._protocol.connection_made, self, context=context) # only start reading when connection_made() has been called self._loop.call_soon(self._add_reader, - self._sock_fd, self._read_ready) + self._sock_fd, self._read_ready, context=context) if waiter is not None: # only wake up the waiter when connection_made() has been called self._loop.call_soon(futures._set_result_unless_cancelled, - waiter, None) + waiter, None, context=context) def set_protocol(self, protocol): if isinstance(protocol, protocols.BufferedProtocol): @@ -1081,7 +1081,7 @@ def write(self, data): if not data: return # Not all was written; register write handler. - self._loop._add_writer(self._sock_fd, self._write_ready) + self._loop._add_writer(self._sock_fd, self._write_ready, context=self._context) # Add it to the buffer. self._buffer.append(data) @@ -1185,7 +1185,7 @@ def writelines(self, list_of_data): self._write_ready() # If the entire buffer couldn't be written, register a write handler if self._buffer: - self._loop._add_writer(self._sock_fd, self._write_ready) + self._loop._add_writer(self._sock_fd, self._write_ready, context=self._context) self._maybe_pause_protocol() def can_write_eof(self): From 40a8f7c151d78e46bf48e48c5510677d18bafaa6 Mon Sep 17 00:00:00 2001 From: Kumar Aditya Date: Tue, 2 Dec 2025 11:50:42 +0530 Subject: [PATCH 2/2] fix _add_writer --- Lib/asyncio/selector_events.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py index 85bbf4c7f360c1..67b2f07c32e4ff 100644 --- a/Lib/asyncio/selector_events.py +++ b/Lib/asyncio/selector_events.py @@ -309,9 +309,9 @@ def _remove_reader(self, fd): else: return False - def _add_writer(self, fd, callback, *args): + def _add_writer(self, fd, callback, *args, context=None): self._check_closed() - handle = events.Handle(callback, args, self, None) + handle = events.Handle(callback, args, self, context=context) key = self._selector.get_map().get(fd) if key is None: self._selector.register(fd, selectors.EVENT_WRITE,