diff --git a/Doc/library/ssl.rst b/Doc/library/ssl.rst index 08824feeb3958f..9e80546c45558b 100644 --- a/Doc/library/ssl.rst +++ b/Doc/library/ssl.rst @@ -1315,6 +1315,20 @@ SSL sockets also have the following additional methods and attributes: .. versionadded:: 3.2 +.. attribute:: SSLSocket.eager_recv + + If set to ``True``, a call to :meth:`~socket.socket.recv()` or + :meth:`~socket.socket.recv_into()` on a + :ref:`non-blocking ` TLS socket + will drop the GIL once to read the entire buffer instead of reading at most + one TLS record (16 KB). + + .. note:: + Reading the entire buffer can include the TLS EOF segment, which will + close the TLS layer without raising :exc:`SSLEOFError`. + + .. versionadded:: 3.12 + .. attribute:: SSLSocket.server_side A boolean which is ``True`` for server-side sockets and ``False`` for diff --git a/Lib/ssl.py b/Lib/ssl.py index 1d5873726441e4..afebd432258807 100644 --- a/Lib/ssl.py +++ b/Lib/ssl.py @@ -836,6 +836,15 @@ def session_reused(self): """Was the client session reused during handshake""" return self._sslobj.session_reused + @property + def eager_recv(self): + """If data is read from the socket eagerly, ignoring possible TLS EOF packets.""" + return self._sslobj.eager_recv + + @eager_recv.setter + def eager_recv(self, eager_recv): + self._sslobj.eager_recv = eager_recv + @property def server_side(self): """Whether this is a server-side socket.""" @@ -1044,6 +1053,17 @@ def session_reused(self): if self._sslobj is not None: return self._sslobj.session_reused + @property + @_sslcopydoc + def eager_recv(self): + if self._sslobj is not None: + return self._sslobj.eager_recv + + @eager_recv.setter + def eager_recv(self, eager_recv): + if self._sslobj is not None: + self._sslobj.eager_recv = eager_recv + def dup(self): raise NotImplementedError("Can't dup() %s instances" % self.__class__.__name__) diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index e926fc5e88e584..cf828583541f88 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -2118,6 +2118,49 @@ def test_bio_read_write_data(self): self.assertEqual(buf, b'foo\n') self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap) + def test_bulk_nonblocking_read(self): + # 65536 bytes divide up into 4 TLS records (16 KB each) + # In nonblocking mode, we should be able to read all four in a single + # drop of the GIL. + size = 65536 + + client_context, server_context, hostname = testing_context() + server = ThreadedEchoServer(context=server_context, chatty=False, + buffer_size=size) + with server: + sock = socket.create_connection((HOST, server.port)) + sock.settimeout(0.0) + s = client_context.wrap_socket(sock, server_hostname=hostname, + do_handshake_on_connect=False) + s.eager_recv = True + with s: + while True: + try: + s.do_handshake() + break + except ssl.SSLWantReadError: + select.select([s], [], []) + except ssl.SSLWantWriteError: + select.select([], [s], []) + + s.send(b'\x00' * size) + + select.select([s], [], []) + + while size > 0: + try: + count = len(s.recv(size)) + except ssl.SSLWantReadError: + select.select([s], [], []) + # Give the sender some more time to complete sending. + time.sleep(0.01) + else: + if count > 16384: + return + size -= count + + self.fail("All TLS reads were smaller than 16KB") + @support.requires_resource('network') class NetworkedTests(unittest.TestCase): @@ -2177,7 +2220,7 @@ class ConnectionHandler(threading.Thread): with and without the SSL wrapper around the socket connection, so that we can test the STARTTLS functionality.""" - def __init__(self, server, connsock, addr): + def __init__(self, server, connsock, addr, buffer_size): self.server = server self.running = False self.sock = connsock @@ -2186,6 +2229,7 @@ def __init__(self, server, connsock, addr): self.sslconn = None threading.Thread.__init__(self) self.daemon = True + self.buffer_size = buffer_size def wrap_conn(self): try: @@ -2251,9 +2295,9 @@ def wrap_conn(self): def read(self): if self.sslconn: - return self.sslconn.read() + return self.sslconn.read(self.buffer_size) else: - return self.sock.recv(1024) + return self.sock.recv(self.buffer_size) def write(self, bytes): if self.sslconn: @@ -2371,8 +2415,8 @@ def run(self): def __init__(self, certificate=None, ssl_version=None, certreqs=None, cacerts=None, chatty=True, connectionchatty=False, starttls_server=False, - alpn_protocols=None, - ciphers=None, context=None): + alpn_protocols=None, ciphers=None, context=None, + buffer_size=1024): if context: self.context = context else: @@ -2401,6 +2445,7 @@ def __init__(self, certificate=None, ssl_version=None, self.conn_errors = [] threading.Thread.__init__(self) self.daemon = True + self.buffer_size = buffer_size def __enter__(self): self.start(threading.Event()) @@ -2428,7 +2473,8 @@ def run(self): if support.verbose and self.chatty: sys.stdout.write(' server: new connection from ' + repr(connaddr) + '\n') - handler = self.ConnectionHandler(self, newconn, connaddr) + handler = self.ConnectionHandler(self, newconn, connaddr, + self.buffer_size) handler.start() handler.join() except TimeoutError as e: diff --git a/Misc/NEWS.d/next/Library/2021-04-19-15-53-03.bpo-37355.3pie1n.rst b/Misc/NEWS.d/next/Library/2021-04-19-15-53-03.bpo-37355.3pie1n.rst new file mode 100644 index 00000000000000..1bfad3f755e3f1 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2021-04-19-15-53-03.bpo-37355.3pie1n.rst @@ -0,0 +1,3 @@ +Added :attr:`ssl.SSLSocket.eager_recv`, if enabled a :ref:`non-blocking ` +TLS socket will drop the GIL once to read up to the entire buffer instead of reading at +most TLS record (16 KB). Patch by Josh Snyder and Safihre. diff --git a/Modules/_ssl.c b/Modules/_ssl.c index 591eb91dd0f340..5afc6e55b94af7 100644 --- a/Modules/_ssl.c +++ b/Modules/_ssl.c @@ -305,6 +305,7 @@ typedef struct { PyObject *Socket; /* weakref to socket on which we're layered */ SSL *ssl; PySSLContext *ctx; /* weakref to SSL context */ + int eager_recv; char shutdown_seen_zero; enum py_ssl_server_or_client socket_type; PyObject *owner; /* Python level "owner" passed to servername callback */ @@ -799,6 +800,7 @@ newPySSLSocket(PySSLContext *sslctx, PySocketSockObject *sock, self->ssl = NULL; self->Socket = NULL; self->ctx = (PySSLContext*)Py_NewRef(sslctx); + self->eager_recv = 0; self->shutdown_seen_zero = 0; self->owner = NULL; self->server_hostname = NULL; @@ -2118,6 +2120,22 @@ static int PySSL_set_context(PySSLSocket *self, PyObject *value, return 0; } +static PyObject * +PySSL_get_eager_recv(PySSLSocket *self, void *c) +{ + return PyBool_FromLong(self->eager_recv); +} + +static int +PySSL_set_eager_recv(PySSLSocket *self, PyObject *arg, void *c) +{ + int eager_recv; + if (!PyArg_Parse(arg, "p", &eager_recv)) + return -1; + self->eager_recv = eager_recv; + return 0; +} + PyDoc_STRVAR(PySSL_set_context_doc, "_setter_context(ctx)\n\ \ @@ -2430,10 +2448,11 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, Py_ssize_t len, PyObject *dest = NULL; char *mem; size_t count = 0; + size_t readbytes = 0; int retval; int sockstate; _PySSLError err; - int nonblocking; + int nonblocking = 0; PySocketSockObject *sock = GET_SOCKET(self); _PyTime_t timeout, deadline = 0; int has_timeout; @@ -2493,11 +2512,22 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, Py_ssize_t len, do { PySSL_BEGIN_ALLOW_THREADS - retval = SSL_read_ex(self->ssl, mem, (size_t)len, &count); + do { + retval = SSL_read_ex(self->ssl, mem + count, len, &readbytes); + if (retval <= 0) { + break; + } + count += readbytes; + len -= readbytes; + } while (nonblocking && self->eager_recv && len > 0); err = _PySSL_errno(retval == 0, self->ssl, retval); PySSL_END_ALLOW_THREADS self->err = err; + if (count > 0) { + break; + } + if (PyErr_CheckSignals()) goto error; @@ -2528,7 +2558,7 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, Py_ssize_t len, } while (err.ssl == SSL_ERROR_WANT_READ || err.ssl == SSL_ERROR_WANT_WRITE); - if (retval == 0) { + if (count == 0) { PySSL_SetError(self, retval, __FILE__, __LINE__); goto error; } @@ -2877,6 +2907,8 @@ PyDoc_STRVAR(PySSL_get_session_reused_doc, static PyGetSetDef ssl_getsetlist[] = { {"context", (getter) PySSL_get_context, (setter) PySSL_set_context, PySSL_set_context_doc}, + {"eager_recv", (getter) PySSL_get_eager_recv, + (setter) PySSL_set_eager_recv, NULL}, {"server_side", (getter) PySSL_get_server_side, NULL, PySSL_get_server_side_doc}, {"server_hostname", (getter) PySSL_get_server_hostname, NULL,