diff --git a/.travis.yml b/.travis.yml index 4cd148217a..990f89688b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -18,6 +18,12 @@ matrix: - os: linux language: generic env: USE_PYPY_NIGHTLY=1 + # 5.7.1 has some minor bugs that are nonetheless annoying/hard to + # work around; let's wait for the next beta before we start + # testing pypy releases. + # - os: linux + # language: generic + # env: USE_PYPY_RELEASE=1 - os: osx language: generic env: MACPYTHON=3.5.3 diff --git a/MANIFEST.in b/MANIFEST.in index 1051aeb219..ca841cc0c8 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -2,5 +2,6 @@ include LICENSE LICENSE.MIT LICENSE.APACHE2 include README.rst include CODE_OF_CONDUCT.md include test-requirements.txt +recursive-include trio/tests/test_ssl_certs *.pem recursive-include docs * prune docs/build diff --git a/ci/travis.sh b/ci/travis.sh index eb8b78cc70..9dcb7cdb27 100755 --- a/ci/travis.sh +++ b/ci/travis.sh @@ -3,7 +3,7 @@ set -ex if [ "$TRAVIS_OS_NAME" = "osx" ]; then - curl -o macpython.pkg https://www.python.org/ftp/python/${MACPYTHON}/python-${MACPYTHON}-macosx10.6.pkg + curl -Lo macpython.pkg https://www.python.org/ftp/python/${MACPYTHON}/python-${MACPYTHON}-macosx10.6.pkg sudo installer -pkg macpython.pkg -target / ls /Library/Frameworks/Python.framework/Versions/*/bin/ if expr "${MACPYTHON}" : 2; then @@ -18,7 +18,7 @@ if [ "$TRAVIS_OS_NAME" = "osx" ]; then fi if [ "$USE_PYPY_NIGHTLY" = "1" ]; then - curl -o pypy.tar.bz2 http://buildbot.pypy.org/nightly/py3.5/pypy-c-jit-latest-linux64.tar.bz2 + curl -Lo pypy.tar.bz2 http://buildbot.pypy.org/nightly/py3.5/pypy-c-jit-latest-linux64.tar.bz2 tar xaf pypy.tar.bz2 # something like "pypy-c-jit-89963-748aa3022295-linux64" PYPY_DIR=$(echo pypy-c-jit-*) @@ -29,6 +29,18 @@ if [ "$USE_PYPY_NIGHTLY" = "1" ]; then source testenv/bin/activate fi +if [ "$USE_PYPY_RELEASE" = "1" ]; then + curl -Lo pypy.tar.bz2 https://bitbucket.org/squeaky/portable-pypy/downloads/pypy3.5-5.7.1-beta-linux_x86_64-portable.tar.bz2 + tar xaf pypy.tar.bz2 + # something like "pypy3.5-5.7.1-beta-linux_x86_64-portable" + PYPY_DIR=$(echo pypy3.5-*) + PYTHON_EXE=$PYPY_DIR/bin/pypy3 + $PYTHON_EXE -m ensurepip + $PYTHON_EXE -m pip install virtualenv + $PYTHON_EXE -m virtualenv testenv + source testenv/bin/activate +fi + pip install -U pip setuptools wheel python setup.py sdist --formats=zip diff --git a/docs/source/reference-core.rst b/docs/source/reference-core.rst index 543d9381f6..ab87dc52c9 100644 --- a/docs/source/reference-core.rst +++ b/docs/source/reference-core.rst @@ -374,7 +374,7 @@ work that was happening within the scope that was cancelled. Looking at this, you might wonder how you can tell whether the inner block timed out – perhaps you want to do something different, like try a fallback procedure or report a failure to our caller. To make this -easier, :func:`move_on_after`'s ``__enter__`` function returns an +easier, :func:`move_on_after`\´s ``__enter__`` function returns an object representing this cancel scope, which we can use to check whether this scope caught a :exc:`Cancelled` exception:: @@ -1364,6 +1364,9 @@ don't have any special access to trio's internals.) .. autoclass:: Lock :members: +.. autoclass:: StrictFIFOLock + :members: + .. autoclass:: Condition :members: @@ -1534,12 +1537,14 @@ trio's internal scheduling decisions. Exceptions ---------- -.. autoexception:: TrioInternalError - .. autoexception:: Cancelled .. autoexception:: TooSlowError .. autoexception:: WouldBlock +.. autoexception:: ResourceBusyError + .. autoexception:: RunFinishedError + +.. autoexception:: TrioInternalError diff --git a/docs/source/reference-hazmat.rst b/docs/source/reference-hazmat.rst index fe8a1caef7..9ca67e9fa0 100644 --- a/docs/source/reference-hazmat.rst +++ b/docs/source/reference-hazmat.rst @@ -48,7 +48,7 @@ All environments provide the following functions: :raises TypeError: if the given object is not of type :func:`socket.socket`. - :raises RuntimeError: + :raises trio.ResourceBusyError: if another task is already waiting for the given socket to become readable. @@ -62,7 +62,7 @@ All environments provide the following functions: :raises TypeError: if the given object is not of type :func:`socket.socket`. - :raises RuntimeError: + :raises trio.ResourceBusyError: if another task is already waiting for the given socket to become writable. @@ -85,7 +85,7 @@ Unix-like systems provide the following functions: :arg fd: integer file descriptor, or else an object with a ``fileno()`` method - :raises RuntimeError: + :raises trio.ResourceBusyError: if another task is already waiting for the given fd to become readable. @@ -103,7 +103,7 @@ Unix-like systems provide the following functions: :arg fd: integer file descriptor, or else an object with a ``fileno()`` method - :raises RuntimeError: + :raises trio.ResourceBusyError: if another task is already waiting for the given fd to become writable. diff --git a/docs/source/reference-io.rst b/docs/source/reference-io.rst index c49ecd589e..c43313af77 100644 --- a/docs/source/reference-io.rst +++ b/docs/source/reference-io.rst @@ -1,26 +1,193 @@ +.. currentmodule:: trio + I/O in Trio =========== +.. note:: + + Please excuse our dust! `geocities-construction-worker.gif + `__ + + You're looking at the documentation for trio's development branch, + which is currently about half-way through implementing a proper + high-level networking API. If you want to know how to do networking + in trio *right now*, then you might want to jump down to read about + :mod:`trio.socket`, which is the already-working lower-level + API. Alternatively, you can read on for a (somewhat disorganized) + preview of coming attractions. + +.. _abstract-stream-api: + +The abstract Stream API +----------------------- + +Trio provides a set of abstract base classes that define a standard +interface for unidirectional and bidirectional byte streams. + +Why is this useful? Because it lets you write generic protocol +implementations that can work over arbitrary transports, and easily +create complex transport configurations. Here's some examples: + +* :class:`trio.SocketStream` wraps a raw socket (like a TCP connection + over the network), and converts it to the standard stream interface. + +* :class:`trio.ssl.SSLStream` is a "stream adapter" that can take any + object that implements the :class:`trio.abc.Stream` interface, and + convert it into an encrypted stream. In trio the standard way to + speak SSL over the network is to wrap an + :class:`~trio.ssl.SSLStream` around a :class:`~trio.SocketStream`. + +* If you spawn a subprocess then you can get a + :class:`~trio.abc.SendStream` that lets you write to its stdin, and + a :class:`~trio.abc.ReceiveStream` that lets you read from its + stdout. If for some reason you wanted to speak SSL to a subprocess, + you could use a :class:`StapledStream` to combine its stdin/stdout + into a single bidirectional :class:`~trio.abc.Stream`, and then wrap + that in an :class:`~trio.ssl.SSLStream`:: + + ssl_context = trio.ssl.create_default_context() + ssl_context.check_hostname = False + s = SSLStream(StapledStream(process.stdin, process.stdout), ssl_context) + + [Note: subprocess support is not implemented yet, but that's the + plan. Unless it is implemented, and I forgot to remove this note.] + +* It sometimes happens that you want to connect to an HTTPS server, + but you have to go through a web proxy... and the proxy also uses + HTTPS. So you end up having to do `SSL-on-top-of-SSL + `__. In + trio this is trivial – just wrap your first + :class:`~trio.ssl.SSLStream` in a second + :class:`~trio.ssl.SSLStream`:: + + # Get a raw SocketStream connection to the proxy: + s0 = await open_tcp_stream("proxy", 443) + + # Set up SSL connection to proxy: + s1 = SSLStream(s0, proxy_ssl_context, server_hostname="proxy") + # Request a connection to the website + await s1.send_all(b"CONNECT website:443 / HTTP/1.0\r\n") + await check_CONNECT_response(s1) + + # Set up SSL connection to the real website. Notice that s1 is + # already an SSLStream object, and here we're wrapping a second + # SSLStream object around it. + s2 = SSLStream(s1, website_ssl_context, server_hostname="website") + # Make our request + await s2.send_all("GET /index.html HTTP/1.0\r\n") + ... + +* The :mod:`trio.testing` module provides a set of :ref:`flexible + in-memory stream object implementations `, so if + you have a protocol implementation to test then you can can spawn + two tasks, set up a virtual "socket" connecting them, and then do + things like inject random-but-repeatable delays into the connection. + + +Abstract base classes +~~~~~~~~~~~~~~~~~~~~~ + +.. currentmodule:: trio.abc + +.. autoclass:: trio.abc.AsyncResource + :members: + +.. autoclass:: trio.abc.SendStream + :members: + :show-inheritance: + +.. autoclass:: trio.abc.ReceiveStream + :members: + :show-inheritance: + +.. autoclass:: trio.abc.Stream + :members: + :show-inheritance: + +.. autoclass:: trio.abc.HalfCloseableStream + :members: + :show-inheritance: + +.. currentmodule:: trio + +.. autoexception:: BrokenStreamError + +.. autoexception:: ClosedStreamError + + +Generic stream implementations +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Trio currently provides one generic utility class for working with +streams. And if you want to test code that's written against the +streams interface, you should also check out :ref:`testing-streams` in +:mod:`trio.testing`. + +.. autoclass:: StapledStream + :members: + :show-inheritance: + + Sockets and networking ----------------------- +~~~~~~~~~~~~~~~~~~~~~~ + +The high-level network interface is built on top of our stream +abstraction. + +.. autoclass:: SocketStream + :members: + :show-inheritance: + +.. autofunction:: socket_stream_pair + + +SSL / TLS support +~~~~~~~~~~~~~~~~~ + +.. module:: trio.ssl + +The :mod:`trio.ssl` module implements SSL/TLS support for Trio, using +the standard library :mod:`ssl` module. It re-exports most of +:mod:`ssl`\´s API, with the notable exception is +:class:`ssl.SSLContext`, which has unsafe defaults; if you really want +to use :class:`ssl.SSLContext` you can import it from :mod:`ssl`, but +normally you should create your contexts using +:func:`trio.ssl.create_default_context `. + +Instead of using :meth:`ssl.SSLContext.wrap_socket`, though, you +create a :class:`SSLStream`: + +.. autoclass:: SSLStream + :show-inheritance: + :members: + + +Low-level sockets and networking +-------------------------------- .. module:: trio.socket -The :mod:`trio.socket` module provides trio's basic networking API. +The :mod:`trio.socket` module provides trio's basic low-level +networking API. If you're doing ordinary things with stream-oriented +connections over IPv4/IPv6/Unix domain sockets, then you probably want +to stick to the high-level API described above. If you want to use +UDP, or exotic address families like ``AF_BLUETOOTH``, or otherwise +get direct access to all the quirky bits of your system's networking +API, then you're in the right place. -:mod:`trio.socket`\'s top-level exports +:mod:`trio.socket`: top-level exports ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Generally, :mod:`trio.socket`\'s API mirrors that of the standard -library :mod:`socket` module. Most constants (like ``SOL_SOCKET``) and -simple utilities (like :func:`~socket.inet_aton`) are simply -re-exported unchanged. But there are also some differences: +Generally, the API exposed by :mod:`trio.socket` mirrors that of the +standard library :mod:`socket` module. Most constants (like +``SOL_SOCKET``) and simple utilities (like :func:`~socket.inet_aton`) +are simply re-exported unchanged. But there are also some differences: -All functions that return sockets (e.g. :func:`socket.socket`, -:func:`socket.socketpair`, ...) are modified to return trio sockets -instead. In addition, there is a new function to directly convert a -standard library socket into a trio socket: +All functions that return socket objects (e.g. :func:`socket.socket`, +:func:`socket.socketpair`, ...) are modified to return trio socket +objects instead. In addition, there is a new function to directly +convert a standard library socket into a trio socket: .. autofunction:: from_stdlib_socket @@ -128,9 +295,7 @@ Socket objects to port hijacking attacks `__. - 2. ``TCP_NODELAY`` is enabled by default. - - 3. ``IPV6_V6ONLY`` is disabled, i.e., by default on dual-stack + 2. ``IPV6_V6ONLY`` is disabled, i.e., by default on dual-stack hosts a ``AF_INET6`` socket is able to communicate with both IPv4 and IPv6 peers, where the IPv4 peers appear to be in the `"IPv4-mapped" portion of IPv6 address space @@ -143,10 +308,6 @@ Socket objects This makes trio applications behave more consistently across different environments. - 4. On platforms where it's supported (recent Linux and recent - MacOS), ``TCP_NOTSENT_LOWAT`` is enabled with a reasonable - buffer size (currently 16 KiB). - See `issue #72 `__ for discussion of these defaults. @@ -163,12 +324,6 @@ Socket objects `Not implemented yet! `__ - The following methods are *not* provided: - - * :meth:`~socket.socket.send`: This method has confusing semantics - hidden under a friendly name, and makes it too easy to create - subtle bugs. Use :meth:`sendall` instead. - The following methods are identical to their equivalents in :func:`socket.socket`, except async, and the ones that take address arguments require pre-resolved addresses: @@ -180,6 +335,7 @@ Socket objects * :meth:`~socket.socket.recvfrom_into` * :meth:`~socket.socket.recvmsg` (if available) * :meth:`~socket.socket.recvmsg_into` (if available) + * :meth:`~socket.socket.send` * :meth:`~socket.socket.sendto` * :meth:`~socket.socket.sendmsg` (if available) @@ -204,37 +360,6 @@ Socket objects * :meth:`~socket.socket.get_inheritable` -The abstract Stream API ------------------------ - -(this is currently more of a sketch than something actually useful, -`see issue #73 `__) - -.. currentmodule:: trio - -.. autoclass:: AsyncResource - :members: - :undoc-members: - -.. autoclass:: SendStream - :members: - :undoc-members: - -.. autoclass:: RecvStream - :members: - :undoc-members: - -.. autoclass:: Stream - :members: - :undoc-members: - - -TLS support ------------ - -`Not implemented yet! `__ - - Async disk I/O -------------- diff --git a/docs/source/reference-testing.rst b/docs/source/reference-testing.rst index c45e75492c..66c138f1ea 100644 --- a/docs/source/reference-testing.rst +++ b/docs/source/reference-testing.rst @@ -71,6 +71,92 @@ Inter-task ordering .. autofunction:: wait_all_tasks_blocked +.. _testing-streams: + +Streams +------- + +Virtual, controllable streams +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +One particularly challenging problem when testing network protocols is +making sure that your implementation can handle data whose flow gets +broken up in weird ways and arrives with weird timings: localhost +connections tend to be much better behaved than real networks, so if +you only test on localhost then you might get bitten later. To help +you out, trio provides some fully in-memory implementations of the +stream interfaces (see :ref:`abstract-stream-api`), that let you write +all kinds of interestingly evil tests. + +There are a few pieces here, so here's how they fit together: + +:func:`memory_stream_pair` gives you a pair of connected, +bidirectional streams. It's like :func:`socket.socketpair`, but +without any involvement from that pesky operating system and its +networking stack. + +To build a bidirectional stream, :func:`memory_stream_pair` uses +two unidirectional streams. It gets these by calling +:func:`memory_stream_one_way_pair`. + +:func:`memory_stream_one_way_pair`, in turn, is implemented using the +low-ish level classes :class:`MemorySendStream` and +:class:`MemoryReceiveStream`. These are implementations of (you +guessed it) :class:`trio.abc.SendStream` and +:class:`trio.abc.ReceiveStream` that on their own, aren't attached to +anything – "sending" and "receiving" just put data into and get data +out of a private internal buffer that each object owns. They also have +some interesting hooks you can set, that let you customize the +behavior of their methods. This is where you can insert the evil, if +you want it. :func:`memory_stream_one_way_pair` takes advantage of +these hooks in a relatively boring way: it just sets it up so that +when you call ``sendall``, or when you close the send stream, then it +automatically triggers a call to :func:`memory_stream_pump`, which is +a convenience function that takes data out of a +:class:`MemorySendStream`\´s buffer and puts it into a +:class:`MemoryReceiveStream`\´s buffer. But that's just the default – +you can replace this with whatever arbitrary behavior you want. + +Trio also provides some specialized functions for testing completely +**un**\buffered streams: :func:`lockstep_stream_one_way_pair` and +:func:`lockstep_stream_pair`. These aren't customizable, but they do +exhibit an extreme kind of behavior that's otherwise hard to +implement. + + +API details +~~~~~~~~~~~ + +.. autoclass:: MemorySendStream + :members: + +.. autoclass:: MemoryReceiveStream + :members: + +.. autofunction:: memory_stream_pump + +.. autofunction:: memory_stream_one_way_pair + +.. autofunction:: memory_stream_pair + +.. autofunction:: lockstep_stream_one_way_pair + +.. autofunction:: lockstep_stream_pair + + +Testing custom stream implementations +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Trio also provides some functions to help you test your custom stream +implementations: + +.. autofunction:: check_one_way_stream + +.. autofunction:: check_two_way_stream + +.. autofunction:: check_half_closeable_stream + + Testing checkpoints -------------------- diff --git a/docs/source/tutorial.rst b/docs/source/tutorial.rst index 63f5880346..def13e2f78 100644 --- a/docs/source/tutorial.rst +++ b/docs/source/tutorial.rst @@ -855,9 +855,13 @@ thing, and then we'll discuss the pieces: .. literalinclude:: tutorial/echo-server-low-level.py :linenos: -The actual echo server implementation should be fairly familiar at -this point. Each incoming connection from an echo client gets handled -by its own dedicated task, running the ``echo_server`` function: +Let's start with ``echo_server``. As we'll see below, each time an +echo client connects, our server will spawn a child task running +``echo_server``; there might be lots of these running at once if lots +of clients are connected. Its job is to read data from its particular +client, and then echo it back. It should be pretty straightforward to +understand, because it uses the same socket functions we saw in the +last section: .. literalinclude:: tutorial/echo-server-low-level.py :linenos: @@ -922,8 +926,8 @@ tasks, we end up with a task tree that looks like: ┆ This lets ``parent`` focus on supervising the children, -``echo_listener`` focus on listening for new connections, each -``echo_server`` call will handle a single client. +``echo_listener`` focus on listening for new connections, and each +``echo_server`` focus on handling a single client. Once we know this trick, the listener code becomes pretty straightforward: diff --git a/notes-to-self/server.crt b/notes-to-self/server.crt new file mode 100644 index 0000000000..9c58d8e65b --- /dev/null +++ b/notes-to-self/server.crt @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDBjCCAe4CCQDq+3W9D8C4ejANBgkqhkiG9w0BAQsFADBFMQswCQYDVQQGEwJB +VTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50ZXJuZXQgV2lkZ2l0 +cyBQdHkgTHRkMB4XDTE3MDMxOTAzMDk1MVoXDTE4MDMxOTAzMDk1MVowRTELMAkG +A1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0 +IFdpZGdpdHMgUHR5IEx0ZDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB +AOwDDFVh8pvIrhZtIIX6pb3/PO5SM3rWsfoyyHi73GxemIiEHfYEjMKN8Eo10jUv +4G0n8VlrrmuhGR+UuHY6jCxjoCuYWszQhwBZBaeGE24ydtO/RE24yhNsJHPQWXMe +TL4mg1EBjJYXTwNhd7SwgCpkBQ+724ZJg+CmiPuYhVLdvjjUUmwiSbeueyULIPEJ +G1EWkKdU5pYtyyTZoc0x2YEjes3YNWY563yk+RljvidFBMyAX8N3fF4yrCCHDeY6 +UPdpXry/BJcEJm7PY2lMhbL71T6499qKnmSaWyJjm+KqbXSEYXoWDVBBvg5pR9Ia +XSoJ1MTfJ8eYnZDs5mETYDkCAwEAATANBgkqhkiG9w0BAQsFAAOCAQEApaW8WiKA +3yDOUVzgwkeX3HxvxfhxtMTPBmO1M8YgX1yi+URkkKakc6bg3XW1saQrxBkXWwBr +81Atd0tOLwHsC1HPd7Y5Q/1LKZiYFq2Sva6eZfeedRF/0f/SQC+rSvZNI5DIVPS4 +jW/EpyMKIeerIyWeFXz0/NWcYLCDWN6m2iDtR3m98bJcqSdUemLgyR13EAWsaVZ7 +dB6nkwGl9e78SOIHeGYg1Fb0B7IN2Tqw2tO3Xn0mzhvqs65OYuYo4pB0FzxiySAB +q2nrgu6kGhkQw/RQ8QJ5MYjydYqCU0I4Qve1W7RoUxRnJvxJrMuvcdlMeboASKNl +L7YQurFGvAAiZQ== +-----END CERTIFICATE----- diff --git a/notes-to-self/server.csr b/notes-to-self/server.csr new file mode 100644 index 0000000000..f0fbc3829d --- /dev/null +++ b/notes-to-self/server.csr @@ -0,0 +1,16 @@ +-----BEGIN CERTIFICATE REQUEST----- +MIICijCCAXICAQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUx +ITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDCCASIwDQYJKoZIhvcN +AQEBBQADggEPADCCAQoCggEBAOwDDFVh8pvIrhZtIIX6pb3/PO5SM3rWsfoyyHi7 +3GxemIiEHfYEjMKN8Eo10jUv4G0n8VlrrmuhGR+UuHY6jCxjoCuYWszQhwBZBaeG +E24ydtO/RE24yhNsJHPQWXMeTL4mg1EBjJYXTwNhd7SwgCpkBQ+724ZJg+CmiPuY +hVLdvjjUUmwiSbeueyULIPEJG1EWkKdU5pYtyyTZoc0x2YEjes3YNWY563yk+Rlj +vidFBMyAX8N3fF4yrCCHDeY6UPdpXry/BJcEJm7PY2lMhbL71T6499qKnmSaWyJj +m+KqbXSEYXoWDVBBvg5pR9IaXSoJ1MTfJ8eYnZDs5mETYDkCAwEAAaAAMA0GCSqG +SIb3DQEBCwUAA4IBAQC+LhkPmCjxk5Nzn743u+7D/YzNhjv8Xv4aGUjjNyspVNso +tlCAWkW2dWo8USvQrMUz5yl6qj6QQlg0QaYfaIiK8pkGz4s+Sh1plz1Eaa7QDK4O +0wmtP6KkJyQW561ZY8sixS1DevKOmsp2Pa9fWU/vqKfzRv85A975XNadp6hkxXd7 +YOZCrSZjTnakpQvKoItvT9Xk7yKP6BI6h/03XORscbW/HyvLGoVLdE80yIkmjSot +3JXxHspT27bWNWhz/Slph3UFaVyOVGXFTAqkLDZ3OISMnuC+q/t38EHYkR1aev/l +4WogCtlWkFZ3bmhmlhJrH/bdTEkM6WopwoC6bczh +-----END CERTIFICATE REQUEST----- diff --git a/notes-to-self/server.key b/notes-to-self/server.key new file mode 100644 index 0000000000..c0ba0b8582 --- /dev/null +++ b/notes-to-self/server.key @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEogIBAAKCAQEA7AMMVWHym8iuFm0ghfqlvf887lIzetax+jLIeLvcbF6YiIQd +9gSMwo3wSjXSNS/gbSfxWWuua6EZH5S4djqMLGOgK5hazNCHAFkFp4YTbjJ2079E +TbjKE2wkc9BZcx5MviaDUQGMlhdPA2F3tLCAKmQFD7vbhkmD4KaI+5iFUt2+ONRS +bCJJt657JQsg8QkbURaQp1Tmli3LJNmhzTHZgSN6zdg1ZjnrfKT5GWO+J0UEzIBf +w3d8XjKsIIcN5jpQ92levL8ElwQmbs9jaUyFsvvVPrj32oqeZJpbImOb4qptdIRh +ehYNUEG+DmlH0hpdKgnUxN8nx5idkOzmYRNgOQIDAQABAoIBABuus9ij93fsTvcU +b7cnUh95+6ScgatL2W5WXItExbL0WYHRtU3w9K2xRlj9/Rz98536DGYHqlq3d6Hr +qMM9VMm0GcpjQWs6nksdJfujT04inytxCMrw/MrQaWooKwXErQ20qLxsqRfFvh/Q +Y+EOvsm6F5nj1/jlUJGeFv0jw6eXXxH6bqVUVYIaVCpAMB5Sm8caQ4dAI9UESZJv +vuucT24iSyV8vp060L1tNKgRUr5e2CMfbucauZh0nLALPAyu1I07Ce62q9wLLw66 +c2FLHcZBkTGvL0bPe89ttJJuK0jttHV6GQ/OneytezZFxLw1DMsG3VxzbXt2X7AN +noGzrDECgYEA/fnK0xlNir9bTNmOID42GoVUYF6iYWUf//rRlCKsRofSlPjTZDJK +grl/plTBKDE6qDDEkB1mrEkJufqP3slyq66NfkP0NLoo+PFkGSnsbvUvFNYwcYvH +7w2NWo/GvM4DJRqHvrETryBQwQtBJFsq9biWd3+hNCXYrhawKGqbzw0CgYEA7eSa +T6zIdmvszG5x1XzQ3k29GwUA4SLL1YfV2pnLxoMZgW5Q6T+cOCpJdEuL0BXCNelP +gk0gNXNvCzylwVC0BbpefFaJYsWK6gVg1EwDkiZcGx4FnKd0TWYer6RWrZ9cVohT +eNwix9kKVef7chf+2006eE1O8D0UYwZMpGifqt0CgYAKjmtjwtV6QuHkm9ZQeMV+ +7LPJHaXaLn3aAe7cHWTTuamDD6SZsY1vSY6It1Uf+ovZmc1RwCcYWiDRXhzEwdLG +WAcBjImF94bkcgQbF6cAJajDUPPKhGjXAtUxQnCcQGPZEvU5c9rBmLJCk9ktTazH +cdivNtrYdApBkifYRjYbsQKBgDZl0ctqTSSXJTzR/IG+2twalqV5DWxt0oJvXz1v +caNhExH/sczEWOqW8NkA9WWNtC0zvpSjIjxWuwuswJJl6+Rra3OvLhdB6LP+qteg +0ig3UVR6FvptaDDSqy2qvI9TI4A+CChY3jMotC5Ur7C1P/fRvw8HToesz96c8CWg +LvKZAoGAS4VW2YaYVlN14mkKGtHbZq0XWZGursAwYJn6H5jBCjl+89tsFzfCnTER +hZFH+zs281v1bVbCFws8lWrZg8AZh55dG8CcLtuCkTyXJ/aAdlan0+FmXV60+RLP +Z1TyykQG/oDgO1Z+5GrcN2b0FOFaSbH2NRzRlhyOI63yTQi4lT8= +-----END RSA PRIVATE KEY----- diff --git a/notes-to-self/server.orig.key b/notes-to-self/server.orig.key new file mode 100644 index 0000000000..28fac173ff --- /dev/null +++ b/notes-to-self/server.orig.key @@ -0,0 +1,30 @@ +-----BEGIN RSA PRIVATE KEY----- +Proc-Type: 4,ENCRYPTED +DEK-Info: DES-EDE3-CBC,D9C5B2214855387C + +gYuCZiXsU74IOErbOGOmc1y/BFP1N7UuRO19tidUrq1O6sreJSAVKRibIynAwmXj +p5xvaAnBBIZIH6X7I2vduJgtUeeyvy5yxR98pD6liRKDxFaVD+O1m5IZxSbAs2De +olk4Zlv3YULpbVF6Ud+QuLmgbqfmT+8NVGm4MwRey7Gkj+LEfGrNjpfgLqNRIaUZ +XDPQh9HLZYCsAbz5OeRHwJwawLO74fvWBkFjsQyoWLgJzqZFmt15SyrRufBeKYP6 +oKKemsiW8/A2+i6Rb1vHYOJJ6c9jeeHPJkZSbfNWf4/Z702DAMIisbHmTCQzsUrX +178d2Z7sDKcuDCQ1EInnLRb3YET/V83wGDWyHxWepaHLWHd5S7tFbsqZFsXxIuYM +lcZZVSPsOLnG2SozZK+Tr2RX7jkI4Kmfh0RDgtKBYQQopZjRSFUG2hvvH3EIxVIf +JyUG8AA5RT1J9tkcSJJ5MS40So7i3eyAuZXuYVSkuDai/mu2IUU8vYnwB8a1psU1 +P2CGUj2AFopvMAfSOYIPGHpcIn+lvxuXUdczR/Yikp/BhGT+diJjP68CUsMBdyq7 +pcVmMVyQPVpcsMag3IXGgIAF1v1GhO3zDMd1uXA1lyrHQa6CEah3z+4WFSWwYZ0I +OZz5qM9bnfKoAQesp+xmcZhs8cbrblMRVDkWiPUixxKVJk3eBUsMoa1WYq/2u0ly +EgvNlY39B/3eiLi+k+S6gVGT8a4AP6n4RuxPD4g0A79bI1LpC3xU6vcKV/GyIP69 +t2DHR2q9UDEiRj9DxjuTzxew7eRX8ktD7DhYV06BxrgQIRRiL9MrZRKGuqzXcMP/ +kWY71ioFZJ1ViZkpy7bEsYrpF2XBjGge3We0s2udnrY3r3ogxOjtZiT27e2zEbXD +T59C3gecuEzCSCZ3eQtdcVC9m3RdHMTNNKvqmTVFPgfGOoM5u2gG+rYjhetbpTDB +T5drcEEAcV11DHuokU4tlqOdIWdLuBsK3xgO98JasEr1LyYJT1fnjB+6AbhjfSS2 +p5TPekmSwaZbaBzwfP1xmhINJm388GCROXMkc9iLAWN9npHhssfMAA2WMXqDTgSt +34oUnHgLGmvOm5HzJE/tTR1WP1Rye4nKNLwsbk2x7WXxqcNUPYc+OVmZbsl/R5Gz +3zRHPts01mT/eaSfqj1wkJgpYtDQLPO+V1fc2pDgJmQMYyr7OCLI6I9GJBlB8gVq +aemv0TMi3/eUVyJRaAHxAAi7YMsrSkuKUrsbLfIIRgViaEy+1stFa9iWiHJT0DKJ +0fOqtwcL8OYJURyG/D29yUP5qBJcrFuIYk8uI1wtfDNMeAI4LWoWwMhBLtB6POY+ +a/qmMewFzrGGsR9R0ptwtlplhvJVeArfLYGngnbgBV4vwchjLQTR2RMouZWlwRH9 +NWX6EqsIk/zzYvu+o7sBC2839D3GCPQMmgKqSWwmlf2a76mqZk2duTO9+0v6+e+F +Qc44ndLFE+mEibXkm9PMHvPsXOUdC4KPpugC/aZbn4OCqVd3eSl7k+PZGKZua6IJ +ybhosNzQc4lg25K7iMxRXpK5WrOgEXSAA3kUquDRTWHshpz/Avwbgw== +-----END RSA PRIVATE KEY----- diff --git a/notes-to-self/ssl-handshake/ssl-handshake.py b/notes-to-self/ssl-handshake/ssl-handshake.py new file mode 100644 index 0000000000..81d875be6a --- /dev/null +++ b/notes-to-self/ssl-handshake/ssl-handshake.py @@ -0,0 +1,131 @@ +import ssl +import socket +import threading +from contextlib import contextmanager + +BUFSIZE = 4096 + +server_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) +server_ctx.load_cert_chain("trio-test-1.pem") + +def _ssl_echo_serve_sync(sock): + try: + wrapped = server_ctx.wrap_socket(sock, server_side=True) + while True: + data = wrapped.recv(BUFSIZE) + if not data: + wrapped.unwrap() + return + wrapped.sendall(data) + except BrokenPipeError: + pass + +@contextmanager +def echo_server_connection(): + client_sock, server_sock = socket.socketpair() + with client_sock, server_sock: + t = threading.Thread( + target=_ssl_echo_serve_sync, args=(server_sock,), daemon=True) + t.start() + + yield client_sock + +class ManuallyWrappedSocket: + def __init__(self, ctx, sock, **kwargs): + self.incoming = ssl.MemoryBIO() + self.outgoing = ssl.MemoryBIO() + self.obj = ctx.wrap_bio(self.incoming, self.outgoing, **kwargs) + self.sock = sock + + def _retry(self, fn, *args): + finished = False + while not finished: + want_read = False + try: + ret = fn(*args) + except ssl.SSLWantReadError: + want_read = True + except ssl.SSLWantWriteError: + # can't happen, but if it did this would be the right way to + # handle it anyway + pass + else: + finished = True + # do any sending + data = self.outgoing.read() + if data: + self.sock.sendall(data) + # do any receiving + if want_read: + data = self.sock.recv(BUFSIZE) + if not data: + self.incoming.write_eof() + else: + self.incoming.write(data) + # then retry if necessary + return ret + + def do_handshake(self): + self._retry(self.obj.do_handshake) + + def recv(self, bufsize): + return self._retry(self.obj.read, bufsize) + + def sendall(self, data): + self._retry(self.obj.write, data) + + def unwrap(self): + self._retry(self.obj.unwrap) + return self.sock + + +def wrap_socket_via_wrap_socket(ctx, sock, **kwargs): + return ctx.wrap_socket(sock, do_handshake_on_connect=False, **kwargs) + +def wrap_socket_via_wrap_bio(ctx, sock, **kwargs): + return ManuallyWrappedSocket(ctx, sock, **kwargs) + + +for wrap_socket in [ + wrap_socket_via_wrap_socket, + wrap_socket_via_wrap_bio, +]: + print("\n--- checking {} ---\n".format(wrap_socket.__name__)) + + print("checking with do_handshake + correct hostname...") + with echo_server_connection() as client_sock: + client_ctx = ssl.create_default_context(cafile="trio-test-CA.pem") + wrapped = wrap_socket( + client_ctx, client_sock, server_hostname="trio-test-1.example.org") + wrapped.do_handshake() + wrapped.sendall(b"x") + assert wrapped.recv(1) == b"x" + wrapped.unwrap() + print("...success") + + print("checking with do_handshake + wrong hostname...") + with echo_server_connection() as client_sock: + client_ctx = ssl.create_default_context(cafile="trio-test-CA.pem") + wrapped = wrap_socket( + client_ctx, client_sock, server_hostname="trio-test-2.example.org") + try: + wrapped.do_handshake() + except Exception: + print("...got error as expected") + else: + print("??? no error ???") + + print("checking withOUT do_handshake + wrong hostname...") + with echo_server_connection() as client_sock: + client_ctx = ssl.create_default_context(cafile="trio-test-CA.pem") + wrapped = wrap_socket( + client_ctx, client_sock, server_hostname="trio-test-2.example.org") + # We forgot to call do_handshake + # But the hostname is wrong so something had better error out... + sent = b"x" + print("sending", sent) + wrapped.sendall(sent) + got = wrapped.recv(1) + print("got:", got) + assert got == sent + print("!!!! successful chat with invalid host! we have been haxored!") diff --git a/notes-to-self/ssl-handshake/trio-test-1.pem b/notes-to-self/ssl-handshake/trio-test-1.pem new file mode 100644 index 0000000000..a0c1b773f9 --- /dev/null +++ b/notes-to-self/ssl-handshake/trio-test-1.pem @@ -0,0 +1,64 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDZ8yz1OrHX7aHp +Erfa1ds8kmYfqYomjgy5wDsGdb8i1gF4uxhHCRDQtNZANVOVXI7R3TMchA1GMxzA +ZYDuBDuEUsqTktbTEBNb4GjOhyMu1fF4dX/tMxf7GB+flTx178eE2exTOZLmSmBa +2laoDVe3CrBAYE7nZtBF630jKKKMsUuIl0CbFRHajpoqM3e3CeCo4KcbBzgujRA3 +AsVV6y5qhMH2zqLkOYaurVUfEkdjqoHFgj1VbjWpkTbrXAxPwW6v/uZK056bHgBg +go03RyWexaPapsF2oUm2JNdSN3z7MP0umKphO2n9icyGt9Bmkm2AKs3dA45VLPXh ++NohluqJAgMBAAECggEARlfWCtAG1ko8F52S+W5MdCBMFawCiq8OLGV+p3cZWYT4 +tJ6uFz81ziaPf+m2MF7POazK8kksf5u/i9k245s6GlseRsL90uE9XknvibjUAinK +5bYGs+fptYDzs+3WtbnOC3LKc5IBd5JJxwjxLwwfY1RvzldHIChu0CJRISfcTsvR +occ8hXdeft7svNymvTuwQd05u1yjzL4RwF8Be76i17j5+jDsrAaUKdxxwGNAyOU7 +OKrUY6G851T6NUGgC19iXAJ1wN9tVGIR5QOs3J/s6dCctnX5tN8Di7prkXCKvVlm +vhpC8XWWG+c3LhS90wmEBvKS0AfUeoPDHxMOLyzKgQKBgQD07lZRO0nsc38+PVaI +NrvlP90Q8OgbwMIC52jmSZK3b5YSh3TrllsbCg6hzUk1SAJsa3qi7B1vq36Fd+rG +LGDRW9xY0cfShLhzqvZWi45zU/RYnEcWHOuXQshLikx1DWUpg2KbLSVT2/lyvzmn +QgM1Te8CSxW5vrBRVfluXoJuEwKBgQDjzLAbwk/wdjITKlQtirtsJEzWi3LGuUrg +Z2kMz+0ztUU5d1oFL9B5xh0CwK8bpK9kYnoVZSy/r5+mGHqyz1eKaDdAXIR13nC0 +g7aZbTZzbt2btvuNZc3NCzRffHF3sCqp8a+oCryHyITjZcA+WYeU8nG0TQ5O8Zgr +Skbo1JGocwKBgQC4jCx1oFqe0pd5afYdREBnB6ul7B63apHEZmBfw+fMV0OYSoAK +Uovq37UOrQMQJmXNE16gC5BSZ8E5B5XaI+3/UVvBgK8zK9VfMd3Sb+yxcPyXF4lo +W/oXSrZoVJgvShyDHv/ZNDb/7KsTjon+QHryWvpPnAuOnON1JXZ/dq6ICQKBgCZF +AukG8esR0EPL/qxP/ECksIvyjWu5QU0F0m4mmFDxiRmoZWUtrTZoBAOsXz6johuZ +N61Ue/oQBSAgSKy1jJ1h+LZFVLOAlSqeXhTUditaWryINyaADdz+nuPTwjQ7Uk+O +nNX8R8P/+eNB+tP+snphaJzDvT2h9NCA//ypiXblAoGAJoLmotPI+P3KIRVzESL0 +DAsVmeijtXE3H+R4nwqUDQbBbFKx0/u2pbON+D5C9llaGiuUp9H+awtwQRYhToeX +CNguwWrcpuhFOCeXDHDWF/0NIZYD2wBMxjF/eUarvoLaT4Gi0yyWh5ExIKOW4bFk +EojUPSJ3gomOUp5bIFcSmSU= +-----END PRIVATE KEY----- +-----BEGIN CERTIFICATE----- +MIICrzCCAZcCAQEwDQYJKoZIhvcNAQELBQAwFzEVMBMGA1UECgwMVHJpbyB0ZXN0 +IENBMCAXDTE3MDQwOTEwMDcyMVoYDzIyOTEwMTIyMTAwNzIxWjAiMSAwHgYDVQQD +DBd0cmlvLXRlc3QtMS5leGFtcGxlLm9yZzCCASIwDQYJKoZIhvcNAQEBBQADggEP +ADCCAQoCggEBANnzLPU6sdftoekSt9rV2zySZh+piiaODLnAOwZ1vyLWAXi7GEcJ +ENC01kA1U5VcjtHdMxyEDUYzHMBlgO4EO4RSypOS1tMQE1vgaM6HIy7V8Xh1f+0z +F/sYH5+VPHXvx4TZ7FM5kuZKYFraVqgNV7cKsEBgTudm0EXrfSMoooyxS4iXQJsV +EdqOmiozd7cJ4KjgpxsHOC6NEDcCxVXrLmqEwfbOouQ5hq6tVR8SR2OqgcWCPVVu +NamRNutcDE/Bbq/+5krTnpseAGCCjTdHJZ7Fo9qmwXahSbYk11I3fPsw/S6YqmE7 +af2JzIa30GaSbYAqzd0DjlUs9eH42iGW6okCAwEAATANBgkqhkiG9w0BAQsFAAOC +AQEAlRNA96H88lVnzlpQUYt0pwpoy7B3/CDe8Uvl41thKEfTjb+SIo95F4l+fi+l +jISWSonAYXRMNqymPMXl2ir0NigxfvvrcjggER3khASIs0l1ICwTNTv2a40NnFY6 +ZjTaBeSZ/lAi7191AkENDYvMl3aGhb6kALVIbos4/5LvJYF/UXvQfrjriLWZq/I3 +WkvduU9oSi0EA4Jt9aAhblsgDHMBL0+LU8Nl1tgzy2/NePcJWjzBRQDlF8uxCQ+2 +LesZongKQ+lebS4eYbNs0s810h8hrOEcn7VWn7FfxZRkjeaKIst2FCHmdr5JJgxj +8fw+s7l2UkrNURAJ4IRNQvPB+w== +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIDBjCCAe6gAwIBAgIJAIUF+wna+nuzMA0GCSqGSIb3DQEBCwUAMBcxFTATBgNV +BAoMDFRyaW8gdGVzdCBDQTAgFw0xNzA0MDkxMDA3MjFaGA8yMjkxMDEyMjEwMDcy +MVowFzEVMBMGA1UECgwMVHJpbyB0ZXN0IENBMIIBIjANBgkqhkiG9w0BAQEFAAOC +AQ8AMIIBCgKCAQEAyhE82Cbq+2c2f9M+vj2f9v+z+0+bMZDUVPSXhBDiRdKubt+K +f9vY+ZH3ze1sm0iNgO6xU3OsDTlzO5z0TpsvEEbs0wgsJDUXD7Y8Fb1zH2jaVCro +Y6KcVfFZvD96zsVCnZy0vMsYJw20iIL0RNCtr17lXWVxd17OoVy91djFD9v/cixu +LRIr+N7pa8BDLUQUO/g0ui9YSC9Wgf67mr93KXKPGwjTHBGdjeZeex198j5aZjZR +lkPH/9g5d3hP7EI0EAIMDVd4dvwNJgZzv+AZINbKLAkQyE9AAm+xQ7qBSvdfAvKq +N/fwaFevmyrxUBcfoQxSpds8njWDb3dQzCn7ywIDAQABo1MwUTAdBgNVHQ4EFgQU +JiilveozF8Qpyy2fS3wV4foVRCswHwYDVR0jBBgwFoAUJiilveozF8Qpyy2fS3wV +4foVRCswDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAkcthPjkJ ++npvMKyQtUx7CSVcT4Ar0jHrvfPSg17ipyLv+MhSTIbS2VwhSYXxNmu+oBWKuYUs +BnNxly3+SOcs+dTP3GBMng91SBsz5hhbP4ws8uUtCvYJauzeHbeY67R14RT8Ws/b +mP6HDiybN7zy6LOKGCiz+sCoJqVZG/yBYO87iQsTTyNttgoG27yUSvzP07EQwUa5 +F9dI9Wn+4b5wP2ofMCu3asTbKXjfFbz3w5OkRgpGYhC4jhDdOw/819+01R9//GrM +54Gme03yDAAM7nGihr1Xtld3dp2gLuqv0WgxKBqvG5X+nCbr2WamscAP5qz149vo +y6Hq6P4mm2GmZw== +-----END CERTIFICATE----- diff --git a/notes-to-self/ssl-handshake/trio-test-CA.pem b/notes-to-self/ssl-handshake/trio-test-CA.pem new file mode 100644 index 0000000000..9bf34001b2 --- /dev/null +++ b/notes-to-self/ssl-handshake/trio-test-CA.pem @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDBjCCAe6gAwIBAgIJAIUF+wna+nuzMA0GCSqGSIb3DQEBCwUAMBcxFTATBgNV +BAoMDFRyaW8gdGVzdCBDQTAgFw0xNzA0MDkxMDA3MjFaGA8yMjkxMDEyMjEwMDcy +MVowFzEVMBMGA1UECgwMVHJpbyB0ZXN0IENBMIIBIjANBgkqhkiG9w0BAQEFAAOC +AQ8AMIIBCgKCAQEAyhE82Cbq+2c2f9M+vj2f9v+z+0+bMZDUVPSXhBDiRdKubt+K +f9vY+ZH3ze1sm0iNgO6xU3OsDTlzO5z0TpsvEEbs0wgsJDUXD7Y8Fb1zH2jaVCro +Y6KcVfFZvD96zsVCnZy0vMsYJw20iIL0RNCtr17lXWVxd17OoVy91djFD9v/cixu +LRIr+N7pa8BDLUQUO/g0ui9YSC9Wgf67mr93KXKPGwjTHBGdjeZeex198j5aZjZR +lkPH/9g5d3hP7EI0EAIMDVd4dvwNJgZzv+AZINbKLAkQyE9AAm+xQ7qBSvdfAvKq +N/fwaFevmyrxUBcfoQxSpds8njWDb3dQzCn7ywIDAQABo1MwUTAdBgNVHQ4EFgQU +JiilveozF8Qpyy2fS3wV4foVRCswHwYDVR0jBBgwFoAUJiilveozF8Qpyy2fS3wV +4foVRCswDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAkcthPjkJ ++npvMKyQtUx7CSVcT4Ar0jHrvfPSg17ipyLv+MhSTIbS2VwhSYXxNmu+oBWKuYUs +BnNxly3+SOcs+dTP3GBMng91SBsz5hhbP4ws8uUtCvYJauzeHbeY67R14RT8Ws/b +mP6HDiybN7zy6LOKGCiz+sCoJqVZG/yBYO87iQsTTyNttgoG27yUSvzP07EQwUa5 +F9dI9Wn+4b5wP2ofMCu3asTbKXjfFbz3w5OkRgpGYhC4jhDdOw/819+01R9//GrM +54Gme03yDAAM7nGihr1Xtld3dp2gLuqv0WgxKBqvG5X+nCbr2WamscAP5qz149vo +y6Hq6P4mm2GmZw== +-----END CERTIFICATE----- diff --git a/notes-to-self/sslobject.py b/notes-to-self/sslobject.py new file mode 100644 index 0000000000..cfac98676e --- /dev/null +++ b/notes-to-self/sslobject.py @@ -0,0 +1,76 @@ +from contextlib import contextmanager +import ssl + +client_ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) +client_ctx.check_hostname = False +client_ctx.verify_mode = ssl.CERT_NONE + +cinb = ssl.MemoryBIO() +coutb = ssl.MemoryBIO() +cso = client_ctx.wrap_bio(cinb, coutb) + +server_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) +server_ctx.load_cert_chain("server.crt", "server.key", "xxxx") +sinb = ssl.MemoryBIO() +soutb = ssl.MemoryBIO() +sso = server_ctx.wrap_bio(sinb, soutb, server_side=True) + +@contextmanager +def expect(etype): + try: + yield + except etype: + pass + else: + raise AssertionError("expected {}".format(etype)) + +with expect(ssl.SSLWantReadError): + cso.do_handshake() +assert not cinb.pending +assert coutb.pending + +with expect(ssl.SSLWantReadError): + sso.do_handshake() +assert not sinb.pending +assert not soutb.pending + +# A trickle is not enough +# sinb.write(coutb.read(1)) +# with expect(ssl.SSLWantReadError): +# cso.do_handshake() +# with expect(ssl.SSLWantReadError): +# sso.do_handshake() + +sinb.write(coutb.read()) +# Now it should be able to respond +with expect(ssl.SSLWantReadError): + sso.do_handshake() +assert soutb.pending + +cinb.write(soutb.read()) +with expect(ssl.SSLWantReadError): + cso.do_handshake() + +sinb.write(coutb.read()) +# server done! +sso.do_handshake() +assert soutb.pending + +# client done! +cinb.write(soutb.read()) +cso.do_handshake() + +cso.write(b"hello") +sinb.write(coutb.read()) +assert sso.read(10) == b"hello" +with expect(ssl.SSLWantReadError): + sso.read(10) + +# cso.write(b"x" * 2 ** 30) +# print(coutb.pending) + +assert not coutb.pending +assert not cinb.pending +sso.do_handshake() +assert not coutb.pending +assert not cinb.pending diff --git a/setup.py b/setup.py index 92ad6c4ddb..9a45bc9550 100644 --- a/setup.py +++ b/setup.py @@ -78,6 +78,9 @@ # https://bitbucket.org/pypa/wheel/issues/181/bdist_wheel-silently-discards-pep-508 #"cffi; os_name == 'nt'", # "cffi is required on windows" ], + # This means, just install *everything* you see under trio/, even if it + # doesn't look like a source file, so long as it appears in MANIFEST.in: + include_package_data=True, # Quirky bdist_wheel-specific way: # https://wheel.readthedocs.io/en/latest/#defining-conditional-dependencies # also supported by pip and setuptools, as long as they're vaguely diff --git a/test-requirements.txt b/test-requirements.txt index d43ae6ee1c..4373f85907 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,3 +1,4 @@ pytest >= 3.0.6 # fixes a bug in handling async def test_* pytest-cov -ipython # for the IPython traceback integration tests +ipython # for the IPython traceback integration tests +pyOpenSSL # for the ssl tests diff --git a/trio/__init__.py b/trio/__init__.py index fc122d55ee..040b2fdf10 100644 --- a/trio/__init__.py +++ b/trio/__init__.py @@ -31,6 +31,7 @@ __all__.append(_symbol) del _symbol, _value + from ._timeouts import * __all__ += _timeouts.__all__ @@ -46,7 +47,11 @@ from ._signals import * __all__ += _signals.__all__ +from ._network import * +__all__ += _network.__all__ + # Imported by default from . import socket from . import abc +from . import ssl # Not imported by default: testing diff --git a/trio/_core/_exceptions.py b/trio/_core/_exceptions.py index 589e5edda6..20f5235eb0 100644 --- a/trio/_core/_exceptions.py +++ b/trio/_core/_exceptions.py @@ -2,8 +2,8 @@ # Re-exported __all__ = [ - "TrioInternalError", "RunFinishedError", "WouldBlock", - "Cancelled", "PartialResult", + "TrioInternalError", "RunFinishedError", "WouldBlock", "Cancelled", + "ResourceBusyError", ] class TrioInternalError(Exception): @@ -75,7 +75,13 @@ class Cancelled(BaseException): Cancelled.__module__ = "trio" -@attr.s(slots=True, frozen=True) -class PartialResult: - # XX - bytes_sent = attr.ib() +class ResourceBusyError(Exception): + """Raised when a task attempts to use a resource that some other task is + already using, and this would lead to bugs and nonsense. + + For example, if two tasks try to send data through the same socket at the + same time, trio will raise :class:`ResourceBusyError` instead of letting + the data get scrambled. + + """ +ResourceBusyError.__module__ = "trio" diff --git a/trio/_core/_io_epoll.py b/trio/_core/_io_epoll.py index 838a68b507..0b08760157 100644 --- a/trio/_core/_io_epoll.py +++ b/trio/_core/_io_epoll.py @@ -98,7 +98,8 @@ async def _epoll_wait(self, fd, attr_name): self._registered[fd] = EpollWaiters() waiters = self._registered[fd] if getattr(waiters, attr_name) is not None: - raise RuntimeError( + await _core.yield_briefly() + raise _core.ResourceBusyError( "another task is already reading / writing this fd") setattr(waiters, attr_name, _core.current_task()) self._update_registrations(fd, currently_registered) diff --git a/trio/_core/_io_kqueue.py b/trio/_core/_io_kqueue.py index 39885275ff..f09870a4d4 100644 --- a/trio/_core/_io_kqueue.py +++ b/trio/_core/_io_kqueue.py @@ -80,7 +80,7 @@ def current_kqueue(self): def monitor_kevent(self, ident, filter): key = (ident, filter) if key in self._registered: - raise RuntimeError( + raise _core.ResourceBusyError( "attempt to register multiple listeners for same " "ident/filter pair") q = _core.UnboundedQueue() @@ -95,7 +95,8 @@ def monitor_kevent(self, ident, filter): async def wait_kevent(self, ident, filter, abort_func): key = (ident, filter) if key in self._registered: - raise RuntimeError( + await _core.yield_briefly() + raise _core.ResourceBusyError( "attempt to register multiple listeners for same " "ident/filter pair") self._registered[key] = _core.current_task() diff --git a/trio/_core/_io_windows.py b/trio/_core/_io_windows.py index 3820d9e83c..8725a08769 100644 --- a/trio/_core/_io_windows.py +++ b/trio/_core/_io_windows.py @@ -294,7 +294,7 @@ async def wait_overlapped(self, handle, lpOverlapped): if isinstance(lpOverlapped, int): lpOverlapped = ffi.cast("LPOVERLAPPED", lpOverlapped) if lpOverlapped in self._overlapped_waiters: - raise RuntimeError( + raise _core.ResourceBusyError( "another task is already waiting on that lpOverlapped") task = _core.current_task() self._overlapped_waiters[lpOverlapped] = task @@ -339,9 +339,11 @@ async def _wait_socket(self, which, sock): # sockets in another thread? And on unix we don't handle this case at # all), but hey, why not. if type(sock) is not stdlib_socket.socket: + await _core.yield_briefly() raise TypeError("need a stdlib socket") if sock in self._socket_waiters[which]: - raise RuntimeError( + await _core.yield_briefly() + raise _core.ResourceBusyError( "another task is already waiting to {} this socket" .format(which)) self._socket_waiters[which][sock] = _core.current_task() diff --git a/trio/_core/_run.py b/trio/_core/_run.py index 263cbc4dee..d22ff10a1a 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -366,14 +366,22 @@ class Task: coro = attr.ib() _runner = attr.ib() name = attr.ib() + # Invariant: + # - for unfinished tasks, result is None + # - for finished tasks, result is a Result object result = attr.ib(default=None) - # tasks start out unscheduled, and unscheduled tasks have None here + # Invariant: + # - for unscheduled tasks, _next_send is None + # - for scheduled tasks, _next_send is a Result object + # Tasks start out unscheduled. _next_send = attr.ib(default=None) _abort_func = attr.ib(default=None) # Task-local values, see _local.py _locals = attr.ib(default=attr.Factory(dict)) + # these are counts of how many cancel/schedule points this task has + # executed, for assert{_no,}_yields # XX maybe these should be exposed as part of a statistics() method? _cancel_points = attr.ib(default=0) _schedule_points = attr.ib(default=0) @@ -955,7 +963,7 @@ def _deliver_ki_cb(self): @_public @_hazmat - async def wait_all_tasks_blocked(self, cushion=0.0): + async def wait_all_tasks_blocked(self, cushion=0.0, tiebreaker=0): """Block until there are no runnable tasks. This is useful in testing code when you want to give other tasks a @@ -973,7 +981,9 @@ async def wait_all_tasks_blocked(self, cushion=0.0): then the one with the shortest ``cushion`` is the one woken (and the this task becoming unblocked resets the timers for the remaining tasks). If there are multiple tasks that have exactly the same - ``cushion``, then all are woken. + ``cushion``, then the one with the lowest ``tiebreaker`` value is + woken first. And if there are multiple tasks with the same ``cushion`` + and the same ``tiebreaker``, then all are woken. You should also consider :class:`trio.testing.Sequencer`, which provides a more explicit way to control execution ordering within a @@ -1010,7 +1020,7 @@ async def test_lock_fairness(): """ task = current_task() - key = (cushion, id(task)) + key = (cushion, tiebreaker, id(task)) self.waiting_for_idle[key] = task def abort(_): del self.waiting_for_idle[key] @@ -1201,7 +1211,7 @@ def run_impl(runner, async_fn, args): idle_primed = False if runner.waiting_for_idle: - cushion, _ = runner.waiting_for_idle.keys()[0] + cushion, tiebreaker, _ = runner.waiting_for_idle.keys()[0] if cushion < timeout: timeout = cushion idle_primed = True @@ -1229,7 +1239,7 @@ def run_impl(runner, async_fn, args): if not runner.runq and idle_primed: while runner.waiting_for_idle: key, task = runner.waiting_for_idle.peekitem(0) - if key[0] == cushion: + if key[:2] == (cushion, tiebreaker): del runner.waiting_for_idle[key] runner.reschedule(task) else: diff --git a/trio/_core/tests/test_io.py b/trio/_core/tests/test_io.py index 2047e95644..b7d5697f87 100644 --- a/trio/_core/tests/test_io.py +++ b/trio/_core/tests/test_io.py @@ -145,8 +145,9 @@ async def test_double_read(socketpair, wait_readable): async with _core.open_nursery() as nursery: nursery.spawn(wait_readable, a) await wait_all_tasks_blocked() - with pytest.raises(RuntimeError): - await wait_readable(a) + with assert_yields(): + with pytest.raises(_core.ResourceBusyError): + await wait_readable(a) nursery.cancel_scope.cancel() @write_socket_test @@ -158,8 +159,9 @@ async def test_double_write(socketpair, wait_writable): async with _core.open_nursery() as nursery: nursery.spawn(wait_writable, a) await wait_all_tasks_blocked() - with pytest.raises(RuntimeError): - await wait_writable(a) + with assert_yields(): + with pytest.raises(_core.ResourceBusyError): + await wait_writable(a) nursery.cancel_scope.cancel() diff --git a/trio/_network.py b/trio/_network.py new file mode 100644 index 0000000000..e0bf67f3a8 --- /dev/null +++ b/trio/_network.py @@ -0,0 +1,165 @@ +# "High-level" networking interface + +import errno +from contextlib import contextmanager + +from . import _core +from . import socket as tsocket +from ._util import UnLock +from .abc import HalfCloseableStream +from ._streams import ClosedStreamError, BrokenStreamError + +__all__ = ["SocketStream", "socket_stream_pair"] + +_closed_stream_errnos = { + # Unix + errno.EBADF, + # Windows + errno.ENOTSOCK, +} + +@contextmanager +def _translate_socket_errors_to_stream_errors(): + try: + yield + except OSError as exc: + if exc.errno in _closed_stream_errnos: + raise ClosedStreamError( + "this socket was already closed") from None + else: + raise BrokenStreamError( + "socket connection broken: {}".format(exc)) from exc + +class SocketStream(HalfCloseableStream): + """An implementation of the :class:`trio.abc.HalfCloseableStream` + interface based on a raw network socket. + + Args: + sock (trio.socket.SocketType): The trio socket object to wrap. Must have + type ``SOCK_STREAM``, and be connected. + + By default, :class:`SocketStream` enables ``TCP_NODELAY``, and (on + platforms where it's supported) enables ``TCP_NOTSENT_LOWAT`` with a + reasonable buffer size (currently 16 KiB) – see `issue #72 + `__ for discussion. You can + of course override these defaults by calling :meth:`setsockopt`. + + Once a :class:`SocketStream` object is constructed, it implements the full + :class:`trio.abc.HalfCloseableStream` interface. In addition, it provides + a few extra features: + + .. attribute:: socket + + The :class:`trio.socket.SocketType` object that this stream wraps. + + """ + def __init__(self, sock): + if not isinstance(sock, tsocket.SocketType): + raise TypeError("SocketStream requires trio.socket.SocketType") + if sock._real_type != tsocket.SOCK_STREAM: + raise ValueError("SocketStream requires a SOCK_STREAM socket") + try: + sock.getpeername() + except OSError: + err = ValueError("SocketStream requires a connected socket") + raise err from None + + self.socket = sock + self._send_lock = UnLock( + _core.ResourceBusyError, + "another task is currently sending data on this SocketStream") + + # Socket defaults: + + # Not supported on e.g. unix domain sockets + try: + self.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, True) + except OSError: + pass + + if hasattr(tsocket, "TCP_NOTSENT_LOWAT"): + try: + # 16 KiB is pretty arbitrary and could probably do with some + # tuning. (Apple is also setting this by default in CFNetwork + # apparently -- I'm curious what value they're using, though I + # couldn't find it online trivially. CFNetwork-129.20 source + # has no mentions of TCP_NOTSENT_LOWAT. This presentation says + # "typically 8 kilobytes": + # http://devstreaming.apple.com/videos/wwdc/2015/719ui2k57m/719/719_your_app_and_next_generation_networks.pdf?dl=1 + # ). The theory is that you want it to be bandwidth * + # rescheduling interval. + self.setsockopt( + tsocket.IPPROTO_TCP, tsocket.TCP_NOTSENT_LOWAT, 2 ** 14) + except OSError: + pass + + async def send_all(self, data): + if self.socket._did_SHUT_WR: + await _core.yield_briefly() + raise ClosedStreamError("can't send data after sending EOF") + with self._send_lock.sync: + with _translate_socket_errors_to_stream_errors(): + await self.socket.sendall(data) + + async def wait_send_all_might_not_block(self): + async with self._send_lock: + if self.socket.fileno() == -1: + raise ClosedStreamError + with _translate_socket_errors_to_stream_errors(): + await self.socket.wait_writable() + + async def send_eof(self): + async with self._send_lock: + # On MacOS, calling shutdown a second time raises ENOTCONN, but + # send_eof needs to be idempotent. + if self.socket._did_SHUT_WR: + return + with _translate_socket_errors_to_stream_errors(): + self.socket.shutdown(tsocket.SHUT_WR) + + async def receive_some(self, max_bytes): + if max_bytes < 1: + await _core.yield_briefly() + raise ValueError("max_bytes must be >= 1") + with _translate_socket_errors_to_stream_errors(): + return await self.socket.recv(max_bytes) + + def forceful_close(self): + self.socket.close() + + # graceful_close, __aenter__, __aexit__ inherited from HalfCloseableStream + # are OK + + def setsockopt(self, level, option, value): + """Set an option on the underlying socket. + + See :meth:`socket.socket.setsockopt` for details. + + """ + return self.socket.setsockopt(level, option, value) + + def getsockopt(self, level, option, buffersize=0): + """Check the current value of an option on the underlying socket. + + See :meth:`socket.socket.getsockopt` for details. + + """ + # This is to work around + # https://bitbucket.org/pypy/pypy/issues/2561 + # We should be able to drop it when the next PyPy3 beta is released. + if buffersize == 0: + return self.socket.getsockopt(level, option) + else: + return self.socket.getsockopt(level, option, buffersize) + + +def socket_stream_pair(): + """Returns a pair of connected :class:`SocketStream` objects. + + This is a convenience function that uses :func:`socket.socketpair` to + create the sockets, and then converts the result into + :class:`SocketStream`\s. + + """ + left, right = tsocket.socketpair() + return SocketStream(left), SocketStream(right) diff --git a/trio/_ssl.py b/trio/_ssl.py deleted file mode 100644 index f001374a6e..0000000000 --- a/trio/_ssl.py +++ /dev/null @@ -1,33 +0,0 @@ -# Use SSLObject to make a generic wrapper around Stream - -# SSL shutdown: -# - call unwrap() on the SSLSocket/SSLObject -# - this sends the "all done here" SSL message -# - but in many practical applications this is neither sent nor checked for, -# e.g. HTTPS usually ignores it: -# https://security.stackexchange.com/questions/82028/ssl-tls-is-a-server-always-required-to-respond-to-a-close-notify -# BUT it is important in some cases, so should be possible to handle -# properly. -# -# I think the answer is: close is synchronous, and the TLS Stream also has an -# async def unwrap() which sends the close_notify message. -# Possibly we should also default suppress_ragged_eofs to False, unlike the -# stdlib? not sure. - -# XX how closely should we match the stdlib API? -# - maybe suppress_ragged_eofs=False is a better default? -# - maybe check crypto folks for advice? -# - this is also interesting: https://bugs.python.org/issue8108#msg102867 - -# Definitely keep an eye on Cory's TLS API ideas on security-sig etc. - -# from ._stream import Stream - -# class SSLStream(Stream): -# async def unwrap(self): -# # does a clean shutdown (!) by calling SSLObject.unwrap(), sends the -# # resulting close_notify data, and *then* returns the underlying -# # stream. -# XX - -# # XX diff --git a/trio/_streams.py b/trio/_streams.py index 5ac0b58989..f94fdb4d85 100644 --- a/trio/_streams.py +++ b/trio/_streams.py @@ -4,129 +4,129 @@ import attr from . import _core +from .abc import HalfCloseableStream -__all__ = ["AsyncResource", "SendStream", "RecvStream", "Stream"] - -# The close API is a big question here. -# -# Technically, socket close can block if you set various weird -# lingering-related options. This doesn't seem very useful though. -# -# For other kinds of channels, though, the natural implementation definitely -# can block – e.g. TLS wants to send a goodbye message, and if we're tunneling -# over ssh or HTTP/2 e.g. then again closing requires sending some actual -# data. So we need a concept of a blocking close. -# -# BUT, if the other side is uncooperative, we can't necessarily block in -# close. So if a blocking close is cancelled, we need to do some sort of -# forceful cleanup before raising the exception. -# -# (Probably implementing H2-based streams will be a useful forcing function -# here to figure this out.) - -class AsyncResource(metaclass=abc.ABCMeta): - __slots__ = () - - @abc.abstractmethod - def forceful_close(self): - """Force an immediate close of this resource. +__all__ = [ + "BrokenStreamError", "ClosedStreamError", "StapledStream", +] - This will never block, but (depending on the resource in question) it - might be a "rude" shutdown. - """ - pass +class BrokenStreamError(Exception): + """Raised when an attempt to use a stream a stream fails due to external + circumstances. - async def graceful_close(self): - """Close this resource, gracefully. + For example, you might get this if you try to send data on a stream where + the remote side has already closed the connection. - This may block in order to perform a "graceful" shutdown (for example, - sending a message alerting the other side of a connection that it is - about to close). But, if cancelled, then it still *must* close the - underlying resource. + You *don't* get this error if *you* closed the stream – in that case you + get :class:`ClosedStreamError`. - Default implementation is to perform a :meth:`forceful_close` and then - execute a checkpoint. - """ - self.forceful_close() - await _core.yield_briefly() + This exception's ``__cause__`` attribute will often contain more + information about the underlying error. - async def __aenter__(self): - return self + """ + pass - async def __aexit__(self, *args): - await self.graceful_close() -# XX added in 3.6 -if hasattr(contextlib, "AbstractContextManager"): - contextlib.AbstractContextManager.register(AsyncResource) +class ClosedStreamError(Exception): + """Raised when an attempt to use a stream a stream fails because the + stream was already closed locally. -class SendStream(AsyncResource): - __slots__ = () + You *only* get this error if *your* code closed the stream object you're + attempting to use by calling + :meth:`~trio.abc.AsyncResource.graceful_close` or + similar. (:meth:`~trio.abc.SendStream.send_all` might also raise this if + you already called :meth:`~trio.abc.HalfCloseableStream.send_eof`.) + Therefore this exception generally indicates a bug in your code. - @abc.abstractmethod - async def sendall(self, data): - pass + If a problem arises elsewhere, for example due to a network failure or a + misbehaving peer, then you get :class:`BrokenStreamError` instead. - # This is only a hint, because in some cases we don't know (Windows), or - # we have only a noisy signal (TLS). And in the use cases this is included - # to account for, returning before it's actually writable is NBD, it just - # makes them slightly less efficient. - @abc.abstractmethod - async def wait_maybe_writable(self): - pass + """ + pass - @property - @abc.abstractmethod - def can_send_eof(self): - pass - @abc.abstractmethod - def send_eof(self): - pass +@attr.s(slots=True, cmp=False, hash=False) +class StapledStream(HalfCloseableStream): + """This class `staples `__ + together two unidirectional streams to make single bidirectional stream. -class RecvStream(AsyncResource): - __slots__ = () + Args: + send_stream (~trio.abc.SendStream): The stream to use for sending. + receive_stream (~trio.abc.ReceiveStream): The stream to use for + receiving. - @abc.abstractmethod - async def recv(self, max_bytes): - pass + Example: -class Stream(SendStream, RecvStream): - __slots__ = () + A silly way to make a stream that echoes back whatever you write to + it:: - @staticmethod - def staple(send_stream, recv_stream): - return StapledStream(send_stream=send_stream, recv_stream=recv_stream) + sock1, sock2 = trio.socket.socketpair() + echo_stream = StapledStream(SocketStream(sock1), SocketStream(sock2)) + await echo_stream.send_all(b"x") + assert await echo_stream.receive_some(1) == b"x" -@attr.s(slots=True, cmp=False, hash=False) -class StapledStream(Stream): + :class:`StapledStream` objects implement the methods in the + :class:`~trio.abc.HalfCloseableStream` interface. They also have two + additional public attributes: + + .. attribute:: send_stream + + The underlying :class:`~trio.abc.SendStream`. :meth:`send_all` and + :meth:`wait_send_all_might_not_block` are delegated to this object. + + .. attribute:: receive_stream + + The underlying :class:`~trio.abc.ReceiveStream`. :meth:`receive_some` + is delegated to this object. + + """ send_stream = attr.ib() - recv_stream = attr.ib() + receive_stream = attr.ib() - async def sendall(self, data): - return await self.send_stream.sendall(data) + async def send_all(self, data): + """Calls ``self.send_stream.send_all``. - async def wait_maybe_writable(self): - return await self.send_stream.wait_maybe_writable() + """ + return await self.send_stream.send_all(data) - @property - def can_send_eof(self): - return self.send_stream.can_send_eof + async def wait_send_all_might_not_block(self): + """Calls ``self.send_stream.wait_send_all_might_not_block``. - def send_eof(self): - return self.send_stream.send_eof() + """ + return await self.send_stream.wait_send_all_might_not_block() + + async def send_eof(self): + """Shuts down the send side of the stream. + + If ``self.send_stream.send_eof`` exists, then calls it. Otherwise, + calls ``self.send_stream.graceful_close()``. - async def recv(self, max_bytes): - return self.recv_stream.recv(max_bytes) + """ + if hasattr(self.send_stream, "send_eof"): + return await self.send_stream.send_eof() + else: + return await self.send_stream.graceful_close() + + async def receive_some(self, max_bytes): + """Calls ``self.receive_stream.receive_some``. + + """ + return await self.receive_stream.receive_some(max_bytes) def forceful_close(self): + """Calls ``forceful_close`` on both underlying streams. + + """ try: self.send_stream.forceful_close() finally: - self.recv_stream.forceful_close() + self.receive_stream.forceful_close() async def graceful_close(self): + """Calls ``graceful_close`` on both underlying streams. + + """ try: await self.send_stream.graceful_close() finally: - await self.recv_stream.graceful_close() + await self.receive_stream.graceful_close() diff --git a/trio/_sync.py b/trio/_sync.py index e41e4f843b..fd0dcd5691 100644 --- a/trio/_sync.py +++ b/trio/_sync.py @@ -6,7 +6,8 @@ from . import _core from ._util import aiter_compat -__all__ = ["Event", "Semaphore", "Lock", "Condition", "Queue"] +__all__ = [ + "Event", "Semaphore", "Lock", "StrictFIFOLock", "Condition", "Queue"] @attr.s(slots=True, repr=False, cmp=False, hash=False) class Event: @@ -247,7 +248,8 @@ def __repr__(self): else: s1 = "unlocked" s2 = "" - return "<{} trio.Lock object at {:#x}{}>".format(s1, id(self), s2) + return ("<{} {} object at {:#x}{}>" + .format(s1, self.__class__.__name__, id(self), s2)) def locked(self): """Check whether the lock is currently held. @@ -327,6 +329,69 @@ def statistics(self): ) +class StrictFIFOLock(Lock): + """A variant of :class:`Lock` where tasks are guaranteed to acquire the + lock in strict first-come-first-served order. + + An example of when this is useful is if you're implementing something like + :class:`trio.ssl.SSLStream` or an HTTP/2 server using `h2 + `__, where you have multiple concurrent + tasks that are interacting with a shared state machine, and at + unpredictable moments the state machine requests that a chunk of data be + sent over the network. (For example, when using h2 simply reading incoming + data can occasionally `create outgoing data to send + `__.) The challenge is to make + sure that these chunks are sent in the correct order, without being + garbled. + + One option would be to use a regular :class:`Lock`, and wrap it around + every interaction with the state machine:: + + # This approach is sometimes workable but often sub-optimal; see below + async with lock: + state_machine.do_something() + if state_machine.has_data_to_send(): + await conn.sendall(state_machine.get_data_to_send()) + + But this can be problematic. If you're using h2 then *usually* reading + incoming data doesn't create the need to send any data, so we don't want + to force every task that tries to read from the network to sit and wait + a potentially long time for ``sendall`` to finish. And in some situations + this could even potentially cause a deadlock, if the remote peer is + waiting for you to read some data before it accepts the data you're + sending. + + :class:`StrictFIFOLock` provides an alternative. We can rewrite our + example like:: + + # Note: no awaits between when we start using the state machine and + # when we block to take the lock! + state_machine.do_something() + if state_machine.has_data_to_send(): + # Notice that we fetch the data to send out of the state machine + # *before* sleeping, so that other tasks won't see it. + chunk = state_machine.get_data_to_send() + async with strict_fifo_lock: + await conn.sendall(chunk) + + First we do all our interaction with the state machine in a single + scheduling quantum (notice there are no ``await``\s in there), so it's + automatically atomic with respect to other tasks. And then if and only if + we have data to send, we get in line to send it – and + :class:`StrictFIFOLock` guarantees that each task will send its data in + the same order that the state machine generated it. + + Currently, :class:`StrictFIFOLock` is simply an alias for :class:`Lock`, + but (a) this may not always be true in the future, especially if trio ever + implements `more sophisticated scheduling policies + `__, and (b) the above code + is relying on a pretty subtle property of its lock. Using a + :class:`StrictFIFOLock` acts as an executable reminder that you're relying + on this property. + + """ + + @attr.s(frozen=True) class _ConditionStatistics: tasks_waiting = attr.ib() @@ -350,7 +415,7 @@ class Condition: def __init__(self, lock=None): if lock is None: lock = Lock() - if not isinstance(lock, Lock): + if not type(lock) is Lock: raise TypeError("lock must be a trio.Lock") self._lock = lock self._lot = _core.ParkingLot() diff --git a/trio/_util.py b/trio/_util.py index 027ddae1e5..0986e02440 100644 --- a/trio/_util.py +++ b/trio/_util.py @@ -6,7 +6,14 @@ import async_generator -__all__ = ["signal_raise", "aitercompat", "acontextmanager"] +# There's a dependency loop here... _core is allowed to use this file (in fact +# it's the *only* file in the main trio/ package it's allowed to use), but +# UnLock needs yield_briefly so it also has to import _core. Possibly we +# should split this file into two: one for true generic low-level utility +# code, and one for higher level helpers? +from . import _core + +__all__ = ["signal_raise", "aiter_compat", "acontextmanager", "UnLock"] # Equivalent to the C function raise(), which Python doesn't wrap if os.name == "nt": @@ -67,7 +74,9 @@ async def __aiter__(*args, **kwargs): # Very much derived from the one in contextlib, by copy/pasting and then -# asyncifying everything. +# asyncifying everything. (Also I dropped the obscure support for using +# context managers as function decorators. It could be re-added; I just +# couldn't be bothered.) # So this is a derivative work licensed under the PSF License, which requires # the following notice: # @@ -138,3 +147,44 @@ def helper(*args, **kwds): # A hint for sphinxcontrib-trio: helper.__returns_acontextmanager__ = True return helper + + +class _UnLockSync: + def __init__(self, exc, *args): + self._exc = exc + self._args = args + self._held = False + + def __enter__(self): + if self._held: + raise self._exc(*self._args) + else: + self._held = True + + def __exit__(self, *args): + self._held = False + + +class UnLock: + """An unnecessary lock. + + Use as an async context manager; if two tasks enter it at the same + time then the second one raises an error. You can use it when there are + two pieces of code that *would* collide and need a lock if they ever were + called at the same time, but that should never happen. + + We use this in particular for things like, making sure that two different + tasks don't call sendall simultaneously on the same stream. + + This executes a checkpoint on entry. That's the only reason it's async. + + """ + def __init__(self, exc, *args): + self.sync = _UnLockSync(exc, *args) + + async def __aenter__(self): + await _core.yield_briefly() + return self.sync.__enter__() + + async def __aexit__(self, *args): + return self.sync.__exit__() diff --git a/trio/abc.py b/trio/abc.py index 6bb7e53785..0973745d08 100644 --- a/trio/abc.py +++ b/trio/abc.py @@ -1,6 +1,11 @@ +import contextlib as _contextlib import abc as _abc +from . import _core -__all__ = ["Clock", "Instrument"] +__all__ = [ + "Clock", "Instrument", "AsyncResource", "SendStream", "ReceiveStream", + "Stream", "HalfCloseableStream", +] class Clock(_abc.ABC): """The interface for custom run loop clocks. @@ -130,3 +135,267 @@ def after_io_wait(self, timeout): whether any I/O was ready. """ + + +# We use ABCMeta instead of ABC, plus setting __slots__=(), so as not to force +# a __dict__ onto subclasses. +class AsyncResource(metaclass=_abc.ABCMeta): + """A standard interface for resources that needs to be cleaned up, and + where that cleanup may require blocking operations. + + This class distinguishes between "graceful" closes, which may perform I/O + and thus block, and a "forceful" close, which cannot. For example, cleanly + shutting down a TLS-encrypted connection requires sending a "goodbye" + message; but if a peer has become non-responsive, then sending this + message might block forever, so we may want to just drop the connection + instead. Therefore the :meth:`graceful_close` method is unusual in that it + should always close the connection (or at least make its best attempt) + *even if it fails*; failure indicates a failure to achieve grace, not a + failure to close the connection. + + Objects that implement this interface can be used as async context + managers, i.e., you can write:: + + async with create_resource() as some_async_resource: + ... + + Entering the context manager is synchronous (not a checkpoint); exiting it + calls :meth:`graceful_close`. The default implementations of + ``__aenter__`` and ``__aexit__`` should be adequate for all subclasses. + + """ + __slots__ = () + + @_abc.abstractmethod + def forceful_close(self): + """Force an immediate close of this resource. + + This will never block, but (depending on the resource in question) it + might be a "rude" shutdown. + + If the resource is already closed, then this method should silently + succeed. + + """ + + async def graceful_close(self): + """Close this resource, gracefully. + + This may block in order to perform a "graceful" shutdown (for example, + sending a "goodbye" message). But, if this fails (e.g., due to being + cancelled), then it still *must* close the underlying resource, + possibly by calling :meth:`forceful_close`. + + If the resource is already closed, then this method should silently + succeed. + + :class:`AsyncResource` provides a default implementation of this + method that's suitable for resources that don't distinguish between + graceful and forceful closure: it simply calls :meth:`forceful_close` + and then executes a checkpoint. + + """ + try: + self.forceful_close() + finally: + await _core.yield_briefly() + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + await self.graceful_close() + + +class SendStream(AsyncResource): + """A standard interface for sending data on a byte stream. + + The underlying stream may be unidirectional, or bidirectional. If it's + bidirectional, then you probably want to also implement + :class:`SendStream`, which makes your object a :class:`Stream`. + + Every :class:`SendStream` also implements the :class:`AsyncResource` + interface. + + """ + __slots__ = () + + @_abc.abstractmethod + async def send_all(self, data): + """Sends the given data through the stream, blocking if necessary. + + Args: + data (bytes, bytearray, or memoryview): The data to send. + + Raises: + trio.ResourceBusyError: if another task is already executing a + :meth:`send_all`, :meth:`wait_send_all_might_not_block`, or + :meth:`HalfCloseableStream.send_eof` on this stream. + + """ + + @_abc.abstractmethod + async def wait_send_all_might_not_block(self): + """Block until it's possible that :meth:`send_all` might not block. + + This method may return early: it's possible that after it returns, + :meth:`send_all` will still block. (In the worst case, if no better + implementation is available, then it might always return immediately + without blocking. It's nice to do better than that when possible, + though.) + + This method **must not** return *late*: if it's possible for + :meth:`send_all` to complete without blocking, then it must + return. When implementing it, err on the side of returning early. + + Raises: + trio.ResourceBusyError: if another task is already executing a + :meth:`send_all`, :meth:`wait_send_all_might_not_block`, or + :meth:`HalfCloseableStream.send_eof` on this stream. + + Note: + + This method is intended to aid in implementing protocols that want + to delay choosing which data to send until the last moment. E.g., + suppose you're working on an implemention of a remote display server + like `VNC + `__, and + the network connection is currently backed up so that if you call + :meth:`send_all` now then it will sit for 0.5 seconds before actually + sending anything. In this case it doesn't make sense to take a + screenshot, then wait 0.5 seconds, and then send it, because the + screen will keep changing while you wait; it's better to wait 0.5 + seconds, then take the screenshot, and then send it, because this + way the data you deliver will be more + up-to-date. Using :meth:`wait_send_all_might_not_block` makes it + possible to implement the better strategy. + + If you use this method, you might also want to read up on + ``TCP_NOTSENT_LOWAT``. + + Further reading: + + * `Prioritization Only Works When There's Pending Data to Prioritize + `__ + + * WWDC 2015: Your App and Next Generation Networks: `slides + `__, + `video and transcript + `__ + + """ + + +class ReceiveStream(AsyncResource): + """A standard interface for receiving data on a byte stream. + + The underlying stream may be unidirectional, or bidirectional. If it's + bidirectional, then you probably want to also implement + :class:`SendStream`, which makes your object a :class:`Stream`. + + Every :class:`ReceiveStream` also implements the :class:`AsyncResource` + interface. + + """ + __slots__ = () + + @_abc.abstractmethod + async def receive_some(self, max_bytes): + """Wait until there is data available on this stream, and then return + at most ``max_bytes`` of it. + + A return value of ``b""`` (an empty bytestring) indicates that the + stream has reached end-of-file. Implementations should be careful that + they return ``b""`` if, and only if, the stream has reached + end-of-file! + + This method will return as soon as any data is available, so it may + return fewer than ``max_bytes`` of data. But it will never return + more. + + Args: + max_bytes (int): The maximum number of bytes to return. Must be + greater than zero. + + Returns: + bytes or bytearray: The data received. + + Raises: + trio.ResourceBusyError: if two tasks attempt to call + :meth:`receive_some` on the same stream at the same time. + trio.BrokenStreamError: if something has gone wrong, and the stream + is broken. + trio.ClosedStreamError: if someone already called one of the close + methods on this stream object. + + """ + + +class Stream(SendStream, ReceiveStream): + """A standard interface for interacting with bidirectional byte streams. + + A :class:`Stream` is an object that implements both the + :class:`SendStream` and :class:`ReceiveStream` interfaces. + + If implementing this interface, you should consider whether you can go one + step further and implement :class:`HalfCloseableStream`. + + """ + __slots__ = () + + +class HalfCloseableStream(Stream): + """This interface extends :class:`Stream` to also allow closing the send + part of the stream without closing the receive part. + + """ + __slots__ = () + + + @_abc.abstractmethod + async def send_eof(self): + """Send an end-of-file indication on this stream, if possible. + + The difference between :meth:`send_eof` and + :meth:`~AsyncResource.graceful_close` is that :meth:`send_eof` is a + *unidirectional* end-of-file indication. After you call this method, + you shouldn't try sending any more data on this stream, and your + remote peer should receive an end-of-file indication (eventually, + after receiving all the data you sent before that). But, they may + continue to send data to you, and you can continue to receive it by + calling :meth:`~ReceiveStream.receive_some`. You can think of it as + calling :meth:`~AsyncResource.graceful_close` on just the + :class:`SendStream` "half" of the stream object (and in fact that's + literally how :class:`trio.StapledStream` implements it). + + Examples: + + * On a socket, this corresponds to ``shutdown(..., SHUT_WR)`` (`man + page `__). + + * The SSH protocol provides the ability to multiplex bidirectional + "channels" on top of a single encrypted connection. A trio + implementation of SSH could expose these channels as + :class:`HalfCloseableStream` objects, and calling :meth:`send_eof` + would send an ``SSH_MSG_CHANNEL_EOF`` request (see `RFC 4254 §5.3 + `__). + + * On an SSL/TLS-encrypted connection, the protocol doesn't provide any + way to do a unidirectional shutdown without closing the connection + entirely, so :class:`~trio.ssl.SSLStream` implements + :class:`Stream`, not :class:`HalfCloseableStream`. + + If an EOF has already been sent, then this method should silently + succeed. + + Raises: + trio.ResourceBusyError: if another task is already executing a + :meth:`~SendStream.send_all`, + :meth:`~SendStream.wait_send_all_might_not_block`, or + :meth:`send_eof` on this stream. + trio.BrokenStreamError: if something has gone wrong, and the stream + is broken. + trio.ClosedStreamError: if someone already called one of the close + methods on this stream object. + + """ diff --git a/trio/socket.py b/trio/socket.py index f383f3f7aa..c5c2660aea 100644 --- a/trio/socket.py +++ b/trio/socket.py @@ -1,7 +1,9 @@ from functools import wraps as _wraps, partial as _partial import socket as _stdlib_socket import sys as _sys -import os +import os as _os +from contextlib import contextmanager as _contextmanager +import errno as _errno from . import _core from ._threads import run_in_worker_thread as _run_in_worker_thread @@ -185,6 +187,13 @@ def __init__(self, sock): .format(type(sock).__name__)) self._sock = sock self._sock.setblocking(False) + self._did_SHUT_WR = False + + # Hopefully Python will eventually make something like this public + # (see bpo-21327) but I don't want to make it public myself and then + # find out they picked a different name... this is used internally in + # this file and also elsewhere in trio. + self._real_type = sock.type & _SOCK_TYPE_MASK # Defaults: if self._sock.family == AF_INET6: @@ -200,25 +209,6 @@ def __init__(self, sock): else: self.setsockopt(SOL_SOCKET, SO_REUSEADDR, True) - try: - self.setsockopt(IPPROTO_TCP, TCP_NODELAY, True) - except OSError: - pass - - try: - # 16 KiB is pretty arbitrary and could probably do with some - # tuning. (Apple is also setting this by default in CFNetwork - # apparently -- I'm curious what value they're using, though I - # couldn't find it online trivially. CFNetwork-129.20 source has - # no mentions of TCP_NOTSENT_LOWAT. This presentation says - # "typically 8 kilobytes": - # http://devstreaming.apple.com/videos/wwdc/2015/719ui2k57m/719/719_your_app_and_next_generation_networks.pdf?dl=1 - # ). The theory is that you want it to be bandwidth * rescheduling - # interval. - self.setsockopt(IPPROTO_TCP, TCP_NOTSENT_LOWAT, 2 ** 14) - except (NameError, OSError): - pass - ################################################################ # Simple + portable methods and attributes ################################################################ @@ -235,7 +225,7 @@ def __init__(self, sock): _forward = { "detach", "get_inheritable", "set_inheritable", "fileno", "getpeername", "getsockname", "getsockopt", "setsockopt", "listen", - "shutdown", "close", "share", + "close", "share", } def __getattr__(self, name): if name in self._forward: @@ -282,6 +272,16 @@ def bind(self, address): self._check_address(address, require_resolved=True) return self._sock.bind(address) + def shutdown(self, flag): + # no need to worry about return value b/c always returns None: + self._sock.shutdown(flag) + # only do this if the call succeeded: + if flag in [SHUT_WR, SHUT_RDWR]: + self._did_SHUT_WR = True + + async def wait_writable(self): + await _core.wait_socket_writable(self._sock) + ################################################################ # Address handling ################################################################ @@ -315,7 +315,7 @@ def _check_address(self, address, *, require_resolved): _stdlib_socket.getaddrinfo( address[0], address[1], self._sock.family, - self._sock.type & _SOCK_TYPE_MASK, + self._real_type, self._sock.proto, flags=_NUMERIC_ONLY) except gaierror as exc: @@ -352,7 +352,7 @@ async def _resolve_address(self, address, flags): gai_res = await getaddrinfo( address[0], address[1], self._sock.family, - self._sock.type & _SOCK_TYPE_MASK, + self._real_type, self._sock.proto, flags) # AFAICT from the spec it's not possible for getaddrinfo to return an @@ -492,19 +492,54 @@ async def connect(self, address): # notification. This means it isn't really cancellable... async with _try_sync(): self._check_address(address, require_resolved=True) - # For some reason, PEP 475 left InterruptedError as a - # possible error for non-blocking connect - # (specifically). But as far as I know, EINTR always means - # you need to redo the call (with the extremely special - # exception of close() on Linux, but that's unrelated, and - # POSIX is cranky at them about it). If the kernel wanted - # to signal that the connect really was in progress then - # it'd have used EINPROGRESS. So we retry: - while True: - try: - return self._sock.connect(address) - except InterruptedError: - pass + # An interesting puzzle: can a non-blocking connect() return EINTR + # (= raise InterruptedError)? PEP 475 specifically left this as + # the one place where it lets an InterruptedError escape instead + # of automatically retrying. This is based on the idea that EINTR + # from connect means that the connection was already started, and + # will continue in the background. For a blocking connect, this + # sort of makes sense: if it returns EINTR then the connection + # attempt is continuing in the background, and on many system you + # can't then call connect() again because there is already a + # connect happening. See: + # + # http://www.madore.org/~david/computers/connect-intr.html + # + # For a non-blocking connect, it doesn't make as much sense -- + # surely the interrupt didn't happen after we successfully + # initiated the connect and are just waiting for it to complete, + # because a non-blocking connect does not wait! And the spec + # describes the interaction between EINTR/blocking connect, but + # doesn't have anything useful to say about non-blocking connect: + # + # http://pubs.opengroup.org/onlinepubs/007904975/functions/connect.html + # + # So we have a conundrum: if EINTR means that the connect() hasn't + # happened (like it does for essentially every other syscall), + # then InterruptedError should be caught and retried. If EINTR + # means that the connect() has successfully started, then + # InterruptedError should be caught and ignored. Which should we + # do? + # + # In practice, the resolution is probably that non-blocking + # connect simply never returns EINTR, so the question of how to + # handle it is moot. Someone spelunked MacOS/FreeBSD and + # confirmed this is true there: + # + # https://stackoverflow.com/questions/14134440/eintr-and-non-blocking-calls + # + # and exarkun seems to think it's true in general of non-blocking + # calls: + # + # https://twistedmatrix.com/pipermail/twisted-python/2010-September/022864.html + # (and indeed, AFAICT twisted doesn't try to handle + # InterruptedError). + # + # So we don't try to catch InterruptedError. This way if it + # happens, someone will hopefully tell us, and then hopefully we + # can investigate their system to figure out what its semantics + # are. + return self._sock.connect(address) # It raised BlockingIOError, meaning that it's started the # connection attempt. We wait for it to complete: try: @@ -518,7 +553,7 @@ async def connect(self, address): # Okay, the connect finished, but it might have failed: err = self._sock.getsockopt(SOL_SOCKET, SO_ERROR) if err != 0: - raise OSError(err, "Error in connect: " + os.strerror(err)) + raise OSError(err, "Error in connect: " + _os.strerror(err)) ################################################################ # recv @@ -568,7 +603,7 @@ async def connect(self, address): # send ################################################################ - _send = _make_simple_sock_method_wrapper( + send = _make_simple_sock_method_wrapper( "send", _core.wait_socket_writable) ################################################################ @@ -621,22 +656,23 @@ async def sendall(self, data, flags=0): ``flags`` are passed on to ``send``. - If an error occurs or the operation is cancelled, then the resulting - exception will have a ``.partial_result`` attribute with a - ``.bytes_sent`` attribute containing the number of bytes sent. + Most low-level operations in trio provide a guarantee: if they raise + :exc:`trio.Cancelled`, this means that they had no effect, so the + system remains in a known state. This is **not true** for + :meth:`sendall`. If this operation raises :exc:`trio.Cancelled` (or + any other exception for that matter), then it may have sent some, all, + or none of the requested data, and there is no way to know which. """ with memoryview(data) as data: + if not data: + await _core.yield_briefly() + return total_sent = 0 - try: - while data: - sent = await self._send(data, flags) - total_sent += sent - data = data[sent:] - except BaseException as exc: - pr = _core.PartialResult(bytes_sent=total_sent) - exc.partial_result = pr - raise + while total_sent < len(data): + with data[total_sent:] as remaining: + sent = await self.send(remaining, flags) + total_sent += sent ################################################################ # sendfile @@ -705,3 +741,5 @@ async def sendall(self, data, flags=0): # else: # raise OSError("getaddrinfo returned an empty list") # __all__.append("create_connection") + + diff --git a/trio/ssl.py b/trio/ssl.py new file mode 100644 index 0000000000..0c719a302f --- /dev/null +++ b/trio/ssl.py @@ -0,0 +1,784 @@ +# General theory of operation: +# +# We implement an API that closely mirrors the stdlib ssl module's blocking +# API, and we do it using the stdlib ssl module's non-blocking in-memory API. +# The stdlib non-blocking in-memory API is barely documented, and acts as a +# thin wrapper around openssl, whose documentation also leaves something to be +# desired. So here's the main things you need to know to understand the code +# in this file: +# +# We use an ssl.SSLObject, which exposes the four main I/O operations: +# +# - do_handshake: performs the initial handshake. Must be called once at the +# beginning of each connection; is a no-op once it's completed once. +# +# - write: takes some unencrypted data and attempts to send it to the remote +# peer. + +# - read: attempts to decrypt and return some data from the remote peer. +# +# - unwrap: this is weirdly named; maybe it helps to realize that the thing it +# wraps is called SSL_shutdown. It sends a cryptographically signed message +# saying "I'm closing this connection now", and then waits to receive the +# same from the remote peer (unless we already received one, in which case +# it returns immediately). +# +# All of these operations read and write from some in-memory buffers called +# "BIOs", which are an opaque OpenSSL-specific object that's basically +# semantically equivalent to a Python bytearray. When they want to send some +# bytes to the remote peer, they append them to the outgoing BIO, and when +# they want to receive some bytes from the remote peer, they try to pull them +# out of the incoming BIO. "Sending" always succeeds, because the outgoing BIO +# can always be extended to hold more data. "Receiving" acts sort of like a +# non-blocking socket: it might manage to get some data immediately, or it +# might fail and need to be tried again later. We can also directly add or +# remove data from the BIOs whenever we want. +# +# Now the problem is that while these I/O operations are opaque atomic +# operations from the point of view of us calling them, under the hood they +# might require some arbitrary sequence of sends and receives from the remote +# peer. This is particularly true for do_handshake, which generally requires a +# few round trips, but it's also true for write and read, due to an evil thing +# called "renegotiation". +# +# Renegotiation is the process by which one of the peers might arbitrarily +# decide to redo the handshake at any time. Did I mention it's evil? It's +# pretty evil, and almost universally hated. The HTTP/2 spec forbids the use +# of TLS renegotiation for HTTP/2 connections. TLS 1.3 removes it from the +# protocol entirely. It's impossible to trigger a renegotiation if using +# Python's ssl module. OpenSSL's renegotiation support is pretty buggy [1]. +# Nonetheless, it does get used in real life, mostly in two cases: +# +# 1) Normally in TLS 1.2 and below, when the client side of a connection wants +# to present a certificate to prove their identity, that certificate gets sent +# in plaintext. This is bad, because it means that anyone eavesdropping can +# see who's connecting – it's like sending your username in plain text. Not as +# bad as sending your password in plain text, but still, pretty bad. However, +# renegotiations *are* encrypted. So as a workaround, it's not uncommon for +# systems that want to use client certificates to first do an anonymous +# handshake, and then to turn around and do a second handshake (= +# renegotiation) and this time ask for a client cert. Or sometimes this is +# done on a case-by-case basis, e.g. a web server might accept a connection, +# read the request, and then once it sees the page you're asking for it might +# stop and ask you for a certificate. +# +# 2) In principle the same TLS connection can be used for an arbitrarily long +# time, and might transmit arbitrarily large amounts of data. But this creates +# a cryptographic problem: an attacker who has access to arbitrarily large +# amounts of data that's all encrypted using the same key may eventually be +# able to use this to figure out the key. Is this a real practical problem? I +# have no idea, I'm not a cryptographer. In any case, some people worry that +# it's a problem, so their TLS libraries are designed to automatically trigger +# a renegotation every once in a while on some sort of timer. +# +# The end result is that you might be going along, minding your own business, +# and then *bam*! a wild renegotiation appears! And you just have to cope. +# +# The reason that coping with renegotiations is difficult is that some +# unassuming "read" or "write" call might find itself unable to progress until +# it does a handshake, which remember is a process with multiple round +# trips. So read might have to send data, and write might have to receive +# data, and this might happen multiple times. And some of those attempts might +# fail because there isn't any data yet, and need to be retried. Managing all +# this is pretty complicated. +# +# Here's how openssl (and thus the stdlib ssl module) handle this. All of the +# I/O operations above follow the same rules. When you call one of them: +# +# - it might write some data to the outgoing BIO +# - it might read some data from the incoming BIO +# - it might raise SSLWantReadError if it can't complete without reading more +# data from the incoming BIO. This is important: the "read" in ReadError +# refers to reading from the *underlying* stream. +# - (and in principle it might raise SSLWantWriteError too, but that never +# happens when using memory BIOs, so never mind) +# +# If it doesn't raise an error, then the operation completed successfully +# (though we still need to take any outgoing data out of the memory buffer and +# put it onto the wire). If it *does* raise an error, then we need to retry +# *exactly that method call* later – in particular, if a 'write' failed, we +# need to try again later *with the same data*, because openssl might have +# already committed some of the initial parts of our data to its output even +# though it didn't tell us that, and has remembered that the next time we call +# write it needs to skip the first 1024 bytes or whatever it is. (Well, +# technically, we're actually allowed to call 'write' again with a data buffer +# which is the same as our old one PLUS some extra stuff added onto the end, +# but in trio that never comes up so never mind.) +# +# There are some people online who claim that once you've gotten a Want*Error +# then the *very next call* you make to openssl *must* be the same as the +# previous one. I'm pretty sure those people are wrong. In particular, it's +# okay to call write, get a WantReadError, and then call read a few times; +# it's just that *the next time you call write*, it has to be with the same +# data. +# +# One final wrinkle: we want our SSLStream to support full-duplex operation, +# i.e. it should be possible for one task to be calling send_all while another +# task is calling receive_some. But renegotiation makes this a big hassle, because +# even if SSLStream's restricts themselves to one task calling send_all and one +# task calling receive_some, those two tasks might end up both wanting to call +# send_all, or both to call receive_some at the same time *on the underlying +# stream*. So we have to do some careful locking to hide this problem from our +# users. +# +# (Renegotiation is evil.) +# +# So our basic strategy is to define a single helper method called "_retry", +# which has generic logic for dealing with SSLWantReadError, pushing data from +# the outgoing BIO to the wire, reading data from the wire to the incoming +# BIO, retrying an I/O call until it works, and synchronizing with other tasks +# that might be calling _retry concurrently. Basically it takes an SSLObject +# non-blocking in-memory method and converts it into a trio async blocking +# method. _retry is only about 30 lines of code, but all these cases +# multiplied by concurrent calls make it extremely tricky, so there are lots +# of comments down below on the details, and a really extensive test suite in +# test_ssl.py. And now you know *why* it's so tricky, and can probably +# understand how it works. +# +# [1] https://rt.openssl.org/Ticket/Display.html?id=3712 + +# XX how closely should we match the stdlib API? +# - maybe suppress_ragged_eofs=False is a better default? +# - maybe check crypto folks for advice? +# - this is also interesting: https://bugs.python.org/issue8108#msg102867 + +# Definitely keep an eye on Cory's TLS API ideas on security-sig etc. + +# XX document behavior on cancellation/error (i.e.: all is lost abandon +# stream) +# docs will need to make very clear that this is different from all the other +# cancellations in core trio + +import operator as _operator +import ssl as _stdlib_ssl +from enum import Enum as _Enum + +from . import _core +from .abc import Stream as _Stream +from . import _streams +from . import _sync +from ._util import UnLock as _UnLock + +__all__ = ["SSLStream"] + +################################################################ +# Faking the stdlib ssl API +################################################################ + +def _reexport(name): + try: + value = getattr(_stdlib_ssl, name) + except AttributeError: + pass + else: + globals()[name] = value + __all__.append(name) + +for _name in [ + "SSLError", "SSLZeroReturnError", "SSLSyscallError", "SSLEOFError", + "CertificateError", "create_default_context", "match_hostname", + "cert_time_to_seconds", "DER_cert_to_PEM_cert", "PEM_cert_to_DER_cert", + "get_default_verify_paths", "Purpose", "enum_certificates", + "enum_crls", "SSLSession", "VerifyMode", "VerifyFlags", "Options", + "AlertDescription", "SSLErrorNumber", + # Intentionally not re-exported: SSLContext +]: + _reexport(_name) + +for _name in _stdlib_ssl.__dict__.keys(): + if _name == _name.upper(): + _reexport(_name) + + +################################################################ +# SSLStream +################################################################ + +class _Once: + def __init__(self, afn, *args): + self._afn = afn + self._args = args + self.started = False + self._done = _sync.Event() + + async def ensure(self, *, checkpoint): + if not self.started: + self.started = True + await self._afn(*self._args) + self._done.set() + elif not checkpoint and self._done.is_set(): + return + else: + await self._done.wait() + + +_State = _Enum("_State", ["OK", "BROKEN", "CLOSED"]) + + +class SSLStream(_Stream): + """Encrypted communication using SSL/TLS. + + :class:`SSLStream` wraps an arbitrary :class:`~trio.abc.Stream`, and + allows you to perform encrypted communication over it using the usual + :class:`~trio.abc.Stream` interface. You pass regular data to + :meth:`send_all`, then it encrypts it and sends the encrypted data on the + underlying :class:`~trio.abc.Stream`; :meth:`receive_some` takes encrypted + data out of the underlying :class:`~trio.abc.Stream` and decrypts it + before returning it. + + You should read the standard library's :mod:`ssl` documentation carefully + before attempting to use this class, and probably other general + documentation on SSL/TLS as well. SSL/TLS is subtle and quick to + anger. Really. I'm not kidding. + + Args: + transport_stream (~trio.abc.Stream): The stream used to transport + encrypted data. Required. + + ssl_context (~ssl.SSLContext): The :class:`~ssl.SSLContext` used for + this connection. Required. Usually created by calling + :func:`trio.ssl.create_default_context() + `. + + server_hostname (str or None): The name of the server being connected + to. Used for `SNI + `__ and for + validating the server's certificate (if hostname checking is + enabled). This is effectively mandatory for clients, and actually + mandatory if ``ssl_context.check_hostname`` is True. + + server_side (bool): Whether this stream is acting as a client or + server. Defaults to False, i.e. client mode. + + https_compatible (bool): There are two versions of SSL/TLS commonly + encountered in the wild: the standard version, and the version used + for HTTPS (HTTP-over-SSL/TLS). + + Standard-compliant SSL/TLS implementations always send a + cryptographically signed ``close_notify`` message before closing the + connection. This is important because if the underlying transport + were simply closed, then there wouldn't be any way for the other + side to know whether the connection was intentionally closed by the + peer that they negotiated a cryptographic connection to, or by some + `man-in-the-middle + `__ attacker + who can't manipulate the cryptographic stream, but can manipulate + the transport layer (a so-called "truncation attack"). + + However, this part of the standard is widely ignored by real-world + HTTPS implementations, which means that if you want to interoperate + with them, then you NEED to ignore it too. + + Fortunately this isn't as bad as it sounds, because the HTTP + protocol already includes its own equivalent of ``close_notify``, so + doing this again at the SSL/TLS level is redundant. But not all + protocols do! Therefore, by default Trio implements the safer + standard-compliant version (``https_compatible=False``). But if + you're speaking HTTPS or some other protocol where + ``close_notify``\s are commonly skipped, then you should set + ``https_compatible=True``; with this setting, Trio will neither + expect nor send ``close_notify`` messages. + + If you have code that was written to use :class:`ssl.SSLSocket` and + now you're porting it to Trio, then it may be useful to know that a + difference between :class:`SSLStream` and :class:`ssl.SSLSocket` is + that :class:`~ssl.SSLSocket` implements the + ``https_compatible=True`` behavior by default. + + max_refill_bytes (int): :class:`~ssl.SSLSocket` maintains an internal + buffer of incoming data, and when it runs low then it calls + :meth:`receive_some` on the underlying transport stream to refill + it. This argument lets you set the ``max_bytes`` argument passed to + the *underlying* :meth:`receive_some` call. It doesn't affect calls + to *this* class's :meth:`receive_some`, or really anything else + user-observable except possibly performance. You probably don't need + to worry about this. + + Attributes: + transport_stream (trio.abc.Stream): The underlying transport stream + that was passed to ``__init__``. An example of when this would be + useful is if you're using :class:`SSLStream` over a + :class:`~trio.SocketStream` and want to call the + :class:`~trio.SocketStream`'s :meth:`~trio.SocketStream.setsockopt` + method. + + Internally, this class is implemented using an instance of + :class:`ssl.SSLObject`, and all of :class:`~ssl.SSLObject`'s methods and + attributes are re-exported as methods and attributes on this class. + + This also means that you register a SNI callback using + :meth:`~ssl.SSLContext.set_servername_callback`, then the first argument + your callback receives will be a :class:`ssl.SSLObject`. + + """ + def __init__( + self, transport_stream, ssl_context, + *, + server_hostname=None, server_side=False, + https_compatible=False, max_refill_bytes=32 * 1024 + ): + self.transport_stream = transport_stream + self._state = _State.OK + self._max_bytes = max_refill_bytes + self._https_compatible = https_compatible + self._outgoing = _stdlib_ssl.MemoryBIO() + self._incoming = _stdlib_ssl.MemoryBIO() + self._ssl_object = ssl_context.wrap_bio( + self._incoming, self._outgoing, + server_side=server_side, server_hostname=server_hostname) + # Tracks whether we've already done the initial handshake + self._handshook = _Once(self._do_handshake) + + # These are used to synchronize access to self.transport_stream + self._inner_send_lock = _sync.StrictFIFOLock() + self._inner_recv_count = 0 + self._inner_recv_lock = _sync.Lock() + + # These are used to make sure that our caller doesn't attempt to make + # multiple concurrent calls to send_all/wait_send_all_might_not_block + # or to receive_some. + self._outer_send_lock = _UnLock( + _core.ResourceBusyError, + "another task is currently sending data on this SSLStream") + self._outer_recv_lock = _UnLock( + _core.ResourceBusyError, + "another task is currently receiving data on this SSLStream") + + _forwarded = { + "context", "server_side", "server_hostname", "session", + "session_reused", "getpeercert", "selected_npn_protocol", "cipher", + "shared_ciphers", "compression", "pending", "get_channel_binding", + "selected_alpn_protocol", "version", + } + def __getattr__(self, name): + if name in self._forwarded: + return getattr(self._ssl_object, name) + else: + raise AttributeError(name) + + def __setattr__(self, name, value): + if name in self._forwarded: + setattr(self._ssl_object, name, value) + else: + super().__setattr__(name, value) + + def __dir__(self): + return super().__dir__() + list(self._forwarded) + + def _check_status(self): + if self._state is _State.OK: + return + elif self._state is _State.BROKEN: + raise _streams.BrokenStreamError + elif self._state is _State.CLOSED: + raise _streams.ClosedStreamError + else: # pragma: no cover + assert False + + # This is probably the single trickiest function in trio. It has lots of + # comments, though, just make sure to think carefully if you ever have to + # touch it. The big comment at the top of this file will help explain + # too. + async def _retry(self, fn, *args, ignore_want_read=False): + await _core.yield_if_cancelled() + yielded = False + try: + finished = False + while not finished: + # WARNING: this code needs to be very careful with when it + # calls 'await'! There might be multiple tasks calling this + # function at the same time trying to do different operations, + # so we need to be careful to: + # + # 1) interact with the SSLObject, then + # 2) await on exactly one thing that lets us make forward + # progress, then + # 3) loop or exit + # + # In particular we don't want to yield while interacting with + # the SSLObject (because it's shared state, so someone else + # might come in and mess with it while we're suspended), and + # we don't want to yield *before* starting the operation that + # will help us make progress, because then someone else might + # come in and leapfrog us. + + # Call the SSLObject method, and get its result. + # + # NB: despite what the docs say, SSLWantWriteError can't + # happen – "Writes to memory BIOs will always succeed if + # memory is available: that is their size can grow + # indefinitely." + # https://wiki.openssl.org/index.php/Manual:BIO_s_mem(3) + want_read = False + ret = None + try: + ret = fn(*args) + except _stdlib_ssl.SSLWantReadError: + want_read = True + except (SSLError, CertificateError) as exc: + self._state = _State.BROKEN + raise _streams.BrokenStreamError from exc + else: + finished = True + if ignore_want_read: + want_read = False + finished = True + to_send = self._outgoing.read() + + # Outputs from the above code block are: + # + # - to_send: bytestring; if non-empty then we need to send + # this data to make forward progress + # + # - want_read: True if we need to receive_some some data to make + # forward progress + # + # - finished: False means that we need to retry the call to + # fn(*args) again, after having pushed things forward. True + # means we still need to do whatever was said (in particular + # send any data in to_send), but once we do then we're + # done. + # + # - ret: the operation's return value. (Meaningless unless + # finished is True.) + # + # Invariant: want_read and finished can't both be True at the + # same time. + # + # Now we need to move things forward. There are two things we + # might have to do, and any given operation might require + # either, both, or neither to proceed: + # + # - send the data in to_send + # + # - receive_some some data and put it into the incoming BIO + # + # Our strategy is: if there's data to send, send it; + # *otherwise* if there's data to receive_some, receive_some it. + # + # If both need to happen, then we only send. Why? Well, we + # know that *right now* we have to both send and receive_some + # before the operation can complete. But as soon as we yield, + # that information becomes potentially stale – e.g. while + # we're sending, some other task might go and receive_some the + # data we need and put it into the incoming BIO. And if it + # does, then we *definitely don't* want to do a receive_some – + # there might not be any more data coming, and we'd deadlock! + # We could do something tricky to keep track of whether a + # receive_some happens while we're sending, but the case where + # we have to do both is very unusual (only during a + # renegotation), so it's better to keep things simple. So we + # do just one potentially-blocking operation, then check again + # for fresh information. + # + # And we prioritize sending over receiving because, if there + # are multiple tasks that want to receive_some, then it + # doesn't matter what order they go in. But if there are + # multiple tasks that want to send, then they each have + # different data, and the data needs to get put onto the wire + # in the same order that it was retrieved from the outgoing + # BIO. So if we have data to send, that *needs* to be the + # *very* *next* *thing* we do, to make sure no-one else sneaks + # in before us. Or if we can't send immediately because + # someone else is, then we at least need to get in line + # immediately. + if to_send: + # NOTE: This relies on the lock being strict FIFO fair! + async with self._inner_send_lock: + yielded = True + try: + await self.transport_stream.send_all(to_send) + except: + # Some unknown amount of our data got sent, and we + # don't know how much. This stream is doomed. + self._state = _State.BROKEN + raise + elif want_read: + # It's possible that someone else is already blocked in + # transport_stream.receive_some. If so then we want to + # wait for them to finish, but we don't want to call + # transport_stream.receive_some again ourselves; we just + # want to loop around and check if their contribution + # helped anything. So we make a note of how many times + # some task has been through here before taking the lock, + # and if it's changed by the time we get the lock, then we + # skip calling transport_stream.receive_some and loop + # around immediately. + recv_count = self._inner_recv_count + async with self._inner_recv_lock: + yielded = True + if recv_count == self._inner_recv_count: + data = await self.transport_stream.receive_some(self._max_bytes) + if not data: + self._incoming.write_eof() + else: + self._incoming.write(data) + self._inner_recv_count += 1 + return ret + finally: + if not yielded: + await _core.yield_briefly_no_cancel() + + async def _do_handshake(self): + try: + await self._retry(self._ssl_object.do_handshake) + except: + self._state = _State.BROKEN + raise + + async def do_handshake(self): + """Ensure that the initial handshake has completed. + + The SSL protocol requires an initial handshake to exchange + certificates, select cryptographic keys, and so forth, before any + actual data can be sent or received. You don't have to call this + method; if you don't, then :class:`SSLStream` will automatically + peform the handshake as needed, the first time you try to send or + receive data. But if you want to trigger it manually – for example, + because you want to look at the peer's certificate before you start + talking to them – then you can call this method. + + If the initial handshake is already in progress in another task, this + waits for it to complete and then returns. + + If the initial handshake has already completed, this returns + immediately without doing anything (except executing a checkpoint). + + .. warning:: If this method is cancelled, then it may leave the + :class:`SSLStream` in an unusable state. If this happens then any + future attempt to use the object will raise + :exc:`trio.BrokenStreamError`. + + """ + try: + self._check_status() + except: + await _core.yield_briefly() + raise + await self._handshook.ensure(checkpoint=True) + + # Most things work if we don't explicitly force do_handshake to be called + # before calling receive_some or send_all, because openssl will + # automatically perform the handshake on the first SSL_{read,write} + # call. BUT, allowing openssl to do this will disable Python's hostname + # checking!!! See: + # https://bugs.python.org/issue30141 + # So we *definitely* have to make sure that do_handshake is called + # before doing anything else. + async def receive_some(self, max_bytes): + """Read some data from the underlying transport, decrypt it, and + return it. + + See :meth:`trio.abc.ReceiveStream.receive_some` for details. + + .. warning:: If this method is cancelled while the initial handshake + or a renegotiation are in progress, then it may leave the + :class:`SSLStream` in an unusable state. If this happens then any + future attempt to use the object will raise + :exc:`trio.BrokenStreamError`. + + """ + async with self._outer_recv_lock: + self._check_status() + try: + await self._handshook.ensure(checkpoint=False) + except _streams.BrokenStreamError as exc: + # For some reason, EOF before handshake sometimes raises + # SSLSyscallError instead of SSLEOFError (e.g. on my linux + # laptop, but not on appveyor). Thanks openssl. + if (self._https_compatible + and isinstance( + exc.__cause__, (SSLEOFError, SSLSyscallError))): + return b"" + else: + raise + max_bytes = _operator.index(max_bytes) + if max_bytes < 1: + raise ValueError("max_bytes must be >= 1") + try: + return await self._retry(self._ssl_object.read, max_bytes) + except _streams.BrokenStreamError as exc: + # This isn't quite equivalent to just returning b"" in the + # first place, because we still end up with self._state set to + # BROKEN. But that's actually fine, because after getting an + # EOF on TLS then the only thing you can do is close the + # stream, and closing doesn't care about the state. + if (self._https_compatible + and isinstance(exc.__cause__, SSLEOFError)): + return b"" + else: + raise + + async def send_all(self, data): + """Encrypt some data and then send it on the underlying transport. + + See :meth:`trio.abc.SendStream.send_all` for details. + + .. warning:: If this method is cancelled, then it may leave the + :class:`SSLStream` in an unusable state. If this happens then any + attempt to use the object will raise + :exc:`trio.BrokenStreamError`. + + """ + async with self._outer_send_lock: + self._check_status() + await self._handshook.ensure(checkpoint=False) + # SSLObject interprets write(b"") as an EOF for some reason, which + # is not what we want. + if not data: + await _core.yield_briefly() + return + await self._retry(self._ssl_object.write, data) + + async def unwrap(self): + """Cleanly close down the SSL/TLS encryption layer, allowing the + underlying stream to be used for unencrypted communication. + + You almost certainly don't need this. + + Returns: + A pair ``(transport_stream, trailing_bytes)``, where + ``transport_stream`` is the underlying transport stream, and + ``trailing_bytes`` is a byte string. Since :class:`SSLStream` + doesn't necessarily know where the end of the encrypted data will + be, it can happen that it accidentally reads too much from the + underlying stream. ``trailing_bytes`` contains this extra data; you + should process it as if it was returned from a call to + ``transport_stream.receive_some(...)``. + + """ + async with self._outer_recv_lock, self._outer_send_lock: + self._check_status() + await self._handshook.ensure(checkpoint=False) + await self._retry(self._ssl_object.unwrap) + transport_stream = self.transport_stream + self.transport_stream = None + self._state = _State.CLOSED + return (transport_stream, self._incoming.read()) + + def forceful_close(self): + """Forcefully closes the underlying transport and marks this stream as + closed. + + """ + if self._state is not _State.CLOSED: + self._state = _State.CLOSED + self.transport_stream.forceful_close() + + async def graceful_close(self): + """Gracefully shut down this connection, and close the underlying + transport. + + If ``https_compatible`` is False (the default), then this attempts to + first send a ``close_notify`` and then close the underlying stream by + calling its :meth:`~trio.abc.AsyncResource.graceful_close` method. + + If ``https_compatible`` is set to True, then this simply closes the + underlying stream and marks this stream as closed. + + """ + if self._state is _State.CLOSED: + await _core.yield_briefly() + return + if self._state is _State.BROKEN or self._https_compatible: + self._state = _State.CLOSED + await self.transport_stream.graceful_close() + return + try: + await self._handshook.ensure(checkpoint=False) + # Here, we call SSL_shutdown *once*, because we want to send a + # close_notify but *not* wait for the other side to send back a + # response. In principle it would be more polite to wait for the + # other side to reply with their own close_notify. However, if + # they aren't paying attention (e.g., if they're just sending + # data and not receiving) then we will never notice our + # close_notify and we'll be waiting forever. Eventually we'll time + # out (hopefully), but it's still kind of nasty. And we can't + # require the other side to always be receiving, because (a) + # backpressure is kind of important, and (b) I bet there are + # broken TLS implementations out there that don't receive all the + # time. (Like e.g. anyone using Python ssl in synchronous mode.) + # + # The send-then-immediately-close behavior is explicitly allowed + # by the TLS specs, so we're ok on that. + # + # Subtlety: SSLObject.unwrap will immediately call it a second + # time, and the second time will raise SSLWantReadError because + # there hasn't been time for the other side to respond + # yet. (Unless they spontaneously sent a close_notify before we + # called this, and it's either already been processed or gets + # pulled out of the buffer by Python's second call.) So the way to + # do what we want is to ignore SSLWantReadError on this call. + # + # Also, because the other side might have already sent + # close_notify and closed their connection then it's possible that + # our attempt to send close_notify will raise + # BrokenStreamError. This is totally legal, and in fact can happen + # with two well-behaved trio programs talking to each other, so we + # don't want to raise an error. So we suppress BrokenStreamError + # here. (This is safe, because literally the only thing this call + # to _retry will do is send the close_notify alert, so that's + # surely where the error comes from.) + # + # FYI in some cases this could also raise SSLSyscallError which I + # think is because SSL_shutdown is terrible. (Check out that note + # at the bottom of the man page saying that it sometimes gets + # raised spuriously.) I haven't seen this since we switched to + # immediately closing the socket, and I don't know exactly what + # conditions cause it and how to respond, so for now we're just + # letting that happen. But if you start seeing it, then hopefully + # this will give you a little head start on tracking it down, + # because whoa did this puzzle us at the 2017 PyCon sprints. + try: + await self._retry( + self._ssl_object.unwrap, ignore_want_read=True) + except _streams.BrokenStreamError: + pass + # Close the underlying stream + await self.transport_stream.graceful_close() + except: + self.transport_stream.forceful_close() + raise + finally: + self._state = _State.CLOSED + + async def wait_send_all_might_not_block(self): + """See :meth:`trio.abc.SendStream.wait_send_all_might_not_block`. + + """ + # This method's implementation is deceptively simple. + # + # First, we take the outer send lock, because of trio's standard + # semantics that wait_send_all_might_not_block and send_all conflict. + async with self._outer_send_lock: + self._check_status() + # Then we take the inner send lock. We know that no other tasks + # are calling self.send_all or self.wait_send_all_might_not_block, + # because we have the outer_send_lock. But! There might be another + # task calling self.receive_some -> transport_stream.send_all, in + # which case if we were to call + # transport_stream.wait_send_all_might_not_block directly we'd + # have two tasks doing write-related operations on + # transport_stream simultaneously, which is not allowed. We + # *don't* want to raise this conflict to our caller, because it's + # purely an internal affair – all they did was call + # wait_send_all_might_not_block and receive_some at the same time, + # which is totally valid. And waiting for the lock is OK, because + # a call to send_all certainly wouldn't complete while the other + # task holds the lock. + async with self._inner_send_lock: + # Now we have the lock, which creates another potential + # problem: what if a call to self.receive_some attempts to do + # transport_stream.send_all now? It'll have to wait for us to + # finish! But that's OK, because we release the lock as soon + # as the underlying stream becomes writable, and the + # self.receive_some call wasn't going to make any progress + # until then anyway. + # + # Of course, this does mean we might return *before* the + # stream is logically writable, because immediately after we + # return self.receive_some might write some data and make it + # non-writable again. But that's OK too, + # wait_send_all_might_not_block only guarantees that it + # doesn't return late. + await self.transport_stream.wait_send_all_might_not_block() diff --git a/trio/testing.py b/trio/testing.py index a04e4869e0..93a27726ed 100644 --- a/trio/testing.py +++ b/trio/testing.py @@ -5,17 +5,27 @@ from collections import defaultdict import time from math import inf +import random +import operator import attr from async_generator import async_generator, yield_ -from ._util import acontextmanager +from . import _util from . import _core -from . import Event, sleep -from .abc import Clock +from . import _streams +from . import Event as _Event, sleep as _sleep +from . import abc as _abc -__all__ = ["wait_all_tasks_blocked", "trio_test", "MockClock", - "assert_yields", "assert_no_yields", "Sequencer"] +__all__ = [ + "wait_all_tasks_blocked", "trio_test", "MockClock", + "assert_yields", "assert_no_yields", "Sequencer", + "MemorySendStream", "MemoryReceiveStream", "memory_stream_pump", + "memory_stream_one_way_pair", "memory_stream_pair", + "lockstep_stream_one_way_pair", "lockstep_stream_pair", + "check_one_way_stream", "check_two_way_stream", + "check_half_closeable_stream", +] # re-export from ._core import wait_all_tasks_blocked @@ -32,7 +42,7 @@ def trio_test(fn): @wraps(fn) def wrapper(**kwargs): __tracebackhide__ = True - clocks = [c for c in kwargs.values() if isinstance(c, Clock)] + clocks = [c for c in kwargs.values() if isinstance(c, _abc.Clock)] if not clocks: clock = None elif len(clocks) == 1: @@ -42,10 +52,15 @@ def wrapper(**kwargs): return _core.run(partial(fn, **kwargs), clock=clock) return wrapper + +################################################################ +# The glorious MockClock +################################################################ + # Prior art: # https://twistedmatrix.com/documents/current/api/twisted.internet.task.Clock.html # https://github.com/ztellman/manifold/issues/57 -class MockClock(Clock): +class MockClock(_abc.Clock): """A user-controllable clock suitable for writing tests. Args: @@ -61,41 +76,53 @@ class MockClock(Clock): .. attribute:: autojump_threshold - If all tasks are blocked for this many real seconds (i.e., according to - the actual clock, not this clock), then this clock automatically jumps - ahead to the run loop's next scheduled timeout. Default is - :data:`math.inf`, i.e., to never autojump. You can assign to this - attribute to change it. + The clock keeps an eye on the run loop, and if at any point it detects + that all tasks have been blocked for this many real seconds (i.e., + according to the actual clock, not this clock), then the clock + automatically jumps ahead to the run loop's next scheduled + timeout. Default is :data:`math.inf`, i.e., to never autojump. You can + assign to this attribute to change it. + + Basically the idea is that if you have code or tests that use sleeps + and timeouts, you can use this to make it run much faster, totally + automatically. (At least, as long as those sleeps/timeouts are + happening inside trio; if your test involves talking to external + service and waiting for it to timeout then obviously we can't help you + there.) You should set this to the smallest value that lets you reliably avoid "false alarms" where some I/O is in flight (e.g. between two halves of a socketpair) but the threshold gets triggered and time gets advanced anyway. This will depend on the details of your tests and test environment. If you aren't doing any I/O (like in our sleeping example - above) then setting it to zero is fine. - - Note that setting this attribute interacts with the run loop, so it can - only be done from inside a run context or (as a special case) before - calling :func:`trio.run`. + above) then just set it to zero, and the clock will jump whenever all + tasks are blocked. .. warning:: If you're using :func:`wait_all_tasks_blocked` and :attr:`autojump_threshold` together, then you have to be - careful. Setting :attr:`autojump_threshold` acts like a task - calling:: + careful. Setting :attr:`autojump_threshold` acts like a background + task calling:: while True: - await wait_all_tasks_blocked(cushion=clock.autojump_threshold) + await wait_all_tasks_blocked( + cushion=clock.autojump_threshold, tiebreaker=float("inf")) This means that if you call :func:`wait_all_tasks_blocked` with a cushion *larger* than your autojump threshold, then your call to :func:`wait_all_tasks_blocked` will never return, because the autojump task will keep waking up before your task does, and each - time it does it'll reset your task's timer. + time it does it'll reset your task's timer. However, if your cushion + and the autojump threshold are the *same*, then the autojump's + tiebreaker will prevent them from interfering (unless you also set + your tiebreaker to infinity for some reason. Don't do that). As an + important special case: this means that if you set an autojump + threshold of zero and use :func:`wait_all_tasks_blocked` with the + default zero cushion, then everything will work fine. - **Summary**: you should set :attr:`autojump_threshold` to be at *least* - as large as the largest cushion you plan to pass to + **Summary**: you should set :attr:`autojump_threshold` to be at + least as large as the largest cushion you plan to pass to :func:`wait_all_tasks_blocked`. """ @@ -146,11 +173,12 @@ async def _autojumper(self): self._autojump_cancel_scope = cancel_scope try: # If the autojump_threshold changes, then the setter does - # cancel_scope.cancel() ,which causes this line to raise - # Cancelled, which is absorbed by the cancel scope above, - # and effectively just causes us to skip start the loop - # over, like a 'continue' here. - await wait_all_tasks_blocked(self._autojump_threshold) + # cancel_scope.cancel(), which causes the next line here + # to raise Cancelled, which is absorbed by the cancel + # scope above, and effectively just causes us to skip back + # to the start the loop, like a 'continue'. + await wait_all_tasks_blocked( + self._autojump_threshold, float("inf")) statistics = _core.current_statistics() jump = statistics.seconds_to_next_deadline if jump < inf: @@ -160,7 +188,7 @@ async def _autojumper(self): # until some actual I/O arrives (or maybe another # wait_all_tasks_blocked task wakes up). That's fine, # but if our threshold is zero then this will become a - # busy-wait -- so insert a small-but-non-zero sleep to + # busy-wait -- so insert a small-but-non-zero _sleep to # avoid that. if self._autojump_threshold == 0: await wait_all_tasks_blocked(0.01) @@ -221,6 +249,10 @@ def jump(self, seconds): self._virtual_base += seconds +################################################################ +# Testing checkpoints +################################################################ + @contextmanager def _assert_yields_or_not(expected): __tracebackhide__ = True @@ -276,6 +308,10 @@ def assert_no_yields(): return _assert_yields_or_not(False) +################################################################ +# Sequencer +################################################################ + @attr.s(slots=True, cmp=False, hash=False) class Sequencer: """A convenience class for forcing code in different tasks to run in an @@ -318,11 +354,11 @@ async def main(): """ _sequence_points = attr.ib( - default=attr.Factory(lambda: defaultdict(Event)), init=False) + default=attr.Factory(lambda: defaultdict(_Event)), init=False) _claimed = attr.ib(default=attr.Factory(set), init=False) _broken = attr.ib(default=False, init=False) - @acontextmanager + @_util.acontextmanager @async_generator async def __call__(self, position): if position in self._claimed: @@ -347,3 +383,993 @@ async def __call__(self, position): await yield_() finally: self._sequence_points[position + 1].set() + + +################################################################ +# Generic stream tests +################################################################ + +class _CloseBoth: + def __init__(self, both): + self._both = list(both) + + async def __aenter__(self): + return self._both + + async def __aexit__(self, *args): + try: + self._both[0].forceful_close() + finally: + self._both[1].forceful_close() + + +@contextmanager +def _assert_raises(exc): + __tracebackhide__ = True + try: + yield + except exc: + pass + else: + raise AssertionError("expected exception: {}".format(exc)) + + +async def check_one_way_stream(stream_maker, clogged_stream_maker): + """Perform a number of generic tests on a custom one-way stream + implementation. + + Args: + stream_maker: An async (!) function which returns a connected + (:class:`~trio.abc.SendStream`, :class:`~trio.abc.ReceiveStream`) + pair. + clogged_stream_maker: Either None, or an async function similar to + stream_maker, but with the extra property that the returned stream + is in a state where ``send_all`` and + ``wait_send_all_might_not_block`` will block until ``receive_some`` + has been called. This allows for more thorough testing of some edge + cases, especially around ``wait_send_all_might_not_block``. + + Raises: + AssertionError: if a test fails. + + """ + async with _CloseBoth(await stream_maker()) as (s, r): + assert isinstance(s, _abc.SendStream) + assert isinstance(r, _abc.ReceiveStream) + + async def do_send_all(data): + with assert_yields(): + assert await s.send_all(data) is None + + async def do_receive_some(max_bytes): + with assert_yields(): + return await r.receive_some(1) + + async def checked_receive_1(expected): + assert await do_receive_some(1) == expected + + async def do_graceful_close(resource): + with assert_yields(): + await resource.graceful_close() + + # Simple sending/receiving + async with _core.open_nursery() as nursery: + nursery.spawn(do_send_all, b"x") + nursery.spawn(checked_receive_1, b"x") + + async def send_empty_then_y(): + # Streams should tolerate sending b"" without giving it any + # special meaning. + await do_send_all(b"") + await do_send_all(b"y") + + async with _core.open_nursery() as nursery: + nursery.spawn(send_empty_then_y) + nursery.spawn(checked_receive_1, b"y") + + ### Checking various argument types + + # send_all accepts bytearray and memoryview + async with _core.open_nursery() as nursery: + nursery.spawn(do_send_all, bytearray(b"1")) + nursery.spawn(checked_receive_1, b"1") + + async with _core.open_nursery() as nursery: + nursery.spawn(do_send_all, memoryview(b"2")) + nursery.spawn(checked_receive_1, b"2") + + # max_bytes must be a positive integer + with _assert_raises(ValueError): + await r.receive_some(-1) + with _assert_raises(ValueError): + await r.receive_some(0) + with _assert_raises(TypeError): + await r.receive_some(1.5) + + with _assert_raises(_core.ResourceBusyError): + async with _core.open_nursery() as nursery: + nursery.spawn(do_receive_some, 1) + nursery.spawn(do_receive_some, 1) + + # Method always has to exist, and an empty stream with a blocked + # receive_some should *always* allow send_all. (Technically it's legal + # for send_all to wait until receive_some is called to run, though; a + # stream doesn't *have* to have any internal buffering. That's why we + # spawn a concurrent receive_some call, then cancel it.) + async def simple_check_wait_send_all_might_not_block(scope): + with assert_yields(): + await s.wait_send_all_might_not_block() + scope.cancel() + + async with _core.open_nursery() as nursery: + nursery.spawn(simple_check_wait_send_all_might_not_block, + nursery.cancel_scope) + nursery.spawn(do_receive_some, 1) + + # closing the r side leads to BrokenStreamError on the s side + # (eventually) + async def expect_broken_stream_on_send(): + with _assert_raises(_streams.BrokenStreamError): + while True: + await do_send_all(b"x" * 100) + + async with _core.open_nursery() as nursery: + nursery.spawn(expect_broken_stream_on_send) + nursery.spawn(do_graceful_close, r) + + # once detected, the stream stays broken + with _assert_raises(_streams.BrokenStreamError): + await do_send_all(b"x" * 100) + + # r closed -> ClosedStreamError on the receive side + with _assert_raises(_streams.ClosedStreamError): + await do_receive_some(4096) + + # we can close the same stream repeatedly, it's fine + r.forceful_close() + with assert_yields(): + await r.graceful_close() + r.forceful_close() + + # closing the sender side + with assert_yields(): + await s.graceful_close() + + # now trying to send raises ClosedStreamError + with _assert_raises(_streams.ClosedStreamError): + await do_send_all(b"x" * 100) + + # ditto for wait_send_all_might_not_block + with _assert_raises(_streams.ClosedStreamError): + with assert_yields(): + await s.wait_send_all_might_not_block() + + # and again, repeated closing is fine + s.forceful_close() + await do_graceful_close(s) + s.forceful_close() + + async with _CloseBoth(await stream_maker()) as (s, r): + # if send-then-graceful-close, receiver gets data then b"" + async def send_then_close(): + await do_send_all(b"y") + await do_graceful_close(s) + + async def receive_send_then_close(): + # We want to make sure that if the sender closes the stream before + # we read anything, then we still get all the data. But some + # streams might block on the do_send_all call. So we let the + # sender get as far as it can, then we receive. + await wait_all_tasks_blocked() + await checked_receive_1(b"y") + await checked_receive_1(b"") + await do_graceful_close(r) + + async with _core.open_nursery() as nursery: + nursery.spawn(send_then_close) + nursery.spawn(receive_send_then_close) + + # using forceful_close also makes things closed + async with _CloseBoth(await stream_maker()) as (s, r): + r.forceful_close() + + with _assert_raises(_streams.BrokenStreamError): + while True: + await do_send_all(b"x" * 100) + + with _assert_raises(_streams.ClosedStreamError): + await do_receive_some(4096) + + async with _CloseBoth(await stream_maker()) as (s, r): + s.forceful_close() + + with _assert_raises(_streams.ClosedStreamError): + await do_send_all(b"123") + + # after the sender does a forceful close, the receiver might either + # get BrokenStreamError or a clean b""; either is OK. Not OK would be + # if it freezes, or returns data. + try: + await checked_receive_1(b"") + except _streams.BrokenStreamError: + pass + + # cancelled graceful_close still closes + async with _CloseBoth(await stream_maker()) as (s, r): + with _core.open_cancel_scope() as scope: + scope.cancel() + await r.graceful_close() + + with _core.open_cancel_scope() as scope: + scope.cancel() + await s.graceful_close() + + with _assert_raises(_streams.ClosedStreamError): + await do_send_all(b"123") + + with _assert_raises(_streams.ClosedStreamError): + await do_receive_some(4096) + + # Check that we can still gracefully close a stream after an operation has + # been cancelled. This can be challenging if cancellation can leave the + # stream internals in an inconsistent state, e.g. for + # SSLStream. Unfortunately this test isn't very thorough; the really + # challenging case for something like SSLStream is it gets cancelled + # *while* it's sending data on the underlying, not before. But testing + # that requires some special-case handling of the particular stream setup; + # we can't do it here. Maybe we could do a bit better with + # https://github.com/python-trio/trio/issues/77 + async with _CloseBoth(await stream_maker()) as (s, r): + async def expect_cancelled(afn, *args): + with _assert_raises(_core.Cancelled): + await afn(*args) + + with _core.open_cancel_scope() as scope: + scope.cancel() + async with _core.open_nursery() as nursery: + nursery.spawn(expect_cancelled, do_send_all, b"x") + nursery.spawn(expect_cancelled, do_receive_some, 1) + + async with _core.open_nursery() as nursery: + nursery.spawn(do_graceful_close, s) + nursery.spawn(do_graceful_close, r) + + # check wait_send_all_might_not_block, if we can + if clogged_stream_maker is not None: + async with _CloseBoth(await clogged_stream_maker()) as (s, r): + record = [] + + async def waiter(cancel_scope): + record.append("waiter sleeping") + with assert_yields(): + await s.wait_send_all_might_not_block() + record.append("waiter wokeup") + cancel_scope.cancel() + + async def receiver(): + # give wait_send_all_might_not_block a chance to block + await wait_all_tasks_blocked() + record.append("receiver starting") + while True: + await r.receive_some(16834) + + async with _core.open_nursery() as nursery: + nursery.spawn(waiter, nursery.cancel_scope) + await wait_all_tasks_blocked() + nursery.spawn(receiver) + + assert record == [ + "waiter sleeping", + "receiver starting", + "waiter wokeup", + ] + + async with _CloseBoth(await clogged_stream_maker()) as (s, r): + # simultaneous wait_send_all_might_not_block fails + with _assert_raises(_core.ResourceBusyError): + async with _core.open_nursery() as nursery: + nursery.spawn(s.wait_send_all_might_not_block) + nursery.spawn(s.wait_send_all_might_not_block) + + # and simultaneous send_all and wait_send_all_might_not_block (NB + # this test might destroy the stream b/c we end up cancelling + # send_all and e.g. SSLStream can't handle that, so we have to + # recreate afterwards) + with _assert_raises(_core.ResourceBusyError): + async with _core.open_nursery() as nursery: + nursery.spawn(s.wait_send_all_might_not_block) + nursery.spawn(s.send_all, b"123") + + async with _CloseBoth(await clogged_stream_maker()) as (s, r): + # send_all and send_all blocked simultaneously should also raise + # (but again this might destroy the stream) + with _assert_raises(_core.ResourceBusyError): + async with _core.open_nursery() as nursery: + nursery.spawn(s.send_all, b"123") + nursery.spawn(s.send_all, b"123") + + # closing the receiver causes wait_send_all_might_not_block to return + async with _CloseBoth(await clogged_stream_maker()) as (s, r): + async def sender(): + try: + with assert_yields(): + await s.wait_send_all_might_not_block() + except _streams.BrokenStreamError: + pass + + async def receiver(): + await wait_all_tasks_blocked() + r.forceful_close() + + async with _core.open_nursery() as nursery: + nursery.spawn(sender) + nursery.spawn(receiver) + + # and again with the call starting after the close + async with _CloseBoth(await clogged_stream_maker()) as (s, r): + r.forceful_close() + try: + with assert_yields(): + await s.wait_send_all_might_not_block() + except _streams.BrokenStreamError: + pass + + +async def check_two_way_stream(stream_maker, clogged_stream_maker): + """Perform a number of generic tests on a custom two-way stream + implementation. + + This is similar to :func:`check_one_way_stream`, except that the maker + functions are expected to return objects implementing the + :class:`~trio.abc.Stream` interface. + + This function tests a *superset* of what :func:`check_one_way_stream` + checks – if you call this, then you don't need to also call + :func:`check_one_way_stream`. + + """ + await check_one_way_stream(stream_maker, clogged_stream_maker) + + async def flipped_stream_maker(): + return reversed(await stream_maker()) + if clogged_stream_maker is not None: + async def flipped_clogged_stream_maker(): + return reversed(await clogged_stream_maker()) + else: + flipped_clogged_stream_maker = None + await check_one_way_stream( + flipped_stream_maker, flipped_clogged_stream_maker) + + async with _CloseBoth(await stream_maker()) as (s1, s2): + assert isinstance(s1, _abc.Stream) + assert isinstance(s2, _abc.Stream) + + # Duplex can be a bit tricky, might as well check it as well + DUPLEX_TEST_SIZE = 2 ** 20 + CHUNK_SIZE_MAX = 2 ** 14 + + r = random.Random(0) + i = r.getrandbits(8 * DUPLEX_TEST_SIZE) + test_data = i.to_bytes(DUPLEX_TEST_SIZE, "little") + + async def sender(s, data, seed): + r = random.Random(seed) + m = memoryview(data) + while m: + chunk_size = r.randint(1, CHUNK_SIZE_MAX) + await s.send_all(m[:chunk_size]) + m = m[chunk_size:] + + async def receiver(s, data, seed): + r = random.Random(seed) + got = bytearray() + while len(got) < len(data): + chunk = await s.receive_some(r.randint(1, CHUNK_SIZE_MAX)) + assert chunk + got += chunk + assert got == data + + async with _core.open_nursery() as nursery: + nursery.spawn(sender, s1, test_data, 0) + nursery.spawn(sender, s2, test_data[::-1], 1) + nursery.spawn(receiver, s1, test_data[::-1], 2) + nursery.spawn(receiver, s2, test_data, 3) + + async def expect_receive_some_empty(): + assert await s2.receive_some(10) == b"" + await s2.graceful_close() + + async with _core.open_nursery() as nursery: + nursery.spawn(expect_receive_some_empty) + nursery.spawn(s1.graceful_close) + + +async def check_half_closeable_stream(stream_maker, clogged_stream_maker): + """Perform a number of generic tests on a custom half-closeable stream + implementation. + + This is similar to :func:`check_two_way_stream`, except that the maker + functions are expected to return objects that implement the + :class:`~trio.abc.HalfCloseableStream` interface. + + This function tests a *superset* of what :func:`check_two_way_stream` + checks – if you call this, then you don't need to also call + :func:`check_two_way_stream`. + + """ + await check_two_way_stream(stream_maker, clogged_stream_maker) + + async with _CloseBoth(await stream_maker()) as (s1, s2): + assert isinstance(s1, _abc.HalfCloseableStream) + assert isinstance(s2, _abc.HalfCloseableStream) + + async def send_x_then_eof(s): + await s.send_all(b"x") + with assert_yields(): + await s.send_eof() + + async def expect_x_then_eof(r): + await wait_all_tasks_blocked() + assert await r.receive_some(10) == b"x" + assert await r.receive_some(10) == b"" + + async with _core.open_nursery() as nursery: + nursery.spawn(send_x_then_eof, s1) + nursery.spawn(expect_x_then_eof, s2) + + # now sending is disallowed + with _assert_raises(_streams.ClosedStreamError): + await s1.send_all(b"y") + + # but we can do send_eof again + with assert_yields(): + await s1.send_eof() + + # and we can still send stuff back the other way + async with _core.open_nursery() as nursery: + nursery.spawn(send_x_then_eof, s2) + nursery.spawn(expect_x_then_eof, s1) + + if clogged_stream_maker is not None: + async with _CloseBoth(await clogged_stream_maker()) as (s1, s2): + # send_all and send_eof simultaneously is not ok + with _assert_raises(_core.ResourceBusyError): + async with _core.open_nursery() as nursery: + t = nursery.spawn(s1.send_all, b"x") + await wait_all_tasks_blocked() + assert t.result is None + nursery.spawn(s1.send_eof) + + async with _CloseBoth(await clogged_stream_maker()) as (s1, s2): + # wait_send_all_might_not_block and send_eof simultaneously is not + # ok either + with _assert_raises(_core.ResourceBusyError): + async with _core.open_nursery() as nursery: + t = nursery.spawn(s1.wait_send_all_might_not_block) + await wait_all_tasks_blocked() + assert t.result is None + nursery.spawn(s1.send_eof) + + +################################################################ +# In-memory streams +################################################################ + +class _UnboundedByteQueue: + def __init__(self): + self._data = bytearray() + self._closed = False + self._lot = _core.ParkingLot() + self._fetch_lock = _util.UnLock( + _core.ResourceBusyError, "another task is already fetching data") + + def close(self): + self._closed = True + self._lot.unpark_all() + + def put(self, data): + if self._closed: + raise _streams.ClosedStreamError("virtual connection closed") + self._data += data + self._lot.unpark_all() + + def _check_max_bytes(self, max_bytes): + if max_bytes is None: + return + max_bytes = operator.index(max_bytes) + if max_bytes < 1: + raise ValueError("max_bytes must be >= 1") + + def _get_impl(self, max_bytes): + assert self._closed or self._data + if max_bytes is None: + max_bytes = len(self._data) + if self._data: + chunk = self._data[:max_bytes] + del self._data[:max_bytes] + assert chunk + return chunk + else: + return bytearray() + + def get_nowait(self, max_bytes=None): + with self._fetch_lock.sync: + self._check_max_bytes(max_bytes) + if not self._closed and not self._data: + raise _core.WouldBlock + return self._get_impl(max_bytes) + + async def get(self, max_bytes=None): + async with self._fetch_lock: + self._check_max_bytes(max_bytes) + if not self._closed and not self._data: + await self._lot.park() + return self._get_impl(max_bytes) + + +class MemorySendStream(_abc.SendStream): + """An in-memory :class:`~trio.abc.SendStream`. + + Args: + send_all_hook: An async function, or None. Called from + :meth:`send_all`. Can do whatever you like. + wait_send_all_might_not_block_hook: An async function, or None. Called + from :meth:`wait_send_all_might_not_block`. Can do whatever you + like. + close_hook: A synchronous function, or None. Called from + :meth:`forceful_close`. Can do whatever you like. + + .. attribute:: send_all_hook + wait_send_all_might_not_block_hook + close_hook + + All of these hooks are also exposed as attributes on the object, and + you can change them at any time. + + """ + def __init__(self, + send_all_hook=None, + wait_send_all_might_not_block_hook=None, + close_hook=None): + self._lock = _util.UnLock( + _core.ResourceBusyError, "another task is using this stream") + self._outgoing = _UnboundedByteQueue() + self.send_all_hook = send_all_hook + self.wait_send_all_might_not_block_hook = wait_send_all_might_not_block_hook + self.close_hook = close_hook + + async def send_all(self, data): + """Places the given data into the object's internal buffer, and then + calls the :attr:`send_all_hook` (if any). + + """ + # The lock itself is a checkpoint, but then we also yield inside the + # lock to give ourselves a chance to detect buggy user code that calls + # this twice at the same time. + async with self._lock: + await _core.yield_briefly() + self._outgoing.put(data) + if self.send_all_hook is not None: + await self.send_all_hook() + + async def wait_send_all_might_not_block(self): + """Calls the :attr:`wait_send_all_might_not_block_hook` (if any), and + then returns immediately. + + """ + # The lock itself is a checkpoint, but then we also yield inside the + # lock to give ourselves a chance to detect buggy user code that calls + # this twice at the same time. + async with self._lock: + await _core.yield_briefly() + # check for being closed: + self._outgoing.put(b"") + if self.wait_send_all_might_not_block_hook is not None: + await self.wait_send_all_might_not_block_hook() + + def forceful_close(self): + """Marks this stream as closed, and then calls the :attr:`close_hook` + (if any). + + """ + self._outgoing.close() + if self.close_hook is not None: + self.close_hook() + + async def get_data(self, max_bytes=None): + """Retrieves data from the internal buffer, blocking if necessary. + + Args: + max_bytes (int or None): The maximum amount of data to + retrieve. None (the default) means to retrieve all the data + that's present (but still blocks until at least one byte is + available). + + Returns: + If this stream has been closed, an empty bytearray. Otherwise, the + requested data. + + """ + return await self._outgoing.get(max_bytes) + + def get_data_nowait(self, max_bytes=None): + """Retrieves data from the internal buffer, but doesn't block. + + See :meth:`get_data` for details. + + Raises: + trio.WouldBlock: if no data is available to retrieve. + + """ + return self._outgoing.get_nowait(max_bytes) + + +class MemoryReceiveStream(_abc.ReceiveStream): + """An in-memory :class:`~trio.abc.ReceiveStream`. + + Args: + receive_some_hook: An async function, or None. Called from + :meth:`receive_some`. Can do whatever you like. + close_hook: A synchronous function, or None. Called from + :meth:`forceful_close`. Can do whatever you like. + + .. attribute:: receive_some_hook + close_hook + + Both hooks are also exposed as attributes on the object, and you can + change them at any time. + + """ + def __init__(self, receive_some_hook=None, close_hook=None): + self._lock = _util.UnLock( + _core.ResourceBusyError, "another task is using this stream") + self._incoming = _UnboundedByteQueue() + self._closed = False + self.receive_some_hook = receive_some_hook + self.close_hook = close_hook + + async def receive_some(self, max_bytes): + """Calls the :attr:`receive_some_hook` (if any), and then retrieves + data from the internal buffer, blocking if necessary. + + """ + # The lock itself is a checkpoint, but then we also yield inside the + # lock to give ourselves a chance to detect buggy user code that calls + # this twice at the same time. + async with self._lock: + await _core.yield_briefly() + if max_bytes is None: + raise TypeError("max_bytes must not be None") + if self._closed: + raise _streams.ClosedStreamError + if self.receive_some_hook is not None: + await self.receive_some_hook() + return await self._incoming.get(max_bytes) + + def forceful_close(self): + """Discards any pending data from the internal buffer, and marks this + stream as closed. + + """ + # discard any pending data + self._closed = True + try: + self._incoming.get_nowait() + except _core.WouldBlock: + pass + self._incoming.close() + if self.close_hook is not None: + self.close_hook() + + def put_data(self, data): + """Appends the given data to the internal buffer. + + """ + self._incoming.put(data) + + def put_eof(self): + """Adds an end-of-file marker to the internal buffer. + + """ + self._incoming.close() + + +def memory_stream_pump( + memory_send_stream, memory_recieve_stream, *, max_bytes=None): + """Take data out of the given :class:`MemorySendStream`'s internal buffer, + and put it into the given :class:`MemoryReceiveStream`'s internal buffer. + + Args: + memory_send_stream (MemorySendStream): The stream to get data from. + memory_recieve_stream (MemoryReceiveStream): The stream to put data into. + max_bytes (int or None): The maximum amount of data to transfer in this + call, or None to transfer all available data. + + Returns: + True if it successfully transferred some data, or False if there was no + data to transfer. + + This is used to implement :func:`memory_stream_one_way_pair` and + :func:`memory_stream_pair`; see the latter's docstring for an example + of how you might use it yourself. + + """ + try: + data = memory_send_stream.get_data_nowait(max_bytes) + except _core.WouldBlock: + return False + try: + if not data: + memory_recieve_stream.put_eof() + else: + memory_recieve_stream.put_data(data) + except _streams.ClosedStreamError: + raise _streams.BrokenStreamError("MemoryReceiveStream was closed") + return True + + +def memory_stream_one_way_pair(): + """Create a connected, pure-Python, unidirectional stream with infinite + buffering and flexible configuration options. + + You can think of this as being a no-operating-system-involved + trio-streamsified version of :func:`os.pipe` (except that :func:`os.pipe` + returns the streams in the wrong order – we follow the superior convention + that data flows from left to right). + + Returns: + A tuple (:class:`MemorySendStream`, :class:`MemoryReceiveStream`), where + the :class:`MemorySendStream` has its hooks set up so that it calls + :func:`memory_stream_pump` from its + :attr:`~MemorySendStream.send_all_hook` and + :attr:`~MemorySendStream.close_hook`. + + The end result is that data automatically flows from the + :class:`MemorySendStream` to the :class:`MemoryReceiveStream`. But you're + also free to rearrange things however you like. For example, you can + temporarily set the :attr:`~MemorySendStream.send_all_hook` to None if you + want to simulate a stall in data transmission. Or see + :func:`memory_stream_pair` for a more elaborate example. + + """ + send_stream = MemorySendStream() + recv_stream = MemoryReceiveStream() + def pump_from_send_stream_to_recv_stream(): + memory_stream_pump(send_stream, recv_stream) + async def async_pump_from_send_stream_to_recv_stream(): + pump_from_send_stream_to_recv_stream() + send_stream.send_all_hook = async_pump_from_send_stream_to_recv_stream + send_stream.close_hook = pump_from_send_stream_to_recv_stream + return send_stream, recv_stream + + +def _make_stapled_pair(one_way_pair): + pipe1_send, pipe1_recv = one_way_pair() + pipe2_send, pipe2_recv = one_way_pair() + stream1 = _streams.StapledStream(pipe1_send, pipe2_recv) + stream2 = _streams.StapledStream(pipe2_send, pipe1_recv) + return stream1, stream2 + + +def memory_stream_pair(): + """Create a connected, pure-Python, bidirectional stream with infinite + buffering and flexible configuration options. + + This is a convenience function that creates two one-way streams using + :func:`memory_stream_one_way_pair`, and then uses + :class:`~trio.StapledStream` to combine them into a single bidirectional + stream. + + This is like a no-operating-system-involved, trio-streamsified version of + :func:`socket.socketpair`. + + Returns: + A pair of :class:`~trio.StapledStream` objects that are connected so + that data automatically flows from one to the other in both directions. + + After creating a stream pair, you can send data back and forth, which is + enough for simple tests:: + + left, right = memory_stream_pair() + await left.send_all(b"123") + assert await right.receive_some(10) == b"123" + await right.send_all(b"456") + assert await left.receive_some(10) == b"456" + + But if you read the docs for :class:`~trio.StapledStream` and + :func:`memory_stream_one_way_pair`, you'll see that all the pieces + involved in wiring this up are public APIs, so you can adjust to suit the + requirements of your tests. For example, here's how to tweak a stream so + that data flowing from left to right trickles in one byte at a time (but + data flowing from right to left proceeds at full speed):: + + left, right = memory_stream_pair() + async def trickle(): + # left is a StapledStream, and left.send_stream is a MemorySendStream + # right is a StapledStream, and right.recv_stream is a MemoryReceiveStream + while memory_stream_pump(left.send_stream, right.recv_stream, max_byes=1): + # Pause between each byte + await trio.sleep(1) + # Normally this send_all_hook calls memory_stream_pump directly without + # passing in a max_bytes. We replace it with our custom version: + left.send_stream.send_all_hook = trickle + + And here's a simple test using our modified stream objects:: + + async def sender(): + await left.send_all(b"12345") + await left.send_eof() + + async def receiver(): + while True: + data = await right.receive_some(10) + if data == b"": + return + print(data) + + async with trio.open_nursery() as nursery: + nursery.spawn(sender) + nursery.spawn(receiver) + + By default, this will print ``b"12345"`` and then immediately exit; with + our trickle stream it instead sleeps 1 second, then prints ``b"1"``, then + sleeps 1 second, then prints ``b"2"``, etc. + + Pro-tip: you can insert sleep calls (like in our example above) to + manipulate the flow of data across tasks... and then use + :class:`MockClock` and its :attr:`~MockClock.autojump_threshold` + functionality to keep your test suite running quickly. + + If you want to stress test a protocol implementation, one nice trick is to + use the :mod:`random` module (preferably with a fixed seed) to move random + numbers of bytes at a time, and insert random sleeps in between them. You + can also set up a custom :attr:`~MemoryReceiveStream.receive_some_hook` if + you want to manipulate things on the receiving side, and not just the + sending side. + + """ + return _make_stapled_pair(memory_stream_one_way_pair) + + +class _LockstepByteQueue: + def __init__(self): + self._data = bytearray() + self._sender_closed = False + self._receiver_closed = False + self._receiver_waiting = False + self._waiters = _core.ParkingLot() + self._send_lock = _util.UnLock( + _core.ResourceBusyError, "another task is already sending") + self._receive_lock = _util.UnLock( + _core.ResourceBusyError, "another task is already receiving") + + def _something_happened(self): + self._waiters.unpark_all() + + async def _wait_for(self, fn): + while not fn(): + await self._waiters.park() + + def close_sender(self): + # close while send_all is in progress is undefined + self._sender_closed = True + self._something_happened() + + def close_receiver(self): + self._receiver_closed = True + self._something_happened() + + async def send_all(self, data): + async with self._send_lock: + if self._sender_closed: + raise _streams.ClosedStreamError + if self._receiver_closed: + raise _streams.BrokenStreamError + assert not self._data + self._data += data + self._something_happened() + await self._wait_for( + lambda: not self._data or self._receiver_closed) + if self._data and self._receiver_closed: + raise _streams.BrokenStreamError + if not self._data: + return + + async def wait_send_all_might_not_block(self): + async with self._send_lock: + if self._sender_closed: + raise _streams.ClosedStreamError + if self._receiver_closed: + return + await self._wait_for( + lambda: self._receiver_waiting or self._receiver_closed) + + async def receive_some(self, max_bytes): + async with self._receive_lock: + # Argument validation + max_bytes = operator.index(max_bytes) + if max_bytes < 1: + raise ValueError("max_bytes must be >= 1") + # State validation + if self._receiver_closed: + raise _streams.ClosedStreamError + # Wake wait_send_all_might_not_block and wait for data + self._receiver_waiting = True + self._something_happened() + try: + await self._wait_for(lambda: self._data or self._sender_closed) + finally: + self._receiver_waiting = False + # Get data, possibly waking send_all + if self._data: + got = self._data[:max_bytes] + del self._data[:max_bytes] + self._something_happened() + return got + else: + assert self._sender_closed + return b"" + + +class _LockstepSendStream(_abc.SendStream): + def __init__(self, lbq): + self._lbq = lbq + + def forceful_close(self): + self._lbq.close_sender() + + async def send_all(self, data): + await self._lbq.send_all(data) + + async def wait_send_all_might_not_block(self): + await self._lbq.wait_send_all_might_not_block() + + +class _LockstepReceiveStream(_abc.ReceiveStream): + def __init__(self, lbq): + self._lbq = lbq + + def forceful_close(self): + self._lbq.close_receiver() + + async def receive_some(self, max_bytes): + return await self._lbq.receive_some(max_bytes) + + +def lockstep_stream_one_way_pair(): + """Create a connected, pure Python, unidirectional stream where data flows + in lockstep. + + Returns: + A tuple + (:class:`~trio.abc.SendStream`, :class:`~trio.abc.ReceiveStream`). + + This stream has *absolutely no* buffering. Each call to + :meth:`~trio.abc.SendStream.send_all` will block until all the given data + has been returned by a call to + :meth:`~trio.abc.ReceiveStream.receive_some`. + + This can be useful for testing flow control mechanisms in an extreme case, + or for setting up "clogged" streams to use with + :func:`check_one_way_stream` and friends. + + """ + + lbq = _LockstepByteQueue() + return _LockstepSendStream(lbq), _LockstepReceiveStream(lbq) + + +def lockstep_stream_pair(): + """Create a connected, pure-Python, bidirectional stream where data flows + in lockstep. + + Returns: + A tuple (:class:`~trio.StapledStream`, :class:`~trio.StapledStream`). + + This is a convenience function that creates two one-way streams using + :func:`lockstep_stream_one_way_pair`, and then uses + :class:`~trio.StapledStream` to combine them into a single bidirectional + stream. + + """ + return _make_stapled_pair(lockstep_stream_one_way_pair) diff --git a/trio/tests/test_abc.py b/trio/tests/test_abc.py new file mode 100644 index 0000000000..4c0392ca42 --- /dev/null +++ b/trio/tests/test_abc.py @@ -0,0 +1,38 @@ +import pytest + +import attr + +from ..testing import assert_yields +from .. import abc as tabc + +async def test_AsyncResource_defaults(): + @attr.s + class MyAR(tabc.AsyncResource): + record = attr.ib(default=attr.Factory(list)) + + def forceful_close(self): + self.record.append("fc") + + async with MyAR() as myar: + assert isinstance(myar, MyAR) + assert myar.record == [] + + assert myar.record == ["fc"] + + with assert_yields(): + await myar.graceful_close() + assert myar.record == ["fc", "fc"] + + @attr.s + class BadAR(tabc.AsyncResource): + record = attr.ib(default=attr.Factory(list)) + + def forceful_close(self): + self.record.append("boom") + raise KeyError + + badar = BadAR() + with pytest.raises(KeyError): + with assert_yields(): + await badar.graceful_close() + assert badar.record == ["boom"] diff --git a/trio/tests/test_network.py b/trio/tests/test_network.py new file mode 100644 index 0000000000..89c51ce6b5 --- /dev/null +++ b/trio/tests/test_network.py @@ -0,0 +1,77 @@ +import pytest + +import socket as stdlib_socket + +from .. import _core +from ..testing import check_half_closeable_stream, wait_all_tasks_blocked +from .._network import * +from .. import socket as tsocket + +async def test_SocketStream_basics(): + # stdlib socket bad (even if connected) + a, b = stdlib_socket.socketpair() + with a, b: + with pytest.raises(TypeError): + SocketStream(a) + + # DGRAM socket bad + with tsocket.socket(type=tsocket.SOCK_DGRAM) as sock: + with pytest.raises(ValueError): + SocketStream(sock) + + # disconnected socket bad + with tsocket.socket() as sock: + with pytest.raises(ValueError): + SocketStream(sock) + + a, b = tsocket.socketpair() + with a, b: + s = SocketStream(a) + assert s.socket is a + + # Use a real, connected socket to test socket options, because + # socketpair() might give us a unix socket that doesn't support any of + # these options + with tsocket.socket() as listen_sock: + listen_sock.bind(("127.0.0.1", 0)) + listen_sock.listen(1) + with tsocket.socket() as client_sock: + await client_sock.connect(listen_sock.getsockname()) + + s = SocketStream(client_sock) + + # TCP_NODELAY enabled by default + assert s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY) + # We can disable it though + s.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, False) + assert not s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY) + + b = s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, 1) + assert isinstance(b, bytes) + + +async def fill_stream(s): + async def sender(): + while True: + await s.send_all(b"x" * 10000) + + async def waiter(nursery): + await wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + + async with _core.open_nursery() as nursery: + nursery.spawn(sender) + nursery.spawn(waiter, nursery) + + +async def test_SocketStream_and_socket_stream_pair_generic(): + async def stream_maker(): + return socket_stream_pair() + + async def clogged_stream_maker(): + left, right = socket_stream_pair() + await fill_stream(left) + await fill_stream(right) + return left, right + + await check_half_closeable_stream(stream_maker, clogged_stream_maker) diff --git a/trio/tests/test_socket.py b/trio/tests/test_socket.py index c405f7eae7..e043ee6ab1 100644 --- a/trio/tests/test_socket.py +++ b/trio/tests/test_socket.py @@ -505,26 +505,6 @@ async def test_SocketType_connect_paths(): with pytest.raises(_core.Cancelled): await sock.connect(("127.0.0.1", 80)) - # Handling InterruptedError - class InterruptySocket(stdlib_socket.socket): - def connect(self, *args, **kwargs): - if not hasattr(self, "_connect_count"): - self._connect_count = 0 - self._connect_count += 1 - if self._connect_count < 3: - raise InterruptedError - else: - return super().connect(*args, **kwargs) - with tsocket.socket() as sock, tsocket.socket() as listener: - listener.bind(("127.0.0.1", 0)) - listener.listen() - # Swap in our weird subclass under the trio.socket.SocketType's nose - sock._sock.close() - sock._sock = InterruptySocket() - with assert_yields(): - await sock.connect(listener.getsockname()) - assert sock.getpeername() == listener.getsockname() - # Cancelled in between the connect() call and the connect completing with _core.open_cancel_scope() as cancel_scope: with tsocket.socket() as sock, tsocket.socket() as listener: @@ -550,14 +530,17 @@ def connect(self, *args, **kwargs): assert sock.fileno() == -1 # Failed connect (hopefully after raising BlockingIOError) - with tsocket.socket() as sock, tsocket.socket() as non_listener: - # Claim an unused port - non_listener.bind(("127.0.0.1", 0)) - # ...but don't call listen, so we're guaranteed that connect attempts - # to it will fail. + with tsocket.socket() as sock: with assert_yields(): with pytest.raises(OSError): - await sock.connect(non_listener.getsockname()) + # TCP port 2 is not assigned. Pretty sure nothing will be + # listening there. (We used to bind a port and then *not* call + # listen() to ensure nothing was listening there, but it turns + # out on MacOS if you do this it takes 30 seconds for the + # connect to fail. Really. Also if you use a non-routable + # address. This way fails instantly though. As long as nothing + # is listening on port 2.) + await sock.connect(("127.0.0.1", 2)) async def test_send_recv_variants(): @@ -659,48 +642,45 @@ async def test_SocketType_sendall(): # Check a sendall that has to be split into multiple parts (on most # platforms... on Windows every send() either succeeds or fails as a # whole) - async with _core.open_nursery() as nursery: - send_task = nursery.spawn(a.sendall, b"x" * BIG) + async def sender(): + data = bytearray(BIG) + await a.sendall(data) + # sendall uses memoryviews internally, which temporarily "lock" + # the object they view. If it doesn't clean them up properly, then + # some bytearray operations might raise an error afterwards, which + # would be a pretty weird and annoying side-effect to spring on + # users. So test that this doesn't happen, by forcing the + # bytearray's underlying buffer to be realloc'ed: + data += bytes(BIG) + # (Note: the above line of code doesn't do a very good job at + # testing anything, because: + # - on CPython, the refcount GC generally cleans up memoryviews + # for us even if we're sloppy. + # - on PyPy3, at least as of 5.7.0, the memoryview code and the + # bytearray code conspire so that resizing never fails – if + # resizing forces the bytearray's internal buffer to move, then + # all memoryview references are automagically updated (!!). + # See: + # https://gist.github.com/njsmith/0ffd38ec05ad8e34004f34a7dc492227 + # But I'm leaving the test here in hopes that if this ever changes + # and we break our implementation of sendall, then we'll get some + # early warning...) + + async def receiver(): + # Make sure the sender fills up the kernel buffers and blocks await wait_all_tasks_blocked() nbytes = 0 while nbytes < BIG: nbytes += len(await b.recv(BIG)) - assert send_task.result is not None assert nbytes == BIG - with pytest.raises(BlockingIOError): - b._sock.recv(1) - a, b = tsocket.socketpair() - with a, b: - # Cancel half-way through async with _core.open_nursery() as nursery: - sent_complete = 0 - async def sendall_until_cancelled(): - nonlocal sent_complete - # Need to loop to make sure that we actually do block on - # Windows - while True: - await a.sendall(b"x" * BIG) - sent_complete += BIG - send_task = nursery.spawn(sendall_until_cancelled) - await wait_all_tasks_blocked() - nursery.cancel_scope.cancel() - assert type(send_task.result) is _core.Error - assert isinstance(send_task.result.error, _core.Cancelled) - sent_partial = send_task.result.error.partial_result.bytes_sent - a.close() - sent_total = 0 - while True: - got = len(await b.recv(BIG)) - if not got: - break - sent_total += got - assert sent_complete + sent_partial == sent_total + nursery.spawn(sender) + nursery.spawn(receiver) - a, b = tsocket.socketpair() - with a, b: - # A different error - a.close() - with pytest.raises(OSError) as excinfo: - await a.sendall(b"x") - assert excinfo.value.partial_result.bytes_sent == 0 + # We know that we received BIG bytes of NULs so far. Make sure that + # was all the data in there. + await a.sendall(b"e") + assert await b.recv(10) == b"e" + a.shutdown(tsocket.SHUT_WR) + assert await b.recv(10) == b"" diff --git a/trio/tests/test_ssl.py b/trio/tests/test_ssl.py new file mode 100644 index 0000000000..0193dd4b93 --- /dev/null +++ b/trio/tests/test_ssl.py @@ -0,0 +1,1036 @@ +import pytest + +from pathlib import Path +import threading +import socket as stdlib_socket +import ssl as stdlib_ssl +from contextlib import contextmanager + +from OpenSSL import SSL + +import trio +from .. import _core +from .. import _network +from .._streams import BrokenStreamError, ClosedStreamError +from .. import ssl as tssl +from .. import socket as tsocket +from .._util import UnLock + +from .._core.tests.tutil import slow + +from ..testing import ( + assert_yields, Sequencer, memory_stream_pair, lockstep_stream_pair, + check_two_way_stream, +) + +ASSETS_DIR = Path(__file__).parent / "test_ssl_certs" +CA = str(ASSETS_DIR / "trio-test-CA.pem") +CERT1 = str(ASSETS_DIR / "trio-test-1.pem") + + +# We have two different kinds of echo server fixtures we use for testing. The +# first is a real server written using the stdlib ssl module and blocking +# sockets. It runs in a thread and we talk to it over a real socketpair(), to +# validate interoperability in a semi-realistic setting. +# +# The second is a very weird virtual echo server that lives inside a custom +# Stream class. It lives entirely inside the Python object space; there are no +# operating system calls in it at all. No threads, no I/O, nothing. It's +# 'send_all' call takes encrypted data from a client and feeds it directly into +# the server-side TLS state engine to decrypt, then takes that data, feeds it +# back through to get the encrypted response, and returns it from 'receive_some'. This +# gives us full control and reproducibility. This server is written using +# PyOpenSSL, so that we can trigger renegotiations on demand. It also allows +# us to insert random (virtual) delays, to really exercise all the weird paths +# in SSLStream's state engine. +# +# Both present a certificate for "trio-test-1.example.org", that's signed by +# trio-test-CA.pem, an extremely trustworthy CA. + + +SERVER_CTX = stdlib_ssl.create_default_context( + stdlib_ssl.Purpose.CLIENT_AUTH, +) +SERVER_CTX.load_cert_chain(CERT1) + +CLIENT_CTX = stdlib_ssl.create_default_context(cafile=CA) + +# The blocking socket server. +def ssl_echo_serve_sync(sock, *, expect_fail=False): + try: + wrapped = SERVER_CTX.wrap_socket(sock, server_side=True) + wrapped.do_handshake() + while True: + data = wrapped.recv(4096) + if not data: + # graceful shutdown + wrapped.unwrap() + return + wrapped.sendall(data) + except Exception as exc: + if expect_fail: + print("ssl_echo_serve_sync got error as expected:", exc) + else: # pragma: no cover + raise + else: + if expect_fail: # pragma: no cover + print("failed to fail?!") + + +# Fixture that gives a raw socket connected to a trio-test-1 echo server +# (running in a thread). Useful for testing making connections with different +# SSLContexts. +@contextmanager +def ssl_echo_server_raw(**kwargs): + a, b = stdlib_socket.socketpair() + with a, b: + t = threading.Thread( + target=ssl_echo_serve_sync, + args=(b,), + kwargs=kwargs, + ) + t.start() + + yield _network.SocketStream(tsocket.from_stdlib_socket(a)) + + # exiting the context manager closes the sockets, which should force the + # thread to shut down (possibly with an error) + t.join() + + +# Fixture that gives a properly set up SSLStream connected to a trio-test-1 +# echo server (running in a thread) +@contextmanager +def ssl_echo_server(**kwargs): + with ssl_echo_server_raw(**kwargs) as sock: + yield tssl.SSLStream( + sock, CLIENT_CTX, server_hostname="trio-test-1.example.org") + + +# The weird in-memory server ... thing. +# Doesn't inherit from Stream because I left out the methods that we don't +# actually need. +class PyOpenSSLEchoStream: + def __init__(self, sleeper=None): + ctx = SSL.Context(SSL.SSLv23_METHOD) + # TLS 1.3 removes renegotiation support. Which is great for them, but + # we still have to support versions before that, and that means we + # need to test renegotation support, which means we need to force this + # to use a lower version where this test server can trigger + # renegotiations. Of course TLS 1.3 support isn't released yet, but + # I'm told that this will work once it is. (And once it is we can + # remove the pragma: no cover too.) Alternatively, once we drop + # support for CPython 3.5 on MacOS, then we could switch to using + # TLSv1_2_METHOD. + # + # Discussion: https://github.com/pyca/pyopenssl/issues/624 + if hasattr(SSL, "OP_NO_TLSv1_3"): # pragma: no cover + ctx.set_options(SSL.OP_NO_TLSv1_3) + # Unfortunately there's currently no way to say "use 1.3 or worse", we + # can only disable specific versions. And if the two sides start + # negotiating 1.4 at some point in the future, it *might* mean that + # our tests silently stop working properly. So the next line is a + # tripwire to remind us we need to revisit this stuff in 5 years or + # whatever when the next TLS version is released: + assert not hasattr(SSL, "OP_NO_TLSv1_4") + ctx.use_certificate_file(CERT1) + ctx.use_privatekey_file(CERT1) + self._conn = SSL.Connection(ctx, None) + self._conn.set_accept_state() + self._lot = _core.ParkingLot() + self._pending_cleartext = bytearray() + + self._send_all_mutex = UnLock( + _core.ResourceBusyError, + "simultaneous calls to PyOpenSSLEchoStream.send_all") + self._receive_some_mutex = UnLock( + _core.ResourceBusyError, + "simultaneous calls to PyOpenSSLEchoStream.receive_some") + + if sleeper is None: + async def no_op_sleeper(_): + return + self.sleeper = no_op_sleeper + else: + self.sleeper = sleeper + + def forceful_close(self): + self._conn.bio_shutdown() + + async def graceful_close(self): + self.forceful_close() + + def renegotiate_pending(self): + return self._conn.renegotiate_pending() + + def renegotiate(self): + # Returns false if a renegotation is already in progress, meaning + # nothing happens. + assert self._conn.renegotiate() + + async def wait_send_all_might_not_block(self): + async with self._send_all_mutex: + await _core.yield_briefly() + await self.sleeper("wait_send_all_might_not_block") + + async def send_all(self, data): + print(" --> transport_stream.send_all") + async with self._send_all_mutex: + await _core.yield_briefly() + await self.sleeper("send_all") + self._conn.bio_write(data) + while True: + await self.sleeper("send_all") + try: + data = self._conn.recv(1) + except SSL.ZeroReturnError: + self._conn.shutdown() + print("renegotiations:", self._conn.total_renegotiations()) + break + except SSL.WantReadError: + break + else: + self._pending_cleartext += data + self._lot.unpark_all() + await self.sleeper("send_all") + print(" <-- transport_stream.send_all finished") + + async def receive_some(self, nbytes): + print(" --> transport_stream.receive_some") + async with self._receive_some_mutex: + try: + await _core.yield_briefly() + while True: + await self.sleeper("receive_some") + try: + return self._conn.bio_read(nbytes) + except SSL.WantReadError: + # No data in our ciphertext buffer; try to generate + # some. + if self._pending_cleartext: + # We have some cleartext; maybe we can encrypt it + # and then return it. + print(" trying", self._pending_cleartext) + try: + # PyOpenSSL bug: doesn't accept bytearray + # https://github.com/pyca/pyopenssl/issues/621 + next_byte = self._pending_cleartext[0:1] + self._conn.send(bytes(next_byte)) + # Apparently this next bit never gets hit in the + # test suite, but it's not an interesting omission + # so let's pragma it. + except SSL.WantReadError: # pragma: no cover + # We didn't manage to send the cleartext (and + # in particular we better leave it there to + # try again, due to openssl's retry + # semantics), but it's possible we pushed a + # renegotiation forward and *now* we have data + # to send. + try: + return self._conn.bio_read(nbytes) + except SSL.WantReadError: + # Nope. We're just going to have to wait + # for someone to call send_all() to give + # use more data. + print("parking (a)") + await self._lot.park() + else: + # We successfully sent that byte, so we don't + # have to again. + del self._pending_cleartext[0:1] + else: + # no pending cleartext; nothing to do but wait for + # someone to call send_all + print("parking (b)") + await self._lot.park() + finally: + await self.sleeper("receive_some") + print(" <-- transport_stream.receive_some finished") + + +async def test_PyOpenSSLEchoStream_gives_resource_busy_errors(): + # Make sure that PyOpenSSLEchoStream complains if two tasks call send_all + # at the same time, or ditto for receive_some. The tricky cases where SSLStream + # might accidentally do this are during renegotation, which we test using + # PyOpenSSLEchoStream, so this makes sure that if we do have a bug then + # PyOpenSSLEchoStream will notice and complain. + + s = PyOpenSSLEchoStream() + with pytest.raises(_core.ResourceBusyError) as excinfo: + async with _core.open_nursery() as nursery: + nursery.spawn(s.send_all, b"x") + nursery.spawn(s.send_all, b"x") + assert "simultaneous" in str(excinfo.value) + + s = PyOpenSSLEchoStream() + with pytest.raises(_core.ResourceBusyError) as excinfo: + async with _core.open_nursery() as nursery: + nursery.spawn(s.send_all, b"x") + nursery.spawn(s.wait_send_all_might_not_block) + assert "simultaneous" in str(excinfo.value) + + s = PyOpenSSLEchoStream() + with pytest.raises(_core.ResourceBusyError) as excinfo: + async with _core.open_nursery() as nursery: + nursery.spawn(s.wait_send_all_might_not_block) + nursery.spawn(s.wait_send_all_might_not_block) + assert "simultaneous" in str(excinfo.value) + + s = PyOpenSSLEchoStream() + with pytest.raises(_core.ResourceBusyError) as excinfo: + async with _core.open_nursery() as nursery: + nursery.spawn(s.receive_some, 1) + nursery.spawn(s.receive_some, 1) + assert "simultaneous" in str(excinfo.value) + + +@contextmanager +def virtual_ssl_echo_server(**kwargs): + fakesock = PyOpenSSLEchoStream(**kwargs) + yield tssl.SSLStream( + fakesock, CLIENT_CTX, server_hostname="trio-test-1.example.org") + + +def ssl_wrap_pair(client_transport, server_transport, + *, client_kwargs={}, server_kwargs={}): + client_ssl = tssl.SSLStream( + client_transport, CLIENT_CTX, + server_hostname="trio-test-1.example.org", **client_kwargs) + server_ssl = tssl.SSLStream( + server_transport, SERVER_CTX, server_side=True, **server_kwargs) + return client_ssl, server_ssl + +def ssl_memory_stream_pair(**kwargs): + client_transport, server_transport = memory_stream_pair() + return ssl_wrap_pair(client_transport, server_transport, **kwargs) + +def ssl_lockstep_stream_pair(**kwargs): + client_transport, server_transport = lockstep_stream_pair() + return ssl_wrap_pair(client_transport, server_transport, **kwargs) + + +def test_exports(): + # Just a quick check to make sure _reexport isn't totally broken + assert hasattr(tssl, "SSLError") + assert "SSLError" in tssl.__all__ + + assert hasattr(tssl, "Purpose") + assert "Purpose" in tssl.__all__ + + # Intentionally omitted + assert not hasattr(tssl, "SSLContext") + + +# Simple smoke test for handshake/send/receive/shutdown talking to a +# synchronous server, plus make sure that we do the bare minimum of +# certificate checking (even though this is really Python's responsibility) +async def test_ssl_client_basics(): + # Everything OK + with ssl_echo_server() as s: + assert not s.server_side + await s.send_all(b"x") + assert await s.receive_some(1) == b"x" + await s.graceful_close() + + # Didn't configure the CA file, should fail + with ssl_echo_server_raw(expect_fail=True) as sock: + client_ctx = stdlib_ssl.create_default_context() + s = tssl.SSLStream( + sock, client_ctx, server_hostname="trio-test-1.example.org") + assert not s.server_side + with pytest.raises(BrokenStreamError) as excinfo: + await s.send_all(b"x") + assert isinstance(excinfo.value.__cause__, tssl.SSLError) + + # Trusted CA, but wrong host name + with ssl_echo_server_raw(expect_fail=True) as sock: + s = tssl.SSLStream( + sock, CLIENT_CTX, server_hostname="trio-test-2.example.org") + assert not s.server_side + with pytest.raises(BrokenStreamError) as excinfo: + await s.send_all(b"x") + assert isinstance(excinfo.value.__cause__, tssl.CertificateError) + + +async def test_ssl_server_basics(): + a, b = stdlib_socket.socketpair() + with a, b: + server_sock = tsocket.from_stdlib_socket(b) + server_transport = tssl.SSLStream( + _network.SocketStream(server_sock), SERVER_CTX, server_side=True) + assert server_transport.server_side + + def client(): + client_sock = CLIENT_CTX.wrap_socket( + a, server_hostname="trio-test-1.example.org") + client_sock.sendall(b"x") + assert client_sock.recv(1) == b"y" + client_sock.sendall(b"z") + client_sock.unwrap() + t = threading.Thread(target=client) + t.start() + + assert await server_transport.receive_some(1) == b"x" + await server_transport.send_all(b"y") + assert await server_transport.receive_some(1) == b"z" + assert await server_transport.receive_some(1) == b"" + await server_transport.graceful_close() + + t.join() + + +async def test_attributes(): + with ssl_echo_server_raw(expect_fail=True) as sock: + good_ctx = CLIENT_CTX + bad_ctx = stdlib_ssl.create_default_context() + s = tssl.SSLStream( + sock, good_ctx, server_hostname="trio-test-1.example.org") + + assert s.transport_stream is sock + + # Forwarded attribute getting + assert s.context is good_ctx + assert s.server_side == False + assert s.server_hostname == "trio-test-1.example.org" + with pytest.raises(AttributeError): + s.asfdasdfsa + + # __dir__ + assert "transport_stream" in dir(s) + assert "context" in dir(s) + + # Setting the attribute goes through to the underlying object + + # most attributes on SSLObject are read-only + with pytest.raises(AttributeError): + s.server_side = True + with pytest.raises(AttributeError): + s.server_hostname = "asdf" + + # but .context is *not*. Check that we forward attribute setting by + # making sure that after we set the bad context our handshake indeed + # fails: + s.context = bad_ctx + assert s.context is bad_ctx + with pytest.raises(BrokenStreamError) as excinfo: + await s.do_handshake() + assert isinstance(excinfo.value.__cause__, tssl.SSLError) + + +# Note: this test fails horribly if we force TLS 1.2 and trigger a +# renegotiation at the beginning (e.g. by switching to the pyopenssl +# server). Usually the client crashes in SSLObject.write with "UNEXPECTED +# RECORD"; sometimes we get something more exotic like a SyscallError. This is +# odd because openssl isn't doing any syscalls, but so it goes. After lots of +# websearching I'm pretty sure this is due to a bug in OpenSSL, where it just +# can't reliably handle full-duplex communication combined with +# renegotiation. Nice, eh? +# +# https://rt.openssl.org/Ticket/Display.html?id=3712 +# https://rt.openssl.org/Ticket/Display.html?id=2481 +# http://openssl.6102.n7.nabble.com/TLS-renegotiation-failure-on-receiving-application-data-during-handshake-td48127.html +# https://stackoverflow.com/questions/18728355/ssl-renegotiation-with-full-duplex-socket-communication +# +# In some variants of this test (maybe only against the java server?) I've +# also seen cases where our send_all blocks waiting to write, and then our receive_some +# also blocks waiting to write, and they never wake up again. It looks like +# some kind of deadlock. I suspect there may be an issue where we've filled up +# the send buffers, and the remote side is trying to handle the renegotiation +# from inside a write() call, so it has a problem: there's all this application +# data clogging up the pipe, but it can't process and return it to the +# application because it's in write(), and it doesn't want to buffer infinite +# amounts of data, and... actually I guess those are the only two choices. +# +# NSS even documents that you shouldn't try to do a renegotiation except when +# the connection is idle: +# +# https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/SSL_functions/sslfnc.html#1061582 +# +# I begin to see why HTTP/2 forbids renegotiation and TLS 1.3 removes it... + +async def test_full_duplex_basics(): + CHUNKS = 30 + CHUNK_SIZE = 32768 + EXPECTED = CHUNKS * CHUNK_SIZE + + sent = bytearray() + received = bytearray() + + async def sender(s): + nonlocal sent + for i in range(CHUNKS): + print(i) + chunk = bytes([i] * CHUNK_SIZE) + sent += chunk + await s.send_all(chunk) + + async def receiver(s): + nonlocal received + while len(received) < EXPECTED: + chunk = await s.receive_some(CHUNK_SIZE // 2) + received += chunk + + with ssl_echo_server() as s: + async with _core.open_nursery() as nursery: + nursery.spawn(sender, s) + nursery.spawn(receiver, s) + # And let's have some doing handshakes too, everyone + # simultaneously + nursery.spawn(s.do_handshake) + nursery.spawn(s.do_handshake) + + await s.graceful_close() + + assert len(sent) == len(received) == EXPECTED + assert sent == received + + +async def test_renegotiation_simple(): + with virtual_ssl_echo_server() as s: + await s.do_handshake() + + s.transport_stream.renegotiate() + await s.send_all(b"a") + assert await s.receive_some(1) == b"a" + + # Have to send some more data back and forth to make sure the + # renegotiation is finished before shutting down the + # connection... otherwise openssl raises an error. I think this is a + # bug in openssl but what can ya do. + await s.send_all(b"b") + assert await s.receive_some(1) == b"b" + + await s.graceful_close() + + +@slow +async def test_renegotiation_randomized(mock_clock): + # The only blocking things in this function are our random sleeps, so 0 is + # a good threshold. + mock_clock.autojump_threshold = 0 + + import random + r = random.Random(0) + + async def sleeper(_): + await trio.sleep(r.uniform(0, 10)) + + async def clear(): + while s.transport_stream.renegotiate_pending(): + with assert_yields(): + await send(b"-") + with assert_yields(): + await expect(b"-") + print("-- clear --") + + async def send(byte): + await s.transport_stream.sleeper("outer send") + print("calling SSLStream.send_all", byte) + with assert_yields(): + await s.send_all(byte) + + async def expect(expected): + await s.transport_stream.sleeper("expect") + print("calling SSLStream.receive_some, expecting", expected) + assert len(expected) == 1 + with assert_yields(): + assert await s.receive_some(1) == expected + + with virtual_ssl_echo_server(sleeper=sleeper) as s: + await s.do_handshake() + + await send(b"a") + s.transport_stream.renegotiate() + await expect(b"a") + + await clear() + + for i in range(100): + b1 = bytes([i % 0xff]) + b2 = bytes([(2 * i) % 0xff]) + s.transport_stream.renegotiate() + async with _core.open_nursery() as nursery: + nursery.spawn(send, b1) + nursery.spawn(expect, b1) + async with _core.open_nursery() as nursery: + nursery.spawn(expect, b2) + nursery.spawn(send, b2) + await clear() + + for i in range(100): + b1 = bytes([i % 0xff]) + b2 = bytes([(2 * i) % 0xff]) + await send(b1) + s.transport_stream.renegotiate() + await expect(b1) + async with _core.open_nursery() as nursery: + nursery.spawn(expect, b2) + nursery.spawn(send, b2) + await clear() + + # Checking that wait_send_all_might_not_block and receive_some don't + # conflict: + + # 1) Set up a situation where expect (receive_some) is blocked sending, + # and wait_send_all_might_not_block comes in. + + # Our receive_some() call will get stuck when it hits send_all + async def sleeper_with_slow_send_all(method): + if method == "send_all": + await trio.sleep(100000) + + # And our wait_send_all_might_not_block call will give it time to get + # stuck, and then start + async def sleep_then_wait_writable(): + await trio.sleep(1000) + await s.wait_send_all_might_not_block() + + with virtual_ssl_echo_server(sleeper=sleeper_with_slow_send_all) as s: + await send(b"x") + s.transport_stream.renegotiate() + async with _core.open_nursery() as nursery: + nursery.spawn(expect, b"x") + nursery.spawn(sleep_then_wait_writable) + + await clear() + + await s.graceful_close() + + # 2) Same, but now wait_send_all_might_not_block is stuck when + # receive_some tries to send. + + async def sleeper_with_slow_wait_writable_and_expect(method): + if method == "wait_send_all_might_not_block": + await trio.sleep(100000) + elif method == "expect": + await trio.sleep(1000) + + with virtual_ssl_echo_server( + sleeper=sleeper_with_slow_wait_writable_and_expect) as s: + await send(b"x") + s.transport_stream.renegotiate() + async with _core.open_nursery() as nursery: + nursery.spawn(expect, b"x") + nursery.spawn(s.wait_send_all_might_not_block) + + await clear() + + await s.graceful_close() + + +async def test_resource_busy_errors(): + async def do_send_all(): + with assert_yields(): + await s.send_all(b"x") + + async def do_receive_some(): + with assert_yields(): + await s.receive_some(1) + + async def do_wait_send_all_might_not_block(): + with assert_yields(): + await s.wait_send_all_might_not_block() + + s, _ = ssl_lockstep_stream_pair() + with pytest.raises(_core.ResourceBusyError) as excinfo: + async with _core.open_nursery() as nursery: + nursery.spawn(do_send_all) + nursery.spawn(do_send_all) + assert "another task" in str(excinfo.value) + + s, _ = ssl_lockstep_stream_pair() + with pytest.raises(_core.ResourceBusyError) as excinfo: + async with _core.open_nursery() as nursery: + nursery.spawn(do_receive_some) + nursery.spawn(do_receive_some) + assert "another task" in str(excinfo.value) + + s, _ = ssl_lockstep_stream_pair() + with pytest.raises(_core.ResourceBusyError) as excinfo: + async with _core.open_nursery() as nursery: + nursery.spawn(do_send_all) + nursery.spawn(do_wait_send_all_might_not_block) + assert "another task" in str(excinfo.value) + + s, _ = ssl_lockstep_stream_pair() + with pytest.raises(_core.ResourceBusyError) as excinfo: + async with _core.open_nursery() as nursery: + nursery.spawn(do_wait_send_all_might_not_block) + nursery.spawn(do_wait_send_all_might_not_block) + assert "another task" in str(excinfo.value) + + +async def test_wait_writable_calls_underlying_wait_writable(): + record = [] + class NotAStream: + async def wait_send_all_might_not_block(self): + record.append("ok") + ctx = stdlib_ssl.create_default_context() + s = tssl.SSLStream(NotAStream(), ctx, server_hostname="x") + await s.wait_send_all_might_not_block() + assert record == ["ok"] + + +async def test_checkpoints(): + with ssl_echo_server() as s: + with assert_yields(): + await s.do_handshake() + with assert_yields(): + await s.do_handshake() + with assert_yields(): + await s.wait_send_all_might_not_block() + with assert_yields(): + await s.send_all(b"xxx") + with assert_yields(): + await s.receive_some(1) + # These receive_some's in theory could return immediately, because the + # "xxx" was sent in a single record and after the first + # receive_some(1) the rest are sitting inside the SSLObject's internal + # buffers. + with assert_yields(): + await s.receive_some(1) + with assert_yields(): + await s.receive_some(1) + with assert_yields(): + await s.unwrap() + + with ssl_echo_server() as s: + await s.do_handshake() + with assert_yields(): + await s.graceful_close() + + +async def test_send_all_empty_string(): + with ssl_echo_server() as s: + await s.do_handshake() + + # underlying SSLObject interprets writing b"" as indicating an EOF, + # for some reason. Make sure we don't inherit this. + with assert_yields(): + await s.send_all(b"") + with assert_yields(): + await s.send_all(b"") + await s.send_all(b"x") + assert await s.receive_some(1) == b"x" + + await s.graceful_close() + + +@pytest.mark.parametrize("https_compatible", [False, True]) +async def test_SSLStream_generic(https_compatible): + async def stream_maker(): + return ssl_memory_stream_pair( + client_kwargs={"https_compatible": https_compatible}, + server_kwargs={"https_compatible": https_compatible}, + ) + + async def clogged_stream_maker(): + client, server = ssl_lockstep_stream_pair() + # If we don't do handshakes up front, then we run into a problem in + # the following situation: + # - server does wait_send_all_might_not_block + # - client does receive_some to unclog it + # Then the client's receive_some will actually send some data to start + # the handshake, and itself get stuck. + async with _core.open_nursery() as nursery: + nursery.spawn(client.do_handshake) + nursery.spawn(server.do_handshake) + return client, server + + await check_two_way_stream(stream_maker, clogged_stream_maker) + + +async def test_unwrap(): + client_ssl, server_ssl = ssl_memory_stream_pair() + client_transport = client_ssl.transport_stream + server_transport = server_ssl.transport_stream + + seq = Sequencer() + + async def client(): + await client_ssl.do_handshake() + await client_ssl.send_all(b"x") + assert await client_ssl.receive_some(1) == b"y" + await client_ssl.send_all(b"z") + + # After sending that, disable outgoing data from our end, to make + # sure the server doesn't see our EOF until after we've sent some + # trailing data + async with seq(0): + send_all_hook = client_transport.send_stream.send_all_hook + client_transport.send_stream.send_all_hook = None + + assert await client_ssl.receive_some(1) == b"" + assert client_ssl.transport_stream is client_transport + # We just received EOF. Unwrap the connection and send some more. + raw, trailing = await client_ssl.unwrap() + assert raw is client_transport + assert trailing == b"" + assert client_ssl.transport_stream is None + await raw.send_all(b"trailing") + + # Reconnect the streams. Now the server will receive both our shutdown + # acknowledgement + the trailing data in a single lump. + client_transport.send_stream.send_all_hook = send_all_hook + await client_transport.send_stream.send_all_hook() + + async def server(): + await server_ssl.do_handshake() + assert await server_ssl.receive_some(1) == b"x" + await server_ssl.send_all(b"y") + assert await server_ssl.receive_some(1) == b"z" + # Now client is blocked waiting for us to send something, but + # instead we close the TLS connection (with sequencer to make sure + # that the client won't see and automatically respond before we've had + # a chance to disable the client->server transport) + async with seq(1): + raw, trailing = await server_ssl.unwrap() + assert raw is server_transport + assert trailing == b"trailing" + assert server_ssl.transport_stream is None + + async with _core.open_nursery() as nursery: + nursery.spawn(client) + nursery.spawn(server) + + +async def test_closing_nice_case(): + # the nice case: graceful closes all around + + client_ssl, server_ssl = ssl_memory_stream_pair() + client_transport = client_ssl.transport_stream + + # Both the handshake and the close require back-and-forth discussion, so + # we need to run them concurrently + async def client_closer(): + with assert_yields(): + await client_ssl.graceful_close() + + async def server_closer(): + assert await server_ssl.receive_some(10) == b"" + assert await server_ssl.receive_some(10) == b"" + with assert_yields(): + await server_ssl.graceful_close() + + async with _core.open_nursery() as nursery: + nursery.spawn(client_closer) + nursery.spawn(server_closer) + + # closing the SSLStream also closes its transport + with pytest.raises(ClosedStreamError): + await client_transport.send_all(b"123") + + # once closed, it's OK to close again + with assert_yields(): + await client_ssl.graceful_close() + client_ssl.forceful_close() + + # Trying to send more data does not work + with assert_yields(): + with pytest.raises(ClosedStreamError): + await server_ssl.send_all(b"123") + + # And once the connection is has been closed *locally*, then instead of + # getting empty bytestrings we get a proper error + with assert_yields(): + with pytest.raises(ClosedStreamError): + await client_ssl.receive_some(10) == b"" + + with assert_yields(): + with pytest.raises(ClosedStreamError): + await client_ssl.unwrap() + + with assert_yields(): + with pytest.raises(ClosedStreamError): + await client_ssl.do_handshake() + + # Check that a graceful close *before* handshaking gives a clean EOF on + # the other side + client_ssl, server_ssl = ssl_memory_stream_pair() + async def expect_eof_server(): + with assert_yields(): + assert await server_ssl.receive_some(10) == b"" + with assert_yields(): + await server_ssl.graceful_close() + + async with _core.open_nursery() as nursery: + nursery.spawn(client_ssl.graceful_close) + nursery.spawn(expect_eof_server) + + +async def test_send_all_fails_in_the_middle(): + client, server = ssl_memory_stream_pair() + + async with _core.open_nursery() as nursery: + nursery.spawn(client.do_handshake) + nursery.spawn(server.do_handshake) + + async def bad_hook(): + raise KeyError + + client.transport_stream.send_stream.send_all_hook = bad_hook + + with pytest.raises(KeyError): + await client.send_all(b"x") + + with pytest.raises(BrokenStreamError): + await client.wait_send_all_might_not_block() + + closed = 0 + def close_hook(): + nonlocal closed + closed += 1 + + client.transport_stream.send_stream.close_hook = close_hook + client.transport_stream.receive_stream.close_hook = close_hook + await client.graceful_close() + + assert closed == 2 + + +async def test_ssl_over_ssl(): + client_0, server_0 = memory_stream_pair() + + client_1 = tssl.SSLStream( + client_0, CLIENT_CTX, server_hostname="trio-test-1.example.org") + server_1 = tssl.SSLStream( + server_0, SERVER_CTX, server_side=True) + + client_2 = tssl.SSLStream( + client_1, CLIENT_CTX, server_hostname="trio-test-1.example.org") + server_2 = tssl.SSLStream( + server_1, SERVER_CTX, server_side=True) + + async def client(): + await client_2.send_all(b"hi") + assert await client_2.receive_some(10) == b"bye" + + async def server(): + assert await server_2.receive_some(10) == b"hi" + await server_2.send_all(b"bye") + + async with _core.open_nursery() as nursery: + nursery.spawn(client) + nursery.spawn(server) + + +async def test_ssl_bad_shutdown(): + client, server = ssl_memory_stream_pair() + + async with _core.open_nursery() as nursery: + nursery.spawn(client.do_handshake) + nursery.spawn(server.do_handshake) + + client.forceful_close() + # now the server sees a broken stream + with pytest.raises(BrokenStreamError): + await server.receive_some(10) + with pytest.raises(BrokenStreamError): + await server.send_all(b"x" * 10) + + await server.graceful_close() + + +async def test_ssl_bad_shutdown_but_its_ok(): + client, server = ssl_memory_stream_pair( + server_kwargs={"https_compatible": True}, + client_kwargs={"https_compatible": True}) + + async with _core.open_nursery() as nursery: + nursery.spawn(client.do_handshake) + nursery.spawn(server.do_handshake) + + client.forceful_close() + # the server sees that as a clean shutdown + assert await server.receive_some(10) == b"" + with pytest.raises(BrokenStreamError): + await server.send_all(b"x" * 10) + + await server.graceful_close() + + +async def test_ssl_https_compatibility_disagreement(): + client, server = ssl_memory_stream_pair( + server_kwargs={"https_compatible": False}, + client_kwargs={"https_compatible": True}) + + async with _core.open_nursery() as nursery: + nursery.spawn(client.do_handshake) + nursery.spawn(server.do_handshake) + + # client is in HTTPS-mode, server is not + # so client doing graceful_shutdown causes an error on server + async def receive_and_expect_error(): + with pytest.raises(BrokenStreamError) as excinfo: + await server.receive_some(10) + assert isinstance(excinfo.value.__cause__, tssl.SSLEOFError) + + async with _core.open_nursery() as nursery: + nursery.spawn(client.graceful_close) + nursery.spawn(receive_and_expect_error) + + +async def test_https_mode_eof_before_handshake(): + client, server = ssl_memory_stream_pair( + server_kwargs={"https_compatible": True}, + client_kwargs={"https_compatible": True}) + + async def server_expect_clean_eof(): + assert await server.receive_some(10) == b"" + + async with _core.open_nursery() as nursery: + nursery.spawn(client.graceful_close) + nursery.spawn(server_expect_clean_eof) + + +async def test_send_error_during_handshake(): + client, server = ssl_memory_stream_pair() + + async def bad_hook(): + raise KeyError + + client.transport_stream.send_stream.send_all_hook = bad_hook + + with pytest.raises(KeyError): + with assert_yields(): + await client.do_handshake() + + with pytest.raises(BrokenStreamError): + with assert_yields(): + await client.do_handshake() + + +async def test_receive_error_during_handshake(): + client, server = ssl_memory_stream_pair() + + async def bad_hook(): + raise KeyError + + client.transport_stream.receive_stream.receive_some_hook = bad_hook + + async def client_side(cancel_scope): + with pytest.raises(KeyError): + with assert_yields(): + await client.do_handshake() + cancel_scope.cancel() + + async with _core.open_nursery() as nursery: + nursery.spawn(client_side, nursery.cancel_scope) + nursery.spawn(server.do_handshake) + + with pytest.raises(BrokenStreamError): + with assert_yields(): + await client.do_handshake() + + +async def test_getpeercert(): + # Make sure we're not affected by https://bugs.python.org/issue29334 + client, server = ssl_memory_stream_pair() + + async with _core.open_nursery() as nursery: + nursery.spawn(client.do_handshake) + nursery.spawn(server.do_handshake) + + assert server.getpeercert() is None + assert ((("commonName", "trio-test-1.example.org"),) + in client.getpeercert()["subject"]) diff --git a/trio/tests/test_ssl_certs/make-test-certs.sh b/trio/tests/test_ssl_certs/make-test-certs.sh new file mode 100755 index 0000000000..0964cb14e0 --- /dev/null +++ b/trio/tests/test_ssl_certs/make-test-certs.sh @@ -0,0 +1,53 @@ +#!/bin/bash + +# This file generates 3 .pem files that we use in the test suite. +# +# trio-test-CA.pem contains the public key for a root CA certificate. +# +# trio-test-{1,2}.pem contain, concatenated in order: +# - a private key +# - a certificate signed by our CA claiming that this is the key for the host +# "trio-test-$N.example.org" +# - the root CA certificate again (to complete the cert chain) +# +# End result is that if you do +# +# ssl.create_default_context(cafile="trio-test-CA.pem") +# +# then your SSLContext will trust the trio-test-{1,2} certificates. +# +# And if you do +# +# sslcontext.load_cert_chain("trio-test-1.pem") +# +# then you can claim to be trio-test-1.example.org. + +set -uxe -o pipefail + +# Generate a self-signed 2048-bit RSA key as our signing root +openssl req -x509 -newkey rsa:2048 -nodes -sha256 -days 99999 \ + -subj '/O=Trio test CA' \ + -keyout trio-test-CA.key -out trio-test-CA.pem + +for CERT in 1; do + # Create a key and CSR. + # + # Our tests only use one name, so CN= is enough. (Otherwise we would need + # to use subjectAltNames=, which *replaces* CN=.) + openssl req -new -newkey rsa:2048 -nodes -sha256 \ + -subj "/CN=trio-test-${CERT}.example.org" \ + -keyout trio-test-${CERT}.key -out trio-test-${CERT}.csr + # Use the CSR and CA to sign the key, generating a certificate + openssl x509 -req -in trio-test-${CERT}.csr -sha256 -days 99999 \ + -CA trio-test-CA.pem -CAkey trio-test-CA.key \ + -set_serial ${CERT} \ + -out trio-test-${CERT}.crt + # Combine key/cert/root-CA into a single file for convenience + # (see https://docs.python.org/3/library/ssl.html#ssl-certificates) + cat trio-test-${CERT}.key trio-test-${CERT}.crt trio-test-CA.pem \ + > trio-test-${CERT}.pem + rm -f trio-test-${CERT}.{csr,key,crt} +done + +# We don't need the signing key anymore, remove it to reduce clutter +rm -f trio-test-CA.key diff --git a/trio/tests/test_ssl_certs/trio-test-1.pem b/trio/tests/test_ssl_certs/trio-test-1.pem new file mode 100644 index 0000000000..1431aa033d --- /dev/null +++ b/trio/tests/test_ssl_certs/trio-test-1.pem @@ -0,0 +1,64 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDQC/Q9j1KiWr56 +QUSL5CQEmWecI/2IH1bOekAn9hy+xb98DIsJXy9u7u0gjFTAHJh1ysvRQ8gacdsC +locBnt6+ElDPpJgiGHyf5NWzrJ/4uNwM4FKZYUObA5QeM8CTJzuCvFkH3MI17yJk +3uqz/HfiUdeA2QO4iBDgqGj4LjIfO16C43VoaHvb8ePsRoBHlWCzZGqp2q98lMgi +Ed7bq1IlrGb5MzU4bnZ0k0fDi4adSG5TI9W+aYwTz+OvZIZ9xqTjj5+7Zh4IAnBh +Ql23tVgS+Kuemcf5zS0vVUxG1paATKB/MBqtfeyXJyrMMXSUR9Qv0UAaGHNqPQGZ +ituO1tYXAgMBAAECggEAHJLXu7C4j7XY3V+jc3ck/0C2ezpyMsTjHj6qGxLxRb5R +G095tRLOp/TGuqaraStEQUFWFuqxS/iBNOzJpA5W11IaqToY7u3gB/Hc6+10lyuE +hXw1u/0g1OR77l37P/qucLk/nRXT0qaCWcpH/+pX6MyGxZqIqUp+zuwyZouptKIq +QE1wDy6AEKLSraan3dh9pEd1FGJB7lKKnPSgbwP2H9dgtWgYYOgNAZSszKWX7rGT ++kEH0aPC5OOpkG58saF4d/eI+KRU+/QGZcLrFZsiYuWxFcFkrnKKx8PV5rjrLRFG ++D13zASaTFJ4W6gLCTVb9lNDofhHb8JCcXrZi77fWQKBgQD3Vf+94Fw2Ybvr84G0 +7p1kCf6naDMcvRPHl0pGI/oCP7gwkgtdBnNLFU2Y8IS6YvN8jMQzPB3gvwj3T3xK +LPoM5Ie/RJYbPt3tNuXpzLY6NYx4BZ1rvK7h7ShTDNf8zb+pIdk8k0Xw3rVURWRI +WLA6jSf3Oc9c08/X8TWpZpwaEwKBgQDXVaBUSDz7+zH6WOtRSq5Wr1bDkubYnJ8Y +RsUN4Ju/59hhnIRewDty/28jtXcCtpPH1CrYf/mvsU8k6cn/u3pPbhgIgSDLdZT1 +j/kODgetfhXTbMZAp70Er9svlyUVlo0xV3fNMcj/FCVeVIF/0Icn/jUTJNTdlxbW +ch2JHRjUbQKBgQCzwU7CoqKh61n2W90ysBC3OgRXioVLJ6eOcUfLvi3fIIwu0JVt +oFh+gxcIRhVQmMW5CV02l0RnqK9Nffkot5Nrd1OpEKG/X2tPEYz65Iqzt2NFf18v +g8vd6sxZv4Xh926J703AlpBIRLOocV42ri41/4zCQsOQBWiS2n1Thn2A/QKBgEbm +58q4mnPxywv+eUUkDPF3/F6bIS2TrILm0n12RnJS2ZmSWreEHk8IMkUUvCIFkfVL +M+xjfwhNnpyt6hgtV+GNg5ZRRkYX6jtM85mgHwEOMguSlli1onRHnyk1YD2Se90S +St0ilmb+8Cr2MkmulMIjXsB18S0hUaC8pGMAVKulAoGARusNpVPh3M8hc7k1l5tn +qOy7r/DkaxUXAje8c52x7meRXQenFmwbjLvtpgF1wh3RYzh48Ljb4An4vIy+dzC0 +ZRpE46qQg05cRksRoPWwpAD9oJjrXgBV/QPHMG4eTJXqh8vKECjlAhDD5HduP191 +lZiUjv4qGPFam1Iuhc2wRhc= +-----END PRIVATE KEY----- +-----BEGIN CERTIFICATE----- +MIICrzCCAZcCAQEwDQYJKoZIhvcNAQELBQAwFzEVMBMGA1UECgwMVHJpbyB0ZXN0 +IENBMCAXDTE3MDQyNTAyMjgzNloYDzIyOTEwMjA3MDIyODM2WjAiMSAwHgYDVQQD +DBd0cmlvLXRlc3QtMS5leGFtcGxlLm9yZzCCASIwDQYJKoZIhvcNAQEBBQADggEP +ADCCAQoCggEBANAL9D2PUqJavnpBRIvkJASZZ5wj/YgfVs56QCf2HL7Fv3wMiwlf +L27u7SCMVMAcmHXKy9FDyBpx2wKWhwGe3r4SUM+kmCIYfJ/k1bOsn/i43AzgUplh +Q5sDlB4zwJMnO4K8WQfcwjXvImTe6rP8d+JR14DZA7iIEOCoaPguMh87XoLjdWho +e9vx4+xGgEeVYLNkaqnar3yUyCIR3turUiWsZvkzNThudnSTR8OLhp1IblMj1b5p +jBPP469khn3GpOOPn7tmHggCcGFCXbe1WBL4q56Zx/nNLS9VTEbWloBMoH8wGq19 +7JcnKswxdJRH1C/RQBoYc2o9AZmK247W1hcCAwEAATANBgkqhkiG9w0BAQsFAAOC +AQEAPJ18+JGxFSFsRTt8P4RMQjTerjIOealquLlTUxflwb7EkjTnGtbVibAKDYoB +nwGgVPk5cXEWMAa31Ub+riG6PaXQirwzAZZYIfffFCCAi7JX41PM0TXWhNEWIvAT +7t9wREIo1rU337PrjlxNZx/lJ3Iwglxo5ol7Ecguq3x7rXbl5tW9oCOmbKexnAGq +dU9flNTCnQNuthfJToIIbTIiDMJB0vMvIHqGeehu+xf2slv8CNbSW3cfs2RHsLS7 +kZelM6ALMhOtHcW7V89bzGbp1OK0w3qIW2gixEO9ujLpNeFfWsDXo+fZRyQ7imq5 +wnZOkvKSUGrtrrschhr6X6bd7w== +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIDBjCCAe6gAwIBAgIJAK25bq8asqsuMA0GCSqGSIb3DQEBCwUAMBcxFTATBgNV +BAoMDFRyaW8gdGVzdCBDQTAgFw0xNzA0MjUwMjI4MzZaGA8yMjkxMDIwNzAyMjgz +NlowFzEVMBMGA1UECgwMVHJpbyB0ZXN0IENBMIIBIjANBgkqhkiG9w0BAQEFAAOC +AQ8AMIIBCgKCAQEAoNv5sSeaMn0RR4zovWPFUpUNeaEyAvJvFU1SRrvZmBwQUihK +FKaFfW820T5fEfAV266bFTpyw8DnxLGe6nlNqfF5H3yrwDNLQUebvC7K7xZy07zv +5aYqxVPlD/JqSqBzMLB4gl0nrKVRU+6y06nKlJ+6InRZ9KhhHI9ak7081aElDad2 +50iPiW3jEC50L3gBY1DVEbUQHcmOdS1Ne5keY2AgF1umZz4obsjFQghH+NwbVRq8 +VyX8dtbxnsLq4CTFUZpHz48GDKQz6ipRl6zWiMOreMPTNXmjR9taoxtXa1+m47PE +gb3oJ8IGUbxCk7Bd85uVCWRvA06gm26eaHraRQIDAQABo1MwUTAdBgNVHQ4EFgQU +Q2zAnu8wesWfG4DkT1SMqOh3G8swHwYDVR0jBBgwFoAUQ2zAnu8wesWfG4DkT1SM +qOh3G8swDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAl8PY4uy4 +/dKVLwHEhh7HGbRDibs5xK87qGUZQ6FmiNWMI2X4aNAXuqrMMKMs4151kpNBqif7 +7f3wMjKj2tIfrEm50aVW72aihGdNY7W88rfh1xnxE/SZ4l3m3UMF8a0NRPkhkEMC +Nk/ngoM20YI11OEO/VHZzef59DCx1Pfl0+T4lRgA4lbgG0dNTIggy9uhdZsWemF1 +gVeiyivEtJIy2RmvwA9rP+ZwYLVoYSV1lCOVrOlO1VWadAe//NYAKzcuBpufjWNz +vY+G5AzWfmNtu3RegDw+xEO31mGF7oKqCuESqTsvTbYBymi/l0qFazovrUo8js4G +5GjPPPACNXyNJw== +-----END CERTIFICATE----- diff --git a/trio/tests/test_ssl_certs/trio-test-CA.pem b/trio/tests/test_ssl_certs/trio-test-CA.pem new file mode 100644 index 0000000000..ca5eb63e8c --- /dev/null +++ b/trio/tests/test_ssl_certs/trio-test-CA.pem @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDBjCCAe6gAwIBAgIJAK25bq8asqsuMA0GCSqGSIb3DQEBCwUAMBcxFTATBgNV +BAoMDFRyaW8gdGVzdCBDQTAgFw0xNzA0MjUwMjI4MzZaGA8yMjkxMDIwNzAyMjgz +NlowFzEVMBMGA1UECgwMVHJpbyB0ZXN0IENBMIIBIjANBgkqhkiG9w0BAQEFAAOC +AQ8AMIIBCgKCAQEAoNv5sSeaMn0RR4zovWPFUpUNeaEyAvJvFU1SRrvZmBwQUihK +FKaFfW820T5fEfAV266bFTpyw8DnxLGe6nlNqfF5H3yrwDNLQUebvC7K7xZy07zv +5aYqxVPlD/JqSqBzMLB4gl0nrKVRU+6y06nKlJ+6InRZ9KhhHI9ak7081aElDad2 +50iPiW3jEC50L3gBY1DVEbUQHcmOdS1Ne5keY2AgF1umZz4obsjFQghH+NwbVRq8 +VyX8dtbxnsLq4CTFUZpHz48GDKQz6ipRl6zWiMOreMPTNXmjR9taoxtXa1+m47PE +gb3oJ8IGUbxCk7Bd85uVCWRvA06gm26eaHraRQIDAQABo1MwUTAdBgNVHQ4EFgQU +Q2zAnu8wesWfG4DkT1SMqOh3G8swHwYDVR0jBBgwFoAUQ2zAnu8wesWfG4DkT1SM +qOh3G8swDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAl8PY4uy4 +/dKVLwHEhh7HGbRDibs5xK87qGUZQ6FmiNWMI2X4aNAXuqrMMKMs4151kpNBqif7 +7f3wMjKj2tIfrEm50aVW72aihGdNY7W88rfh1xnxE/SZ4l3m3UMF8a0NRPkhkEMC +Nk/ngoM20YI11OEO/VHZzef59DCx1Pfl0+T4lRgA4lbgG0dNTIggy9uhdZsWemF1 +gVeiyivEtJIy2RmvwA9rP+ZwYLVoYSV1lCOVrOlO1VWadAe//NYAKzcuBpufjWNz +vY+G5AzWfmNtu3RegDw+xEO31mGF7oKqCuESqTsvTbYBymi/l0qFazovrUo8js4G +5GjPPPACNXyNJw== +-----END CERTIFICATE----- diff --git a/trio/tests/test_streams.py b/trio/tests/test_streams.py new file mode 100644 index 0000000000..4fd282fd61 --- /dev/null +++ b/trio/tests/test_streams.py @@ -0,0 +1,108 @@ +import pytest + +import attr + +from ..abc import SendStream, ReceiveStream +from .._streams import StapledStream + +@attr.s +class RecordSendStream(SendStream): + record = attr.ib(default=attr.Factory(list)) + + async def send_all(self, data): + self.record.append(("send_all", data)) + + async def wait_send_all_might_not_block(self): + self.record.append("wait_send_all_might_not_block") + + async def graceful_close(self): + self.record.append("graceful_close") + + def forceful_close(self): + self.record.append("forceful_close") + + +@attr.s +class RecordReceiveStream(ReceiveStream): + record = attr.ib(default=attr.Factory(list)) + + async def receive_some(self, max_bytes): + self.record.append(("receive_some", max_bytes)) + + async def graceful_close(self): + self.record.append("graceful_close") + + def forceful_close(self): + self.record.append("forceful_close") + + +async def test_StapledStream(): + send_stream = RecordSendStream() + receive_stream = RecordReceiveStream() + stapled = StapledStream(send_stream, receive_stream) + + assert stapled.send_stream is send_stream + assert stapled.receive_stream is receive_stream + + await stapled.send_all(b"foo") + await stapled.wait_send_all_might_not_block() + assert send_stream.record == [ + ("send_all", b"foo"), "wait_send_all_might_not_block", + ] + send_stream.record.clear() + + await stapled.send_eof() + assert send_stream.record == ["graceful_close"] + send_stream.record.clear() + + async def fake_send_eof(): + send_stream.record.append("send_eof") + send_stream.send_eof = fake_send_eof + await stapled.send_eof() + assert send_stream.record == ["send_eof"] + + send_stream.record.clear() + assert receive_stream.record == [] + + await stapled.receive_some(1234) + assert receive_stream.record == [("receive_some", 1234)] + assert send_stream.record == [] + receive_stream.record.clear() + + await stapled.graceful_close() + stapled.forceful_close() + assert receive_stream.record == ["graceful_close", "forceful_close"] + assert send_stream.record == ["graceful_close", "forceful_close"] + + +async def test_StapledStream_with_erroring_close(): + class BrokenSendStream(RecordSendStream): + def forceful_close(self): + super().forceful_close() + raise KeyError + + async def graceful_close(self): + await super().graceful_close() + raise ValueError + + class BrokenReceiveStream(RecordReceiveStream): + def forceful_close(self): + super().forceful_close() + raise KeyError + + async def graceful_close(self): + await super().graceful_close() + raise ValueError + + stapled = StapledStream(BrokenSendStream(), BrokenReceiveStream()) + + with pytest.raises(KeyError) as excinfo: + stapled.forceful_close() + assert isinstance(excinfo.value.__context__, KeyError) + + with pytest.raises(ValueError) as excinfo: + await stapled.graceful_close() + assert isinstance(excinfo.value.__context__, ValueError) + + assert stapled.send_stream.record == ["forceful_close", "graceful_close"] + assert stapled.receive_stream.record == ["forceful_close", "graceful_close"] diff --git a/trio/tests/test_sync.py b/trio/tests/test_sync.py index e8e1db4154..3f5e622abd 100644 --- a/trio/tests/test_sync.py +++ b/trio/tests/test_sync.py @@ -94,10 +94,14 @@ async def test_Semaphore_bounded(): assert bs.value == 1 -async def test_Lock(): - l = Lock() +@pytest.mark.parametrize( + "lockcls", [Lock, StrictFIFOLock], ids=lambda fn: fn.__name__) +async def test_Lock_and_StrictFIFOLock(lockcls): + l = lockcls() assert not l.locked() repr(l) # smoke test + # make sure repr uses the right name for subclasses + assert lockcls.__name__ in repr(l) with assert_yields(): async with l: assert l.locked() @@ -161,6 +165,8 @@ async def holder(): async def test_Condition(): with pytest.raises(TypeError): Condition(Semaphore(1)) + with pytest.raises(TypeError): + Condition(StrictFIFOLock) l = Lock() c = Condition(l) assert not l.locked() @@ -436,6 +442,7 @@ def release(self): lock_factories = [ lambda: Semaphore(1), Lock, + StrictFIFOLock, lambda: QueueLock1(10), lambda: QueueLock1(1), QueueLock2, @@ -443,6 +450,7 @@ def release(self): lock_factory_names = [ "Semaphore(1)", "Lock", + "StrictFIFOLock", "QueueLock1(10)", "QueueLock1(1)", "QueueLock2", @@ -483,7 +491,7 @@ async def worker(lock_like): # Several workers queue on the same lock; make sure they each get it, in # order. @generic_lock_test -async def test_generic_lock_fairness(lock_factory): +async def test_generic_lock_fifo_fairness(lock_factory): initial_order = [] record = [] LOOPS = 5 diff --git a/trio/tests/test_testing.py b/trio/tests/test_testing.py index d14947a3f1..36587bce79 100644 --- a/trio/tests/test_testing.py +++ b/trio/tests/test_testing.py @@ -5,7 +5,9 @@ from .. import sleep from .. import _core +from .. import _streams from ..testing import * +from ..testing import _assert_raises, _UnboundedByteQueue async def test_wait_all_tasks_blocked(): record = [] @@ -34,12 +36,14 @@ async def cancelled_while_waiting(): nursery.cancel_scope.cancel() assert t4.result.unwrap() == "ok" + async def test_wait_all_tasks_blocked_with_timeouts(mock_clock): record = [] async def timeout_task(): record.append("tt start") await sleep(5) record.append("tt finished") + async with _core.open_nursery() as nursery: t = nursery.spawn(timeout_task) await wait_all_tasks_blocked() @@ -76,6 +80,7 @@ async def wait_big_cushion(): nursery.spawn(wait_small_cushion) nursery.spawn(wait_small_cushion) nursery.spawn(wait_big_cushion) + assert record == [ "blink start", "wait_no_cushion end", @@ -85,6 +90,35 @@ async def wait_big_cushion(): "wait_big_cushion end", ] + +async def test_wait_all_tasks_blocked_with_tiebreaker(): + record = [] + + async def do_wait(cushion, tiebreaker): + await wait_all_tasks_blocked(cushion=cushion, tiebreaker=tiebreaker) + record.append((cushion, tiebreaker)) + + async with _core.open_nursery() as nursery: + nursery.spawn(do_wait, 0, 0) + nursery.spawn(do_wait, 0, -1) + nursery.spawn(do_wait, 0, 1) + nursery.spawn(do_wait, 0, -1) + nursery.spawn(do_wait, 0.0001, 10) + nursery.spawn(do_wait, 0.0001, -10) + + assert record == sorted(record) + assert record == [ + (0, -1), + (0, -1), + (0, 0), + (0, 1), + (0.0001, -10), + (0.0001, 10), + ] + + +################################################################ + async def test_assert_yields(): with assert_yields(): await _core.yield_briefly() @@ -121,6 +155,8 @@ async def test_assert_yields(): await _core.yield_briefly() +################################################################ + async def test_Sequencer(): record = [] def t(val): @@ -187,6 +223,8 @@ async def child(i): pass # pragma: no cover +################################################################ + def test_mock_clock(): REAL_NOW = 123.0 c = MockClock() @@ -266,11 +304,11 @@ async def test_mock_clock_autojump(mock_clock): # This should wake up at the same time as the autojump_threshold, and # confuse things. There is no deadline, so it shouldn't actually jump # the clock. But does it handle the situation gracefully? - await wait_all_tasks_blocked(0.02) + await wait_all_tasks_blocked(cushion=0.02, tiebreaker=float("inf")) # And again with threshold=0, because that has some special # busy-wait-avoidance logic: mock_clock.autojump_threshold = 0 - await wait_all_tasks_blocked() + await wait_all_tasks_blocked(tiebreaker=float("inf")) # set up a situation where the autojump task is blocked for a long long # time, to make sure that cancel-and-adjust-threshold logic is working @@ -307,3 +345,402 @@ def test_mock_clock_autojump_preset(): real_start = time.monotonic() _core.run(sleep, 10000, clock=mock_clock) assert time.monotonic() - real_start < 1 + + +async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked(mock_clock): + # Checks that autojump_threshold=0 doesn't interfere with + # calling wait_all_tasks_blocked with the default cushion=0 and arbitrary + # tiebreakers. + + mock_clock.autojump_threshold = 0 + + record = [] + + async def sleeper(): + await sleep(100) + record.append("yawn") + + async def waiter(): + for i in range(10): + await wait_all_tasks_blocked(tiebreaker=i) + record.append(i) + await sleep(1000) + record.append("waiter done") + + async with _core.open_nursery() as nursery: + nursery.spawn(sleeper) + nursery.spawn(waiter) + + assert record == list(range(10)) + ["yawn", "waiter done"] + + +################################################################ + +async def test__assert_raises(): + with pytest.raises(AssertionError): + with _assert_raises(RuntimeError): + 1 + 1 + + with pytest.raises(TypeError): + with _assert_raises(RuntimeError): + "foo" + 1 + + with _assert_raises(RuntimeError): + raise RuntimeError + +# This is a private implementation detail, but it's complex enough to be worth +# testing directly +async def test__UnboundeByteQueue(): + ubq = _UnboundedByteQueue() + + ubq.put(b"123") + ubq.put(b"456") + assert ubq.get_nowait(1) == b"1" + assert ubq.get_nowait(10) == b"23456" + ubq.put(b"789") + assert ubq.get_nowait() == b"789" + + with pytest.raises(_core.WouldBlock): + ubq.get_nowait(10) + with pytest.raises(_core.WouldBlock): + ubq.get_nowait() + + with pytest.raises(TypeError): + ubq.put("string") + + ubq.put(b"abc") + with assert_yields(): + assert await ubq.get(10) == b"abc" + ubq.put(b"def") + ubq.put(b"ghi") + with assert_yields(): + assert await ubq.get(1) == b"d" + with assert_yields(): + assert await ubq.get() == b"efghi" + + async def putter(data): + await wait_all_tasks_blocked() + ubq.put(data) + + async def getter(expect): + with assert_yields(): + assert await ubq.get() == expect + + async with _core.open_nursery() as nursery: + nursery.spawn(getter, b"xyz") + nursery.spawn(putter, b"xyz") + + # Two gets at the same time -> ResourceBusyError + with pytest.raises(_core.ResourceBusyError): + async with _core.open_nursery() as nursery: + nursery.spawn(getter, b"asdf") + nursery.spawn(getter, b"asdf") + + # Closing + + ubq.close() + with pytest.raises(_streams.ClosedStreamError): + ubq.put(b"---") + + assert ubq.get_nowait(10) == b"" + assert ubq.get_nowait() == b"" + assert await ubq.get(10) == b"" + assert await ubq.get() == b"" + + # close is idempotent + ubq.close() + + # close wakes up blocked getters + ubq2 = _UnboundedByteQueue() + + async def closer(): + await wait_all_tasks_blocked() + ubq2.close() + + async with _core.open_nursery() as nursery: + nursery.spawn(getter, b"") + nursery.spawn(closer) + + +async def test_MemorySendStream(): + mss = MemorySendStream() + + async def do_send_all(data): + with assert_yields(): + await mss.send_all(data) + + await do_send_all(b"123") + assert mss.get_data_nowait(1) == b"1" + assert mss.get_data_nowait() == b"23" + + with assert_yields(): + await mss.wait_send_all_might_not_block() + + with pytest.raises(_core.WouldBlock): + mss.get_data_nowait() + with pytest.raises(_core.WouldBlock): + mss.get_data_nowait(10) + + await do_send_all(b"456") + with assert_yields(): + assert await mss.get_data() == b"456" + + with pytest.raises(_core.ResourceBusyError): + async with _core.open_nursery() as nursery: + nursery.spawn(do_send_all, b"xxx") + nursery.spawn(do_send_all, b"xxx") + + with assert_yields(): + await mss.graceful_close() + + assert await mss.get_data() == b"xxx" + assert await mss.get_data() == b"" + with pytest.raises(_streams.ClosedStreamError): + await do_send_all(b"---") + + # hooks + + assert mss.send_all_hook is None + assert mss.wait_send_all_might_not_block_hook is None + assert mss.close_hook is None + + record = [] + async def send_all_hook(): + # hook runs after send_all does its work (can pull data out) + assert mss2.get_data_nowait() == b"abc" + record.append("send_all_hook") + async def wait_send_all_might_not_block_hook(): + record.append("wait_send_all_might_not_block_hook") + def close_hook(): + record.append("close_hook") + + mss2 = MemorySendStream( + send_all_hook, + wait_send_all_might_not_block_hook, + close_hook, + ) + + assert mss2.send_all_hook is send_all_hook + assert mss2.wait_send_all_might_not_block_hook is wait_send_all_might_not_block_hook + assert mss2.close_hook is close_hook + + await mss2.send_all(b"abc") + await mss2.wait_send_all_might_not_block() + mss2.forceful_close() + + assert record == [ + "send_all_hook", + "wait_send_all_might_not_block_hook", + "close_hook", + ] + + +async def test_MemoryRecieveStream(): + mrs = MemoryReceiveStream() + + async def do_receive_some(max_bytes): + with assert_yields(): + return await mrs.receive_some(max_bytes) + + mrs.put_data(b"abc") + assert await do_receive_some(1) == b"a" + assert await do_receive_some(10) == b"bc" + with pytest.raises(TypeError): + await do_receive_some(None) + + with pytest.raises(_core.ResourceBusyError): + async with _core.open_nursery() as nursery: + nursery.spawn(do_receive_some, 10) + nursery.spawn(do_receive_some, 10) + + assert mrs.receive_some_hook is None + + mrs.put_data(b"def") + mrs.put_eof() + mrs.put_eof() + + assert await do_receive_some(10) == b"def" + assert await do_receive_some(10) == b"" + assert await do_receive_some(10) == b"" + + with pytest.raises(_streams.ClosedStreamError): + mrs.put_data(b"---") + + async def receive_some_hook(): + mrs2.put_data(b"xxx") + + record = [] + def close_hook(): + record.append("closed") + + mrs2 = MemoryReceiveStream(receive_some_hook, close_hook) + assert mrs2.receive_some_hook is receive_some_hook + assert mrs2.close_hook is close_hook + + mrs2.put_data(b"yyy") + assert await mrs2.receive_some(10) == b"yyyxxx" + assert await mrs2.receive_some(10) == b"xxx" + assert await mrs2.receive_some(10) == b"xxx" + + mrs2.put_data(b"zzz") + mrs2.receive_some_hook = None + assert await mrs2.receive_some(10) == b"zzz" + + mrs2.put_data(b"lost on close") + with assert_yields(): + await mrs2.graceful_close() + assert record == ["closed"] + + with pytest.raises(_streams.ClosedStreamError): + await mrs2.receive_some(10) + + +async def test_MemoryRecvStream_closing(): + mrs = MemoryReceiveStream() + # close with no pending data + mrs.forceful_close() + with pytest.raises(_streams.ClosedStreamError): + assert await mrs.receive_some(10) == b"" + # repeated closes ok + mrs.forceful_close() + # put_data now fails + with pytest.raises(_streams.ClosedStreamError): + mrs.put_data(b"123") + + mrs2 = MemoryReceiveStream() + # close with pending data + mrs2.put_data(b"xyz") + mrs2.forceful_close() + with pytest.raises(_streams.ClosedStreamError): + await mrs2.receive_some(10) + + +async def test_memory_stream_pump(): + mss = MemorySendStream() + mrs = MemoryReceiveStream() + + # no-op if no data present + memory_stream_pump(mss, mrs) + + await mss.send_all(b"123") + memory_stream_pump(mss, mrs) + assert await mrs.receive_some(10) == b"123" + + await mss.send_all(b"456") + assert memory_stream_pump(mss, mrs, max_bytes=1) + assert await mrs.receive_some(10) == b"4" + assert memory_stream_pump(mss, mrs, max_bytes=1) + assert memory_stream_pump(mss, mrs, max_bytes=1) + assert not memory_stream_pump(mss, mrs, max_bytes=1) + assert await mrs.receive_some(10) == b"56" + + mss.forceful_close() + memory_stream_pump(mss, mrs) + assert await mrs.receive_some(10) == b"" + + +async def test_memory_stream_one_way_pair(): + s, r = memory_stream_one_way_pair() + assert s.send_all_hook is not None + assert s.wait_send_all_might_not_block_hook is None + assert s.close_hook is not None + assert r.receive_some_hook is None + await s.send_all(b"123") + assert await r.receive_some(10) == b"123" + + # This fails if we pump on r.receive_some_hook; we need to pump on s.send_all_hook + async def sender(): + await wait_all_tasks_blocked() + await s.send_all(b"abc") + + async def receiver(expected): + assert await r.receive_some(10) == expected + + async with _core.open_nursery() as nursery: + nursery.spawn(receiver, b"abc") + nursery.spawn(sender) + + # And this fails if we don't pump from close_hook + async def graceful_closer(): + await wait_all_tasks_blocked() + await s.graceful_close() + + async with _core.open_nursery() as nursery: + nursery.spawn(receiver, b"") + nursery.spawn(graceful_closer) + + s, r = memory_stream_one_way_pair() + + async def forceful_closer(): + await wait_all_tasks_blocked() + s.forceful_close() + + async with _core.open_nursery() as nursery: + nursery.spawn(receiver, b"") + nursery.spawn(forceful_closer) + + s, r = memory_stream_one_way_pair() + + old = s.send_all_hook + s.send_all_hook = None + await s.send_all(b"456") + + async def cancel_after_idle(nursery): + await wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + + async def check_for_cancel(): + with pytest.raises(_core.Cancelled): + # This should block forever... or until cancelled. Even though we + # sent some data on the send stream. + await r.receive_some(10) + + async with _core.open_nursery() as nursery: + nursery.spawn(cancel_after_idle, nursery) + nursery.spawn(check_for_cancel) + + s.send_all_hook = old + await s.send_all(b"789") + assert await r.receive_some(10) == b"456789" + + +async def test_memory_stream_pair(): + a, b = memory_stream_pair() + await a.send_all(b"123") + await b.send_all(b"abc") + assert await b.receive_some(10) == b"123" + assert await a.receive_some(10) == b"abc" + + await a.send_eof() + assert await b.receive_some(10) == b"" + + async def sender(): + await wait_all_tasks_blocked() + await b.send_all(b"xyz") + + async def receiver(): + assert await a.receive_some(10) == b"xyz" + + async with _core.open_nursery() as nursery: + nursery.spawn(receiver) + nursery.spawn(sender) + + +async def test_memory_streams_with_generic_tests(): + async def one_way_stream_maker(): + return memory_stream_one_way_pair() + await check_one_way_stream(one_way_stream_maker, None) + + async def half_closeable_stream_maker(): + return memory_stream_pair() + await check_half_closeable_stream(half_closeable_stream_maker, None) + + +async def test_lockstep_streams_with_generic_tests(): + async def one_way_stream_maker(): + return lockstep_stream_one_way_pair() + await check_one_way_stream(one_way_stream_maker, one_way_stream_maker) + + async def two_way_stream_maker(): + return lockstep_stream_pair() + await check_two_way_stream(two_way_stream_maker, two_way_stream_maker) diff --git a/trio/tests/test_util.py b/trio/tests/test_util.py index 6f39beeca8..00b0907d32 100644 --- a/trio/tests/test_util.py +++ b/trio/tests/test_util.py @@ -2,16 +2,59 @@ import signal -from .. import _util +from .._util import * +from .. import _core +from ..testing import wait_all_tasks_blocked, assert_yields -async def test_signal_raise(): +def test_signal_raise(): record = [] def handler(signum, _): record.append(signum) old = signal.signal(signal.SIGFPE, handler) try: - _util.signal_raise(signal.SIGFPE) + signal_raise(signal.SIGFPE) finally: signal.signal(signal.SIGFPE, old) assert record == [signal.SIGFPE] + + +async def test_UnLock(): + ul1 = UnLock(RuntimeError, "ul1") + ul2 = UnLock(ValueError) + + async with ul1: + with assert_yields(): + async with ul2: + print("ok") + + with pytest.raises(RuntimeError) as excinfo: + async with ul1: + with assert_yields(): + async with ul1: + pass # pragma: no cover + assert "ul1" in str(excinfo.value) + + with pytest.raises(ValueError) as excinfo: + async with ul2: + with assert_yields(): + async with ul2: + pass # pragma: no cover + + async def wait_with_ul1(): + async with ul1: + await wait_all_tasks_blocked() + + with pytest.raises(RuntimeError) as excinfo: + async with _core.open_nursery() as nursery: + nursery.spawn(wait_with_ul1) + nursery.spawn(wait_with_ul1) + assert "ul1" in str(excinfo.value) + + # mixing sync and async entry + with pytest.raises(RuntimeError) as excinfo: + with ul1.sync: + with assert_yields(): + async with ul1: + pass # pragma: no cover + assert "ul1" in str(excinfo.value)