From 8e40d2f2717e8d2b115c7306415ddc2cf7201f5a Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Thu, 7 Jun 2018 12:24:13 -0400 Subject: [PATCH] bpo-33694: asyncio: Fix race in Proactor's Transport.set_protocol() --- Lib/asyncio/proactor_events.py | 157 ++++++----------- Lib/asyncio/protocols.py | 19 +++ Lib/asyncio/sslproto.py | 21 +-- Lib/test/test_asyncio/test_events.py | 10 +- Lib/test/test_asyncio/test_proactor_events.py | 161 +----------------- Lib/test/test_asyncio/test_sslproto.py | 14 +- .../2018-06-07-12-24-11.bpo-33694.NjSAqI.rst | 4 + 7 files changed, 97 insertions(+), 289 deletions(-) create mode 100644 Misc/NEWS.d/next/Library/2018-06-07-12-24-11.bpo-33694.NjSAqI.rst diff --git a/Lib/asyncio/proactor_events.py b/Lib/asyncio/proactor_events.py index 337ed0fb204751..e0de1b61839561 100644 --- a/Lib/asyncio/proactor_events.py +++ b/Lib/asyncio/proactor_events.py @@ -161,6 +161,12 @@ def __init__(self, loop, sock, protocol, waiter=None, extra=None, server=None): self._loop_reading_cb = None self._paused = True + self._buffered = False + + self._read_backup_buffer = bytearray(32768) + self._read_buffer = None + self._read_buffer_belongs_to_proto = False + super().__init__(loop, sock, protocol, waiter, extra, server) self._reschedule_on_resume = False @@ -168,18 +174,10 @@ def __init__(self, loop, sock, protocol, waiter=None, self._paused = False def set_protocol(self, protocol): - if isinstance(protocol, protocols.BufferedProtocol): - self._loop_reading_cb = self._loop_reading__get_buffer - else: - self._loop_reading_cb = self._loop_reading__data_received - + self._buffered = isinstance(protocol, protocols.BufferedProtocol) + self._buffer_belongs_to_proto = False super().set_protocol(protocol) - if self.is_reading(): - # reset reading callback / buffers / self._read_fut - self.pause_reading() - self.resume_reading() - def is_reading(self): return not self._paused and not self._closing @@ -187,19 +185,6 @@ def pause_reading(self): if self._closing or self._paused: return self._paused = True - - if self._read_fut is not None and not self._read_fut.done(): - # TODO: This is an ugly hack to cancel the current read future - # *and* avoid potential race conditions, as read cancellation - # goes through `future.cancel()` and `loop.call_soon()`. - # We then use this special attribute in the reader callback to - # exit *immediately* without doing any cleanup/rescheduling. - self._read_fut.__asyncio_cancelled_on_pause__ = True - - self._read_fut.cancel() - self._read_fut = None - self._reschedule_on_resume = True - if self._loop.get_debug(): logger.debug("%r pauses reading", self) @@ -228,67 +213,6 @@ def _loop_reading__on_eof(self): self.close() def _loop_reading(self, fut=None): - self._loop_reading_cb(fut) - - def _loop_reading__data_received(self, fut): - if (fut is not None and - getattr(fut, '__asyncio_cancelled_on_pause__', False)): - return - - if self._paused: - self._reschedule_on_resume = True - return - - data = None - try: - if fut is not None: - assert self._read_fut is fut or (self._read_fut is None and - self._closing) - self._read_fut = None - if fut.done(): - # deliver data later in "finally" clause - data = fut.result() - else: - # the future will be replaced by next proactor.recv call - fut.cancel() - - if self._closing: - # since close() has been called we ignore any read data - data = None - return - - if data == b'': - # we got end-of-file so no need to reschedule a new read - return - - # reschedule a new read - self._read_fut = self._loop._proactor.recv(self._sock, 32768) - except ConnectionAbortedError as exc: - if not self._closing: - self._fatal_error(exc, 'Fatal read error on pipe transport') - elif self._loop.get_debug(): - logger.debug("Read error on pipe transport while closing", - exc_info=True) - except ConnectionResetError as exc: - self._force_close(exc) - except OSError as exc: - self._fatal_error(exc, 'Fatal read error on pipe transport') - except futures.CancelledError: - if not self._closing: - raise - else: - self._read_fut.add_done_callback(self._loop_reading__data_received) - finally: - if data: - self._protocol.data_received(data) - elif data == b'': - self._loop_reading__on_eof() - - def _loop_reading__get_buffer(self, fut): - if (fut is not None and - getattr(fut, '__asyncio_cancelled_on_pause__', False)): - return - if self._paused: self._reschedule_on_resume = True return @@ -324,32 +248,59 @@ def _loop_reading__get_buffer(self, fut): # we got end-of-file so no need to reschedule a new read self._loop_reading__on_eof() else: - try: - self._protocol.buffer_updated(nbytes) - except Exception as exc: - self._fatal_error( - exc, - 'Fatal error: ' - 'protocol.buffer_updated() call failed.') - return + if self._buffered: + try: + if self._read_buffer_belongs_to_proto: + self._protocol.buffer_updated(nbytes) + else: + protocols._feed_data_to_bufferred_proto( + self._protocol, self._read_buffer[:nbytes]) + except Exception as exc: + self._fatal_error( + exc, + 'Fatal error: ' + 'protocol.buffer_updated() call failed.') + return + else: + try: + self._protocol.data_received( + self._read_buffer[:nbytes]) + except Exception as exc: + self._fatal_error( + exc, + 'Fatal error: ' + 'protocol.data_received() call failed.') + return if self._closing or nbytes == 0: # since close() has been called we ignore any read data return - try: - buf = self._protocol.get_buffer(-1) - if not len(buf): - raise RuntimeError('get_buffer() returned an empty buffer') - except Exception as exc: - self._fatal_error( - exc, 'Fatal error: protocol.get_buffer() call failed.') - return + if self._buffered: + try: + self._read_buffer = self._protocol.get_buffer(-1) + if not len(self._read_buffer): + raise RuntimeError('get_buffer() returned an empty buffer') + except Exception as exc: + self._fatal_error( + exc, 'Fatal error: protocol.get_buffer() call failed.') + return + else: + self._read_buffer_belongs_to_proto = True + else: + self._read_buffer = self._read_backup_buffer + self._read_buffer_belongs_to_proto = False try: - # schedule a new read - self._read_fut = self._loop._proactor.recv_into(self._sock, buf) - self._read_fut.add_done_callback(self._loop_reading__get_buffer) + # TODO 3.8: Use WSARecv instead of WSARecvInto like libuv + # (see win/tcp.c; uv_process_tcp_read_req function.) + # WSARecv accepts a buffer as its first argument, which + # means we can use it for BufferedProtocol just fine, + # but it's easier to work with (in particular we can + # handle cancellation better.) + self._read_fut = self._loop._proactor.recv_into( + self._sock, self._read_buffer) + self._read_fut.add_done_callback(self._loop_reading) except ConnectionAbortedError as exc: if not self._closing: self._fatal_error(exc, 'Fatal read error on pipe transport') diff --git a/Lib/asyncio/protocols.py b/Lib/asyncio/protocols.py index b8d2e6be552e1e..4d47da387caa3f 100644 --- a/Lib/asyncio/protocols.py +++ b/Lib/asyncio/protocols.py @@ -189,3 +189,22 @@ def pipe_connection_lost(self, fd, exc): def process_exited(self): """Called when subprocess has exited.""" + + +def _feed_data_to_bufferred_proto(proto, data): + data_len = len(data) + while data_len: + buf = proto.get_buffer(data_len) + buf_len = len(buf) + if not buf_len: + raise RuntimeError('get_buffer() returned an empty buffer') + + if buf_len >= data_len: + buf[:data_len] = data + proto.buffer_updated(data_len) + return + else: + buf[:buf_len] = data[:buf_len] + proto.buffer_updated(buf_len) + data = data[buf_len:] + data_len = len(data) diff --git a/Lib/asyncio/sslproto.py b/Lib/asyncio/sslproto.py index fac2ae74e808b8..5578c6f81834ae 100644 --- a/Lib/asyncio/sslproto.py +++ b/Lib/asyncio/sslproto.py @@ -535,7 +535,7 @@ def data_received(self, data): if chunk: try: if self._app_protocol_is_buffer: - _feed_data_to_bufferred_proto( + protocols._feed_data_to_bufferred_proto( self._app_protocol, chunk) else: self._app_protocol.data_received(chunk) @@ -721,22 +721,3 @@ def _abort(self): self._transport.abort() finally: self._finalize() - - -def _feed_data_to_bufferred_proto(proto, data): - data_len = len(data) - while data_len: - buf = proto.get_buffer(data_len) - buf_len = len(buf) - if not buf_len: - raise RuntimeError('get_buffer() returned an empty buffer') - - if buf_len >= data_len: - buf[:data_len] = data - proto.buffer_updated(data_len) - return - else: - buf[:buf_len] = data[:buf_len] - proto.buffer_updated(buf_len) - data = data[buf_len:] - data_len = len(data) diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py index 11cd950df1cedb..60047817c0fc2a 100644 --- a/Lib/test/test_asyncio/test_events.py +++ b/Lib/test/test_asyncio/test_events.py @@ -2504,10 +2504,12 @@ def test_sendfile_close_peer_in_the_middle_of_receiving(self): self.loop.sendfile(cli_proto.transport, self.file)) self.run_loop(srv_proto.done) - self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA), - srv_proto.nbytes) - self.assertTrue(1024 <= self.file.tell() < len(self.DATA), - self.file.tell()) + self.assertLessEqual(1024, srv_proto.nbytes) + self.assertLessEqual(srv_proto.nbytes, len(self.DATA)) + + self.assertLessEqual(1024, self.file.tell()) + self.assertLessEqual(self.file.tell(), len(self.DATA)) + self.assertTrue(cli_proto.transport.is_closing()) def test_sendfile_fallback_close_peer_in_the_middle_of_receiving(self): diff --git a/Lib/test/test_asyncio/test_proactor_events.py b/Lib/test/test_asyncio/test_proactor_events.py index 26588634de04a7..9e5d7ee5d086e0 100644 --- a/Lib/test/test_asyncio/test_proactor_events.py +++ b/Lib/test/test_asyncio/test_proactor_events.py @@ -52,49 +52,16 @@ def test_ctor(self): test_utils.run_briefly(self.loop) self.assertIsNone(fut.result()) self.protocol.connection_made(tr) - self.proactor.recv.assert_called_with(self.sock, 32768) + self.proactor.recv_into.assert_called_with(self.sock, tr._read_buffer) def test_loop_reading(self): tr = self.socket_transport() tr._loop_reading() - self.loop._proactor.recv.assert_called_with(self.sock, 32768) + self.loop._proactor.recv_into.assert_called_with( + self.sock, tr._read_buffer) self.assertFalse(self.protocol.data_received.called) self.assertFalse(self.protocol.eof_received.called) - def test_loop_reading_data(self): - res = asyncio.Future(loop=self.loop) - res.set_result(b'data') - - tr = self.socket_transport() - tr._read_fut = res - tr._loop_reading(res) - self.loop._proactor.recv.assert_called_with(self.sock, 32768) - self.protocol.data_received.assert_called_with(b'data') - - def test_loop_reading_no_data(self): - res = asyncio.Future(loop=self.loop) - res.set_result(b'') - - tr = self.socket_transport() - self.assertRaises(AssertionError, tr._loop_reading, res) - - tr.close = mock.Mock() - tr._read_fut = res - tr._loop_reading(res) - self.assertFalse(self.loop._proactor.recv.called) - self.assertTrue(self.protocol.eof_received.called) - self.assertTrue(tr.close.called) - - def test_loop_reading_aborted(self): - err = self.loop._proactor.recv.side_effect = ConnectionAbortedError() - - tr = self.socket_transport() - tr._fatal_error = mock.Mock() - tr._loop_reading() - tr._fatal_error.assert_called_with( - err, - 'Fatal read error on pipe transport') - def test_loop_reading_aborted_closing(self): self.loop._proactor.recv.side_effect = ConnectionAbortedError() @@ -104,35 +71,6 @@ def test_loop_reading_aborted_closing(self): tr._loop_reading() self.assertFalse(tr._fatal_error.called) - def test_loop_reading_aborted_is_fatal(self): - self.loop._proactor.recv.side_effect = ConnectionAbortedError() - tr = self.socket_transport() - tr._closing = False - tr._fatal_error = mock.Mock() - tr._loop_reading() - self.assertTrue(tr._fatal_error.called) - - def test_loop_reading_conn_reset_lost(self): - err = self.loop._proactor.recv.side_effect = ConnectionResetError() - - tr = self.socket_transport() - tr._closing = False - tr._fatal_error = mock.Mock() - tr._force_close = mock.Mock() - tr._loop_reading() - self.assertFalse(tr._fatal_error.called) - tr._force_close.assert_called_with(err) - - def test_loop_reading_exception(self): - err = self.loop._proactor.recv.side_effect = (OSError()) - - tr = self.socket_transport() - tr._fatal_error = mock.Mock() - tr._loop_reading() - tr._fatal_error.assert_called_with( - err, - 'Fatal read error on pipe transport') - def test_write(self): tr = self.socket_transport() tr._loop_writing = mock.Mock() @@ -335,51 +273,6 @@ def test_write_eof_duplex_pipe(self): tr.write_eof() close_transport(tr) - def test_pause_resume_reading(self): - tr = self.socket_transport() - futures = [] - for msg in [b'data1', b'data2', b'data3', b'data4', b'data5', b'']: - f = asyncio.Future(loop=self.loop) - f.set_result(msg) - futures.append(f) - - self.loop._proactor.recv.side_effect = futures - self.loop._run_once() - self.assertFalse(tr._paused) - self.assertTrue(tr.is_reading()) - self.loop._run_once() - self.protocol.data_received.assert_called_with(b'data1') - self.loop._run_once() - self.protocol.data_received.assert_called_with(b'data2') - - tr.pause_reading() - tr.pause_reading() - self.assertTrue(tr._paused) - self.assertFalse(tr.is_reading()) - for i in range(10): - self.loop._run_once() - self.protocol.data_received.assert_called_with(b'data2') - - tr.resume_reading() - tr.resume_reading() - self.assertFalse(tr._paused) - self.assertTrue(tr.is_reading()) - self.loop._run_once() - self.protocol.data_received.assert_called_with(b'data3') - self.loop._run_once() - self.protocol.data_received.assert_called_with(b'data4') - - tr.pause_reading() - tr.resume_reading() - self.loop.call_exception_handler = mock.Mock() - self.loop._run_once() - self.loop.call_exception_handler.assert_not_called() - self.protocol.data_received.assert_called_with(b'data5') - tr.close() - - self.assertFalse(tr.is_reading()) - - def pause_writing_transport(self, high): tr = self.socket_transport() tr.set_write_buffer_limits(high=high) @@ -527,13 +420,13 @@ def test_proto_type_switch(self): tr = self.socket_transport() res = asyncio.Future(loop=self.loop) - res.set_result(b'data') + res.set_result(1) tr = self.socket_transport() tr._read_fut = res + tr._read_buffer = b'a' tr._loop_reading(res) - self.loop._proactor.recv.assert_called_with(self.sock, 32768) - self.protocol.data_received.assert_called_with(b'data') + self.protocol.data_received.assert_called_with(b'a') # switch protocol to a BufferedProtocol @@ -551,38 +444,6 @@ def test_proto_type_switch(self): self.loop._proactor.recv_into.assert_called_with(self.sock, buf) buf_proto.buffer_updated.assert_called_with(4) - def test_proto_buf_switch(self): - tr = self.socket_transport() - test_utils.run_briefly(self.loop) - self.protocol.get_buffer.assert_called_with(-1) - - # switch protocol to *another* BufferedProtocol - - buf_proto = test_utils.make_test_protocol(asyncio.BufferedProtocol) - buf = bytearray(4) - buf_proto.get_buffer.side_effect = lambda hint: buf - tr._read_fut.done.side_effect = lambda: False - tr.set_protocol(buf_proto) - self.assertFalse(buf_proto.get_buffer.called) - test_utils.run_briefly(self.loop) - buf_proto.get_buffer.assert_called_with(-1) - - def test_buffer_updated_error(self): - transport = self.socket_transport() - transport._fatal_error = mock.Mock() - - self.loop.call_exception_handler = mock.Mock() - self.protocol.buffer_updated.side_effect = LookupError() - - res = asyncio.Future(loop=self.loop) - res.set_result(10) - transport._read_fut = res - transport._loop_reading(res) - - self.assertTrue(transport._fatal_error.called) - self.assertFalse(self.protocol.get_buffer.called) - self.assertTrue(self.protocol.buffer_updated.called) - def test_loop_eof_received_error(self): res = asyncio.Future(loop=self.loop) res.set_result(0) @@ -599,16 +460,6 @@ def test_loop_eof_received_error(self): self.assertTrue(self.protocol.eof_received.called) self.assertTrue(tr._fatal_error.called) - def test_loop_reading_data(self): - res = asyncio.Future(loop=self.loop) - res.set_result(4) - - tr = self.socket_transport() - tr._read_fut = res - tr._loop_reading(res) - self.loop._proactor.recv_into.assert_called_with(self.sock, self.buf) - self.protocol.buffer_updated.assert_called_with(4) - def test_loop_reading_no_data(self): res = asyncio.Future(loop=self.loop) res.set_result(0) diff --git a/Lib/test/test_asyncio/test_sslproto.py b/Lib/test/test_asyncio/test_sslproto.py index 78ab1eb8223148..66147bfb0d0c4b 100644 --- a/Lib/test/test_asyncio/test_sslproto.py +++ b/Lib/test/test_asyncio/test_sslproto.py @@ -12,7 +12,7 @@ import asyncio from asyncio import log from asyncio import sslproto -from asyncio import tasks +from asyncio import protocols from test.test_asyncio import utils as test_utils from test.test_asyncio import functional as func_tests @@ -189,28 +189,28 @@ def buffer_updated(self, nsize): for usemv in [False, True]: proto = Proto(1, usemv) - sslproto._feed_data_to_bufferred_proto(proto, b'12345') + protocols._feed_data_to_bufferred_proto(proto, b'12345') self.assertEqual(proto.data, b'12345') proto = Proto(2, usemv) - sslproto._feed_data_to_bufferred_proto(proto, b'12345') + protocols._feed_data_to_bufferred_proto(proto, b'12345') self.assertEqual(proto.data, b'12345') proto = Proto(2, usemv) - sslproto._feed_data_to_bufferred_proto(proto, b'1234') + protocols._feed_data_to_bufferred_proto(proto, b'1234') self.assertEqual(proto.data, b'1234') proto = Proto(4, usemv) - sslproto._feed_data_to_bufferred_proto(proto, b'1234') + protocols._feed_data_to_bufferred_proto(proto, b'1234') self.assertEqual(proto.data, b'1234') proto = Proto(100, usemv) - sslproto._feed_data_to_bufferred_proto(proto, b'12345') + protocols._feed_data_to_bufferred_proto(proto, b'12345') self.assertEqual(proto.data, b'12345') proto = Proto(0, usemv) with self.assertRaisesRegex(RuntimeError, 'empty buffer'): - sslproto._feed_data_to_bufferred_proto(proto, b'12345') + protocols._feed_data_to_bufferred_proto(proto, b'12345') def test_start_tls_client_reg_proto_1(self): HELLO_MSG = b'1' * self.PAYLOAD_SIZE diff --git a/Misc/NEWS.d/next/Library/2018-06-07-12-24-11.bpo-33694.NjSAqI.rst b/Misc/NEWS.d/next/Library/2018-06-07-12-24-11.bpo-33694.NjSAqI.rst new file mode 100644 index 00000000000000..903328bcecc155 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2018-06-07-12-24-11.bpo-33694.NjSAqI.rst @@ -0,0 +1,4 @@ +Refactor Proactor's data receiving code path. The idea is to always use +recv_into() and to emulate the old Protocol.data_received() path. This way +there is no read cancellation race or data loss in set_protocol() and +there's no performance degradation for both Protocol and BufferedProtocol.