diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index ae66c3e7d4a56c..02d34863dbe965 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -2258,6 +2258,50 @@ def test_bio_read_write_data(self): self.assertEqual(buf, b'foo\n') self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap) + def test_bulk_nonblocking_read(self): + # 65536 bytes divide up into 4 TLS records (16 KB each) + # In nonblocking mode, we should be able to read all four in a single + # drop of the GIL. + size = 65536 + trips = [] + + client_context, server_context, hostname = testing_context() + server = ThreadedEchoServer(context=server_context, chatty=False, + buffer_size=size) + with server: + sock = socket.create_connection((HOST, server.port)) + sock.settimeout(0.0) + s = client_context.wrap_socket(sock, server_hostname=hostname, + do_handshake_on_connect=False) + + with s: + while True: + try: + s.do_handshake() + break + except ssl.SSLWantReadError: + select.select([s], [], []) + except ssl.SSLWantWriteError: + select.select([], [s], []) + + s.send(b'\x00' * size) + + select.select([s], [], []) + + while size > 0: + try: + count = len(s.recv(size)) + except ssl.SSLWantReadError: + select.select([s], [], []) + # Give the sender some more time to complete sending. + time.sleep(0.01) + else: + if count > 16384: + return + size -= count + + raise AssertionError("All TLS reads were smaller than 16KB") + class NetworkedTests(unittest.TestCase): @@ -2316,7 +2360,7 @@ class ConnectionHandler(threading.Thread): with and without the SSL wrapper around the socket connection, so that we can test the STARTTLS functionality.""" - def __init__(self, server, connsock, addr): + def __init__(self, server, connsock, addr, buffer_size): self.server = server self.running = False self.sock = connsock @@ -2325,6 +2369,7 @@ def __init__(self, server, connsock, addr): self.sslconn = None threading.Thread.__init__(self) self.daemon = True + self.buffer_size = buffer_size def wrap_conn(self): try: @@ -2382,9 +2427,9 @@ def wrap_conn(self): def read(self): if self.sslconn: - return self.sslconn.read() + return self.sslconn.read(self.buffer_size) else: - return self.sock.recv(1024) + return self.sock.recv(self.buffer_size) def write(self, bytes): if self.sslconn: @@ -2505,8 +2550,8 @@ def run(self): def __init__(self, certificate=None, ssl_version=None, certreqs=None, cacerts=None, chatty=True, connectionchatty=False, starttls_server=False, - alpn_protocols=None, - ciphers=None, context=None): + alpn_protocols=None, ciphers=None, context=None, + buffer_size=1024): if context: self.context = context else: @@ -2535,6 +2580,7 @@ def __init__(self, certificate=None, ssl_version=None, self.conn_errors = [] threading.Thread.__init__(self) self.daemon = True + self.buffer_size = buffer_size def __enter__(self): self.start(threading.Event()) @@ -2562,7 +2608,8 @@ def run(self): if support.verbose and self.chatty: sys.stdout.write(' server: new connection from ' + repr(connaddr) + '\n') - handler = self.ConnectionHandler(self, newconn, connaddr) + handler = self.ConnectionHandler(self, newconn, connaddr, + self.buffer_size) handler.start() handler.join() except TimeoutError: diff --git a/Misc/NEWS.d/next/Library/2021-04-19-15-53-03.bpo-37355.3pie1n.rst b/Misc/NEWS.d/next/Library/2021-04-19-15-53-03.bpo-37355.3pie1n.rst new file mode 100644 index 00000000000000..be6ad12d4711b1 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2021-04-19-15-53-03.bpo-37355.3pie1n.rst @@ -0,0 +1,3 @@ +When reading from a nonblocking TLS socket, drop the GIL once to read up to +the entire buffer. Previously we would read at most one TLS record (16 KB). +Patch by Josh Snyder. diff --git a/Modules/_ssl.c b/Modules/_ssl.c index 4b84014d008c12..97e3a3ad37199d 100644 --- a/Modules/_ssl.c +++ b/Modules/_ssl.c @@ -2334,10 +2334,11 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, int len, int group_right_1, PyObject *dest = NULL; char *mem; size_t count = 0; + size_t got = 0; int retval; int sockstate; _PySSLError err; - int nonblocking; + int nonblocking = 0; PySocketSockObject *sock = GET_SOCKET(self); _PyTime_t timeout, deadline = 0; int has_timeout; @@ -2397,11 +2398,23 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, int len, int group_right_1, do { PySSL_BEGIN_ALLOW_THREADS - retval = SSL_read_ex(self->ssl, mem, len, &count); + do { + retval = SSL_read_ex(self->ssl, mem + got, len, &count); + if(retval <= 0) { + break; + } + + got += count; + len -= count; + } while(nonblocking && len > 0); err = _PySSL_errno(retval == 0, self->ssl, retval); PySSL_END_ALLOW_THREADS self->err = err; + if(got > 0) { + break; + } + if (PyErr_CheckSignals()) goto error; @@ -2415,7 +2428,7 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, int len, int group_right_1, } else if (err.ssl == SSL_ERROR_ZERO_RETURN && SSL_get_shutdown(self->ssl) == SSL_RECEIVED_SHUTDOWN) { - count = 0; + got = 0; goto done; } else @@ -2431,7 +2444,7 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, int len, int group_right_1, } while (err.ssl == SSL_ERROR_WANT_READ || err.ssl == SSL_ERROR_WANT_WRITE); - if (retval == 0) { + if (got == 0) { PySSL_SetError(self, retval, __FILE__, __LINE__); goto error; } @@ -2441,11 +2454,11 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, int len, int group_right_1, done: Py_XDECREF(sock); if (!group_right_1) { - _PyBytes_Resize(&dest, count); + _PyBytes_Resize(&dest, got); return dest; } else { - return PyLong_FromSize_t(count); + return PyLong_FromSize_t(got); } error: