Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 53 additions & 6 deletions Lib/test/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2258,6 +2258,50 @@ def test_bio_read_write_data(self):
self.assertEqual(buf, b'foo\n')
self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap)

def test_bulk_nonblocking_read(self):
# 65536 bytes divide up into 4 TLS records (16 KB each)
# In nonblocking mode, we should be able to read all four in a single
# drop of the GIL.
size = 65536
trips = []

client_context, server_context, hostname = testing_context()
server = ThreadedEchoServer(context=server_context, chatty=False,
buffer_size=size)
with server:
sock = socket.create_connection((HOST, server.port))
sock.settimeout(0.0)
s = client_context.wrap_socket(sock, server_hostname=hostname,
do_handshake_on_connect=False)

with s:
while True:
try:
s.do_handshake()
break
except ssl.SSLWantReadError:
select.select([s], [], [])
except ssl.SSLWantWriteError:
select.select([], [s], [])

s.send(b'\x00' * size)

select.select([s], [], [])

while size > 0:
try:
count = len(s.recv(size))
except ssl.SSLWantReadError:
select.select([s], [], [])
# Give the sender some more time to complete sending.
time.sleep(0.01)
else:
if count > 16384:
return
size -= count

raise AssertionError("All TLS reads were smaller than 16KB")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise AssertionError("All TLS reads were smaller than 16KB")
self.fail("All TLS reads were smaller than 16KB")



class NetworkedTests(unittest.TestCase):

Expand Down Expand Up @@ -2316,7 +2360,7 @@ class ConnectionHandler(threading.Thread):
with and without the SSL wrapper around the socket connection, so
that we can test the STARTTLS functionality."""

def __init__(self, server, connsock, addr):
def __init__(self, server, connsock, addr, buffer_size):
self.server = server
self.running = False
self.sock = connsock
Expand All @@ -2325,6 +2369,7 @@ def __init__(self, server, connsock, addr):
self.sslconn = None
threading.Thread.__init__(self)
self.daemon = True
self.buffer_size = buffer_size

def wrap_conn(self):
try:
Expand Down Expand Up @@ -2382,9 +2427,9 @@ def wrap_conn(self):

def read(self):
if self.sslconn:
return self.sslconn.read()
return self.sslconn.read(self.buffer_size)
else:
return self.sock.recv(1024)
return self.sock.recv(self.buffer_size)

def write(self, bytes):
if self.sslconn:
Expand Down Expand Up @@ -2505,8 +2550,8 @@ def run(self):
def __init__(self, certificate=None, ssl_version=None,
certreqs=None, cacerts=None,
chatty=True, connectionchatty=False, starttls_server=False,
alpn_protocols=None,
ciphers=None, context=None):
alpn_protocols=None, ciphers=None, context=None,
buffer_size=1024):
if context:
self.context = context
else:
Expand Down Expand Up @@ -2535,6 +2580,7 @@ def __init__(self, certificate=None, ssl_version=None,
self.conn_errors = []
threading.Thread.__init__(self)
self.daemon = True
self.buffer_size = buffer_size

def __enter__(self):
self.start(threading.Event())
Expand Down Expand Up @@ -2562,7 +2608,8 @@ def run(self):
if support.verbose and self.chatty:
sys.stdout.write(' server: new connection from '
+ repr(connaddr) + '\n')
handler = self.ConnectionHandler(self, newconn, connaddr)
handler = self.ConnectionHandler(self, newconn, connaddr,
self.buffer_size)
handler.start()
handler.join()
except TimeoutError:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
When reading from a nonblocking TLS socket, drop the GIL once to read up to
the entire buffer. Previously we would read at most one TLS record (16 KB).
Patch by Josh Snyder.
25 changes: 19 additions & 6 deletions Modules/_ssl.c
Original file line number Diff line number Diff line change
Expand Up @@ -2334,10 +2334,11 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, int len, int group_right_1,
PyObject *dest = NULL;
char *mem;
size_t count = 0;
size_t got = 0;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
size_t got = 0;
size_t readbytes;

int retval;
int sockstate;
_PySSLError err;
int nonblocking;
int nonblocking = 0;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it necessary to initialize the variable with 0?

PySocketSockObject *sock = GET_SOCKET(self);
_PyTime_t timeout, deadline = 0;
int has_timeout;
Expand Down Expand Up @@ -2397,11 +2398,23 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, int len, int group_right_1,

do {
PySSL_BEGIN_ALLOW_THREADS
retval = SSL_read_ex(self->ssl, mem, len, &count);
do {
retval = SSL_read_ex(self->ssl, mem + got, len, &count);
if(retval <= 0) {
break;
}

got += count;
len -= count;
} while(nonblocking && len > 0);
Comment on lines +2402 to +2409
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got and count are confusing variable names. I suggest that you keep count as counter for the read() function call and introduce a new variable for SSL_read_ex() call.

Is it a good idea to have a potentially infinite loop here that ignores signals and timeouts? This smells like DoS vunlerable to be happen.

Suggested change
retval = SSL_read_ex(self->ssl, mem + got, len, &count);
if(retval <= 0) {
break;
}
got += count;
len -= count;
} while(nonblocking && len > 0);
retval = SSL_read_ex(self->ssl, mem + got, len, &readbytes);
if (retval <= 0) {
break;
}
count += readbytes;
len -= readbytes;
} while (nonblocking && len > 0);

err = _PySSL_errno(retval == 0, self->ssl, retval);
PySSL_END_ALLOW_THREADS
self->err = err;

if(got > 0) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if(got > 0) {
if (count > 0) {

break;
}

if (PyErr_CheckSignals())
goto error;

Expand All @@ -2415,7 +2428,7 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, int len, int group_right_1,
} else if (err.ssl == SSL_ERROR_ZERO_RETURN &&
SSL_get_shutdown(self->ssl) == SSL_RECEIVED_SHUTDOWN)
{
count = 0;
got = 0;
goto done;
}
else
Expand All @@ -2431,7 +2444,7 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, int len, int group_right_1,
} while (err.ssl == SSL_ERROR_WANT_READ ||
err.ssl == SSL_ERROR_WANT_WRITE);

if (retval == 0) {
if (got == 0) {
PySSL_SetError(self, retval, __FILE__, __LINE__);
goto error;
}
Expand All @@ -2441,11 +2454,11 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, int len, int group_right_1,
done:
Py_XDECREF(sock);
if (!group_right_1) {
_PyBytes_Resize(&dest, count);
_PyBytes_Resize(&dest, got);
return dest;
}
else {
return PyLong_FromSize_t(count);
return PyLong_FromSize_t(got);
}

error:
Expand Down