diff --git a/.travis.yml b/.travis.yml
index 4cd148217a..990f89688b 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -18,6 +18,12 @@ matrix:
- os: linux
language: generic
env: USE_PYPY_NIGHTLY=1
+ # 5.7.1 has some minor bugs that are nonetheless annoying/hard to
+ # work around; let's wait for the next beta before we start
+ # testing pypy releases.
+ # - os: linux
+ # language: generic
+ # env: USE_PYPY_RELEASE=1
- os: osx
language: generic
env: MACPYTHON=3.5.3
diff --git a/MANIFEST.in b/MANIFEST.in
index 1051aeb219..ca841cc0c8 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -2,5 +2,6 @@ include LICENSE LICENSE.MIT LICENSE.APACHE2
include README.rst
include CODE_OF_CONDUCT.md
include test-requirements.txt
+recursive-include trio/tests/test_ssl_certs *.pem
recursive-include docs *
prune docs/build
diff --git a/ci/travis.sh b/ci/travis.sh
index eb8b78cc70..9dcb7cdb27 100755
--- a/ci/travis.sh
+++ b/ci/travis.sh
@@ -3,7 +3,7 @@
set -ex
if [ "$TRAVIS_OS_NAME" = "osx" ]; then
- curl -o macpython.pkg https://www.python.org/ftp/python/${MACPYTHON}/python-${MACPYTHON}-macosx10.6.pkg
+ curl -Lo macpython.pkg https://www.python.org/ftp/python/${MACPYTHON}/python-${MACPYTHON}-macosx10.6.pkg
sudo installer -pkg macpython.pkg -target /
ls /Library/Frameworks/Python.framework/Versions/*/bin/
if expr "${MACPYTHON}" : 2; then
@@ -18,7 +18,7 @@ if [ "$TRAVIS_OS_NAME" = "osx" ]; then
fi
if [ "$USE_PYPY_NIGHTLY" = "1" ]; then
- curl -o pypy.tar.bz2 http://buildbot.pypy.org/nightly/py3.5/pypy-c-jit-latest-linux64.tar.bz2
+ curl -Lo pypy.tar.bz2 http://buildbot.pypy.org/nightly/py3.5/pypy-c-jit-latest-linux64.tar.bz2
tar xaf pypy.tar.bz2
# something like "pypy-c-jit-89963-748aa3022295-linux64"
PYPY_DIR=$(echo pypy-c-jit-*)
@@ -29,6 +29,18 @@ if [ "$USE_PYPY_NIGHTLY" = "1" ]; then
source testenv/bin/activate
fi
+if [ "$USE_PYPY_RELEASE" = "1" ]; then
+ curl -Lo pypy.tar.bz2 https://bitbucket.org/squeaky/portable-pypy/downloads/pypy3.5-5.7.1-beta-linux_x86_64-portable.tar.bz2
+ tar xaf pypy.tar.bz2
+ # something like "pypy3.5-5.7.1-beta-linux_x86_64-portable"
+ PYPY_DIR=$(echo pypy3.5-*)
+ PYTHON_EXE=$PYPY_DIR/bin/pypy3
+ $PYTHON_EXE -m ensurepip
+ $PYTHON_EXE -m pip install virtualenv
+ $PYTHON_EXE -m virtualenv testenv
+ source testenv/bin/activate
+fi
+
pip install -U pip setuptools wheel
python setup.py sdist --formats=zip
diff --git a/docs/source/reference-core.rst b/docs/source/reference-core.rst
index 543d9381f6..ab87dc52c9 100644
--- a/docs/source/reference-core.rst
+++ b/docs/source/reference-core.rst
@@ -374,7 +374,7 @@ work that was happening within the scope that was cancelled.
Looking at this, you might wonder how you can tell whether the inner
block timed out – perhaps you want to do something different, like try
a fallback procedure or report a failure to our caller. To make this
-easier, :func:`move_on_after`'s ``__enter__`` function returns an
+easier, :func:`move_on_after`\´s ``__enter__`` function returns an
object representing this cancel scope, which we can use to check
whether this scope caught a :exc:`Cancelled` exception::
@@ -1364,6 +1364,9 @@ don't have any special access to trio's internals.)
.. autoclass:: Lock
:members:
+.. autoclass:: StrictFIFOLock
+ :members:
+
.. autoclass:: Condition
:members:
@@ -1534,12 +1537,14 @@ trio's internal scheduling decisions.
Exceptions
----------
-.. autoexception:: TrioInternalError
-
.. autoexception:: Cancelled
.. autoexception:: TooSlowError
.. autoexception:: WouldBlock
+.. autoexception:: ResourceBusyError
+
.. autoexception:: RunFinishedError
+
+.. autoexception:: TrioInternalError
diff --git a/docs/source/reference-hazmat.rst b/docs/source/reference-hazmat.rst
index fe8a1caef7..9ca67e9fa0 100644
--- a/docs/source/reference-hazmat.rst
+++ b/docs/source/reference-hazmat.rst
@@ -48,7 +48,7 @@ All environments provide the following functions:
:raises TypeError:
if the given object is not of type :func:`socket.socket`.
- :raises RuntimeError:
+ :raises trio.ResourceBusyError:
if another task is already waiting for the given socket to
become readable.
@@ -62,7 +62,7 @@ All environments provide the following functions:
:raises TypeError:
if the given object is not of type :func:`socket.socket`.
- :raises RuntimeError:
+ :raises trio.ResourceBusyError:
if another task is already waiting for the given socket to
become writable.
@@ -85,7 +85,7 @@ Unix-like systems provide the following functions:
:arg fd:
integer file descriptor, or else an object with a ``fileno()`` method
- :raises RuntimeError:
+ :raises trio.ResourceBusyError:
if another task is already waiting for the given fd to
become readable.
@@ -103,7 +103,7 @@ Unix-like systems provide the following functions:
:arg fd:
integer file descriptor, or else an object with a ``fileno()`` method
- :raises RuntimeError:
+ :raises trio.ResourceBusyError:
if another task is already waiting for the given fd to
become writable.
diff --git a/docs/source/reference-io.rst b/docs/source/reference-io.rst
index c49ecd589e..c43313af77 100644
--- a/docs/source/reference-io.rst
+++ b/docs/source/reference-io.rst
@@ -1,26 +1,193 @@
+.. currentmodule:: trio
+
I/O in Trio
===========
+.. note::
+
+ Please excuse our dust! `geocities-construction-worker.gif
+ `__
+
+ You're looking at the documentation for trio's development branch,
+ which is currently about half-way through implementing a proper
+ high-level networking API. If you want to know how to do networking
+ in trio *right now*, then you might want to jump down to read about
+ :mod:`trio.socket`, which is the already-working lower-level
+ API. Alternatively, you can read on for a (somewhat disorganized)
+ preview of coming attractions.
+
+.. _abstract-stream-api:
+
+The abstract Stream API
+-----------------------
+
+Trio provides a set of abstract base classes that define a standard
+interface for unidirectional and bidirectional byte streams.
+
+Why is this useful? Because it lets you write generic protocol
+implementations that can work over arbitrary transports, and easily
+create complex transport configurations. Here's some examples:
+
+* :class:`trio.SocketStream` wraps a raw socket (like a TCP connection
+ over the network), and converts it to the standard stream interface.
+
+* :class:`trio.ssl.SSLStream` is a "stream adapter" that can take any
+ object that implements the :class:`trio.abc.Stream` interface, and
+ convert it into an encrypted stream. In trio the standard way to
+ speak SSL over the network is to wrap an
+ :class:`~trio.ssl.SSLStream` around a :class:`~trio.SocketStream`.
+
+* If you spawn a subprocess then you can get a
+ :class:`~trio.abc.SendStream` that lets you write to its stdin, and
+ a :class:`~trio.abc.ReceiveStream` that lets you read from its
+ stdout. If for some reason you wanted to speak SSL to a subprocess,
+ you could use a :class:`StapledStream` to combine its stdin/stdout
+ into a single bidirectional :class:`~trio.abc.Stream`, and then wrap
+ that in an :class:`~trio.ssl.SSLStream`::
+
+ ssl_context = trio.ssl.create_default_context()
+ ssl_context.check_hostname = False
+ s = SSLStream(StapledStream(process.stdin, process.stdout), ssl_context)
+
+ [Note: subprocess support is not implemented yet, but that's the
+ plan. Unless it is implemented, and I forgot to remove this note.]
+
+* It sometimes happens that you want to connect to an HTTPS server,
+ but you have to go through a web proxy... and the proxy also uses
+ HTTPS. So you end up having to do `SSL-on-top-of-SSL
+ `__. In
+ trio this is trivial – just wrap your first
+ :class:`~trio.ssl.SSLStream` in a second
+ :class:`~trio.ssl.SSLStream`::
+
+ # Get a raw SocketStream connection to the proxy:
+ s0 = await open_tcp_stream("proxy", 443)
+
+ # Set up SSL connection to proxy:
+ s1 = SSLStream(s0, proxy_ssl_context, server_hostname="proxy")
+ # Request a connection to the website
+ await s1.send_all(b"CONNECT website:443 / HTTP/1.0\r\n")
+ await check_CONNECT_response(s1)
+
+ # Set up SSL connection to the real website. Notice that s1 is
+ # already an SSLStream object, and here we're wrapping a second
+ # SSLStream object around it.
+ s2 = SSLStream(s1, website_ssl_context, server_hostname="website")
+ # Make our request
+ await s2.send_all("GET /index.html HTTP/1.0\r\n")
+ ...
+
+* The :mod:`trio.testing` module provides a set of :ref:`flexible
+ in-memory stream object implementations `, so if
+ you have a protocol implementation to test then you can can spawn
+ two tasks, set up a virtual "socket" connecting them, and then do
+ things like inject random-but-repeatable delays into the connection.
+
+
+Abstract base classes
+~~~~~~~~~~~~~~~~~~~~~
+
+.. currentmodule:: trio.abc
+
+.. autoclass:: trio.abc.AsyncResource
+ :members:
+
+.. autoclass:: trio.abc.SendStream
+ :members:
+ :show-inheritance:
+
+.. autoclass:: trio.abc.ReceiveStream
+ :members:
+ :show-inheritance:
+
+.. autoclass:: trio.abc.Stream
+ :members:
+ :show-inheritance:
+
+.. autoclass:: trio.abc.HalfCloseableStream
+ :members:
+ :show-inheritance:
+
+.. currentmodule:: trio
+
+.. autoexception:: BrokenStreamError
+
+.. autoexception:: ClosedStreamError
+
+
+Generic stream implementations
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Trio currently provides one generic utility class for working with
+streams. And if you want to test code that's written against the
+streams interface, you should also check out :ref:`testing-streams` in
+:mod:`trio.testing`.
+
+.. autoclass:: StapledStream
+ :members:
+ :show-inheritance:
+
+
Sockets and networking
-----------------------
+~~~~~~~~~~~~~~~~~~~~~~
+
+The high-level network interface is built on top of our stream
+abstraction.
+
+.. autoclass:: SocketStream
+ :members:
+ :show-inheritance:
+
+.. autofunction:: socket_stream_pair
+
+
+SSL / TLS support
+~~~~~~~~~~~~~~~~~
+
+.. module:: trio.ssl
+
+The :mod:`trio.ssl` module implements SSL/TLS support for Trio, using
+the standard library :mod:`ssl` module. It re-exports most of
+:mod:`ssl`\´s API, with the notable exception is
+:class:`ssl.SSLContext`, which has unsafe defaults; if you really want
+to use :class:`ssl.SSLContext` you can import it from :mod:`ssl`, but
+normally you should create your contexts using
+:func:`trio.ssl.create_default_context `.
+
+Instead of using :meth:`ssl.SSLContext.wrap_socket`, though, you
+create a :class:`SSLStream`:
+
+.. autoclass:: SSLStream
+ :show-inheritance:
+ :members:
+
+
+Low-level sockets and networking
+--------------------------------
.. module:: trio.socket
-The :mod:`trio.socket` module provides trio's basic networking API.
+The :mod:`trio.socket` module provides trio's basic low-level
+networking API. If you're doing ordinary things with stream-oriented
+connections over IPv4/IPv6/Unix domain sockets, then you probably want
+to stick to the high-level API described above. If you want to use
+UDP, or exotic address families like ``AF_BLUETOOTH``, or otherwise
+get direct access to all the quirky bits of your system's networking
+API, then you're in the right place.
-:mod:`trio.socket`\'s top-level exports
+:mod:`trio.socket`: top-level exports
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-Generally, :mod:`trio.socket`\'s API mirrors that of the standard
-library :mod:`socket` module. Most constants (like ``SOL_SOCKET``) and
-simple utilities (like :func:`~socket.inet_aton`) are simply
-re-exported unchanged. But there are also some differences:
+Generally, the API exposed by :mod:`trio.socket` mirrors that of the
+standard library :mod:`socket` module. Most constants (like
+``SOL_SOCKET``) and simple utilities (like :func:`~socket.inet_aton`)
+are simply re-exported unchanged. But there are also some differences:
-All functions that return sockets (e.g. :func:`socket.socket`,
-:func:`socket.socketpair`, ...) are modified to return trio sockets
-instead. In addition, there is a new function to directly convert a
-standard library socket into a trio socket:
+All functions that return socket objects (e.g. :func:`socket.socket`,
+:func:`socket.socketpair`, ...) are modified to return trio socket
+objects instead. In addition, there is a new function to directly
+convert a standard library socket into a trio socket:
.. autofunction:: from_stdlib_socket
@@ -128,9 +295,7 @@ Socket objects
to port hijacking attacks
`__.
- 2. ``TCP_NODELAY`` is enabled by default.
-
- 3. ``IPV6_V6ONLY`` is disabled, i.e., by default on dual-stack
+ 2. ``IPV6_V6ONLY`` is disabled, i.e., by default on dual-stack
hosts a ``AF_INET6`` socket is able to communicate with both
IPv4 and IPv6 peers, where the IPv4 peers appear to be in the
`"IPv4-mapped" portion of IPv6 address space
@@ -143,10 +308,6 @@ Socket objects
This makes trio applications behave more consistently across
different environments.
- 4. On platforms where it's supported (recent Linux and recent
- MacOS), ``TCP_NOTSENT_LOWAT`` is enabled with a reasonable
- buffer size (currently 16 KiB).
-
See `issue #72 `__ for
discussion of these defaults.
@@ -163,12 +324,6 @@ Socket objects
`Not implemented yet! `__
- The following methods are *not* provided:
-
- * :meth:`~socket.socket.send`: This method has confusing semantics
- hidden under a friendly name, and makes it too easy to create
- subtle bugs. Use :meth:`sendall` instead.
-
The following methods are identical to their equivalents in
:func:`socket.socket`, except async, and the ones that take address
arguments require pre-resolved addresses:
@@ -180,6 +335,7 @@ Socket objects
* :meth:`~socket.socket.recvfrom_into`
* :meth:`~socket.socket.recvmsg` (if available)
* :meth:`~socket.socket.recvmsg_into` (if available)
+ * :meth:`~socket.socket.send`
* :meth:`~socket.socket.sendto`
* :meth:`~socket.socket.sendmsg` (if available)
@@ -204,37 +360,6 @@ Socket objects
* :meth:`~socket.socket.get_inheritable`
-The abstract Stream API
------------------------
-
-(this is currently more of a sketch than something actually useful,
-`see issue #73 `__)
-
-.. currentmodule:: trio
-
-.. autoclass:: AsyncResource
- :members:
- :undoc-members:
-
-.. autoclass:: SendStream
- :members:
- :undoc-members:
-
-.. autoclass:: RecvStream
- :members:
- :undoc-members:
-
-.. autoclass:: Stream
- :members:
- :undoc-members:
-
-
-TLS support
------------
-
-`Not implemented yet! `__
-
-
Async disk I/O
--------------
diff --git a/docs/source/reference-testing.rst b/docs/source/reference-testing.rst
index c45e75492c..66c138f1ea 100644
--- a/docs/source/reference-testing.rst
+++ b/docs/source/reference-testing.rst
@@ -71,6 +71,92 @@ Inter-task ordering
.. autofunction:: wait_all_tasks_blocked
+.. _testing-streams:
+
+Streams
+-------
+
+Virtual, controllable streams
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+One particularly challenging problem when testing network protocols is
+making sure that your implementation can handle data whose flow gets
+broken up in weird ways and arrives with weird timings: localhost
+connections tend to be much better behaved than real networks, so if
+you only test on localhost then you might get bitten later. To help
+you out, trio provides some fully in-memory implementations of the
+stream interfaces (see :ref:`abstract-stream-api`), that let you write
+all kinds of interestingly evil tests.
+
+There are a few pieces here, so here's how they fit together:
+
+:func:`memory_stream_pair` gives you a pair of connected,
+bidirectional streams. It's like :func:`socket.socketpair`, but
+without any involvement from that pesky operating system and its
+networking stack.
+
+To build a bidirectional stream, :func:`memory_stream_pair` uses
+two unidirectional streams. It gets these by calling
+:func:`memory_stream_one_way_pair`.
+
+:func:`memory_stream_one_way_pair`, in turn, is implemented using the
+low-ish level classes :class:`MemorySendStream` and
+:class:`MemoryReceiveStream`. These are implementations of (you
+guessed it) :class:`trio.abc.SendStream` and
+:class:`trio.abc.ReceiveStream` that on their own, aren't attached to
+anything – "sending" and "receiving" just put data into and get data
+out of a private internal buffer that each object owns. They also have
+some interesting hooks you can set, that let you customize the
+behavior of their methods. This is where you can insert the evil, if
+you want it. :func:`memory_stream_one_way_pair` takes advantage of
+these hooks in a relatively boring way: it just sets it up so that
+when you call ``sendall``, or when you close the send stream, then it
+automatically triggers a call to :func:`memory_stream_pump`, which is
+a convenience function that takes data out of a
+:class:`MemorySendStream`\´s buffer and puts it into a
+:class:`MemoryReceiveStream`\´s buffer. But that's just the default –
+you can replace this with whatever arbitrary behavior you want.
+
+Trio also provides some specialized functions for testing completely
+**un**\buffered streams: :func:`lockstep_stream_one_way_pair` and
+:func:`lockstep_stream_pair`. These aren't customizable, but they do
+exhibit an extreme kind of behavior that's otherwise hard to
+implement.
+
+
+API details
+~~~~~~~~~~~
+
+.. autoclass:: MemorySendStream
+ :members:
+
+.. autoclass:: MemoryReceiveStream
+ :members:
+
+.. autofunction:: memory_stream_pump
+
+.. autofunction:: memory_stream_one_way_pair
+
+.. autofunction:: memory_stream_pair
+
+.. autofunction:: lockstep_stream_one_way_pair
+
+.. autofunction:: lockstep_stream_pair
+
+
+Testing custom stream implementations
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Trio also provides some functions to help you test your custom stream
+implementations:
+
+.. autofunction:: check_one_way_stream
+
+.. autofunction:: check_two_way_stream
+
+.. autofunction:: check_half_closeable_stream
+
+
Testing checkpoints
--------------------
diff --git a/docs/source/tutorial.rst b/docs/source/tutorial.rst
index 63f5880346..def13e2f78 100644
--- a/docs/source/tutorial.rst
+++ b/docs/source/tutorial.rst
@@ -855,9 +855,13 @@ thing, and then we'll discuss the pieces:
.. literalinclude:: tutorial/echo-server-low-level.py
:linenos:
-The actual echo server implementation should be fairly familiar at
-this point. Each incoming connection from an echo client gets handled
-by its own dedicated task, running the ``echo_server`` function:
+Let's start with ``echo_server``. As we'll see below, each time an
+echo client connects, our server will spawn a child task running
+``echo_server``; there might be lots of these running at once if lots
+of clients are connected. Its job is to read data from its particular
+client, and then echo it back. It should be pretty straightforward to
+understand, because it uses the same socket functions we saw in the
+last section:
.. literalinclude:: tutorial/echo-server-low-level.py
:linenos:
@@ -922,8 +926,8 @@ tasks, we end up with a task tree that looks like:
┆
This lets ``parent`` focus on supervising the children,
-``echo_listener`` focus on listening for new connections, each
-``echo_server`` call will handle a single client.
+``echo_listener`` focus on listening for new connections, and each
+``echo_server`` focus on handling a single client.
Once we know this trick, the listener code becomes pretty
straightforward:
diff --git a/notes-to-self/server.crt b/notes-to-self/server.crt
new file mode 100644
index 0000000000..9c58d8e65b
--- /dev/null
+++ b/notes-to-self/server.crt
@@ -0,0 +1,19 @@
+-----BEGIN CERTIFICATE-----
+MIIDBjCCAe4CCQDq+3W9D8C4ejANBgkqhkiG9w0BAQsFADBFMQswCQYDVQQGEwJB
+VTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50ZXJuZXQgV2lkZ2l0
+cyBQdHkgTHRkMB4XDTE3MDMxOTAzMDk1MVoXDTE4MDMxOTAzMDk1MVowRTELMAkG
+A1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0
+IFdpZGdpdHMgUHR5IEx0ZDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB
+AOwDDFVh8pvIrhZtIIX6pb3/PO5SM3rWsfoyyHi73GxemIiEHfYEjMKN8Eo10jUv
+4G0n8VlrrmuhGR+UuHY6jCxjoCuYWszQhwBZBaeGE24ydtO/RE24yhNsJHPQWXMe
+TL4mg1EBjJYXTwNhd7SwgCpkBQ+724ZJg+CmiPuYhVLdvjjUUmwiSbeueyULIPEJ
+G1EWkKdU5pYtyyTZoc0x2YEjes3YNWY563yk+RljvidFBMyAX8N3fF4yrCCHDeY6
+UPdpXry/BJcEJm7PY2lMhbL71T6499qKnmSaWyJjm+KqbXSEYXoWDVBBvg5pR9Ia
+XSoJ1MTfJ8eYnZDs5mETYDkCAwEAATANBgkqhkiG9w0BAQsFAAOCAQEApaW8WiKA
+3yDOUVzgwkeX3HxvxfhxtMTPBmO1M8YgX1yi+URkkKakc6bg3XW1saQrxBkXWwBr
+81Atd0tOLwHsC1HPd7Y5Q/1LKZiYFq2Sva6eZfeedRF/0f/SQC+rSvZNI5DIVPS4
+jW/EpyMKIeerIyWeFXz0/NWcYLCDWN6m2iDtR3m98bJcqSdUemLgyR13EAWsaVZ7
+dB6nkwGl9e78SOIHeGYg1Fb0B7IN2Tqw2tO3Xn0mzhvqs65OYuYo4pB0FzxiySAB
+q2nrgu6kGhkQw/RQ8QJ5MYjydYqCU0I4Qve1W7RoUxRnJvxJrMuvcdlMeboASKNl
+L7YQurFGvAAiZQ==
+-----END CERTIFICATE-----
diff --git a/notes-to-self/server.csr b/notes-to-self/server.csr
new file mode 100644
index 0000000000..f0fbc3829d
--- /dev/null
+++ b/notes-to-self/server.csr
@@ -0,0 +1,16 @@
+-----BEGIN CERTIFICATE REQUEST-----
+MIICijCCAXICAQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUx
+ITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDCCASIwDQYJKoZIhvcN
+AQEBBQADggEPADCCAQoCggEBAOwDDFVh8pvIrhZtIIX6pb3/PO5SM3rWsfoyyHi7
+3GxemIiEHfYEjMKN8Eo10jUv4G0n8VlrrmuhGR+UuHY6jCxjoCuYWszQhwBZBaeG
+E24ydtO/RE24yhNsJHPQWXMeTL4mg1EBjJYXTwNhd7SwgCpkBQ+724ZJg+CmiPuY
+hVLdvjjUUmwiSbeueyULIPEJG1EWkKdU5pYtyyTZoc0x2YEjes3YNWY563yk+Rlj
+vidFBMyAX8N3fF4yrCCHDeY6UPdpXry/BJcEJm7PY2lMhbL71T6499qKnmSaWyJj
+m+KqbXSEYXoWDVBBvg5pR9IaXSoJ1MTfJ8eYnZDs5mETYDkCAwEAAaAAMA0GCSqG
+SIb3DQEBCwUAA4IBAQC+LhkPmCjxk5Nzn743u+7D/YzNhjv8Xv4aGUjjNyspVNso
+tlCAWkW2dWo8USvQrMUz5yl6qj6QQlg0QaYfaIiK8pkGz4s+Sh1plz1Eaa7QDK4O
+0wmtP6KkJyQW561ZY8sixS1DevKOmsp2Pa9fWU/vqKfzRv85A975XNadp6hkxXd7
+YOZCrSZjTnakpQvKoItvT9Xk7yKP6BI6h/03XORscbW/HyvLGoVLdE80yIkmjSot
+3JXxHspT27bWNWhz/Slph3UFaVyOVGXFTAqkLDZ3OISMnuC+q/t38EHYkR1aev/l
+4WogCtlWkFZ3bmhmlhJrH/bdTEkM6WopwoC6bczh
+-----END CERTIFICATE REQUEST-----
diff --git a/notes-to-self/server.key b/notes-to-self/server.key
new file mode 100644
index 0000000000..c0ba0b8582
--- /dev/null
+++ b/notes-to-self/server.key
@@ -0,0 +1,27 @@
+-----BEGIN RSA PRIVATE KEY-----
+MIIEogIBAAKCAQEA7AMMVWHym8iuFm0ghfqlvf887lIzetax+jLIeLvcbF6YiIQd
+9gSMwo3wSjXSNS/gbSfxWWuua6EZH5S4djqMLGOgK5hazNCHAFkFp4YTbjJ2079E
+TbjKE2wkc9BZcx5MviaDUQGMlhdPA2F3tLCAKmQFD7vbhkmD4KaI+5iFUt2+ONRS
+bCJJt657JQsg8QkbURaQp1Tmli3LJNmhzTHZgSN6zdg1ZjnrfKT5GWO+J0UEzIBf
+w3d8XjKsIIcN5jpQ92levL8ElwQmbs9jaUyFsvvVPrj32oqeZJpbImOb4qptdIRh
+ehYNUEG+DmlH0hpdKgnUxN8nx5idkOzmYRNgOQIDAQABAoIBABuus9ij93fsTvcU
+b7cnUh95+6ScgatL2W5WXItExbL0WYHRtU3w9K2xRlj9/Rz98536DGYHqlq3d6Hr
+qMM9VMm0GcpjQWs6nksdJfujT04inytxCMrw/MrQaWooKwXErQ20qLxsqRfFvh/Q
+Y+EOvsm6F5nj1/jlUJGeFv0jw6eXXxH6bqVUVYIaVCpAMB5Sm8caQ4dAI9UESZJv
+vuucT24iSyV8vp060L1tNKgRUr5e2CMfbucauZh0nLALPAyu1I07Ce62q9wLLw66
+c2FLHcZBkTGvL0bPe89ttJJuK0jttHV6GQ/OneytezZFxLw1DMsG3VxzbXt2X7AN
+noGzrDECgYEA/fnK0xlNir9bTNmOID42GoVUYF6iYWUf//rRlCKsRofSlPjTZDJK
+grl/plTBKDE6qDDEkB1mrEkJufqP3slyq66NfkP0NLoo+PFkGSnsbvUvFNYwcYvH
+7w2NWo/GvM4DJRqHvrETryBQwQtBJFsq9biWd3+hNCXYrhawKGqbzw0CgYEA7eSa
+T6zIdmvszG5x1XzQ3k29GwUA4SLL1YfV2pnLxoMZgW5Q6T+cOCpJdEuL0BXCNelP
+gk0gNXNvCzylwVC0BbpefFaJYsWK6gVg1EwDkiZcGx4FnKd0TWYer6RWrZ9cVohT
+eNwix9kKVef7chf+2006eE1O8D0UYwZMpGifqt0CgYAKjmtjwtV6QuHkm9ZQeMV+
+7LPJHaXaLn3aAe7cHWTTuamDD6SZsY1vSY6It1Uf+ovZmc1RwCcYWiDRXhzEwdLG
+WAcBjImF94bkcgQbF6cAJajDUPPKhGjXAtUxQnCcQGPZEvU5c9rBmLJCk9ktTazH
+cdivNtrYdApBkifYRjYbsQKBgDZl0ctqTSSXJTzR/IG+2twalqV5DWxt0oJvXz1v
+caNhExH/sczEWOqW8NkA9WWNtC0zvpSjIjxWuwuswJJl6+Rra3OvLhdB6LP+qteg
+0ig3UVR6FvptaDDSqy2qvI9TI4A+CChY3jMotC5Ur7C1P/fRvw8HToesz96c8CWg
+LvKZAoGAS4VW2YaYVlN14mkKGtHbZq0XWZGursAwYJn6H5jBCjl+89tsFzfCnTER
+hZFH+zs281v1bVbCFws8lWrZg8AZh55dG8CcLtuCkTyXJ/aAdlan0+FmXV60+RLP
+Z1TyykQG/oDgO1Z+5GrcN2b0FOFaSbH2NRzRlhyOI63yTQi4lT8=
+-----END RSA PRIVATE KEY-----
diff --git a/notes-to-self/server.orig.key b/notes-to-self/server.orig.key
new file mode 100644
index 0000000000..28fac173ff
--- /dev/null
+++ b/notes-to-self/server.orig.key
@@ -0,0 +1,30 @@
+-----BEGIN RSA PRIVATE KEY-----
+Proc-Type: 4,ENCRYPTED
+DEK-Info: DES-EDE3-CBC,D9C5B2214855387C
+
+gYuCZiXsU74IOErbOGOmc1y/BFP1N7UuRO19tidUrq1O6sreJSAVKRibIynAwmXj
+p5xvaAnBBIZIH6X7I2vduJgtUeeyvy5yxR98pD6liRKDxFaVD+O1m5IZxSbAs2De
+olk4Zlv3YULpbVF6Ud+QuLmgbqfmT+8NVGm4MwRey7Gkj+LEfGrNjpfgLqNRIaUZ
+XDPQh9HLZYCsAbz5OeRHwJwawLO74fvWBkFjsQyoWLgJzqZFmt15SyrRufBeKYP6
+oKKemsiW8/A2+i6Rb1vHYOJJ6c9jeeHPJkZSbfNWf4/Z702DAMIisbHmTCQzsUrX
+178d2Z7sDKcuDCQ1EInnLRb3YET/V83wGDWyHxWepaHLWHd5S7tFbsqZFsXxIuYM
+lcZZVSPsOLnG2SozZK+Tr2RX7jkI4Kmfh0RDgtKBYQQopZjRSFUG2hvvH3EIxVIf
+JyUG8AA5RT1J9tkcSJJ5MS40So7i3eyAuZXuYVSkuDai/mu2IUU8vYnwB8a1psU1
+P2CGUj2AFopvMAfSOYIPGHpcIn+lvxuXUdczR/Yikp/BhGT+diJjP68CUsMBdyq7
+pcVmMVyQPVpcsMag3IXGgIAF1v1GhO3zDMd1uXA1lyrHQa6CEah3z+4WFSWwYZ0I
+OZz5qM9bnfKoAQesp+xmcZhs8cbrblMRVDkWiPUixxKVJk3eBUsMoa1WYq/2u0ly
+EgvNlY39B/3eiLi+k+S6gVGT8a4AP6n4RuxPD4g0A79bI1LpC3xU6vcKV/GyIP69
+t2DHR2q9UDEiRj9DxjuTzxew7eRX8ktD7DhYV06BxrgQIRRiL9MrZRKGuqzXcMP/
+kWY71ioFZJ1ViZkpy7bEsYrpF2XBjGge3We0s2udnrY3r3ogxOjtZiT27e2zEbXD
+T59C3gecuEzCSCZ3eQtdcVC9m3RdHMTNNKvqmTVFPgfGOoM5u2gG+rYjhetbpTDB
+T5drcEEAcV11DHuokU4tlqOdIWdLuBsK3xgO98JasEr1LyYJT1fnjB+6AbhjfSS2
+p5TPekmSwaZbaBzwfP1xmhINJm388GCROXMkc9iLAWN9npHhssfMAA2WMXqDTgSt
+34oUnHgLGmvOm5HzJE/tTR1WP1Rye4nKNLwsbk2x7WXxqcNUPYc+OVmZbsl/R5Gz
+3zRHPts01mT/eaSfqj1wkJgpYtDQLPO+V1fc2pDgJmQMYyr7OCLI6I9GJBlB8gVq
+aemv0TMi3/eUVyJRaAHxAAi7YMsrSkuKUrsbLfIIRgViaEy+1stFa9iWiHJT0DKJ
+0fOqtwcL8OYJURyG/D29yUP5qBJcrFuIYk8uI1wtfDNMeAI4LWoWwMhBLtB6POY+
+a/qmMewFzrGGsR9R0ptwtlplhvJVeArfLYGngnbgBV4vwchjLQTR2RMouZWlwRH9
+NWX6EqsIk/zzYvu+o7sBC2839D3GCPQMmgKqSWwmlf2a76mqZk2duTO9+0v6+e+F
+Qc44ndLFE+mEibXkm9PMHvPsXOUdC4KPpugC/aZbn4OCqVd3eSl7k+PZGKZua6IJ
+ybhosNzQc4lg25K7iMxRXpK5WrOgEXSAA3kUquDRTWHshpz/Avwbgw==
+-----END RSA PRIVATE KEY-----
diff --git a/notes-to-self/ssl-handshake/ssl-handshake.py b/notes-to-self/ssl-handshake/ssl-handshake.py
new file mode 100644
index 0000000000..81d875be6a
--- /dev/null
+++ b/notes-to-self/ssl-handshake/ssl-handshake.py
@@ -0,0 +1,131 @@
+import ssl
+import socket
+import threading
+from contextlib import contextmanager
+
+BUFSIZE = 4096
+
+server_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
+server_ctx.load_cert_chain("trio-test-1.pem")
+
+def _ssl_echo_serve_sync(sock):
+ try:
+ wrapped = server_ctx.wrap_socket(sock, server_side=True)
+ while True:
+ data = wrapped.recv(BUFSIZE)
+ if not data:
+ wrapped.unwrap()
+ return
+ wrapped.sendall(data)
+ except BrokenPipeError:
+ pass
+
+@contextmanager
+def echo_server_connection():
+ client_sock, server_sock = socket.socketpair()
+ with client_sock, server_sock:
+ t = threading.Thread(
+ target=_ssl_echo_serve_sync, args=(server_sock,), daemon=True)
+ t.start()
+
+ yield client_sock
+
+class ManuallyWrappedSocket:
+ def __init__(self, ctx, sock, **kwargs):
+ self.incoming = ssl.MemoryBIO()
+ self.outgoing = ssl.MemoryBIO()
+ self.obj = ctx.wrap_bio(self.incoming, self.outgoing, **kwargs)
+ self.sock = sock
+
+ def _retry(self, fn, *args):
+ finished = False
+ while not finished:
+ want_read = False
+ try:
+ ret = fn(*args)
+ except ssl.SSLWantReadError:
+ want_read = True
+ except ssl.SSLWantWriteError:
+ # can't happen, but if it did this would be the right way to
+ # handle it anyway
+ pass
+ else:
+ finished = True
+ # do any sending
+ data = self.outgoing.read()
+ if data:
+ self.sock.sendall(data)
+ # do any receiving
+ if want_read:
+ data = self.sock.recv(BUFSIZE)
+ if not data:
+ self.incoming.write_eof()
+ else:
+ self.incoming.write(data)
+ # then retry if necessary
+ return ret
+
+ def do_handshake(self):
+ self._retry(self.obj.do_handshake)
+
+ def recv(self, bufsize):
+ return self._retry(self.obj.read, bufsize)
+
+ def sendall(self, data):
+ self._retry(self.obj.write, data)
+
+ def unwrap(self):
+ self._retry(self.obj.unwrap)
+ return self.sock
+
+
+def wrap_socket_via_wrap_socket(ctx, sock, **kwargs):
+ return ctx.wrap_socket(sock, do_handshake_on_connect=False, **kwargs)
+
+def wrap_socket_via_wrap_bio(ctx, sock, **kwargs):
+ return ManuallyWrappedSocket(ctx, sock, **kwargs)
+
+
+for wrap_socket in [
+ wrap_socket_via_wrap_socket,
+ wrap_socket_via_wrap_bio,
+]:
+ print("\n--- checking {} ---\n".format(wrap_socket.__name__))
+
+ print("checking with do_handshake + correct hostname...")
+ with echo_server_connection() as client_sock:
+ client_ctx = ssl.create_default_context(cafile="trio-test-CA.pem")
+ wrapped = wrap_socket(
+ client_ctx, client_sock, server_hostname="trio-test-1.example.org")
+ wrapped.do_handshake()
+ wrapped.sendall(b"x")
+ assert wrapped.recv(1) == b"x"
+ wrapped.unwrap()
+ print("...success")
+
+ print("checking with do_handshake + wrong hostname...")
+ with echo_server_connection() as client_sock:
+ client_ctx = ssl.create_default_context(cafile="trio-test-CA.pem")
+ wrapped = wrap_socket(
+ client_ctx, client_sock, server_hostname="trio-test-2.example.org")
+ try:
+ wrapped.do_handshake()
+ except Exception:
+ print("...got error as expected")
+ else:
+ print("??? no error ???")
+
+ print("checking withOUT do_handshake + wrong hostname...")
+ with echo_server_connection() as client_sock:
+ client_ctx = ssl.create_default_context(cafile="trio-test-CA.pem")
+ wrapped = wrap_socket(
+ client_ctx, client_sock, server_hostname="trio-test-2.example.org")
+ # We forgot to call do_handshake
+ # But the hostname is wrong so something had better error out...
+ sent = b"x"
+ print("sending", sent)
+ wrapped.sendall(sent)
+ got = wrapped.recv(1)
+ print("got:", got)
+ assert got == sent
+ print("!!!! successful chat with invalid host! we have been haxored!")
diff --git a/notes-to-self/ssl-handshake/trio-test-1.pem b/notes-to-self/ssl-handshake/trio-test-1.pem
new file mode 100644
index 0000000000..a0c1b773f9
--- /dev/null
+++ b/notes-to-self/ssl-handshake/trio-test-1.pem
@@ -0,0 +1,64 @@
+-----BEGIN PRIVATE KEY-----
+MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDZ8yz1OrHX7aHp
+Erfa1ds8kmYfqYomjgy5wDsGdb8i1gF4uxhHCRDQtNZANVOVXI7R3TMchA1GMxzA
+ZYDuBDuEUsqTktbTEBNb4GjOhyMu1fF4dX/tMxf7GB+flTx178eE2exTOZLmSmBa
+2laoDVe3CrBAYE7nZtBF630jKKKMsUuIl0CbFRHajpoqM3e3CeCo4KcbBzgujRA3
+AsVV6y5qhMH2zqLkOYaurVUfEkdjqoHFgj1VbjWpkTbrXAxPwW6v/uZK056bHgBg
+go03RyWexaPapsF2oUm2JNdSN3z7MP0umKphO2n9icyGt9Bmkm2AKs3dA45VLPXh
++NohluqJAgMBAAECggEARlfWCtAG1ko8F52S+W5MdCBMFawCiq8OLGV+p3cZWYT4
+tJ6uFz81ziaPf+m2MF7POazK8kksf5u/i9k245s6GlseRsL90uE9XknvibjUAinK
+5bYGs+fptYDzs+3WtbnOC3LKc5IBd5JJxwjxLwwfY1RvzldHIChu0CJRISfcTsvR
+occ8hXdeft7svNymvTuwQd05u1yjzL4RwF8Be76i17j5+jDsrAaUKdxxwGNAyOU7
+OKrUY6G851T6NUGgC19iXAJ1wN9tVGIR5QOs3J/s6dCctnX5tN8Di7prkXCKvVlm
+vhpC8XWWG+c3LhS90wmEBvKS0AfUeoPDHxMOLyzKgQKBgQD07lZRO0nsc38+PVaI
+NrvlP90Q8OgbwMIC52jmSZK3b5YSh3TrllsbCg6hzUk1SAJsa3qi7B1vq36Fd+rG
+LGDRW9xY0cfShLhzqvZWi45zU/RYnEcWHOuXQshLikx1DWUpg2KbLSVT2/lyvzmn
+QgM1Te8CSxW5vrBRVfluXoJuEwKBgQDjzLAbwk/wdjITKlQtirtsJEzWi3LGuUrg
+Z2kMz+0ztUU5d1oFL9B5xh0CwK8bpK9kYnoVZSy/r5+mGHqyz1eKaDdAXIR13nC0
+g7aZbTZzbt2btvuNZc3NCzRffHF3sCqp8a+oCryHyITjZcA+WYeU8nG0TQ5O8Zgr
+Skbo1JGocwKBgQC4jCx1oFqe0pd5afYdREBnB6ul7B63apHEZmBfw+fMV0OYSoAK
+Uovq37UOrQMQJmXNE16gC5BSZ8E5B5XaI+3/UVvBgK8zK9VfMd3Sb+yxcPyXF4lo
+W/oXSrZoVJgvShyDHv/ZNDb/7KsTjon+QHryWvpPnAuOnON1JXZ/dq6ICQKBgCZF
+AukG8esR0EPL/qxP/ECksIvyjWu5QU0F0m4mmFDxiRmoZWUtrTZoBAOsXz6johuZ
+N61Ue/oQBSAgSKy1jJ1h+LZFVLOAlSqeXhTUditaWryINyaADdz+nuPTwjQ7Uk+O
+nNX8R8P/+eNB+tP+snphaJzDvT2h9NCA//ypiXblAoGAJoLmotPI+P3KIRVzESL0
+DAsVmeijtXE3H+R4nwqUDQbBbFKx0/u2pbON+D5C9llaGiuUp9H+awtwQRYhToeX
+CNguwWrcpuhFOCeXDHDWF/0NIZYD2wBMxjF/eUarvoLaT4Gi0yyWh5ExIKOW4bFk
+EojUPSJ3gomOUp5bIFcSmSU=
+-----END PRIVATE KEY-----
+-----BEGIN CERTIFICATE-----
+MIICrzCCAZcCAQEwDQYJKoZIhvcNAQELBQAwFzEVMBMGA1UECgwMVHJpbyB0ZXN0
+IENBMCAXDTE3MDQwOTEwMDcyMVoYDzIyOTEwMTIyMTAwNzIxWjAiMSAwHgYDVQQD
+DBd0cmlvLXRlc3QtMS5leGFtcGxlLm9yZzCCASIwDQYJKoZIhvcNAQEBBQADggEP
+ADCCAQoCggEBANnzLPU6sdftoekSt9rV2zySZh+piiaODLnAOwZ1vyLWAXi7GEcJ
+ENC01kA1U5VcjtHdMxyEDUYzHMBlgO4EO4RSypOS1tMQE1vgaM6HIy7V8Xh1f+0z
+F/sYH5+VPHXvx4TZ7FM5kuZKYFraVqgNV7cKsEBgTudm0EXrfSMoooyxS4iXQJsV
+EdqOmiozd7cJ4KjgpxsHOC6NEDcCxVXrLmqEwfbOouQ5hq6tVR8SR2OqgcWCPVVu
+NamRNutcDE/Bbq/+5krTnpseAGCCjTdHJZ7Fo9qmwXahSbYk11I3fPsw/S6YqmE7
+af2JzIa30GaSbYAqzd0DjlUs9eH42iGW6okCAwEAATANBgkqhkiG9w0BAQsFAAOC
+AQEAlRNA96H88lVnzlpQUYt0pwpoy7B3/CDe8Uvl41thKEfTjb+SIo95F4l+fi+l
+jISWSonAYXRMNqymPMXl2ir0NigxfvvrcjggER3khASIs0l1ICwTNTv2a40NnFY6
+ZjTaBeSZ/lAi7191AkENDYvMl3aGhb6kALVIbos4/5LvJYF/UXvQfrjriLWZq/I3
+WkvduU9oSi0EA4Jt9aAhblsgDHMBL0+LU8Nl1tgzy2/NePcJWjzBRQDlF8uxCQ+2
+LesZongKQ+lebS4eYbNs0s810h8hrOEcn7VWn7FfxZRkjeaKIst2FCHmdr5JJgxj
+8fw+s7l2UkrNURAJ4IRNQvPB+w==
+-----END CERTIFICATE-----
+-----BEGIN CERTIFICATE-----
+MIIDBjCCAe6gAwIBAgIJAIUF+wna+nuzMA0GCSqGSIb3DQEBCwUAMBcxFTATBgNV
+BAoMDFRyaW8gdGVzdCBDQTAgFw0xNzA0MDkxMDA3MjFaGA8yMjkxMDEyMjEwMDcy
+MVowFzEVMBMGA1UECgwMVHJpbyB0ZXN0IENBMIIBIjANBgkqhkiG9w0BAQEFAAOC
+AQ8AMIIBCgKCAQEAyhE82Cbq+2c2f9M+vj2f9v+z+0+bMZDUVPSXhBDiRdKubt+K
+f9vY+ZH3ze1sm0iNgO6xU3OsDTlzO5z0TpsvEEbs0wgsJDUXD7Y8Fb1zH2jaVCro
+Y6KcVfFZvD96zsVCnZy0vMsYJw20iIL0RNCtr17lXWVxd17OoVy91djFD9v/cixu
+LRIr+N7pa8BDLUQUO/g0ui9YSC9Wgf67mr93KXKPGwjTHBGdjeZeex198j5aZjZR
+lkPH/9g5d3hP7EI0EAIMDVd4dvwNJgZzv+AZINbKLAkQyE9AAm+xQ7qBSvdfAvKq
+N/fwaFevmyrxUBcfoQxSpds8njWDb3dQzCn7ywIDAQABo1MwUTAdBgNVHQ4EFgQU
+JiilveozF8Qpyy2fS3wV4foVRCswHwYDVR0jBBgwFoAUJiilveozF8Qpyy2fS3wV
+4foVRCswDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAkcthPjkJ
++npvMKyQtUx7CSVcT4Ar0jHrvfPSg17ipyLv+MhSTIbS2VwhSYXxNmu+oBWKuYUs
+BnNxly3+SOcs+dTP3GBMng91SBsz5hhbP4ws8uUtCvYJauzeHbeY67R14RT8Ws/b
+mP6HDiybN7zy6LOKGCiz+sCoJqVZG/yBYO87iQsTTyNttgoG27yUSvzP07EQwUa5
+F9dI9Wn+4b5wP2ofMCu3asTbKXjfFbz3w5OkRgpGYhC4jhDdOw/819+01R9//GrM
+54Gme03yDAAM7nGihr1Xtld3dp2gLuqv0WgxKBqvG5X+nCbr2WamscAP5qz149vo
+y6Hq6P4mm2GmZw==
+-----END CERTIFICATE-----
diff --git a/notes-to-self/ssl-handshake/trio-test-CA.pem b/notes-to-self/ssl-handshake/trio-test-CA.pem
new file mode 100644
index 0000000000..9bf34001b2
--- /dev/null
+++ b/notes-to-self/ssl-handshake/trio-test-CA.pem
@@ -0,0 +1,19 @@
+-----BEGIN CERTIFICATE-----
+MIIDBjCCAe6gAwIBAgIJAIUF+wna+nuzMA0GCSqGSIb3DQEBCwUAMBcxFTATBgNV
+BAoMDFRyaW8gdGVzdCBDQTAgFw0xNzA0MDkxMDA3MjFaGA8yMjkxMDEyMjEwMDcy
+MVowFzEVMBMGA1UECgwMVHJpbyB0ZXN0IENBMIIBIjANBgkqhkiG9w0BAQEFAAOC
+AQ8AMIIBCgKCAQEAyhE82Cbq+2c2f9M+vj2f9v+z+0+bMZDUVPSXhBDiRdKubt+K
+f9vY+ZH3ze1sm0iNgO6xU3OsDTlzO5z0TpsvEEbs0wgsJDUXD7Y8Fb1zH2jaVCro
+Y6KcVfFZvD96zsVCnZy0vMsYJw20iIL0RNCtr17lXWVxd17OoVy91djFD9v/cixu
+LRIr+N7pa8BDLUQUO/g0ui9YSC9Wgf67mr93KXKPGwjTHBGdjeZeex198j5aZjZR
+lkPH/9g5d3hP7EI0EAIMDVd4dvwNJgZzv+AZINbKLAkQyE9AAm+xQ7qBSvdfAvKq
+N/fwaFevmyrxUBcfoQxSpds8njWDb3dQzCn7ywIDAQABo1MwUTAdBgNVHQ4EFgQU
+JiilveozF8Qpyy2fS3wV4foVRCswHwYDVR0jBBgwFoAUJiilveozF8Qpyy2fS3wV
+4foVRCswDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAkcthPjkJ
++npvMKyQtUx7CSVcT4Ar0jHrvfPSg17ipyLv+MhSTIbS2VwhSYXxNmu+oBWKuYUs
+BnNxly3+SOcs+dTP3GBMng91SBsz5hhbP4ws8uUtCvYJauzeHbeY67R14RT8Ws/b
+mP6HDiybN7zy6LOKGCiz+sCoJqVZG/yBYO87iQsTTyNttgoG27yUSvzP07EQwUa5
+F9dI9Wn+4b5wP2ofMCu3asTbKXjfFbz3w5OkRgpGYhC4jhDdOw/819+01R9//GrM
+54Gme03yDAAM7nGihr1Xtld3dp2gLuqv0WgxKBqvG5X+nCbr2WamscAP5qz149vo
+y6Hq6P4mm2GmZw==
+-----END CERTIFICATE-----
diff --git a/notes-to-self/sslobject.py b/notes-to-self/sslobject.py
new file mode 100644
index 0000000000..cfac98676e
--- /dev/null
+++ b/notes-to-self/sslobject.py
@@ -0,0 +1,76 @@
+from contextlib import contextmanager
+import ssl
+
+client_ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
+client_ctx.check_hostname = False
+client_ctx.verify_mode = ssl.CERT_NONE
+
+cinb = ssl.MemoryBIO()
+coutb = ssl.MemoryBIO()
+cso = client_ctx.wrap_bio(cinb, coutb)
+
+server_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
+server_ctx.load_cert_chain("server.crt", "server.key", "xxxx")
+sinb = ssl.MemoryBIO()
+soutb = ssl.MemoryBIO()
+sso = server_ctx.wrap_bio(sinb, soutb, server_side=True)
+
+@contextmanager
+def expect(etype):
+ try:
+ yield
+ except etype:
+ pass
+ else:
+ raise AssertionError("expected {}".format(etype))
+
+with expect(ssl.SSLWantReadError):
+ cso.do_handshake()
+assert not cinb.pending
+assert coutb.pending
+
+with expect(ssl.SSLWantReadError):
+ sso.do_handshake()
+assert not sinb.pending
+assert not soutb.pending
+
+# A trickle is not enough
+# sinb.write(coutb.read(1))
+# with expect(ssl.SSLWantReadError):
+# cso.do_handshake()
+# with expect(ssl.SSLWantReadError):
+# sso.do_handshake()
+
+sinb.write(coutb.read())
+# Now it should be able to respond
+with expect(ssl.SSLWantReadError):
+ sso.do_handshake()
+assert soutb.pending
+
+cinb.write(soutb.read())
+with expect(ssl.SSLWantReadError):
+ cso.do_handshake()
+
+sinb.write(coutb.read())
+# server done!
+sso.do_handshake()
+assert soutb.pending
+
+# client done!
+cinb.write(soutb.read())
+cso.do_handshake()
+
+cso.write(b"hello")
+sinb.write(coutb.read())
+assert sso.read(10) == b"hello"
+with expect(ssl.SSLWantReadError):
+ sso.read(10)
+
+# cso.write(b"x" * 2 ** 30)
+# print(coutb.pending)
+
+assert not coutb.pending
+assert not cinb.pending
+sso.do_handshake()
+assert not coutb.pending
+assert not cinb.pending
diff --git a/setup.py b/setup.py
index 92ad6c4ddb..9a45bc9550 100644
--- a/setup.py
+++ b/setup.py
@@ -78,6 +78,9 @@
# https://bitbucket.org/pypa/wheel/issues/181/bdist_wheel-silently-discards-pep-508
#"cffi; os_name == 'nt'", # "cffi is required on windows"
],
+ # This means, just install *everything* you see under trio/, even if it
+ # doesn't look like a source file, so long as it appears in MANIFEST.in:
+ include_package_data=True,
# Quirky bdist_wheel-specific way:
# https://wheel.readthedocs.io/en/latest/#defining-conditional-dependencies
# also supported by pip and setuptools, as long as they're vaguely
diff --git a/test-requirements.txt b/test-requirements.txt
index d43ae6ee1c..4373f85907 100644
--- a/test-requirements.txt
+++ b/test-requirements.txt
@@ -1,3 +1,4 @@
pytest >= 3.0.6 # fixes a bug in handling async def test_*
pytest-cov
-ipython # for the IPython traceback integration tests
+ipython # for the IPython traceback integration tests
+pyOpenSSL # for the ssl tests
diff --git a/trio/__init__.py b/trio/__init__.py
index fc122d55ee..040b2fdf10 100644
--- a/trio/__init__.py
+++ b/trio/__init__.py
@@ -31,6 +31,7 @@
__all__.append(_symbol)
del _symbol, _value
+
from ._timeouts import *
__all__ += _timeouts.__all__
@@ -46,7 +47,11 @@
from ._signals import *
__all__ += _signals.__all__
+from ._network import *
+__all__ += _network.__all__
+
# Imported by default
from . import socket
from . import abc
+from . import ssl
# Not imported by default: testing
diff --git a/trio/_core/_exceptions.py b/trio/_core/_exceptions.py
index 589e5edda6..20f5235eb0 100644
--- a/trio/_core/_exceptions.py
+++ b/trio/_core/_exceptions.py
@@ -2,8 +2,8 @@
# Re-exported
__all__ = [
- "TrioInternalError", "RunFinishedError", "WouldBlock",
- "Cancelled", "PartialResult",
+ "TrioInternalError", "RunFinishedError", "WouldBlock", "Cancelled",
+ "ResourceBusyError",
]
class TrioInternalError(Exception):
@@ -75,7 +75,13 @@ class Cancelled(BaseException):
Cancelled.__module__ = "trio"
-@attr.s(slots=True, frozen=True)
-class PartialResult:
- # XX
- bytes_sent = attr.ib()
+class ResourceBusyError(Exception):
+ """Raised when a task attempts to use a resource that some other task is
+ already using, and this would lead to bugs and nonsense.
+
+ For example, if two tasks try to send data through the same socket at the
+ same time, trio will raise :class:`ResourceBusyError` instead of letting
+ the data get scrambled.
+
+ """
+ResourceBusyError.__module__ = "trio"
diff --git a/trio/_core/_io_epoll.py b/trio/_core/_io_epoll.py
index 838a68b507..0b08760157 100644
--- a/trio/_core/_io_epoll.py
+++ b/trio/_core/_io_epoll.py
@@ -98,7 +98,8 @@ async def _epoll_wait(self, fd, attr_name):
self._registered[fd] = EpollWaiters()
waiters = self._registered[fd]
if getattr(waiters, attr_name) is not None:
- raise RuntimeError(
+ await _core.yield_briefly()
+ raise _core.ResourceBusyError(
"another task is already reading / writing this fd")
setattr(waiters, attr_name, _core.current_task())
self._update_registrations(fd, currently_registered)
diff --git a/trio/_core/_io_kqueue.py b/trio/_core/_io_kqueue.py
index 39885275ff..f09870a4d4 100644
--- a/trio/_core/_io_kqueue.py
+++ b/trio/_core/_io_kqueue.py
@@ -80,7 +80,7 @@ def current_kqueue(self):
def monitor_kevent(self, ident, filter):
key = (ident, filter)
if key in self._registered:
- raise RuntimeError(
+ raise _core.ResourceBusyError(
"attempt to register multiple listeners for same "
"ident/filter pair")
q = _core.UnboundedQueue()
@@ -95,7 +95,8 @@ def monitor_kevent(self, ident, filter):
async def wait_kevent(self, ident, filter, abort_func):
key = (ident, filter)
if key in self._registered:
- raise RuntimeError(
+ await _core.yield_briefly()
+ raise _core.ResourceBusyError(
"attempt to register multiple listeners for same "
"ident/filter pair")
self._registered[key] = _core.current_task()
diff --git a/trio/_core/_io_windows.py b/trio/_core/_io_windows.py
index 3820d9e83c..8725a08769 100644
--- a/trio/_core/_io_windows.py
+++ b/trio/_core/_io_windows.py
@@ -294,7 +294,7 @@ async def wait_overlapped(self, handle, lpOverlapped):
if isinstance(lpOverlapped, int):
lpOverlapped = ffi.cast("LPOVERLAPPED", lpOverlapped)
if lpOverlapped in self._overlapped_waiters:
- raise RuntimeError(
+ raise _core.ResourceBusyError(
"another task is already waiting on that lpOverlapped")
task = _core.current_task()
self._overlapped_waiters[lpOverlapped] = task
@@ -339,9 +339,11 @@ async def _wait_socket(self, which, sock):
# sockets in another thread? And on unix we don't handle this case at
# all), but hey, why not.
if type(sock) is not stdlib_socket.socket:
+ await _core.yield_briefly()
raise TypeError("need a stdlib socket")
if sock in self._socket_waiters[which]:
- raise RuntimeError(
+ await _core.yield_briefly()
+ raise _core.ResourceBusyError(
"another task is already waiting to {} this socket"
.format(which))
self._socket_waiters[which][sock] = _core.current_task()
diff --git a/trio/_core/_run.py b/trio/_core/_run.py
index 263cbc4dee..d22ff10a1a 100644
--- a/trio/_core/_run.py
+++ b/trio/_core/_run.py
@@ -366,14 +366,22 @@ class Task:
coro = attr.ib()
_runner = attr.ib()
name = attr.ib()
+ # Invariant:
+ # - for unfinished tasks, result is None
+ # - for finished tasks, result is a Result object
result = attr.ib(default=None)
- # tasks start out unscheduled, and unscheduled tasks have None here
+ # Invariant:
+ # - for unscheduled tasks, _next_send is None
+ # - for scheduled tasks, _next_send is a Result object
+ # Tasks start out unscheduled.
_next_send = attr.ib(default=None)
_abort_func = attr.ib(default=None)
# Task-local values, see _local.py
_locals = attr.ib(default=attr.Factory(dict))
+ # these are counts of how many cancel/schedule points this task has
+ # executed, for assert{_no,}_yields
# XX maybe these should be exposed as part of a statistics() method?
_cancel_points = attr.ib(default=0)
_schedule_points = attr.ib(default=0)
@@ -955,7 +963,7 @@ def _deliver_ki_cb(self):
@_public
@_hazmat
- async def wait_all_tasks_blocked(self, cushion=0.0):
+ async def wait_all_tasks_blocked(self, cushion=0.0, tiebreaker=0):
"""Block until there are no runnable tasks.
This is useful in testing code when you want to give other tasks a
@@ -973,7 +981,9 @@ async def wait_all_tasks_blocked(self, cushion=0.0):
then the one with the shortest ``cushion`` is the one woken (and the
this task becoming unblocked resets the timers for the remaining
tasks). If there are multiple tasks that have exactly the same
- ``cushion``, then all are woken.
+ ``cushion``, then the one with the lowest ``tiebreaker`` value is
+ woken first. And if there are multiple tasks with the same ``cushion``
+ and the same ``tiebreaker``, then all are woken.
You should also consider :class:`trio.testing.Sequencer`, which
provides a more explicit way to control execution ordering within a
@@ -1010,7 +1020,7 @@ async def test_lock_fairness():
"""
task = current_task()
- key = (cushion, id(task))
+ key = (cushion, tiebreaker, id(task))
self.waiting_for_idle[key] = task
def abort(_):
del self.waiting_for_idle[key]
@@ -1201,7 +1211,7 @@ def run_impl(runner, async_fn, args):
idle_primed = False
if runner.waiting_for_idle:
- cushion, _ = runner.waiting_for_idle.keys()[0]
+ cushion, tiebreaker, _ = runner.waiting_for_idle.keys()[0]
if cushion < timeout:
timeout = cushion
idle_primed = True
@@ -1229,7 +1239,7 @@ def run_impl(runner, async_fn, args):
if not runner.runq and idle_primed:
while runner.waiting_for_idle:
key, task = runner.waiting_for_idle.peekitem(0)
- if key[0] == cushion:
+ if key[:2] == (cushion, tiebreaker):
del runner.waiting_for_idle[key]
runner.reschedule(task)
else:
diff --git a/trio/_core/tests/test_io.py b/trio/_core/tests/test_io.py
index 2047e95644..b7d5697f87 100644
--- a/trio/_core/tests/test_io.py
+++ b/trio/_core/tests/test_io.py
@@ -145,8 +145,9 @@ async def test_double_read(socketpair, wait_readable):
async with _core.open_nursery() as nursery:
nursery.spawn(wait_readable, a)
await wait_all_tasks_blocked()
- with pytest.raises(RuntimeError):
- await wait_readable(a)
+ with assert_yields():
+ with pytest.raises(_core.ResourceBusyError):
+ await wait_readable(a)
nursery.cancel_scope.cancel()
@write_socket_test
@@ -158,8 +159,9 @@ async def test_double_write(socketpair, wait_writable):
async with _core.open_nursery() as nursery:
nursery.spawn(wait_writable, a)
await wait_all_tasks_blocked()
- with pytest.raises(RuntimeError):
- await wait_writable(a)
+ with assert_yields():
+ with pytest.raises(_core.ResourceBusyError):
+ await wait_writable(a)
nursery.cancel_scope.cancel()
diff --git a/trio/_network.py b/trio/_network.py
new file mode 100644
index 0000000000..e0bf67f3a8
--- /dev/null
+++ b/trio/_network.py
@@ -0,0 +1,165 @@
+# "High-level" networking interface
+
+import errno
+from contextlib import contextmanager
+
+from . import _core
+from . import socket as tsocket
+from ._util import UnLock
+from .abc import HalfCloseableStream
+from ._streams import ClosedStreamError, BrokenStreamError
+
+__all__ = ["SocketStream", "socket_stream_pair"]
+
+_closed_stream_errnos = {
+ # Unix
+ errno.EBADF,
+ # Windows
+ errno.ENOTSOCK,
+}
+
+@contextmanager
+def _translate_socket_errors_to_stream_errors():
+ try:
+ yield
+ except OSError as exc:
+ if exc.errno in _closed_stream_errnos:
+ raise ClosedStreamError(
+ "this socket was already closed") from None
+ else:
+ raise BrokenStreamError(
+ "socket connection broken: {}".format(exc)) from exc
+
+class SocketStream(HalfCloseableStream):
+ """An implementation of the :class:`trio.abc.HalfCloseableStream`
+ interface based on a raw network socket.
+
+ Args:
+ sock (trio.socket.SocketType): The trio socket object to wrap. Must have
+ type ``SOCK_STREAM``, and be connected.
+
+ By default, :class:`SocketStream` enables ``TCP_NODELAY``, and (on
+ platforms where it's supported) enables ``TCP_NOTSENT_LOWAT`` with a
+ reasonable buffer size (currently 16 KiB) – see `issue #72
+ `__ for discussion. You can
+ of course override these defaults by calling :meth:`setsockopt`.
+
+ Once a :class:`SocketStream` object is constructed, it implements the full
+ :class:`trio.abc.HalfCloseableStream` interface. In addition, it provides
+ a few extra features:
+
+ .. attribute:: socket
+
+ The :class:`trio.socket.SocketType` object that this stream wraps.
+
+ """
+ def __init__(self, sock):
+ if not isinstance(sock, tsocket.SocketType):
+ raise TypeError("SocketStream requires trio.socket.SocketType")
+ if sock._real_type != tsocket.SOCK_STREAM:
+ raise ValueError("SocketStream requires a SOCK_STREAM socket")
+ try:
+ sock.getpeername()
+ except OSError:
+ err = ValueError("SocketStream requires a connected socket")
+ raise err from None
+
+ self.socket = sock
+ self._send_lock = UnLock(
+ _core.ResourceBusyError,
+ "another task is currently sending data on this SocketStream")
+
+ # Socket defaults:
+
+ # Not supported on e.g. unix domain sockets
+ try:
+ self.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, True)
+ except OSError:
+ pass
+
+ if hasattr(tsocket, "TCP_NOTSENT_LOWAT"):
+ try:
+ # 16 KiB is pretty arbitrary and could probably do with some
+ # tuning. (Apple is also setting this by default in CFNetwork
+ # apparently -- I'm curious what value they're using, though I
+ # couldn't find it online trivially. CFNetwork-129.20 source
+ # has no mentions of TCP_NOTSENT_LOWAT. This presentation says
+ # "typically 8 kilobytes":
+ # http://devstreaming.apple.com/videos/wwdc/2015/719ui2k57m/719/719_your_app_and_next_generation_networks.pdf?dl=1
+ # ). The theory is that you want it to be bandwidth *
+ # rescheduling interval.
+ self.setsockopt(
+ tsocket.IPPROTO_TCP, tsocket.TCP_NOTSENT_LOWAT, 2 ** 14)
+ except OSError:
+ pass
+
+ async def send_all(self, data):
+ if self.socket._did_SHUT_WR:
+ await _core.yield_briefly()
+ raise ClosedStreamError("can't send data after sending EOF")
+ with self._send_lock.sync:
+ with _translate_socket_errors_to_stream_errors():
+ await self.socket.sendall(data)
+
+ async def wait_send_all_might_not_block(self):
+ async with self._send_lock:
+ if self.socket.fileno() == -1:
+ raise ClosedStreamError
+ with _translate_socket_errors_to_stream_errors():
+ await self.socket.wait_writable()
+
+ async def send_eof(self):
+ async with self._send_lock:
+ # On MacOS, calling shutdown a second time raises ENOTCONN, but
+ # send_eof needs to be idempotent.
+ if self.socket._did_SHUT_WR:
+ return
+ with _translate_socket_errors_to_stream_errors():
+ self.socket.shutdown(tsocket.SHUT_WR)
+
+ async def receive_some(self, max_bytes):
+ if max_bytes < 1:
+ await _core.yield_briefly()
+ raise ValueError("max_bytes must be >= 1")
+ with _translate_socket_errors_to_stream_errors():
+ return await self.socket.recv(max_bytes)
+
+ def forceful_close(self):
+ self.socket.close()
+
+ # graceful_close, __aenter__, __aexit__ inherited from HalfCloseableStream
+ # are OK
+
+ def setsockopt(self, level, option, value):
+ """Set an option on the underlying socket.
+
+ See :meth:`socket.socket.setsockopt` for details.
+
+ """
+ return self.socket.setsockopt(level, option, value)
+
+ def getsockopt(self, level, option, buffersize=0):
+ """Check the current value of an option on the underlying socket.
+
+ See :meth:`socket.socket.getsockopt` for details.
+
+ """
+ # This is to work around
+ # https://bitbucket.org/pypy/pypy/issues/2561
+ # We should be able to drop it when the next PyPy3 beta is released.
+ if buffersize == 0:
+ return self.socket.getsockopt(level, option)
+ else:
+ return self.socket.getsockopt(level, option, buffersize)
+
+
+def socket_stream_pair():
+ """Returns a pair of connected :class:`SocketStream` objects.
+
+ This is a convenience function that uses :func:`socket.socketpair` to
+ create the sockets, and then converts the result into
+ :class:`SocketStream`\s.
+
+ """
+ left, right = tsocket.socketpair()
+ return SocketStream(left), SocketStream(right)
diff --git a/trio/_ssl.py b/trio/_ssl.py
deleted file mode 100644
index f001374a6e..0000000000
--- a/trio/_ssl.py
+++ /dev/null
@@ -1,33 +0,0 @@
-# Use SSLObject to make a generic wrapper around Stream
-
-# SSL shutdown:
-# - call unwrap() on the SSLSocket/SSLObject
-# - this sends the "all done here" SSL message
-# - but in many practical applications this is neither sent nor checked for,
-# e.g. HTTPS usually ignores it:
-# https://security.stackexchange.com/questions/82028/ssl-tls-is-a-server-always-required-to-respond-to-a-close-notify
-# BUT it is important in some cases, so should be possible to handle
-# properly.
-#
-# I think the answer is: close is synchronous, and the TLS Stream also has an
-# async def unwrap() which sends the close_notify message.
-# Possibly we should also default suppress_ragged_eofs to False, unlike the
-# stdlib? not sure.
-
-# XX how closely should we match the stdlib API?
-# - maybe suppress_ragged_eofs=False is a better default?
-# - maybe check crypto folks for advice?
-# - this is also interesting: https://bugs.python.org/issue8108#msg102867
-
-# Definitely keep an eye on Cory's TLS API ideas on security-sig etc.
-
-# from ._stream import Stream
-
-# class SSLStream(Stream):
-# async def unwrap(self):
-# # does a clean shutdown (!) by calling SSLObject.unwrap(), sends the
-# # resulting close_notify data, and *then* returns the underlying
-# # stream.
-# XX
-
-# # XX
diff --git a/trio/_streams.py b/trio/_streams.py
index 5ac0b58989..f94fdb4d85 100644
--- a/trio/_streams.py
+++ b/trio/_streams.py
@@ -4,129 +4,129 @@
import attr
from . import _core
+from .abc import HalfCloseableStream
-__all__ = ["AsyncResource", "SendStream", "RecvStream", "Stream"]
-
-# The close API is a big question here.
-#
-# Technically, socket close can block if you set various weird
-# lingering-related options. This doesn't seem very useful though.
-#
-# For other kinds of channels, though, the natural implementation definitely
-# can block – e.g. TLS wants to send a goodbye message, and if we're tunneling
-# over ssh or HTTP/2 e.g. then again closing requires sending some actual
-# data. So we need a concept of a blocking close.
-#
-# BUT, if the other side is uncooperative, we can't necessarily block in
-# close. So if a blocking close is cancelled, we need to do some sort of
-# forceful cleanup before raising the exception.
-#
-# (Probably implementing H2-based streams will be a useful forcing function
-# here to figure this out.)
-
-class AsyncResource(metaclass=abc.ABCMeta):
- __slots__ = ()
-
- @abc.abstractmethod
- def forceful_close(self):
- """Force an immediate close of this resource.
+__all__ = [
+ "BrokenStreamError", "ClosedStreamError", "StapledStream",
+]
- This will never block, but (depending on the resource in question) it
- might be a "rude" shutdown.
- """
- pass
+class BrokenStreamError(Exception):
+ """Raised when an attempt to use a stream a stream fails due to external
+ circumstances.
- async def graceful_close(self):
- """Close this resource, gracefully.
+ For example, you might get this if you try to send data on a stream where
+ the remote side has already closed the connection.
- This may block in order to perform a "graceful" shutdown (for example,
- sending a message alerting the other side of a connection that it is
- about to close). But, if cancelled, then it still *must* close the
- underlying resource.
+ You *don't* get this error if *you* closed the stream – in that case you
+ get :class:`ClosedStreamError`.
- Default implementation is to perform a :meth:`forceful_close` and then
- execute a checkpoint.
- """
- self.forceful_close()
- await _core.yield_briefly()
+ This exception's ``__cause__`` attribute will often contain more
+ information about the underlying error.
- async def __aenter__(self):
- return self
+ """
+ pass
- async def __aexit__(self, *args):
- await self.graceful_close()
-# XX added in 3.6
-if hasattr(contextlib, "AbstractContextManager"):
- contextlib.AbstractContextManager.register(AsyncResource)
+class ClosedStreamError(Exception):
+ """Raised when an attempt to use a stream a stream fails because the
+ stream was already closed locally.
-class SendStream(AsyncResource):
- __slots__ = ()
+ You *only* get this error if *your* code closed the stream object you're
+ attempting to use by calling
+ :meth:`~trio.abc.AsyncResource.graceful_close` or
+ similar. (:meth:`~trio.abc.SendStream.send_all` might also raise this if
+ you already called :meth:`~trio.abc.HalfCloseableStream.send_eof`.)
+ Therefore this exception generally indicates a bug in your code.
- @abc.abstractmethod
- async def sendall(self, data):
- pass
+ If a problem arises elsewhere, for example due to a network failure or a
+ misbehaving peer, then you get :class:`BrokenStreamError` instead.
- # This is only a hint, because in some cases we don't know (Windows), or
- # we have only a noisy signal (TLS). And in the use cases this is included
- # to account for, returning before it's actually writable is NBD, it just
- # makes them slightly less efficient.
- @abc.abstractmethod
- async def wait_maybe_writable(self):
- pass
+ """
+ pass
- @property
- @abc.abstractmethod
- def can_send_eof(self):
- pass
- @abc.abstractmethod
- def send_eof(self):
- pass
+@attr.s(slots=True, cmp=False, hash=False)
+class StapledStream(HalfCloseableStream):
+ """This class `staples `__
+ together two unidirectional streams to make single bidirectional stream.
-class RecvStream(AsyncResource):
- __slots__ = ()
+ Args:
+ send_stream (~trio.abc.SendStream): The stream to use for sending.
+ receive_stream (~trio.abc.ReceiveStream): The stream to use for
+ receiving.
- @abc.abstractmethod
- async def recv(self, max_bytes):
- pass
+ Example:
-class Stream(SendStream, RecvStream):
- __slots__ = ()
+ A silly way to make a stream that echoes back whatever you write to
+ it::
- @staticmethod
- def staple(send_stream, recv_stream):
- return StapledStream(send_stream=send_stream, recv_stream=recv_stream)
+ sock1, sock2 = trio.socket.socketpair()
+ echo_stream = StapledStream(SocketStream(sock1), SocketStream(sock2))
+ await echo_stream.send_all(b"x")
+ assert await echo_stream.receive_some(1) == b"x"
-@attr.s(slots=True, cmp=False, hash=False)
-class StapledStream(Stream):
+ :class:`StapledStream` objects implement the methods in the
+ :class:`~trio.abc.HalfCloseableStream` interface. They also have two
+ additional public attributes:
+
+ .. attribute:: send_stream
+
+ The underlying :class:`~trio.abc.SendStream`. :meth:`send_all` and
+ :meth:`wait_send_all_might_not_block` are delegated to this object.
+
+ .. attribute:: receive_stream
+
+ The underlying :class:`~trio.abc.ReceiveStream`. :meth:`receive_some`
+ is delegated to this object.
+
+ """
send_stream = attr.ib()
- recv_stream = attr.ib()
+ receive_stream = attr.ib()
- async def sendall(self, data):
- return await self.send_stream.sendall(data)
+ async def send_all(self, data):
+ """Calls ``self.send_stream.send_all``.
- async def wait_maybe_writable(self):
- return await self.send_stream.wait_maybe_writable()
+ """
+ return await self.send_stream.send_all(data)
- @property
- def can_send_eof(self):
- return self.send_stream.can_send_eof
+ async def wait_send_all_might_not_block(self):
+ """Calls ``self.send_stream.wait_send_all_might_not_block``.
- def send_eof(self):
- return self.send_stream.send_eof()
+ """
+ return await self.send_stream.wait_send_all_might_not_block()
+
+ async def send_eof(self):
+ """Shuts down the send side of the stream.
+
+ If ``self.send_stream.send_eof`` exists, then calls it. Otherwise,
+ calls ``self.send_stream.graceful_close()``.
- async def recv(self, max_bytes):
- return self.recv_stream.recv(max_bytes)
+ """
+ if hasattr(self.send_stream, "send_eof"):
+ return await self.send_stream.send_eof()
+ else:
+ return await self.send_stream.graceful_close()
+
+ async def receive_some(self, max_bytes):
+ """Calls ``self.receive_stream.receive_some``.
+
+ """
+ return await self.receive_stream.receive_some(max_bytes)
def forceful_close(self):
+ """Calls ``forceful_close`` on both underlying streams.
+
+ """
try:
self.send_stream.forceful_close()
finally:
- self.recv_stream.forceful_close()
+ self.receive_stream.forceful_close()
async def graceful_close(self):
+ """Calls ``graceful_close`` on both underlying streams.
+
+ """
try:
await self.send_stream.graceful_close()
finally:
- await self.recv_stream.graceful_close()
+ await self.receive_stream.graceful_close()
diff --git a/trio/_sync.py b/trio/_sync.py
index e41e4f843b..fd0dcd5691 100644
--- a/trio/_sync.py
+++ b/trio/_sync.py
@@ -6,7 +6,8 @@
from . import _core
from ._util import aiter_compat
-__all__ = ["Event", "Semaphore", "Lock", "Condition", "Queue"]
+__all__ = [
+ "Event", "Semaphore", "Lock", "StrictFIFOLock", "Condition", "Queue"]
@attr.s(slots=True, repr=False, cmp=False, hash=False)
class Event:
@@ -247,7 +248,8 @@ def __repr__(self):
else:
s1 = "unlocked"
s2 = ""
- return "<{} trio.Lock object at {:#x}{}>".format(s1, id(self), s2)
+ return ("<{} {} object at {:#x}{}>"
+ .format(s1, self.__class__.__name__, id(self), s2))
def locked(self):
"""Check whether the lock is currently held.
@@ -327,6 +329,69 @@ def statistics(self):
)
+class StrictFIFOLock(Lock):
+ """A variant of :class:`Lock` where tasks are guaranteed to acquire the
+ lock in strict first-come-first-served order.
+
+ An example of when this is useful is if you're implementing something like
+ :class:`trio.ssl.SSLStream` or an HTTP/2 server using `h2
+ `__, where you have multiple concurrent
+ tasks that are interacting with a shared state machine, and at
+ unpredictable moments the state machine requests that a chunk of data be
+ sent over the network. (For example, when using h2 simply reading incoming
+ data can occasionally `create outgoing data to send
+ `__.) The challenge is to make
+ sure that these chunks are sent in the correct order, without being
+ garbled.
+
+ One option would be to use a regular :class:`Lock`, and wrap it around
+ every interaction with the state machine::
+
+ # This approach is sometimes workable but often sub-optimal; see below
+ async with lock:
+ state_machine.do_something()
+ if state_machine.has_data_to_send():
+ await conn.sendall(state_machine.get_data_to_send())
+
+ But this can be problematic. If you're using h2 then *usually* reading
+ incoming data doesn't create the need to send any data, so we don't want
+ to force every task that tries to read from the network to sit and wait
+ a potentially long time for ``sendall`` to finish. And in some situations
+ this could even potentially cause a deadlock, if the remote peer is
+ waiting for you to read some data before it accepts the data you're
+ sending.
+
+ :class:`StrictFIFOLock` provides an alternative. We can rewrite our
+ example like::
+
+ # Note: no awaits between when we start using the state machine and
+ # when we block to take the lock!
+ state_machine.do_something()
+ if state_machine.has_data_to_send():
+ # Notice that we fetch the data to send out of the state machine
+ # *before* sleeping, so that other tasks won't see it.
+ chunk = state_machine.get_data_to_send()
+ async with strict_fifo_lock:
+ await conn.sendall(chunk)
+
+ First we do all our interaction with the state machine in a single
+ scheduling quantum (notice there are no ``await``\s in there), so it's
+ automatically atomic with respect to other tasks. And then if and only if
+ we have data to send, we get in line to send it – and
+ :class:`StrictFIFOLock` guarantees that each task will send its data in
+ the same order that the state machine generated it.
+
+ Currently, :class:`StrictFIFOLock` is simply an alias for :class:`Lock`,
+ but (a) this may not always be true in the future, especially if trio ever
+ implements `more sophisticated scheduling policies
+ `__, and (b) the above code
+ is relying on a pretty subtle property of its lock. Using a
+ :class:`StrictFIFOLock` acts as an executable reminder that you're relying
+ on this property.
+
+ """
+
+
@attr.s(frozen=True)
class _ConditionStatistics:
tasks_waiting = attr.ib()
@@ -350,7 +415,7 @@ class Condition:
def __init__(self, lock=None):
if lock is None:
lock = Lock()
- if not isinstance(lock, Lock):
+ if not type(lock) is Lock:
raise TypeError("lock must be a trio.Lock")
self._lock = lock
self._lot = _core.ParkingLot()
diff --git a/trio/_util.py b/trio/_util.py
index 027ddae1e5..0986e02440 100644
--- a/trio/_util.py
+++ b/trio/_util.py
@@ -6,7 +6,14 @@
import async_generator
-__all__ = ["signal_raise", "aitercompat", "acontextmanager"]
+# There's a dependency loop here... _core is allowed to use this file (in fact
+# it's the *only* file in the main trio/ package it's allowed to use), but
+# UnLock needs yield_briefly so it also has to import _core. Possibly we
+# should split this file into two: one for true generic low-level utility
+# code, and one for higher level helpers?
+from . import _core
+
+__all__ = ["signal_raise", "aiter_compat", "acontextmanager", "UnLock"]
# Equivalent to the C function raise(), which Python doesn't wrap
if os.name == "nt":
@@ -67,7 +74,9 @@ async def __aiter__(*args, **kwargs):
# Very much derived from the one in contextlib, by copy/pasting and then
-# asyncifying everything.
+# asyncifying everything. (Also I dropped the obscure support for using
+# context managers as function decorators. It could be re-added; I just
+# couldn't be bothered.)
# So this is a derivative work licensed under the PSF License, which requires
# the following notice:
#
@@ -138,3 +147,44 @@ def helper(*args, **kwds):
# A hint for sphinxcontrib-trio:
helper.__returns_acontextmanager__ = True
return helper
+
+
+class _UnLockSync:
+ def __init__(self, exc, *args):
+ self._exc = exc
+ self._args = args
+ self._held = False
+
+ def __enter__(self):
+ if self._held:
+ raise self._exc(*self._args)
+ else:
+ self._held = True
+
+ def __exit__(self, *args):
+ self._held = False
+
+
+class UnLock:
+ """An unnecessary lock.
+
+ Use as an async context manager; if two tasks enter it at the same
+ time then the second one raises an error. You can use it when there are
+ two pieces of code that *would* collide and need a lock if they ever were
+ called at the same time, but that should never happen.
+
+ We use this in particular for things like, making sure that two different
+ tasks don't call sendall simultaneously on the same stream.
+
+ This executes a checkpoint on entry. That's the only reason it's async.
+
+ """
+ def __init__(self, exc, *args):
+ self.sync = _UnLockSync(exc, *args)
+
+ async def __aenter__(self):
+ await _core.yield_briefly()
+ return self.sync.__enter__()
+
+ async def __aexit__(self, *args):
+ return self.sync.__exit__()
diff --git a/trio/abc.py b/trio/abc.py
index 6bb7e53785..0973745d08 100644
--- a/trio/abc.py
+++ b/trio/abc.py
@@ -1,6 +1,11 @@
+import contextlib as _contextlib
import abc as _abc
+from . import _core
-__all__ = ["Clock", "Instrument"]
+__all__ = [
+ "Clock", "Instrument", "AsyncResource", "SendStream", "ReceiveStream",
+ "Stream", "HalfCloseableStream",
+]
class Clock(_abc.ABC):
"""The interface for custom run loop clocks.
@@ -130,3 +135,267 @@ def after_io_wait(self, timeout):
whether any I/O was ready.
"""
+
+
+# We use ABCMeta instead of ABC, plus setting __slots__=(), so as not to force
+# a __dict__ onto subclasses.
+class AsyncResource(metaclass=_abc.ABCMeta):
+ """A standard interface for resources that needs to be cleaned up, and
+ where that cleanup may require blocking operations.
+
+ This class distinguishes between "graceful" closes, which may perform I/O
+ and thus block, and a "forceful" close, which cannot. For example, cleanly
+ shutting down a TLS-encrypted connection requires sending a "goodbye"
+ message; but if a peer has become non-responsive, then sending this
+ message might block forever, so we may want to just drop the connection
+ instead. Therefore the :meth:`graceful_close` method is unusual in that it
+ should always close the connection (or at least make its best attempt)
+ *even if it fails*; failure indicates a failure to achieve grace, not a
+ failure to close the connection.
+
+ Objects that implement this interface can be used as async context
+ managers, i.e., you can write::
+
+ async with create_resource() as some_async_resource:
+ ...
+
+ Entering the context manager is synchronous (not a checkpoint); exiting it
+ calls :meth:`graceful_close`. The default implementations of
+ ``__aenter__`` and ``__aexit__`` should be adequate for all subclasses.
+
+ """
+ __slots__ = ()
+
+ @_abc.abstractmethod
+ def forceful_close(self):
+ """Force an immediate close of this resource.
+
+ This will never block, but (depending on the resource in question) it
+ might be a "rude" shutdown.
+
+ If the resource is already closed, then this method should silently
+ succeed.
+
+ """
+
+ async def graceful_close(self):
+ """Close this resource, gracefully.
+
+ This may block in order to perform a "graceful" shutdown (for example,
+ sending a "goodbye" message). But, if this fails (e.g., due to being
+ cancelled), then it still *must* close the underlying resource,
+ possibly by calling :meth:`forceful_close`.
+
+ If the resource is already closed, then this method should silently
+ succeed.
+
+ :class:`AsyncResource` provides a default implementation of this
+ method that's suitable for resources that don't distinguish between
+ graceful and forceful closure: it simply calls :meth:`forceful_close`
+ and then executes a checkpoint.
+
+ """
+ try:
+ self.forceful_close()
+ finally:
+ await _core.yield_briefly()
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, *args):
+ await self.graceful_close()
+
+
+class SendStream(AsyncResource):
+ """A standard interface for sending data on a byte stream.
+
+ The underlying stream may be unidirectional, or bidirectional. If it's
+ bidirectional, then you probably want to also implement
+ :class:`SendStream`, which makes your object a :class:`Stream`.
+
+ Every :class:`SendStream` also implements the :class:`AsyncResource`
+ interface.
+
+ """
+ __slots__ = ()
+
+ @_abc.abstractmethod
+ async def send_all(self, data):
+ """Sends the given data through the stream, blocking if necessary.
+
+ Args:
+ data (bytes, bytearray, or memoryview): The data to send.
+
+ Raises:
+ trio.ResourceBusyError: if another task is already executing a
+ :meth:`send_all`, :meth:`wait_send_all_might_not_block`, or
+ :meth:`HalfCloseableStream.send_eof` on this stream.
+
+ """
+
+ @_abc.abstractmethod
+ async def wait_send_all_might_not_block(self):
+ """Block until it's possible that :meth:`send_all` might not block.
+
+ This method may return early: it's possible that after it returns,
+ :meth:`send_all` will still block. (In the worst case, if no better
+ implementation is available, then it might always return immediately
+ without blocking. It's nice to do better than that when possible,
+ though.)
+
+ This method **must not** return *late*: if it's possible for
+ :meth:`send_all` to complete without blocking, then it must
+ return. When implementing it, err on the side of returning early.
+
+ Raises:
+ trio.ResourceBusyError: if another task is already executing a
+ :meth:`send_all`, :meth:`wait_send_all_might_not_block`, or
+ :meth:`HalfCloseableStream.send_eof` on this stream.
+
+ Note:
+
+ This method is intended to aid in implementing protocols that want
+ to delay choosing which data to send until the last moment. E.g.,
+ suppose you're working on an implemention of a remote display server
+ like `VNC
+ `__, and
+ the network connection is currently backed up so that if you call
+ :meth:`send_all` now then it will sit for 0.5 seconds before actually
+ sending anything. In this case it doesn't make sense to take a
+ screenshot, then wait 0.5 seconds, and then send it, because the
+ screen will keep changing while you wait; it's better to wait 0.5
+ seconds, then take the screenshot, and then send it, because this
+ way the data you deliver will be more
+ up-to-date. Using :meth:`wait_send_all_might_not_block` makes it
+ possible to implement the better strategy.
+
+ If you use this method, you might also want to read up on
+ ``TCP_NOTSENT_LOWAT``.
+
+ Further reading:
+
+ * `Prioritization Only Works When There's Pending Data to Prioritize
+ `__
+
+ * WWDC 2015: Your App and Next Generation Networks: `slides
+ `__,
+ `video and transcript
+ `__
+
+ """
+
+
+class ReceiveStream(AsyncResource):
+ """A standard interface for receiving data on a byte stream.
+
+ The underlying stream may be unidirectional, or bidirectional. If it's
+ bidirectional, then you probably want to also implement
+ :class:`SendStream`, which makes your object a :class:`Stream`.
+
+ Every :class:`ReceiveStream` also implements the :class:`AsyncResource`
+ interface.
+
+ """
+ __slots__ = ()
+
+ @_abc.abstractmethod
+ async def receive_some(self, max_bytes):
+ """Wait until there is data available on this stream, and then return
+ at most ``max_bytes`` of it.
+
+ A return value of ``b""`` (an empty bytestring) indicates that the
+ stream has reached end-of-file. Implementations should be careful that
+ they return ``b""`` if, and only if, the stream has reached
+ end-of-file!
+
+ This method will return as soon as any data is available, so it may
+ return fewer than ``max_bytes`` of data. But it will never return
+ more.
+
+ Args:
+ max_bytes (int): The maximum number of bytes to return. Must be
+ greater than zero.
+
+ Returns:
+ bytes or bytearray: The data received.
+
+ Raises:
+ trio.ResourceBusyError: if two tasks attempt to call
+ :meth:`receive_some` on the same stream at the same time.
+ trio.BrokenStreamError: if something has gone wrong, and the stream
+ is broken.
+ trio.ClosedStreamError: if someone already called one of the close
+ methods on this stream object.
+
+ """
+
+
+class Stream(SendStream, ReceiveStream):
+ """A standard interface for interacting with bidirectional byte streams.
+
+ A :class:`Stream` is an object that implements both the
+ :class:`SendStream` and :class:`ReceiveStream` interfaces.
+
+ If implementing this interface, you should consider whether you can go one
+ step further and implement :class:`HalfCloseableStream`.
+
+ """
+ __slots__ = ()
+
+
+class HalfCloseableStream(Stream):
+ """This interface extends :class:`Stream` to also allow closing the send
+ part of the stream without closing the receive part.
+
+ """
+ __slots__ = ()
+
+
+ @_abc.abstractmethod
+ async def send_eof(self):
+ """Send an end-of-file indication on this stream, if possible.
+
+ The difference between :meth:`send_eof` and
+ :meth:`~AsyncResource.graceful_close` is that :meth:`send_eof` is a
+ *unidirectional* end-of-file indication. After you call this method,
+ you shouldn't try sending any more data on this stream, and your
+ remote peer should receive an end-of-file indication (eventually,
+ after receiving all the data you sent before that). But, they may
+ continue to send data to you, and you can continue to receive it by
+ calling :meth:`~ReceiveStream.receive_some`. You can think of it as
+ calling :meth:`~AsyncResource.graceful_close` on just the
+ :class:`SendStream` "half" of the stream object (and in fact that's
+ literally how :class:`trio.StapledStream` implements it).
+
+ Examples:
+
+ * On a socket, this corresponds to ``shutdown(..., SHUT_WR)`` (`man
+ page `__).
+
+ * The SSH protocol provides the ability to multiplex bidirectional
+ "channels" on top of a single encrypted connection. A trio
+ implementation of SSH could expose these channels as
+ :class:`HalfCloseableStream` objects, and calling :meth:`send_eof`
+ would send an ``SSH_MSG_CHANNEL_EOF`` request (see `RFC 4254 §5.3
+ `__).
+
+ * On an SSL/TLS-encrypted connection, the protocol doesn't provide any
+ way to do a unidirectional shutdown without closing the connection
+ entirely, so :class:`~trio.ssl.SSLStream` implements
+ :class:`Stream`, not :class:`HalfCloseableStream`.
+
+ If an EOF has already been sent, then this method should silently
+ succeed.
+
+ Raises:
+ trio.ResourceBusyError: if another task is already executing a
+ :meth:`~SendStream.send_all`,
+ :meth:`~SendStream.wait_send_all_might_not_block`, or
+ :meth:`send_eof` on this stream.
+ trio.BrokenStreamError: if something has gone wrong, and the stream
+ is broken.
+ trio.ClosedStreamError: if someone already called one of the close
+ methods on this stream object.
+
+ """
diff --git a/trio/socket.py b/trio/socket.py
index f383f3f7aa..c5c2660aea 100644
--- a/trio/socket.py
+++ b/trio/socket.py
@@ -1,7 +1,9 @@
from functools import wraps as _wraps, partial as _partial
import socket as _stdlib_socket
import sys as _sys
-import os
+import os as _os
+from contextlib import contextmanager as _contextmanager
+import errno as _errno
from . import _core
from ._threads import run_in_worker_thread as _run_in_worker_thread
@@ -185,6 +187,13 @@ def __init__(self, sock):
.format(type(sock).__name__))
self._sock = sock
self._sock.setblocking(False)
+ self._did_SHUT_WR = False
+
+ # Hopefully Python will eventually make something like this public
+ # (see bpo-21327) but I don't want to make it public myself and then
+ # find out they picked a different name... this is used internally in
+ # this file and also elsewhere in trio.
+ self._real_type = sock.type & _SOCK_TYPE_MASK
# Defaults:
if self._sock.family == AF_INET6:
@@ -200,25 +209,6 @@ def __init__(self, sock):
else:
self.setsockopt(SOL_SOCKET, SO_REUSEADDR, True)
- try:
- self.setsockopt(IPPROTO_TCP, TCP_NODELAY, True)
- except OSError:
- pass
-
- try:
- # 16 KiB is pretty arbitrary and could probably do with some
- # tuning. (Apple is also setting this by default in CFNetwork
- # apparently -- I'm curious what value they're using, though I
- # couldn't find it online trivially. CFNetwork-129.20 source has
- # no mentions of TCP_NOTSENT_LOWAT. This presentation says
- # "typically 8 kilobytes":
- # http://devstreaming.apple.com/videos/wwdc/2015/719ui2k57m/719/719_your_app_and_next_generation_networks.pdf?dl=1
- # ). The theory is that you want it to be bandwidth * rescheduling
- # interval.
- self.setsockopt(IPPROTO_TCP, TCP_NOTSENT_LOWAT, 2 ** 14)
- except (NameError, OSError):
- pass
-
################################################################
# Simple + portable methods and attributes
################################################################
@@ -235,7 +225,7 @@ def __init__(self, sock):
_forward = {
"detach", "get_inheritable", "set_inheritable", "fileno",
"getpeername", "getsockname", "getsockopt", "setsockopt", "listen",
- "shutdown", "close", "share",
+ "close", "share",
}
def __getattr__(self, name):
if name in self._forward:
@@ -282,6 +272,16 @@ def bind(self, address):
self._check_address(address, require_resolved=True)
return self._sock.bind(address)
+ def shutdown(self, flag):
+ # no need to worry about return value b/c always returns None:
+ self._sock.shutdown(flag)
+ # only do this if the call succeeded:
+ if flag in [SHUT_WR, SHUT_RDWR]:
+ self._did_SHUT_WR = True
+
+ async def wait_writable(self):
+ await _core.wait_socket_writable(self._sock)
+
################################################################
# Address handling
################################################################
@@ -315,7 +315,7 @@ def _check_address(self, address, *, require_resolved):
_stdlib_socket.getaddrinfo(
address[0], address[1],
self._sock.family,
- self._sock.type & _SOCK_TYPE_MASK,
+ self._real_type,
self._sock.proto,
flags=_NUMERIC_ONLY)
except gaierror as exc:
@@ -352,7 +352,7 @@ async def _resolve_address(self, address, flags):
gai_res = await getaddrinfo(
address[0], address[1],
self._sock.family,
- self._sock.type & _SOCK_TYPE_MASK,
+ self._real_type,
self._sock.proto,
flags)
# AFAICT from the spec it's not possible for getaddrinfo to return an
@@ -492,19 +492,54 @@ async def connect(self, address):
# notification. This means it isn't really cancellable...
async with _try_sync():
self._check_address(address, require_resolved=True)
- # For some reason, PEP 475 left InterruptedError as a
- # possible error for non-blocking connect
- # (specifically). But as far as I know, EINTR always means
- # you need to redo the call (with the extremely special
- # exception of close() on Linux, but that's unrelated, and
- # POSIX is cranky at them about it). If the kernel wanted
- # to signal that the connect really was in progress then
- # it'd have used EINPROGRESS. So we retry:
- while True:
- try:
- return self._sock.connect(address)
- except InterruptedError:
- pass
+ # An interesting puzzle: can a non-blocking connect() return EINTR
+ # (= raise InterruptedError)? PEP 475 specifically left this as
+ # the one place where it lets an InterruptedError escape instead
+ # of automatically retrying. This is based on the idea that EINTR
+ # from connect means that the connection was already started, and
+ # will continue in the background. For a blocking connect, this
+ # sort of makes sense: if it returns EINTR then the connection
+ # attempt is continuing in the background, and on many system you
+ # can't then call connect() again because there is already a
+ # connect happening. See:
+ #
+ # http://www.madore.org/~david/computers/connect-intr.html
+ #
+ # For a non-blocking connect, it doesn't make as much sense --
+ # surely the interrupt didn't happen after we successfully
+ # initiated the connect and are just waiting for it to complete,
+ # because a non-blocking connect does not wait! And the spec
+ # describes the interaction between EINTR/blocking connect, but
+ # doesn't have anything useful to say about non-blocking connect:
+ #
+ # http://pubs.opengroup.org/onlinepubs/007904975/functions/connect.html
+ #
+ # So we have a conundrum: if EINTR means that the connect() hasn't
+ # happened (like it does for essentially every other syscall),
+ # then InterruptedError should be caught and retried. If EINTR
+ # means that the connect() has successfully started, then
+ # InterruptedError should be caught and ignored. Which should we
+ # do?
+ #
+ # In practice, the resolution is probably that non-blocking
+ # connect simply never returns EINTR, so the question of how to
+ # handle it is moot. Someone spelunked MacOS/FreeBSD and
+ # confirmed this is true there:
+ #
+ # https://stackoverflow.com/questions/14134440/eintr-and-non-blocking-calls
+ #
+ # and exarkun seems to think it's true in general of non-blocking
+ # calls:
+ #
+ # https://twistedmatrix.com/pipermail/twisted-python/2010-September/022864.html
+ # (and indeed, AFAICT twisted doesn't try to handle
+ # InterruptedError).
+ #
+ # So we don't try to catch InterruptedError. This way if it
+ # happens, someone will hopefully tell us, and then hopefully we
+ # can investigate their system to figure out what its semantics
+ # are.
+ return self._sock.connect(address)
# It raised BlockingIOError, meaning that it's started the
# connection attempt. We wait for it to complete:
try:
@@ -518,7 +553,7 @@ async def connect(self, address):
# Okay, the connect finished, but it might have failed:
err = self._sock.getsockopt(SOL_SOCKET, SO_ERROR)
if err != 0:
- raise OSError(err, "Error in connect: " + os.strerror(err))
+ raise OSError(err, "Error in connect: " + _os.strerror(err))
################################################################
# recv
@@ -568,7 +603,7 @@ async def connect(self, address):
# send
################################################################
- _send = _make_simple_sock_method_wrapper(
+ send = _make_simple_sock_method_wrapper(
"send", _core.wait_socket_writable)
################################################################
@@ -621,22 +656,23 @@ async def sendall(self, data, flags=0):
``flags`` are passed on to ``send``.
- If an error occurs or the operation is cancelled, then the resulting
- exception will have a ``.partial_result`` attribute with a
- ``.bytes_sent`` attribute containing the number of bytes sent.
+ Most low-level operations in trio provide a guarantee: if they raise
+ :exc:`trio.Cancelled`, this means that they had no effect, so the
+ system remains in a known state. This is **not true** for
+ :meth:`sendall`. If this operation raises :exc:`trio.Cancelled` (or
+ any other exception for that matter), then it may have sent some, all,
+ or none of the requested data, and there is no way to know which.
"""
with memoryview(data) as data:
+ if not data:
+ await _core.yield_briefly()
+ return
total_sent = 0
- try:
- while data:
- sent = await self._send(data, flags)
- total_sent += sent
- data = data[sent:]
- except BaseException as exc:
- pr = _core.PartialResult(bytes_sent=total_sent)
- exc.partial_result = pr
- raise
+ while total_sent < len(data):
+ with data[total_sent:] as remaining:
+ sent = await self.send(remaining, flags)
+ total_sent += sent
################################################################
# sendfile
@@ -705,3 +741,5 @@ async def sendall(self, data, flags=0):
# else:
# raise OSError("getaddrinfo returned an empty list")
# __all__.append("create_connection")
+
+
diff --git a/trio/ssl.py b/trio/ssl.py
new file mode 100644
index 0000000000..0c719a302f
--- /dev/null
+++ b/trio/ssl.py
@@ -0,0 +1,784 @@
+# General theory of operation:
+#
+# We implement an API that closely mirrors the stdlib ssl module's blocking
+# API, and we do it using the stdlib ssl module's non-blocking in-memory API.
+# The stdlib non-blocking in-memory API is barely documented, and acts as a
+# thin wrapper around openssl, whose documentation also leaves something to be
+# desired. So here's the main things you need to know to understand the code
+# in this file:
+#
+# We use an ssl.SSLObject, which exposes the four main I/O operations:
+#
+# - do_handshake: performs the initial handshake. Must be called once at the
+# beginning of each connection; is a no-op once it's completed once.
+#
+# - write: takes some unencrypted data and attempts to send it to the remote
+# peer.
+
+# - read: attempts to decrypt and return some data from the remote peer.
+#
+# - unwrap: this is weirdly named; maybe it helps to realize that the thing it
+# wraps is called SSL_shutdown. It sends a cryptographically signed message
+# saying "I'm closing this connection now", and then waits to receive the
+# same from the remote peer (unless we already received one, in which case
+# it returns immediately).
+#
+# All of these operations read and write from some in-memory buffers called
+# "BIOs", which are an opaque OpenSSL-specific object that's basically
+# semantically equivalent to a Python bytearray. When they want to send some
+# bytes to the remote peer, they append them to the outgoing BIO, and when
+# they want to receive some bytes from the remote peer, they try to pull them
+# out of the incoming BIO. "Sending" always succeeds, because the outgoing BIO
+# can always be extended to hold more data. "Receiving" acts sort of like a
+# non-blocking socket: it might manage to get some data immediately, or it
+# might fail and need to be tried again later. We can also directly add or
+# remove data from the BIOs whenever we want.
+#
+# Now the problem is that while these I/O operations are opaque atomic
+# operations from the point of view of us calling them, under the hood they
+# might require some arbitrary sequence of sends and receives from the remote
+# peer. This is particularly true for do_handshake, which generally requires a
+# few round trips, but it's also true for write and read, due to an evil thing
+# called "renegotiation".
+#
+# Renegotiation is the process by which one of the peers might arbitrarily
+# decide to redo the handshake at any time. Did I mention it's evil? It's
+# pretty evil, and almost universally hated. The HTTP/2 spec forbids the use
+# of TLS renegotiation for HTTP/2 connections. TLS 1.3 removes it from the
+# protocol entirely. It's impossible to trigger a renegotiation if using
+# Python's ssl module. OpenSSL's renegotiation support is pretty buggy [1].
+# Nonetheless, it does get used in real life, mostly in two cases:
+#
+# 1) Normally in TLS 1.2 and below, when the client side of a connection wants
+# to present a certificate to prove their identity, that certificate gets sent
+# in plaintext. This is bad, because it means that anyone eavesdropping can
+# see who's connecting – it's like sending your username in plain text. Not as
+# bad as sending your password in plain text, but still, pretty bad. However,
+# renegotiations *are* encrypted. So as a workaround, it's not uncommon for
+# systems that want to use client certificates to first do an anonymous
+# handshake, and then to turn around and do a second handshake (=
+# renegotiation) and this time ask for a client cert. Or sometimes this is
+# done on a case-by-case basis, e.g. a web server might accept a connection,
+# read the request, and then once it sees the page you're asking for it might
+# stop and ask you for a certificate.
+#
+# 2) In principle the same TLS connection can be used for an arbitrarily long
+# time, and might transmit arbitrarily large amounts of data. But this creates
+# a cryptographic problem: an attacker who has access to arbitrarily large
+# amounts of data that's all encrypted using the same key may eventually be
+# able to use this to figure out the key. Is this a real practical problem? I
+# have no idea, I'm not a cryptographer. In any case, some people worry that
+# it's a problem, so their TLS libraries are designed to automatically trigger
+# a renegotation every once in a while on some sort of timer.
+#
+# The end result is that you might be going along, minding your own business,
+# and then *bam*! a wild renegotiation appears! And you just have to cope.
+#
+# The reason that coping with renegotiations is difficult is that some
+# unassuming "read" or "write" call might find itself unable to progress until
+# it does a handshake, which remember is a process with multiple round
+# trips. So read might have to send data, and write might have to receive
+# data, and this might happen multiple times. And some of those attempts might
+# fail because there isn't any data yet, and need to be retried. Managing all
+# this is pretty complicated.
+#
+# Here's how openssl (and thus the stdlib ssl module) handle this. All of the
+# I/O operations above follow the same rules. When you call one of them:
+#
+# - it might write some data to the outgoing BIO
+# - it might read some data from the incoming BIO
+# - it might raise SSLWantReadError if it can't complete without reading more
+# data from the incoming BIO. This is important: the "read" in ReadError
+# refers to reading from the *underlying* stream.
+# - (and in principle it might raise SSLWantWriteError too, but that never
+# happens when using memory BIOs, so never mind)
+#
+# If it doesn't raise an error, then the operation completed successfully
+# (though we still need to take any outgoing data out of the memory buffer and
+# put it onto the wire). If it *does* raise an error, then we need to retry
+# *exactly that method call* later – in particular, if a 'write' failed, we
+# need to try again later *with the same data*, because openssl might have
+# already committed some of the initial parts of our data to its output even
+# though it didn't tell us that, and has remembered that the next time we call
+# write it needs to skip the first 1024 bytes or whatever it is. (Well,
+# technically, we're actually allowed to call 'write' again with a data buffer
+# which is the same as our old one PLUS some extra stuff added onto the end,
+# but in trio that never comes up so never mind.)
+#
+# There are some people online who claim that once you've gotten a Want*Error
+# then the *very next call* you make to openssl *must* be the same as the
+# previous one. I'm pretty sure those people are wrong. In particular, it's
+# okay to call write, get a WantReadError, and then call read a few times;
+# it's just that *the next time you call write*, it has to be with the same
+# data.
+#
+# One final wrinkle: we want our SSLStream to support full-duplex operation,
+# i.e. it should be possible for one task to be calling send_all while another
+# task is calling receive_some. But renegotiation makes this a big hassle, because
+# even if SSLStream's restricts themselves to one task calling send_all and one
+# task calling receive_some, those two tasks might end up both wanting to call
+# send_all, or both to call receive_some at the same time *on the underlying
+# stream*. So we have to do some careful locking to hide this problem from our
+# users.
+#
+# (Renegotiation is evil.)
+#
+# So our basic strategy is to define a single helper method called "_retry",
+# which has generic logic for dealing with SSLWantReadError, pushing data from
+# the outgoing BIO to the wire, reading data from the wire to the incoming
+# BIO, retrying an I/O call until it works, and synchronizing with other tasks
+# that might be calling _retry concurrently. Basically it takes an SSLObject
+# non-blocking in-memory method and converts it into a trio async blocking
+# method. _retry is only about 30 lines of code, but all these cases
+# multiplied by concurrent calls make it extremely tricky, so there are lots
+# of comments down below on the details, and a really extensive test suite in
+# test_ssl.py. And now you know *why* it's so tricky, and can probably
+# understand how it works.
+#
+# [1] https://rt.openssl.org/Ticket/Display.html?id=3712
+
+# XX how closely should we match the stdlib API?
+# - maybe suppress_ragged_eofs=False is a better default?
+# - maybe check crypto folks for advice?
+# - this is also interesting: https://bugs.python.org/issue8108#msg102867
+
+# Definitely keep an eye on Cory's TLS API ideas on security-sig etc.
+
+# XX document behavior on cancellation/error (i.e.: all is lost abandon
+# stream)
+# docs will need to make very clear that this is different from all the other
+# cancellations in core trio
+
+import operator as _operator
+import ssl as _stdlib_ssl
+from enum import Enum as _Enum
+
+from . import _core
+from .abc import Stream as _Stream
+from . import _streams
+from . import _sync
+from ._util import UnLock as _UnLock
+
+__all__ = ["SSLStream"]
+
+################################################################
+# Faking the stdlib ssl API
+################################################################
+
+def _reexport(name):
+ try:
+ value = getattr(_stdlib_ssl, name)
+ except AttributeError:
+ pass
+ else:
+ globals()[name] = value
+ __all__.append(name)
+
+for _name in [
+ "SSLError", "SSLZeroReturnError", "SSLSyscallError", "SSLEOFError",
+ "CertificateError", "create_default_context", "match_hostname",
+ "cert_time_to_seconds", "DER_cert_to_PEM_cert", "PEM_cert_to_DER_cert",
+ "get_default_verify_paths", "Purpose", "enum_certificates",
+ "enum_crls", "SSLSession", "VerifyMode", "VerifyFlags", "Options",
+ "AlertDescription", "SSLErrorNumber",
+ # Intentionally not re-exported: SSLContext
+]:
+ _reexport(_name)
+
+for _name in _stdlib_ssl.__dict__.keys():
+ if _name == _name.upper():
+ _reexport(_name)
+
+
+################################################################
+# SSLStream
+################################################################
+
+class _Once:
+ def __init__(self, afn, *args):
+ self._afn = afn
+ self._args = args
+ self.started = False
+ self._done = _sync.Event()
+
+ async def ensure(self, *, checkpoint):
+ if not self.started:
+ self.started = True
+ await self._afn(*self._args)
+ self._done.set()
+ elif not checkpoint and self._done.is_set():
+ return
+ else:
+ await self._done.wait()
+
+
+_State = _Enum("_State", ["OK", "BROKEN", "CLOSED"])
+
+
+class SSLStream(_Stream):
+ """Encrypted communication using SSL/TLS.
+
+ :class:`SSLStream` wraps an arbitrary :class:`~trio.abc.Stream`, and
+ allows you to perform encrypted communication over it using the usual
+ :class:`~trio.abc.Stream` interface. You pass regular data to
+ :meth:`send_all`, then it encrypts it and sends the encrypted data on the
+ underlying :class:`~trio.abc.Stream`; :meth:`receive_some` takes encrypted
+ data out of the underlying :class:`~trio.abc.Stream` and decrypts it
+ before returning it.
+
+ You should read the standard library's :mod:`ssl` documentation carefully
+ before attempting to use this class, and probably other general
+ documentation on SSL/TLS as well. SSL/TLS is subtle and quick to
+ anger. Really. I'm not kidding.
+
+ Args:
+ transport_stream (~trio.abc.Stream): The stream used to transport
+ encrypted data. Required.
+
+ ssl_context (~ssl.SSLContext): The :class:`~ssl.SSLContext` used for
+ this connection. Required. Usually created by calling
+ :func:`trio.ssl.create_default_context()
+ `.
+
+ server_hostname (str or None): The name of the server being connected
+ to. Used for `SNI
+ `__ and for
+ validating the server's certificate (if hostname checking is
+ enabled). This is effectively mandatory for clients, and actually
+ mandatory if ``ssl_context.check_hostname`` is True.
+
+ server_side (bool): Whether this stream is acting as a client or
+ server. Defaults to False, i.e. client mode.
+
+ https_compatible (bool): There are two versions of SSL/TLS commonly
+ encountered in the wild: the standard version, and the version used
+ for HTTPS (HTTP-over-SSL/TLS).
+
+ Standard-compliant SSL/TLS implementations always send a
+ cryptographically signed ``close_notify`` message before closing the
+ connection. This is important because if the underlying transport
+ were simply closed, then there wouldn't be any way for the other
+ side to know whether the connection was intentionally closed by the
+ peer that they negotiated a cryptographic connection to, or by some
+ `man-in-the-middle
+ `__ attacker
+ who can't manipulate the cryptographic stream, but can manipulate
+ the transport layer (a so-called "truncation attack").
+
+ However, this part of the standard is widely ignored by real-world
+ HTTPS implementations, which means that if you want to interoperate
+ with them, then you NEED to ignore it too.
+
+ Fortunately this isn't as bad as it sounds, because the HTTP
+ protocol already includes its own equivalent of ``close_notify``, so
+ doing this again at the SSL/TLS level is redundant. But not all
+ protocols do! Therefore, by default Trio implements the safer
+ standard-compliant version (``https_compatible=False``). But if
+ you're speaking HTTPS or some other protocol where
+ ``close_notify``\s are commonly skipped, then you should set
+ ``https_compatible=True``; with this setting, Trio will neither
+ expect nor send ``close_notify`` messages.
+
+ If you have code that was written to use :class:`ssl.SSLSocket` and
+ now you're porting it to Trio, then it may be useful to know that a
+ difference between :class:`SSLStream` and :class:`ssl.SSLSocket` is
+ that :class:`~ssl.SSLSocket` implements the
+ ``https_compatible=True`` behavior by default.
+
+ max_refill_bytes (int): :class:`~ssl.SSLSocket` maintains an internal
+ buffer of incoming data, and when it runs low then it calls
+ :meth:`receive_some` on the underlying transport stream to refill
+ it. This argument lets you set the ``max_bytes`` argument passed to
+ the *underlying* :meth:`receive_some` call. It doesn't affect calls
+ to *this* class's :meth:`receive_some`, or really anything else
+ user-observable except possibly performance. You probably don't need
+ to worry about this.
+
+ Attributes:
+ transport_stream (trio.abc.Stream): The underlying transport stream
+ that was passed to ``__init__``. An example of when this would be
+ useful is if you're using :class:`SSLStream` over a
+ :class:`~trio.SocketStream` and want to call the
+ :class:`~trio.SocketStream`'s :meth:`~trio.SocketStream.setsockopt`
+ method.
+
+ Internally, this class is implemented using an instance of
+ :class:`ssl.SSLObject`, and all of :class:`~ssl.SSLObject`'s methods and
+ attributes are re-exported as methods and attributes on this class.
+
+ This also means that you register a SNI callback using
+ :meth:`~ssl.SSLContext.set_servername_callback`, then the first argument
+ your callback receives will be a :class:`ssl.SSLObject`.
+
+ """
+ def __init__(
+ self, transport_stream, ssl_context,
+ *,
+ server_hostname=None, server_side=False,
+ https_compatible=False, max_refill_bytes=32 * 1024
+ ):
+ self.transport_stream = transport_stream
+ self._state = _State.OK
+ self._max_bytes = max_refill_bytes
+ self._https_compatible = https_compatible
+ self._outgoing = _stdlib_ssl.MemoryBIO()
+ self._incoming = _stdlib_ssl.MemoryBIO()
+ self._ssl_object = ssl_context.wrap_bio(
+ self._incoming, self._outgoing,
+ server_side=server_side, server_hostname=server_hostname)
+ # Tracks whether we've already done the initial handshake
+ self._handshook = _Once(self._do_handshake)
+
+ # These are used to synchronize access to self.transport_stream
+ self._inner_send_lock = _sync.StrictFIFOLock()
+ self._inner_recv_count = 0
+ self._inner_recv_lock = _sync.Lock()
+
+ # These are used to make sure that our caller doesn't attempt to make
+ # multiple concurrent calls to send_all/wait_send_all_might_not_block
+ # or to receive_some.
+ self._outer_send_lock = _UnLock(
+ _core.ResourceBusyError,
+ "another task is currently sending data on this SSLStream")
+ self._outer_recv_lock = _UnLock(
+ _core.ResourceBusyError,
+ "another task is currently receiving data on this SSLStream")
+
+ _forwarded = {
+ "context", "server_side", "server_hostname", "session",
+ "session_reused", "getpeercert", "selected_npn_protocol", "cipher",
+ "shared_ciphers", "compression", "pending", "get_channel_binding",
+ "selected_alpn_protocol", "version",
+ }
+ def __getattr__(self, name):
+ if name in self._forwarded:
+ return getattr(self._ssl_object, name)
+ else:
+ raise AttributeError(name)
+
+ def __setattr__(self, name, value):
+ if name in self._forwarded:
+ setattr(self._ssl_object, name, value)
+ else:
+ super().__setattr__(name, value)
+
+ def __dir__(self):
+ return super().__dir__() + list(self._forwarded)
+
+ def _check_status(self):
+ if self._state is _State.OK:
+ return
+ elif self._state is _State.BROKEN:
+ raise _streams.BrokenStreamError
+ elif self._state is _State.CLOSED:
+ raise _streams.ClosedStreamError
+ else: # pragma: no cover
+ assert False
+
+ # This is probably the single trickiest function in trio. It has lots of
+ # comments, though, just make sure to think carefully if you ever have to
+ # touch it. The big comment at the top of this file will help explain
+ # too.
+ async def _retry(self, fn, *args, ignore_want_read=False):
+ await _core.yield_if_cancelled()
+ yielded = False
+ try:
+ finished = False
+ while not finished:
+ # WARNING: this code needs to be very careful with when it
+ # calls 'await'! There might be multiple tasks calling this
+ # function at the same time trying to do different operations,
+ # so we need to be careful to:
+ #
+ # 1) interact with the SSLObject, then
+ # 2) await on exactly one thing that lets us make forward
+ # progress, then
+ # 3) loop or exit
+ #
+ # In particular we don't want to yield while interacting with
+ # the SSLObject (because it's shared state, so someone else
+ # might come in and mess with it while we're suspended), and
+ # we don't want to yield *before* starting the operation that
+ # will help us make progress, because then someone else might
+ # come in and leapfrog us.
+
+ # Call the SSLObject method, and get its result.
+ #
+ # NB: despite what the docs say, SSLWantWriteError can't
+ # happen – "Writes to memory BIOs will always succeed if
+ # memory is available: that is their size can grow
+ # indefinitely."
+ # https://wiki.openssl.org/index.php/Manual:BIO_s_mem(3)
+ want_read = False
+ ret = None
+ try:
+ ret = fn(*args)
+ except _stdlib_ssl.SSLWantReadError:
+ want_read = True
+ except (SSLError, CertificateError) as exc:
+ self._state = _State.BROKEN
+ raise _streams.BrokenStreamError from exc
+ else:
+ finished = True
+ if ignore_want_read:
+ want_read = False
+ finished = True
+ to_send = self._outgoing.read()
+
+ # Outputs from the above code block are:
+ #
+ # - to_send: bytestring; if non-empty then we need to send
+ # this data to make forward progress
+ #
+ # - want_read: True if we need to receive_some some data to make
+ # forward progress
+ #
+ # - finished: False means that we need to retry the call to
+ # fn(*args) again, after having pushed things forward. True
+ # means we still need to do whatever was said (in particular
+ # send any data in to_send), but once we do then we're
+ # done.
+ #
+ # - ret: the operation's return value. (Meaningless unless
+ # finished is True.)
+ #
+ # Invariant: want_read and finished can't both be True at the
+ # same time.
+ #
+ # Now we need to move things forward. There are two things we
+ # might have to do, and any given operation might require
+ # either, both, or neither to proceed:
+ #
+ # - send the data in to_send
+ #
+ # - receive_some some data and put it into the incoming BIO
+ #
+ # Our strategy is: if there's data to send, send it;
+ # *otherwise* if there's data to receive_some, receive_some it.
+ #
+ # If both need to happen, then we only send. Why? Well, we
+ # know that *right now* we have to both send and receive_some
+ # before the operation can complete. But as soon as we yield,
+ # that information becomes potentially stale – e.g. while
+ # we're sending, some other task might go and receive_some the
+ # data we need and put it into the incoming BIO. And if it
+ # does, then we *definitely don't* want to do a receive_some –
+ # there might not be any more data coming, and we'd deadlock!
+ # We could do something tricky to keep track of whether a
+ # receive_some happens while we're sending, but the case where
+ # we have to do both is very unusual (only during a
+ # renegotation), so it's better to keep things simple. So we
+ # do just one potentially-blocking operation, then check again
+ # for fresh information.
+ #
+ # And we prioritize sending over receiving because, if there
+ # are multiple tasks that want to receive_some, then it
+ # doesn't matter what order they go in. But if there are
+ # multiple tasks that want to send, then they each have
+ # different data, and the data needs to get put onto the wire
+ # in the same order that it was retrieved from the outgoing
+ # BIO. So if we have data to send, that *needs* to be the
+ # *very* *next* *thing* we do, to make sure no-one else sneaks
+ # in before us. Or if we can't send immediately because
+ # someone else is, then we at least need to get in line
+ # immediately.
+ if to_send:
+ # NOTE: This relies on the lock being strict FIFO fair!
+ async with self._inner_send_lock:
+ yielded = True
+ try:
+ await self.transport_stream.send_all(to_send)
+ except:
+ # Some unknown amount of our data got sent, and we
+ # don't know how much. This stream is doomed.
+ self._state = _State.BROKEN
+ raise
+ elif want_read:
+ # It's possible that someone else is already blocked in
+ # transport_stream.receive_some. If so then we want to
+ # wait for them to finish, but we don't want to call
+ # transport_stream.receive_some again ourselves; we just
+ # want to loop around and check if their contribution
+ # helped anything. So we make a note of how many times
+ # some task has been through here before taking the lock,
+ # and if it's changed by the time we get the lock, then we
+ # skip calling transport_stream.receive_some and loop
+ # around immediately.
+ recv_count = self._inner_recv_count
+ async with self._inner_recv_lock:
+ yielded = True
+ if recv_count == self._inner_recv_count:
+ data = await self.transport_stream.receive_some(self._max_bytes)
+ if not data:
+ self._incoming.write_eof()
+ else:
+ self._incoming.write(data)
+ self._inner_recv_count += 1
+ return ret
+ finally:
+ if not yielded:
+ await _core.yield_briefly_no_cancel()
+
+ async def _do_handshake(self):
+ try:
+ await self._retry(self._ssl_object.do_handshake)
+ except:
+ self._state = _State.BROKEN
+ raise
+
+ async def do_handshake(self):
+ """Ensure that the initial handshake has completed.
+
+ The SSL protocol requires an initial handshake to exchange
+ certificates, select cryptographic keys, and so forth, before any
+ actual data can be sent or received. You don't have to call this
+ method; if you don't, then :class:`SSLStream` will automatically
+ peform the handshake as needed, the first time you try to send or
+ receive data. But if you want to trigger it manually – for example,
+ because you want to look at the peer's certificate before you start
+ talking to them – then you can call this method.
+
+ If the initial handshake is already in progress in another task, this
+ waits for it to complete and then returns.
+
+ If the initial handshake has already completed, this returns
+ immediately without doing anything (except executing a checkpoint).
+
+ .. warning:: If this method is cancelled, then it may leave the
+ :class:`SSLStream` in an unusable state. If this happens then any
+ future attempt to use the object will raise
+ :exc:`trio.BrokenStreamError`.
+
+ """
+ try:
+ self._check_status()
+ except:
+ await _core.yield_briefly()
+ raise
+ await self._handshook.ensure(checkpoint=True)
+
+ # Most things work if we don't explicitly force do_handshake to be called
+ # before calling receive_some or send_all, because openssl will
+ # automatically perform the handshake on the first SSL_{read,write}
+ # call. BUT, allowing openssl to do this will disable Python's hostname
+ # checking!!! See:
+ # https://bugs.python.org/issue30141
+ # So we *definitely* have to make sure that do_handshake is called
+ # before doing anything else.
+ async def receive_some(self, max_bytes):
+ """Read some data from the underlying transport, decrypt it, and
+ return it.
+
+ See :meth:`trio.abc.ReceiveStream.receive_some` for details.
+
+ .. warning:: If this method is cancelled while the initial handshake
+ or a renegotiation are in progress, then it may leave the
+ :class:`SSLStream` in an unusable state. If this happens then any
+ future attempt to use the object will raise
+ :exc:`trio.BrokenStreamError`.
+
+ """
+ async with self._outer_recv_lock:
+ self._check_status()
+ try:
+ await self._handshook.ensure(checkpoint=False)
+ except _streams.BrokenStreamError as exc:
+ # For some reason, EOF before handshake sometimes raises
+ # SSLSyscallError instead of SSLEOFError (e.g. on my linux
+ # laptop, but not on appveyor). Thanks openssl.
+ if (self._https_compatible
+ and isinstance(
+ exc.__cause__, (SSLEOFError, SSLSyscallError))):
+ return b""
+ else:
+ raise
+ max_bytes = _operator.index(max_bytes)
+ if max_bytes < 1:
+ raise ValueError("max_bytes must be >= 1")
+ try:
+ return await self._retry(self._ssl_object.read, max_bytes)
+ except _streams.BrokenStreamError as exc:
+ # This isn't quite equivalent to just returning b"" in the
+ # first place, because we still end up with self._state set to
+ # BROKEN. But that's actually fine, because after getting an
+ # EOF on TLS then the only thing you can do is close the
+ # stream, and closing doesn't care about the state.
+ if (self._https_compatible
+ and isinstance(exc.__cause__, SSLEOFError)):
+ return b""
+ else:
+ raise
+
+ async def send_all(self, data):
+ """Encrypt some data and then send it on the underlying transport.
+
+ See :meth:`trio.abc.SendStream.send_all` for details.
+
+ .. warning:: If this method is cancelled, then it may leave the
+ :class:`SSLStream` in an unusable state. If this happens then any
+ attempt to use the object will raise
+ :exc:`trio.BrokenStreamError`.
+
+ """
+ async with self._outer_send_lock:
+ self._check_status()
+ await self._handshook.ensure(checkpoint=False)
+ # SSLObject interprets write(b"") as an EOF for some reason, which
+ # is not what we want.
+ if not data:
+ await _core.yield_briefly()
+ return
+ await self._retry(self._ssl_object.write, data)
+
+ async def unwrap(self):
+ """Cleanly close down the SSL/TLS encryption layer, allowing the
+ underlying stream to be used for unencrypted communication.
+
+ You almost certainly don't need this.
+
+ Returns:
+ A pair ``(transport_stream, trailing_bytes)``, where
+ ``transport_stream`` is the underlying transport stream, and
+ ``trailing_bytes`` is a byte string. Since :class:`SSLStream`
+ doesn't necessarily know where the end of the encrypted data will
+ be, it can happen that it accidentally reads too much from the
+ underlying stream. ``trailing_bytes`` contains this extra data; you
+ should process it as if it was returned from a call to
+ ``transport_stream.receive_some(...)``.
+
+ """
+ async with self._outer_recv_lock, self._outer_send_lock:
+ self._check_status()
+ await self._handshook.ensure(checkpoint=False)
+ await self._retry(self._ssl_object.unwrap)
+ transport_stream = self.transport_stream
+ self.transport_stream = None
+ self._state = _State.CLOSED
+ return (transport_stream, self._incoming.read())
+
+ def forceful_close(self):
+ """Forcefully closes the underlying transport and marks this stream as
+ closed.
+
+ """
+ if self._state is not _State.CLOSED:
+ self._state = _State.CLOSED
+ self.transport_stream.forceful_close()
+
+ async def graceful_close(self):
+ """Gracefully shut down this connection, and close the underlying
+ transport.
+
+ If ``https_compatible`` is False (the default), then this attempts to
+ first send a ``close_notify`` and then close the underlying stream by
+ calling its :meth:`~trio.abc.AsyncResource.graceful_close` method.
+
+ If ``https_compatible`` is set to True, then this simply closes the
+ underlying stream and marks this stream as closed.
+
+ """
+ if self._state is _State.CLOSED:
+ await _core.yield_briefly()
+ return
+ if self._state is _State.BROKEN or self._https_compatible:
+ self._state = _State.CLOSED
+ await self.transport_stream.graceful_close()
+ return
+ try:
+ await self._handshook.ensure(checkpoint=False)
+ # Here, we call SSL_shutdown *once*, because we want to send a
+ # close_notify but *not* wait for the other side to send back a
+ # response. In principle it would be more polite to wait for the
+ # other side to reply with their own close_notify. However, if
+ # they aren't paying attention (e.g., if they're just sending
+ # data and not receiving) then we will never notice our
+ # close_notify and we'll be waiting forever. Eventually we'll time
+ # out (hopefully), but it's still kind of nasty. And we can't
+ # require the other side to always be receiving, because (a)
+ # backpressure is kind of important, and (b) I bet there are
+ # broken TLS implementations out there that don't receive all the
+ # time. (Like e.g. anyone using Python ssl in synchronous mode.)
+ #
+ # The send-then-immediately-close behavior is explicitly allowed
+ # by the TLS specs, so we're ok on that.
+ #
+ # Subtlety: SSLObject.unwrap will immediately call it a second
+ # time, and the second time will raise SSLWantReadError because
+ # there hasn't been time for the other side to respond
+ # yet. (Unless they spontaneously sent a close_notify before we
+ # called this, and it's either already been processed or gets
+ # pulled out of the buffer by Python's second call.) So the way to
+ # do what we want is to ignore SSLWantReadError on this call.
+ #
+ # Also, because the other side might have already sent
+ # close_notify and closed their connection then it's possible that
+ # our attempt to send close_notify will raise
+ # BrokenStreamError. This is totally legal, and in fact can happen
+ # with two well-behaved trio programs talking to each other, so we
+ # don't want to raise an error. So we suppress BrokenStreamError
+ # here. (This is safe, because literally the only thing this call
+ # to _retry will do is send the close_notify alert, so that's
+ # surely where the error comes from.)
+ #
+ # FYI in some cases this could also raise SSLSyscallError which I
+ # think is because SSL_shutdown is terrible. (Check out that note
+ # at the bottom of the man page saying that it sometimes gets
+ # raised spuriously.) I haven't seen this since we switched to
+ # immediately closing the socket, and I don't know exactly what
+ # conditions cause it and how to respond, so for now we're just
+ # letting that happen. But if you start seeing it, then hopefully
+ # this will give you a little head start on tracking it down,
+ # because whoa did this puzzle us at the 2017 PyCon sprints.
+ try:
+ await self._retry(
+ self._ssl_object.unwrap, ignore_want_read=True)
+ except _streams.BrokenStreamError:
+ pass
+ # Close the underlying stream
+ await self.transport_stream.graceful_close()
+ except:
+ self.transport_stream.forceful_close()
+ raise
+ finally:
+ self._state = _State.CLOSED
+
+ async def wait_send_all_might_not_block(self):
+ """See :meth:`trio.abc.SendStream.wait_send_all_might_not_block`.
+
+ """
+ # This method's implementation is deceptively simple.
+ #
+ # First, we take the outer send lock, because of trio's standard
+ # semantics that wait_send_all_might_not_block and send_all conflict.
+ async with self._outer_send_lock:
+ self._check_status()
+ # Then we take the inner send lock. We know that no other tasks
+ # are calling self.send_all or self.wait_send_all_might_not_block,
+ # because we have the outer_send_lock. But! There might be another
+ # task calling self.receive_some -> transport_stream.send_all, in
+ # which case if we were to call
+ # transport_stream.wait_send_all_might_not_block directly we'd
+ # have two tasks doing write-related operations on
+ # transport_stream simultaneously, which is not allowed. We
+ # *don't* want to raise this conflict to our caller, because it's
+ # purely an internal affair – all they did was call
+ # wait_send_all_might_not_block and receive_some at the same time,
+ # which is totally valid. And waiting for the lock is OK, because
+ # a call to send_all certainly wouldn't complete while the other
+ # task holds the lock.
+ async with self._inner_send_lock:
+ # Now we have the lock, which creates another potential
+ # problem: what if a call to self.receive_some attempts to do
+ # transport_stream.send_all now? It'll have to wait for us to
+ # finish! But that's OK, because we release the lock as soon
+ # as the underlying stream becomes writable, and the
+ # self.receive_some call wasn't going to make any progress
+ # until then anyway.
+ #
+ # Of course, this does mean we might return *before* the
+ # stream is logically writable, because immediately after we
+ # return self.receive_some might write some data and make it
+ # non-writable again. But that's OK too,
+ # wait_send_all_might_not_block only guarantees that it
+ # doesn't return late.
+ await self.transport_stream.wait_send_all_might_not_block()
diff --git a/trio/testing.py b/trio/testing.py
index a04e4869e0..93a27726ed 100644
--- a/trio/testing.py
+++ b/trio/testing.py
@@ -5,17 +5,27 @@
from collections import defaultdict
import time
from math import inf
+import random
+import operator
import attr
from async_generator import async_generator, yield_
-from ._util import acontextmanager
+from . import _util
from . import _core
-from . import Event, sleep
-from .abc import Clock
+from . import _streams
+from . import Event as _Event, sleep as _sleep
+from . import abc as _abc
-__all__ = ["wait_all_tasks_blocked", "trio_test", "MockClock",
- "assert_yields", "assert_no_yields", "Sequencer"]
+__all__ = [
+ "wait_all_tasks_blocked", "trio_test", "MockClock",
+ "assert_yields", "assert_no_yields", "Sequencer",
+ "MemorySendStream", "MemoryReceiveStream", "memory_stream_pump",
+ "memory_stream_one_way_pair", "memory_stream_pair",
+ "lockstep_stream_one_way_pair", "lockstep_stream_pair",
+ "check_one_way_stream", "check_two_way_stream",
+ "check_half_closeable_stream",
+]
# re-export
from ._core import wait_all_tasks_blocked
@@ -32,7 +42,7 @@ def trio_test(fn):
@wraps(fn)
def wrapper(**kwargs):
__tracebackhide__ = True
- clocks = [c for c in kwargs.values() if isinstance(c, Clock)]
+ clocks = [c for c in kwargs.values() if isinstance(c, _abc.Clock)]
if not clocks:
clock = None
elif len(clocks) == 1:
@@ -42,10 +52,15 @@ def wrapper(**kwargs):
return _core.run(partial(fn, **kwargs), clock=clock)
return wrapper
+
+################################################################
+# The glorious MockClock
+################################################################
+
# Prior art:
# https://twistedmatrix.com/documents/current/api/twisted.internet.task.Clock.html
# https://github.com/ztellman/manifold/issues/57
-class MockClock(Clock):
+class MockClock(_abc.Clock):
"""A user-controllable clock suitable for writing tests.
Args:
@@ -61,41 +76,53 @@ class MockClock(Clock):
.. attribute:: autojump_threshold
- If all tasks are blocked for this many real seconds (i.e., according to
- the actual clock, not this clock), then this clock automatically jumps
- ahead to the run loop's next scheduled timeout. Default is
- :data:`math.inf`, i.e., to never autojump. You can assign to this
- attribute to change it.
+ The clock keeps an eye on the run loop, and if at any point it detects
+ that all tasks have been blocked for this many real seconds (i.e.,
+ according to the actual clock, not this clock), then the clock
+ automatically jumps ahead to the run loop's next scheduled
+ timeout. Default is :data:`math.inf`, i.e., to never autojump. You can
+ assign to this attribute to change it.
+
+ Basically the idea is that if you have code or tests that use sleeps
+ and timeouts, you can use this to make it run much faster, totally
+ automatically. (At least, as long as those sleeps/timeouts are
+ happening inside trio; if your test involves talking to external
+ service and waiting for it to timeout then obviously we can't help you
+ there.)
You should set this to the smallest value that lets you reliably avoid
"false alarms" where some I/O is in flight (e.g. between two halves of
a socketpair) but the threshold gets triggered and time gets advanced
anyway. This will depend on the details of your tests and test
environment. If you aren't doing any I/O (like in our sleeping example
- above) then setting it to zero is fine.
-
- Note that setting this attribute interacts with the run loop, so it can
- only be done from inside a run context or (as a special case) before
- calling :func:`trio.run`.
+ above) then just set it to zero, and the clock will jump whenever all
+ tasks are blocked.
.. warning::
If you're using :func:`wait_all_tasks_blocked` and
:attr:`autojump_threshold` together, then you have to be
- careful. Setting :attr:`autojump_threshold` acts like a task
- calling::
+ careful. Setting :attr:`autojump_threshold` acts like a background
+ task calling::
while True:
- await wait_all_tasks_blocked(cushion=clock.autojump_threshold)
+ await wait_all_tasks_blocked(
+ cushion=clock.autojump_threshold, tiebreaker=float("inf"))
This means that if you call :func:`wait_all_tasks_blocked` with a
cushion *larger* than your autojump threshold, then your call to
:func:`wait_all_tasks_blocked` will never return, because the
autojump task will keep waking up before your task does, and each
- time it does it'll reset your task's timer.
+ time it does it'll reset your task's timer. However, if your cushion
+ and the autojump threshold are the *same*, then the autojump's
+ tiebreaker will prevent them from interfering (unless you also set
+ your tiebreaker to infinity for some reason. Don't do that). As an
+ important special case: this means that if you set an autojump
+ threshold of zero and use :func:`wait_all_tasks_blocked` with the
+ default zero cushion, then everything will work fine.
- **Summary**: you should set :attr:`autojump_threshold` to be at *least*
- as large as the largest cushion you plan to pass to
+ **Summary**: you should set :attr:`autojump_threshold` to be at
+ least as large as the largest cushion you plan to pass to
:func:`wait_all_tasks_blocked`.
"""
@@ -146,11 +173,12 @@ async def _autojumper(self):
self._autojump_cancel_scope = cancel_scope
try:
# If the autojump_threshold changes, then the setter does
- # cancel_scope.cancel() ,which causes this line to raise
- # Cancelled, which is absorbed by the cancel scope above,
- # and effectively just causes us to skip start the loop
- # over, like a 'continue' here.
- await wait_all_tasks_blocked(self._autojump_threshold)
+ # cancel_scope.cancel(), which causes the next line here
+ # to raise Cancelled, which is absorbed by the cancel
+ # scope above, and effectively just causes us to skip back
+ # to the start the loop, like a 'continue'.
+ await wait_all_tasks_blocked(
+ self._autojump_threshold, float("inf"))
statistics = _core.current_statistics()
jump = statistics.seconds_to_next_deadline
if jump < inf:
@@ -160,7 +188,7 @@ async def _autojumper(self):
# until some actual I/O arrives (or maybe another
# wait_all_tasks_blocked task wakes up). That's fine,
# but if our threshold is zero then this will become a
- # busy-wait -- so insert a small-but-non-zero sleep to
+ # busy-wait -- so insert a small-but-non-zero _sleep to
# avoid that.
if self._autojump_threshold == 0:
await wait_all_tasks_blocked(0.01)
@@ -221,6 +249,10 @@ def jump(self, seconds):
self._virtual_base += seconds
+################################################################
+# Testing checkpoints
+################################################################
+
@contextmanager
def _assert_yields_or_not(expected):
__tracebackhide__ = True
@@ -276,6 +308,10 @@ def assert_no_yields():
return _assert_yields_or_not(False)
+################################################################
+# Sequencer
+################################################################
+
@attr.s(slots=True, cmp=False, hash=False)
class Sequencer:
"""A convenience class for forcing code in different tasks to run in an
@@ -318,11 +354,11 @@ async def main():
"""
_sequence_points = attr.ib(
- default=attr.Factory(lambda: defaultdict(Event)), init=False)
+ default=attr.Factory(lambda: defaultdict(_Event)), init=False)
_claimed = attr.ib(default=attr.Factory(set), init=False)
_broken = attr.ib(default=False, init=False)
- @acontextmanager
+ @_util.acontextmanager
@async_generator
async def __call__(self, position):
if position in self._claimed:
@@ -347,3 +383,993 @@ async def __call__(self, position):
await yield_()
finally:
self._sequence_points[position + 1].set()
+
+
+################################################################
+# Generic stream tests
+################################################################
+
+class _CloseBoth:
+ def __init__(self, both):
+ self._both = list(both)
+
+ async def __aenter__(self):
+ return self._both
+
+ async def __aexit__(self, *args):
+ try:
+ self._both[0].forceful_close()
+ finally:
+ self._both[1].forceful_close()
+
+
+@contextmanager
+def _assert_raises(exc):
+ __tracebackhide__ = True
+ try:
+ yield
+ except exc:
+ pass
+ else:
+ raise AssertionError("expected exception: {}".format(exc))
+
+
+async def check_one_way_stream(stream_maker, clogged_stream_maker):
+ """Perform a number of generic tests on a custom one-way stream
+ implementation.
+
+ Args:
+ stream_maker: An async (!) function which returns a connected
+ (:class:`~trio.abc.SendStream`, :class:`~trio.abc.ReceiveStream`)
+ pair.
+ clogged_stream_maker: Either None, or an async function similar to
+ stream_maker, but with the extra property that the returned stream
+ is in a state where ``send_all`` and
+ ``wait_send_all_might_not_block`` will block until ``receive_some``
+ has been called. This allows for more thorough testing of some edge
+ cases, especially around ``wait_send_all_might_not_block``.
+
+ Raises:
+ AssertionError: if a test fails.
+
+ """
+ async with _CloseBoth(await stream_maker()) as (s, r):
+ assert isinstance(s, _abc.SendStream)
+ assert isinstance(r, _abc.ReceiveStream)
+
+ async def do_send_all(data):
+ with assert_yields():
+ assert await s.send_all(data) is None
+
+ async def do_receive_some(max_bytes):
+ with assert_yields():
+ return await r.receive_some(1)
+
+ async def checked_receive_1(expected):
+ assert await do_receive_some(1) == expected
+
+ async def do_graceful_close(resource):
+ with assert_yields():
+ await resource.graceful_close()
+
+ # Simple sending/receiving
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(do_send_all, b"x")
+ nursery.spawn(checked_receive_1, b"x")
+
+ async def send_empty_then_y():
+ # Streams should tolerate sending b"" without giving it any
+ # special meaning.
+ await do_send_all(b"")
+ await do_send_all(b"y")
+
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(send_empty_then_y)
+ nursery.spawn(checked_receive_1, b"y")
+
+ ### Checking various argument types
+
+ # send_all accepts bytearray and memoryview
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(do_send_all, bytearray(b"1"))
+ nursery.spawn(checked_receive_1, b"1")
+
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(do_send_all, memoryview(b"2"))
+ nursery.spawn(checked_receive_1, b"2")
+
+ # max_bytes must be a positive integer
+ with _assert_raises(ValueError):
+ await r.receive_some(-1)
+ with _assert_raises(ValueError):
+ await r.receive_some(0)
+ with _assert_raises(TypeError):
+ await r.receive_some(1.5)
+
+ with _assert_raises(_core.ResourceBusyError):
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(do_receive_some, 1)
+ nursery.spawn(do_receive_some, 1)
+
+ # Method always has to exist, and an empty stream with a blocked
+ # receive_some should *always* allow send_all. (Technically it's legal
+ # for send_all to wait until receive_some is called to run, though; a
+ # stream doesn't *have* to have any internal buffering. That's why we
+ # spawn a concurrent receive_some call, then cancel it.)
+ async def simple_check_wait_send_all_might_not_block(scope):
+ with assert_yields():
+ await s.wait_send_all_might_not_block()
+ scope.cancel()
+
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(simple_check_wait_send_all_might_not_block,
+ nursery.cancel_scope)
+ nursery.spawn(do_receive_some, 1)
+
+ # closing the r side leads to BrokenStreamError on the s side
+ # (eventually)
+ async def expect_broken_stream_on_send():
+ with _assert_raises(_streams.BrokenStreamError):
+ while True:
+ await do_send_all(b"x" * 100)
+
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(expect_broken_stream_on_send)
+ nursery.spawn(do_graceful_close, r)
+
+ # once detected, the stream stays broken
+ with _assert_raises(_streams.BrokenStreamError):
+ await do_send_all(b"x" * 100)
+
+ # r closed -> ClosedStreamError on the receive side
+ with _assert_raises(_streams.ClosedStreamError):
+ await do_receive_some(4096)
+
+ # we can close the same stream repeatedly, it's fine
+ r.forceful_close()
+ with assert_yields():
+ await r.graceful_close()
+ r.forceful_close()
+
+ # closing the sender side
+ with assert_yields():
+ await s.graceful_close()
+
+ # now trying to send raises ClosedStreamError
+ with _assert_raises(_streams.ClosedStreamError):
+ await do_send_all(b"x" * 100)
+
+ # ditto for wait_send_all_might_not_block
+ with _assert_raises(_streams.ClosedStreamError):
+ with assert_yields():
+ await s.wait_send_all_might_not_block()
+
+ # and again, repeated closing is fine
+ s.forceful_close()
+ await do_graceful_close(s)
+ s.forceful_close()
+
+ async with _CloseBoth(await stream_maker()) as (s, r):
+ # if send-then-graceful-close, receiver gets data then b""
+ async def send_then_close():
+ await do_send_all(b"y")
+ await do_graceful_close(s)
+
+ async def receive_send_then_close():
+ # We want to make sure that if the sender closes the stream before
+ # we read anything, then we still get all the data. But some
+ # streams might block on the do_send_all call. So we let the
+ # sender get as far as it can, then we receive.
+ await wait_all_tasks_blocked()
+ await checked_receive_1(b"y")
+ await checked_receive_1(b"")
+ await do_graceful_close(r)
+
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(send_then_close)
+ nursery.spawn(receive_send_then_close)
+
+ # using forceful_close also makes things closed
+ async with _CloseBoth(await stream_maker()) as (s, r):
+ r.forceful_close()
+
+ with _assert_raises(_streams.BrokenStreamError):
+ while True:
+ await do_send_all(b"x" * 100)
+
+ with _assert_raises(_streams.ClosedStreamError):
+ await do_receive_some(4096)
+
+ async with _CloseBoth(await stream_maker()) as (s, r):
+ s.forceful_close()
+
+ with _assert_raises(_streams.ClosedStreamError):
+ await do_send_all(b"123")
+
+ # after the sender does a forceful close, the receiver might either
+ # get BrokenStreamError or a clean b""; either is OK. Not OK would be
+ # if it freezes, or returns data.
+ try:
+ await checked_receive_1(b"")
+ except _streams.BrokenStreamError:
+ pass
+
+ # cancelled graceful_close still closes
+ async with _CloseBoth(await stream_maker()) as (s, r):
+ with _core.open_cancel_scope() as scope:
+ scope.cancel()
+ await r.graceful_close()
+
+ with _core.open_cancel_scope() as scope:
+ scope.cancel()
+ await s.graceful_close()
+
+ with _assert_raises(_streams.ClosedStreamError):
+ await do_send_all(b"123")
+
+ with _assert_raises(_streams.ClosedStreamError):
+ await do_receive_some(4096)
+
+ # Check that we can still gracefully close a stream after an operation has
+ # been cancelled. This can be challenging if cancellation can leave the
+ # stream internals in an inconsistent state, e.g. for
+ # SSLStream. Unfortunately this test isn't very thorough; the really
+ # challenging case for something like SSLStream is it gets cancelled
+ # *while* it's sending data on the underlying, not before. But testing
+ # that requires some special-case handling of the particular stream setup;
+ # we can't do it here. Maybe we could do a bit better with
+ # https://github.com/python-trio/trio/issues/77
+ async with _CloseBoth(await stream_maker()) as (s, r):
+ async def expect_cancelled(afn, *args):
+ with _assert_raises(_core.Cancelled):
+ await afn(*args)
+
+ with _core.open_cancel_scope() as scope:
+ scope.cancel()
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(expect_cancelled, do_send_all, b"x")
+ nursery.spawn(expect_cancelled, do_receive_some, 1)
+
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(do_graceful_close, s)
+ nursery.spawn(do_graceful_close, r)
+
+ # check wait_send_all_might_not_block, if we can
+ if clogged_stream_maker is not None:
+ async with _CloseBoth(await clogged_stream_maker()) as (s, r):
+ record = []
+
+ async def waiter(cancel_scope):
+ record.append("waiter sleeping")
+ with assert_yields():
+ await s.wait_send_all_might_not_block()
+ record.append("waiter wokeup")
+ cancel_scope.cancel()
+
+ async def receiver():
+ # give wait_send_all_might_not_block a chance to block
+ await wait_all_tasks_blocked()
+ record.append("receiver starting")
+ while True:
+ await r.receive_some(16834)
+
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(waiter, nursery.cancel_scope)
+ await wait_all_tasks_blocked()
+ nursery.spawn(receiver)
+
+ assert record == [
+ "waiter sleeping",
+ "receiver starting",
+ "waiter wokeup",
+ ]
+
+ async with _CloseBoth(await clogged_stream_maker()) as (s, r):
+ # simultaneous wait_send_all_might_not_block fails
+ with _assert_raises(_core.ResourceBusyError):
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(s.wait_send_all_might_not_block)
+ nursery.spawn(s.wait_send_all_might_not_block)
+
+ # and simultaneous send_all and wait_send_all_might_not_block (NB
+ # this test might destroy the stream b/c we end up cancelling
+ # send_all and e.g. SSLStream can't handle that, so we have to
+ # recreate afterwards)
+ with _assert_raises(_core.ResourceBusyError):
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(s.wait_send_all_might_not_block)
+ nursery.spawn(s.send_all, b"123")
+
+ async with _CloseBoth(await clogged_stream_maker()) as (s, r):
+ # send_all and send_all blocked simultaneously should also raise
+ # (but again this might destroy the stream)
+ with _assert_raises(_core.ResourceBusyError):
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(s.send_all, b"123")
+ nursery.spawn(s.send_all, b"123")
+
+ # closing the receiver causes wait_send_all_might_not_block to return
+ async with _CloseBoth(await clogged_stream_maker()) as (s, r):
+ async def sender():
+ try:
+ with assert_yields():
+ await s.wait_send_all_might_not_block()
+ except _streams.BrokenStreamError:
+ pass
+
+ async def receiver():
+ await wait_all_tasks_blocked()
+ r.forceful_close()
+
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(sender)
+ nursery.spawn(receiver)
+
+ # and again with the call starting after the close
+ async with _CloseBoth(await clogged_stream_maker()) as (s, r):
+ r.forceful_close()
+ try:
+ with assert_yields():
+ await s.wait_send_all_might_not_block()
+ except _streams.BrokenStreamError:
+ pass
+
+
+async def check_two_way_stream(stream_maker, clogged_stream_maker):
+ """Perform a number of generic tests on a custom two-way stream
+ implementation.
+
+ This is similar to :func:`check_one_way_stream`, except that the maker
+ functions are expected to return objects implementing the
+ :class:`~trio.abc.Stream` interface.
+
+ This function tests a *superset* of what :func:`check_one_way_stream`
+ checks – if you call this, then you don't need to also call
+ :func:`check_one_way_stream`.
+
+ """
+ await check_one_way_stream(stream_maker, clogged_stream_maker)
+
+ async def flipped_stream_maker():
+ return reversed(await stream_maker())
+ if clogged_stream_maker is not None:
+ async def flipped_clogged_stream_maker():
+ return reversed(await clogged_stream_maker())
+ else:
+ flipped_clogged_stream_maker = None
+ await check_one_way_stream(
+ flipped_stream_maker, flipped_clogged_stream_maker)
+
+ async with _CloseBoth(await stream_maker()) as (s1, s2):
+ assert isinstance(s1, _abc.Stream)
+ assert isinstance(s2, _abc.Stream)
+
+ # Duplex can be a bit tricky, might as well check it as well
+ DUPLEX_TEST_SIZE = 2 ** 20
+ CHUNK_SIZE_MAX = 2 ** 14
+
+ r = random.Random(0)
+ i = r.getrandbits(8 * DUPLEX_TEST_SIZE)
+ test_data = i.to_bytes(DUPLEX_TEST_SIZE, "little")
+
+ async def sender(s, data, seed):
+ r = random.Random(seed)
+ m = memoryview(data)
+ while m:
+ chunk_size = r.randint(1, CHUNK_SIZE_MAX)
+ await s.send_all(m[:chunk_size])
+ m = m[chunk_size:]
+
+ async def receiver(s, data, seed):
+ r = random.Random(seed)
+ got = bytearray()
+ while len(got) < len(data):
+ chunk = await s.receive_some(r.randint(1, CHUNK_SIZE_MAX))
+ assert chunk
+ got += chunk
+ assert got == data
+
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(sender, s1, test_data, 0)
+ nursery.spawn(sender, s2, test_data[::-1], 1)
+ nursery.spawn(receiver, s1, test_data[::-1], 2)
+ nursery.spawn(receiver, s2, test_data, 3)
+
+ async def expect_receive_some_empty():
+ assert await s2.receive_some(10) == b""
+ await s2.graceful_close()
+
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(expect_receive_some_empty)
+ nursery.spawn(s1.graceful_close)
+
+
+async def check_half_closeable_stream(stream_maker, clogged_stream_maker):
+ """Perform a number of generic tests on a custom half-closeable stream
+ implementation.
+
+ This is similar to :func:`check_two_way_stream`, except that the maker
+ functions are expected to return objects that implement the
+ :class:`~trio.abc.HalfCloseableStream` interface.
+
+ This function tests a *superset* of what :func:`check_two_way_stream`
+ checks – if you call this, then you don't need to also call
+ :func:`check_two_way_stream`.
+
+ """
+ await check_two_way_stream(stream_maker, clogged_stream_maker)
+
+ async with _CloseBoth(await stream_maker()) as (s1, s2):
+ assert isinstance(s1, _abc.HalfCloseableStream)
+ assert isinstance(s2, _abc.HalfCloseableStream)
+
+ async def send_x_then_eof(s):
+ await s.send_all(b"x")
+ with assert_yields():
+ await s.send_eof()
+
+ async def expect_x_then_eof(r):
+ await wait_all_tasks_blocked()
+ assert await r.receive_some(10) == b"x"
+ assert await r.receive_some(10) == b""
+
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(send_x_then_eof, s1)
+ nursery.spawn(expect_x_then_eof, s2)
+
+ # now sending is disallowed
+ with _assert_raises(_streams.ClosedStreamError):
+ await s1.send_all(b"y")
+
+ # but we can do send_eof again
+ with assert_yields():
+ await s1.send_eof()
+
+ # and we can still send stuff back the other way
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(send_x_then_eof, s2)
+ nursery.spawn(expect_x_then_eof, s1)
+
+ if clogged_stream_maker is not None:
+ async with _CloseBoth(await clogged_stream_maker()) as (s1, s2):
+ # send_all and send_eof simultaneously is not ok
+ with _assert_raises(_core.ResourceBusyError):
+ async with _core.open_nursery() as nursery:
+ t = nursery.spawn(s1.send_all, b"x")
+ await wait_all_tasks_blocked()
+ assert t.result is None
+ nursery.spawn(s1.send_eof)
+
+ async with _CloseBoth(await clogged_stream_maker()) as (s1, s2):
+ # wait_send_all_might_not_block and send_eof simultaneously is not
+ # ok either
+ with _assert_raises(_core.ResourceBusyError):
+ async with _core.open_nursery() as nursery:
+ t = nursery.spawn(s1.wait_send_all_might_not_block)
+ await wait_all_tasks_blocked()
+ assert t.result is None
+ nursery.spawn(s1.send_eof)
+
+
+################################################################
+# In-memory streams
+################################################################
+
+class _UnboundedByteQueue:
+ def __init__(self):
+ self._data = bytearray()
+ self._closed = False
+ self._lot = _core.ParkingLot()
+ self._fetch_lock = _util.UnLock(
+ _core.ResourceBusyError, "another task is already fetching data")
+
+ def close(self):
+ self._closed = True
+ self._lot.unpark_all()
+
+ def put(self, data):
+ if self._closed:
+ raise _streams.ClosedStreamError("virtual connection closed")
+ self._data += data
+ self._lot.unpark_all()
+
+ def _check_max_bytes(self, max_bytes):
+ if max_bytes is None:
+ return
+ max_bytes = operator.index(max_bytes)
+ if max_bytes < 1:
+ raise ValueError("max_bytes must be >= 1")
+
+ def _get_impl(self, max_bytes):
+ assert self._closed or self._data
+ if max_bytes is None:
+ max_bytes = len(self._data)
+ if self._data:
+ chunk = self._data[:max_bytes]
+ del self._data[:max_bytes]
+ assert chunk
+ return chunk
+ else:
+ return bytearray()
+
+ def get_nowait(self, max_bytes=None):
+ with self._fetch_lock.sync:
+ self._check_max_bytes(max_bytes)
+ if not self._closed and not self._data:
+ raise _core.WouldBlock
+ return self._get_impl(max_bytes)
+
+ async def get(self, max_bytes=None):
+ async with self._fetch_lock:
+ self._check_max_bytes(max_bytes)
+ if not self._closed and not self._data:
+ await self._lot.park()
+ return self._get_impl(max_bytes)
+
+
+class MemorySendStream(_abc.SendStream):
+ """An in-memory :class:`~trio.abc.SendStream`.
+
+ Args:
+ send_all_hook: An async function, or None. Called from
+ :meth:`send_all`. Can do whatever you like.
+ wait_send_all_might_not_block_hook: An async function, or None. Called
+ from :meth:`wait_send_all_might_not_block`. Can do whatever you
+ like.
+ close_hook: A synchronous function, or None. Called from
+ :meth:`forceful_close`. Can do whatever you like.
+
+ .. attribute:: send_all_hook
+ wait_send_all_might_not_block_hook
+ close_hook
+
+ All of these hooks are also exposed as attributes on the object, and
+ you can change them at any time.
+
+ """
+ def __init__(self,
+ send_all_hook=None,
+ wait_send_all_might_not_block_hook=None,
+ close_hook=None):
+ self._lock = _util.UnLock(
+ _core.ResourceBusyError, "another task is using this stream")
+ self._outgoing = _UnboundedByteQueue()
+ self.send_all_hook = send_all_hook
+ self.wait_send_all_might_not_block_hook = wait_send_all_might_not_block_hook
+ self.close_hook = close_hook
+
+ async def send_all(self, data):
+ """Places the given data into the object's internal buffer, and then
+ calls the :attr:`send_all_hook` (if any).
+
+ """
+ # The lock itself is a checkpoint, but then we also yield inside the
+ # lock to give ourselves a chance to detect buggy user code that calls
+ # this twice at the same time.
+ async with self._lock:
+ await _core.yield_briefly()
+ self._outgoing.put(data)
+ if self.send_all_hook is not None:
+ await self.send_all_hook()
+
+ async def wait_send_all_might_not_block(self):
+ """Calls the :attr:`wait_send_all_might_not_block_hook` (if any), and
+ then returns immediately.
+
+ """
+ # The lock itself is a checkpoint, but then we also yield inside the
+ # lock to give ourselves a chance to detect buggy user code that calls
+ # this twice at the same time.
+ async with self._lock:
+ await _core.yield_briefly()
+ # check for being closed:
+ self._outgoing.put(b"")
+ if self.wait_send_all_might_not_block_hook is not None:
+ await self.wait_send_all_might_not_block_hook()
+
+ def forceful_close(self):
+ """Marks this stream as closed, and then calls the :attr:`close_hook`
+ (if any).
+
+ """
+ self._outgoing.close()
+ if self.close_hook is not None:
+ self.close_hook()
+
+ async def get_data(self, max_bytes=None):
+ """Retrieves data from the internal buffer, blocking if necessary.
+
+ Args:
+ max_bytes (int or None): The maximum amount of data to
+ retrieve. None (the default) means to retrieve all the data
+ that's present (but still blocks until at least one byte is
+ available).
+
+ Returns:
+ If this stream has been closed, an empty bytearray. Otherwise, the
+ requested data.
+
+ """
+ return await self._outgoing.get(max_bytes)
+
+ def get_data_nowait(self, max_bytes=None):
+ """Retrieves data from the internal buffer, but doesn't block.
+
+ See :meth:`get_data` for details.
+
+ Raises:
+ trio.WouldBlock: if no data is available to retrieve.
+
+ """
+ return self._outgoing.get_nowait(max_bytes)
+
+
+class MemoryReceiveStream(_abc.ReceiveStream):
+ """An in-memory :class:`~trio.abc.ReceiveStream`.
+
+ Args:
+ receive_some_hook: An async function, or None. Called from
+ :meth:`receive_some`. Can do whatever you like.
+ close_hook: A synchronous function, or None. Called from
+ :meth:`forceful_close`. Can do whatever you like.
+
+ .. attribute:: receive_some_hook
+ close_hook
+
+ Both hooks are also exposed as attributes on the object, and you can
+ change them at any time.
+
+ """
+ def __init__(self, receive_some_hook=None, close_hook=None):
+ self._lock = _util.UnLock(
+ _core.ResourceBusyError, "another task is using this stream")
+ self._incoming = _UnboundedByteQueue()
+ self._closed = False
+ self.receive_some_hook = receive_some_hook
+ self.close_hook = close_hook
+
+ async def receive_some(self, max_bytes):
+ """Calls the :attr:`receive_some_hook` (if any), and then retrieves
+ data from the internal buffer, blocking if necessary.
+
+ """
+ # The lock itself is a checkpoint, but then we also yield inside the
+ # lock to give ourselves a chance to detect buggy user code that calls
+ # this twice at the same time.
+ async with self._lock:
+ await _core.yield_briefly()
+ if max_bytes is None:
+ raise TypeError("max_bytes must not be None")
+ if self._closed:
+ raise _streams.ClosedStreamError
+ if self.receive_some_hook is not None:
+ await self.receive_some_hook()
+ return await self._incoming.get(max_bytes)
+
+ def forceful_close(self):
+ """Discards any pending data from the internal buffer, and marks this
+ stream as closed.
+
+ """
+ # discard any pending data
+ self._closed = True
+ try:
+ self._incoming.get_nowait()
+ except _core.WouldBlock:
+ pass
+ self._incoming.close()
+ if self.close_hook is not None:
+ self.close_hook()
+
+ def put_data(self, data):
+ """Appends the given data to the internal buffer.
+
+ """
+ self._incoming.put(data)
+
+ def put_eof(self):
+ """Adds an end-of-file marker to the internal buffer.
+
+ """
+ self._incoming.close()
+
+
+def memory_stream_pump(
+ memory_send_stream, memory_recieve_stream, *, max_bytes=None):
+ """Take data out of the given :class:`MemorySendStream`'s internal buffer,
+ and put it into the given :class:`MemoryReceiveStream`'s internal buffer.
+
+ Args:
+ memory_send_stream (MemorySendStream): The stream to get data from.
+ memory_recieve_stream (MemoryReceiveStream): The stream to put data into.
+ max_bytes (int or None): The maximum amount of data to transfer in this
+ call, or None to transfer all available data.
+
+ Returns:
+ True if it successfully transferred some data, or False if there was no
+ data to transfer.
+
+ This is used to implement :func:`memory_stream_one_way_pair` and
+ :func:`memory_stream_pair`; see the latter's docstring for an example
+ of how you might use it yourself.
+
+ """
+ try:
+ data = memory_send_stream.get_data_nowait(max_bytes)
+ except _core.WouldBlock:
+ return False
+ try:
+ if not data:
+ memory_recieve_stream.put_eof()
+ else:
+ memory_recieve_stream.put_data(data)
+ except _streams.ClosedStreamError:
+ raise _streams.BrokenStreamError("MemoryReceiveStream was closed")
+ return True
+
+
+def memory_stream_one_way_pair():
+ """Create a connected, pure-Python, unidirectional stream with infinite
+ buffering and flexible configuration options.
+
+ You can think of this as being a no-operating-system-involved
+ trio-streamsified version of :func:`os.pipe` (except that :func:`os.pipe`
+ returns the streams in the wrong order – we follow the superior convention
+ that data flows from left to right).
+
+ Returns:
+ A tuple (:class:`MemorySendStream`, :class:`MemoryReceiveStream`), where
+ the :class:`MemorySendStream` has its hooks set up so that it calls
+ :func:`memory_stream_pump` from its
+ :attr:`~MemorySendStream.send_all_hook` and
+ :attr:`~MemorySendStream.close_hook`.
+
+ The end result is that data automatically flows from the
+ :class:`MemorySendStream` to the :class:`MemoryReceiveStream`. But you're
+ also free to rearrange things however you like. For example, you can
+ temporarily set the :attr:`~MemorySendStream.send_all_hook` to None if you
+ want to simulate a stall in data transmission. Or see
+ :func:`memory_stream_pair` for a more elaborate example.
+
+ """
+ send_stream = MemorySendStream()
+ recv_stream = MemoryReceiveStream()
+ def pump_from_send_stream_to_recv_stream():
+ memory_stream_pump(send_stream, recv_stream)
+ async def async_pump_from_send_stream_to_recv_stream():
+ pump_from_send_stream_to_recv_stream()
+ send_stream.send_all_hook = async_pump_from_send_stream_to_recv_stream
+ send_stream.close_hook = pump_from_send_stream_to_recv_stream
+ return send_stream, recv_stream
+
+
+def _make_stapled_pair(one_way_pair):
+ pipe1_send, pipe1_recv = one_way_pair()
+ pipe2_send, pipe2_recv = one_way_pair()
+ stream1 = _streams.StapledStream(pipe1_send, pipe2_recv)
+ stream2 = _streams.StapledStream(pipe2_send, pipe1_recv)
+ return stream1, stream2
+
+
+def memory_stream_pair():
+ """Create a connected, pure-Python, bidirectional stream with infinite
+ buffering and flexible configuration options.
+
+ This is a convenience function that creates two one-way streams using
+ :func:`memory_stream_one_way_pair`, and then uses
+ :class:`~trio.StapledStream` to combine them into a single bidirectional
+ stream.
+
+ This is like a no-operating-system-involved, trio-streamsified version of
+ :func:`socket.socketpair`.
+
+ Returns:
+ A pair of :class:`~trio.StapledStream` objects that are connected so
+ that data automatically flows from one to the other in both directions.
+
+ After creating a stream pair, you can send data back and forth, which is
+ enough for simple tests::
+
+ left, right = memory_stream_pair()
+ await left.send_all(b"123")
+ assert await right.receive_some(10) == b"123"
+ await right.send_all(b"456")
+ assert await left.receive_some(10) == b"456"
+
+ But if you read the docs for :class:`~trio.StapledStream` and
+ :func:`memory_stream_one_way_pair`, you'll see that all the pieces
+ involved in wiring this up are public APIs, so you can adjust to suit the
+ requirements of your tests. For example, here's how to tweak a stream so
+ that data flowing from left to right trickles in one byte at a time (but
+ data flowing from right to left proceeds at full speed)::
+
+ left, right = memory_stream_pair()
+ async def trickle():
+ # left is a StapledStream, and left.send_stream is a MemorySendStream
+ # right is a StapledStream, and right.recv_stream is a MemoryReceiveStream
+ while memory_stream_pump(left.send_stream, right.recv_stream, max_byes=1):
+ # Pause between each byte
+ await trio.sleep(1)
+ # Normally this send_all_hook calls memory_stream_pump directly without
+ # passing in a max_bytes. We replace it with our custom version:
+ left.send_stream.send_all_hook = trickle
+
+ And here's a simple test using our modified stream objects::
+
+ async def sender():
+ await left.send_all(b"12345")
+ await left.send_eof()
+
+ async def receiver():
+ while True:
+ data = await right.receive_some(10)
+ if data == b"":
+ return
+ print(data)
+
+ async with trio.open_nursery() as nursery:
+ nursery.spawn(sender)
+ nursery.spawn(receiver)
+
+ By default, this will print ``b"12345"`` and then immediately exit; with
+ our trickle stream it instead sleeps 1 second, then prints ``b"1"``, then
+ sleeps 1 second, then prints ``b"2"``, etc.
+
+ Pro-tip: you can insert sleep calls (like in our example above) to
+ manipulate the flow of data across tasks... and then use
+ :class:`MockClock` and its :attr:`~MockClock.autojump_threshold`
+ functionality to keep your test suite running quickly.
+
+ If you want to stress test a protocol implementation, one nice trick is to
+ use the :mod:`random` module (preferably with a fixed seed) to move random
+ numbers of bytes at a time, and insert random sleeps in between them. You
+ can also set up a custom :attr:`~MemoryReceiveStream.receive_some_hook` if
+ you want to manipulate things on the receiving side, and not just the
+ sending side.
+
+ """
+ return _make_stapled_pair(memory_stream_one_way_pair)
+
+
+class _LockstepByteQueue:
+ def __init__(self):
+ self._data = bytearray()
+ self._sender_closed = False
+ self._receiver_closed = False
+ self._receiver_waiting = False
+ self._waiters = _core.ParkingLot()
+ self._send_lock = _util.UnLock(
+ _core.ResourceBusyError, "another task is already sending")
+ self._receive_lock = _util.UnLock(
+ _core.ResourceBusyError, "another task is already receiving")
+
+ def _something_happened(self):
+ self._waiters.unpark_all()
+
+ async def _wait_for(self, fn):
+ while not fn():
+ await self._waiters.park()
+
+ def close_sender(self):
+ # close while send_all is in progress is undefined
+ self._sender_closed = True
+ self._something_happened()
+
+ def close_receiver(self):
+ self._receiver_closed = True
+ self._something_happened()
+
+ async def send_all(self, data):
+ async with self._send_lock:
+ if self._sender_closed:
+ raise _streams.ClosedStreamError
+ if self._receiver_closed:
+ raise _streams.BrokenStreamError
+ assert not self._data
+ self._data += data
+ self._something_happened()
+ await self._wait_for(
+ lambda: not self._data or self._receiver_closed)
+ if self._data and self._receiver_closed:
+ raise _streams.BrokenStreamError
+ if not self._data:
+ return
+
+ async def wait_send_all_might_not_block(self):
+ async with self._send_lock:
+ if self._sender_closed:
+ raise _streams.ClosedStreamError
+ if self._receiver_closed:
+ return
+ await self._wait_for(
+ lambda: self._receiver_waiting or self._receiver_closed)
+
+ async def receive_some(self, max_bytes):
+ async with self._receive_lock:
+ # Argument validation
+ max_bytes = operator.index(max_bytes)
+ if max_bytes < 1:
+ raise ValueError("max_bytes must be >= 1")
+ # State validation
+ if self._receiver_closed:
+ raise _streams.ClosedStreamError
+ # Wake wait_send_all_might_not_block and wait for data
+ self._receiver_waiting = True
+ self._something_happened()
+ try:
+ await self._wait_for(lambda: self._data or self._sender_closed)
+ finally:
+ self._receiver_waiting = False
+ # Get data, possibly waking send_all
+ if self._data:
+ got = self._data[:max_bytes]
+ del self._data[:max_bytes]
+ self._something_happened()
+ return got
+ else:
+ assert self._sender_closed
+ return b""
+
+
+class _LockstepSendStream(_abc.SendStream):
+ def __init__(self, lbq):
+ self._lbq = lbq
+
+ def forceful_close(self):
+ self._lbq.close_sender()
+
+ async def send_all(self, data):
+ await self._lbq.send_all(data)
+
+ async def wait_send_all_might_not_block(self):
+ await self._lbq.wait_send_all_might_not_block()
+
+
+class _LockstepReceiveStream(_abc.ReceiveStream):
+ def __init__(self, lbq):
+ self._lbq = lbq
+
+ def forceful_close(self):
+ self._lbq.close_receiver()
+
+ async def receive_some(self, max_bytes):
+ return await self._lbq.receive_some(max_bytes)
+
+
+def lockstep_stream_one_way_pair():
+ """Create a connected, pure Python, unidirectional stream where data flows
+ in lockstep.
+
+ Returns:
+ A tuple
+ (:class:`~trio.abc.SendStream`, :class:`~trio.abc.ReceiveStream`).
+
+ This stream has *absolutely no* buffering. Each call to
+ :meth:`~trio.abc.SendStream.send_all` will block until all the given data
+ has been returned by a call to
+ :meth:`~trio.abc.ReceiveStream.receive_some`.
+
+ This can be useful for testing flow control mechanisms in an extreme case,
+ or for setting up "clogged" streams to use with
+ :func:`check_one_way_stream` and friends.
+
+ """
+
+ lbq = _LockstepByteQueue()
+ return _LockstepSendStream(lbq), _LockstepReceiveStream(lbq)
+
+
+def lockstep_stream_pair():
+ """Create a connected, pure-Python, bidirectional stream where data flows
+ in lockstep.
+
+ Returns:
+ A tuple (:class:`~trio.StapledStream`, :class:`~trio.StapledStream`).
+
+ This is a convenience function that creates two one-way streams using
+ :func:`lockstep_stream_one_way_pair`, and then uses
+ :class:`~trio.StapledStream` to combine them into a single bidirectional
+ stream.
+
+ """
+ return _make_stapled_pair(lockstep_stream_one_way_pair)
diff --git a/trio/tests/test_abc.py b/trio/tests/test_abc.py
new file mode 100644
index 0000000000..4c0392ca42
--- /dev/null
+++ b/trio/tests/test_abc.py
@@ -0,0 +1,38 @@
+import pytest
+
+import attr
+
+from ..testing import assert_yields
+from .. import abc as tabc
+
+async def test_AsyncResource_defaults():
+ @attr.s
+ class MyAR(tabc.AsyncResource):
+ record = attr.ib(default=attr.Factory(list))
+
+ def forceful_close(self):
+ self.record.append("fc")
+
+ async with MyAR() as myar:
+ assert isinstance(myar, MyAR)
+ assert myar.record == []
+
+ assert myar.record == ["fc"]
+
+ with assert_yields():
+ await myar.graceful_close()
+ assert myar.record == ["fc", "fc"]
+
+ @attr.s
+ class BadAR(tabc.AsyncResource):
+ record = attr.ib(default=attr.Factory(list))
+
+ def forceful_close(self):
+ self.record.append("boom")
+ raise KeyError
+
+ badar = BadAR()
+ with pytest.raises(KeyError):
+ with assert_yields():
+ await badar.graceful_close()
+ assert badar.record == ["boom"]
diff --git a/trio/tests/test_network.py b/trio/tests/test_network.py
new file mode 100644
index 0000000000..89c51ce6b5
--- /dev/null
+++ b/trio/tests/test_network.py
@@ -0,0 +1,77 @@
+import pytest
+
+import socket as stdlib_socket
+
+from .. import _core
+from ..testing import check_half_closeable_stream, wait_all_tasks_blocked
+from .._network import *
+from .. import socket as tsocket
+
+async def test_SocketStream_basics():
+ # stdlib socket bad (even if connected)
+ a, b = stdlib_socket.socketpair()
+ with a, b:
+ with pytest.raises(TypeError):
+ SocketStream(a)
+
+ # DGRAM socket bad
+ with tsocket.socket(type=tsocket.SOCK_DGRAM) as sock:
+ with pytest.raises(ValueError):
+ SocketStream(sock)
+
+ # disconnected socket bad
+ with tsocket.socket() as sock:
+ with pytest.raises(ValueError):
+ SocketStream(sock)
+
+ a, b = tsocket.socketpair()
+ with a, b:
+ s = SocketStream(a)
+ assert s.socket is a
+
+ # Use a real, connected socket to test socket options, because
+ # socketpair() might give us a unix socket that doesn't support any of
+ # these options
+ with tsocket.socket() as listen_sock:
+ listen_sock.bind(("127.0.0.1", 0))
+ listen_sock.listen(1)
+ with tsocket.socket() as client_sock:
+ await client_sock.connect(listen_sock.getsockname())
+
+ s = SocketStream(client_sock)
+
+ # TCP_NODELAY enabled by default
+ assert s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY)
+ # We can disable it though
+ s.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, False)
+ assert not s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY)
+
+ b = s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, 1)
+ assert isinstance(b, bytes)
+
+
+async def fill_stream(s):
+ async def sender():
+ while True:
+ await s.send_all(b"x" * 10000)
+
+ async def waiter(nursery):
+ await wait_all_tasks_blocked()
+ nursery.cancel_scope.cancel()
+
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(sender)
+ nursery.spawn(waiter, nursery)
+
+
+async def test_SocketStream_and_socket_stream_pair_generic():
+ async def stream_maker():
+ return socket_stream_pair()
+
+ async def clogged_stream_maker():
+ left, right = socket_stream_pair()
+ await fill_stream(left)
+ await fill_stream(right)
+ return left, right
+
+ await check_half_closeable_stream(stream_maker, clogged_stream_maker)
diff --git a/trio/tests/test_socket.py b/trio/tests/test_socket.py
index c405f7eae7..e043ee6ab1 100644
--- a/trio/tests/test_socket.py
+++ b/trio/tests/test_socket.py
@@ -505,26 +505,6 @@ async def test_SocketType_connect_paths():
with pytest.raises(_core.Cancelled):
await sock.connect(("127.0.0.1", 80))
- # Handling InterruptedError
- class InterruptySocket(stdlib_socket.socket):
- def connect(self, *args, **kwargs):
- if not hasattr(self, "_connect_count"):
- self._connect_count = 0
- self._connect_count += 1
- if self._connect_count < 3:
- raise InterruptedError
- else:
- return super().connect(*args, **kwargs)
- with tsocket.socket() as sock, tsocket.socket() as listener:
- listener.bind(("127.0.0.1", 0))
- listener.listen()
- # Swap in our weird subclass under the trio.socket.SocketType's nose
- sock._sock.close()
- sock._sock = InterruptySocket()
- with assert_yields():
- await sock.connect(listener.getsockname())
- assert sock.getpeername() == listener.getsockname()
-
# Cancelled in between the connect() call and the connect completing
with _core.open_cancel_scope() as cancel_scope:
with tsocket.socket() as sock, tsocket.socket() as listener:
@@ -550,14 +530,17 @@ def connect(self, *args, **kwargs):
assert sock.fileno() == -1
# Failed connect (hopefully after raising BlockingIOError)
- with tsocket.socket() as sock, tsocket.socket() as non_listener:
- # Claim an unused port
- non_listener.bind(("127.0.0.1", 0))
- # ...but don't call listen, so we're guaranteed that connect attempts
- # to it will fail.
+ with tsocket.socket() as sock:
with assert_yields():
with pytest.raises(OSError):
- await sock.connect(non_listener.getsockname())
+ # TCP port 2 is not assigned. Pretty sure nothing will be
+ # listening there. (We used to bind a port and then *not* call
+ # listen() to ensure nothing was listening there, but it turns
+ # out on MacOS if you do this it takes 30 seconds for the
+ # connect to fail. Really. Also if you use a non-routable
+ # address. This way fails instantly though. As long as nothing
+ # is listening on port 2.)
+ await sock.connect(("127.0.0.1", 2))
async def test_send_recv_variants():
@@ -659,48 +642,45 @@ async def test_SocketType_sendall():
# Check a sendall that has to be split into multiple parts (on most
# platforms... on Windows every send() either succeeds or fails as a
# whole)
- async with _core.open_nursery() as nursery:
- send_task = nursery.spawn(a.sendall, b"x" * BIG)
+ async def sender():
+ data = bytearray(BIG)
+ await a.sendall(data)
+ # sendall uses memoryviews internally, which temporarily "lock"
+ # the object they view. If it doesn't clean them up properly, then
+ # some bytearray operations might raise an error afterwards, which
+ # would be a pretty weird and annoying side-effect to spring on
+ # users. So test that this doesn't happen, by forcing the
+ # bytearray's underlying buffer to be realloc'ed:
+ data += bytes(BIG)
+ # (Note: the above line of code doesn't do a very good job at
+ # testing anything, because:
+ # - on CPython, the refcount GC generally cleans up memoryviews
+ # for us even if we're sloppy.
+ # - on PyPy3, at least as of 5.7.0, the memoryview code and the
+ # bytearray code conspire so that resizing never fails – if
+ # resizing forces the bytearray's internal buffer to move, then
+ # all memoryview references are automagically updated (!!).
+ # See:
+ # https://gist.github.com/njsmith/0ffd38ec05ad8e34004f34a7dc492227
+ # But I'm leaving the test here in hopes that if this ever changes
+ # and we break our implementation of sendall, then we'll get some
+ # early warning...)
+
+ async def receiver():
+ # Make sure the sender fills up the kernel buffers and blocks
await wait_all_tasks_blocked()
nbytes = 0
while nbytes < BIG:
nbytes += len(await b.recv(BIG))
- assert send_task.result is not None
assert nbytes == BIG
- with pytest.raises(BlockingIOError):
- b._sock.recv(1)
- a, b = tsocket.socketpair()
- with a, b:
- # Cancel half-way through
async with _core.open_nursery() as nursery:
- sent_complete = 0
- async def sendall_until_cancelled():
- nonlocal sent_complete
- # Need to loop to make sure that we actually do block on
- # Windows
- while True:
- await a.sendall(b"x" * BIG)
- sent_complete += BIG
- send_task = nursery.spawn(sendall_until_cancelled)
- await wait_all_tasks_blocked()
- nursery.cancel_scope.cancel()
- assert type(send_task.result) is _core.Error
- assert isinstance(send_task.result.error, _core.Cancelled)
- sent_partial = send_task.result.error.partial_result.bytes_sent
- a.close()
- sent_total = 0
- while True:
- got = len(await b.recv(BIG))
- if not got:
- break
- sent_total += got
- assert sent_complete + sent_partial == sent_total
+ nursery.spawn(sender)
+ nursery.spawn(receiver)
- a, b = tsocket.socketpair()
- with a, b:
- # A different error
- a.close()
- with pytest.raises(OSError) as excinfo:
- await a.sendall(b"x")
- assert excinfo.value.partial_result.bytes_sent == 0
+ # We know that we received BIG bytes of NULs so far. Make sure that
+ # was all the data in there.
+ await a.sendall(b"e")
+ assert await b.recv(10) == b"e"
+ a.shutdown(tsocket.SHUT_WR)
+ assert await b.recv(10) == b""
diff --git a/trio/tests/test_ssl.py b/trio/tests/test_ssl.py
new file mode 100644
index 0000000000..0193dd4b93
--- /dev/null
+++ b/trio/tests/test_ssl.py
@@ -0,0 +1,1036 @@
+import pytest
+
+from pathlib import Path
+import threading
+import socket as stdlib_socket
+import ssl as stdlib_ssl
+from contextlib import contextmanager
+
+from OpenSSL import SSL
+
+import trio
+from .. import _core
+from .. import _network
+from .._streams import BrokenStreamError, ClosedStreamError
+from .. import ssl as tssl
+from .. import socket as tsocket
+from .._util import UnLock
+
+from .._core.tests.tutil import slow
+
+from ..testing import (
+ assert_yields, Sequencer, memory_stream_pair, lockstep_stream_pair,
+ check_two_way_stream,
+)
+
+ASSETS_DIR = Path(__file__).parent / "test_ssl_certs"
+CA = str(ASSETS_DIR / "trio-test-CA.pem")
+CERT1 = str(ASSETS_DIR / "trio-test-1.pem")
+
+
+# We have two different kinds of echo server fixtures we use for testing. The
+# first is a real server written using the stdlib ssl module and blocking
+# sockets. It runs in a thread and we talk to it over a real socketpair(), to
+# validate interoperability in a semi-realistic setting.
+#
+# The second is a very weird virtual echo server that lives inside a custom
+# Stream class. It lives entirely inside the Python object space; there are no
+# operating system calls in it at all. No threads, no I/O, nothing. It's
+# 'send_all' call takes encrypted data from a client and feeds it directly into
+# the server-side TLS state engine to decrypt, then takes that data, feeds it
+# back through to get the encrypted response, and returns it from 'receive_some'. This
+# gives us full control and reproducibility. This server is written using
+# PyOpenSSL, so that we can trigger renegotiations on demand. It also allows
+# us to insert random (virtual) delays, to really exercise all the weird paths
+# in SSLStream's state engine.
+#
+# Both present a certificate for "trio-test-1.example.org", that's signed by
+# trio-test-CA.pem, an extremely trustworthy CA.
+
+
+SERVER_CTX = stdlib_ssl.create_default_context(
+ stdlib_ssl.Purpose.CLIENT_AUTH,
+)
+SERVER_CTX.load_cert_chain(CERT1)
+
+CLIENT_CTX = stdlib_ssl.create_default_context(cafile=CA)
+
+# The blocking socket server.
+def ssl_echo_serve_sync(sock, *, expect_fail=False):
+ try:
+ wrapped = SERVER_CTX.wrap_socket(sock, server_side=True)
+ wrapped.do_handshake()
+ while True:
+ data = wrapped.recv(4096)
+ if not data:
+ # graceful shutdown
+ wrapped.unwrap()
+ return
+ wrapped.sendall(data)
+ except Exception as exc:
+ if expect_fail:
+ print("ssl_echo_serve_sync got error as expected:", exc)
+ else: # pragma: no cover
+ raise
+ else:
+ if expect_fail: # pragma: no cover
+ print("failed to fail?!")
+
+
+# Fixture that gives a raw socket connected to a trio-test-1 echo server
+# (running in a thread). Useful for testing making connections with different
+# SSLContexts.
+@contextmanager
+def ssl_echo_server_raw(**kwargs):
+ a, b = stdlib_socket.socketpair()
+ with a, b:
+ t = threading.Thread(
+ target=ssl_echo_serve_sync,
+ args=(b,),
+ kwargs=kwargs,
+ )
+ t.start()
+
+ yield _network.SocketStream(tsocket.from_stdlib_socket(a))
+
+ # exiting the context manager closes the sockets, which should force the
+ # thread to shut down (possibly with an error)
+ t.join()
+
+
+# Fixture that gives a properly set up SSLStream connected to a trio-test-1
+# echo server (running in a thread)
+@contextmanager
+def ssl_echo_server(**kwargs):
+ with ssl_echo_server_raw(**kwargs) as sock:
+ yield tssl.SSLStream(
+ sock, CLIENT_CTX, server_hostname="trio-test-1.example.org")
+
+
+# The weird in-memory server ... thing.
+# Doesn't inherit from Stream because I left out the methods that we don't
+# actually need.
+class PyOpenSSLEchoStream:
+ def __init__(self, sleeper=None):
+ ctx = SSL.Context(SSL.SSLv23_METHOD)
+ # TLS 1.3 removes renegotiation support. Which is great for them, but
+ # we still have to support versions before that, and that means we
+ # need to test renegotation support, which means we need to force this
+ # to use a lower version where this test server can trigger
+ # renegotiations. Of course TLS 1.3 support isn't released yet, but
+ # I'm told that this will work once it is. (And once it is we can
+ # remove the pragma: no cover too.) Alternatively, once we drop
+ # support for CPython 3.5 on MacOS, then we could switch to using
+ # TLSv1_2_METHOD.
+ #
+ # Discussion: https://github.com/pyca/pyopenssl/issues/624
+ if hasattr(SSL, "OP_NO_TLSv1_3"): # pragma: no cover
+ ctx.set_options(SSL.OP_NO_TLSv1_3)
+ # Unfortunately there's currently no way to say "use 1.3 or worse", we
+ # can only disable specific versions. And if the two sides start
+ # negotiating 1.4 at some point in the future, it *might* mean that
+ # our tests silently stop working properly. So the next line is a
+ # tripwire to remind us we need to revisit this stuff in 5 years or
+ # whatever when the next TLS version is released:
+ assert not hasattr(SSL, "OP_NO_TLSv1_4")
+ ctx.use_certificate_file(CERT1)
+ ctx.use_privatekey_file(CERT1)
+ self._conn = SSL.Connection(ctx, None)
+ self._conn.set_accept_state()
+ self._lot = _core.ParkingLot()
+ self._pending_cleartext = bytearray()
+
+ self._send_all_mutex = UnLock(
+ _core.ResourceBusyError,
+ "simultaneous calls to PyOpenSSLEchoStream.send_all")
+ self._receive_some_mutex = UnLock(
+ _core.ResourceBusyError,
+ "simultaneous calls to PyOpenSSLEchoStream.receive_some")
+
+ if sleeper is None:
+ async def no_op_sleeper(_):
+ return
+ self.sleeper = no_op_sleeper
+ else:
+ self.sleeper = sleeper
+
+ def forceful_close(self):
+ self._conn.bio_shutdown()
+
+ async def graceful_close(self):
+ self.forceful_close()
+
+ def renegotiate_pending(self):
+ return self._conn.renegotiate_pending()
+
+ def renegotiate(self):
+ # Returns false if a renegotation is already in progress, meaning
+ # nothing happens.
+ assert self._conn.renegotiate()
+
+ async def wait_send_all_might_not_block(self):
+ async with self._send_all_mutex:
+ await _core.yield_briefly()
+ await self.sleeper("wait_send_all_might_not_block")
+
+ async def send_all(self, data):
+ print(" --> transport_stream.send_all")
+ async with self._send_all_mutex:
+ await _core.yield_briefly()
+ await self.sleeper("send_all")
+ self._conn.bio_write(data)
+ while True:
+ await self.sleeper("send_all")
+ try:
+ data = self._conn.recv(1)
+ except SSL.ZeroReturnError:
+ self._conn.shutdown()
+ print("renegotiations:", self._conn.total_renegotiations())
+ break
+ except SSL.WantReadError:
+ break
+ else:
+ self._pending_cleartext += data
+ self._lot.unpark_all()
+ await self.sleeper("send_all")
+ print(" <-- transport_stream.send_all finished")
+
+ async def receive_some(self, nbytes):
+ print(" --> transport_stream.receive_some")
+ async with self._receive_some_mutex:
+ try:
+ await _core.yield_briefly()
+ while True:
+ await self.sleeper("receive_some")
+ try:
+ return self._conn.bio_read(nbytes)
+ except SSL.WantReadError:
+ # No data in our ciphertext buffer; try to generate
+ # some.
+ if self._pending_cleartext:
+ # We have some cleartext; maybe we can encrypt it
+ # and then return it.
+ print(" trying", self._pending_cleartext)
+ try:
+ # PyOpenSSL bug: doesn't accept bytearray
+ # https://github.com/pyca/pyopenssl/issues/621
+ next_byte = self._pending_cleartext[0:1]
+ self._conn.send(bytes(next_byte))
+ # Apparently this next bit never gets hit in the
+ # test suite, but it's not an interesting omission
+ # so let's pragma it.
+ except SSL.WantReadError: # pragma: no cover
+ # We didn't manage to send the cleartext (and
+ # in particular we better leave it there to
+ # try again, due to openssl's retry
+ # semantics), but it's possible we pushed a
+ # renegotiation forward and *now* we have data
+ # to send.
+ try:
+ return self._conn.bio_read(nbytes)
+ except SSL.WantReadError:
+ # Nope. We're just going to have to wait
+ # for someone to call send_all() to give
+ # use more data.
+ print("parking (a)")
+ await self._lot.park()
+ else:
+ # We successfully sent that byte, so we don't
+ # have to again.
+ del self._pending_cleartext[0:1]
+ else:
+ # no pending cleartext; nothing to do but wait for
+ # someone to call send_all
+ print("parking (b)")
+ await self._lot.park()
+ finally:
+ await self.sleeper("receive_some")
+ print(" <-- transport_stream.receive_some finished")
+
+
+async def test_PyOpenSSLEchoStream_gives_resource_busy_errors():
+ # Make sure that PyOpenSSLEchoStream complains if two tasks call send_all
+ # at the same time, or ditto for receive_some. The tricky cases where SSLStream
+ # might accidentally do this are during renegotation, which we test using
+ # PyOpenSSLEchoStream, so this makes sure that if we do have a bug then
+ # PyOpenSSLEchoStream will notice and complain.
+
+ s = PyOpenSSLEchoStream()
+ with pytest.raises(_core.ResourceBusyError) as excinfo:
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(s.send_all, b"x")
+ nursery.spawn(s.send_all, b"x")
+ assert "simultaneous" in str(excinfo.value)
+
+ s = PyOpenSSLEchoStream()
+ with pytest.raises(_core.ResourceBusyError) as excinfo:
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(s.send_all, b"x")
+ nursery.spawn(s.wait_send_all_might_not_block)
+ assert "simultaneous" in str(excinfo.value)
+
+ s = PyOpenSSLEchoStream()
+ with pytest.raises(_core.ResourceBusyError) as excinfo:
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(s.wait_send_all_might_not_block)
+ nursery.spawn(s.wait_send_all_might_not_block)
+ assert "simultaneous" in str(excinfo.value)
+
+ s = PyOpenSSLEchoStream()
+ with pytest.raises(_core.ResourceBusyError) as excinfo:
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(s.receive_some, 1)
+ nursery.spawn(s.receive_some, 1)
+ assert "simultaneous" in str(excinfo.value)
+
+
+@contextmanager
+def virtual_ssl_echo_server(**kwargs):
+ fakesock = PyOpenSSLEchoStream(**kwargs)
+ yield tssl.SSLStream(
+ fakesock, CLIENT_CTX, server_hostname="trio-test-1.example.org")
+
+
+def ssl_wrap_pair(client_transport, server_transport,
+ *, client_kwargs={}, server_kwargs={}):
+ client_ssl = tssl.SSLStream(
+ client_transport, CLIENT_CTX,
+ server_hostname="trio-test-1.example.org", **client_kwargs)
+ server_ssl = tssl.SSLStream(
+ server_transport, SERVER_CTX, server_side=True, **server_kwargs)
+ return client_ssl, server_ssl
+
+def ssl_memory_stream_pair(**kwargs):
+ client_transport, server_transport = memory_stream_pair()
+ return ssl_wrap_pair(client_transport, server_transport, **kwargs)
+
+def ssl_lockstep_stream_pair(**kwargs):
+ client_transport, server_transport = lockstep_stream_pair()
+ return ssl_wrap_pair(client_transport, server_transport, **kwargs)
+
+
+def test_exports():
+ # Just a quick check to make sure _reexport isn't totally broken
+ assert hasattr(tssl, "SSLError")
+ assert "SSLError" in tssl.__all__
+
+ assert hasattr(tssl, "Purpose")
+ assert "Purpose" in tssl.__all__
+
+ # Intentionally omitted
+ assert not hasattr(tssl, "SSLContext")
+
+
+# Simple smoke test for handshake/send/receive/shutdown talking to a
+# synchronous server, plus make sure that we do the bare minimum of
+# certificate checking (even though this is really Python's responsibility)
+async def test_ssl_client_basics():
+ # Everything OK
+ with ssl_echo_server() as s:
+ assert not s.server_side
+ await s.send_all(b"x")
+ assert await s.receive_some(1) == b"x"
+ await s.graceful_close()
+
+ # Didn't configure the CA file, should fail
+ with ssl_echo_server_raw(expect_fail=True) as sock:
+ client_ctx = stdlib_ssl.create_default_context()
+ s = tssl.SSLStream(
+ sock, client_ctx, server_hostname="trio-test-1.example.org")
+ assert not s.server_side
+ with pytest.raises(BrokenStreamError) as excinfo:
+ await s.send_all(b"x")
+ assert isinstance(excinfo.value.__cause__, tssl.SSLError)
+
+ # Trusted CA, but wrong host name
+ with ssl_echo_server_raw(expect_fail=True) as sock:
+ s = tssl.SSLStream(
+ sock, CLIENT_CTX, server_hostname="trio-test-2.example.org")
+ assert not s.server_side
+ with pytest.raises(BrokenStreamError) as excinfo:
+ await s.send_all(b"x")
+ assert isinstance(excinfo.value.__cause__, tssl.CertificateError)
+
+
+async def test_ssl_server_basics():
+ a, b = stdlib_socket.socketpair()
+ with a, b:
+ server_sock = tsocket.from_stdlib_socket(b)
+ server_transport = tssl.SSLStream(
+ _network.SocketStream(server_sock), SERVER_CTX, server_side=True)
+ assert server_transport.server_side
+
+ def client():
+ client_sock = CLIENT_CTX.wrap_socket(
+ a, server_hostname="trio-test-1.example.org")
+ client_sock.sendall(b"x")
+ assert client_sock.recv(1) == b"y"
+ client_sock.sendall(b"z")
+ client_sock.unwrap()
+ t = threading.Thread(target=client)
+ t.start()
+
+ assert await server_transport.receive_some(1) == b"x"
+ await server_transport.send_all(b"y")
+ assert await server_transport.receive_some(1) == b"z"
+ assert await server_transport.receive_some(1) == b""
+ await server_transport.graceful_close()
+
+ t.join()
+
+
+async def test_attributes():
+ with ssl_echo_server_raw(expect_fail=True) as sock:
+ good_ctx = CLIENT_CTX
+ bad_ctx = stdlib_ssl.create_default_context()
+ s = tssl.SSLStream(
+ sock, good_ctx, server_hostname="trio-test-1.example.org")
+
+ assert s.transport_stream is sock
+
+ # Forwarded attribute getting
+ assert s.context is good_ctx
+ assert s.server_side == False
+ assert s.server_hostname == "trio-test-1.example.org"
+ with pytest.raises(AttributeError):
+ s.asfdasdfsa
+
+ # __dir__
+ assert "transport_stream" in dir(s)
+ assert "context" in dir(s)
+
+ # Setting the attribute goes through to the underlying object
+
+ # most attributes on SSLObject are read-only
+ with pytest.raises(AttributeError):
+ s.server_side = True
+ with pytest.raises(AttributeError):
+ s.server_hostname = "asdf"
+
+ # but .context is *not*. Check that we forward attribute setting by
+ # making sure that after we set the bad context our handshake indeed
+ # fails:
+ s.context = bad_ctx
+ assert s.context is bad_ctx
+ with pytest.raises(BrokenStreamError) as excinfo:
+ await s.do_handshake()
+ assert isinstance(excinfo.value.__cause__, tssl.SSLError)
+
+
+# Note: this test fails horribly if we force TLS 1.2 and trigger a
+# renegotiation at the beginning (e.g. by switching to the pyopenssl
+# server). Usually the client crashes in SSLObject.write with "UNEXPECTED
+# RECORD"; sometimes we get something more exotic like a SyscallError. This is
+# odd because openssl isn't doing any syscalls, but so it goes. After lots of
+# websearching I'm pretty sure this is due to a bug in OpenSSL, where it just
+# can't reliably handle full-duplex communication combined with
+# renegotiation. Nice, eh?
+#
+# https://rt.openssl.org/Ticket/Display.html?id=3712
+# https://rt.openssl.org/Ticket/Display.html?id=2481
+# http://openssl.6102.n7.nabble.com/TLS-renegotiation-failure-on-receiving-application-data-during-handshake-td48127.html
+# https://stackoverflow.com/questions/18728355/ssl-renegotiation-with-full-duplex-socket-communication
+#
+# In some variants of this test (maybe only against the java server?) I've
+# also seen cases where our send_all blocks waiting to write, and then our receive_some
+# also blocks waiting to write, and they never wake up again. It looks like
+# some kind of deadlock. I suspect there may be an issue where we've filled up
+# the send buffers, and the remote side is trying to handle the renegotiation
+# from inside a write() call, so it has a problem: there's all this application
+# data clogging up the pipe, but it can't process and return it to the
+# application because it's in write(), and it doesn't want to buffer infinite
+# amounts of data, and... actually I guess those are the only two choices.
+#
+# NSS even documents that you shouldn't try to do a renegotiation except when
+# the connection is idle:
+#
+# https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/SSL_functions/sslfnc.html#1061582
+#
+# I begin to see why HTTP/2 forbids renegotiation and TLS 1.3 removes it...
+
+async def test_full_duplex_basics():
+ CHUNKS = 30
+ CHUNK_SIZE = 32768
+ EXPECTED = CHUNKS * CHUNK_SIZE
+
+ sent = bytearray()
+ received = bytearray()
+
+ async def sender(s):
+ nonlocal sent
+ for i in range(CHUNKS):
+ print(i)
+ chunk = bytes([i] * CHUNK_SIZE)
+ sent += chunk
+ await s.send_all(chunk)
+
+ async def receiver(s):
+ nonlocal received
+ while len(received) < EXPECTED:
+ chunk = await s.receive_some(CHUNK_SIZE // 2)
+ received += chunk
+
+ with ssl_echo_server() as s:
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(sender, s)
+ nursery.spawn(receiver, s)
+ # And let's have some doing handshakes too, everyone
+ # simultaneously
+ nursery.spawn(s.do_handshake)
+ nursery.spawn(s.do_handshake)
+
+ await s.graceful_close()
+
+ assert len(sent) == len(received) == EXPECTED
+ assert sent == received
+
+
+async def test_renegotiation_simple():
+ with virtual_ssl_echo_server() as s:
+ await s.do_handshake()
+
+ s.transport_stream.renegotiate()
+ await s.send_all(b"a")
+ assert await s.receive_some(1) == b"a"
+
+ # Have to send some more data back and forth to make sure the
+ # renegotiation is finished before shutting down the
+ # connection... otherwise openssl raises an error. I think this is a
+ # bug in openssl but what can ya do.
+ await s.send_all(b"b")
+ assert await s.receive_some(1) == b"b"
+
+ await s.graceful_close()
+
+
+@slow
+async def test_renegotiation_randomized(mock_clock):
+ # The only blocking things in this function are our random sleeps, so 0 is
+ # a good threshold.
+ mock_clock.autojump_threshold = 0
+
+ import random
+ r = random.Random(0)
+
+ async def sleeper(_):
+ await trio.sleep(r.uniform(0, 10))
+
+ async def clear():
+ while s.transport_stream.renegotiate_pending():
+ with assert_yields():
+ await send(b"-")
+ with assert_yields():
+ await expect(b"-")
+ print("-- clear --")
+
+ async def send(byte):
+ await s.transport_stream.sleeper("outer send")
+ print("calling SSLStream.send_all", byte)
+ with assert_yields():
+ await s.send_all(byte)
+
+ async def expect(expected):
+ await s.transport_stream.sleeper("expect")
+ print("calling SSLStream.receive_some, expecting", expected)
+ assert len(expected) == 1
+ with assert_yields():
+ assert await s.receive_some(1) == expected
+
+ with virtual_ssl_echo_server(sleeper=sleeper) as s:
+ await s.do_handshake()
+
+ await send(b"a")
+ s.transport_stream.renegotiate()
+ await expect(b"a")
+
+ await clear()
+
+ for i in range(100):
+ b1 = bytes([i % 0xff])
+ b2 = bytes([(2 * i) % 0xff])
+ s.transport_stream.renegotiate()
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(send, b1)
+ nursery.spawn(expect, b1)
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(expect, b2)
+ nursery.spawn(send, b2)
+ await clear()
+
+ for i in range(100):
+ b1 = bytes([i % 0xff])
+ b2 = bytes([(2 * i) % 0xff])
+ await send(b1)
+ s.transport_stream.renegotiate()
+ await expect(b1)
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(expect, b2)
+ nursery.spawn(send, b2)
+ await clear()
+
+ # Checking that wait_send_all_might_not_block and receive_some don't
+ # conflict:
+
+ # 1) Set up a situation where expect (receive_some) is blocked sending,
+ # and wait_send_all_might_not_block comes in.
+
+ # Our receive_some() call will get stuck when it hits send_all
+ async def sleeper_with_slow_send_all(method):
+ if method == "send_all":
+ await trio.sleep(100000)
+
+ # And our wait_send_all_might_not_block call will give it time to get
+ # stuck, and then start
+ async def sleep_then_wait_writable():
+ await trio.sleep(1000)
+ await s.wait_send_all_might_not_block()
+
+ with virtual_ssl_echo_server(sleeper=sleeper_with_slow_send_all) as s:
+ await send(b"x")
+ s.transport_stream.renegotiate()
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(expect, b"x")
+ nursery.spawn(sleep_then_wait_writable)
+
+ await clear()
+
+ await s.graceful_close()
+
+ # 2) Same, but now wait_send_all_might_not_block is stuck when
+ # receive_some tries to send.
+
+ async def sleeper_with_slow_wait_writable_and_expect(method):
+ if method == "wait_send_all_might_not_block":
+ await trio.sleep(100000)
+ elif method == "expect":
+ await trio.sleep(1000)
+
+ with virtual_ssl_echo_server(
+ sleeper=sleeper_with_slow_wait_writable_and_expect) as s:
+ await send(b"x")
+ s.transport_stream.renegotiate()
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(expect, b"x")
+ nursery.spawn(s.wait_send_all_might_not_block)
+
+ await clear()
+
+ await s.graceful_close()
+
+
+async def test_resource_busy_errors():
+ async def do_send_all():
+ with assert_yields():
+ await s.send_all(b"x")
+
+ async def do_receive_some():
+ with assert_yields():
+ await s.receive_some(1)
+
+ async def do_wait_send_all_might_not_block():
+ with assert_yields():
+ await s.wait_send_all_might_not_block()
+
+ s, _ = ssl_lockstep_stream_pair()
+ with pytest.raises(_core.ResourceBusyError) as excinfo:
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(do_send_all)
+ nursery.spawn(do_send_all)
+ assert "another task" in str(excinfo.value)
+
+ s, _ = ssl_lockstep_stream_pair()
+ with pytest.raises(_core.ResourceBusyError) as excinfo:
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(do_receive_some)
+ nursery.spawn(do_receive_some)
+ assert "another task" in str(excinfo.value)
+
+ s, _ = ssl_lockstep_stream_pair()
+ with pytest.raises(_core.ResourceBusyError) as excinfo:
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(do_send_all)
+ nursery.spawn(do_wait_send_all_might_not_block)
+ assert "another task" in str(excinfo.value)
+
+ s, _ = ssl_lockstep_stream_pair()
+ with pytest.raises(_core.ResourceBusyError) as excinfo:
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(do_wait_send_all_might_not_block)
+ nursery.spawn(do_wait_send_all_might_not_block)
+ assert "another task" in str(excinfo.value)
+
+
+async def test_wait_writable_calls_underlying_wait_writable():
+ record = []
+ class NotAStream:
+ async def wait_send_all_might_not_block(self):
+ record.append("ok")
+ ctx = stdlib_ssl.create_default_context()
+ s = tssl.SSLStream(NotAStream(), ctx, server_hostname="x")
+ await s.wait_send_all_might_not_block()
+ assert record == ["ok"]
+
+
+async def test_checkpoints():
+ with ssl_echo_server() as s:
+ with assert_yields():
+ await s.do_handshake()
+ with assert_yields():
+ await s.do_handshake()
+ with assert_yields():
+ await s.wait_send_all_might_not_block()
+ with assert_yields():
+ await s.send_all(b"xxx")
+ with assert_yields():
+ await s.receive_some(1)
+ # These receive_some's in theory could return immediately, because the
+ # "xxx" was sent in a single record and after the first
+ # receive_some(1) the rest are sitting inside the SSLObject's internal
+ # buffers.
+ with assert_yields():
+ await s.receive_some(1)
+ with assert_yields():
+ await s.receive_some(1)
+ with assert_yields():
+ await s.unwrap()
+
+ with ssl_echo_server() as s:
+ await s.do_handshake()
+ with assert_yields():
+ await s.graceful_close()
+
+
+async def test_send_all_empty_string():
+ with ssl_echo_server() as s:
+ await s.do_handshake()
+
+ # underlying SSLObject interprets writing b"" as indicating an EOF,
+ # for some reason. Make sure we don't inherit this.
+ with assert_yields():
+ await s.send_all(b"")
+ with assert_yields():
+ await s.send_all(b"")
+ await s.send_all(b"x")
+ assert await s.receive_some(1) == b"x"
+
+ await s.graceful_close()
+
+
+@pytest.mark.parametrize("https_compatible", [False, True])
+async def test_SSLStream_generic(https_compatible):
+ async def stream_maker():
+ return ssl_memory_stream_pair(
+ client_kwargs={"https_compatible": https_compatible},
+ server_kwargs={"https_compatible": https_compatible},
+ )
+
+ async def clogged_stream_maker():
+ client, server = ssl_lockstep_stream_pair()
+ # If we don't do handshakes up front, then we run into a problem in
+ # the following situation:
+ # - server does wait_send_all_might_not_block
+ # - client does receive_some to unclog it
+ # Then the client's receive_some will actually send some data to start
+ # the handshake, and itself get stuck.
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(client.do_handshake)
+ nursery.spawn(server.do_handshake)
+ return client, server
+
+ await check_two_way_stream(stream_maker, clogged_stream_maker)
+
+
+async def test_unwrap():
+ client_ssl, server_ssl = ssl_memory_stream_pair()
+ client_transport = client_ssl.transport_stream
+ server_transport = server_ssl.transport_stream
+
+ seq = Sequencer()
+
+ async def client():
+ await client_ssl.do_handshake()
+ await client_ssl.send_all(b"x")
+ assert await client_ssl.receive_some(1) == b"y"
+ await client_ssl.send_all(b"z")
+
+ # After sending that, disable outgoing data from our end, to make
+ # sure the server doesn't see our EOF until after we've sent some
+ # trailing data
+ async with seq(0):
+ send_all_hook = client_transport.send_stream.send_all_hook
+ client_transport.send_stream.send_all_hook = None
+
+ assert await client_ssl.receive_some(1) == b""
+ assert client_ssl.transport_stream is client_transport
+ # We just received EOF. Unwrap the connection and send some more.
+ raw, trailing = await client_ssl.unwrap()
+ assert raw is client_transport
+ assert trailing == b""
+ assert client_ssl.transport_stream is None
+ await raw.send_all(b"trailing")
+
+ # Reconnect the streams. Now the server will receive both our shutdown
+ # acknowledgement + the trailing data in a single lump.
+ client_transport.send_stream.send_all_hook = send_all_hook
+ await client_transport.send_stream.send_all_hook()
+
+ async def server():
+ await server_ssl.do_handshake()
+ assert await server_ssl.receive_some(1) == b"x"
+ await server_ssl.send_all(b"y")
+ assert await server_ssl.receive_some(1) == b"z"
+ # Now client is blocked waiting for us to send something, but
+ # instead we close the TLS connection (with sequencer to make sure
+ # that the client won't see and automatically respond before we've had
+ # a chance to disable the client->server transport)
+ async with seq(1):
+ raw, trailing = await server_ssl.unwrap()
+ assert raw is server_transport
+ assert trailing == b"trailing"
+ assert server_ssl.transport_stream is None
+
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(client)
+ nursery.spawn(server)
+
+
+async def test_closing_nice_case():
+ # the nice case: graceful closes all around
+
+ client_ssl, server_ssl = ssl_memory_stream_pair()
+ client_transport = client_ssl.transport_stream
+
+ # Both the handshake and the close require back-and-forth discussion, so
+ # we need to run them concurrently
+ async def client_closer():
+ with assert_yields():
+ await client_ssl.graceful_close()
+
+ async def server_closer():
+ assert await server_ssl.receive_some(10) == b""
+ assert await server_ssl.receive_some(10) == b""
+ with assert_yields():
+ await server_ssl.graceful_close()
+
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(client_closer)
+ nursery.spawn(server_closer)
+
+ # closing the SSLStream also closes its transport
+ with pytest.raises(ClosedStreamError):
+ await client_transport.send_all(b"123")
+
+ # once closed, it's OK to close again
+ with assert_yields():
+ await client_ssl.graceful_close()
+ client_ssl.forceful_close()
+
+ # Trying to send more data does not work
+ with assert_yields():
+ with pytest.raises(ClosedStreamError):
+ await server_ssl.send_all(b"123")
+
+ # And once the connection is has been closed *locally*, then instead of
+ # getting empty bytestrings we get a proper error
+ with assert_yields():
+ with pytest.raises(ClosedStreamError):
+ await client_ssl.receive_some(10) == b""
+
+ with assert_yields():
+ with pytest.raises(ClosedStreamError):
+ await client_ssl.unwrap()
+
+ with assert_yields():
+ with pytest.raises(ClosedStreamError):
+ await client_ssl.do_handshake()
+
+ # Check that a graceful close *before* handshaking gives a clean EOF on
+ # the other side
+ client_ssl, server_ssl = ssl_memory_stream_pair()
+ async def expect_eof_server():
+ with assert_yields():
+ assert await server_ssl.receive_some(10) == b""
+ with assert_yields():
+ await server_ssl.graceful_close()
+
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(client_ssl.graceful_close)
+ nursery.spawn(expect_eof_server)
+
+
+async def test_send_all_fails_in_the_middle():
+ client, server = ssl_memory_stream_pair()
+
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(client.do_handshake)
+ nursery.spawn(server.do_handshake)
+
+ async def bad_hook():
+ raise KeyError
+
+ client.transport_stream.send_stream.send_all_hook = bad_hook
+
+ with pytest.raises(KeyError):
+ await client.send_all(b"x")
+
+ with pytest.raises(BrokenStreamError):
+ await client.wait_send_all_might_not_block()
+
+ closed = 0
+ def close_hook():
+ nonlocal closed
+ closed += 1
+
+ client.transport_stream.send_stream.close_hook = close_hook
+ client.transport_stream.receive_stream.close_hook = close_hook
+ await client.graceful_close()
+
+ assert closed == 2
+
+
+async def test_ssl_over_ssl():
+ client_0, server_0 = memory_stream_pair()
+
+ client_1 = tssl.SSLStream(
+ client_0, CLIENT_CTX, server_hostname="trio-test-1.example.org")
+ server_1 = tssl.SSLStream(
+ server_0, SERVER_CTX, server_side=True)
+
+ client_2 = tssl.SSLStream(
+ client_1, CLIENT_CTX, server_hostname="trio-test-1.example.org")
+ server_2 = tssl.SSLStream(
+ server_1, SERVER_CTX, server_side=True)
+
+ async def client():
+ await client_2.send_all(b"hi")
+ assert await client_2.receive_some(10) == b"bye"
+
+ async def server():
+ assert await server_2.receive_some(10) == b"hi"
+ await server_2.send_all(b"bye")
+
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(client)
+ nursery.spawn(server)
+
+
+async def test_ssl_bad_shutdown():
+ client, server = ssl_memory_stream_pair()
+
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(client.do_handshake)
+ nursery.spawn(server.do_handshake)
+
+ client.forceful_close()
+ # now the server sees a broken stream
+ with pytest.raises(BrokenStreamError):
+ await server.receive_some(10)
+ with pytest.raises(BrokenStreamError):
+ await server.send_all(b"x" * 10)
+
+ await server.graceful_close()
+
+
+async def test_ssl_bad_shutdown_but_its_ok():
+ client, server = ssl_memory_stream_pair(
+ server_kwargs={"https_compatible": True},
+ client_kwargs={"https_compatible": True})
+
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(client.do_handshake)
+ nursery.spawn(server.do_handshake)
+
+ client.forceful_close()
+ # the server sees that as a clean shutdown
+ assert await server.receive_some(10) == b""
+ with pytest.raises(BrokenStreamError):
+ await server.send_all(b"x" * 10)
+
+ await server.graceful_close()
+
+
+async def test_ssl_https_compatibility_disagreement():
+ client, server = ssl_memory_stream_pair(
+ server_kwargs={"https_compatible": False},
+ client_kwargs={"https_compatible": True})
+
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(client.do_handshake)
+ nursery.spawn(server.do_handshake)
+
+ # client is in HTTPS-mode, server is not
+ # so client doing graceful_shutdown causes an error on server
+ async def receive_and_expect_error():
+ with pytest.raises(BrokenStreamError) as excinfo:
+ await server.receive_some(10)
+ assert isinstance(excinfo.value.__cause__, tssl.SSLEOFError)
+
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(client.graceful_close)
+ nursery.spawn(receive_and_expect_error)
+
+
+async def test_https_mode_eof_before_handshake():
+ client, server = ssl_memory_stream_pair(
+ server_kwargs={"https_compatible": True},
+ client_kwargs={"https_compatible": True})
+
+ async def server_expect_clean_eof():
+ assert await server.receive_some(10) == b""
+
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(client.graceful_close)
+ nursery.spawn(server_expect_clean_eof)
+
+
+async def test_send_error_during_handshake():
+ client, server = ssl_memory_stream_pair()
+
+ async def bad_hook():
+ raise KeyError
+
+ client.transport_stream.send_stream.send_all_hook = bad_hook
+
+ with pytest.raises(KeyError):
+ with assert_yields():
+ await client.do_handshake()
+
+ with pytest.raises(BrokenStreamError):
+ with assert_yields():
+ await client.do_handshake()
+
+
+async def test_receive_error_during_handshake():
+ client, server = ssl_memory_stream_pair()
+
+ async def bad_hook():
+ raise KeyError
+
+ client.transport_stream.receive_stream.receive_some_hook = bad_hook
+
+ async def client_side(cancel_scope):
+ with pytest.raises(KeyError):
+ with assert_yields():
+ await client.do_handshake()
+ cancel_scope.cancel()
+
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(client_side, nursery.cancel_scope)
+ nursery.spawn(server.do_handshake)
+
+ with pytest.raises(BrokenStreamError):
+ with assert_yields():
+ await client.do_handshake()
+
+
+async def test_getpeercert():
+ # Make sure we're not affected by https://bugs.python.org/issue29334
+ client, server = ssl_memory_stream_pair()
+
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(client.do_handshake)
+ nursery.spawn(server.do_handshake)
+
+ assert server.getpeercert() is None
+ assert ((("commonName", "trio-test-1.example.org"),)
+ in client.getpeercert()["subject"])
diff --git a/trio/tests/test_ssl_certs/make-test-certs.sh b/trio/tests/test_ssl_certs/make-test-certs.sh
new file mode 100755
index 0000000000..0964cb14e0
--- /dev/null
+++ b/trio/tests/test_ssl_certs/make-test-certs.sh
@@ -0,0 +1,53 @@
+#!/bin/bash
+
+# This file generates 3 .pem files that we use in the test suite.
+#
+# trio-test-CA.pem contains the public key for a root CA certificate.
+#
+# trio-test-{1,2}.pem contain, concatenated in order:
+# - a private key
+# - a certificate signed by our CA claiming that this is the key for the host
+# "trio-test-$N.example.org"
+# - the root CA certificate again (to complete the cert chain)
+#
+# End result is that if you do
+#
+# ssl.create_default_context(cafile="trio-test-CA.pem")
+#
+# then your SSLContext will trust the trio-test-{1,2} certificates.
+#
+# And if you do
+#
+# sslcontext.load_cert_chain("trio-test-1.pem")
+#
+# then you can claim to be trio-test-1.example.org.
+
+set -uxe -o pipefail
+
+# Generate a self-signed 2048-bit RSA key as our signing root
+openssl req -x509 -newkey rsa:2048 -nodes -sha256 -days 99999 \
+ -subj '/O=Trio test CA' \
+ -keyout trio-test-CA.key -out trio-test-CA.pem
+
+for CERT in 1; do
+ # Create a key and CSR.
+ #
+ # Our tests only use one name, so CN= is enough. (Otherwise we would need
+ # to use subjectAltNames=, which *replaces* CN=.)
+ openssl req -new -newkey rsa:2048 -nodes -sha256 \
+ -subj "/CN=trio-test-${CERT}.example.org" \
+ -keyout trio-test-${CERT}.key -out trio-test-${CERT}.csr
+ # Use the CSR and CA to sign the key, generating a certificate
+ openssl x509 -req -in trio-test-${CERT}.csr -sha256 -days 99999 \
+ -CA trio-test-CA.pem -CAkey trio-test-CA.key \
+ -set_serial ${CERT} \
+ -out trio-test-${CERT}.crt
+ # Combine key/cert/root-CA into a single file for convenience
+ # (see https://docs.python.org/3/library/ssl.html#ssl-certificates)
+ cat trio-test-${CERT}.key trio-test-${CERT}.crt trio-test-CA.pem \
+ > trio-test-${CERT}.pem
+ rm -f trio-test-${CERT}.{csr,key,crt}
+done
+
+# We don't need the signing key anymore, remove it to reduce clutter
+rm -f trio-test-CA.key
diff --git a/trio/tests/test_ssl_certs/trio-test-1.pem b/trio/tests/test_ssl_certs/trio-test-1.pem
new file mode 100644
index 0000000000..1431aa033d
--- /dev/null
+++ b/trio/tests/test_ssl_certs/trio-test-1.pem
@@ -0,0 +1,64 @@
+-----BEGIN PRIVATE KEY-----
+MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDQC/Q9j1KiWr56
+QUSL5CQEmWecI/2IH1bOekAn9hy+xb98DIsJXy9u7u0gjFTAHJh1ysvRQ8gacdsC
+locBnt6+ElDPpJgiGHyf5NWzrJ/4uNwM4FKZYUObA5QeM8CTJzuCvFkH3MI17yJk
+3uqz/HfiUdeA2QO4iBDgqGj4LjIfO16C43VoaHvb8ePsRoBHlWCzZGqp2q98lMgi
+Ed7bq1IlrGb5MzU4bnZ0k0fDi4adSG5TI9W+aYwTz+OvZIZ9xqTjj5+7Zh4IAnBh
+Ql23tVgS+Kuemcf5zS0vVUxG1paATKB/MBqtfeyXJyrMMXSUR9Qv0UAaGHNqPQGZ
+ituO1tYXAgMBAAECggEAHJLXu7C4j7XY3V+jc3ck/0C2ezpyMsTjHj6qGxLxRb5R
+G095tRLOp/TGuqaraStEQUFWFuqxS/iBNOzJpA5W11IaqToY7u3gB/Hc6+10lyuE
+hXw1u/0g1OR77l37P/qucLk/nRXT0qaCWcpH/+pX6MyGxZqIqUp+zuwyZouptKIq
+QE1wDy6AEKLSraan3dh9pEd1FGJB7lKKnPSgbwP2H9dgtWgYYOgNAZSszKWX7rGT
++kEH0aPC5OOpkG58saF4d/eI+KRU+/QGZcLrFZsiYuWxFcFkrnKKx8PV5rjrLRFG
++D13zASaTFJ4W6gLCTVb9lNDofhHb8JCcXrZi77fWQKBgQD3Vf+94Fw2Ybvr84G0
+7p1kCf6naDMcvRPHl0pGI/oCP7gwkgtdBnNLFU2Y8IS6YvN8jMQzPB3gvwj3T3xK
+LPoM5Ie/RJYbPt3tNuXpzLY6NYx4BZ1rvK7h7ShTDNf8zb+pIdk8k0Xw3rVURWRI
+WLA6jSf3Oc9c08/X8TWpZpwaEwKBgQDXVaBUSDz7+zH6WOtRSq5Wr1bDkubYnJ8Y
+RsUN4Ju/59hhnIRewDty/28jtXcCtpPH1CrYf/mvsU8k6cn/u3pPbhgIgSDLdZT1
+j/kODgetfhXTbMZAp70Er9svlyUVlo0xV3fNMcj/FCVeVIF/0Icn/jUTJNTdlxbW
+ch2JHRjUbQKBgQCzwU7CoqKh61n2W90ysBC3OgRXioVLJ6eOcUfLvi3fIIwu0JVt
+oFh+gxcIRhVQmMW5CV02l0RnqK9Nffkot5Nrd1OpEKG/X2tPEYz65Iqzt2NFf18v
+g8vd6sxZv4Xh926J703AlpBIRLOocV42ri41/4zCQsOQBWiS2n1Thn2A/QKBgEbm
+58q4mnPxywv+eUUkDPF3/F6bIS2TrILm0n12RnJS2ZmSWreEHk8IMkUUvCIFkfVL
+M+xjfwhNnpyt6hgtV+GNg5ZRRkYX6jtM85mgHwEOMguSlli1onRHnyk1YD2Se90S
+St0ilmb+8Cr2MkmulMIjXsB18S0hUaC8pGMAVKulAoGARusNpVPh3M8hc7k1l5tn
+qOy7r/DkaxUXAje8c52x7meRXQenFmwbjLvtpgF1wh3RYzh48Ljb4An4vIy+dzC0
+ZRpE46qQg05cRksRoPWwpAD9oJjrXgBV/QPHMG4eTJXqh8vKECjlAhDD5HduP191
+lZiUjv4qGPFam1Iuhc2wRhc=
+-----END PRIVATE KEY-----
+-----BEGIN CERTIFICATE-----
+MIICrzCCAZcCAQEwDQYJKoZIhvcNAQELBQAwFzEVMBMGA1UECgwMVHJpbyB0ZXN0
+IENBMCAXDTE3MDQyNTAyMjgzNloYDzIyOTEwMjA3MDIyODM2WjAiMSAwHgYDVQQD
+DBd0cmlvLXRlc3QtMS5leGFtcGxlLm9yZzCCASIwDQYJKoZIhvcNAQEBBQADggEP
+ADCCAQoCggEBANAL9D2PUqJavnpBRIvkJASZZ5wj/YgfVs56QCf2HL7Fv3wMiwlf
+L27u7SCMVMAcmHXKy9FDyBpx2wKWhwGe3r4SUM+kmCIYfJ/k1bOsn/i43AzgUplh
+Q5sDlB4zwJMnO4K8WQfcwjXvImTe6rP8d+JR14DZA7iIEOCoaPguMh87XoLjdWho
+e9vx4+xGgEeVYLNkaqnar3yUyCIR3turUiWsZvkzNThudnSTR8OLhp1IblMj1b5p
+jBPP469khn3GpOOPn7tmHggCcGFCXbe1WBL4q56Zx/nNLS9VTEbWloBMoH8wGq19
+7JcnKswxdJRH1C/RQBoYc2o9AZmK247W1hcCAwEAATANBgkqhkiG9w0BAQsFAAOC
+AQEAPJ18+JGxFSFsRTt8P4RMQjTerjIOealquLlTUxflwb7EkjTnGtbVibAKDYoB
+nwGgVPk5cXEWMAa31Ub+riG6PaXQirwzAZZYIfffFCCAi7JX41PM0TXWhNEWIvAT
+7t9wREIo1rU337PrjlxNZx/lJ3Iwglxo5ol7Ecguq3x7rXbl5tW9oCOmbKexnAGq
+dU9flNTCnQNuthfJToIIbTIiDMJB0vMvIHqGeehu+xf2slv8CNbSW3cfs2RHsLS7
+kZelM6ALMhOtHcW7V89bzGbp1OK0w3qIW2gixEO9ujLpNeFfWsDXo+fZRyQ7imq5
+wnZOkvKSUGrtrrschhr6X6bd7w==
+-----END CERTIFICATE-----
+-----BEGIN CERTIFICATE-----
+MIIDBjCCAe6gAwIBAgIJAK25bq8asqsuMA0GCSqGSIb3DQEBCwUAMBcxFTATBgNV
+BAoMDFRyaW8gdGVzdCBDQTAgFw0xNzA0MjUwMjI4MzZaGA8yMjkxMDIwNzAyMjgz
+NlowFzEVMBMGA1UECgwMVHJpbyB0ZXN0IENBMIIBIjANBgkqhkiG9w0BAQEFAAOC
+AQ8AMIIBCgKCAQEAoNv5sSeaMn0RR4zovWPFUpUNeaEyAvJvFU1SRrvZmBwQUihK
+FKaFfW820T5fEfAV266bFTpyw8DnxLGe6nlNqfF5H3yrwDNLQUebvC7K7xZy07zv
+5aYqxVPlD/JqSqBzMLB4gl0nrKVRU+6y06nKlJ+6InRZ9KhhHI9ak7081aElDad2
+50iPiW3jEC50L3gBY1DVEbUQHcmOdS1Ne5keY2AgF1umZz4obsjFQghH+NwbVRq8
+VyX8dtbxnsLq4CTFUZpHz48GDKQz6ipRl6zWiMOreMPTNXmjR9taoxtXa1+m47PE
+gb3oJ8IGUbxCk7Bd85uVCWRvA06gm26eaHraRQIDAQABo1MwUTAdBgNVHQ4EFgQU
+Q2zAnu8wesWfG4DkT1SMqOh3G8swHwYDVR0jBBgwFoAUQ2zAnu8wesWfG4DkT1SM
+qOh3G8swDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAl8PY4uy4
+/dKVLwHEhh7HGbRDibs5xK87qGUZQ6FmiNWMI2X4aNAXuqrMMKMs4151kpNBqif7
+7f3wMjKj2tIfrEm50aVW72aihGdNY7W88rfh1xnxE/SZ4l3m3UMF8a0NRPkhkEMC
+Nk/ngoM20YI11OEO/VHZzef59DCx1Pfl0+T4lRgA4lbgG0dNTIggy9uhdZsWemF1
+gVeiyivEtJIy2RmvwA9rP+ZwYLVoYSV1lCOVrOlO1VWadAe//NYAKzcuBpufjWNz
+vY+G5AzWfmNtu3RegDw+xEO31mGF7oKqCuESqTsvTbYBymi/l0qFazovrUo8js4G
+5GjPPPACNXyNJw==
+-----END CERTIFICATE-----
diff --git a/trio/tests/test_ssl_certs/trio-test-CA.pem b/trio/tests/test_ssl_certs/trio-test-CA.pem
new file mode 100644
index 0000000000..ca5eb63e8c
--- /dev/null
+++ b/trio/tests/test_ssl_certs/trio-test-CA.pem
@@ -0,0 +1,19 @@
+-----BEGIN CERTIFICATE-----
+MIIDBjCCAe6gAwIBAgIJAK25bq8asqsuMA0GCSqGSIb3DQEBCwUAMBcxFTATBgNV
+BAoMDFRyaW8gdGVzdCBDQTAgFw0xNzA0MjUwMjI4MzZaGA8yMjkxMDIwNzAyMjgz
+NlowFzEVMBMGA1UECgwMVHJpbyB0ZXN0IENBMIIBIjANBgkqhkiG9w0BAQEFAAOC
+AQ8AMIIBCgKCAQEAoNv5sSeaMn0RR4zovWPFUpUNeaEyAvJvFU1SRrvZmBwQUihK
+FKaFfW820T5fEfAV266bFTpyw8DnxLGe6nlNqfF5H3yrwDNLQUebvC7K7xZy07zv
+5aYqxVPlD/JqSqBzMLB4gl0nrKVRU+6y06nKlJ+6InRZ9KhhHI9ak7081aElDad2
+50iPiW3jEC50L3gBY1DVEbUQHcmOdS1Ne5keY2AgF1umZz4obsjFQghH+NwbVRq8
+VyX8dtbxnsLq4CTFUZpHz48GDKQz6ipRl6zWiMOreMPTNXmjR9taoxtXa1+m47PE
+gb3oJ8IGUbxCk7Bd85uVCWRvA06gm26eaHraRQIDAQABo1MwUTAdBgNVHQ4EFgQU
+Q2zAnu8wesWfG4DkT1SMqOh3G8swHwYDVR0jBBgwFoAUQ2zAnu8wesWfG4DkT1SM
+qOh3G8swDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAl8PY4uy4
+/dKVLwHEhh7HGbRDibs5xK87qGUZQ6FmiNWMI2X4aNAXuqrMMKMs4151kpNBqif7
+7f3wMjKj2tIfrEm50aVW72aihGdNY7W88rfh1xnxE/SZ4l3m3UMF8a0NRPkhkEMC
+Nk/ngoM20YI11OEO/VHZzef59DCx1Pfl0+T4lRgA4lbgG0dNTIggy9uhdZsWemF1
+gVeiyivEtJIy2RmvwA9rP+ZwYLVoYSV1lCOVrOlO1VWadAe//NYAKzcuBpufjWNz
+vY+G5AzWfmNtu3RegDw+xEO31mGF7oKqCuESqTsvTbYBymi/l0qFazovrUo8js4G
+5GjPPPACNXyNJw==
+-----END CERTIFICATE-----
diff --git a/trio/tests/test_streams.py b/trio/tests/test_streams.py
new file mode 100644
index 0000000000..4fd282fd61
--- /dev/null
+++ b/trio/tests/test_streams.py
@@ -0,0 +1,108 @@
+import pytest
+
+import attr
+
+from ..abc import SendStream, ReceiveStream
+from .._streams import StapledStream
+
+@attr.s
+class RecordSendStream(SendStream):
+ record = attr.ib(default=attr.Factory(list))
+
+ async def send_all(self, data):
+ self.record.append(("send_all", data))
+
+ async def wait_send_all_might_not_block(self):
+ self.record.append("wait_send_all_might_not_block")
+
+ async def graceful_close(self):
+ self.record.append("graceful_close")
+
+ def forceful_close(self):
+ self.record.append("forceful_close")
+
+
+@attr.s
+class RecordReceiveStream(ReceiveStream):
+ record = attr.ib(default=attr.Factory(list))
+
+ async def receive_some(self, max_bytes):
+ self.record.append(("receive_some", max_bytes))
+
+ async def graceful_close(self):
+ self.record.append("graceful_close")
+
+ def forceful_close(self):
+ self.record.append("forceful_close")
+
+
+async def test_StapledStream():
+ send_stream = RecordSendStream()
+ receive_stream = RecordReceiveStream()
+ stapled = StapledStream(send_stream, receive_stream)
+
+ assert stapled.send_stream is send_stream
+ assert stapled.receive_stream is receive_stream
+
+ await stapled.send_all(b"foo")
+ await stapled.wait_send_all_might_not_block()
+ assert send_stream.record == [
+ ("send_all", b"foo"), "wait_send_all_might_not_block",
+ ]
+ send_stream.record.clear()
+
+ await stapled.send_eof()
+ assert send_stream.record == ["graceful_close"]
+ send_stream.record.clear()
+
+ async def fake_send_eof():
+ send_stream.record.append("send_eof")
+ send_stream.send_eof = fake_send_eof
+ await stapled.send_eof()
+ assert send_stream.record == ["send_eof"]
+
+ send_stream.record.clear()
+ assert receive_stream.record == []
+
+ await stapled.receive_some(1234)
+ assert receive_stream.record == [("receive_some", 1234)]
+ assert send_stream.record == []
+ receive_stream.record.clear()
+
+ await stapled.graceful_close()
+ stapled.forceful_close()
+ assert receive_stream.record == ["graceful_close", "forceful_close"]
+ assert send_stream.record == ["graceful_close", "forceful_close"]
+
+
+async def test_StapledStream_with_erroring_close():
+ class BrokenSendStream(RecordSendStream):
+ def forceful_close(self):
+ super().forceful_close()
+ raise KeyError
+
+ async def graceful_close(self):
+ await super().graceful_close()
+ raise ValueError
+
+ class BrokenReceiveStream(RecordReceiveStream):
+ def forceful_close(self):
+ super().forceful_close()
+ raise KeyError
+
+ async def graceful_close(self):
+ await super().graceful_close()
+ raise ValueError
+
+ stapled = StapledStream(BrokenSendStream(), BrokenReceiveStream())
+
+ with pytest.raises(KeyError) as excinfo:
+ stapled.forceful_close()
+ assert isinstance(excinfo.value.__context__, KeyError)
+
+ with pytest.raises(ValueError) as excinfo:
+ await stapled.graceful_close()
+ assert isinstance(excinfo.value.__context__, ValueError)
+
+ assert stapled.send_stream.record == ["forceful_close", "graceful_close"]
+ assert stapled.receive_stream.record == ["forceful_close", "graceful_close"]
diff --git a/trio/tests/test_sync.py b/trio/tests/test_sync.py
index e8e1db4154..3f5e622abd 100644
--- a/trio/tests/test_sync.py
+++ b/trio/tests/test_sync.py
@@ -94,10 +94,14 @@ async def test_Semaphore_bounded():
assert bs.value == 1
-async def test_Lock():
- l = Lock()
+@pytest.mark.parametrize(
+ "lockcls", [Lock, StrictFIFOLock], ids=lambda fn: fn.__name__)
+async def test_Lock_and_StrictFIFOLock(lockcls):
+ l = lockcls()
assert not l.locked()
repr(l) # smoke test
+ # make sure repr uses the right name for subclasses
+ assert lockcls.__name__ in repr(l)
with assert_yields():
async with l:
assert l.locked()
@@ -161,6 +165,8 @@ async def holder():
async def test_Condition():
with pytest.raises(TypeError):
Condition(Semaphore(1))
+ with pytest.raises(TypeError):
+ Condition(StrictFIFOLock)
l = Lock()
c = Condition(l)
assert not l.locked()
@@ -436,6 +442,7 @@ def release(self):
lock_factories = [
lambda: Semaphore(1),
Lock,
+ StrictFIFOLock,
lambda: QueueLock1(10),
lambda: QueueLock1(1),
QueueLock2,
@@ -443,6 +450,7 @@ def release(self):
lock_factory_names = [
"Semaphore(1)",
"Lock",
+ "StrictFIFOLock",
"QueueLock1(10)",
"QueueLock1(1)",
"QueueLock2",
@@ -483,7 +491,7 @@ async def worker(lock_like):
# Several workers queue on the same lock; make sure they each get it, in
# order.
@generic_lock_test
-async def test_generic_lock_fairness(lock_factory):
+async def test_generic_lock_fifo_fairness(lock_factory):
initial_order = []
record = []
LOOPS = 5
diff --git a/trio/tests/test_testing.py b/trio/tests/test_testing.py
index d14947a3f1..36587bce79 100644
--- a/trio/tests/test_testing.py
+++ b/trio/tests/test_testing.py
@@ -5,7 +5,9 @@
from .. import sleep
from .. import _core
+from .. import _streams
from ..testing import *
+from ..testing import _assert_raises, _UnboundedByteQueue
async def test_wait_all_tasks_blocked():
record = []
@@ -34,12 +36,14 @@ async def cancelled_while_waiting():
nursery.cancel_scope.cancel()
assert t4.result.unwrap() == "ok"
+
async def test_wait_all_tasks_blocked_with_timeouts(mock_clock):
record = []
async def timeout_task():
record.append("tt start")
await sleep(5)
record.append("tt finished")
+
async with _core.open_nursery() as nursery:
t = nursery.spawn(timeout_task)
await wait_all_tasks_blocked()
@@ -76,6 +80,7 @@ async def wait_big_cushion():
nursery.spawn(wait_small_cushion)
nursery.spawn(wait_small_cushion)
nursery.spawn(wait_big_cushion)
+
assert record == [
"blink start",
"wait_no_cushion end",
@@ -85,6 +90,35 @@ async def wait_big_cushion():
"wait_big_cushion end",
]
+
+async def test_wait_all_tasks_blocked_with_tiebreaker():
+ record = []
+
+ async def do_wait(cushion, tiebreaker):
+ await wait_all_tasks_blocked(cushion=cushion, tiebreaker=tiebreaker)
+ record.append((cushion, tiebreaker))
+
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(do_wait, 0, 0)
+ nursery.spawn(do_wait, 0, -1)
+ nursery.spawn(do_wait, 0, 1)
+ nursery.spawn(do_wait, 0, -1)
+ nursery.spawn(do_wait, 0.0001, 10)
+ nursery.spawn(do_wait, 0.0001, -10)
+
+ assert record == sorted(record)
+ assert record == [
+ (0, -1),
+ (0, -1),
+ (0, 0),
+ (0, 1),
+ (0.0001, -10),
+ (0.0001, 10),
+ ]
+
+
+################################################################
+
async def test_assert_yields():
with assert_yields():
await _core.yield_briefly()
@@ -121,6 +155,8 @@ async def test_assert_yields():
await _core.yield_briefly()
+################################################################
+
async def test_Sequencer():
record = []
def t(val):
@@ -187,6 +223,8 @@ async def child(i):
pass # pragma: no cover
+################################################################
+
def test_mock_clock():
REAL_NOW = 123.0
c = MockClock()
@@ -266,11 +304,11 @@ async def test_mock_clock_autojump(mock_clock):
# This should wake up at the same time as the autojump_threshold, and
# confuse things. There is no deadline, so it shouldn't actually jump
# the clock. But does it handle the situation gracefully?
- await wait_all_tasks_blocked(0.02)
+ await wait_all_tasks_blocked(cushion=0.02, tiebreaker=float("inf"))
# And again with threshold=0, because that has some special
# busy-wait-avoidance logic:
mock_clock.autojump_threshold = 0
- await wait_all_tasks_blocked()
+ await wait_all_tasks_blocked(tiebreaker=float("inf"))
# set up a situation where the autojump task is blocked for a long long
# time, to make sure that cancel-and-adjust-threshold logic is working
@@ -307,3 +345,402 @@ def test_mock_clock_autojump_preset():
real_start = time.monotonic()
_core.run(sleep, 10000, clock=mock_clock)
assert time.monotonic() - real_start < 1
+
+
+async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked(mock_clock):
+ # Checks that autojump_threshold=0 doesn't interfere with
+ # calling wait_all_tasks_blocked with the default cushion=0 and arbitrary
+ # tiebreakers.
+
+ mock_clock.autojump_threshold = 0
+
+ record = []
+
+ async def sleeper():
+ await sleep(100)
+ record.append("yawn")
+
+ async def waiter():
+ for i in range(10):
+ await wait_all_tasks_blocked(tiebreaker=i)
+ record.append(i)
+ await sleep(1000)
+ record.append("waiter done")
+
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(sleeper)
+ nursery.spawn(waiter)
+
+ assert record == list(range(10)) + ["yawn", "waiter done"]
+
+
+################################################################
+
+async def test__assert_raises():
+ with pytest.raises(AssertionError):
+ with _assert_raises(RuntimeError):
+ 1 + 1
+
+ with pytest.raises(TypeError):
+ with _assert_raises(RuntimeError):
+ "foo" + 1
+
+ with _assert_raises(RuntimeError):
+ raise RuntimeError
+
+# This is a private implementation detail, but it's complex enough to be worth
+# testing directly
+async def test__UnboundeByteQueue():
+ ubq = _UnboundedByteQueue()
+
+ ubq.put(b"123")
+ ubq.put(b"456")
+ assert ubq.get_nowait(1) == b"1"
+ assert ubq.get_nowait(10) == b"23456"
+ ubq.put(b"789")
+ assert ubq.get_nowait() == b"789"
+
+ with pytest.raises(_core.WouldBlock):
+ ubq.get_nowait(10)
+ with pytest.raises(_core.WouldBlock):
+ ubq.get_nowait()
+
+ with pytest.raises(TypeError):
+ ubq.put("string")
+
+ ubq.put(b"abc")
+ with assert_yields():
+ assert await ubq.get(10) == b"abc"
+ ubq.put(b"def")
+ ubq.put(b"ghi")
+ with assert_yields():
+ assert await ubq.get(1) == b"d"
+ with assert_yields():
+ assert await ubq.get() == b"efghi"
+
+ async def putter(data):
+ await wait_all_tasks_blocked()
+ ubq.put(data)
+
+ async def getter(expect):
+ with assert_yields():
+ assert await ubq.get() == expect
+
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(getter, b"xyz")
+ nursery.spawn(putter, b"xyz")
+
+ # Two gets at the same time -> ResourceBusyError
+ with pytest.raises(_core.ResourceBusyError):
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(getter, b"asdf")
+ nursery.spawn(getter, b"asdf")
+
+ # Closing
+
+ ubq.close()
+ with pytest.raises(_streams.ClosedStreamError):
+ ubq.put(b"---")
+
+ assert ubq.get_nowait(10) == b""
+ assert ubq.get_nowait() == b""
+ assert await ubq.get(10) == b""
+ assert await ubq.get() == b""
+
+ # close is idempotent
+ ubq.close()
+
+ # close wakes up blocked getters
+ ubq2 = _UnboundedByteQueue()
+
+ async def closer():
+ await wait_all_tasks_blocked()
+ ubq2.close()
+
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(getter, b"")
+ nursery.spawn(closer)
+
+
+async def test_MemorySendStream():
+ mss = MemorySendStream()
+
+ async def do_send_all(data):
+ with assert_yields():
+ await mss.send_all(data)
+
+ await do_send_all(b"123")
+ assert mss.get_data_nowait(1) == b"1"
+ assert mss.get_data_nowait() == b"23"
+
+ with assert_yields():
+ await mss.wait_send_all_might_not_block()
+
+ with pytest.raises(_core.WouldBlock):
+ mss.get_data_nowait()
+ with pytest.raises(_core.WouldBlock):
+ mss.get_data_nowait(10)
+
+ await do_send_all(b"456")
+ with assert_yields():
+ assert await mss.get_data() == b"456"
+
+ with pytest.raises(_core.ResourceBusyError):
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(do_send_all, b"xxx")
+ nursery.spawn(do_send_all, b"xxx")
+
+ with assert_yields():
+ await mss.graceful_close()
+
+ assert await mss.get_data() == b"xxx"
+ assert await mss.get_data() == b""
+ with pytest.raises(_streams.ClosedStreamError):
+ await do_send_all(b"---")
+
+ # hooks
+
+ assert mss.send_all_hook is None
+ assert mss.wait_send_all_might_not_block_hook is None
+ assert mss.close_hook is None
+
+ record = []
+ async def send_all_hook():
+ # hook runs after send_all does its work (can pull data out)
+ assert mss2.get_data_nowait() == b"abc"
+ record.append("send_all_hook")
+ async def wait_send_all_might_not_block_hook():
+ record.append("wait_send_all_might_not_block_hook")
+ def close_hook():
+ record.append("close_hook")
+
+ mss2 = MemorySendStream(
+ send_all_hook,
+ wait_send_all_might_not_block_hook,
+ close_hook,
+ )
+
+ assert mss2.send_all_hook is send_all_hook
+ assert mss2.wait_send_all_might_not_block_hook is wait_send_all_might_not_block_hook
+ assert mss2.close_hook is close_hook
+
+ await mss2.send_all(b"abc")
+ await mss2.wait_send_all_might_not_block()
+ mss2.forceful_close()
+
+ assert record == [
+ "send_all_hook",
+ "wait_send_all_might_not_block_hook",
+ "close_hook",
+ ]
+
+
+async def test_MemoryRecieveStream():
+ mrs = MemoryReceiveStream()
+
+ async def do_receive_some(max_bytes):
+ with assert_yields():
+ return await mrs.receive_some(max_bytes)
+
+ mrs.put_data(b"abc")
+ assert await do_receive_some(1) == b"a"
+ assert await do_receive_some(10) == b"bc"
+ with pytest.raises(TypeError):
+ await do_receive_some(None)
+
+ with pytest.raises(_core.ResourceBusyError):
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(do_receive_some, 10)
+ nursery.spawn(do_receive_some, 10)
+
+ assert mrs.receive_some_hook is None
+
+ mrs.put_data(b"def")
+ mrs.put_eof()
+ mrs.put_eof()
+
+ assert await do_receive_some(10) == b"def"
+ assert await do_receive_some(10) == b""
+ assert await do_receive_some(10) == b""
+
+ with pytest.raises(_streams.ClosedStreamError):
+ mrs.put_data(b"---")
+
+ async def receive_some_hook():
+ mrs2.put_data(b"xxx")
+
+ record = []
+ def close_hook():
+ record.append("closed")
+
+ mrs2 = MemoryReceiveStream(receive_some_hook, close_hook)
+ assert mrs2.receive_some_hook is receive_some_hook
+ assert mrs2.close_hook is close_hook
+
+ mrs2.put_data(b"yyy")
+ assert await mrs2.receive_some(10) == b"yyyxxx"
+ assert await mrs2.receive_some(10) == b"xxx"
+ assert await mrs2.receive_some(10) == b"xxx"
+
+ mrs2.put_data(b"zzz")
+ mrs2.receive_some_hook = None
+ assert await mrs2.receive_some(10) == b"zzz"
+
+ mrs2.put_data(b"lost on close")
+ with assert_yields():
+ await mrs2.graceful_close()
+ assert record == ["closed"]
+
+ with pytest.raises(_streams.ClosedStreamError):
+ await mrs2.receive_some(10)
+
+
+async def test_MemoryRecvStream_closing():
+ mrs = MemoryReceiveStream()
+ # close with no pending data
+ mrs.forceful_close()
+ with pytest.raises(_streams.ClosedStreamError):
+ assert await mrs.receive_some(10) == b""
+ # repeated closes ok
+ mrs.forceful_close()
+ # put_data now fails
+ with pytest.raises(_streams.ClosedStreamError):
+ mrs.put_data(b"123")
+
+ mrs2 = MemoryReceiveStream()
+ # close with pending data
+ mrs2.put_data(b"xyz")
+ mrs2.forceful_close()
+ with pytest.raises(_streams.ClosedStreamError):
+ await mrs2.receive_some(10)
+
+
+async def test_memory_stream_pump():
+ mss = MemorySendStream()
+ mrs = MemoryReceiveStream()
+
+ # no-op if no data present
+ memory_stream_pump(mss, mrs)
+
+ await mss.send_all(b"123")
+ memory_stream_pump(mss, mrs)
+ assert await mrs.receive_some(10) == b"123"
+
+ await mss.send_all(b"456")
+ assert memory_stream_pump(mss, mrs, max_bytes=1)
+ assert await mrs.receive_some(10) == b"4"
+ assert memory_stream_pump(mss, mrs, max_bytes=1)
+ assert memory_stream_pump(mss, mrs, max_bytes=1)
+ assert not memory_stream_pump(mss, mrs, max_bytes=1)
+ assert await mrs.receive_some(10) == b"56"
+
+ mss.forceful_close()
+ memory_stream_pump(mss, mrs)
+ assert await mrs.receive_some(10) == b""
+
+
+async def test_memory_stream_one_way_pair():
+ s, r = memory_stream_one_way_pair()
+ assert s.send_all_hook is not None
+ assert s.wait_send_all_might_not_block_hook is None
+ assert s.close_hook is not None
+ assert r.receive_some_hook is None
+ await s.send_all(b"123")
+ assert await r.receive_some(10) == b"123"
+
+ # This fails if we pump on r.receive_some_hook; we need to pump on s.send_all_hook
+ async def sender():
+ await wait_all_tasks_blocked()
+ await s.send_all(b"abc")
+
+ async def receiver(expected):
+ assert await r.receive_some(10) == expected
+
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(receiver, b"abc")
+ nursery.spawn(sender)
+
+ # And this fails if we don't pump from close_hook
+ async def graceful_closer():
+ await wait_all_tasks_blocked()
+ await s.graceful_close()
+
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(receiver, b"")
+ nursery.spawn(graceful_closer)
+
+ s, r = memory_stream_one_way_pair()
+
+ async def forceful_closer():
+ await wait_all_tasks_blocked()
+ s.forceful_close()
+
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(receiver, b"")
+ nursery.spawn(forceful_closer)
+
+ s, r = memory_stream_one_way_pair()
+
+ old = s.send_all_hook
+ s.send_all_hook = None
+ await s.send_all(b"456")
+
+ async def cancel_after_idle(nursery):
+ await wait_all_tasks_blocked()
+ nursery.cancel_scope.cancel()
+
+ async def check_for_cancel():
+ with pytest.raises(_core.Cancelled):
+ # This should block forever... or until cancelled. Even though we
+ # sent some data on the send stream.
+ await r.receive_some(10)
+
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(cancel_after_idle, nursery)
+ nursery.spawn(check_for_cancel)
+
+ s.send_all_hook = old
+ await s.send_all(b"789")
+ assert await r.receive_some(10) == b"456789"
+
+
+async def test_memory_stream_pair():
+ a, b = memory_stream_pair()
+ await a.send_all(b"123")
+ await b.send_all(b"abc")
+ assert await b.receive_some(10) == b"123"
+ assert await a.receive_some(10) == b"abc"
+
+ await a.send_eof()
+ assert await b.receive_some(10) == b""
+
+ async def sender():
+ await wait_all_tasks_blocked()
+ await b.send_all(b"xyz")
+
+ async def receiver():
+ assert await a.receive_some(10) == b"xyz"
+
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(receiver)
+ nursery.spawn(sender)
+
+
+async def test_memory_streams_with_generic_tests():
+ async def one_way_stream_maker():
+ return memory_stream_one_way_pair()
+ await check_one_way_stream(one_way_stream_maker, None)
+
+ async def half_closeable_stream_maker():
+ return memory_stream_pair()
+ await check_half_closeable_stream(half_closeable_stream_maker, None)
+
+
+async def test_lockstep_streams_with_generic_tests():
+ async def one_way_stream_maker():
+ return lockstep_stream_one_way_pair()
+ await check_one_way_stream(one_way_stream_maker, one_way_stream_maker)
+
+ async def two_way_stream_maker():
+ return lockstep_stream_pair()
+ await check_two_way_stream(two_way_stream_maker, two_way_stream_maker)
diff --git a/trio/tests/test_util.py b/trio/tests/test_util.py
index 6f39beeca8..00b0907d32 100644
--- a/trio/tests/test_util.py
+++ b/trio/tests/test_util.py
@@ -2,16 +2,59 @@
import signal
-from .. import _util
+from .._util import *
+from .. import _core
+from ..testing import wait_all_tasks_blocked, assert_yields
-async def test_signal_raise():
+def test_signal_raise():
record = []
def handler(signum, _):
record.append(signum)
old = signal.signal(signal.SIGFPE, handler)
try:
- _util.signal_raise(signal.SIGFPE)
+ signal_raise(signal.SIGFPE)
finally:
signal.signal(signal.SIGFPE, old)
assert record == [signal.SIGFPE]
+
+
+async def test_UnLock():
+ ul1 = UnLock(RuntimeError, "ul1")
+ ul2 = UnLock(ValueError)
+
+ async with ul1:
+ with assert_yields():
+ async with ul2:
+ print("ok")
+
+ with pytest.raises(RuntimeError) as excinfo:
+ async with ul1:
+ with assert_yields():
+ async with ul1:
+ pass # pragma: no cover
+ assert "ul1" in str(excinfo.value)
+
+ with pytest.raises(ValueError) as excinfo:
+ async with ul2:
+ with assert_yields():
+ async with ul2:
+ pass # pragma: no cover
+
+ async def wait_with_ul1():
+ async with ul1:
+ await wait_all_tasks_blocked()
+
+ with pytest.raises(RuntimeError) as excinfo:
+ async with _core.open_nursery() as nursery:
+ nursery.spawn(wait_with_ul1)
+ nursery.spawn(wait_with_ul1)
+ assert "ul1" in str(excinfo.value)
+
+ # mixing sync and async entry
+ with pytest.raises(RuntimeError) as excinfo:
+ with ul1.sync:
+ with assert_yields():
+ async with ul1:
+ pass # pragma: no cover
+ assert "ul1" in str(excinfo.value)