diff --git a/Include/cpython/pyerrors.h b/Include/cpython/pyerrors.h index 0d9cc9922f7368..60c3433180205d 100644 --- a/Include/cpython/pyerrors.h +++ b/Include/cpython/pyerrors.h @@ -134,6 +134,7 @@ PyAPI_FUNC(PyObject *) _PyErr_TrySetFromCause( int PySignal_SetWakeupFd(int fd); PyAPI_FUNC(int) _PyErr_CheckSignals(void); +PyAPI_FUNC(int) _PyErr_CheckSignalsTrippedNoGil(void); /* Support for adding program text to SyntaxErrors */ diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index d4eb2d2e81fe0f..ba3681a74474f0 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -2120,6 +2120,49 @@ def test_bio_read_write_data(self): self.assertEqual(buf, b'foo\n') self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap) + def test_bulk_nonblocking_read(self): + # 65536 bytes divide up into 4 TLS records (16 KB each) + # In nonblocking mode, we should be able to read all four in a single + # drop of the GIL. + size = 65536 + trips = [] + + client_context, server_context, hostname = testing_context() + server = ThreadedEchoServer(context=server_context, chatty=False, + buffer_size=size) + with server: + sock = socket.create_connection((HOST, server.port)) + sock.settimeout(0.0) + s = client_context.wrap_socket(sock, server_hostname=hostname, + do_handshake_on_connect=False) + + with s: + while True: + try: + s.do_handshake() + break + except ssl.SSLWantReadError: + select.select([s], [], []) + except ssl.SSLWantWriteError: + select.select([], [s], []) + + s.send(b'\x00' * size) + + select.select([s], [], []) + + while size > 0: + try: + count = len(s.recv(size)) + except ssl.SSLWantReadError: + select.select([s], [], []) + # Give the sender some more time to complete sending. + time.sleep(0.01) + else: + if count > 16384: + return + size -= count + + self.fail("All TLS reads were smaller than 16KB") @support.requires_resource('network') class NetworkedTests(unittest.TestCase): @@ -2179,7 +2222,7 @@ class ConnectionHandler(threading.Thread): with and without the SSL wrapper around the socket connection, so that we can test the STARTTLS functionality.""" - def __init__(self, server, connsock, addr): + def __init__(self, server, connsock, addr, buffer_size): self.server = server self.running = False self.sock = connsock @@ -2188,6 +2231,7 @@ def __init__(self, server, connsock, addr): self.sslconn = None threading.Thread.__init__(self) self.daemon = True + self.buffer_size = buffer_size def wrap_conn(self): try: @@ -2253,9 +2297,9 @@ def wrap_conn(self): def read(self): if self.sslconn: - return self.sslconn.read() + return self.sslconn.read(self.buffer_size) else: - return self.sock.recv(1024) + return self.sock.recv(self.buffer_size) def write(self, bytes): if self.sslconn: @@ -2374,7 +2418,8 @@ def __init__(self, certificate=None, ssl_version=None, certreqs=None, cacerts=None, chatty=True, connectionchatty=False, starttls_server=False, alpn_protocols=None, - ciphers=None, context=None): + ciphers=None, context=None, + buffer_size=1024): if context: self.context = context else: @@ -2403,6 +2448,7 @@ def __init__(self, certificate=None, ssl_version=None, self.conn_errors = [] threading.Thread.__init__(self) self.daemon = True + self.buffer_size = buffer_size def __enter__(self): self.start(threading.Event()) @@ -2430,7 +2476,8 @@ def run(self): if support.verbose and self.chatty: sys.stdout.write(' server: new connection from ' + repr(connaddr) + '\n') - handler = self.ConnectionHandler(self, newconn, connaddr) + handler = self.ConnectionHandler(self, newconn, connaddr, + self.buffer_size) handler.start() handler.join() except TimeoutError as e: diff --git a/Modules/_ssl.c b/Modules/_ssl.c index 28112317bc289e..b9a3c5cc4ed557 100644 --- a/Modules/_ssl.c +++ b/Modules/_ssl.c @@ -310,6 +310,7 @@ typedef struct { SSL *ssl; PySSLContext *ctx; /* weakref to SSL context */ char shutdown_seen_zero; + char deferred_empty_reads; enum py_ssl_server_or_client socket_type; PyObject *owner; /* Python level "owner" passed to servername callback */ PyObject *server_hostname; @@ -360,7 +361,7 @@ class _ssl.SSLSession "PySSLSession *" "get_state_type(type)->PySSLSession_Type" #include "clinic/_ssl.c.h" -static int PySSL_select(PySocketSockObject *s, int writing, _PyTime_t timeout); +static int PySSL_select(PySocketSockObject *s, int writing, _PyTime_t timeout, int gil_held); static int PySSL_set_owner(PySSLSocket *, PyObject *, void *); static int PySSL_set_session(PySSLSocket *, PyObject *, void *); @@ -663,6 +664,11 @@ PySSL_SetError(PySSLSocket *sslsock, int ret, const char *filename, int lineno) if (ERR_GET_LIB(e) == ERR_LIB_SSL && ERR_GET_REASON(e) == SSL_R_CERTIFICATE_VERIFY_FAILED) { type = state->PySSLCertVerificationErrorObject; +#ifdef SSL_R_UNEXPECTED_EOF_WHILE_READING + } else if (ERR_GET_LIB(e) == ERR_LIB_SSL && + ERR_GET_REASON(e) == SSL_R_UNEXPECTED_EOF_WHILE_READING) { + type = state->PySSLEOFErrorObject; +#endif } break; } @@ -804,6 +810,7 @@ newPySSLSocket(PySSLContext *sslctx, PySocketSockObject *sock, self->Socket = NULL; self->ctx = (PySSLContext*)Py_NewRef(sslctx); self->shutdown_seen_zero = 0; + self->deferred_empty_reads = 0; self->owner = NULL; self->server_hostname = NULL; self->err = err; @@ -961,9 +968,9 @@ _ssl__SSLSocket_do_handshake_impl(PySSLSocket *self) timeout = _PyDeadline_Get(deadline); if (err.ssl == SSL_ERROR_WANT_READ) { - sockstate = PySSL_select(sock, 0, timeout); + sockstate = PySSL_select(sock, 0, timeout, 1); } else if (err.ssl == SSL_ERROR_WANT_WRITE) { - sockstate = PySSL_select(sock, 1, timeout); + sockstate = PySSL_select(sock, 1, timeout, 1); } else { sockstate = SOCKET_OPERATION_OK; } @@ -2217,7 +2224,7 @@ PySSL_dealloc(PySSLSocket *self) */ static int -PySSL_select(PySocketSockObject *s, int writing, _PyTime_t timeout) +PySSL_select(PySocketSockObject *s, int writing, _PyTime_t timeout, int gil_held) { int rc; #ifdef HAVE_POLL @@ -2226,6 +2233,8 @@ PySSL_select(PySocketSockObject *s, int writing, _PyTime_t timeout) #else int nfds; fd_set fds; + fd_set *rfds = NULL; + fd_set *wfds = NULL; struct timeval tv; #endif @@ -2249,13 +2258,17 @@ PySSL_select(PySocketSockObject *s, int writing, _PyTime_t timeout) pollfd.fd = s->sock_fd; pollfd.events = writing ? POLLOUT : POLLIN; - /* timeout is in seconds, poll() uses milliseconds */ + /* timeout is in nanoseconds, poll() uses milliseconds */ ms = (int)_PyTime_AsMilliseconds(timeout, _PyTime_ROUND_CEILING); assert(ms <= INT_MAX); - PySSL_BEGIN_ALLOW_THREADS - rc = poll(&pollfd, 1, (int)ms); - PySSL_END_ALLOW_THREADS + if (gil_held) { + PySSL_BEGIN_ALLOW_THREADS + rc = poll(&pollfd, 1, (int)ms); + PySSL_END_ALLOW_THREADS + } else { + rc = poll(&pollfd, 1, (int)ms); + } #else /* Guard against socket too large for select*/ if (!_PyIsSelectable_fd(s->sock_fd)) @@ -2266,14 +2279,17 @@ PySSL_select(PySocketSockObject *s, int writing, _PyTime_t timeout) FD_ZERO(&fds); FD_SET(s->sock_fd, &fds); - /* Wait until the socket becomes ready */ - PySSL_BEGIN_ALLOW_THREADS nfds = Py_SAFE_DOWNCAST(s->sock_fd+1, SOCKET_T, int); - if (writing) - rc = select(nfds, NULL, &fds, NULL, &tv); - else - rc = select(nfds, &fds, NULL, NULL, &tv); - PySSL_END_ALLOW_THREADS + rfds = writing ? NULL : &fds; + wfds = writing ? &fds : NULL; + /* Wait until the socket becomes ready */ + if (gil_held) { + PySSL_BEGIN_ALLOW_THREADS + rc = select(nfds, rfds, wfds, NULL, &tv); + PySSL_END_ALLOW_THREADS + } else { + rc = select(nfds, rfds, wfds, NULL, &tv); + } #endif /* Return SOCKET_TIMED_OUT on timeout, SOCKET_OPERATION_OK otherwise @@ -2299,10 +2315,9 @@ _ssl__SSLSocket_write_impl(PySSLSocket *self, Py_buffer *b) int retval; int sockstate; _PySSLError err; - int nonblocking; PySocketSockObject *sock = GET_SOCKET(self); _PyTime_t timeout, deadline = 0; - int has_timeout; + unsigned int signalled = 0; if (sock != NULL) { if (((PyObject*)sock) == Py_None) { @@ -2314,71 +2329,73 @@ _ssl__SSLSocket_write_impl(PySSLSocket *self, Py_buffer *b) Py_INCREF(sock); } - if (sock != NULL) { - /* just in case the blocking state of the socket has been changed */ - nonblocking = (sock->sock_timeout >= 0); - BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking); - BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking); - } + PySSL_BEGIN_ALLOW_THREADS timeout = GET_SOCKET_TIMEOUT(sock); - has_timeout = (timeout > 0); - if (has_timeout) { + if (timeout > 0) { deadline = _PyDeadline_Init(timeout); } - sockstate = PySSL_select(sock, 1, timeout); - if (sockstate == SOCKET_HAS_TIMED_OUT) { - PyErr_SetString(PyExc_TimeoutError, - "The write operation timed out"); - goto error; - } else if (sockstate == SOCKET_HAS_BEEN_CLOSED) { - PyErr_SetString(get_state_sock(self)->PySSLErrorObject, - "Underlying socket has been closed."); - goto error; - } else if (sockstate == SOCKET_TOO_LARGE_FOR_SELECT) { - PyErr_SetString(get_state_sock(self)->PySSLErrorObject, - "Underlying socket too large for select()."); - goto error; + if (sock != NULL) { + /* just in case the blocking state of the socket has been changed */ + BIO_set_nbio(SSL_get_rbio(self->ssl), deadline > 0); + BIO_set_nbio(SSL_get_wbio(self->ssl), deadline > 0); } - do { - PySSL_BEGIN_ALLOW_THREADS + sockstate = SOCKET_OPERATION_OK; + while (sockstate == SOCKET_OPERATION_OK) { retval = SSL_write_ex(self->ssl, b->buf, (size_t)b->len, &count); err = _PySSL_errno(retval == 0, self->ssl, retval); - PySSL_END_ALLOW_THREADS self->err = err; - if (PyErr_CheckSignals()) - goto error; + if (retval > 0) { + /* write complete */ + break; + } - if (has_timeout) { + if (_PyErr_CheckSignalsTrippedNoGil()) { + Py_BLOCK_THREADS; + signalled = PyErr_CheckSignals(); + Py_UNBLOCK_THREADS; + if (signalled) { + break; + } + } + + if (deadline > 0) { timeout = _PyDeadline_Get(deadline); } if (err.ssl == SSL_ERROR_WANT_READ) { - sockstate = PySSL_select(sock, 0, timeout); + sockstate = PySSL_select(sock, 0, timeout, 0); } else if (err.ssl == SSL_ERROR_WANT_WRITE) { - sockstate = PySSL_select(sock, 1, timeout); + sockstate = PySSL_select(sock, 1, timeout, 0); } else { - sockstate = SOCKET_OPERATION_OK; - } - - if (sockstate == SOCKET_HAS_TIMED_OUT) { - PyErr_SetString(PyExc_TimeoutError, - "The write operation timed out"); - goto error; - } else if (sockstate == SOCKET_HAS_BEEN_CLOSED) { - PyErr_SetString(get_state_sock(self)->PySSLErrorObject, - "Underlying socket has been closed."); - goto error; - } else if (sockstate == SOCKET_IS_NONBLOCKING) { break; } - } while (err.ssl == SSL_ERROR_WANT_READ || - err.ssl == SSL_ERROR_WANT_WRITE); + } + PySSL_END_ALLOW_THREADS + + if (signalled) { + goto error; + } + + if (sockstate == SOCKET_HAS_TIMED_OUT) { + PyErr_SetString(PyExc_TimeoutError, + "The write operation timed out"); + goto error; + } else if (sockstate == SOCKET_HAS_BEEN_CLOSED) { + PyErr_SetString(get_state_sock(self)->PySSLErrorObject, + "Underlying socket has been closed."); + goto error; + } else if (sockstate == SOCKET_TOO_LARGE_FOR_SELECT) { + PyErr_SetString(get_state_sock(self)->PySSLErrorObject, + "Underlying socket too large for select()."); + goto error; + } Py_XDECREF(sock); + if (retval == 0) return PySSL_SetError(self, retval, __FILE__, __LINE__); if (PySSL_ChainExceptions(self) < 0) @@ -2436,11 +2453,9 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, Py_ssize_t len, size_t count = 0; int retval; int sockstate; - _PySSLError err; - int nonblocking; PySocketSockObject *sock = GET_SOCKET(self); _PyTime_t timeout, deadline = 0; - int has_timeout; + unsigned int signalled = 0; if (!group_right_1 && len < 0) { PyErr_SetString(PyExc_ValueError, "size should not be negative"); @@ -2477,67 +2492,133 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, Py_ssize_t len, goto error; } if (len == 0) { - count = 0; goto done; } } } - if (sock != NULL) { - /* just in case the blocking state of the socket has been changed */ - nonblocking = (sock->sock_timeout >= 0); - BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking); - BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking); + if (self->deferred_empty_reads) { + count = 0; + self->deferred_empty_reads -= 1; + goto done; } + PySSL_BEGIN_ALLOW_THREADS + timeout = GET_SOCKET_TIMEOUT(sock); - has_timeout = (timeout > 0); - if (has_timeout) + if (timeout > 0) { deadline = _PyDeadline_Init(timeout); + } - do { - PySSL_BEGIN_ALLOW_THREADS - retval = SSL_read_ex(self->ssl, mem, (size_t)len, &count); - err = _PySSL_errno(retval == 0, self->ssl, retval); - PySSL_END_ALLOW_THREADS + if (sock) { + /* just in case the blocking state of the socket has been changed */ + BIO_set_nbio(SSL_get_rbio(self->ssl), deadline > 0); + BIO_set_nbio(SSL_get_wbio(self->ssl), deadline > 0); + } + + sockstate = SOCKET_OPERATION_OK; + while (sockstate == SOCKET_OPERATION_OK) { + size_t bytes_read = 0; + retval = SSL_read_ex(self->ssl, mem + count, (size_t)len - count, &bytes_read); + _PySSLError err = _PySSL_errno(retval == 0, self->ssl, retval); self->err = err; - if (PyErr_CheckSignals()) - goto error; + if (_PyErr_CheckSignalsTrippedNoGil()) { + Py_BLOCK_THREADS + signalled = PyErr_CheckSignals(); + Py_UNBLOCK_THREADS + if (signalled) { + break; + } + } - if (has_timeout) { + if (retval > 0) { + count += bytes_read; + if (!bytes_read || count >= (size_t)len) { + /* read complete */ + break; + } + if (deadline && _PyDeadline_Get(deadline) <= 0) { + sockstate = SOCKET_HAS_TIMED_OUT; + break; + } + if (!SSL_has_pending(self->ssl)) { + /* HACK: timeout of 1 to make sure we don't immediately fail */ + sockstate = PySSL_select(sock, 0, 1, 0); + if (sockstate == SOCKET_HAS_TIMED_OUT) { + /* nothing else right now, so return what we have */ + sockstate = SOCKET_OPERATION_OK; + break; + } + } + /* could be an error or more data */ + continue; + } + + if (err.ssl == SSL_ERROR_ZERO_RETURN && SSL_get_shutdown(self->ssl) == SSL_RECEIVED_SHUTDOWN) { + self->deferred_empty_reads += 1; + retval = 1; + break; + } + + /* See https://www.openssl.org/docs/man1.1.1/man3/SSL_get_error.html#BUGS + for more details about this EOF condition */ +#ifdef SSL_R_UNEXPECTED_EOF_WHILE_READING + unsigned long e; + if (err.ssl == SSL_ERROR_SSL && + ERR_GET_LIB((e = ERR_peek_last_error())) == ERR_LIB_SSL && + ERR_GET_REASON(e) == SSL_R_UNEXPECTED_EOF_WHILE_READING) { +#else + if (err.ssl == SSL_ERROR_SYSCALL && !err.c) { +#endif + self->deferred_empty_reads += 1; + retval = 1; + break; + } + + if (deadline > 0) { timeout = _PyDeadline_Get(deadline); + if (timeout < 0) { + sockstate = SOCKET_HAS_TIMED_OUT; + break; + } } if (err.ssl == SSL_ERROR_WANT_READ) { - sockstate = PySSL_select(sock, 0, timeout); + sockstate = PySSL_select(sock, 0, timeout, 0); } else if (err.ssl == SSL_ERROR_WANT_WRITE) { - sockstate = PySSL_select(sock, 1, timeout); - } else if (err.ssl == SSL_ERROR_ZERO_RETURN && - SSL_get_shutdown(self->ssl) == SSL_RECEIVED_SHUTDOWN) - { - count = 0; - goto done; - } - else - sockstate = SOCKET_OPERATION_OK; - - if (sockstate == SOCKET_HAS_TIMED_OUT) { - PyErr_SetString(PyExc_TimeoutError, - "The read operation timed out"); - goto error; - } else if (sockstate == SOCKET_IS_NONBLOCKING) { + sockstate = PySSL_select(sock, 1, timeout, 0); + } else { break; } - } while (err.ssl == SSL_ERROR_WANT_READ || - err.ssl == SSL_ERROR_WANT_WRITE); + } + PySSL_END_ALLOW_THREADS + + if (signalled) { + goto error; + } + + if (sockstate == SOCKET_HAS_TIMED_OUT) { + PyErr_SetString(PyExc_TimeoutError, + "The read operation timed out"); + goto error; + } else if (sockstate == SOCKET_HAS_BEEN_CLOSED) { + PyErr_SetString(get_state_sock(self)->PySSLErrorObject, + "Underlying socket has been closed."); + goto error; + } else if (sockstate == SOCKET_TOO_LARGE_FOR_SELECT) { + PyErr_SetString(get_state_sock(self)->PySSLErrorObject, + "Underlying socket too large for select()."); + goto error; + } if (retval == 0) { PySSL_SetError(self, retval, __FILE__, __LINE__); goto error; } - if (self->exc_type != NULL) + if (self->exc_type != NULL) { goto error; + } done: Py_XDECREF(sock); @@ -2555,6 +2636,7 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, Py_ssize_t len, if (!group_right_1) Py_XDECREF(dest); return NULL; + } /*[clinic input] @@ -2633,9 +2715,9 @@ _ssl__SSLSocket_shutdown_impl(PySSLSocket *self) /* Possibly retry shutdown until timeout or failure */ if (err.ssl == SSL_ERROR_WANT_READ) - sockstate = PySSL_select(sock, 0, timeout); + sockstate = PySSL_select(sock, 0, timeout, 1); else if (err.ssl == SSL_ERROR_WANT_WRITE) - sockstate = PySSL_select(sock, 1, timeout); + sockstate = PySSL_select(sock, 1, timeout, 1); else break; diff --git a/Modules/signalmodule.c b/Modules/signalmodule.c index 0e472e1ee4f9dd..381786caeebbf8 100644 --- a/Modules/signalmodule.c +++ b/Modules/signalmodule.c @@ -1780,6 +1780,14 @@ PyErr_CheckSignals(void) } +int +_PyErr_CheckSignalsTrippedNoGil() +{ + _Py_CHECK_EMSCRIPTEN_SIGNALS(); + return _Py_atomic_load(&is_tripped) ? 1 : 0; +} + + /* Declared in cpython/pyerrors.h */ int _PyErr_CheckSignalsTstate(PyThreadState *tstate)