diff --git a/docs/source/reference-core.rst b/docs/source/reference-core.rst index d816767caf..e8696975b7 100644 --- a/docs/source/reference-core.rst +++ b/docs/source/reference-core.rst @@ -1546,6 +1546,8 @@ Exceptions and warnings .. autoexception:: ResourceBusyError +.. autoexception:: ClosedResourceError + .. autoexception:: RunFinishedError .. autoexception:: TrioInternalError diff --git a/docs/source/reference-hazmat.rst b/docs/source/reference-hazmat.rst index cfc13519b7..094f5901cc 100644 --- a/docs/source/reference-hazmat.rst +++ b/docs/source/reference-hazmat.rst @@ -151,6 +151,28 @@ All environments provide the following functions: :raises trio.ResourceBusyError: if another task is already waiting for the given socket to become writable. + :raises trio.ClosedResourceError: + if another task calls :func:`notify_socket_close` while this + function is still working. + + +.. function:: notify_socket_close(sock) + + Notifies Trio's internal I/O machinery that you are about to close + a socket. + + This causes any operations currently waiting for this socket to + immediately raise :exc:`~trio.ClosedResourceError`. + + This does *not* actually close the socket. Generally when closing a + socket, you should first call this function, and then close the + socket. + + The given object *must* be exactly of type :func:`socket.socket`, + nothing else. + + :raises TypeError: + if the given object is not of type :func:`socket.socket`. Unix-specific API @@ -174,6 +196,9 @@ Unix-like systems provide the following functions: :raises trio.ResourceBusyError: if another task is already waiting for the given fd to become readable. + :raises trio.ClosedResourceError: + if another task calls :func:`notify_fd_close` while this + function is still working. .. function:: wait_writable(fd) @@ -192,6 +217,21 @@ Unix-like systems provide the following functions: :raises trio.ResourceBusyError: if another task is already waiting for the given fd to become writable. + :raises trio.ClosedResourceError: + if another task calls :func:`notify_fd_close` while this + function is still working. + +.. function:: notify_fd_close(fd) + + Notifies Trio's internal I/O machinery that you are about to close + a file descriptor. + + This causes any operations currently waiting for this file + descriptor to immediately raise :exc:`~trio.ClosedResourceError`. + + This does *not* actually close the file descriptor. Generally when + closing a file descriptor, you should first call this function, and + then actually close it. Kqueue-specific API diff --git a/docs/source/reference-io.rst b/docs/source/reference-io.rst index 9dda5b74cc..3ac4d819f2 100644 --- a/docs/source/reference-io.rst +++ b/docs/source/reference-io.rst @@ -148,8 +148,6 @@ Abstract base classes .. autoexception:: BrokenStreamError -.. autoexception:: ClosedStreamError - .. currentmodule:: trio.abc .. autoclass:: trio.abc.Listener @@ -158,9 +156,6 @@ Abstract base classes .. currentmodule:: trio -.. autoexception:: ClosedListenerError - - Generic stream tools ~~~~~~~~~~~~~~~~~~~~ diff --git a/newsfragments/36.feature.rst b/newsfragments/36.feature.rst new file mode 100644 index 0000000000..0a09579a3f --- /dev/null +++ b/newsfragments/36.feature.rst @@ -0,0 +1,18 @@ +Suppose one task is blocked trying to use a resource – for example, +reading from a socket – and while it's doing this, another task closes +the resource. Previously, this produced undefined behavior. Now, +closing a resource causes pending operations on that resource to +terminate immediately with a :exc:`ClosedResourceError`. + +``ClosedStreamError`` and ``ClosedListenerError`` are now aliases for +:exc:`ClosedResourceError`, and deprecated. + +For this to work, Trio needs to know when a resource has been closed. +To facilitate this, new functions have been added: +:func:`trio.hazmat.notify_fd_close` and +:func:`trio.hazmat.notify_socket_close`. If you're using Trio's +built-in wrappers like :cls:`~trio.SocketStream` or +:mod:`trio.socket`, then you don't need to worry about this, but if +you're using the low-level functions like +:func:`trio.hazmat.wait_readable`, you should make sure to call these +functions at appropriate times. diff --git a/trio/__init__.py b/trio/__init__.py index 22e9de6e62..7ea2df21eb 100644 --- a/trio/__init__.py +++ b/trio/__init__.py @@ -66,6 +66,24 @@ from . import ssl # Not imported by default: testing +_deprecate.enable_attribute_deprecations(__name__) +__deprecated_attributes__ = { + "ClosedStreamError": + _deprecate.DeprecatedAttribute( + ClosedResourceError, + "0.5.0", + issue=36, + instead=ClosedResourceError + ), + "ClosedListenerError": + _deprecate.DeprecatedAttribute( + ClosedResourceError, + "0.5.0", + issue=36, + instead=ClosedResourceError + ), +} + _deprecate.enable_attribute_deprecations(hazmat.__name__) # Temporary hack to make sure _result is loaded, just during the deprecation diff --git a/trio/_abc.py b/trio/_abc.py index 6816227d3f..0ec724e335 100644 --- a/trio/_abc.py +++ b/trio/_abc.py @@ -260,6 +260,10 @@ async def aclose(self): If the resource is already closed, then this method should silently succeed. + Once this method completes, any other pending or future operations on + this resource should generally raise :exc:`~trio.ClosedResourceError`, + unless there's a good reason to do otherwise. + See also: :func:`trio.aclose_forcefully`. """ @@ -297,7 +301,9 @@ async def send_all(self, data): :meth:`HalfCloseableStream.send_eof` on this stream. trio.BrokenStreamError: if something has gone wrong, and the stream is broken. - trio.ClosedStreamError: if you already closed this stream object. + trio.ClosedResourceError: if you previously closed this stream + object, or if another task closes this stream object while + :meth:`send_all` is running. Most low-level operations in trio provide a guarantee: if they raise :exc:`trio.Cancelled`, this means that they had no effect, so the @@ -328,7 +334,9 @@ async def wait_send_all_might_not_block(self): :meth:`HalfCloseableStream.send_eof` on this stream. trio.BrokenStreamError: if something has gone wrong, and the stream is broken. - trio.ClosedStreamError: if you already closed this stream object. + trio.ClosedResourceError: if you previously closed this stream + object, or if another task closes this stream object while + :meth:`wait_send_all_might_not_block` is running. Note: @@ -402,7 +410,9 @@ async def receive_some(self, max_bytes): :meth:`receive_some` on the same stream at the same time. trio.BrokenStreamError: if something has gone wrong, and the stream is broken. - trio.ClosedStreamError: if you already closed this stream object. + trio.ClosedResourceError: if you previously closed this stream + object, or if another task closes this stream object while + :meth:`receive_some` is running. """ @@ -470,7 +480,9 @@ async def send_eof(self): :meth:`send_eof` on this stream. trio.BrokenStreamError: if something has gone wrong, and the stream is broken. - trio.ClosedStreamError: if you already closed this stream object. + trio.ClosedResourceError: if you previously closed this stream + object, or if another task closes this stream object while + :meth:`send_eof` is running. """ @@ -494,7 +506,9 @@ async def accept(self): Raises: trio.ResourceBusyError: if two tasks attempt to call :meth:`accept` on the same listener at the same time. - trio.ClosedListenerError: if you already closed this listener. + trio.ClosedResourceError: if you previously closed this listener + object, or if another task closes this listener object while + :meth:`accept` is running. Note that there is no ``BrokenListenerError``, because for listeners there is no general condition of "the network/remote peer broke the diff --git a/trio/_core/__init__.py b/trio/_core/__init__.py index 8ed18b1275..25779eef04 100644 --- a/trio/_core/__init__.py +++ b/trio/_core/__init__.py @@ -49,4 +49,11 @@ async def wait_socket_writable(sock): raise TypeError("need a socket") await wait_writable(sock) - __all__ += ["wait_socket_readable", "wait_socket_writable"] + def notify_socket_close(sock): + if type(sock) != _stdlib_socket.socket: + raise TypeError("need a socket") + notify_fd_close(sock) + + __all__ += [ + "wait_socket_readable", "wait_socket_writable", "notify_socket_close" + ] diff --git a/trio/_core/_exceptions.py b/trio/_core/_exceptions.py index 8307608d51..eea9297b33 100644 --- a/trio/_core/_exceptions.py +++ b/trio/_core/_exceptions.py @@ -7,6 +7,7 @@ "WouldBlock", "Cancelled", "ResourceBusyError", + "ClosedResourceError", ] @@ -103,3 +104,16 @@ class ResourceBusyError(Exception): the data get scrambled. """ + + +class ClosedResourceError(Exception): + """Raised when attempting to use a resource after it has been closed. + + Note that "closed" here means that *your* code closed the resource, + generally by calling a method with a name like ``close`` or ``aclose``, or + by exiting a context manager. If a problem arises elsewhere – for example, + because of a network failure, or because a remote peer closed their end of + a connection – then that should be indicated by a different exception + class, like :exc:`BrokenStreamError` or an :exc:`OSError` subclass. + + """ diff --git a/trio/_core/_io_epoll.py b/trio/_core/_io_epoll.py index da605dd701..23f987c9a8 100644 --- a/trio/_core/_io_epoll.py +++ b/trio/_core/_io_epoll.py @@ -1,5 +1,6 @@ import select import attr +import outcome from .. import _core from . import _public @@ -122,3 +123,26 @@ async def wait_readable(self, fd): @_public async def wait_writable(self, fd): await self._epoll_wait(fd, "write_task") + + @_public + def notify_fd_close(self, fd): + if not isinstance(fd, int): + fd = fd.fileno() + if fd not in self._registered: + return + + waiters = self._registered[fd] + + def interrupt(task): + exc = _core.ClosedResourceError("another task closed this fd") + _core.reschedule(task, outcome.Error(exc)) + + if waiters.write_task is not None: + interrupt(waiters.write_task) + waiters.write_task = None + + if waiters.read_task is not None: + interrupt(waiters.read_task) + waiters.read_task = None + + self._update_registrations(fd, True) diff --git a/trio/_core/_io_kqueue.py b/trio/_core/_io_kqueue.py index 79a606a1a2..14cf810b58 100644 --- a/trio/_core/_io_kqueue.py +++ b/trio/_core/_io_kqueue.py @@ -133,3 +133,28 @@ async def wait_readable(self, fd): @_public async def wait_writable(self, fd): await self._wait_common(fd, select.KQ_FILTER_WRITE) + + @_public + def notify_fd_close(self, fd): + if not isinstance(fd, int): + fd = fd.fileno() + + for filter in [select.KQ_FILTER_READ, select.KQ_FILTER_WRITE]: + key = (fd, filter) + receiver = self._registered.get(key) + + if receiver is None: + continue + + if type(receiver) is _core.Task: + event = select.kevent(fd, filter, select.KQ_EV_DELETE) + self._kqueue.control([event], 0) + exc = _core.ClosedResourceError("another task closed this fd") + _core.reschedule(receiver, outcome.Error(exc)) + del self._registered[key] + else: + # XX this is an interesting example of a case where being able + # to close a queue would be useful... + raise NotImplementedError( + "can't close an fd that monitor_kevent is using" + ) diff --git a/trio/_core/_io_windows.py b/trio/_core/_io_windows.py index cf4d36da2e..90a79c032d 100644 --- a/trio/_core/_io_windows.py +++ b/trio/_core/_io_windows.py @@ -382,6 +382,19 @@ async def wait_socket_readable(self, sock): async def wait_socket_writable(self, sock): await self._wait_socket("write", sock) + @_public + def notify_socket_close(self, sock): + if type(sock) is not stdlib_socket.socket: + raise TypeError("need a stdlib socket") + + for mode in ["read", "write"]: + if sock in self._socket_waiters[mode]: + task = self._socket_waiters[mode].pop(sock) + exc = _core.ClosedResourceError( + "another task closed this socket" + ) + _core.reschedule(task, outcome.Error(exc)) + # This has cffi-isms in it and is untested... but it demonstrates the # logic we'll want when we start actually using overlapped I/O. # diff --git a/trio/_core/tests/test_io.py b/trio/_core/tests/test_io.py index c19c0924b5..dad67ab8fc 100644 --- a/trio/_core/tests/test_io.py +++ b/trio/_core/tests/test_io.py @@ -38,6 +38,7 @@ def socketpair(): wait_readable_options = [_core.wait_socket_readable] wait_writable_options = [_core.wait_socket_writable] +notify_close_options = [_core.notify_socket_close] if hasattr(_core, "wait_readable"): wait_readable_options.append(_core.wait_readable) @@ -52,23 +53,38 @@ async def wait_writable_fd(fileobj): return await _core.wait_writable(fileobj.fileno()) wait_writable_options.append(wait_writable_fd) +if hasattr(_core, "notify_fd_close"): + notify_close_options.append(_core.notify_fd_close) -# Decorators that feed in different settings for wait_readable / wait_writable. -# Note that if you use both decorators on the same test, it will run all -# N**2 *combinations* + def notify_fd_close_rawfd(fileobj): + _core.notify_fd_close(fileobj.fileno()) + + notify_close_options.append(notify_fd_close_rawfd) + +# Decorators that feed in different settings for wait_readable / wait_writable +# / notify_close. +# Note that if you use all three decorators on the same test, it will run all +# N**3 *combinations* read_socket_test = pytest.mark.parametrize( "wait_readable", wait_readable_options, ids=lambda fn: fn.__name__ ) write_socket_test = pytest.mark.parametrize( "wait_writable", wait_writable_options, ids=lambda fn: fn.__name__ ) +notify_close_test = pytest.mark.parametrize( + "notify_close", notify_close_options, ids=lambda fn: fn.__name__ +) async def test_wait_socket_type_checking(socketpair): a, b = socketpair # wait_socket_* accept actual socket objects, only - for sock_fn in [_core.wait_socket_readable, _core.wait_socket_writable]: + for sock_fn in [ + _core.wait_socket_readable, + _core.wait_socket_writable, + _core.notify_socket_close, + ]: with pytest.raises(TypeError): await sock_fn(a.fileno()) @@ -179,6 +195,31 @@ async def test_double_write(socketpair, wait_writable): nursery.cancel_scope.cancel() +@read_socket_test +@write_socket_test +@notify_close_test +async def test_interrupted_by_close( + socketpair, wait_readable, wait_writable, notify_close +): + a, b = socketpair + + async def reader(): + with pytest.raises(_core.ClosedResourceError): + await wait_readable(a) + + async def writer(): + with pytest.raises(_core.ClosedResourceError): + await wait_writable(a) + + fill_socket(a) + + async with _core.open_nursery() as nursery: + nursery.start_soon(reader) + nursery.start_soon(writer) + await wait_all_tasks_blocked() + notify_close(a) + + @read_socket_test @write_socket_test async def test_socket_simultaneous_read_write( diff --git a/trio/_highlevel_generic.py b/trio/_highlevel_generic.py index fd92abc3bd..bda261f3dc 100644 --- a/trio/_highlevel_generic.py +++ b/trio/_highlevel_generic.py @@ -6,8 +6,6 @@ __all__ = [ "aclose_forcefully", "BrokenStreamError", - "ClosedStreamError", - "ClosedListenerError", "StapledStream", ] @@ -50,7 +48,7 @@ class BrokenStreamError(Exception): the remote side has already closed the connection. You *don't* get this error if *you* closed the stream – in that case you - get :class:`ClosedStreamError`. + get :class:`ClosedResourceError`. This exception's ``__cause__`` attribute will often contain more information about the underlying error. @@ -59,36 +57,6 @@ class BrokenStreamError(Exception): pass -class ClosedStreamError(Exception): - """Raised when an attempt to use a stream fails because the stream was - already closed locally. - - You *only* get this error if *your* code closed the stream object you're - attempting to use by calling - :meth:`~trio.abc.AsyncResource.aclose` or - similar. (:meth:`~trio.abc.SendStream.send_all` might also raise this if - you already called :meth:`~trio.abc.HalfCloseableStream.send_eof`.) - Therefore this exception generally indicates a bug in your code. - - If a problem arises elsewhere, for example due to a network failure or a - misbehaving peer, then you get :class:`BrokenStreamError` instead. - - """ - pass - - -class ClosedListenerError(Exception): - """Raised when an attempt to use a listener fails because it was already - closed locally. - - You *only* get this error if *your* code closed the stream object you're - attempting to use by calling :meth:`~trio.abc.AsyncResource.aclose` or - similar. Therefore this exception generally indicates a bug in your code. - - """ - pass - - @attr.s(cmp=False, hash=False) class StapledStream(HalfCloseableStream): """This class `staples `__ diff --git a/trio/_highlevel_socket.py b/trio/_highlevel_socket.py index 25a25f1576..831ab303d9 100644 --- a/trio/_highlevel_socket.py +++ b/trio/_highlevel_socket.py @@ -7,9 +7,7 @@ from . import socket as tsocket from ._util import ConflictDetector from .abc import HalfCloseableStream, Listener -from ._highlevel_generic import ( - ClosedStreamError, BrokenStreamError, ClosedListenerError -) +from ._highlevel_generic import BrokenStreamError __all__ = ["SocketStream", "SocketListener"] @@ -27,7 +25,9 @@ def _translate_socket_errors_to_stream_errors(): yield except OSError as exc: if exc.errno in _closed_stream_errnos: - raise ClosedStreamError("this socket was already closed") from None + raise _core.ClosedResourceError( + "this socket was already closed" + ) from None else: raise BrokenStreamError( "socket connection broken: {}".format(exc) @@ -102,7 +102,9 @@ def __init__(self, socket): async def send_all(self, data): if self.socket.did_shutdown_SHUT_WR: await _core.checkpoint() - raise ClosedStreamError("can't send data after sending EOF") + raise _core.ClosedResourceError( + "can't send data after sending EOF" + ) with self._send_conflict_detector.sync: with _translate_socket_errors_to_stream_errors(): with memoryview(data) as data: @@ -118,7 +120,7 @@ async def send_all(self, data): async def wait_send_all_might_not_block(self): async with self._send_conflict_detector: if self.socket.fileno() == -1: - raise ClosedStreamError + raise _core.ClosedResourceError with _translate_socket_errors_to_stream_errors(): await self.socket.wait_writable() @@ -362,7 +364,7 @@ async def accept(self): Raises: OSError: if the underlying call to ``accept`` raises an unexpected error. - ClosedListenerError: if you already closed the socket. + ClosedResourceError: if you already closed the socket. This method handles routine errors like ``ECONNABORTED``, but passes other errors on to its caller. In particular, it does *not* make any @@ -375,7 +377,7 @@ async def accept(self): sock, _ = await self.socket.accept() except OSError as exc: if exc.errno in _closed_stream_errnos: - raise ClosedListenerError + raise _core.ClosedResourceError if exc.errno not in _ignorable_accept_errnos: raise else: diff --git a/trio/_socket.py b/trio/_socket.py index 9c3a13bd08..eba29366b3 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -415,7 +415,6 @@ def __init__(self, sock): "getsockopt", "setsockopt", "listen", - "close", "share", } @@ -461,6 +460,10 @@ def dup(self): """ return _SocketType(self._sock.dup()) + def close(self): + _core.notify_socket_close(self._sock) + self._sock.close() + async def bind(self, address): await _core.checkpoint() address = await self._resolve_local_address(address) diff --git a/trio/_ssl.py b/trio/_ssl.py index 5e6d2615fa..80f56714f1 100644 --- a/trio/_ssl.py +++ b/trio/_ssl.py @@ -155,9 +155,7 @@ from . import _core from .abc import Stream, Listener -from ._highlevel_generic import ( - BrokenStreamError, ClosedStreamError, aclose_forcefully -) +from ._highlevel_generic import BrokenStreamError, aclose_forcefully from . import _sync from ._util import ConflictDetector @@ -413,7 +411,7 @@ def _check_status(self): elif self._state is _State.BROKEN: raise BrokenStreamError elif self._state is _State.CLOSED: - raise ClosedStreamError + raise _core.ClosedResourceError else: # pragma: no cover assert False @@ -774,11 +772,15 @@ async def aclose(self): # letting that happen. But if you start seeing it, then hopefully # this will give you a little head start on tracking it down, # because whoa did this puzzle us at the 2017 PyCon sprints. + # + # Also, if someone else is blocked in send/receive, then we aren't + # going to be able to do a clean shutdown. If that happens, we'll + # just do an unclean shutdown. try: await self._retry( self._ssl_object.unwrap, ignore_want_read=True ) - except BrokenStreamError: + except (BrokenStreamError, _core.ResourceBusyError): pass except: # Failure! Kill the stream and move on. diff --git a/trio/_toplevel_core_reexports.py b/trio/_toplevel_core_reexports.py index 27a64e7098..c5754e518d 100644 --- a/trio/_toplevel_core_reexports.py +++ b/trio/_toplevel_core_reexports.py @@ -19,6 +19,7 @@ "WouldBlock", "Cancelled", "ResourceBusyError", + "ClosedResourceError", "MultiError", "run", "open_nursery", diff --git a/trio/hazmat.py b/trio/hazmat.py index f474ea41e2..e5a417ceb8 100644 --- a/trio/hazmat.py +++ b/trio/hazmat.py @@ -20,14 +20,16 @@ "add_instrument", "current_clock", "current_statistics", - "wait_writable", - "wait_readable", "ParkingLot", "UnboundedQueue", "RunLocal", "RunVar", + "wait_writable", + "wait_readable", + "notify_fd_close", "wait_socket_readable", "wait_socket_writable", + "notify_socket_close", "TrioToken", "current_trio_token", # kqueue symbols diff --git a/trio/testing/_check_streams.py b/trio/testing/_check_streams.py index 9bff4df35b..9c245aa5bd 100644 --- a/trio/testing/_check_streams.py +++ b/trio/testing/_check_streams.py @@ -4,9 +4,7 @@ import random from .. import _core -from .._highlevel_generic import ( - BrokenStreamError, ClosedStreamError, aclose_forcefully -) +from .._highlevel_generic import BrokenStreamError, aclose_forcefully from .._abc import SendStream, ReceiveStream, Stream, HalfCloseableStream from ._checkpoints import assert_checkpoints @@ -151,8 +149,8 @@ async def expect_broken_stream_on_send(): with _assert_raises(BrokenStreamError): await do_send_all(b"x" * 100) - # r closed -> ClosedStreamError on the receive side - with _assert_raises(ClosedStreamError): + # r closed -> ClosedResourceError on the receive side + with _assert_raises(_core.ClosedResourceError): await do_receive_some(4096) # we can close the same stream repeatedly, it's fine @@ -162,12 +160,12 @@ async def expect_broken_stream_on_send(): # closing the sender side await do_aclose(s) - # now trying to send raises ClosedStreamError - with _assert_raises(ClosedStreamError): + # now trying to send raises ClosedResourceError + with _assert_raises(_core.ClosedResourceError): await do_send_all(b"x" * 100) # ditto for wait_send_all_might_not_block - with _assert_raises(ClosedStreamError): + with _assert_raises(_core.ClosedResourceError): with assert_checkpoints(): await s.wait_send_all_might_not_block() @@ -202,13 +200,13 @@ async def receive_send_then_close(): while True: await do_send_all(b"x" * 100) - with _assert_raises(ClosedStreamError): + with _assert_raises(_core.ClosedResourceError): await do_receive_some(4096) async with _ForceCloseBoth(await stream_maker()) as (s, r): await aclose_forcefully(s) - with _assert_raises(ClosedStreamError): + with _assert_raises(_core.ClosedResourceError): await do_send_all(b"123") # after the sender does a forceful close, the receiver might either @@ -229,10 +227,10 @@ async def receive_send_then_close(): scope.cancel() await s.aclose() - with _assert_raises(ClosedStreamError): + with _assert_raises(_core.ClosedResourceError): await do_send_all(b"123") - with _assert_raises(ClosedStreamError): + with _assert_raises(_core.ClosedResourceError): await do_receive_some(4096) # Check that we can still gracefully close a stream after an operation has @@ -260,6 +258,19 @@ async def expect_cancelled(afn, *args): nursery.start_soon(do_aclose, s) nursery.start_soon(do_aclose, r) + # Check that if a task is blocked in receive_some, then closing the + # receive stream causes it to wake up. + async with _ForceCloseBoth(await stream_maker()) as (s, r): + + async def receive_expecting_closed(): + with _assert_raises(_core.ClosedResourceError): + await r.receive_some(10) + + async with _core.open_nursery() as nursery: + nursery.start_soon(receive_expecting_closed) + await _core.wait_all_tasks_blocked() + await aclose_forcefully(r) + # check wait_send_all_might_not_block, if we can if clogged_stream_maker is not None: async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r): @@ -341,6 +352,24 @@ async def receiver(): except BrokenStreamError: pass + # Check that if a task is blocked in a send-side method, then closing + # the send stream causes it to wake up. + async def close_soon(s): + await _core.wait_all_tasks_blocked() + await aclose_forcefully(s) + + async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r): + async with _core.open_nursery() as nursery: + nursery.start_soon(close_soon, s) + with _assert_raises(_core.ClosedResourceError): + await s.send_all(b"xyzzy") + + async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r): + async with _core.open_nursery() as nursery: + nursery.start_soon(close_soon, s) + with _assert_raises(_core.ClosedResourceError): + await s.wait_send_all_might_not_block() + async def check_two_way_stream(stream_maker, clogged_stream_maker): """Perform a number of generic tests on a custom two-way stream @@ -448,7 +477,7 @@ async def expect_x_then_eof(r): nursery.start_soon(expect_x_then_eof, s2) # now sending is disallowed - with _assert_raises(ClosedStreamError): + with _assert_raises(_core.ClosedResourceError): await s1.send_all(b"y") # but we can do send_eof again diff --git a/trio/testing/_memory_streams.py b/trio/testing/_memory_streams.py index f711aaadb6..8caede6c0f 100644 --- a/trio/testing/_memory_streams.py +++ b/trio/testing/_memory_streams.py @@ -1,9 +1,7 @@ import operator from .. import _core -from .._highlevel_generic import ( - ClosedStreamError, BrokenStreamError, StapledStream -) +from .._highlevel_generic import BrokenStreamError, StapledStream from .. import _util from ..abc import SendStream, ReceiveStream @@ -31,13 +29,21 @@ def __init__(self): "another task is already fetching data" ) + # This object treats "close" as being like closing the send side of a + # channel: so after close(), calling put() raises ClosedResourceError, and + # calling the get() variants drains the buffer and then returns an empty + # bytearray. def close(self): self._closed = True self._lot.unpark_all() + def close_and_wipe(self): + self._data = bytearray() + self.close() + def put(self, data): if self._closed: - raise ClosedStreamError("virtual connection closed") + raise _core.ClosedResourceError("virtual connection closed") self._data += data self._lot.unpark_all() @@ -144,6 +150,14 @@ def close(self): (if any). """ + # XXX should this cancel any pending calls to the send_all_hook and + # wait_send_all_might_not_block_hook? Those are the only places where + # send_all and wait_send_all_might_not_block can be blocked. + # + # The way we set things up, send_all_hook is memory_stream_pump, and + # wait_send_all_might_not_block_hook is unset. memory_stream_pump is + # synchronous. So normally, send_all and wait_send_all_might_not_block + # cannot block at all. self._outgoing.close() if self.close_hook is not None: self.close_hook() @@ -222,23 +236,25 @@ async def receive_some(self, max_bytes): if max_bytes is None: raise TypeError("max_bytes must not be None") if self._closed: - raise ClosedStreamError + raise _core.ClosedResourceError if self.receive_some_hook is not None: await self.receive_some_hook() - return await self._incoming.get(max_bytes) + # self._incoming's closure state tracks whether we got an EOF. + # self._closed tracks whether we, ourselves, are closed. + # self.close() sends an EOF to wake us up and sets self._closed, + # so after we wake up we have to check self._closed again. + data = await self._incoming.get(max_bytes) + if self._closed: + raise _core.ClosedResourceError + return data def close(self): """Discards any pending data from the internal buffer, and marks this stream as closed. """ - # discard any pending data self._closed = True - try: - self._incoming.get_nowait() - except _core.WouldBlock: - pass - self._incoming.close() + self._incoming.close_and_wipe() if self.close_hook is not None: self.close_hook() @@ -292,7 +308,7 @@ def memory_stream_pump( memory_recieve_stream.put_eof() else: memory_recieve_stream.put_data(data) - except ClosedStreamError: + except _core.ClosedResourceError: raise BrokenStreamError("MemoryReceiveStream was closed") return True @@ -445,12 +461,17 @@ def __init__(self): def _something_happened(self): self._waiters.unpark_all() + # Always wakes up when one side is closed, because everyone always reacts + # to that. async def _wait_for(self, fn): - while not fn(): + while True: + if fn(): + break + if self._sender_closed or self._receiver_closed: + break await self._waiters.park() def close_sender(self): - # close while send_all is in progress is undefined self._sender_closed = True self._something_happened() @@ -461,27 +482,27 @@ def close_receiver(self): async def send_all(self, data): async with self._send_conflict_detector: if self._sender_closed: - raise ClosedStreamError + raise _core.ClosedResourceError if self._receiver_closed: raise BrokenStreamError assert not self._data self._data += data self._something_happened() - await self._wait_for( - lambda: not self._data or self._receiver_closed - ) + await self._wait_for(lambda: not self._data) + if self._sender_closed: + raise _core.ClosedResourceError if self._data and self._receiver_closed: raise BrokenStreamError async def wait_send_all_might_not_block(self): async with self._send_conflict_detector: if self._sender_closed: - raise ClosedStreamError + raise _core.ClosedResourceError if self._receiver_closed: return - await self._wait_for( - lambda: self._receiver_waiting or self._receiver_closed - ) + await self._wait_for(lambda: self._receiver_waiting) + if self._sender_closed: + raise _core.ClosedResourceError async def receive_some(self, max_bytes): async with self._receive_conflict_detector: @@ -491,14 +512,16 @@ async def receive_some(self, max_bytes): raise ValueError("max_bytes must be >= 1") # State validation if self._receiver_closed: - raise ClosedStreamError + raise _core.ClosedResourceError # Wake wait_send_all_might_not_block and wait for data self._receiver_waiting = True self._something_happened() try: - await self._wait_for(lambda: self._data or self._sender_closed) + await self._wait_for(lambda: self._data) finally: self._receiver_waiting = False + if self._receiver_closed: + raise _core.ClosedResourceError # Get data, possibly waking send_all if self._data: got = self._data[:max_bytes] diff --git a/trio/tests/test_highlevel_socket.py b/trio/tests/test_highlevel_socket.py index aafa9330e0..7c01aab70d 100644 --- a/trio/tests/test_highlevel_socket.py +++ b/trio/tests/test_highlevel_socket.py @@ -8,7 +8,6 @@ from ..testing import ( check_half_closeable_stream, wait_all_tasks_blocked, assert_checkpoints ) -from .._highlevel_generic import ClosedListenerError from .._highlevel_socket import * from .. import socket as tsocket @@ -185,7 +184,7 @@ async def test_SocketListener(): await listener.aclose() with assert_checkpoints(): - with pytest.raises(ClosedListenerError): + with pytest.raises(_core.ClosedResourceError): await listener.accept() client_sock.close() @@ -203,7 +202,7 @@ async def test_SocketListener_socket_closed_underfoot(): # SocketListener gives correct error with assert_checkpoints(): - with pytest.raises(ClosedListenerError): + with pytest.raises(_core.ClosedResourceError): await listener.accept() diff --git a/trio/tests/test_socket.py b/trio/tests/test_socket.py index efb40391e4..fdcb8dfa6b 100644 --- a/trio/tests/test_socket.py +++ b/trio/tests/test_socket.py @@ -839,3 +839,33 @@ async def check_AF_UNIX(path): except FileNotFoundError: # MacOS doesn't support abstract filenames with the leading NUL byte pass + + +async def test_interrupted_by_close(): + a_stdlib, b_stdlib = stdlib_socket.socketpair() + with a_stdlib, b_stdlib: + a_stdlib.setblocking(False) + + data = b"x" * 99999 + + try: + while True: + a_stdlib.send(data) + except BlockingIOError: + pass + + a = tsocket.from_stdlib_socket(a_stdlib) + + async def sender(): + with pytest.raises(_core.ClosedResourceError): + await a.send(data) + + async def receiver(): + with pytest.raises(_core.ClosedResourceError): + await a.recv(1) + + async with _core.open_nursery() as nursery: + nursery.start_soon(sender) + nursery.start_soon(receiver) + await wait_all_tasks_blocked() + a.close() diff --git a/trio/tests/test_ssl.py b/trio/tests/test_ssl.py index 60e42cb29c..2b5a083143 100644 --- a/trio/tests/test_ssl.py +++ b/trio/tests/test_ssl.py @@ -14,9 +14,8 @@ import trio from .. import _core from .._highlevel_socket import SocketStream, SocketListener -from .._highlevel_generic import ( - BrokenStreamError, ClosedStreamError, aclose_forcefully -) +from .._highlevel_generic import BrokenStreamError, aclose_forcefully +from .._core import ClosedResourceError from .._highlevel_open_tcp_stream import open_tcp_stream from .. import ssl as tssl from .. import socket as tsocket @@ -853,7 +852,7 @@ async def server_closer(): nursery.start_soon(server_closer) # closing the SSLStream also closes its transport - with pytest.raises(ClosedStreamError): + with pytest.raises(ClosedResourceError): await client_transport.send_all(b"123") # once closed, it's OK to close again @@ -864,21 +863,21 @@ async def server_closer(): # Trying to send more data does not work with assert_checkpoints(): - with pytest.raises(ClosedStreamError): + with pytest.raises(ClosedResourceError): await server_ssl.send_all(b"123") # And once the connection is has been closed *locally*, then instead of # getting empty bytestrings we get a proper error with assert_checkpoints(): - with pytest.raises(ClosedStreamError): + with pytest.raises(ClosedResourceError): await client_ssl.receive_some(10) == b"" with assert_checkpoints(): - with pytest.raises(ClosedStreamError): + with pytest.raises(ClosedResourceError): await client_ssl.unwrap() with assert_checkpoints(): - with pytest.raises(ClosedStreamError): + with pytest.raises(ClosedResourceError): await client_ssl.do_handshake() # Check that a graceful close *before* handshaking gives a clean EOF on diff --git a/trio/tests/test_testing.py b/trio/tests/test_testing.py index d1d8adf0e7..5f4a6e84e1 100644 --- a/trio/tests/test_testing.py +++ b/trio/tests/test_testing.py @@ -9,7 +9,7 @@ from .._core.tests.tutil import have_ipv6 from .. import sleep from .. import _core -from .._highlevel_generic import ClosedStreamError, aclose_forcefully +from .._highlevel_generic import aclose_forcefully from ..testing import * from ..testing._check_streams import _assert_raises from ..testing._memory_streams import _UnboundedByteQueue @@ -477,7 +477,7 @@ async def getter(expect): # Closing ubq.close() - with pytest.raises(ClosedStreamError): + with pytest.raises(_core.ClosedResourceError): ubq.put(b"---") assert ubq.get_nowait(10) == b"" @@ -547,7 +547,7 @@ async def do_send_all_count_resourcebusy(): assert await mss.get_data() == b"xxx" assert await mss.get_data() == b"" - with pytest.raises(ClosedStreamError): + with pytest.raises(_core.ClosedResourceError): await do_send_all(b"---") # hooks @@ -620,7 +620,7 @@ async def do_receive_some(max_bytes): assert await do_receive_some(10) == b"" assert await do_receive_some(10) == b"" - with pytest.raises(ClosedStreamError): + with pytest.raises(_core.ClosedResourceError): mrs.put_data(b"---") async def receive_some_hook(): @@ -649,7 +649,7 @@ def close_hook(): await mrs2.aclose() assert record == ["closed"] - with pytest.raises(ClosedStreamError): + with pytest.raises(_core.ClosedResourceError): await mrs2.receive_some(10) @@ -657,19 +657,19 @@ async def test_MemoryRecvStream_closing(): mrs = MemoryReceiveStream() # close with no pending data mrs.close() - with pytest.raises(ClosedStreamError): + with pytest.raises(_core.ClosedResourceError): assert await mrs.receive_some(10) == b"" # repeated closes ok mrs.close() # put_data now fails - with pytest.raises(ClosedStreamError): + with pytest.raises(_core.ClosedResourceError): mrs.put_data(b"123") mrs2 = MemoryReceiveStream() # close with pending data mrs2.put_data(b"xyz") mrs2.close() - with pytest.raises(ClosedStreamError): + with pytest.raises(_core.ClosedResourceError): await mrs2.receive_some(10) @@ -706,36 +706,27 @@ async def test_memory_stream_one_way_pair(): await s.send_all(b"123") assert await r.receive_some(10) == b"123" - # This fails if we pump on r.receive_some_hook; we need to pump on s.send_all_hook - async def sender(): - await wait_all_tasks_blocked() - await s.send_all(b"abc") - async def receiver(expected): assert await r.receive_some(10) == expected + # This fails if we pump on r.receive_some_hook; we need to pump on s.send_all_hook async with _core.open_nursery() as nursery: nursery.start_soon(receiver, b"abc") - nursery.start_soon(sender) - - # And this fails if we don't pump from close_hook - async def aclose_after_all_tasks_blocked(): await wait_all_tasks_blocked() - await s.aclose() + await s.send_all(b"abc") + # And this fails if we don't pump from close_hook async with _core.open_nursery() as nursery: nursery.start_soon(receiver, b"") - nursery.start_soon(aclose_after_all_tasks_blocked) + await wait_all_tasks_blocked() + await s.aclose() s, r = memory_stream_one_way_pair() - async def close_after_all_tasks_blocked(): - await wait_all_tasks_blocked() - s.close() - async with _core.open_nursery() as nursery: nursery.start_soon(receiver, b"") - nursery.start_soon(close_after_all_tasks_blocked) + await wait_all_tasks_blocked() + s.close() s, r = memory_stream_one_way_pair()