diff --git a/.builds/alpine.yml b/.builds/alpine.yml new file mode 100644 index 0000000000..d4c60f6d7d --- /dev/null +++ b/.builds/alpine.yml @@ -0,0 +1,22 @@ +image: alpine/latest +packages: + - curl + - gcc + - libffi-dev + - musl-dev + - openssl-dev + - python3-dev + # required to build cryptography + - rust + - cargo +sources: + - https://github.com/python-trio/trio +tasks: + - test: | + python3 -m venv venv + source venv/bin/activate + cd trio + CI_BUILD_ID=$JOB_ID CI_BUILD_URL=$JOB_URL ./ci.sh +environment: + CODECOV_TOKEN: 87cefb17-c44b-4f2f-8b30-1fff5769ce46 + JOB_NAME: Alpine diff --git a/.builds/fedora.yml b/.builds/fedora.yml new file mode 100644 index 0000000000..eddb2368a0 --- /dev/null +++ b/.builds/fedora.yml @@ -0,0 +1,15 @@ +image: fedora/rawhide +packages: + - python3-devel + - python3-pip +sources: + - https://github.com/python-trio/trio +tasks: + - test: | + python3 -m venv venv + source venv/bin/activate + cd trio + CI_BUILD_ID=$JOB_ID CI_BUILD_URL=$JOB_URL ./ci.sh +environment: + CODECOV_TOKEN: 87cefb17-c44b-4f2f-8b30-1fff5769ce46 + JOB_NAME: Fedora diff --git a/.builds/freebsd.yml b/.builds/freebsd.yml new file mode 100644 index 0000000000..2f55cac0e9 --- /dev/null +++ b/.builds/freebsd.yml @@ -0,0 +1,18 @@ +image: freebsd/latest +packages: + - curl + - python39 + - py39-sqlite3 + - rust # required to build cryptography +sources: + - https://github.com/python-trio/trio +tasks: + - setup: sudo ln -s /usr/local/bin/bash /bin/bash + - test: | + python3.9 -m venv venv + source venv/bin/activate + cd trio + CI_BUILD_ID=$JOB_ID CI_BUILD_URL=$JOB_URL ./ci.sh +environment: + CODECOV_TOKEN: 87cefb17-c44b-4f2f-8b30-1fff5769ce46 + JOB_NAME: FreeBSD diff --git a/.coveragerc b/.coveragerc index 4bdbf69131..d577aa8adf 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,13 +1,15 @@ [run] branch=True source=trio -# For some reason coverage recording doesn't work for ipython_custom_exc.py, -# so leave it out of reports omit= setup.py - */ipython_custom_exc.py -# Omit the generated files in trio/_core starting with _public_ +# These are run in subprocesses, but still don't work. We follow +# coverage's documentation to no avail. + */trio/_core/_tests/test_multierror_scripts/* +# Omit the generated files in trio/_core starting with _generated_ */trio/_core/_generated_* +# Script used to check type completeness that isn't run in tests + */trio/_tests/check_type_completeness.py # The test suite spawns subprocesses to test some stuff, so make sure # this doesn't corrupt the coverage files parallel=True @@ -17,3 +19,13 @@ precision = 1 exclude_lines = pragma: no cover abc.abstractmethod + if TYPE_CHECKING: + if _t.TYPE_CHECKING: + @overload + +partial_branches = + pragma: no branch + if not TYPE_CHECKING: + if not _t.TYPE_CHECKING: + if .* or not TYPE_CHECKING: + if .* or not _t.TYPE_CHECKING: diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 0000000000..1d3079ad5a --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1,2 @@ +# sorting all imports with isort +933f77b96f0092e1baab4474a9208fc2e379aa32 diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000..7fbcb4fe2d --- /dev/null +++ b/.gitattributes @@ -0,0 +1,6 @@ +# For files generated by trio/_tools/gen_exports.py +trio/_core/_generated* linguist-generated=true +# Treat generated files as binary in git diff +trio/_core/_generated* -diff +# don't merge the generated json file, let the user (script) handle it +trio/_tests/verify_types.json merge=binary diff --git a/.github/workflows/autodeps.yml b/.github/workflows/autodeps.yml new file mode 100644 index 0000000000..40cf05726c --- /dev/null +++ b/.github/workflows/autodeps.yml @@ -0,0 +1,82 @@ +name: Autodeps + +on: + workflow_dispatch: + schedule: + - cron: '0 0 1 * *' + +jobs: + Autodeps: + name: Autodeps + timeout-minutes: 10 + runs-on: 'ubuntu-latest' + # https://docs.github.com/en/code-security/dependabot/working-with-dependabot/automating-dependabot-with-github-actions#changing-github_token-permissions + permissions: + pull-requests: write + issues: write + repository-projects: write + contents: write + steps: + - name: Checkout + uses: actions/checkout@v3 + - name: Setup python + uses: actions/setup-python@v4 + with: + python-version: "3.8" + - name: Bump dependencies + run: | + python -m pip install -U pip + python -m pip install -r test-requirements.txt + pip-compile -U test-requirements.in + pip-compile -U docs-requirements.in + - name: Black + run: | + # The new dependencies may contain a new black version. + # Commit any changes immediately. + python -m pip install -r test-requirements.txt + black setup.py trio + - name: Commit changes and create automerge PR + env: + GH_TOKEN: ${{ github.token }} + run: | + # setup git repo + git switch --force-create autodeps/bump_from_${GITHUB_SHA:0:6} + git config user.name 'github-actions[bot]' + git config user.email '41898282+github-actions[bot]@users.noreply.github.com' + + if ! git commit -am "Dependency updates"; then + echo "No changes to commit!" + exit 0 + fi + + git push --force --set-upstream origin autodeps/bump_from_${GITHUB_SHA:0:6} + + # git push returns before github is ready for a pr, so we poll until success + for BACKOFF in 1 2 4 8 0; do + sleep $BACKOFF + if gh pr create \ + --label dependencies --body "" \ + --title "Bump dependencies from commit ${GITHUB_SHA:0:6}" \ + ; then + break + fi + done + + if [ $BACKOFF -eq 0 ]; then + echo "Could not create the PR" + exit 1 + fi + + # gh pr create returns before the pr is ready, so we again poll until success + # https://github.com/cli/cli/issues/2619#issuecomment-1240543096 + for BACKOFF in 1 2 4 8 0; do + sleep $BACKOFF + if gh pr merge --auto --squash; then + break + fi + done + + if [ $BACKOFF -eq 0 ]; then + echo "Could not set automerge" + exit 1 + fi diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000000..40af0960f5 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,193 @@ +name: CI + +on: + push: + branches-ignore: + - "dependabot/**" + pull_request: + +concurrency: + group: ${{ github.ref }}-${{ github.workflow }}-${{ github.event_name }}${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) && format('-{0}', github.sha) || '' }} + cancel-in-progress: true + +jobs: + Windows: + name: 'Windows (${{ matrix.python }}, ${{ matrix.arch }}${{ matrix.extra_name }})' + timeout-minutes: 20 + runs-on: 'windows-latest' + strategy: + fail-fast: false + matrix: + # pypy-3.10 is failing, see https://github.com/python-trio/trio/issues/2678 + python: ['3.8', '3.9', '3.10', 'pypy-3.9-nightly'] #, 'pypy-3.10-nightly'] + arch: ['x86', 'x64'] + lsp: [''] + lsp_extract_file: [''] + extra_name: [''] + exclude: + # pypy does not release 32-bit binaries + - python: 'pypy-3.9-nightly' + arch: 'x86' + #- python: 'pypy-3.10-nightly' + # arch: 'x86' + include: + - python: '3.8' + arch: 'x64' + lsp: 'https://raw.githubusercontent.com/python-trio/trio-ci-assets/master/komodia-based-vpn-setup.zip' + lsp_extract_file: 'komodia-based-vpn-setup.exe' + extra_name: ', with Komodia LSP' + - python: '3.8' + arch: 'x64' + lsp: 'https://www.proxifier.com/download/legacy/ProxifierSetup342.exe' + lsp_extract_file: '' + extra_name: ', with IFS LSP' + #- python: '3.8' + # arch: 'x64' + # lsp: 'http://download.pctools.com/mirror/updates/9.0.0.2308-SDavfree-lite_en.exe' + # lsp_extract_file: '' + # extra_name: ', with non-IFS LSP' + continue-on-error: >- + ${{ + ( + endsWith(matrix.python, '-dev') + || endsWith(matrix.python, '-nightly') + ) + && true + || false + }} + steps: + - name: Checkout + uses: actions/checkout@v3 + - name: Setup python + uses: actions/setup-python@v4 + with: + # This allows the matrix to specify just the major.minor version while still + # expanding it to get the latest patch version including alpha releases. + # This avoids the need to update for each new alpha, beta, release candidate, + # and then finally an actual release version. actions/setup-python doesn't + # support this for PyPy presently so we get no help there. + # + # 'CPython' -> '3.9.0-alpha - 3.9.X' + # 'PyPy' -> 'pypy-3.9' + python-version: ${{ fromJSON(format('["{0}", "{1}"]', format('{0}.0-alpha - {0}.X', matrix.python), matrix.python))[startsWith(matrix.python, 'pypy')] }} + architecture: '${{ matrix.arch }}' + cache: pip + cache-dependency-path: test-requirements.txt + - name: Run tests + run: ./ci.sh + shell: bash + env: + LSP: '${{ matrix.lsp }}' + LSP_EXTRACT_FILE: '${{ matrix.lsp_extract_file }}' + - if: always() + uses: codecov/codecov-action@v3 + with: + directory: empty + token: 87cefb17-c44b-4f2f-8b30-1fff5769ce46 + name: Windows (${{ matrix.python }}, ${{ matrix.arch }}${{ matrix.extra_name }}) + flags: Windows,${{ matrix.python }} + + Ubuntu: + name: 'Ubuntu (${{ matrix.python }}${{ matrix.extra_name }})' + timeout-minutes: 10 + runs-on: 'ubuntu-latest' + strategy: + fail-fast: false + matrix: + python: ['pypy-3.9', 'pypy-3.10', '3.8', '3.9', '3.10', '3.11', '3.12-dev', 'pypy-3.9-nightly', 'pypy-3.10-nightly'] + check_formatting: ['0'] + extra_name: [''] + include: + - python: '3.8' + check_formatting: '1' + extra_name: ', check formatting' + continue-on-error: >- + ${{ + ( + matrix.check_formatting == '1' + || endsWith(matrix.python, '-dev') + || endsWith(matrix.python, '-nightly') + ) + && true + || false + }} + steps: + - name: Checkout + uses: actions/checkout@v3 + - name: Setup python + uses: actions/setup-python@v4 + if: "!endsWith(matrix.python, '-dev')" + with: + python-version: ${{ fromJSON(format('["{0}", "{1}"]', format('{0}.0-alpha - {0}.X', matrix.python), matrix.python))[startsWith(matrix.python, 'pypy')] }} + cache: pip + cache-dependency-path: test-requirements.txt + - name: Setup python (dev) + uses: deadsnakes/action@v2.0.2 + if: endsWith(matrix.python, '-dev') + with: + python-version: '${{ matrix.python }}' + - name: Run tests + run: ./ci.sh + env: + CHECK_FORMATTING: '${{ matrix.check_formatting }}' + - if: always() + uses: codecov/codecov-action@v3 + with: + directory: empty + token: 87cefb17-c44b-4f2f-8b30-1fff5769ce46 + name: Ubuntu (${{ matrix.python }}${{ matrix.extra_name }}) + flags: Ubuntu,${{ matrix.python }} + + macOS: + name: 'macOS (${{ matrix.python }})' + timeout-minutes: 15 + runs-on: 'macos-latest' + strategy: + fail-fast: false + matrix: + python: ['3.8', '3.9', '3.10', 'pypy-3.9-nightly', 'pypy-3.10-nightly'] + continue-on-error: >- + ${{ + ( + endsWith(matrix.python, '-dev') + || endsWith(matrix.python, '-nightly') + ) + && true + || false + }} + steps: + - name: Checkout + uses: actions/checkout@v3 + - name: Setup python + uses: actions/setup-python@v4 + with: + python-version: ${{ fromJSON(format('["{0}", "{1}"]', format('{0}.0-alpha - {0}.X', matrix.python), matrix.python))[startsWith(matrix.python, 'pypy')] }} + cache: pip + cache-dependency-path: test-requirements.txt + - name: Run tests + run: ./ci.sh + - if: always() + uses: codecov/codecov-action@v3 + with: + directory: empty + token: 87cefb17-c44b-4f2f-8b30-1fff5769ce46 + name: macOS (${{ matrix.python }}) + flags: macOS,${{ matrix.python }} + + # https://github.com/marketplace/actions/alls-green#why + check: # This job does nothing and is only used for the branch protection + + if: always() + + needs: + - Windows + - Ubuntu + - macOS + + runs-on: ubuntu-latest + + steps: + - name: Decide whether the needed jobs succeeded or failed + uses: re-actors/alls-green@release/v1 + with: + jobs: ${{ toJSON(needs) }} diff --git a/.gitignore b/.gitignore index a50c10b8a3..057e28568e 100644 --- a/.gitignore +++ b/.gitignore @@ -40,6 +40,7 @@ pip-log.txt htmlcov/ .tox/ .venv/ +pyvenv.cfg .coverage .coverage.* .cache diff --git a/.readthedocs.yml b/.readthedocs.yml index 909ccf1bb3..9fde00ef8f 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -5,7 +5,14 @@ formats: - htmlzip - epub +build: + os: "ubuntu-22.04" + tools: + python: "3.11" + python: - version: 3.7 install: - requirements: docs-requirements.txt + +sphinx: + fail_on_warning: true diff --git a/.style.yapf b/.style.yapf deleted file mode 100644 index 745ec88d70..0000000000 --- a/.style.yapf +++ /dev/null @@ -1,185 +0,0 @@ -[style] -# Align closing bracket with visual indentation. -align_closing_bracket_with_visual_indent=True - -# Allow dictionary keys to exist on multiple lines. For example: -# -# x = { -# ('this is the first element of a tuple', -# 'this is the second element of a tuple'): -# value, -# } -allow_multiline_dictionary_keys=False - -# Allow lambdas to be formatted on more than one line. -allow_multiline_lambdas=False - -# Insert a blank line before a class-level docstring. -blank_line_before_class_docstring=False - -# Insert a blank line before a 'def' or 'class' immediately nested -# within another 'def' or 'class'. For example: -# -# class Foo: -# # <------ this blank line -# def method(): -# ... -blank_line_before_nested_class_or_def=False - -# Do not split consecutive brackets. Only relevant when -# dedent_closing_brackets is set. For example: -# -# call_func_that_takes_a_dict( -# { -# 'key1': 'value1', -# 'key2': 'value2', -# } -# ) -# -# would reformat to: -# -# call_func_that_takes_a_dict({ -# 'key1': 'value1', -# 'key2': 'value2', -# }) -coalesce_brackets=False - -# The column limit. -column_limit=79 - -# Indent width used for line continuations. -continuation_indent_width=4 - -# Put closing brackets on a separate line, dedented, if the bracketed -# expression can't fit in a single line. Applies to all kinds of brackets, -# including function definitions and calls. For example: -# -# config = { -# 'key1': 'value1', -# 'key2': 'value2', -# } # <--- this bracket is dedented and on a separate line -# -# time_series = self.remote_client.query_entity_counters( -# entity='dev3246.region1', -# key='dns.query_latency_tcp', -# transform=Transformation.AVERAGE(window=timedelta(seconds=60)), -# start_ts=now()-timedelta(days=3), -# end_ts=now(), -# ) # <--- this bracket is dedented and on a separate line -dedent_closing_brackets=True - -# Place each dictionary entry onto its own line. -each_dict_entry_on_separate_line=True - -# The regex for an i18n comment. The presence of this comment stops -# reformatting of that line, because the comments are required to be -# next to the string they translate. -i18n_comment= - -# The i18n function call names. The presence of this function stops -# reformattting on that line, because the string it has cannot be moved -# away from the i18n comment. -i18n_function_call= - -# Indent the dictionary value if it cannot fit on the same line as the -# dictionary key. For example: -# -# config = { -# 'key1': -# 'value1', -# 'key2': value1 + -# value2, -# } -indent_dictionary_value=True - -# The number of columns to use for indentation. -indent_width=4 - -# Join short lines into one line. E.g., single line 'if' statements. -join_multiple_lines=False - -# Use spaces around default or named assigns. -spaces_around_default_or_named_assign=False - -# Use spaces around the power operator. -spaces_around_power_operator=False - -# The number of spaces required before a trailing comment. -spaces_before_comment=2 - -# Insert a space between the ending comma and closing bracket of a list, -# etc. -space_between_ending_comma_and_closing_bracket=False - -# Split before arguments if the argument list is terminated by a -# comma. -split_arguments_when_comma_terminated=True - -# Set to True to prefer splitting before '&', '|' or '^' rather than -# after. -split_before_bitwise_operator=True - -# Split before a dictionary or set generator (comp_for). For example, note -# the split before the 'for': -# -# foo = { -# variable: 'Hello world, have a nice day!' -# for variable in bar if variable != 42 -# } -split_before_dict_set_generator=True - -# If an argument / parameter list is going to be split, then split before -# the first argument. -split_before_first_argument=False - -# Set to True to prefer splitting before 'and' or 'or' rather than -# after. -split_before_logical_operator=True - -# Split named assignments onto individual lines. -split_before_named_assigns=True - -# The penalty for splitting right after the opening bracket. -split_penalty_after_opening_bracket=30 - -# The penalty for splitting the line after a unary operator. -split_penalty_after_unary_operator=10000 - -# The penalty for splitting right before an if expression. -split_penalty_before_if_expr=0 - -# The penalty of splitting the line around the '&', '|', and '^' -# operators. -split_penalty_bitwise_operator=300 - -# The penalty for characters over the column limit. -split_penalty_excess_character=4500 - -# The penalty incurred by adding a line split to the unwrapped line. The -# more line splits added the higher the penalty. -split_penalty_for_added_line_split=30 - -# The penalty of splitting a list of "import as" names. For example: -# -# from a_very_long_or_indented_module_name_yada_yad import (long_argument_1, -# long_argument_2, -# long_argument_3) -# -# would reformat to something like: -# -# from a_very_long_or_indented_module_name_yada_yad import ( -# long_argument_1, long_argument_2, long_argument_3) -split_penalty_import_names=0 - -# The penalty of splitting the line around the 'and' and 'or' -# operators. -split_penalty_logical_operator=0 - -# Use the Tab character for indentation. -use_tabs=False - -# Without this, yapf likes to write things like -# "foo bar {}". -# format(...) -# which is just awful. -split_before_dot=True diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 685b4a8e93..0000000000 --- a/.travis.yml +++ /dev/null @@ -1,41 +0,0 @@ -os: linux -language: python -dist: bionic - -jobs: - include: - # The pypy tests are slow, so we list them first - - python: pypy3.6-7.2.0 - - language: generic - env: PYPY_NIGHTLY_BRANCH=py3.6 - # Qemu tests are also slow - # The unique thing this provides is testing on the given distro's - # kernel, which is important when we use new kernel features. This - # is also good for testing the latest openssl etc., and getting - # early warning of any issues that might happen in the next Ubuntu - # LTS. - - language: generic - # We use bionic for the host, b/c rumor says that Travis's - # 'bionic' systems have nested KVM enabled. - dist: bionic - env: - - "JOB_NAME='Ubuntu 19.10, full VM'" - - "VM_IMAGE=https://cloud-images.ubuntu.com/eoan/current/eoan-server-cloudimg-amd64.img" - # 3.5.0 and 3.5.1 have different __aiter__ semantics than all - # other versions, so we need to test them specially. Travis's - # newer images only provide 3.5.2+, so we have to request the old - # 'trusty' images. - - python: 3.5.0 - dist: trusty - - python: 3.5-dev - - python: 3.6-dev - - python: 3.7-dev - - python: 3.8-dev - - python: nightly - -script: - - ./ci.sh - -branches: - except: - - /^dependabot/.*/ diff --git a/.yapfignore b/.yapfignore deleted file mode 100644 index 677b760f4c..0000000000 --- a/.yapfignore +++ /dev/null @@ -1 +0,0 @@ -**/_generated* diff --git a/LICENSE.MIT b/LICENSE.MIT index b8bb971859..c26b9f32ae 100644 --- a/LICENSE.MIT +++ b/LICENSE.MIT @@ -1,3 +1,5 @@ +Copyright Contributors to the Trio project. + The MIT License (MIT) Permission is hereby granted, free of charge, to any person obtaining diff --git a/MANIFEST.in b/MANIFEST.in index e2fd4c157f..8b92523fb7 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -2,6 +2,6 @@ include LICENSE LICENSE.MIT LICENSE.APACHE2 include README.rst include CODE_OF_CONDUCT.md CONTRIBUTING.md include test-requirements.txt -recursive-include trio/tests/test_ssl_certs *.pem +recursive-include trio/_tests/test_ssl_certs *.pem recursive-include docs * prune docs/build diff --git a/README.rst b/README.rst index 4fe3ee5ebe..016823e1f5 100644 --- a/README.rst +++ b/README.rst @@ -9,14 +9,14 @@ .. image:: https://img.shields.io/badge/docs-read%20now-blue.svg :target: https://trio.readthedocs.io :alt: Documentation - + .. image:: https://img.shields.io/pypi/v/trio.svg :target: https://pypi.org/project/trio :alt: Latest PyPi version .. image:: https://img.shields.io/conda/vn/conda-forge/trio.svg :target: https://anaconda.org/conda-forge/trio - :alt: Latest conda-forge version + :alt: Latest conda-forge version .. image:: https://codecov.io/gh/python-trio/trio/branch/master/graph/badge.svg :target: https://codecov.io/gh/python-trio/trio @@ -25,28 +25,18 @@ Trio – a friendly Python library for async concurrency and I/O ============================================================== -.. Github carefully breaks rendering of SVG directly out of the repo, - so we have to redirect through cdn.rawgit.com - See: - https://github.com/isaacs/github/issues/316 - https://github.com/github/markup/issues/556#issuecomment-288581799 - I also tried rendering to PNG and linking to that locally, which - "works" in that it displays the image, but for some reason it - ignores the width and align directives, so it's actually pretty - useless... - -.. image:: https://cdn.rawgit.com/python-trio/trio/9b0bec646a31e0d0f67b8b6ecc6939726faf3e17/logo/logo-with-background.svg +.. image:: https://raw.githubusercontent.com/python-trio/trio/9b0bec646a31e0d0f67b8b6ecc6939726faf3e17/logo/logo-with-background.svg :width: 200px :align: right -The Trio project's goal is to produce a production-quality, +The Trio project aims to produce a production-quality, `permissively licensed `__, async/await-native I/O library for Python. Like all async libraries, its main purpose is to help you write programs that do **multiple things at the same time** with **parallelized I/O**. A web spider that wants to fetch lots of pages in parallel, a web server that needs to -juggle lots of downloads and websocket connections at the same time, a +juggle lots of downloads and websocket connections simultaneously, a process supervisor monitoring multiple subprocesses... that sort of thing. Compared to other libraries, Trio attempts to distinguish itself with an obsessive focus on **usability** and @@ -67,11 +57,11 @@ fun. `Perhaps you'll find the same `__. This project is young and still somewhat experimental: the overall -design is solid and the existing features are fully tested and +design is solid, and the existing features are fully tested and documented, but you may encounter missing functionality or rough edges. We *do* encourage you to use it, but you should `read and subscribe to issue #1 -`__ to get warning and a +`__ to get a warning and a chance to give feedback about any compatibility-breaking changes. @@ -102,14 +92,15 @@ demonstration of implementing the "Happy Eyeballs" algorithm in an older library versus Trio. **Cool, but will it work on my system?** Probably! As long as you have -some kind of Python 3.5-or-better (CPython or the latest PyPy3 are -both fine), and are using Linux, macOS, or Windows, then Trio should -absolutely work. *BSD and illumos likely work too, but we don't have -testing infrastructure for them. And all of our dependencies are pure -Python, except for CFFI on Windows, and that has wheels available, so -installation should be easy. - -**I tried it but it's not working.** Sorry to hear that! You can try +some kind of Python 3.8-or-better (CPython or [currently maintained versions of +PyPy3](https://doc.pypy.org/en/latest/faq.html#which-python-versions-does-pypy-implement) +are both fine), and are using Linux, macOS, Windows, or FreeBSD, then Trio +will work. Other environments might work too, but those +are the ones we test on. And all of our dependencies are pure Python, +except for CFFI on Windows, which has wheels available, so +installation should be easy (no C compiler needed). + +**I tried it, but it's not working.** Sorry to hear that! You can try asking for help in our `chat room `__ or `forum `__, `filing a bug @@ -118,7 +109,7 @@ question on StackOverflow `__, and we'll do our best to help you out. -**Trio is awesome and I want to help make it more awesome!** You're +**Trio is awesome, and I want to help make it more awesome!** You're the best! There's tons of work to do – filling in missing functionality, building up an ecosystem of Trio-using libraries, usability testing (e.g., maybe try teaching yourself or a friend to diff --git a/azure-pipelines.yml b/azure-pipelines.yml deleted file mode 100644 index dabc5cb2d2..0000000000 --- a/azure-pipelines.yml +++ /dev/null @@ -1,126 +0,0 @@ -trigger: - branches: - exclude: - - 'dependabot/*' - -jobs: - -- job: 'Windows' - pool: - vmImage: 'windows-latest' - timeoutInMinutes: 20 - strategy: - # Python version list: - # 64-bit: https://www.nuget.org/packages/python/ - # 32-bit: https://www.nuget.org/packages/pythonx86/ - matrix: - # The LSP tests can be super slow for some reason - like - # sometimes it just randomly takes 5 minutes to run the LSP - # installer. So we put them at the top, so they can get started - # earlier. - "with IFS LSP, Python 3.7, 64 bit": - python.version: '3.7.5' - python.pkg: 'python' - lsp: 'http://www.proxifier.com/download/ProxifierSetup.exe' - "with non-IFS LSP, Python 3.7, 64 bit": - python.version: '3.7.5' - python.pkg: 'python' - lsp: 'http://download.pctools.com/mirror/updates/9.0.0.2308-SDavfree-lite_en.exe' - "Python 3.5, 32 bit": - python.version: '3.5.4' - python.pkg: 'pythonx86' - "Python 3.5, 64 bit": - python.version: '3.5.4' - python.pkg: 'python' - "Python 3.6, 32 bit": - python.version: '3.6.8' - python.pkg: 'pythonx86' - "Python 3.6, 64 bit": - python.version: '3.6.8' - python.pkg: 'python' - "Python 3.7, 32 bit": - python.version: '3.7.5' - python.pkg: 'pythonx86' - "Python 3.7, 64 bit": - python.version: '3.7.5' - python.pkg: 'python' - "Python 3.8, 32 bit": - python.version: '3.8.0' - python.pkg: 'pythonx86' - "Python 3.8, 64 bit": - python.version: '3.8.0' - python.pkg: 'python' - - steps: - - task: NuGetToolInstaller@0 - - - bash: ./ci.sh - displayName: "Run the actual tests" - - - task: PublishTestResults@2 - inputs: - testResultsFiles: 'test-results.xml' - testRunTitle: 'Windows $(python.pkg) $(python.version)' - condition: succeededOrFailed() - -- job: 'Linux' - pool: - vmImage: 'ubuntu-latest' - timeoutInMinutes: 10 - strategy: - matrix: - "Check docs": - python.version: '3.8' - CHECK_DOCS: 1 - "Formatting and linting": - python.version: '3.8' - CHECK_FORMATTING: 1 - "Python 3.5": - python.version: '3.5' - "Python 3.6": - python.version: '3.6' - "Python 3.7": - python.version: '3.7' - "Python 3.8": - python.version: '3.8' - - steps: - - task: UsePythonVersion@0 - inputs: - versionSpec: '$(python.version)' - - - bash: ./ci.sh - displayName: "Run the actual tests" - - - task: PublishTestResults@2 - inputs: - testResultsFiles: 'test-results.xml' - condition: succeededOrFailed() - -- job: 'macOS' - pool: - vmImage: 'macOS-latest' - timeoutInMinutes: 10 - strategy: - matrix: - "Python 3.5": - python.version: '3.5' - "Python 3.6": - python.version: '3.6' - "Python 3.7": - python.version: '3.7' - "Python 3.8": - python.version: '3.8' - - steps: - - task: UsePythonVersion@0 - inputs: - versionSpec: '$(python.version)' - - - bash: ./ci.sh - displayName: "Run the actual tests" - - - task: PublishTestResults@2 - inputs: - testResultsFiles: 'test-results.xml' - condition: succeededOrFailed() diff --git a/check.sh b/check.sh index 0b66ca227b..a0efa531b6 100755 --- a/check.sh +++ b/check.sh @@ -13,13 +13,42 @@ python ./trio/_tools/gen_exports.py --test \ # see https://forum.bors.tech/t/pre-test-and-pre-merge-hooks/322) # autoflake --recursive --in-place . # pyupgrade --py3-plus $(find . -name "*.py") -yapf -rpd setup.py trio \ - || EXIT_STATUS=$? +if ! black --check setup.py trio; then + EXIT_STATUS=1 + black --diff setup.py trio +fi -# Run flake8 without pycodestyle and import-related errors -flake8 trio/ \ - --ignore=D,E,W,F401,F403,F405,F821,F822\ - || EXIT_STATUS=$? +if ! isort --check setup.py trio; then + EXIT_STATUS=1 + isort --diff setup.py trio +fi + +# Run flake8, configured in pyproject.toml +flake8 trio/ || EXIT_STATUS=$? + +# Run mypy on all supported platforms +mypy -m trio -m trio.testing --platform linux || EXIT_STATUS=$? +mypy -m trio -m trio.testing --platform darwin || EXIT_STATUS=$? # tests FreeBSD too +mypy -m trio -m trio.testing --platform win32 || EXIT_STATUS=$? + +# Check pip compile is consistent +pip-compile test-requirements.in +pip-compile docs-requirements.in + +if git status --porcelain | grep -q "requirements.txt"; then + git status --porcelain + git --no-pager diff --color *requirements.txt + EXIT_STATUS=1 +fi + +codespell || EXIT_STATUS=$? + +python trio/_tests/check_type_completeness.py --overwrite-file || EXIT_STATUS=$? +if git status --porcelain trio/_tests/verify_types.json | grep -q "M"; then + echo "Type completeness changed, please update!" + git --no-pager diff --color trio/_tests/verify_types.json + EXIT_STATUS=1 +fi # Finally, leave a really clear warning of any issues and exit if [ $EXIT_STATUS -ne 0 ]; then @@ -31,7 +60,8 @@ Problems were found by static analysis (listed above). To fix formatting and see remaining errors, run pip install -r test-requirements.txt - yapf -rpi setup.py trio + black setup.py trio + isort setup.py trio ./check.sh in your local checkout. diff --git a/ci.sh b/ci.sh index 63f1355b3e..ed97ff738b 100755 --- a/ci.sh +++ b/ci.sh @@ -2,18 +2,14 @@ set -ex -o pipefail +# disable warnings about pyright being out of date +# used in test_exports and in check.sh +export PYRIGHT_PYTHON_IGNORE_WARNINGS=1 + # Log some general info about the environment +uname -a env | sort -if [ "$JOB_NAME" = "" ]; then - if [ "$SYSTEM_JOBIDENTIFIER" != "" ]; then - # azure pipelines - JOB_NAME="$SYSTEM_JOBDISPLAYNAME" - else - JOB_NAME="${TRAVIS_OS_NAME}-${TRAVIS_PYTHON_VERSION:-unknown}" - fi -fi - # Curl's built-in retry system is not very robust; it gives up on lots of # network errors that we want to retry on. Wget might work better, but it's # not installed on azure pipelines's windows boxes. So... let's try some good @@ -30,182 +26,6 @@ function curl-harder() { return 1 } -################################################################ -# Bootstrap python environment, if necessary -################################################################ - -### Azure pipelines + Windows ### - -# On azure pipeline's windows VMs, to get reasonable performance, we need to -# jump through hoops to avoid touching the C:\ drive as much as possible. -if [ "$AGENT_OS" = "Windows_NT" ]; then - # By default temp and cache directories are on C:\. Fix that. - export TEMP="${AGENT_TEMPDIRECTORY}" - export TMP="${AGENT_TEMPDIRECTORY}" - export TMPDIR="${AGENT_TEMPDIRECTORY}" - export PIP_CACHE_DIR="${AGENT_TEMPDIRECTORY}\\pip-cache" - - # Download and install Python from scratch onto D:\, instead of using the - # pre-installed versions that azure pipelines provides on C:\. - # Also use -DirectDownload to stop nuget from caching things on C:\. - nuget install "${PYTHON_PKG}" -Version "${PYTHON_VERSION}" \ - -OutputDirectory "$PWD/pyinstall" -ExcludeVersion \ - -Source "https://api.nuget.org/v3/index.json" \ - -Verbosity detailed -DirectDownload -NonInteractive - - pydir="$PWD/pyinstall/${PYTHON_PKG}" - export PATH="${pydir}/tools:${pydir}/tools/scripts:$PATH" - - # Fix an issue with the nuget python 3.5 packages - # https://github.com/python-trio/trio/pull/827#issuecomment-457433940 - rm -f "${pydir}/tools/pyvenv.cfg" || true -fi - -### Travis + macOS ### - -if [ "$TRAVIS_OS_NAME" = "osx" ]; then - JOB_NAME="osx_${MACPYTHON}" - curl-harder -o macpython.pkg https://www.python.org/ftp/python/${MACPYTHON}/python-${MACPYTHON}-macosx10.6.pkg - sudo installer -pkg macpython.pkg -target / - ls /Library/Frameworks/Python.framework/Versions/*/bin/ - PYTHON_EXE=/Library/Frameworks/Python.framework/Versions/*/bin/python3 - # The pip in older MacPython releases doesn't support a new enough TLS - curl-harder -o get-pip.py https://bootstrap.pypa.io/get-pip.py - sudo $PYTHON_EXE get-pip.py - sudo $PYTHON_EXE -m pip install virtualenv - $PYTHON_EXE -m virtualenv testenv - source testenv/bin/activate -fi - -### PyPy nightly (currently on Travis) ### - -if [ "$PYPY_NIGHTLY_BRANCH" != "" ]; then - JOB_NAME="pypy_nightly_${PYPY_NIGHTLY_BRANCH}" - curl-harder -o pypy.tar.bz2 http://buildbot.pypy.org/nightly/${PYPY_NIGHTLY_BRANCH}/pypy-c-jit-latest-linux64.tar.bz2 - if [ ! -s pypy.tar.bz2 ]; then - # We know: - # - curl succeeded (200 response code) - # - nonetheless, pypy.tar.bz2 does not exist, or contains no data - # This isn't going to work, and the failure is not informative of - # anything involving Trio. - ls -l - echo "PyPy3 nightly build failed to download – something is wrong on their end." - echo "Skipping testing against the nightly build for right now." - exit 0 - fi - tar xaf pypy.tar.bz2 - # something like "pypy-c-jit-89963-748aa3022295-linux64" - PYPY_DIR=$(echo pypy-c-jit-*) - PYTHON_EXE=$PYPY_DIR/bin/pypy3 - - if ! ($PYTHON_EXE -m ensurepip \ - && $PYTHON_EXE -m pip install virtualenv \ - && $PYTHON_EXE -m virtualenv testenv); then - echo "pypy nightly is broken; skipping tests" - exit 0 - fi - source testenv/bin/activate -fi - -### Qemu virtual-machine inception, on Travis - -if [ "$VM_IMAGE" != "" ]; then - VM_CPU=${VM_CPU:-x86_64} - - sudo apt update - sudo apt install cloud-image-utils qemu-system-x86 - - # If the base image is already present, we don't try downloading it again; - # and we use a scratch image for the actual run, in order to keep the base - # image file pristine. None of this matters when running in CI, but it - # makes local testing much easier. - BASEIMG=$(basename $VM_IMAGE) - if [ ! -e $BASEIMG ]; then - curl-harder "$VM_IMAGE" -o $BASEIMG - fi - rm -f os-working.img - qemu-img create -f qcow2 -b $BASEIMG os-working.img - - # This is the test script, that runs inside the VM, using cloud-init. - # - # This script goes through shell expansion, so use \ to quote any - # $variables you want to expand inside the guest. - cloud-localds -H test-host seed.img /dev/stdin << EOF -#!/bin/bash - -set -xeuo pipefail - -# When this script exits, we shut down the machine, which causes the qemu on -# the host to exit -trap "poweroff" exit - -uname -a -echo \$PWD -id -cat /etc/lsb-release -cat /proc/cpuinfo - -# Pass-through JOB_NAME + the env vars that codecov-bash looks at -export JOB_NAME="$JOB_NAME" -export CI="$CI" -export TRAVIS="$TRAVIS" -export TRAVIS_COMMIT="$TRAVIS_COMMIT" -export TRAVIS_PULL_REQUEST_SHA="$TRAVIS_PULL_REQUEST_SHA" -export TRAVIS_JOB_NUMBER="$TRAVIS_JOB_NUMBER" -export TRAVIS_PULL_REQUEST="$TRAVIS_PULL_REQUEST" -export TRAVIS_JOB_ID="$TRAVIS_JOB_ID" -export TRAVIS_REPO_SLUG="$TRAVIS_REPO_SLUG" -export TRAVIS_TAG="$TRAVIS_TAG" -export TRAVIS_BRANCH="$TRAVIS_BRANCH" - -env - -mkdir /host-files -mount -t 9p -o trans=virtio,version=9p2000.L host-files /host-files - -# Install and set up the system Python (assumes Debian/Ubuntu) -apt update -apt install -y python3-dev python3-virtualenv git build-essential curl -python3 -m virtualenv -p python3 /venv -# Uses unbound shell variable PS1, so have to allow that temporarily -set +u -source /venv/bin/activate -set -u - -# And then we re-invoke ourselves! -cd /host-files -./ci.sh - -# We can't pass our exit status out. So if we got this far without error, make -# a marker file where the host can see it. -touch /host-files/SUCCESS -EOF - - rm -f SUCCESS - # Apparently Travis's bionic images have nested virtualization enabled, so - # we can use KVM... but the default user isn't in the appropriate groups - # to use KVM, so we have to use 'sudo' to add that. And then a second - # 'sudo', because by default we have rights to run arbitrary commands as - # root, but we don't have rights to run a command as ourselves but with a - # tweaked group setting. - # - # Travis Linux VMs have 7.5 GiB RAM, so we give our nested VM 6 GiB RAM - # (-m 6144). - sudo sudo -u $USER -g kvm qemu-system-$VM_CPU \ - -enable-kvm \ - -M pc \ - -m 6144 \ - -nographic \ - -drive "file=./os-working.img,if=virtio" \ - -drive "file=./seed.img,if=virtio,format=raw" \ - -net nic \ - -net "user,hostfwd=tcp:127.0.0.1:50022-:22" \ - -virtfs local,path=$PWD,security_model=mapped-file,mount_tag=host-files - - test -e SUCCESS - exit -fi - ################################################################ # We have a Python environment! ################################################################ @@ -218,26 +38,50 @@ python -m pip --version python setup.py sdist --formats=zip python -m pip install dist/*.zip -if [ "$CHECK_DOCS" = "1" ]; then - python -m pip install -r docs-requirements.txt - towncrier --yes # catch errors in newsfragments - cd docs - # -n (nit-picky): warn on missing references - # -W: turn warnings into errors - sphinx-build -nW -b html source build -elif [ "$CHECK_FORMATTING" = "1" ]; then +if [ "$CHECK_FORMATTING" = "1" ]; then python -m pip install -r test-requirements.txt source check.sh else # Actual tests python -m pip install -r test-requirements.txt + # So we can run the test for our apport/excepthook interaction working + if [ -e /etc/lsb-release ] && grep -q Ubuntu /etc/lsb-release; then + sudo apt install -q python3-apport + fi + # If we're testing with a LSP installed, then it might break network # stuff, so wait until after we've finished setting everything else # up. if [ "$LSP" != "" ]; then echo "Installing LSP from ${LSP}" - curl-harder -o lsp-installer.exe "$LSP" + # We use --insecure because one of the LSP's has been observed to give + # cert verification errors: + # + # https://github.com/python-trio/trio/issues/1478 + # + # *Normally*, you should never ever use --insecure, especially when + # fetching an executable! But *in this case*, we're intentionally + # installing some untrustworthy quasi-malware onto into a sandboxed + # machine for testing. So MITM attacks are really the least of our + # worries. + if [ "$LSP_EXTRACT_FILE" != "" ]; then + # We host the Astrill VPN installer ourselves, and encrypt it + # so as to decrease the chances of becoming an inadvertent + # public redistributor. + curl-harder -o lsp-installer.zip "$LSP" + unzip -P "not very secret trio ci key" lsp-installer.zip "$LSP_EXTRACT_FILE" + mv "$LSP_EXTRACT_FILE" lsp-installer.exe + else + curl-harder --insecure -o lsp-installer.exe "$LSP" + fi + # This is only needed for the Astrill LSP, but there's no harm in + # doing it all the time. The cert was manually extracted by installing + # the package in a VM, clicking "Always trust from this publisher" + # when installing, and then running 'certmgr.msc' and exporting the + # certificate. See: + # http://www.migee.com/2010/09/24/solution-for-unattendedsilent-installs-and-would-you-like-to-install-this-device-software/ + certutil -addstore "TrustedPublisher" trio/_tests/astrill-codesigning-cert.cer # Double-slashes are how you tell windows-bash that you want a single # slash, and don't treat this as a unix-style filename that needs to # be replaced by a windows-style filename. @@ -250,24 +94,39 @@ else netsh winsock show catalog fi - mkdir empty + # We run the tests from inside an empty directory, to make sure Python + # doesn't pick up any .py files from our working dir. Might have been + # pre-created by some of the code above. + mkdir empty || true cd empty INSTALLDIR=$(python -c "import os, trio; print(os.path.dirname(trio.__file__))") - cp ../setup.cfg $INSTALLDIR - if pytest -W error -r a --junitxml=../test-results.xml --run-slow ${INSTALLDIR} --cov="$INSTALLDIR" --cov-config=../.coveragerc --verbose; then + cp ../pyproject.toml $INSTALLDIR + + # TODO: remove this once we have a py.typed file + touch "$INSTALLDIR/py.typed" + + # get mypy tests a nice cache + MYPYPATH=".." mypy --config-file= --cache-dir=./.mypy_cache -c "import trio" >/dev/null 2>/dev/null || true + + # support subprocess spawning with coverage.py + echo "import coverage; coverage.process_startup()" | tee -a "$INSTALLDIR/../sitecustomize.py" + + if COVERAGE_PROCESS_START=$(pwd)/../.coveragerc coverage run --rcfile=../.coveragerc -m pytest -r a -p trio._tests.pytest_plugin --junitxml=../test-results.xml --run-slow ${INSTALLDIR} --verbose; then PASSED=true else PASSED=false fi + coverage combine --rcfile ../.coveragerc + coverage report -m --rcfile ../.coveragerc + coverage xml --rcfile ../.coveragerc + # Remove the LSP again; again we want to do this ASAP to avoid # accidentally breaking other stuff. if [ "$LSP" != "" ]; then netsh winsock reset fi - bash <(curl-harder -o codecov.sh https://codecov.io/bash) -n "${JOB_NAME}" - $PASSED fi diff --git a/docs-requirements.in b/docs-requirements.in index 23a1b0f652..d6214ec1d0 100644 --- a/docs-requirements.in +++ b/docs-requirements.in @@ -1,18 +1,24 @@ # RTD is currently installing 1.5.3, which has a bug in :lineno-match: -sphinx >= 1.7.0 +# sphinx-3.4 causes warnings about some trio._abc classes: GH#2338 +sphinx >= 1.7.0, < 6.2 +# jinja2-3.1 causes importerror with sphinx<4.0 +jinja2 < 3.1 sphinx_rtd_theme sphinxcontrib-trio towncrier # Trio's own dependencies cffi; os_name == "nt" -contextvars; python_version < "3.7" attrs >= 19.2.0 sortedcontainers async_generator >= 1.9 idna outcome sniffio +exceptiongroup >= 1.0.0rc9 # See note in test-requirements.in immutables >= 0.6 + +# types used in annotations +pyOpenSSL diff --git a/docs-requirements.txt b/docs-requirements.txt index 7e162227ea..fabf3e901a 100644 --- a/docs-requirements.txt +++ b/docs-requirements.txt @@ -1,42 +1,111 @@ # -# This file is autogenerated by pip-compile -# To update, run: +# This file is autogenerated by pip-compile with Python 3.8 +# by the following command: # -# pip-compile --output-file docs-requirements.txt docs-requirements.in +# pip-compile docs-requirements.in # -alabaster==0.7.12 # via sphinx -async-generator==1.10 # via -r docs-requirements.in -attrs==19.3.0 # via -r docs-requirements.in, outcome -babel==2.8.0 # via sphinx -certifi==2019.11.28 # via requests -chardet==3.0.4 # via requests -click==7.1.1 # via towncrier -docutils==0.16 # via sphinx -idna==2.9 # via -r docs-requirements.in, requests -imagesize==1.2.0 # via sphinx -immutables==0.11 # via -r docs-requirements.in -incremental==17.5.0 # via towncrier -jinja2==2.11.1 # via sphinx, towncrier -markupsafe==1.1.1 # via jinja2 -outcome==1.0.1 # via -r docs-requirements.in -packaging==20.3 # via sphinx -pygments==2.6.1 # via sphinx -pyparsing==2.4.6 # via packaging -pytz==2019.3 # via babel -requests==2.23.0 # via sphinx -six==1.14.0 # via packaging -sniffio==1.1.0 # via -r docs-requirements.in -snowballstemmer==2.0.0 # via sphinx -sortedcontainers==2.1.0 # via -r docs-requirements.in -sphinx-rtd-theme==0.4.3 # via -r docs-requirements.in -sphinx==2.4.4 # via -r docs-requirements.in, sphinx-rtd-theme, sphinxcontrib-trio -sphinxcontrib-applehelp==1.0.2 # via sphinx -sphinxcontrib-devhelp==1.0.2 # via sphinx -sphinxcontrib-htmlhelp==1.0.3 # via sphinx -sphinxcontrib-jsmath==1.0.1 # via sphinx -sphinxcontrib-qthelp==1.0.3 # via sphinx -sphinxcontrib-serializinghtml==1.1.4 # via sphinx -sphinxcontrib-trio==1.1.0 # via -r docs-requirements.in -toml==0.10.0 # via towncrier -towncrier==19.2.0 # via -r docs-requirements.in -urllib3==1.25.8 # via requests +alabaster==0.7.13 + # via sphinx +async-generator==1.10 + # via -r docs-requirements.in +attrs==23.1.0 + # via + # -r docs-requirements.in + # outcome +babel==2.12.1 + # via sphinx +certifi==2023.7.22 + # via requests +cffi==1.15.1 + # via cryptography +charset-normalizer==3.2.0 + # via requests +click==8.1.5 + # via + # click-default-group + # towncrier +click-default-group==1.2.2 + # via towncrier +cryptography==41.0.2 + # via pyopenssl +docutils==0.18.1 + # via + # sphinx + # sphinx-rtd-theme +exceptiongroup==1.1.2 + # via -r docs-requirements.in +idna==3.4 + # via + # -r docs-requirements.in + # requests +imagesize==1.4.1 + # via sphinx +immutables==0.19 + # via -r docs-requirements.in +importlib-metadata==6.8.0 + # via sphinx +importlib-resources==6.0.0 + # via towncrier +incremental==22.10.0 + # via towncrier +jinja2==3.0.3 + # via + # -r docs-requirements.in + # sphinx + # towncrier +markupsafe==2.1.3 + # via jinja2 +outcome==1.2.0 + # via -r docs-requirements.in +packaging==23.1 + # via sphinx +pycparser==2.21 + # via cffi +pygments==2.15.1 + # via sphinx +pyopenssl==23.2.0 + # via -r docs-requirements.in +pytz==2023.3 + # via babel +requests==2.31.0 + # via sphinx +sniffio==1.3.0 + # via -r docs-requirements.in +snowballstemmer==2.2.0 + # via sphinx +sortedcontainers==2.4.0 + # via -r docs-requirements.in +sphinx==6.1.3 + # via + # -r docs-requirements.in + # sphinx-rtd-theme + # sphinxcontrib-jquery + # sphinxcontrib-trio +sphinx-rtd-theme==1.2.2 + # via -r docs-requirements.in +sphinxcontrib-applehelp==1.0.4 + # via sphinx +sphinxcontrib-devhelp==1.0.2 + # via sphinx +sphinxcontrib-htmlhelp==2.0.1 + # via sphinx +sphinxcontrib-jquery==4.1 + # via sphinx-rtd-theme +sphinxcontrib-jsmath==1.0.1 + # via sphinx +sphinxcontrib-qthelp==1.0.3 + # via sphinx +sphinxcontrib-serializinghtml==1.1.5 + # via sphinx +sphinxcontrib-trio==1.1.2 + # via -r docs-requirements.in +tomli==2.0.1 + # via towncrier +towncrier==23.6.0 + # via -r docs-requirements.in +urllib3==2.0.3 + # via requests +zipp==3.16.2 + # via + # importlib-metadata + # importlib-resources diff --git a/docs/source/_static/favicon-32.ico b/docs/source/_static/favicon-32.ico deleted file mode 100644 index 1ec20bf8a4..0000000000 Binary files a/docs/source/_static/favicon-32.ico and /dev/null differ diff --git a/docs/source/_templates/layout.html b/docs/source/_templates/layout.html index d3b0ca89fd..dbebf5c2ae 100644 --- a/docs/source/_templates/layout.html +++ b/docs/source/_templates/layout.html @@ -4,8 +4,13 @@ {% extends "!layout.html" %} {% block sidebartitle %} -``, then there will be at least one checkpoint before - each iteration of the loop and one checkpoint after the last - iteration. + trio object>``, then there will be at least one checkpoint in + each iteration of the loop, and it will still checkpoint if the + iterable is empty. * Partial exception for async context managers: Both the entry and exit of an ``async with`` block are @@ -355,7 +355,7 @@ Here's an example:: print("starting...") with trio.move_on_after(5): with trio.move_on_after(10): - await sleep(20) + await trio.sleep(20) print("sleep finished without error") print("move_on_after(10) finished without error") print("move_on_after(5) finished without error") @@ -382,7 +382,7 @@ object representing this cancel scope, which we can use to check whether this scope caught a :exc:`Cancelled` exception:: with trio.move_on_after(5) as cancel_scope: - await sleep(10) + await trio.sleep(10) print(cancel_scope.cancelled_caught) # prints "True" The ``cancel_scope`` object also allows you to check or adjust this @@ -524,6 +524,10 @@ objects. .. autoattribute:: cancel_called +Often there is no need to create :class:`CancelScope` object. Trio +already includes :attr:`~trio.Nursery.cancel_scope` attribute in a +task-related :class:`Nursery` object. We will cover nurseries later in +the manual. Trio also provides several convenience functions for the common situation of just wanting to impose a timeout on some code: @@ -637,7 +641,7 @@ crucial things to keep in mind: * Any unhandled exceptions are re-raised inside the parent task. If there are multiple exceptions, then they're collected up into a - single :exc:`MultiError` exception. + single :exc:`BaseExceptionGroup` or :exc:`ExceptionGroup` exception. Since all tasks are descendents of the initial task, one consequence of this is that :func:`run` can't finish until all tasks have @@ -663,7 +667,7 @@ In Trio, child tasks inherit the parent nursery's cancel scopes. So in this example, both the child tasks will be cancelled when the timeout expires:: - with move_on_after(TIMEOUT): + with trio.move_on_after(TIMEOUT): async with trio.open_nursery() as nursery: nursery.start_soon(child1) nursery.start_soon(child2) @@ -674,15 +678,22 @@ Note that what matters here is the scopes that were active when nothing at all:: async with trio.open_nursery() as nursery: - with move_on_after(TIMEOUT): # don't do this! + with trio.move_on_after(TIMEOUT): # don't do this! nursery.start_soon(child) +Why is this so? Well, ``start_soon()`` returns as soon as it has scheduled the new task to start running. The flow of execution in the parent then continues on to exit the ``with trio.move_on_after(TIMEOUT):`` block, at which point Trio forgets about the timeout entirely. In order for the timeout to apply to the child task, Trio must be able to tell that its associated cancel scope will stay open for at least as long as the child task is executing. And Trio can only know that for sure if the cancel scope block is outside the nursery block. + +You might wonder why Trio can't just remember "this task should be cancelled in ``TIMEOUT`` seconds", even after the ``with trio.move_on_after(TIMEOUT):`` block is gone. The reason has to do with :ref:`how cancellation is implemented `. Recall that cancellation is represented by a `Cancelled` exception, which eventually needs to be caught by the cancel scope that caused it. (Otherwise, the exception would take down your whole program!) In order to be able to cancel the child tasks, the cancel scope has to be able to "see" the `Cancelled` exceptions that they raise -- and those exceptions come out of the ``async with open_nursery()`` block, not out of the call to ``start_soon()``. + +If you want a timeout to apply to one task but not another, then you need to put the cancel scope in that individual task's function -- ``child()``, in this example. + +.. _exceptiongroups: Errors in multiple child tasks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Normally, in Python, only one thing happens at a time, which means -that only one thing can wrong at a time. Trio has no such +that only one thing can go wrong at a time. Trio has no such limitation. Consider code like:: async def broken1(): @@ -700,18 +711,102 @@ limitation. Consider code like:: ``broken1`` raises ``KeyError``. ``broken2`` raises ``IndexError``. Obviously ``parent`` should raise some error, but -what? In some sense, the answer should be "both of these at once", but -in Python there can only be one exception at a time. +what? The answer is that both exceptions are grouped in an :exc:`ExceptionGroup`. +:exc:`ExceptionGroup` and its parent class :exc:`BaseExceptionGroup` are used to +encapsulate multiple exceptions being raised at once. + +To catch individual exceptions encapsulated in an exception group, the ``except*`` +clause was introduced in Python 3.11 (:pep:`654`). Here's how it works:: + + try: + async with trio.open_nursery() as nursery: + nursery.start_soon(broken1) + nursery.start_soon(broken2) + except* KeyError as excgroup: + for exc in excgroup.exceptions: + ... # handle each KeyError + except* IndexError as excgroup: + for exc in excgroup.exceptions: + ... # handle each IndexError + +If you want to reraise exceptions, or raise new ones, you can do so, but be aware that +exceptions raised in ``except*`` sections will be raised together in a new exception +group. + +But what if you can't use ``except*`` just yet? Well, for that there is the handy +exceptiongroup_ library which lets you approximate this behavior with exception handler +callbacks:: + + from exceptiongroup import catch + + def handle_keyerrors(excgroup): + for exc in excgroup.exceptions: + ... # handle each KeyError + + def handle_indexerrors(excgroup): + for exc in excgroup.exceptions: + ... # handle each IndexError + + with catch({ + KeyError: handle_keyerrors, + IndexError: handle_indexerrors + }): + async with trio.open_nursery() as nursery: + nursery.start_soon(broken1) + nursery.start_soon(broken2) + +The semantics for the handler functions are equal to ``except*`` blocks, except for +setting local variables. If you need to set local variables, you need to declare them +inside the handler function(s) with the ``nonlocal`` keyword:: -Trio's answer is that it raises a :exc:`MultiError` object. This is a -special exception which encapsulates multiple exception objects – -either regular exceptions or nested :exc:`MultiError`\s. To make these -easier to work with, Trio installs a custom `sys.excepthook` that -knows how to print nice tracebacks for unhandled :exc:`MultiError`\s, -and it also provides some helpful utilities like -:meth:`MultiError.catch`, which allows you to catch "part of" a -:exc:`MultiError`. + def handle_keyerrors(excgroup): + nonlocal myflag + myflag = True + + myflag = False + with catch({KeyError: handle_keyerrors}): + async with trio.open_nursery() as nursery: + nursery.start_soon(broken1) +For reasons of backwards compatibility, nurseries raise ``trio.MultiError`` and +``trio.NonBaseMultiError`` which inherit from :exc:`BaseExceptionGroup` and +:exc:`ExceptionGroup`, respectively. Users should refrain from attempting to raise or +catch the Trio specific exceptions themselves, and treat them as if they were standard +:exc:`BaseExceptionGroup` or :exc:`ExceptionGroup` instances instead. + +"Strict" versus "loose" ExceptionGroup semantics +++++++++++++++++++++++++++++++++++++++++++++++++ + +Ideally, in some abstract sense we'd want everything that *can* raise an +`ExceptionGroup` to *always* raise an `ExceptionGroup` (rather than, say, a single +`ValueError`). Otherwise, it would be easy to accidentally write something like ``except +ValueError:`` (not ``except*``), which works if a single exception is raised but fails to +catch _anything_ in the case of multiple simultaneous exceptions (even if one of them is +a ValueError). However, this is not how Trio worked in the past: as a concession to +practicality when the ``except*`` syntax hadn't been dreamed up yet, the old +``trio.MultiError`` was raised only when at least two exceptions occurred +simultaneously. Adding a layer of `ExceptionGroup` around every nursery, while +theoretically appealing, would probably break a lot of existing code in practice. + +Therefore, we've chosen to gate the newer, "stricter" behavior behind a parameter +called ``strict_exception_groups``. This is accepted as a parameter to +:func:`open_nursery`, to set the behavior for that nursery, and to :func:`trio.run`, +to set the default behavior for any nursery in your program that doesn't override it. + +* With ``strict_exception_groups=True``, the exception(s) coming out of a nursery will + always be wrapped in an `ExceptionGroup`, so you'll know that if you're handling + single errors correctly, multiple simultaneous errors will work as well. + +* With ``strict_exception_groups=False``, a nursery in which only one task has failed + will raise that task's exception without an additional layer of `ExceptionGroup` + wrapping, so you'll get maximum compatibility with code that was written to + support older versions of Trio. + +To maintain backwards compatibility, the default is ``strict_exception_groups=False``. +The default will eventually change to ``True`` in a future version of Trio, once +Python 3.11 and later versions are in wide use. + +.. _exceptiongroup: https://pypi.org/project/exceptiongroup/ Spawning tasks without becoming a parent ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -754,7 +849,7 @@ example, the timeout does *not* apply to ``child`` (or to anything else):: async def do_spawn(nursery): - with move_on_after(TIMEOUT): # don't do this, it has no effect + with trio.move_on_after(TIMEOUT): # don't do this, it has no effect nursery.start_soon(child) async with trio.open_nursery() as nursery: @@ -785,23 +880,23 @@ finishes first:: if not async_fns: raise ValueError("must pass at least one argument") - send_channel, receive_channel = trio.open_memory_channel(0) + winner = None - async def jockey(async_fn): - await send_channel.send(await async_fn()) + async def jockey(async_fn, cancel_scope): + nonlocal winner + winner = await async_fn() + cancel_scope.cancel() async with trio.open_nursery() as nursery: for async_fn in async_fns: - nursery.start_soon(jockey, async_fn) - winner = await receive_channel.receive() - nursery.cancel_scope.cancel() - return winner + nursery.start_soon(jockey, async_fn, nursery.cancel_scope) + + return winner This works by starting a set of tasks which each try to run their -function, and then report back the value it returns. The main task -uses ``receive_channel.receive`` to wait for one to finish; as soon as -the first task crosses the finish line, it cancels the rest, and then -returns the winning value. +function. As soon as the first function completes its execution, the task will set the nonlocal variable ``winner`` +from the outer scope to the result of the function, and cancel the other tasks using the passed in cancel scope. Once all tasks +have been cancelled (which exits the nursery block), the variable ``winner`` will be returned. Here if one or more of the racing functions raises an unhandled exception then Trio's normal handling kicks in: it cancels the others @@ -827,104 +922,8 @@ The nursery API See :meth:`~Nursery.start`. - -Working with :exc:`MultiError`\s -++++++++++++++++++++++++++++++++ - -.. autoexception:: MultiError - - .. attribute:: exceptions - - The list of exception objects that this :exc:`MultiError` - represents. - - .. automethod:: filter - - .. automethod:: catch - :with: - -Examples: - -Suppose we have a handler function that discards :exc:`ValueError`\s:: - - def handle_ValueError(exc): - if isinstance(exc, ValueError): - return None - else: - return exc - -Then these both raise :exc:`KeyError`:: - - with MultiError.catch(handle_ValueError): - raise MultiError([KeyError(), ValueError()]) - - with MultiError.catch(handle_ValueError): - raise MultiError([ - ValueError(), - MultiError([KeyError(), ValueError()]), - ]) - -And both of these raise nothing at all:: - - with MultiError.catch(handle_ValueError): - raise MultiError([ValueError(), ValueError()]) - - with MultiError.catch(handle_ValueError): - raise MultiError([ - MultiError([ValueError(), ValueError()]), - ValueError(), - ]) - -You can also return a new or modified exception, for example:: - - def convert_ValueError_to_MyCustomError(exc): - if isinstance(exc, ValueError): - # Similar to 'raise MyCustomError from exc' - new_exc = MyCustomError(...) - new_exc.__cause__ = exc - return new_exc - else: - return exc - -In the example above, we set ``__cause__`` as a form of explicit -context chaining. :meth:`MultiError.filter` and -:meth:`MultiError.catch` also perform implicit exception chaining – if -you return a new exception object, then the new object's -``__context__`` attribute will automatically be set to the original -exception. - -We also monkey patch :class:`traceback.TracebackException` to be able -to handle formatting :exc:`MultiError`\s. This means that anything that -formats exception messages like :mod:`logging` will work out of the -box:: - - import logging - - logging.basicConfig() - - try: - raise MultiError([ValueError("foo"), KeyError("bar")]) - except: - logging.exception("Oh no!") - raise - -Will properly log the inner exceptions: - -.. code-block:: none - - ERROR:root:Oh no! - Traceback (most recent call last): - File "", line 2, in - trio.MultiError: ValueError('foo',), KeyError('bar',) - - Details of embedded exception 1: - - ValueError: foo - - Details of embedded exception 2: - - KeyError: 'bar' - +.. autoclass:: TaskStatus + :members: .. _task-local-storage: @@ -977,12 +976,8 @@ work. What we need is something that's *like* a global variable, but that can have different values depending on which request handler is accessing it. -To solve this problem, Python 3.7 added a new module to the standard -library: :mod:`contextvars`. And not only does Trio have built-in -support for :mod:`contextvars`, but if you're using an earlier version -of Python, then Trio makes sure that a backported version of -:mod:`contextvars` is installed. So you can assume :mod:`contextvars` -is there and works regardless of what version of Python you're using. +To solve this problem, Python has a module in the standard +library: :mod:`contextvars`. Here's a toy example demonstrating how to use :mod:`contextvars`: @@ -1012,7 +1007,7 @@ Example output (yours may differ slightly): request 0: Request received finished For more information, read the -`contextvar docs `__. +`contextvars docs `__. .. _synchronization: @@ -1081,7 +1076,7 @@ you'll see that the two tasks politely take turns:: async def loopy_child(number, lock): while True: async with lock: - print("Child {} has the lock!".format(number)) + print(f"Child {number} has the lock!") await trio.sleep(0.5) async def main(): @@ -1099,6 +1094,8 @@ Broadcasting an event with :class:`Event` .. autoclass:: Event :members: +.. autoclass:: EventStatistics + :members: .. _channels: @@ -1172,7 +1169,7 @@ the previous version, and then exits cleanly. The only change is the addition of ``async with`` blocks inside the producer and consumer: .. literalinclude:: reference-core/channels-shutdown.py - :emphasize-lines: 10,15 + :emphasize-lines: 11,17 The really important thing here is the producer's ``async with`` . When the producer exits, this closes the ``send_channel``, and that @@ -1251,7 +1248,7 @@ Fortunately, there's a better way! Here's a fixed version of our program above: .. literalinclude:: reference-core/channels-mpmc-fixed.py - :emphasize-lines: 7, 9, 10, 12, 13 + :emphasize-lines: 8, 10, 11, 13, 14 This example demonstrates using the `MemorySendChannel.clone` and `MemoryReceiveChannel.clone` methods. What these do is create copies @@ -1306,7 +1303,7 @@ produces *backpressure*: if the channel producers are running faster than the consumers, then it forces the producers to slow down. You can disable buffering entirely, by doing -``open_memory_channel(0)``. In that case any task calls +``open_memory_channel(0)``. In that case any task that calls :meth:`~trio.abc.SendChannel.send` will wait until another task calls :meth:`~trio.abc.ReceiveChannel.receive`, and vice versa. This is similar to how channels work in the `classic Communicating Sequential Processes @@ -1433,9 +1430,9 @@ than the lower-level primitives discussed in this section. But if you need them, they're here. (If you find yourself reaching for these because you're trying to implement a new higher-level synchronization primitive, then you might also want to check out the facilities in -:mod:`trio.hazmat` for a more direct exposure of Trio's underlying +:mod:`trio.lowlevel` for a more direct exposure of Trio's underlying synchronization logic. All of classes discussed in this section are -implemented on top of the public APIs in :mod:`trio.hazmat`; they +implemented on top of the public APIs in :mod:`trio.lowlevel`; they don't have any special access to Trio's internals.) .. autoclass:: CapacityLimiter @@ -1444,8 +1441,14 @@ don't have any special access to Trio's internals.) .. autoclass:: Semaphore :members: +.. We have to use :inherited-members: here because all the actual lock + methods are stashed in _LockImpl. Weird side-effect of having both + Lock and StrictFIFOLock, but wanting both to be marked Final so + neither can inherit from the other. + .. autoclass:: Lock :members: + :inherited-members: .. autoclass:: StrictFIFOLock :members: @@ -1453,6 +1456,194 @@ don't have any special access to Trio's internals.) .. autoclass:: Condition :members: +These primitives return statistics objects that can be inspected. + +.. autoclass:: CapacityLimiterStatistics + :members: + +.. autoclass:: LockStatistics + :members: + +.. autoclass:: ConditionStatistics + :members: + +.. _async-generators: + +Notes on async generators +------------------------- + +Python 3.6 added support for *async generators*, which can use +``await``, ``async for``, and ``async with`` in between their ``yield`` +statements. As you might expect, you use ``async for`` to iterate +over them. :pep:`525` has many more details if you want them. + +For example, the following is a roundabout way to print +the numbers 0 through 9 with a 1-second delay before each one:: + + async def range_slowly(*args): + """Like range(), but adds a 1-second sleep before each value.""" + for value in range(*args): + await trio.sleep(1) + yield value + + async def use_it(): + async for value in range_slowly(10): + print(value) + + trio.run(use_it) + +Trio supports async generators, with some caveats described in this section. + +Finalization +~~~~~~~~~~~~ + +If you iterate over an async generator in its entirety, like the +example above does, then the execution of the async generator will +occur completely in the context of the code that's iterating over it, +and there aren't too many surprises. + +If you abandon a partially-completed async generator, though, such as +by ``break``\ing out of the iteration, things aren't so simple. The +async generator iterator object is still alive, waiting for you to +resume iterating it so it can produce more values. At some point, +Python will realize that you've dropped all references to the +iterator, and will call on Trio to throw in a `GeneratorExit` exception +so that any remaining cleanup code inside the generator has a chance +to run: ``finally`` blocks, ``__aexit__`` handlers, and so on. + +So far, so good. Unfortunately, Python provides no guarantees about +*when* this happens. It could be as soon as you break out of the +``async for`` loop, or an arbitrary amount of time later. It could +even be after the entire Trio run has finished! Just about the only +guarantee is that it *won't* happen in the task that was using the +generator. That task will continue on with whatever else it's doing, +and the async generator cleanup will happen "sometime later, +somewhere else": potentially with different context variables, +not subject to timeouts, and/or after any nurseries you're using have +been closed. + +If you don't like that ambiguity, and you want to ensure that a +generator's ``finally`` blocks and ``__aexit__`` handlers execute as +soon as you're done using it, then you'll need to wrap your use of the +generator in something like `async_generator.aclosing() +`__:: + + # Instead of this: + async for value in my_generator(): + if value == 42: + break + + # Do this: + async with aclosing(my_generator()) as aiter: + async for value in aiter: + if value == 42: + break + +This is cumbersome, but Python unfortunately doesn't provide any other +reliable options. If you use ``aclosing()``, then +your generator's cleanup code executes in the same context as the +rest of its iterations, so timeouts, exceptions, and context +variables work like you'd expect. + +If you don't use ``aclosing()``, then Trio will do +its best anyway, but you'll have to contend with the following semantics: + +* The cleanup of the generator occurs in a cancelled context, i.e., + all blocking calls executed during cleanup will raise `Cancelled`. + This is to compensate for the fact that any timeouts surrounding + the original use of the generator have been long since forgotten. + +* The cleanup runs without access to any :ref:`context variables + ` that may have been present when the generator + was originally being used. + +* If the generator raises an exception during cleanup, then it's + printed to the ``trio.async_generator_errors`` logger and otherwise + ignored. + +* If an async generator is still alive at the end of the whole + call to :func:`trio.run`, then it will be cleaned up after all + tasks have exited and before :func:`trio.run` returns. + Since the "system nursery" has already been closed at this point, + Trio isn't able to support any new calls to + :func:`trio.lowlevel.spawn_system_task`. + +If you plan to run your code on PyPy to take advantage of its better +performance, you should be aware that PyPy is *far more likely* than +CPython to perform async generator cleanup at a time well after the +last use of the generator. (This is a consequence of the fact that +PyPy does not use reference counting to manage memory.) To help catch +issues like this, Trio will issue a `ResourceWarning` (ignored by +default, but enabled when running under ``python -X dev`` for example) +for each async generator that needs to be handled through the fallback +finalization path. + +Cancel scopes and nurseries +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. warning:: You may not write a ``yield`` statement that suspends an async generator + inside a `CancelScope` or `Nursery` that was entered within the generator. + +That is, this is OK:: + + async def some_agen(): + with trio.move_on_after(1): + await long_operation() + yield "first" + async with trio.open_nursery() as nursery: + nursery.start_soon(task1) + nursery.start_soon(task2) + yield "second" + ... + +But this is not:: + + async def some_agen(): + with trio.move_on_after(1): + yield "first" + async with trio.open_nursery() as nursery: + yield "second" + ... + +Async generators decorated with ``@asynccontextmanager`` to serve as +the template for an async context manager are *not* subject to this +constraint, because ``@asynccontextmanager`` uses them in a limited +way that doesn't create problems. + +Violating the rule described in this section will sometimes get you a +useful error message, but Trio is not able to detect all such cases, +so sometimes you'll get an unhelpful `TrioInternalError`. (And +sometimes it will seem to work, which is probably the worst outcome of +all, since then you might not notice the issue until you perform some +minor refactoring of the generator or the code that's iterating it, or +just get unlucky. There is a `proposed Python enhancement +`__ +that would at least make it fail consistently.) + +The reason for the restriction on cancel scopes has to do with the +difficulty of noticing when a generator gets suspended and +resumed. The cancel scopes inside the generator shouldn't affect code +running outside the generator, but Trio isn't involved in the process +of exiting and reentering the generator, so it would be hard pressed +to keep its cancellation plumbing in the correct state. Nurseries +use a cancel scope internally, so they have all the problems of cancel +scopes plus a number of problems of their own: for example, when +the generator is suspended, what should the background tasks do? +There's no good way to suspend them, but if they keep running and throw +an exception, where can that exception be reraised? + +If you have an async generator that wants to ``yield`` from within a nursery +or cancel scope, your best bet is to refactor it to be a separate task +that communicates over memory channels. The ``trio_util`` package offers a +`decorator that does this for you transparently +`__. + +For more discussion, see +Trio issues `264 `__ +(especially `this comment +`__) +and `638 `__. + .. _threads: @@ -1627,6 +1818,66 @@ to spawn a child thread, and then use a :ref:`memory channel .. literalinclude:: reference-core/from-thread-example.py +Threads and task-local storage +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +When working with threads, you can use the same `contextvars` we discussed above, +because their values are preserved. + +This is done by automatically copying the `contextvars` context when you use any of: + +* `trio.to_thread.run_sync` +* `trio.from_thread.run` +* `trio.from_thread.run_sync` + +That means that the values of the context variables are accessible even in worker +threads, or when sending a function to be run in the main/parent Trio thread using +`trio.from_thread.run` *from* one of these worker threads. + +But it also means that as the context is not the same but a copy, if you `set` the +context variable value *inside* one of these functions that work in threads, the +new value will only be available in that context (that was copied). So, the new value +will be available for that function and other internal/children tasks, but the value +won't be available in the parent thread. + +If you need to modify values that would live in the context variables and you need to +make those modifications from the child threads, you can instead set a mutable object +(e.g. a dictionary) in the context variable of the top level/parent Trio thread. +Then in the children, instead of setting the context variable, you can ``get`` the same +object, and modify its values. That way you keep the same object in the context +variable and only mutate it in child threads. + +This way, you can modify the object content in child threads and still access the +new content in the parent thread. + +Here's an example: + +.. literalinclude:: reference-core/thread-contextvars-example.py + +Running that script will result in the output: + +.. code-block:: none + + Processed user 2 with message Hello 2 in a thread worker + Processed user 0 with message Hello 0 in a thread worker + Processed user 1 with message Hello 1 in a thread worker + New contextvar value from worker thread for user 2: Hello 2 + New contextvar value from worker thread for user 1: Hello 1 + New contextvar value from worker thread for user 0: Hello 0 + +If you are using ``contextvars`` or you are using a library that uses them, now you +know how they interact when working with threads in Trio. + +But have in mind that in many cases it might be a lot simpler to *not* use context +variables in your own code and instead pass values in arguments, as it might be more +explicit and might be easier to reason about. + +.. note:: + + The context is automatically copied instead of using the same parent context because + a single context can't be used in more than one thread, it's not supported by + ``contextvars``. + Exceptions and warnings ----------------------- diff --git a/docs/source/reference-core/channels-backpressure.py b/docs/source/reference-core/channels-backpressure.py index 50ac67f20a..72cb6900c2 100644 --- a/docs/source/reference-core/channels-backpressure.py +++ b/docs/source/reference-core/channels-backpressure.py @@ -4,6 +4,7 @@ import trio import math + async def producer(send_channel): count = 0 while True: @@ -14,6 +15,7 @@ async def producer(send_channel): print("Sent message:", count) count += 1 + async def consumer(receive_channel): async for value in receive_channel: print("Received message:", value) @@ -21,10 +23,12 @@ async def consumer(receive_channel): # takes 1 second await trio.sleep(1) + async def main(): send_channel, receive_channel = trio.open_memory_channel(math.inf) async with trio.open_nursery() as nursery: nursery.start_soon(producer, send_channel) nursery.start_soon(consumer, receive_channel) + trio.run(main) diff --git a/docs/source/reference-core/channels-mpmc-broken.py b/docs/source/reference-core/channels-mpmc-broken.py index 2a755acba3..7043f0aafa 100644 --- a/docs/source/reference-core/channels-mpmc-broken.py +++ b/docs/source/reference-core/channels-mpmc-broken.py @@ -3,6 +3,7 @@ import trio import random + async def main(): async with trio.open_nursery() as nursery: send_channel, receive_channel = trio.open_memory_channel(0) @@ -13,18 +14,21 @@ async def main(): nursery.start_soon(consumer, "X", receive_channel) nursery.start_soon(consumer, "Y", receive_channel) + async def producer(name, send_channel): async with send_channel: for i in range(3): - await send_channel.send("{} from producer {}".format(i, name)) + await send_channel.send(f"{i} from producer {name}") # Random sleeps help trigger the problem more reliably await trio.sleep(random.random()) + async def consumer(name, receive_channel): async with receive_channel: async for value in receive_channel: - print("consumer {} got value {!r}".format(name, value)) + print(f"consumer {name} got value {value!r}") # Random sleeps help trigger the problem more reliably await trio.sleep(random.random()) + trio.run(main) diff --git a/docs/source/reference-core/channels-mpmc-fixed.py b/docs/source/reference-core/channels-mpmc-fixed.py index a3e7044fe7..986e0d0c31 100644 --- a/docs/source/reference-core/channels-mpmc-fixed.py +++ b/docs/source/reference-core/channels-mpmc-fixed.py @@ -1,6 +1,7 @@ import trio import random + async def main(): async with trio.open_nursery() as nursery: send_channel, receive_channel = trio.open_memory_channel(0) @@ -12,18 +13,21 @@ async def main(): nursery.start_soon(consumer, "X", receive_channel.clone()) nursery.start_soon(consumer, "Y", receive_channel.clone()) + async def producer(name, send_channel): async with send_channel: for i in range(3): - await send_channel.send("{} from producer {}".format(i, name)) + await send_channel.send(f"{i} from producer {name}") # Random sleeps help trigger the problem more reliably await trio.sleep(random.random()) + async def consumer(name, receive_channel): async with receive_channel: async for value in receive_channel: - print("consumer {} got value {!r}".format(name, value)) + print(f"consumer {name} got value {value!r}") # Random sleeps help trigger the problem more reliably await trio.sleep(random.random()) + trio.run(main) diff --git a/docs/source/reference-core/channels-shutdown.py b/docs/source/reference-core/channels-shutdown.py index dcd35767ae..cba4e43801 100644 --- a/docs/source/reference-core/channels-shutdown.py +++ b/docs/source/reference-core/channels-shutdown.py @@ -1,19 +1,23 @@ import trio + async def main(): async with trio.open_nursery() as nursery: send_channel, receive_channel = trio.open_memory_channel(0) nursery.start_soon(producer, send_channel) nursery.start_soon(consumer, receive_channel) + async def producer(send_channel): async with send_channel: for i in range(3): - await send_channel.send("message {}".format(i)) + await send_channel.send(f"message {i}") + async def consumer(receive_channel): async with receive_channel: async for value in receive_channel: - print("got value {!r}".format(value)) + print(f"got value {value!r}") + trio.run(main) diff --git a/docs/source/reference-core/channels-simple.py b/docs/source/reference-core/channels-simple.py index d04ebd722c..e89347379d 100644 --- a/docs/source/reference-core/channels-simple.py +++ b/docs/source/reference-core/channels-simple.py @@ -1,5 +1,6 @@ import trio + async def main(): async with trio.open_nursery() as nursery: # Open a channel: @@ -9,15 +10,18 @@ async def main(): nursery.start_soon(producer, send_channel) nursery.start_soon(consumer, receive_channel) + async def producer(send_channel): # Producer sends 3 messages for i in range(3): # The producer sends using 'await send_channel.send(...)' - await send_channel.send("message {}".format(i)) + await send_channel.send(f"message {i}") + async def consumer(receive_channel): # The consumer uses an 'async for' loop to receive the values: async for value in receive_channel: - print("got value {!r}".format(value)) + print(f"got value {value!r}") + trio.run(main) diff --git a/docs/source/reference-core/contextvar-example.py b/docs/source/reference-core/contextvar-example.py index 7b98355f90..d2caa0d0dc 100644 --- a/docs/source/reference-core/contextvar-example.py +++ b/docs/source/reference-core/contextvar-example.py @@ -10,7 +10,7 @@ def log(msg): # Read from task-local storage: request_tag = request_info.get() - print("request {}: {}".format(request_tag, msg)) + print(f"request {request_tag}: {msg}") # An example "request handler" that does some work itself and also @@ -29,9 +29,9 @@ async def handle_request(tag): async def concurrent_helper(job): - log("Helper task {} started".format(job)) + log(f"Helper task {job} started") await trio.sleep(random.random()) - log("Helper task {} finished".format(job)) + log(f"Helper task {job} finished") # Spawn several "request handlers" simultaneously, to simulate a diff --git a/docs/source/reference-core/from-thread-example.py b/docs/source/reference-core/from-thread-example.py index 71a75d67bf..b95a0e00d3 100644 --- a/docs/source/reference-core/from-thread-example.py +++ b/docs/source/reference-core/from-thread-example.py @@ -21,7 +21,7 @@ async def main(): async with trio.open_nursery() as nursery: # In a background thread, run: - # thread_fn(portal, receive_from_trio, send_to_trio) + # thread_fn(receive_from_trio, send_to_trio) nursery.start_soon( trio.to_thread.run_sync, thread_fn, receive_from_trio, send_to_trio ) diff --git a/docs/source/reference-core/thread-contextvars-example.py b/docs/source/reference-core/thread-contextvars-example.py new file mode 100644 index 0000000000..d6c062bea4 --- /dev/null +++ b/docs/source/reference-core/thread-contextvars-example.py @@ -0,0 +1,47 @@ +import contextvars +import time + +import trio + +request_state = contextvars.ContextVar("request_state") + +# Blocking function that should be run on a thread +# It could be reading or writing files, communicating with a database +# with a driver not compatible with async / await, etc. +def work_in_thread(msg): + # Only use request_state.get() inside the worker thread + state_value = request_state.get() + current_user_id = state_value["current_user_id"] + time.sleep(3) # this would be some blocking call, like reading a file + print(f"Processed user {current_user_id} with message {msg} in a thread worker") + # Modify/mutate the state object, without setting the entire + # contextvar with request_state.set() + state_value["msg"] = msg + + +# An example "request handler" that does some work itself and also +# spawns some helper tasks in threads to execute blocking code. +async def handle_request(current_user_id): + # Write to task-local storage: + current_state = {"current_user_id": current_user_id, "msg": ""} + request_state.set(current_state) + + # Here the current implicit contextvars context will be automatically copied + # inside the worker thread + await trio.to_thread.run_sync(work_in_thread, f"Hello {current_user_id}") + # Extract the value set inside the thread in the same object stored in a contextvar + new_msg = current_state["msg"] + print( + f"New contextvar value from worker thread for user {current_user_id}: {new_msg}" + ) + + +# Spawn several "request handlers" simultaneously, to simulate a +# busy server handling multiple requests at the same time. +async def main(): + async with trio.open_nursery() as nursery: + for i in range(3): + nursery.start_soon(handle_request, i) + + +trio.run(main) diff --git a/docs/source/reference-hazmat.rst b/docs/source/reference-hazmat.rst deleted file mode 100644 index 7f32ebf985..0000000000 --- a/docs/source/reference-hazmat.rst +++ /dev/null @@ -1,563 +0,0 @@ -======================================================= - Introspecting and extending Trio with ``trio.hazmat`` -======================================================= - -.. module:: trio.hazmat - -.. warning:: - You probably don't want to use this module. - -:mod:`trio.hazmat` is Trio's "hazardous materials" layer: it contains -APIs useful for introspecting and extending Trio. If you're writing -ordinary, everyday code, then you can ignore this module completely. -But sometimes you need something a bit lower level. Here are some -examples of situations where you should reach for :mod:`trio.hazmat`: - -* You want to implement a new :ref:`synchronization primitive - ` that Trio doesn't (yet) provide, like a - reader-writer lock. -* You want to extract low-level metrics to monitor the health of your - application. -* You want to add support for a low-level operating system interface - that Trio doesn't (yet) expose, like watching a filesystem directory - for changes. -* You want to implement an interface for calling between Trio and - another event loop within the same process. -* You're writing a debugger and want to visualize Trio's task tree. -* You need to interoperate with a C library whose API exposes raw file - descriptors. - -Using :mod:`trio.hazmat` isn't really *that* hazardous; in fact you're -already using it – it's how most of the functionality described in -previous chapters is implemented. The APIs described here have -strictly defined and carefully documented semantics, and are perfectly -safe – *if* you read carefully and take proper precautions. Some of -those strict semantics have `nasty big pointy teeth -`__. If you make a -mistake, Trio may not be able to handle it gracefully; conventions and -guarantees that are followed strictly in the rest of Trio do not -always apply. Using this module makes it your responsibility to think -through and handle the nasty cases to expose a friendly Trio-style API -to your users. - - -Debugging and instrumentation -============================= - -Trio tries hard to provide useful hooks for debugging and -instrumentation. Some are documented above (the nursery introspection -attributes, :meth:`trio.Lock.statistics`, etc.). Here are some more. - - -Global statistics ------------------ - -.. autofunction:: current_statistics - - -The current clock ------------------ - -.. autofunction:: current_clock - - -.. _instrumentation: - -Instrument API --------------- - -The instrument API provides a standard way to add custom -instrumentation to the run loop. Want to make a histogram of -scheduling latencies, log a stack trace of any task that blocks the -run loop for >50 ms, or measure what percentage of your process's -running time is spent waiting for I/O? This is the place. - -The general idea is that at any given moment, :func:`trio.run` -maintains a set of "instruments", which are objects that implement the -:class:`trio.abc.Instrument` interface. When an interesting event -happens, it loops over these instruments and notifies them by calling -an appropriate method. The tutorial has :ref:`a simple example of -using this for tracing `. - -Since this hooks into Trio at a rather low level, you do have to be -careful. The callbacks are run synchronously, and in many cases if -they error out then there isn't any plausible way to propagate this -exception (for instance, we might be deep in the guts of the exception -propagation machinery...). Therefore our `current strategy -`__ for handling -exceptions raised by instruments is to (a) log an exception to the -``"trio.abc.Instrument"`` logger, which by default prints a stack -trace to standard error and (b) disable the offending instrument. - -You can register an initial list of instruments by passing them to -:func:`trio.run`. :func:`add_instrument` and -:func:`remove_instrument` let you add and remove instruments at -runtime. - -.. autofunction:: add_instrument - -.. autofunction:: remove_instrument - -And here's the interface to implement if you want to build your own -:class:`~trio.abc.Instrument`: - -.. autoclass:: trio.abc.Instrument - :members: - -The tutorial has a :ref:`fully-worked example -` of defining a custom instrument to log -Trio's internal scheduling decisions. - - -Low-level I/O primitives -======================== - -Different environments expose different low-level APIs for performing -async I/O. :mod:`trio.hazmat` exposes these APIs in a relatively -direct way, so as to allow maximum power and flexibility for higher -level code. However, this means that the exact API provided may vary -depending on what system Trio is running on. - - -Universally available API -------------------------- - -All environments provide the following functions: - -.. function:: wait_readable(obj) - :async: - - Block until the kernel reports that the given object is readable. - - On Unix systems, ``obj`` must either be an integer file descriptor, - or else an object with a ``.fileno()`` method which returns an - integer file descriptor. Any kind of file descriptor can be passed, - though the exact semantics will depend on your kernel. For example, - this probably won't do anything useful for on-disk files. - - On Windows systems, ``obj`` must either be an integer ``SOCKET`` - handle, or else an object with a ``.fileno()`` method which returns - an integer ``SOCKET`` handle. File descriptors aren't supported, - and neither are handles that refer to anything besides a - ``SOCKET``. - - :raises trio.BusyResourceError: - if another task is already waiting for the given socket to - become readable. - :raises trio.ClosedResourceError: - if another task calls :func:`notify_closing` while this - function is still working. - -.. function:: wait_writable(obj) - :async: - - Block until the kernel reports that the given object is writable. - - See `wait_readable` for the definition of ``obj``. - - :raises trio.BusyResourceError: - if another task is already waiting for the given socket to - become writable. - :raises trio.ClosedResourceError: - if another task calls :func:`notify_closing` while this - function is still working. - - -.. function:: notify_closing(obj) - - Call this before closing a file descriptor (on Unix) or socket (on - Windows). This will cause any `wait_readable` or `wait_writable` - calls on the given object to immediately wake up and raise - `~trio.ClosedResourceError`. - - This doesn't actually close the object – you still have to do that - yourself afterwards. Also, you want to be careful to make sure no - new tasks start waiting on the object in between when you call this - and when it's actually closed. So to close something properly, you - usually want to do these steps in order: - - 1. Explicitly mark the object as closed, so that any new attempts - to use it will abort before they start. - 2. Call `notify_closing` to wake up any already-existing users. - 3. Actually close the object. - - It's also possible to do them in a different order if that's more - convenient, *but only if* you make sure not to have any checkpoints in - between the steps. This way they all happen in a single atomic - step, so other tasks won't be able to tell what order they happened - in anyway. - - -Unix-specific API ------------------ - -`FdStream` supports wrapping Unix files (such as a pipe or TTY) as -a stream. - -If you have two different file descriptors for sending and receiving, -and want to bundle them together into a single bidirectional -`~trio.abc.Stream`, then use `trio.StapledStream`:: - - bidirectional_stream = trio.StapledStream( - trio.hazmat.FdStream(write_fd), - trio.hazmat.FdStream(read_fd) - ) - -.. autoclass:: FdStream - :show-inheritance: - :members: - - -Kqueue-specific API -------------------- - -TODO: these are implemented, but are currently more of a sketch than -anything real. See `#26 -`__. - -.. function:: current_kqueue() - -.. function:: wait_kevent(ident, filter, abort_func) - :async: - -.. function:: monitor_kevent(ident, filter) - :with: queue - - -Windows-specific API --------------------- - -.. function:: WaitForSingleObject(handle) - :async: - - Async and cancellable variant of `WaitForSingleObject - `__. - Windows only. - - :arg handle: - A Win32 object handle, as a Python integer. - :raises OSError: - If the handle is invalid, e.g. when it is already closed. - - -TODO: these are implemented, but are currently more of a sketch than -anything real. See `#26 -`__ and `#52 -`__. - -.. function:: register_with_iocp(handle) - -.. function:: wait_overlapped(handle, lpOverlapped) - :async: - -.. function:: current_iocp() - -.. function:: monitor_completion_key() - :with: queue - - -Global state: system tasks and run-local variables -================================================== - -.. autoclass:: RunVar - -.. autofunction:: spawn_system_task - - -Trio tokens -=========== - -.. autoclass:: TrioToken() - :members: - -.. autofunction:: current_trio_token - - -Safer KeyboardInterrupt handling -================================ - -Trio's handling of control-C is designed to balance usability and -safety. On the one hand, there are sensitive regions (like the core -scheduling loop) where it's simply impossible to handle arbitrary -:exc:`KeyboardInterrupt` exceptions while maintaining our core -correctness invariants. On the other, if the user accidentally writes -an infinite loop, we do want to be able to break out of that. Our -solution is to install a default signal handler which checks whether -it's safe to raise :exc:`KeyboardInterrupt` at the place where the -signal is received. If so, then we do; otherwise, we schedule a -:exc:`KeyboardInterrupt` to be delivered to the main task at the next -available opportunity (similar to how :exc:`~trio.Cancelled` is -delivered). - -So that's great, but – how do we know whether we're in one of the -sensitive parts of the program or not? - -This is determined on a function-by-function basis. By default, a -function is protected if its caller is, and not if its caller isn't; -this is helpful because it means you only need to override the -defaults at places where you transition from protected code to -unprotected code or vice-versa. - -These transitions are accomplished using two function decorators: - -.. function:: disable_ki_protection() - :decorator: - - Decorator that marks the given regular function, generator - function, async function, or async generator function as - unprotected against :exc:`KeyboardInterrupt`, i.e., the code inside - this function *can* be rudely interrupted by - :exc:`KeyboardInterrupt` at any moment. - - If you have multiple decorators on the same function, then this - should be at the bottom of the stack (closest to the actual - function). - - An example of where you'd use this is in implementing something - like :func:`trio.from_thread.run`, which uses - :meth:`TrioToken.run_sync_soon` to get into the Trio - thread. :meth:`~TrioToken.run_sync_soon` callbacks are run with - :exc:`KeyboardInterrupt` protection enabled, and - :func:`trio.from_thread.run` takes advantage of this to safely set up - the machinery for sending a response back to the original thread, but - then uses :func:`disable_ki_protection` when entering the - user-provided function. - -.. function:: enable_ki_protection() - :decorator: - - Decorator that marks the given regular function, generator - function, async function, or async generator function as protected - against :exc:`KeyboardInterrupt`, i.e., the code inside this - function *won't* be rudely interrupted by - :exc:`KeyboardInterrupt`. (Though if it contains any - :ref:`checkpoints `, then it can still receive - :exc:`KeyboardInterrupt` at those. This is considered a polite - interruption.) - - .. warning:: - - Be very careful to only use this decorator on functions that you - know will either exit in bounded time, or else pass through a - checkpoint regularly. (Of course all of your functions should - have this property, but if you mess it up here then you won't - even be able to use control-C to escape!) - - If you have multiple decorators on the same function, then this - should be at the bottom of the stack (closest to the actual - function). - - An example of where you'd use this is on the ``__exit__`` - implementation for something like a :class:`~trio.Lock`, where a - poorly-timed :exc:`KeyboardInterrupt` could leave the lock in an - inconsistent state and cause a deadlock. - -.. autofunction:: currently_ki_protected - - -Sleeping and waking -=================== - -Wait queue abstraction ----------------------- - -.. autoclass:: ParkingLot - :members: - :undoc-members: - - -Low-level checkpoint functions ------------------------------- - -.. autofunction:: checkpoint - -The next two functions are used *together* to make up a checkpoint: - -.. autofunction:: checkpoint_if_cancelled -.. autofunction:: cancel_shielded_checkpoint - -These are commonly used in cases where you have an operation that -might-or-might-not block, and you want to implement Trio's standard -checkpoint semantics. Example:: - - async def operation_that_maybe_blocks(): - await checkpoint_if_cancelled() - try: - ret = attempt_operation() - except BlockingIOError: - # need to block and then retry, which we do below - pass - else: - # operation succeeded, finish the checkpoint then return - await cancel_shielded_checkpoint() - return ret - while True: - await wait_for_operation_to_be_ready() - try: - return attempt_operation() - except BlockingIOError: - pass - -This logic is a bit convoluted, but accomplishes all of the following: - -* Every successful execution path passes through a checkpoint (assuming that - ``wait_for_operation_to_be_ready`` is an unconditional checkpoint) - -* Our :ref:`cancellation semantics ` say that - :exc:`~trio.Cancelled` should only be raised if the operation didn't - happen. Using :func:`cancel_shielded_checkpoint` on the early-exit - branch accomplishes this. - -* On the path where we do end up blocking, we don't pass through any - schedule points before that, which avoids some unnecessary work. - -* Avoids implicitly chaining the :exc:`BlockingIOError` with any - errors raised by ``attempt_operation`` or - ``wait_for_operation_to_be_ready``, by keeping the ``while True:`` - loop outside of the ``except BlockingIOError:`` block. - -These functions can also be useful in other situations. For example, -when :func:`trio.to_thread.run_sync` schedules some work to run in a -worker thread, it blocks until the work is finished (so it's a -schedule point), but by default it doesn't allow cancellation. So to -make sure that the call always acts as a checkpoint, it calls -:func:`checkpoint_if_cancelled` before starting the thread. - - -Low-level blocking ------------------- - -.. autofunction:: wait_task_rescheduled -.. autoclass:: Abort -.. autofunction:: reschedule - -Here's an example lock class implemented using -:func:`wait_task_rescheduled` directly. This implementation has a -number of flaws, including lack of fairness, O(n) cancellation, -missing error checking, failure to insert a checkpoint on the -non-blocking path, etc. If you really want to implement your own lock, -then you should study the implementation of :class:`trio.Lock` and use -:class:`ParkingLot`, which handles some of these issues for you. But -this does serve to illustrate the basic structure of the -:func:`wait_task_rescheduled` API:: - - class NotVeryGoodLock: - def __init__(self): - self._blocked_tasks = collections.deque() - self._held = False - - async def acquire(self): - while self._held: - task = trio.current_task() - self._blocked_tasks.append(task) - def abort_fn(_): - self._blocked_tasks.remove(task) - return trio.hazmat.Abort.SUCCEEDED - await trio.hazmat.wait_task_rescheduled(abort_fn) - self._held = True - - def release(self): - self._held = False - if self._blocked_tasks: - woken_task = self._blocked_tasks.popleft() - trio.hazmat.reschedule(woken_task) - - -Task API --------- - -.. autofunction:: current_root_task() - -.. autofunction:: current_task() - -.. class:: Task() - - A :class:`Task` object represents a concurrent "thread" of - execution. It has no public constructor; Trio internally creates a - :class:`Task` object for each call to ``nursery.start(...)`` or - ``nursery.start_soon(...)``. - - Its public members are mostly useful for introspection and - debugging: - - .. attribute:: name - - String containing this :class:`Task`\'s name. Usually the name - of the function this :class:`Task` is running, but can be - overridden by passing ``name=`` to ``start`` or ``start_soon``. - - .. attribute:: coro - - This task's coroutine object. Example usage: extracting a stack - trace:: - - import traceback - - def walk_coro_stack(coro): - while coro is not None: - if hasattr(coro, "cr_frame"): - # A real coroutine - yield coro.cr_frame, coro.cr_frame.f_lineno - coro = coro.cr_await - else: - # A generator decorated with @types.coroutine - yield coro.gi_frame, coro.gi_frame.f_lineno - coro = coro.gi_yieldfrom - - def print_stack_for_task(task): - ss = traceback.StackSummary.extract(walk_coro_stack(task.coro)) - print("".join(ss.format())) - - .. attribute:: context - - This task's :class:`contextvars.Context` object. - - .. autoattribute:: parent_nursery - - .. autoattribute:: child_nurseries - - .. attribute:: custom_sleep_data - - Trio doesn't assign this variable any meaning, except that it - sets it to ``None`` whenever a task is rescheduled. It can be - used to share data between the different tasks involved in - putting a task to sleep and then waking it up again. (See - :func:`wait_task_rescheduled` for details.) - - -.. _live-coroutine-handoff: - -Handing off live coroutine objects between coroutine runners ------------------------------------------------------------- - -Internally, Python's async/await syntax is built around the idea of -"coroutine objects" and "coroutine runners". A coroutine object -represents the state of an async callstack. But by itself, this is -just a static object that sits there. If you want it to do anything, -you need a coroutine runner to push it forward. Every Trio task has an -associated coroutine object (see :data:`Task.coro`), and the Trio -scheduler acts as their coroutine runner. - -But of course, Trio isn't the only coroutine runner in Python – -:mod:`asyncio` has one, other event loops have them, you can even -define your own. - -And in some very, very unusual circumstances, it even makes sense to -transfer a single coroutine object back and forth between different -coroutine runners. That's what this section is about. This is an -*extremely* exotic use case, and assumes a lot of expertise in how -Python async/await works internally. For motivating examples, see -`trio-asyncio issue #42 -`__, and `trio -issue #649 `__. For -more details on how coroutines work, we recommend André Caron's `A -tale of event loops -`__, or -going straight to `PEP 492 -`__ for the full details. - -.. autofunction:: permanently_detach_coroutine_object - -.. autofunction:: temporarily_detach_coroutine_object - -.. autofunction:: reattach_detached_coroutine_object diff --git a/docs/source/reference-io.rst b/docs/source/reference-io.rst index 7c700b1328..9207afb41b 100644 --- a/docs/source/reference-io.rst +++ b/docs/source/reference-io.rst @@ -237,8 +237,7 @@ other constants and functions in the :mod:`ssl` module. .. warning:: Avoid instantiating :class:`ssl.SSLContext` directly. A newly constructed :class:`~ssl.SSLContext` has less secure - defaults than one returned by :func:`ssl.create_default_context`, - dramatically so before Python 3.6. + defaults than one returned by :func:`ssl.create_default_context`. Instead of using :meth:`ssl.SSLContext.wrap_socket`, you create a :class:`SSLStream`: @@ -259,6 +258,55 @@ you call them before the handshake completes: .. autoexception:: NeedHandshakeError +Datagram TLS support +~~~~~~~~~~~~~~~~~~~~ + +Trio also has support for Datagram TLS (DTLS), which is like TLS but +for unreliable UDP connections. This can be useful for applications +where TCP's reliable in-order delivery is problematic, like +teleconferencing, latency-sensitive games, and VPNs. + +Currently, using DTLS with Trio requires PyOpenSSL. We hope to +eventually allow the use of the stdlib `ssl` module as well, but +unfortunately that's not yet possible. + +.. warning:: Note that PyOpenSSL is in many ways lower-level than the + `ssl` module – in particular, it currently **HAS NO BUILT-IN + MECHANISM TO VALIDATE CERTIFICATES**. We *strongly* recommend that + you use the `service-identity + `__ library to validate + hostnames and certificates. + +.. autoclass:: DTLSEndpoint + + .. automethod:: connect + + .. automethod:: serve + + .. automethod:: close + +.. autoclass:: DTLSChannel + :show-inheritance: + + .. automethod:: do_handshake + + .. automethod:: send + + .. automethod:: receive + + .. automethod:: close + + .. automethod:: aclose + + .. automethod:: set_ciphertext_mtu + + .. automethod:: get_cleartext_mtu + + .. automethod:: statistics + +.. autoclass:: DTLSChannelStatistics + :members: + .. module:: trio.socket Low-level networking with :mod:`trio.socket` @@ -301,7 +349,7 @@ library socket into a Trio socket: .. autofunction:: from_stdlib_socket -Unlike :func:`socket.socket`, :func:`trio.socket.socket` is a +Unlike :class:`socket.socket`, :func:`trio.socket.socket` is a function, not a class; if you want to check whether an object is a Trio socket, use ``isinstance(obj, trio.socket.SocketType)``. @@ -380,7 +428,7 @@ Socket objects additional error checking. In addition, the following methods are similar to the equivalents - in :func:`socket.socket`, but have some Trio-specific quirks: + in :class:`socket.socket`, but have some Trio-specific quirks: .. method:: connect :async: @@ -421,7 +469,7 @@ Socket objects False otherwise. The following methods are identical to their equivalents in - :func:`socket.socket`, except async, and the ones that take address + :class:`socket.socket`, except async, and the ones that take address arguments require pre-resolved addresses: * :meth:`~socket.socket.accept` @@ -437,7 +485,7 @@ Socket objects * :meth:`~socket.socket.sendmsg` (if available) All methods and attributes *not* mentioned above are identical to - their equivalents in :func:`socket.socket`: + their equivalents in :class:`socket.socket`: * :attr:`~socket.socket.family` * :attr:`~socket.socket.type` @@ -456,6 +504,14 @@ Socket objects * :meth:`~socket.socket.set_inheritable` * :meth:`~socket.socket.get_inheritable` +The internal SocketType +~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: _SocketType +.. + TODO: adding `:members:` here gives error due to overload+_wraps on `sendto` + TODO: rewrite ... all of the above when fixing _SocketType vs SocketType + + .. currentmodule:: trio @@ -472,8 +528,7 @@ people switch to async I/O, and then they're surprised and confused when they find it doesn't speed up their program. The next section explains the theory behind async file I/O, to help you better understand your code's behavior. Or, if you just want to get started, -you can `jump down to the API overview -`__. +you can :ref:`jump down to the API overview `. Background: Why is async file I/O useful? The answer may surprise you @@ -590,9 +645,11 @@ Asynchronous path objects Asynchronous file objects ~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: open_file +.. Suppress type annotations here, they refer to lots of internal types. + The normal Python docs go into better detail. +.. autofunction:: open_file(file, mode='r', buffering=-1, encoding=None, errors=None, newline=None, closefd=None, opener=None) -.. autofunction:: wrap_file +.. autofunction:: wrap_file(file) .. interface:: Asynchronous file interface @@ -665,22 +722,43 @@ Spawning subprocesses Trio provides support for spawning other programs as subprocesses, communicating with them via pipes, sending them signals, and waiting -for them to exit. The interface for doing so consists of two layers: +for them to exit. + +Most of the time, this is done through our high-level interface, +`trio.run_process`. It lets you either run a process to completion +while optionally capturing the output, or else run it in a background +task and interact with it while it's running: + +.. autofunction:: trio.run_process + +.. autoclass:: trio.Process + + .. autoattribute:: returncode + + .. automethod:: wait + + .. automethod:: poll -* :func:`trio.run_process` runs a process from start to - finish and returns a :class:`~subprocess.CompletedProcess` object describing - its outputs and return value. This is what you should reach for if you - want to run a process to completion before continuing, while possibly - sending it some input or capturing its output. It is modelled after - the standard :func:`subprocess.run` with some additional features - and safer defaults. + .. automethod:: kill -* `trio.open_process` starts a process in the background and returns a - `Process` object to let you interact with it. Using it requires a - bit more code than `run_process`, but exposes additional - capabilities: back-and-forth communication, processing output as - soon as it is generated, and so forth. It is modelled after the - standard library :class:`subprocess.Popen`. + .. automethod:: terminate + + .. automethod:: send_signal + + .. note:: :meth:`~subprocess.Popen.communicate` is not provided as a + method on :class:`~trio.Process` objects; call :func:`~trio.run_process` + normally for simple capturing, or write the loop yourself if you + have unusual needs. :meth:`~subprocess.Popen.communicate` has + quite unusual cancellation behavior in the standard library (on + some platforms it spawns a background thread which continues to + read from the child process even after the timeout has expired) + and we wanted to provide an interface with fewer surprises. + +If `trio.run_process` is too limiting, we also offer a low-level API, +`trio.lowlevel.open_process`. For example, if you want to spawn a +child process that will outlive the parent process and be +orphaned, then `~trio.run_process` can't do that, but +`~trio.lowlevel.open_process` can. .. _subprocess-options: @@ -702,62 +780,10 @@ subprocess`` in order to access constants such as ``PIPE`` or Currently, Trio always uses unbuffered byte streams for communicating with a process, so it does not support the ``encoding``, ``errors``, -``universal_newlines`` (alias ``text`` in 3.7+), and ``bufsize`` +``universal_newlines`` (alias ``text``), and ``bufsize`` options. -Running a process and waiting for it to finish -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -The basic interface for running a subprocess start-to-finish is -:func:`trio.run_process`. It always waits for the subprocess to exit -before returning, so there's no need to worry about leaving a process -running by mistake after you've gone on to do other things. -:func:`~trio.run_process` is similar to the standard library -:func:`subprocess.run` function, but tries to have safer defaults: -with no options, the subprocess's input is empty rather than coming -from the user's terminal, and a failure in the subprocess will be -propagated as a :exc:`subprocess.CalledProcessError` exception. Of -course, these defaults can be changed where necessary. - -.. autofunction:: trio.run_process - - -Interacting with a process as it runs -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -If you want more control than :func:`~trio.run_process` affords, you -can use `trio.open_process` to spawn a subprocess, and then interact -with it using the `Process` interface. - -.. autofunction:: trio.open_process - -.. autoclass:: trio.Process - - .. autoattribute:: returncode - - .. automethod:: aclose - - .. automethod:: wait - - .. automethod:: poll - - .. automethod:: kill - - .. automethod:: terminate - - .. automethod:: send_signal - - .. note:: :meth:`~subprocess.Popen.communicate` is not provided as a - method on :class:`~trio.Process` objects; use :func:`~trio.run_process` - instead, or write the loop yourself if you have unusual - needs. :meth:`~subprocess.Popen.communicate` has quite unusual - cancellation behavior in the standard library (on some platforms it - spawns a background thread which continues to read from the child - process even after the timeout has expired) and we wanted to - provide an interface with fewer surprises. - - .. _subprocess-quoting: Quoting: more than you wanted to know diff --git a/docs/source/reference-lowlevel.rst b/docs/source/reference-lowlevel.rst new file mode 100644 index 0000000000..bacebff5ad --- /dev/null +++ b/docs/source/reference-lowlevel.rst @@ -0,0 +1,973 @@ +========================================================= + Introspecting and extending Trio with ``trio.lowlevel`` +========================================================= + +.. module:: trio.lowlevel + +:mod:`trio.lowlevel` contains low-level APIs for introspecting and +extending Trio. If you're writing ordinary, everyday code, then you +can ignore this module completely. But sometimes you need something a +bit lower level. Here are some examples of situations where you should +reach for :mod:`trio.lowlevel`: + +* You want to implement a new :ref:`synchronization primitive + ` that Trio doesn't (yet) provide, like a + reader-writer lock. +* You want to extract low-level metrics to monitor the health of your + application. +* You want to use a low-level operating system interface that Trio + doesn't (yet) provide its own wrappers for, like watching a + filesystem directory for changes. +* You want to implement an interface for calling between Trio and + another event loop within the same process. +* You're writing a debugger and want to visualize Trio's task tree. +* You need to interoperate with a C library whose API exposes raw file + descriptors. + +You don't need to be scared of :mod:`trio.lowlevel`, as long as you +take proper precautions. These are real public APIs, with strictly +defined and carefully documented semantics. They're the same tools we +use to implement all the nice high-level APIs in the :mod:`trio` +namespace. But, be careful. Some of those strict semantics have `nasty +big pointy teeth +`__. If you make a +mistake, Trio may not be able to handle it gracefully; conventions and +guarantees that are followed strictly in the rest of Trio do not +always apply. When you use this module, it's your job to think about +how you're going to handle the tricky cases so you can expose a +friendly Trio-style API to your users. + + +Debugging and instrumentation +============================= + +Trio tries hard to provide useful hooks for debugging and +instrumentation. Some are documented above (the nursery introspection +attributes, :meth:`trio.Lock.statistics`, etc.). Here are some more. + + +Global statistics +----------------- + +.. autofunction:: current_statistics + + +The current clock +----------------- + +.. autofunction:: current_clock + + +.. _instrumentation: + +Instrument API +-------------- + +The instrument API provides a standard way to add custom +instrumentation to the run loop. Want to make a histogram of +scheduling latencies, log a stack trace of any task that blocks the +run loop for >50 ms, or measure what percentage of your process's +running time is spent waiting for I/O? This is the place. + +The general idea is that at any given moment, :func:`trio.run` +maintains a set of "instruments", which are objects that implement the +:class:`trio.abc.Instrument` interface. When an interesting event +happens, it loops over these instruments and notifies them by calling +an appropriate method. The tutorial has :ref:`a simple example of +using this for tracing `. + +Since this hooks into Trio at a rather low level, you do have to be +careful. The callbacks are run synchronously, and in many cases if +they error out then there isn't any plausible way to propagate this +exception (for instance, we might be deep in the guts of the exception +propagation machinery...). Therefore our `current strategy +`__ for handling +exceptions raised by instruments is to (a) log an exception to the +``"trio.abc.Instrument"`` logger, which by default prints a stack +trace to standard error and (b) disable the offending instrument. + +You can register an initial list of instruments by passing them to +:func:`trio.run`. :func:`add_instrument` and +:func:`remove_instrument` let you add and remove instruments at +runtime. + +.. autofunction:: add_instrument + +.. autofunction:: remove_instrument + +And here's the interface to implement if you want to build your own +:class:`~trio.abc.Instrument`: + +.. autoclass:: trio.abc.Instrument + :members: + +The tutorial has a :ref:`fully-worked example +` of defining a custom instrument to log +Trio's internal scheduling decisions. + + +Low-level process spawning +========================== + +.. autofunction:: trio.lowlevel.open_process + + +Low-level I/O primitives +======================== + +Different environments expose different low-level APIs for performing +async I/O. :mod:`trio.lowlevel` exposes these APIs in a relatively +direct way, so as to allow maximum power and flexibility for higher +level code. However, this means that the exact API provided may vary +depending on what system Trio is running on. + + +Universally available API +------------------------- + +All environments provide the following functions: + +.. function:: wait_readable(obj) + :async: + + Block until the kernel reports that the given object is readable. + + On Unix systems, ``obj`` must either be an integer file descriptor, + or else an object with a ``.fileno()`` method which returns an + integer file descriptor. Any kind of file descriptor can be passed, + though the exact semantics will depend on your kernel. For example, + this probably won't do anything useful for on-disk files. + + On Windows systems, ``obj`` must either be an integer ``SOCKET`` + handle, or else an object with a ``.fileno()`` method which returns + an integer ``SOCKET`` handle. File descriptors aren't supported, + and neither are handles that refer to anything besides a + ``SOCKET``. + + :raises trio.BusyResourceError: + if another task is already waiting for the given socket to + become readable. + :raises trio.ClosedResourceError: + if another task calls :func:`notify_closing` while this + function is still working. + +.. function:: wait_writable(obj) + :async: + + Block until the kernel reports that the given object is writable. + + See `wait_readable` for the definition of ``obj``. + + :raises trio.BusyResourceError: + if another task is already waiting for the given socket to + become writable. + :raises trio.ClosedResourceError: + if another task calls :func:`notify_closing` while this + function is still working. + + +.. function:: notify_closing(obj) + + Call this before closing a file descriptor (on Unix) or socket (on + Windows). This will cause any `wait_readable` or `wait_writable` + calls on the given object to immediately wake up and raise + `~trio.ClosedResourceError`. + + This doesn't actually close the object – you still have to do that + yourself afterwards. Also, you want to be careful to make sure no + new tasks start waiting on the object in between when you call this + and when it's actually closed. So to close something properly, you + usually want to do these steps in order: + + 1. Explicitly mark the object as closed, so that any new attempts + to use it will abort before they start. + 2. Call `notify_closing` to wake up any already-existing users. + 3. Actually close the object. + + It's also possible to do them in a different order if that's more + convenient, *but only if* you make sure not to have any checkpoints in + between the steps. This way they all happen in a single atomic + step, so other tasks won't be able to tell what order they happened + in anyway. + + +Unix-specific API +----------------- + +`FdStream` supports wrapping Unix files (such as a pipe or TTY) as +a stream. + +If you have two different file descriptors for sending and receiving, +and want to bundle them together into a single bidirectional +`~trio.abc.Stream`, then use `trio.StapledStream`:: + + bidirectional_stream = trio.StapledStream( + trio.lowlevel.FdStream(write_fd), + trio.lowlevel.FdStream(read_fd) + ) + +.. autoclass:: FdStream + :show-inheritance: + :members: + + +Kqueue-specific API +------------------- + +TODO: these are implemented, but are currently more of a sketch than +anything real. See `#26 +`__. + +.. function:: current_kqueue() + +.. function:: wait_kevent(ident, filter, abort_func) + :async: + +.. function:: monitor_kevent(ident, filter) + :with: queue + + +Windows-specific API +-------------------- + +.. function:: WaitForSingleObject(handle) + :async: + + Async and cancellable variant of `WaitForSingleObject + `__. + Windows only. + + :arg handle: + A Win32 object handle, as a Python integer. + :raises OSError: + If the handle is invalid, e.g. when it is already closed. + + +TODO: these are implemented, but are currently more of a sketch than +anything real. See `#26 +`__ and `#52 +`__. + +.. function:: register_with_iocp(handle) + +.. function:: wait_overlapped(handle, lpOverlapped) + :async: + +.. function:: current_iocp() + +.. function:: monitor_completion_key() + :with: queue + + +Global state: system tasks and run-local variables +================================================== + +.. autoclass:: RunVar + +.. autofunction:: spawn_system_task + + +Trio tokens +=========== + +.. autoclass:: TrioToken() + :members: + +.. autofunction:: current_trio_token + + +Spawning threads +================ + +.. autofunction:: start_thread_soon + + +Safer KeyboardInterrupt handling +================================ + +Trio's handling of control-C is designed to balance usability and +safety. On the one hand, there are sensitive regions (like the core +scheduling loop) where it's simply impossible to handle arbitrary +:exc:`KeyboardInterrupt` exceptions while maintaining our core +correctness invariants. On the other, if the user accidentally writes +an infinite loop, we do want to be able to break out of that. Our +solution is to install a default signal handler which checks whether +it's safe to raise :exc:`KeyboardInterrupt` at the place where the +signal is received. If so, then we do; otherwise, we schedule a +:exc:`KeyboardInterrupt` to be delivered to the main task at the next +available opportunity (similar to how :exc:`~trio.Cancelled` is +delivered). + +So that's great, but – how do we know whether we're in one of the +sensitive parts of the program or not? + +This is determined on a function-by-function basis. By default: + +- The top-level function in regular user tasks is unprotected. +- The top-level function in system tasks is protected. +- If a function doesn't specify otherwise, then it inherits the + protection state of its caller. + +This means you only need to override the defaults at places where you +transition from protected code to unprotected code or vice-versa. + +These transitions are accomplished using two function decorators: + +.. function:: disable_ki_protection() + :decorator: + + Decorator that marks the given regular function, generator + function, async function, or async generator function as + unprotected against :exc:`KeyboardInterrupt`, i.e., the code inside + this function *can* be rudely interrupted by + :exc:`KeyboardInterrupt` at any moment. + + If you have multiple decorators on the same function, then this + should be at the bottom of the stack (closest to the actual + function). + + An example of where you'd use this is in implementing something + like :func:`trio.from_thread.run`, which uses + :meth:`TrioToken.run_sync_soon` to get into the Trio + thread. :meth:`~TrioToken.run_sync_soon` callbacks are run with + :exc:`KeyboardInterrupt` protection enabled, and + :func:`trio.from_thread.run` takes advantage of this to safely set up + the machinery for sending a response back to the original thread, but + then uses :func:`disable_ki_protection` when entering the + user-provided function. + +.. function:: enable_ki_protection() + :decorator: + + Decorator that marks the given regular function, generator + function, async function, or async generator function as protected + against :exc:`KeyboardInterrupt`, i.e., the code inside this + function *won't* be rudely interrupted by + :exc:`KeyboardInterrupt`. (Though if it contains any + :ref:`checkpoints `, then it can still receive + :exc:`KeyboardInterrupt` at those. This is considered a polite + interruption.) + + .. warning:: + + Be very careful to only use this decorator on functions that you + know will either exit in bounded time, or else pass through a + checkpoint regularly. (Of course all of your functions should + have this property, but if you mess it up here then you won't + even be able to use control-C to escape!) + + If you have multiple decorators on the same function, then this + should be at the bottom of the stack (closest to the actual + function). + + An example of where you'd use this is on the ``__exit__`` + implementation for something like a :class:`~trio.Lock`, where a + poorly-timed :exc:`KeyboardInterrupt` could leave the lock in an + inconsistent state and cause a deadlock. + +.. autofunction:: currently_ki_protected + + +Sleeping and waking +=================== + +Wait queue abstraction +---------------------- + +.. autoclass:: ParkingLot + :members: + :undoc-members: + +.. autoclass:: ParkingLotStatistics + :members: + +Low-level checkpoint functions +------------------------------ + +.. autofunction:: checkpoint + +The next two functions are used *together* to make up a checkpoint: + +.. autofunction:: checkpoint_if_cancelled +.. autofunction:: cancel_shielded_checkpoint + +These are commonly used in cases where you have an operation that +might-or-might-not block, and you want to implement Trio's standard +checkpoint semantics. Example:: + + async def operation_that_maybe_blocks(): + await checkpoint_if_cancelled() + try: + ret = attempt_operation() + except BlockingIOError: + # need to block and then retry, which we do below + pass + else: + # operation succeeded, finish the checkpoint then return + await cancel_shielded_checkpoint() + return ret + while True: + await wait_for_operation_to_be_ready() + try: + return attempt_operation() + except BlockingIOError: + pass + +This logic is a bit convoluted, but accomplishes all of the following: + +* Every successful execution path passes through a checkpoint (assuming that + ``wait_for_operation_to_be_ready`` is an unconditional checkpoint) + +* Our :ref:`cancellation semantics ` say that + :exc:`~trio.Cancelled` should only be raised if the operation didn't + happen. Using :func:`cancel_shielded_checkpoint` on the early-exit + branch accomplishes this. + +* On the path where we do end up blocking, we don't pass through any + schedule points before that, which avoids some unnecessary work. + +* Avoids implicitly chaining the :exc:`BlockingIOError` with any + errors raised by ``attempt_operation`` or + ``wait_for_operation_to_be_ready``, by keeping the ``while True:`` + loop outside of the ``except BlockingIOError:`` block. + +These functions can also be useful in other situations. For example, +when :func:`trio.to_thread.run_sync` schedules some work to run in a +worker thread, it blocks until the work is finished (so it's a +schedule point), but by default it doesn't allow cancellation. So to +make sure that the call always acts as a checkpoint, it calls +:func:`checkpoint_if_cancelled` before starting the thread. + + +Low-level blocking +------------------ + +.. autofunction:: wait_task_rescheduled +.. autoclass:: Abort +.. autofunction:: reschedule + +Here's an example lock class implemented using +:func:`wait_task_rescheduled` directly. This implementation has a +number of flaws, including lack of fairness, O(n) cancellation, +missing error checking, failure to insert a checkpoint on the +non-blocking path, etc. If you really want to implement your own lock, +then you should study the implementation of :class:`trio.Lock` and use +:class:`ParkingLot`, which handles some of these issues for you. But +this does serve to illustrate the basic structure of the +:func:`wait_task_rescheduled` API:: + + class NotVeryGoodLock: + def __init__(self): + self._blocked_tasks = collections.deque() + self._held = False + + async def acquire(self): + # We might have to try several times to acquire the lock. + while self._held: + # Someone else has the lock, so we have to wait. + task = trio.lowlevel.current_task() + self._blocked_tasks.append(task) + def abort_fn(_): + self._blocked_tasks.remove(task) + return trio.lowlevel.Abort.SUCCEEDED + await trio.lowlevel.wait_task_rescheduled(abort_fn) + # At this point the lock was released -- but someone else + # might have swooped in and taken it again before we + # woke up. So we loop around to check the 'while' condition + # again. + # if we reach this point, it means that the 'while' condition + # has just failed, so we know no-one is holding the lock, and + # we can take it. + self._held = True + + def release(self): + self._held = False + if self._blocked_tasks: + woken_task = self._blocked_tasks.popleft() + trio.lowlevel.reschedule(woken_task) + + +Task API +======== + +.. autofunction:: current_root_task() + +.. autofunction:: current_task() + +.. class:: Task() + + A :class:`Task` object represents a concurrent "thread" of + execution. It has no public constructor; Trio internally creates a + :class:`Task` object for each call to ``nursery.start(...)`` or + ``nursery.start_soon(...)``. + + Its public members are mostly useful for introspection and + debugging: + + .. attribute:: name + + String containing this :class:`Task`\'s name. Usually the name + of the function this :class:`Task` is running, but can be + overridden by passing ``name=`` to ``start`` or ``start_soon``. + + .. attribute:: coro + + This task's coroutine object. + + .. automethod:: iter_await_frames + + .. attribute:: context + + This task's :class:`contextvars.Context` object. + + .. autoattribute:: parent_nursery + + .. autoattribute:: eventual_parent_nursery + + .. autoattribute:: child_nurseries + + .. attribute:: custom_sleep_data + + Trio doesn't assign this variable any meaning, except that it + sets it to ``None`` whenever a task is rescheduled. It can be + used to share data between the different tasks involved in + putting a task to sleep and then waking it up again. (See + :func:`wait_task_rescheduled` for details.) + + +.. _guest-mode: + +Using "guest mode" to run Trio on top of other event loops +========================================================== + +What is "guest mode"? +--------------------- + +An event loop acts as a central coordinator to manage all the IO +happening in your program. Normally, that means that your application +has to pick one event loop, and use it for everything. But what if you +like Trio, but also need to use a framework like `Qt +`__ or `PyGame +`__ that has its own event loop? Then you +need some way to run both event loops at once. + +It is possible to combine event loops, but the standard approaches all +have significant downsides: + +- **Polling:** this is where you use a `busy-loop + `__ to manually check + for IO on both event loops many times per second. This adds latency, + and wastes CPU time and electricity. + +- **Pluggable IO backends:** this is where you reimplement one of the + event loop APIs on top of the other, so you effectively end up with + just one event loop. This requires a significant amount of work for + each pair of event loops you want to integrate, and different + backends inevitably end up with inconsistent behavior, forcing users + to program against the least-common-denominator. And if the two + event loops expose different feature sets, it may not even be + possible to implement one in terms of the other. + +- **Running the two event loops in separate threads:** This works, but + most event loop APIs aren't thread-safe, so in this approach you + need to keep careful track of which code runs on which event loop, + and remember to use explicit inter-thread messaging whenever you + interact with the other loop – or else risk obscure race conditions + and data corruption. + +That's why Trio offers a fourth option: **guest mode**. Guest mode +lets you execute `trio.run` on top of some other "host" event loop, +like Qt. Its advantages are: + +- Efficiency: guest mode is event-driven instead of using a busy-loop, + so it has low latency and doesn't waste electricity. + +- No need to think about threads: your Trio code runs in the same + thread as the host event loop, so you can freely call sync Trio APIs + from the host, and call sync host APIs from Trio. For example, if + you're making a GUI app with Qt as the host loop, then making a + `cancel button `__ and + connecting it to a `trio.CancelScope` is as easy as writing:: + + # Trio code can create Qt objects without any special ceremony... + my_cancel_button = QPushButton("Cancel") + # ...and Qt can call back to Trio just as easily + my_cancel_button.clicked.connect(my_cancel_scope.cancel) + + (For async APIs, it's not that simple, but you can use sync APIs to + build explicit bridges between the two worlds, e.g. by passing async + functions and their results back and forth through queues.) + +- Consistent behavior: guest mode uses the same code as regular Trio: + the same scheduler, same IO code, same everything. So you get the + full feature set and everything acts the way you expect. + +- Simple integration and broad compatibility: pretty much every event + loop offers some threadsafe "schedule a callback" operation, and + that's all you need to use it as a host loop. + + +Really? How is that possible? +----------------------------- + +.. note:: + + You can use guest mode without reading this section. It's included + for those who enjoy understanding how things work. + +All event loops have the same basic structure. They loop through two +operations, over and over: + +1. Wait for the operating system to notify them that something + interesting has happened, like data arriving on a socket or a + timeout passing. They do this by invoking a platform-specific + ``sleep_until_something_happens()`` system call – ``select``, + ``epoll``, ``kqueue``, ``GetQueuedCompletionEvents``, etc. + +2. Run all the user tasks that care about whatever happened, then go + back to step 1. + +The problem here is step 1. Two different event loops on the same +thread can take turns running user tasks in step 2, but when they're +idle and nothing is happening, they can't both invoke their own +``sleep_until_something_happens()`` function at the same time. + +The "polling" and "pluggable backend" strategies solve this by hacking +the loops so both step 1s can run at the same time in the same thread. +Keeping everything in one thread is great for step 2, but the step 1 +hacks create problems. + +The "separate threads" strategy solves this by moving both steps into +separate threads. This makes step 1 work, but the downside is that now +the user tasks in step 2 are running separate threads as well, so +users are forced to deal with inter-thread coordination. + +The idea behind guest mode is to combine the best parts of each +approach: we move Trio's step 1 into a separate worker thread, while +keeping Trio's step 2 in the main host thread. This way, when the +application is idle, both event loops do their +``sleep_until_something_happens()`` at the same time in their own +threads. But when the app wakes up and your code is actually running, +it all happens in a single thread. The threading trickiness is all +handled transparently inside Trio. + +Concretely, we unroll Trio's internal event loop into a chain of +callbacks, and as each callback finishes, it schedules the next +callback onto the host loop or a worker thread as appropriate. So the +only thing the host loop has to provide is a way to schedule a +callback onto the main thread from a worker thread. + +Coordinating between Trio and the host loop does add some overhead. +The main cost is switching in and out of the background thread, since +this requires cross-thread messaging. This is cheap (on the order of a +few microseconds, assuming your host loop is implemented efficiently), +but it's not free. + +But, there's a nice optimization we can make: we only *need* the +thread when our ``sleep_until_something_happens()`` call actually +sleeps, that is, when the Trio part of your program is idle and has +nothing to do. So before we switch into the worker thread, we +double-check whether we're idle, and if not, then we skip the worker +thread and jump directly to step 2. This means that your app only pays +the extra thread-switching penalty at moments when it would otherwise +be sleeping, so it should have minimal effect on your app's overall +performance. + +The total overhead will depend on your host loop, your platform, your +application, etc. But we expect that in most cases, apps running in +guest mode should only be 5-10% slower than the same code using +`trio.run`. If you find that's not true for your app, then please let +us know and we'll see if we can fix it! + + +.. _guest-run-implementation: + +Implementing guest mode for your favorite event loop +---------------------------------------------------- + +Let's walk through what you need to do to integrate Trio's guest mode +with your favorite event loop. Treat this section like a checklist. + +**Getting started:** The first step is to get something basic working. +Here's a minimal example of running Trio on top of asyncio, that you +can use as a model:: + + import asyncio, trio + + # A tiny Trio program + async def trio_main(): + for _ in range(5): + print("Hello from Trio!") + # This is inside Trio, so we have to use Trio APIs + await trio.sleep(1) + return "trio done!" + + # The code to run it as a guest inside asyncio + async def asyncio_main(): + asyncio_loop = asyncio.get_running_loop() + + def run_sync_soon_threadsafe(fn): + asyncio_loop.call_soon_threadsafe(fn) + + def done_callback(trio_main_outcome): + print(f"Trio program ended with: {trio_main_outcome}") + + # This is where the magic happens: + trio.lowlevel.start_guest_run( + trio_main, + run_sync_soon_threadsafe=run_sync_soon_threadsafe, + done_callback=done_callback, + ) + + # Let the host loop run for a while to give trio_main time to + # finish. (WARNING: This is a hack. See below for better + # approaches.) + # + # This function is in asyncio, so we have to use asyncio APIs. + await asyncio.sleep(10) + + asyncio.run(asyncio_main()) + +You can see we're using asyncio-specific APIs to start up a loop, and +then we call `trio.lowlevel.start_guest_run`. This function is very +similar to `trio.run`, and takes all the same arguments. But it has +two differences: + +First, instead of blocking until ``trio_main`` has finished, it +schedules ``trio_main`` to start running on top of the host loop, and +then returns immediately. So ``trio_main`` is running in the +background – that's why we have to sleep and give it time to finish. + +And second, it requires two extra keyword arguments: +``run_sync_soon_threadsafe``, and ``done_callback``. + +For ``run_sync_soon_threadsafe``, we need a function that takes a +synchronous callback, and schedules it to run on your host loop. And +this function needs to be "threadsafe" in the sense that you can +safely call it from any thread. So you need to figure out how to write +a function that does that using your host loop's API. For asyncio, +this is easy because `~asyncio.loop.call_soon_threadsafe` does exactly +what we need; for your loop, it might be more or less complicated. + +For ``done_callback``, you pass in a function that Trio will +automatically invoke when the Trio run finishes, so you know it's done +and what happened. For this basic starting version, we just print the +result; in the next section we'll discuss better alternatives. + +At this stage you should be able to run a simple Trio program inside +your host loop. Now we'll turn that prototype into something solid. + + +**Loop lifetimes:** One of the trickiest things in most event loops is +shutting down correctly. And having two event loops makes this even +harder! + +If you can, we recommend following this pattern: + +- Start up your host loop +- Immediately call `start_guest_run` to start Trio +- When Trio finishes and your ``done_callback`` is invoked, shut down + the host loop +- Make sure that nothing else shuts down your host loop + +This way, your two event loops have the same lifetime, and your +program automatically exits when your Trio function finishes. + +Here's how we'd extend our asyncio example to implement this pattern: + +.. code-block:: python3 + :emphasize-lines: 8-11,19-22 + + # Improved version, that shuts down properly after Trio finishes + async def asyncio_main(): + asyncio_loop = asyncio.get_running_loop() + + def run_sync_soon_threadsafe(fn): + asyncio_loop.call_soon_threadsafe(fn) + + # Revised 'done' callback: set a Future + done_fut = asyncio_loop.create_future() + def done_callback(trio_main_outcome): + done_fut.set_result(trio_main_outcome) + + trio.lowlevel.start_guest_run( + trio_main, + run_sync_soon_threadsafe=run_sync_soon_threadsafe, + done_callback=done_callback, + ) + + # Wait for the guest run to finish + trio_main_outcome = await done_fut + # Pass through the return value or exception from the guest run + return trio_main_outcome.unwrap() + +And then you can encapsulate all this machinery in a utility function +that exposes a `trio.run`-like API, but runs both loops together:: + + def trio_run_with_asyncio(trio_main, *args, **trio_run_kwargs): + async def asyncio_main(): + # same as above + ... + + return asyncio.run(asyncio_main()) + +Technically, it is possible to use other patterns. But there are some +important limitations you have to respect: + +- **You must let the Trio program run to completion.** Many event + loops let you stop the event loop at any point, and any pending + callbacks/tasks/etc. just... don't run. Trio follows a more + structured system, where you can cancel things, but the code always + runs to completion, so ``finally`` blocks run, resources are cleaned + up, etc. If you stop your host loop early, before the + ``done_callback`` is invoked, then that cuts off the Trio run in the + middle without a chance to clean up. This can leave your code in an + inconsistent state, and will definitely leave Trio's internals in an + inconsistent state, which will cause errors if you try to use Trio + again in that thread. + + Some programs need to be able to quit at any time, for example in + response to a GUI window being closed or a user selecting a "Quit" + from a menu. In these cases, we recommend wrapping your whole + program in a `trio.CancelScope`, and cancelling it when you want to + quit. + +- Each host loop can only have one `start_guest_run` at a time. If you + try to start a second one, you'll get an error. If you need to run + multiple Trio functions at the same time, then start up a single + Trio run, open a nursery, and then start your functions as child + tasks in that nursery. + +- Unless you or your host loop register a handler for `signal.SIGINT` + before starting Trio (this is not common), then Trio will take over + delivery of `KeyboardInterrupt`\s. And since Trio can't tell which + host code is safe to interrupt, it will only deliver + `KeyboardInterrupt` into the Trio part of your code. This is fine if + your program is set up to exit when the Trio part exits, because the + `KeyboardInterrupt` will propagate out of Trio and then trigger the + shutdown of your host loop, which is just what you want. + +Given these constraints, we think the simplest approach is to always +start and stop the two loops together. + +**Signal management:** `"Signals" +`__ are a low-level +inter-process communication primitive. When you hit control-C to kill +a program, that uses a signal. Signal handling in Python has `a lot of +moving parts +`__. +One of those parts is `signal.set_wakeup_fd`, which event loops use to +make sure that they wake up when a signal arrives so they can respond +to it. (If you've ever had an event loop ignore you when you hit +control-C, it was probably because they weren't using +`signal.set_wakeup_fd` correctly.) + +But, only one event loop can use `signal.set_wakeup_fd` at a time. And +in guest mode that can cause problems: Trio and the host loop might +start fighting over who's using `signal.set_wakeup_fd`. + +Some event loops, like asyncio, won't work correctly unless they win +this fight. Fortunately, Trio is a little less picky: as long as +*someone* makes sure that the program wakes up when a signal arrives, +it should work correctly. So if your host loop wants +`signal.set_wakeup_fd`, then you should disable Trio's +`signal.set_wakeup_fd` support, and then both loops will work +correctly. + +On the other hand, if your host loop doesn't use +`signal.set_wakeup_fd`, then the only way to make everything work +correctly is to *enable* Trio's `signal.set_wakeup_fd` support. + +By default, Trio assumes that your host loop doesn't use +`signal.set_wakeup_fd`. It does try to detect when this creates a +conflict with the host loop, and print a warning – but unfortunately, +by the time it detects it, the damage has already been done. So if +you're getting this warning, then you should disable Trio's +`signal.set_wakeup_fd` support by passing +``host_uses_signal_set_wakeup_fd=True`` to `start_guest_run`. + +If you aren't seeing any warnings with your initial prototype, you're +*probably* fine. But the only way to be certain is to check your host +loop's source. For example, asyncio may or may not use +`signal.set_wakeup_fd` depending on the Python version and operating +system. + + +**A small optimization:** Finally, consider a small optimization. Some +event loops offer two versions of their "call this function soon" API: +one that can be used from any thread, and one that can only be used +from the event loop thread, with the latter being cheaper. For +example, asyncio has both `~asyncio.loop.call_soon_threadsafe` and +`~asyncio.loop.call_soon`. + +If you have a loop like this, then you can also pass a +``run_sync_soon_not_threadsafe=...`` kwarg to `start_guest_run`, and +Trio will automatically use it when appropriate. + +If your loop doesn't have a split like this, then don't worry about +it; ``run_sync_soon_not_threadsafe=`` is optional. (If it's not +passed, then Trio will just use your threadsafe version in all cases.) + +**That's it!** If you've followed all these steps, you should now have +a cleanly-integrated hybrid event loop. Go make some cool +GUIs/games/whatever! + + +Limitations +----------- + +In general, almost all Trio features should work in guest mode. The +exception is features which rely on Trio having a complete picture of +everything that your program is doing, since obviously, it can't +control the host loop or see what it's doing. + +Custom clocks can be used in guest mode, but they only affect Trio +timeouts, not host loop timeouts. And the :ref:`autojump clock +` and related `trio.testing.wait_all_tasks_blocked` can +technically be used in guest mode, but they'll only take Trio tasks +into account when decided whether to jump the clock or whether all +tasks are blocked. + + +Reference +--------- + +.. autofunction:: start_guest_run + + +.. _live-coroutine-handoff: + +Handing off live coroutine objects between coroutine runners +============================================================ + +Internally, Python's async/await syntax is built around the idea of +"coroutine objects" and "coroutine runners". A coroutine object +represents the state of an async callstack. But by itself, this is +just a static object that sits there. If you want it to do anything, +you need a coroutine runner to push it forward. Every Trio task has an +associated coroutine object (see :data:`Task.coro`), and the Trio +scheduler acts as their coroutine runner. + +But of course, Trio isn't the only coroutine runner in Python – +:mod:`asyncio` has one, other event loops have them, you can even +define your own. + +And in some very, very unusual circumstances, it even makes sense to +transfer a single coroutine object back and forth between different +coroutine runners. That's what this section is about. This is an +*extremely* exotic use case, and assumes a lot of expertise in how +Python async/await works internally. For motivating examples, see +`trio-asyncio issue #42 +`__, and `trio +issue #649 `__. For +more details on how coroutines work, we recommend André Caron's `A +tale of event loops +`__, or +going straight to `PEP 492 +`__ for the full details. + +.. autofunction:: permanently_detach_coroutine_object + +.. autofunction:: temporarily_detach_coroutine_object + +.. autofunction:: reattach_detached_coroutine_object diff --git a/docs/source/reference-testing.rst b/docs/source/reference-testing.rst index 40a275bbeb..76ecd4a2d4 100644 --- a/docs/source/reference-testing.rst +++ b/docs/source/reference-testing.rst @@ -16,6 +16,8 @@ Test harness integration .. decorator:: trio_test +.. _testing-time: + Time and timeouts ----------------- diff --git a/docs/source/reference-testing/across-realtime.py b/docs/source/reference-testing/across-realtime.py index 41bf502611..300810c065 100644 --- a/docs/source/reference-testing/across-realtime.py +++ b/docs/source/reference-testing/across-realtime.py @@ -6,6 +6,7 @@ YEAR = 365 * 24 * 60 * 60 # seconds + async def task1(): start = trio.current_time() @@ -13,15 +14,15 @@ async def task1(): await trio.sleep(YEAR) duration = trio.current_time() - start - print("task1: woke up; clock says I've slept {} years" - .format(duration / YEAR)) + print(f"task1: woke up; clock says I've slept {duration / YEAR} years") print("task1: sleeping for 1 year, 100 times") for _ in range(100): await trio.sleep(YEAR) duration = trio.current_time() - start - print("task1: slept {} years total".format(duration / YEAR)) + print(f"task1: slept {duration / YEAR} years total") + async def task2(): start = trio.current_time() @@ -30,25 +31,27 @@ async def task2(): await trio.sleep(5 * YEAR) duration = trio.current_time() - start - print("task2: woke up; clock says I've slept {} years" - .format(duration / YEAR)) + print(f"task2: woke up; clock says I've slept {duration / YEAR} years") print("task2: sleeping for 500 years") await trio.sleep(500 * YEAR) duration = trio.current_time() - start - print("task2: slept {} years total".format(duration / YEAR)) + print(f"task2: slept {duration / YEAR} years total") + async def main(): async with trio.open_nursery() as nursery: nursery.start_soon(task1) nursery.start_soon(task2) + def run_example(clock): real_start = time.perf_counter() trio.run(main, clock=clock) real_duration = time.perf_counter() - real_start - print("Total real time elapsed: {} seconds".format(real_duration)) + print(f"Total real time elapsed: {real_duration} seconds") + print("Clock where time passes at 100 years per second:\n") run_example(trio.testing.MockClock(rate=100 * YEAR)) diff --git a/docs/source/releasing.rst b/docs/source/releasing.rst index 0ebad58a80..0fe51370d5 100644 --- a/docs/source/releasing.rst +++ b/docs/source/releasing.rst @@ -19,7 +19,7 @@ Things to do for releasing: * Do the actual release changeset - + update version number + + bump version number - increment as per Semantic Versioning rules @@ -29,7 +29,7 @@ Things to do for releasing: - review history change - - ``git rm`` changes + - ``git rm`` the now outdated newfragments + commit @@ -37,31 +37,26 @@ Things to do for releasing: * create pull request to ``python-trio/trio``'s "master" branch -* announce PR on gitter - - + wait for feedback - - + fix problems, if any - * verify that all checks succeeded -* acknowledge the release PR - - + or rather, somebody else should do that +* tag with vVERSION, push tag on ``python-trio/trio`` (not on your personal repository) -* tag with vVERSION +* push to PyPI:: -* push to PyPI + git clean -xdf # maybe run 'git clean -xdn' first to see what it will delete + python3 setup.py sdist bdist_wheel + twine upload dist/* - + ``python3 setup.py sdist bdist_wheel upload`` +* update version number in the same pull request -* announce on gitter + + add ``+dev`` tag to the end -* update version number +* merge the release pull request - + add ``+dev`` tag to the end +* make a GitHub release (go to the tag and press "Create release from tag") -* prepare another pull request to "master" + + paste in the new content in ``history.rst`` and convert it to markdown: turn the parts under section into ``---``, update links to just be the links, and whatever else is necessary. - + acknowledge it + + include anything else that might be pertinent, like a link to the commits between the latest and current release. +* announce on gitter diff --git a/docs/source/tutorial.rst b/docs/source/tutorial.rst index 3b9255b236..0584446fb7 100644 --- a/docs/source/tutorial.rst +++ b/docs/source/tutorial.rst @@ -13,14 +13,13 @@ Tutorial still probably read this, because Trio is different.) Trio turns Python into a concurrent language. It takes the core - async/await syntax introduced in 3.5, and uses it to add three + async/await syntax introduced in 3.5, and uses it to add two new pieces of semantics: - cancel scopes: a generic system for managing timeouts and cancellation - nurseries: which let your program do multiple things at the same time - - MultiErrors: for when multiple things go wrong at once Of course it also provides a complete suite of APIs for doing networking, file I/O, using worker threads, @@ -34,9 +33,6 @@ Tutorial print(response) and then again with /delay/10 - (note that asks needs cpython 3.6 though. maybe just for one async - generator?) - value of async/await: show you where the cancellation exceptions can happen -- see pillar re: explicit cancel points @@ -60,8 +56,6 @@ Tutorial and demonstrate start() then point out that you can just use serve_tcp() - exceptions and MultiError - example: catch-all logging in our echo server review of the three (or four) core language extensions @@ -94,7 +88,7 @@ Okay, ready? Let's get started. Before you begin ---------------- -1. Make sure you're using Python 3.5 or newer. +1. Make sure you're using Python 3.8 or newer. 2. ``python3 -m pip install --upgrade trio`` (or on Windows, maybe ``py -3 -m pip install --upgrade trio`` – `details @@ -273,7 +267,7 @@ this, that tries to call an async function but leaves out the trio.sleep(2 * x) sleep_time = time.perf_counter() - start_time - print("Woke up after {:.2f} seconds, feeling well rested!".format(sleep_time)) + print(f"Woke up after {sleep_time:.2f} seconds, feeling well rested!") trio.run(broken_double_sleep, 3) @@ -312,7 +306,7 @@ runs: >>>> # but forcing a garbage collection gives us a warning: >>>> import gc >>>> gc.collect() - /home/njs/pypy-3.5-nightly/lib-python/3/importlib/_bootstrap.py:191: RuntimeWarning: coroutine 'sleep' was never awaited + /home/njs/pypy-3.8-nightly/lib-python/3/importlib/_bootstrap.py:191: RuntimeWarning: coroutine 'sleep' was never awaited if _module_locks.get(name) is wr: # XXX PyPy fix? 0 >>>> @@ -442,15 +436,15 @@ Now that we understand ``async with``, let's look at ``parent`` again: :end-at: all done! There are only 4 lines of code that really do anything here. On line -17, we use :func:`trio.open_nursery` to get a "nursery" object, and +20, we use :func:`trio.open_nursery` to get a "nursery" object, and then inside the ``async with`` block we call ``nursery.start_soon`` twice, -on lines 19 and 22. There are actually two ways to call an async +on lines 22 and 25. There are actually two ways to call an async function: the first one is the one we already saw, using ``await async_fn()``; the new one is ``nursery.start_soon(async_fn)``: it asks Trio to start running this async function, *but then returns immediately without waiting for the function to finish*. So after our two calls to ``nursery.start_soon``, ``child1`` and ``child2`` are now running in the -background. And then at line 25, the commented line, we hit the end of +background. And then at line 28, the commented line, we hit the end of the ``async with`` block, and the nursery's ``__aexit__`` function runs. What this does is force ``parent`` to stop here and wait for all the children in the nursery to exit. This is why you have to use @@ -599,8 +593,8 @@ Each task runs until it hits the call to :func:`trio.sleep`, and then suddenly we're back in :func:`trio.run` deciding what to run next. How does this happen? The secret is that :func:`trio.run` and :func:`trio.sleep` work together to make it happen: :func:`trio.sleep` -has access to some special magic that lets it pause its entire -call stack, so it sends a note to :func:`trio.run` requesting to be +has access to some special magic that lets it pause itself, +so it sends a note to :func:`trio.run` requesting to be woken again after 1 second, and then suspends the task. And once the task is suspended, Python gives control back to :func:`trio.run`, which decides what to do next. (If this sounds similar to the way that @@ -622,7 +616,7 @@ between the implementation of generators and async functions.) Only async functions have access to the special magic for suspending a task, so only async functions can cause the program to switch to a -different task. What this means if a call *doesn't* have an ``await`` +different task. What this means is that if a call *doesn't* have an ``await`` on it, then you know that it *can't* be a place where your task will be suspended. This makes tasks much `easier to reason about `__ than @@ -639,7 +633,7 @@ wouldn't have been able to pause at the end and wait for the children to finish; we need our cleanup function to be async, which is exactly what ``async with`` gives us. -Now, back to our execution trace. To recap: at this point ``parent`` +Now, back to our execution point. To recap: at this point ``parent`` is waiting on ``child1`` and ``child2``, and both children are sleeping. So :func:`trio.run` checks its notes, and sees that there's nothing to be done until those sleeps finish – unless possibly some @@ -774,7 +768,7 @@ above, it's baked into Trio's design that when it has multiple tasks, they take turns, so at each moment only one of them is actively running. We're not so much overcoming the GIL as embracing it. But if you're willing to accept that, plus a bit of extra work to put these new -``async`` and ``await`` keywords in the right places, then in exchange +``async`` and ``await`` keywords in the right places, then in exchange you get: * Excellent scalability: Trio can run 10,000+ tasks simultaneously @@ -836,19 +830,25 @@ Networking with Trio Now let's take what we've learned and use it to do some I/O, which is where async/await really shines. - -An echo client -~~~~~~~~~~~~~~ - -The traditional application for demonstrating network APIs is an "echo -server": a program that accepts arbitrary data from a client, and then -sends that same data right back. (Probably a more relevant example +The traditional toy application for demonstrating network APIs is an +"echo server": a program that awaits arbitrary data from remote clients, +and then sends that same data right back. (Probably a more relevant example these days would be an application that does lots of concurrent HTTP requests, but for that `you need an HTTP library `__ such as `asks `__, so we'll stick with the echo server tradition.) +In this tutorial, we present both ends of the pipe: the client, and the +server. The client periodically sends data to the server, and displays its +answers. The server awaits connections; when a client connects, it recopies +the received data back on the pipe. + + +An echo client +~~~~~~~~~~~~~~ + + To start with, here's an example echo *client*, i.e., the program that will send some data at our echo server and get responses back: @@ -857,6 +857,9 @@ will send some data at our echo server and get responses back: .. literalinclude:: tutorial/echo-client.py :linenos: +Note that this code will not work without a TCP server such as the one +we'll implement below. + The overall structure here should be familiar, because it's just like our :ref:`last example `: we have a parent task, which spawns two child tasks to do the actual work, and @@ -1083,7 +1086,7 @@ up, and ``send_all`` will block until the remote side calls Now let's think about this from the server's point of view. Each time it calls ``receive_some``, it gets some data that it needs to send -back. And until it sends it back, the data is sitting around takes up +back. And until it sends it back, the data that is sitting around takes up memory. Computers have finite amounts of RAM, so if our server is well behaved then at some point it needs to stop calling ``receive_some`` until it gets rid of some of the old data by doing its own call to @@ -1143,9 +1146,6 @@ TODO: explain :exc:`Cancelled` TODO: explain how cancellation is also used when one child raises an exception -TODO: show an example :exc:`MultiError` traceback and walk through its -structure - TODO: maybe a brief discussion of :exc:`KeyboardInterrupt` handling? .. diff --git a/docs/source/tutorial/echo-client.py b/docs/source/tutorial/echo-client.py index 06f6a81e7e..244f2831f5 100644 --- a/docs/source/tutorial/echo-client.py +++ b/docs/source/tutorial/echo-client.py @@ -9,23 +9,26 @@ # - must match what we set in our echo server PORT = 12345 + async def sender(client_stream): print("sender: started!") while True: data = b"async can sometimes be confusing, but I believe in you!" - print("sender: sending {!r}".format(data)) + print(f"sender: sending {data!r}") await client_stream.send_all(data) await trio.sleep(1) + async def receiver(client_stream): print("receiver: started!") async for data in client_stream: - print("receiver: got data {!r}".format(data)) + print(f"receiver: got data {data!r}") print("receiver: connection closed") sys.exit() + async def parent(): - print("parent: connecting to 127.0.0.1:{}".format(PORT)) + print(f"parent: connecting to 127.0.0.1:{PORT}") client_stream = await trio.open_tcp_stream("127.0.0.1", PORT) async with client_stream: async with trio.open_nursery() as nursery: @@ -35,4 +38,5 @@ async def parent(): print("parent: spawning receiver...") nursery.start_soon(receiver, client_stream) + trio.run(parent) diff --git a/docs/source/tutorial/echo-server.py b/docs/source/tutorial/echo-server.py index a184cdf46b..3751cadd73 100644 --- a/docs/source/tutorial/echo-server.py +++ b/docs/source/tutorial/echo-server.py @@ -11,31 +11,33 @@ CONNECTION_COUNTER = count() + async def echo_server(server_stream): # Assign each connection a unique number to make our debug prints easier # to understand when there are multiple simultaneous connections. ident = next(CONNECTION_COUNTER) - print("echo_server {}: started".format(ident)) + print(f"echo_server {ident}: started") try: async for data in server_stream: - print("echo_server {}: received data {!r}".format(ident, data)) + print(f"echo_server {ident}: received data {data!r}") await server_stream.send_all(data) - print("echo_server {}: connection closed".format(ident)) - # FIXME: add discussion of MultiErrors to the tutorial, and use - # MultiError.catch here. (Not important in this case, but important if the - # server code uses nurseries internally.) + print(f"echo_server {ident}: connection closed") + # FIXME: add discussion of (Base)ExceptionGroup to the tutorial, and use + # exceptiongroup.catch() here. (Not important in this case, but important + # if the server code uses nurseries internally.) except Exception as exc: # Unhandled exceptions will propagate into our parent and take # down the whole program. If the exception is KeyboardInterrupt, # that's what we want, but otherwise maybe not... - print("echo_server {}: crashed: {!r}".format(ident, exc)) + print(f"echo_server {ident}: crashed: {exc!r}") + async def main(): await trio.serve_tcp(echo_server, PORT) + # We could also just write 'trio.run(trio.serve_tcp, echo_server, PORT)', but real # programs almost always end up doing other stuff too and then we'd have to go # back and factor it out into a separate function anyway. So it's simplest to # just make it a standalone function from the beginning. trio.run(main) - diff --git a/docs/source/tutorial/tasks-intro.py b/docs/source/tutorial/tasks-intro.py index a316cb933d..e00de363b1 100644 --- a/docs/source/tutorial/tasks-intro.py +++ b/docs/source/tutorial/tasks-intro.py @@ -2,16 +2,19 @@ import trio + async def child1(): print(" child1: started! sleeping now...") await trio.sleep(1) print(" child1: exiting!") + async def child2(): print(" child2: started! sleeping now...") await trio.sleep(1) print(" child2: exiting!") + async def parent(): print("parent: started!") async with trio.open_nursery() as nursery: @@ -25,4 +28,5 @@ async def parent(): # -- we exit the nursery block here -- print("parent: all done!") + trio.run(parent) diff --git a/docs/source/tutorial/tasks-with-trace.py b/docs/source/tutorial/tasks-with-trace.py index 38a1ffe862..a6e40ec8ee 100644 --- a/docs/source/tutorial/tasks-with-trace.py +++ b/docs/source/tutorial/tasks-with-trace.py @@ -2,16 +2,19 @@ import trio + async def child1(): print(" child1: started! sleeping now...") await trio.sleep(1) print(" child1: exiting!") + async def child2(): print(" child2 started! sleeping now...") await trio.sleep(1) print(" child2 exiting!") + async def parent(): print("parent: started!") async with trio.open_nursery() as nursery: @@ -25,6 +28,7 @@ async def parent(): # -- we exit the nursery block here -- print("parent: all done!") + class Tracer(trio.abc.Instrument): def before_run(self): print("!!! run started") @@ -32,7 +36,7 @@ def before_run(self): def _print_with_task(self, msg, task): # repr(task) is perhaps more useful than task.name in general, # but in context of a tutorial the extra noise is unhelpful. - print("{}: {}".format(msg, task.name)) + print(f"{msg}: {task.name}") def task_spawned(self, task): self._print_with_task("### new task spawned", task) @@ -51,16 +55,17 @@ def task_exited(self, task): def before_io_wait(self, timeout): if timeout: - print("### waiting for I/O for up to {} seconds".format(timeout)) + print(f"### waiting for I/O for up to {timeout} seconds") else: print("### doing a quick check for I/O") self._sleep_time = trio.current_time() def after_io_wait(self, timeout): duration = trio.current_time() - self._sleep_time - print("### finished I/O check (took {} seconds)".format(duration)) + print(f"### finished I/O check (took {duration} seconds)") def after_run(self): print("!!! run finished") + trio.run(parent, instruments=[Tracer()]) diff --git a/mypy.ini b/mypy.ini deleted file mode 100644 index a93e159b42..0000000000 --- a/mypy.ini +++ /dev/null @@ -1,26 +0,0 @@ -[mypy] -# TODO: run mypy against several OS/version combos in CI -# https://mypy.readthedocs.io/en/latest/command_line.html#platform-configuration - -# Be flexible about unannotated imports -follow_imports = silent -ignore_missing_imports = True - -# Be strict about use of Mypy -warn_unused_ignores = True -warn_unused_configs = True -warn_redundant_casts = True -warn_return_any = True - -# Avoid subtle backsliding -#disallow_any_decorated = True -#disallow_incomplete_defs = True -#disallow_subclassing_any = True - -# Enable gradually / for new modules -check_untyped_defs = False -disallow_untyped_calls = False -disallow_untyped_defs = False - -# DO NOT use `ignore_errors`; it doesn't apply -# downstream and users have to deal with them. diff --git a/newsfragments/1272.feature.rst b/newsfragments/1272.feature.rst deleted file mode 100644 index 3f540b93e5..0000000000 --- a/newsfragments/1272.feature.rst +++ /dev/null @@ -1,16 +0,0 @@ -If you're using Trio's low-level interfaces like -`trio.hazmat.wait_readable` or similar, and then you close a socket or -file descriptor, you're supposed to call `trio.hazmat.notify_closing` -first so Trio can clean up properly. But what if you forget? In the -past, Trio would tend to either deadlock or explode spectacularly. -Now, it's much more robust to this situation, and should generally -survive. (But note that "survive" is not the same as "give you the -results you were expecting", so you should still call -`~trio.hazmat.notify_closing` when appropriate. This is about harm -reduction and making it easier to debug this kind of mistake, not -something you should rely on.) - -If you're using higher-level interfaces outside of the `trio.hazmat` -module, then you don't need to worry about any of this; those -intefaces already take care of calling `~trio.hazmat.notify_closing` -for you. diff --git a/newsfragments/1308.bugfix.rst b/newsfragments/1308.bugfix.rst deleted file mode 100644 index db435e46c1..0000000000 --- a/newsfragments/1308.bugfix.rst +++ /dev/null @@ -1,9 +0,0 @@ -A bug related to the following methods has been introduced in version 0.12.0: - -- `trio.Path.iterdir` -- `trio.Path.glob` -- `trio.Path.rglob` - -The iteration of the blocking generators produced by pathlib was performed in -the trio thread. With this fix, the previous behavior is restored: the blocking -generators are converted into lists in a thread dedicated to blocking IO calls. diff --git a/newsfragments/2668.removal.rst b/newsfragments/2668.removal.rst new file mode 100644 index 0000000000..512f681077 --- /dev/null +++ b/newsfragments/2668.removal.rst @@ -0,0 +1 @@ +Drop support for Python3.7 and PyPy3.7/3.8. diff --git a/newsfragments/README.rst b/newsfragments/README.rst index d01e930f24..52dc0716bb 100644 --- a/newsfragments/README.rst +++ b/newsfragments/README.rst @@ -6,22 +6,31 @@ message and PR description, which are a description of the change as relevant to people working on the code itself.) Each file should be named like ``..rst``, where -```` is an issue numbers, and ```` is one of: +```` is an issue number, and ```` is one of: -* ``feature`` +* ``headline``: a major new feature we want to highlight for users +* ``breaking``: any breaking changes that happen without a proper + deprecation period (note: deprecations, and removal of previously + deprecated features after an appropriate time, go in the + ``deprecated`` category instead) +* ``feature``: any new feature that doesn't qualify for ``headline`` +* ``removal``: removing support for old python versions, or other removals with no deprecation period. * ``bugfix`` * ``doc`` -* ``removal`` +* ``deprecated`` * ``misc`` -So for example: ``123.feature.rst``, ``456.bugfix.rst`` +So for example: ``123.headline.rst``, ``456.bugfix.rst``, +``789.deprecated.rst`` If your PR fixes an issue, use that number here. If there is no issue, then after you submit the PR and get the PR number you can add a newsfragment using that instead. -Note that the ``towncrier`` tool will automatically -reflow your text, so don't try to do any fancy formatting. You can -install ``towncrier`` and then run ``towncrier --draft`` if you want -to get a preview of how your change will look in the final release -notes. +Your text can use all the same markup that we use in our Sphinx docs. +For example, you can use double-backticks to mark code snippets, or +single-backticks to link to a function/class/module. + +To check how your formatting looks, the easiest way is to make the PR, +and then after the CI checks run, click on the "Read the Docs build" +details link, and navigate to the release history page. diff --git a/notes-to-self/afd-lab.py b/notes-to-self/afd-lab.py index 58a6c22799..ed420dbdbd 100644 --- a/notes-to-self/afd-lab.py +++ b/notes-to-self/afd-lab.py @@ -96,7 +96,7 @@ class AFDLab: def __init__(self): self._afd = _afd_helper_handle() - trio.hazmat.register_with_iocp(self._afd) + trio.lowlevel.register_with_iocp(self._afd) async def afd_poll(self, sock, flags, *, exclusive=0): print(f"Starting a poll for {flags!r}") @@ -127,7 +127,7 @@ async def afd_poll(self, sock, flags, *, exclusive=0): raise try: - await trio.hazmat.wait_overlapped(self._afd, lpOverlapped) + await trio.lowlevel.wait_overlapped(self._afd, lpOverlapped) except: print(f"Poll for {flags!r}: {sys.exc_info()[1]!r}") raise diff --git a/notes-to-self/aio-guest-test.py b/notes-to-self/aio-guest-test.py new file mode 100644 index 0000000000..b64a11bd04 --- /dev/null +++ b/notes-to-self/aio-guest-test.py @@ -0,0 +1,48 @@ +import asyncio +import trio + +async def aio_main(): + loop = asyncio.get_running_loop() + + trio_done_fut = loop.create_future() + def trio_done_callback(main_outcome): + print(f"trio_main finished: {main_outcome!r}") + trio_done_fut.set_result(main_outcome) + + trio.lowlevel.start_guest_run( + trio_main, + run_sync_soon_threadsafe=loop.call_soon_threadsafe, + done_callback=trio_done_callback, + ) + + (await trio_done_fut).unwrap() + + +async def trio_main(): + print("trio_main!") + + to_trio, from_aio = trio.open_memory_channel(float("inf")) + from_trio = asyncio.Queue() + + asyncio.create_task(aio_pingpong(from_trio, to_trio)) + + from_trio.put_nowait(0) + + async for n in from_aio: + print(f"trio got: {n}") + await trio.sleep(1) + from_trio.put_nowait(n + 1) + if n >= 10: + return + +async def aio_pingpong(from_trio, to_trio): + print("aio_pingpong!") + + while True: + n = await from_trio.get() + print(f"aio got: {n}") + await asyncio.sleep(1) + to_trio.send_nowait(n + 1) + + +asyncio.run(aio_main()) diff --git a/notes-to-self/fbsd-pipe-close-notify.py b/notes-to-self/fbsd-pipe-close-notify.py new file mode 100644 index 0000000000..7b18f65d6f --- /dev/null +++ b/notes-to-self/fbsd-pipe-close-notify.py @@ -0,0 +1,38 @@ +# This script completes correctly on macOS and FreeBSD 13.0-CURRENT, but hangs +# on FreeBSD 12.1. I'm told the fix will be backported to 12.2 (which is due +# out in October 2020). +# +# Upstream bug: https://bugs.freebsd.org/bugzilla/show_bug.cgi?id=246350 + +import select +import os +import threading + +r, w = os.pipe() + +os.set_blocking(w, False) + +print("filling pipe buffer") +while True: + try: + os.write(w, b"x") + except BlockingIOError: + break + +_, wfds, _ = select.select([], [w], [], 0) +print("select() says the write pipe is", "writable" if w in wfds else "NOT writable") + +kq = select.kqueue() +event = select.kevent(w, select.KQ_FILTER_WRITE, select.KQ_EV_ADD) +kq.control([event], 0) + +print("closing read end of pipe") +os.close(r) + +_, wfds, _ = select.select([], [w], [], 0) +print("select() says the write pipe is", "writable" if w in wfds else "NOT writable") + +print("waiting for kqueue to report the write end is writable") +got = kq.control([], 1) +print("done!") +print(got) diff --git a/notes-to-self/how-does-windows-so-reuseaddr-work.py b/notes-to-self/how-does-windows-so-reuseaddr-work.py index 64430a0f92..d8d60d1d66 100644 --- a/notes-to-self/how-does-windows-so-reuseaddr-work.py +++ b/notes-to-self/how-does-windows-so-reuseaddr-work.py @@ -10,6 +10,7 @@ modes = ["default", "SO_REUSEADDR", "SO_EXCLUSIVEADDRUSE"] bind_types = ["wildcard", "specific"] + def sock(mode): s = socket.socket(family=socket.AF_INET) if mode == "SO_REUSEADDR": @@ -18,6 +19,7 @@ def sock(mode): s.setsockopt(socket.SOL_SOCKET, socket.SO_EXCLUSIVEADDRUSE, 1) return s + def bind(sock, bind_type): if bind_type == "wildcard": sock.bind(("0.0.0.0", 12345)) @@ -26,6 +28,7 @@ def bind(sock, bind_type): else: assert False + def table_entry(mode1, bind_type1, mode2, bind_type2): with sock(mode1) as sock1: bind(sock1, bind_type1) @@ -41,12 +44,22 @@ def table_entry(mode1, bind_type1, mode2, bind_type2): else: return "Success" -print(""" + +print( + """ second bind - | default | SO_REUSEADDR | SO_EXCLUSIVEADDRUSE - | specific| wildcard| specific| wildcard| specific| wildcard -first bind ------------------------------------------------------------""" -# default | wildcard | INUSE | Success | ACCESS | Success | INUSE | Success + | """ + + " | ".join(["%-19s" % mode for mode in modes]) +) + +print(""" """, end="") +for mode in modes: + print(" | " + " | ".join(["%8s" % bind_type for bind_type in bind_types]), end="") + +print( + """ +first bind -----------------------------------------------------------------""" + # default | wildcard | INUSE | Success | ACCESS | Success | INUSE | Success ) for i, mode1 in enumerate(modes): @@ -56,6 +69,8 @@ def table_entry(mode1, bind_type1, mode2, bind_type2): for l, bind_type2 in enumerate(bind_types): entry = table_entry(mode1, bind_type1, mode2, bind_type2) row.append(entry) - #print(mode1, bind_type1, mode2, bind_type2, entry) - print("{:>19} | {:>8} | ".format(mode1, bind_type1) - + " | ".join(["%7s" % entry for entry in row])) + # print(mode1, bind_type1, mode2, bind_type2, entry) + print( + f"{mode1:>19} | {bind_type1:>8} | " + + " | ".join(["%8s" % entry for entry in row]) + ) diff --git a/notes-to-self/print-task-tree.py b/notes-to-self/print-task-tree.py index d4b6dd8da4..38e545853e 100644 --- a/notes-to-self/print-task-tree.py +++ b/notes-to-self/print-task-tree.py @@ -38,7 +38,7 @@ def current_root_task(): - task = trio.hazmat.current_task() + task = trio.lowlevel.current_task() while task.parent_nursery is not None: task = task.parent_nursery.parent_task return task diff --git a/notes-to-self/reopen-pipe.py b/notes-to-self/reopen-pipe.py index 910def397c..5e5b31e41f 100644 --- a/notes-to-self/reopen-pipe.py +++ b/notes-to-self/reopen-pipe.py @@ -3,12 +3,13 @@ import time import tempfile + def check_reopen(r1, w): try: print("Reopening read end") - r2 = os.open("/proc/self/fd/{}".format(r1), os.O_RDONLY) + r2 = os.open(f"/proc/self/fd/{r1}", os.O_RDONLY) - print("r1 is {}, r2 is {}".format(r1, r2)) + print(f"r1 is {r1}, r2 is {r2}") print("checking they both can receive from w...") @@ -36,11 +37,12 @@ def check_reopen(r1, w): def sleep_then_write(): time.sleep(1) os.write(w, b"c") + threading.Thread(target=sleep_then_write, daemon=True).start() assert os.read(r1, 1) == b"c" print("r1 definitely seems to be in blocking mode") except Exception as exc: - print("ERROR: {!r}".format(exc)) + print(f"ERROR: {exc!r}") print("-- testing anonymous pipe --") @@ -63,6 +65,6 @@ def sleep_then_write(): print("-- testing socketpair --") import socket + rs, ws = socket.socketpair() check_reopen(rs.fileno(), ws.fileno()) - diff --git a/notes-to-self/schedule-timing.py b/notes-to-self/schedule-timing.py index 441f579f7c..176dcf9220 100644 --- a/notes-to-self/schedule-timing.py +++ b/notes-to-self/schedule-timing.py @@ -4,16 +4,18 @@ LOOPS = 0 RUNNING = True + async def reschedule_loop(depth): if depth == 0: global LOOPS while RUNNING: LOOPS += 1 await trio.sleep(0) - #await trio.hazmat.cancel_shielded_checkpoint() + # await trio.lowlevel.cancel_shielded_checkpoint() else: await reschedule_loop(depth - 1) + async def report_loop(): global RUNNING try: @@ -25,13 +27,15 @@ async def report_loop(): end_count = LOOPS loops = end_count - start_count duration = end_time - start_time - print("{} loops/sec".format(loops / duration)) + print(f"{loops / duration} loops/sec") finally: RUNNING = False + async def main(): async with trio.open_nursery() as nursery: nursery.start_soon(reschedule_loop, 10) nursery.start_soon(report_loop) + trio.run(main) diff --git a/notes-to-self/socket-scaling.py b/notes-to-self/socket-scaling.py index 61527e2552..1571be4d17 100644 --- a/notes-to-self/socket-scaling.py +++ b/notes-to-self/socket-scaling.py @@ -11,14 +11,6 @@ # On Windows: with the old 'select'-based loop, the cost of scheduling grew # with the number of outstanding sockets, which was bad. # -# With the new IOCP-based loop, the cost of scheduling is constant, which is -# good. But, we find that the cost of cancelling a single wait_readable -# appears to grow like O(n**2) or so in the number of outstanding -# wait_readables. This is bad -- it means that cancelling all of the -# outstanding operations here is something like O(n**3)! To avoid this, we -# should consider creating multiple AFD helper handles and distributing the -# AFD_POLL operations across them. -# # To run this on Unix systems, you'll probably first have to run: # # ulimit -n 31000 @@ -49,11 +41,11 @@ def pt(desc, *, count=total, item="socket"): pt("socket creation") async with trio.open_nursery() as nursery: for s in sockets: - nursery.start_soon(trio.hazmat.wait_readable, s) + nursery.start_soon(trio.lowlevel.wait_readable, s) await trio.testing.wait_all_tasks_blocked() pt("spawning wait tasks") for _ in range(1000): - await trio.hazmat.cancel_shielded_checkpoint() + await trio.lowlevel.cancel_shielded_checkpoint() pt("scheduling 1000 times", count=1000, item="schedule") nursery.cancel_scope.cancel() pt("cancelling wait tasks") diff --git a/notes-to-self/socketpair-buffering.py b/notes-to-self/socketpair-buffering.py index dd3b1ad97d..5e77a709b7 100644 --- a/notes-to-self/socketpair-buffering.py +++ b/notes-to-self/socketpair-buffering.py @@ -32,6 +32,6 @@ except BlockingIOError: pass - print("setsockopt bufsize {}: {}".format(bufsize, i)) + print(f"setsockopt bufsize {bufsize}: {i}") a.close() b.close() diff --git a/notes-to-self/ssl-handshake/ssl-handshake.py b/notes-to-self/ssl-handshake/ssl-handshake.py index 81d875be6a..18a0e1a675 100644 --- a/notes-to-self/ssl-handshake/ssl-handshake.py +++ b/notes-to-self/ssl-handshake/ssl-handshake.py @@ -8,6 +8,7 @@ server_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) server_ctx.load_cert_chain("trio-test-1.pem") + def _ssl_echo_serve_sync(sock): try: wrapped = server_ctx.wrap_socket(sock, server_side=True) @@ -20,16 +21,19 @@ def _ssl_echo_serve_sync(sock): except BrokenPipeError: pass + @contextmanager def echo_server_connection(): client_sock, server_sock = socket.socketpair() with client_sock, server_sock: t = threading.Thread( - target=_ssl_echo_serve_sync, args=(server_sock,), daemon=True) + target=_ssl_echo_serve_sync, args=(server_sock,), daemon=True + ) t.start() yield client_sock + class ManuallyWrappedSocket: def __init__(self, ctx, sock, **kwargs): self.incoming = ssl.MemoryBIO() @@ -82,21 +86,23 @@ def unwrap(self): def wrap_socket_via_wrap_socket(ctx, sock, **kwargs): return ctx.wrap_socket(sock, do_handshake_on_connect=False, **kwargs) + def wrap_socket_via_wrap_bio(ctx, sock, **kwargs): return ManuallyWrappedSocket(ctx, sock, **kwargs) for wrap_socket in [ - wrap_socket_via_wrap_socket, - wrap_socket_via_wrap_bio, + wrap_socket_via_wrap_socket, + wrap_socket_via_wrap_bio, ]: - print("\n--- checking {} ---\n".format(wrap_socket.__name__)) + print(f"\n--- checking {wrap_socket.__name__} ---\n") print("checking with do_handshake + correct hostname...") with echo_server_connection() as client_sock: client_ctx = ssl.create_default_context(cafile="trio-test-CA.pem") wrapped = wrap_socket( - client_ctx, client_sock, server_hostname="trio-test-1.example.org") + client_ctx, client_sock, server_hostname="trio-test-1.example.org" + ) wrapped.do_handshake() wrapped.sendall(b"x") assert wrapped.recv(1) == b"x" @@ -107,7 +113,8 @@ def wrap_socket_via_wrap_bio(ctx, sock, **kwargs): with echo_server_connection() as client_sock: client_ctx = ssl.create_default_context(cafile="trio-test-CA.pem") wrapped = wrap_socket( - client_ctx, client_sock, server_hostname="trio-test-2.example.org") + client_ctx, client_sock, server_hostname="trio-test-2.example.org" + ) try: wrapped.do_handshake() except Exception: @@ -119,7 +126,8 @@ def wrap_socket_via_wrap_bio(ctx, sock, **kwargs): with echo_server_connection() as client_sock: client_ctx = ssl.create_default_context(cafile="trio-test-CA.pem") wrapped = wrap_socket( - client_ctx, client_sock, server_hostname="trio-test-2.example.org") + client_ctx, client_sock, server_hostname="trio-test-2.example.org" + ) # We forgot to call do_handshake # But the hostname is wrong so something had better error out... sent = b"x" diff --git a/notes-to-self/sslobject.py b/notes-to-self/sslobject.py index cfac98676e..0692af319c 100644 --- a/notes-to-self/sslobject.py +++ b/notes-to-self/sslobject.py @@ -15,6 +15,7 @@ soutb = ssl.MemoryBIO() sso = server_ctx.wrap_bio(sinb, soutb, server_side=True) + @contextmanager def expect(etype): try: @@ -22,7 +23,8 @@ def expect(etype): except etype: pass else: - raise AssertionError("expected {}".format(etype)) + raise AssertionError(f"expected {etype}") + with expect(ssl.SSLWantReadError): cso.do_handshake() diff --git a/notes-to-self/subprocess-notes.txt b/notes-to-self/subprocess-notes.txt index e33c835640..d3a5c1096c 100644 --- a/notes-to-self/subprocess-notes.txt +++ b/notes-to-self/subprocess-notes.txt @@ -4,7 +4,7 @@ # tenable. We're better off trying os.waitpid(..., os.WNOHANG), and if that # says the process is still going then spawn a thread to sit in waitpid. # ......though that waitpid is non-cancellable so ugh. this is a problem, -# becaues it's also mutating -- you only get to waitpid() once, and you have +# because it's also mutating -- you only get to waitpid() once, and you have # to do it, because zombies. I guess we could make sure the waitpid thread is # daemonic and either it gets back to us eventually (even if our first call to # 'await wait()' is cancelled, maybe another one won't be), or else we go away diff --git a/notes-to-self/thread-closure-bug-demo.py b/notes-to-self/thread-closure-bug-demo.py index 514636a1b4..b09a87fe5f 100644 --- a/notes-to-self/thread-closure-bug-demo.py +++ b/notes-to-self/thread-closure-bug-demo.py @@ -8,18 +8,21 @@ COUNT = 100 + def slow_tracefunc(frame, event, arg): # A no-op trace function that sleeps briefly to make us more likely to hit # the race condition. time.sleep(0.01) return slow_tracefunc + def run_with_slow_tracefunc(fn): # settrace() only takes effect when you enter a new frame, so we need this # little dance: sys.settrace(slow_tracefunc) return fn() + def outer(): x = 0 # We hide the done variable inside a list, because we want to use it to @@ -46,13 +49,14 @@ def traced_looper(): t.start() for i in range(COUNT): - print("after {} increments, x is {}".format(i, x)) + print(f"after {i} increments, x is {x}") x += 1 time.sleep(0.01) done[0] = True t.join() - print("Final discrepancy: {} (should be 0)".format(COUNT - x)) + print(f"Final discrepancy: {COUNT - x} (should be 0)") + outer() diff --git a/notes-to-self/thread-dispatch-bench.py b/notes-to-self/thread-dispatch-bench.py index 1625efae17..9afb4bbec8 100644 --- a/notes-to-self/thread-dispatch-bench.py +++ b/notes-to-self/thread-dispatch-bench.py @@ -10,11 +10,13 @@ COUNT = 10000 + def worker(in_q, out_q): while True: job = in_q.get() out_q.put(job()) + def main(): in_q = Queue() out_q = Queue() @@ -28,6 +30,7 @@ def main(): in_q.put(lambda: None) out_q.get() end = time.monotonic() - print("{:.2f} µs/job".format((end - start) / COUNT * 1e6)) + print(f"{(end - start) / COUNT * 1e6:.2f} µs/job") + main() diff --git a/notes-to-self/time-wait-windows-exclusiveaddruse.py b/notes-to-self/time-wait-windows-exclusiveaddruse.py index db3aaad08a..dcb4a27dd0 100644 --- a/notes-to-self/time-wait-windows-exclusiveaddruse.py +++ b/notes-to-self/time-wait-windows-exclusiveaddruse.py @@ -8,15 +8,17 @@ import socket from contextlib import contextmanager + @contextmanager def report_outcome(tagline): try: yield except OSError as exc: - print("{}: failed".format(tagline)) - print(" details: {!r}".format(exc)) + print(f"{tagline}: failed") + print(f" details: {exc!r}") else: - print("{}: succeeded".format(tagline)) + print(f"{tagline}: succeeded") + # Set up initial listening socket lsock = socket.socket() diff --git a/notes-to-self/time-wait.py b/notes-to-self/time-wait.py index e865a94982..08c71b0048 100644 --- a/notes-to-self/time-wait.py +++ b/notes-to-self/time-wait.py @@ -31,6 +31,7 @@ import attr + @attr.s(repr=False) class Options: listen1_early = attr.ib(default=None) @@ -49,9 +50,10 @@ def describe(self): for f in attr.fields(self.__class__): value = getattr(self, f.name) if value is not None: - info.append("{}={}".format(f.name, value)) + info.append(f"{f.name}={value}") return "Set/unset: {}".format(", ".join(info)) + def time_wait(options): print(options.describe()) @@ -60,7 +62,7 @@ def time_wait(options): listen0 = socket.socket() listen0.bind(("127.0.0.1", 0)) sockaddr = listen0.getsockname() - #print(" ", sockaddr) + # print(" ", sockaddr) listen0.close() listen1 = socket.socket() @@ -98,6 +100,7 @@ def time_wait(options): else: print(" -> ok") + time_wait(Options()) time_wait(Options(listen1_early=True, server=True, listen2=True)) time_wait(Options(listen1_early=True)) diff --git a/notes-to-self/tiny-thread-pool.py b/notes-to-self/tiny-thread-pool.py deleted file mode 100644 index 85afc4eb3e..0000000000 --- a/notes-to-self/tiny-thread-pool.py +++ /dev/null @@ -1,143 +0,0 @@ -# This is some very messy notes on how we might implement a thread cache - -import threading -import Queue - -# idea: -# -# unbounded thread pool; tracks how many threads are "available" and how much -# work there is to do; if work > available threads, spawn a new thread -# -# if a thread sits idle for >N ms, exit -# -# we don't need to support job cancellation -# -# we do need to mark a thread as "available" just before it -# signals back to Trio that it's done, to maintain the invariant that all -# unavailable threads are inside the limiter= protection -# -# maintaining this invariant while exiting can be a bit tricky -# -# maybe a simple target should be to always have 1 idle thread - -# XX we can't use a single shared dispatch queue, because we need LIFO -# scheduling, or else the idle-thread timeout won't work! -# -# instead, keep a list/deque/OrderedDict/something of idle threads, and -# dispatch by popping one off; put things back by pushing them on the end -# maybe one shared dispatch Lock, plus a Condition for each thread -# dispatch by dropping the job into the place where the thread can see it and -# then signalling its Condition? or could have separate locks - -@attr.s(frozen=True) -class Job: - main = attr.ib() - main_args = attr.ib() - finish = attr.ib() - finish_args = attr.ib() - -class EXIT: - pass - -class ThreadCache: - def __init__(self): - self._lock = threading.Lock() - self._idle_workers = deque() - self._closed = False - - def close(self): - self._closed = True - with self._lock: - while self._idle_workers: - self._idle_workers.pop().submit(None) - - def submit(self, job): - with self._lock: - if not self._idle_workers: - WorkerThread(self, self._lock, job) - else: - worker = self._idle_workers.pop() - worker.submit(job) - - # Called from another thread - # Must be called with the lock held - def remove_idle_worker(self, worker): - self._idle_workers.remove(worker) - - # Called from another thread - # Lock is *not* held - def add_idle_worker(self, worker): - if self._closed: - with self._lock: - worker.submit - self._idle_workers.append(worker) - -# XX thread name - -IDLE_TIMEOUT = 1.0 - -class WorkerThread: - def __init__(self, cache, lock, initial_job): - self._cache = cache - self._condition = threading.Condition(lock) - self._job = None - self._thread = threading.Thread( - target=self._loop, args=(initial_job,), daemon=True) - self._thread.start() - - # Must be called with the lock held - def submit(self, job): - assert self._job is None - self._job = job - self._condition.notify() - - def _loop(self, initial_job): - self._run_job(initial_job) - while True: - with self._condition: - self._condition.wait(IDLE_TIMEOUT): - job = self._job - self._job = None - if job is None: - self._cache.remove_idle_worker(self) - return - # Dropped the lock, and have a job to do - self._run_job(job) - - def _run_job(self, job): - job.main(*job.main_args) - self._cache.add_idle_worker(self) - job.finish(*job.finish_args) - - -# Probably the interface should be: trio.hazmat.call_soon_in_worker_thread? - -# Enqueueing work: -# put into unbounded queue -# with lock: -# if idle_threads: -# idle_threads -= 1 -# else: -# spawn a new thread (it starts out non-idle) -# -# Thread shutdown: -# with lock: -# idle_threads -= 1 -# check for work one last time, and then either exit or do it -# -# Thread startup: -# -# check for work -# while True: -# mark self as idle -# check for work (with timeout) -# either do work or shutdown - -# if we want to support QueueUserAPC cancellation, we need a way to get back -# the thread id... maybe that just works like -# -# def WaitForSingleObjectEx_thread_fn(...): -# with lock: -# check if already cancelled -# put our thread id where main thread can find it -# WaitForSingleObjectEx(...) diff --git a/notes-to-self/trace.py b/notes-to-self/trace.py index 700e2e7e81..c024a36ba5 100644 --- a/notes-to-self/trace.py +++ b/notes-to-self/trace.py @@ -88,7 +88,7 @@ def after_task_step(self, task): def task_scheduled(self, task): try: - waker = trio.hazmat.current_task() + waker = trio.lowlevel.current_task() except RuntimeError: pass else: diff --git a/notes-to-self/win-waitable-timer.py b/notes-to-self/win-waitable-timer.py new file mode 100644 index 0000000000..92bfd7a39a --- /dev/null +++ b/notes-to-self/win-waitable-timer.py @@ -0,0 +1,207 @@ +# Sandbox for exploring the Windows "waitable timer" API. +# Cf https://github.com/python-trio/trio/issues/173 +# +# Observations: +# - if you set a timer in the far future, then block in +# WaitForMultipleObjects, then set the computer's clock forward by a few +# years (past the target sleep time), then the timer immediately wakes up +# (which is good!) +# - if you set a timer in the past, then it wakes up immediately + +# Random thoughts: +# - top-level API sleep_until_datetime +# - portable manages the heap of outstanding sleeps, runs a system task to +# wait for the next one, wakes up tasks when their deadline arrives, etc. +# - non-portable code: async def sleep_until_datetime_raw, which simply blocks +# until the given time using system-specific methods. Can assume that there +# is only one call to this method at a time. +# Actually, this should be a method, so it can hold persistent state (e.g. +# timerfd). +# Can assume that the datetime passed in has tzinfo=timezone.utc +# Need a way to override this object for testing. +# +# should we expose wake-system-on-alarm functionality? windows and linux both +# make this fairly straightforward, but you obviously need to use a separate +# time source + +import cffi +from datetime import datetime, timedelta, timezone +import time + +import trio +from trio._core._windows_cffi import (ffi, kernel32, raise_winerror) + +try: + ffi.cdef( + """ +typedef struct _PROCESS_LEAP_SECOND_INFO { + ULONG Flags; + ULONG Reserved; +} PROCESS_LEAP_SECOND_INFO, *PPROCESS_LEAP_SECOND_INFO; + +typedef struct _SYSTEMTIME { + WORD wYear; + WORD wMonth; + WORD wDayOfWeek; + WORD wDay; + WORD wHour; + WORD wMinute; + WORD wSecond; + WORD wMilliseconds; +} SYSTEMTIME, *PSYSTEMTIME, *LPSYSTEMTIME; +""" + ) +except cffi.CDefError: + pass + +ffi.cdef( + """ +typedef LARGE_INTEGER FILETIME; +typedef FILETIME* LPFILETIME; + +HANDLE CreateWaitableTimerW( + LPSECURITY_ATTRIBUTES lpTimerAttributes, + BOOL bManualReset, + LPCWSTR lpTimerName +); + +BOOL SetWaitableTimer( + HANDLE hTimer, + const LPFILETIME lpDueTime, + LONG lPeriod, + void* pfnCompletionRoutine, + LPVOID lpArgToCompletionRoutine, + BOOL fResume +); + +BOOL SetProcessInformation( + HANDLE hProcess, + /* Really an enum, PROCESS_INFORMATION_CLASS */ + int32_t ProcessInformationClass, + LPVOID ProcessInformation, + DWORD ProcessInformationSize +); + +void GetSystemTimeAsFileTime( + LPFILETIME lpSystemTimeAsFileTime +); + +BOOL SystemTimeToFileTime( + const SYSTEMTIME *lpSystemTime, + LPFILETIME lpFileTime +); +""", + override=True +) + +ProcessLeapSecondInfo = 8 +PROCESS_LEAP_SECOND_INFO_FLAG_ENABLE_SIXTY_SECOND = 1 + + +def set_leap_seconds_enabled(enabled): + plsi = ffi.new("PROCESS_LEAP_SECOND_INFO*") + if enabled: + plsi.Flags = PROCESS_LEAP_SECOND_INFO_FLAG_ENABLE_SIXTY_SECOND + else: + plsi.Flags = 0 + plsi.Reserved = 0 + if not kernel32.SetProcessInformation( + ffi.cast("HANDLE", -1), # current process + ProcessLeapSecondInfo, + plsi, + ffi.sizeof("PROCESS_LEAP_SECOND_INFO"), + ): + raise_winerror() + + +def now_as_filetime(): + ft = ffi.new("LARGE_INTEGER*") + kernel32.GetSystemTimeAsFileTime(ft) + return ft[0] + + +# "FILETIME" is a specific Windows time representation, that I guess was used +# for files originally but now gets used in all kinds of non-file-related +# places. Essentially: integer count of "ticks" since an epoch in 1601, where +# each tick is 100 nanoseconds, in UTC but pretending that leap seconds don't +# exist. (Fortunately, the Python datetime module also pretends that +# leapseconds don't exist, so we can use datetime arithmetic to compute +# FILETIME values.) +# +# https://docs.microsoft.com/en-us/windows/win32/sysinfo/file-times +# +# This page has FILETIME converters and can be useful for debugging: +# +# https://www.epochconverter.com/ldap +# +FILETIME_TICKS_PER_SECOND = 10**7 +FILETIME_EPOCH = datetime.strptime( + '1601-01-01 00:00:00 Z', '%Y-%m-%d %H:%M:%S %z' +) +# XXX THE ABOVE IS WRONG: +# +# https://techcommunity.microsoft.com/t5/networking-blog/leap-seconds-for-the-appdev-what-you-should-know/ba-p/339813# +# +# Sometimes Windows FILETIME does include leap seconds! It depends on Windows +# version, process-global state, environment state, registry settings, and who +# knows what else! +# +# So actually the only correct way to convert a YMDhms-style representation of +# a time into a FILETIME is to use SystemTimeToFileTime +# +# ...also I can't even run this test on my VM, because it's running an ancient +# version of Win10 that doesn't have leap second support. Also also, Windows +# only tracks leap seconds since they added leap second support, and there +# haven't been any, so right now things work correctly either way. +# +# It is possible to insert some fake leap seconds for testing, if you want. + + +def py_datetime_to_win_filetime(dt): + # We'll want to call this on every datetime as it comes in + #dt = dt.astimezone(timezone.utc) + assert dt.tzinfo is timezone.utc + return round( + (dt - FILETIME_EPOCH).total_seconds() * FILETIME_TICKS_PER_SECOND + ) + + +async def main(): + h = kernel32.CreateWaitableTimerW(ffi.NULL, True, ffi.NULL) + if not h: + raise_winerror() + print(h) + + SECONDS = 2 + + wakeup = datetime.now(timezone.utc) + timedelta(seconds=SECONDS) + wakeup_filetime = py_datetime_to_win_filetime(wakeup) + wakeup_cffi = ffi.new("LARGE_INTEGER *") + wakeup_cffi[0] = wakeup_filetime + + print(wakeup_filetime, wakeup_cffi) + + print(f"Sleeping for {SECONDS} seconds (until {wakeup})") + + if not kernel32.SetWaitableTimer( + h, + wakeup_cffi, + 0, + ffi.NULL, + ffi.NULL, + False, + ): + raise_winerror() + + await trio.hazmat.WaitForSingleObject(h) + + print(f"Current FILETIME: {now_as_filetime()}") + set_leap_seconds_enabled(False) + print(f"Current FILETIME: {now_as_filetime()}") + set_leap_seconds_enabled(True) + print(f"Current FILETIME: {now_as_filetime()}") + set_leap_seconds_enabled(False) + print(f"Current FILETIME: {now_as_filetime()}") + + +trio.run(main) diff --git a/pyproject.toml b/pyproject.toml index 768a4766eb..445c40e28c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,18 +1,116 @@ +[tool.black] +target-version = ['py38'] + +[tool.codespell] +ignore-words-list = 'astroid,crasher,asend' + +[tool.flake8] +extend-ignore = ['D', 'E', 'W', 'F403', 'F405', 'F821', 'F822'] +per-file-ignores = [ + 'trio/__init__.py: F401', + 'trio/_core/__init__.py: F401', + 'trio/_core/_generated*.py: F401', + 'trio/_core/_tests/test_multierror_scripts/*: F401', + 'trio/abc.py: F401', + 'trio/lowlevel.py: F401', + 'trio/socket.py: F401', + 'trio/testing/__init__.py: F401' +] + +[tool.isort] +combine_as_imports = true +profile = "black" +skip_gitignore = true + +[tool.mypy] +python_version = "3.8" + +# Be flexible about dependencies that don't have stubs yet (like pytest) +ignore_missing_imports = true + +# Be strict about use of Mypy +warn_unused_ignores = true +warn_unused_configs = true +warn_redundant_casts = true +warn_return_any = true + +# Avoid subtle backsliding +#disallow_any_decorated = true +#disallow_incomplete_defs = true +#disallow_subclassing_any = true + +# Enable gradually / for new modules +check_untyped_defs = false +disallow_untyped_calls = false +disallow_untyped_defs = false + +# DO NOT use `ignore_errors`; it doesn't apply +# downstream and users have to deal with them. +[[tool.mypy.overrides]] +module = [ + "trio._socket", + "trio._core._local", + "trio._sync", + "trio._file_io", +] +disallow_incomplete_defs = true +disallow_untyped_defs = true +disallow_any_generics = true +disallow_any_decorated = true +disallow_subclassing_any = true + +[[tool.mypy.overrides]] +module = [ + "trio._path", +] +disallow_incomplete_defs = true +disallow_untyped_defs = true +#disallow_any_generics = true +#disallow_any_decorated = true +disallow_subclassing_any = true + +[tool.pytest.ini_options] +addopts = ["--strict-markers", "--strict-config"] +faulthandler_timeout = 60 +filterwarnings = [ + "error", + # https://gitter.im/python-trio/general?at=63bb8d0740557a3d5c688d67 + 'ignore:You are using cryptography on a 32-bit Python on a 64-bit Windows Operating System. Cryptography will be significantly faster if you switch to using a 64-bit Python.:UserWarning', + # this should remain until https://github.com/pytest-dev/pytest/pull/10894 is merged + 'ignore:ast.Str is deprecated:DeprecationWarning', + 'ignore:Attribute s is deprecated and will be removed:DeprecationWarning', + 'ignore:ast.NameConstant is deprecated:DeprecationWarning', + 'ignore:ast.Num is deprecated:DeprecationWarning', + # https://github.com/python/mypy/issues/15330 + 'ignore:ast.Ellipsis is deprecated:DeprecationWarning', + 'ignore:ast.Bytes is deprecated:DeprecationWarning' +] +junit_family = "xunit2" +markers = ["redistributors_should_skip: tests that should be skipped by downstream redistributors"] +xfail_strict = true + [tool.towncrier] +directory = "newsfragments" +filename = "docs/source/history.rst" +issue_format = "`#{issue} `__" # Usage: # - PRs should drop a file like "issuenumber.feature" in newsfragments -# (or "bugfix", "doc", "removal", "misc"; misc gets no text, we can -# customize this) +# (or "bugfix", "doc", "removal", "misc"; misc gets no text, we can +# customize this) # - At release time after bumping version number, run: towncrier -# (or towncrier --draft) +# (or towncrier --draft) package = "trio" -filename = "docs/source/history.rst" -directory = "newsfragments" underlines = ["-", "~", "^"] -issue_format = "`#{issue} `__" -# Unfortunately there's no way to simply override -# tool.towncrier.type.misc.showcontent +[[tool.towncrier.type]] +directory = "headline" +name = "Headline features" +showcontent = true + +[[tool.towncrier.type]] +directory = "breaking" +name = "Breaking changes" +showcontent = true [[tool.towncrier.type]] directory = "feature" @@ -26,16 +124,15 @@ showcontent = true [[tool.towncrier.type]] directory = "doc" -name = "Improved Documentation" +name = "Improved documentation" showcontent = true [[tool.towncrier.type]] -directory = "removal" -name = "Deprecations and Removals" +directory = "deprecated" +name = "Deprecations and removals" showcontent = true [[tool.towncrier.type]] directory = "misc" name = "Miscellaneous internal changes" showcontent = true - diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index e6579a4f07..0000000000 --- a/setup.cfg +++ /dev/null @@ -1,5 +0,0 @@ -[tool:pytest] -xfail_strict = true -faulthandler_timeout=60 -markers = - redistributors_should_skip: tests that should be skipped by downstream redistributors diff --git a/setup.py b/setup.py index 9ef41cb3e0..2917f7c12e 100644 --- a/setup.py +++ b/setup.py @@ -1,9 +1,9 @@ -from setuptools import setup, find_packages +from setuptools import find_packages, setup exec(open("trio/_version.py", encoding="utf-8").read()) LONG_DESC = """\ -.. image:: https://cdn.rawgit.com/python-trio/trio/9b0bec646a31e0d0f67b8b6ecc6939726faf3e17/logo/logo-with-background.svg +.. image:: https://raw.githubusercontent.com/python-trio/trio/9b0bec646a31e0d0f67b8b6ecc6939726faf3e17/logo/logo-with-background.svg :width: 200px :align: right @@ -44,7 +44,7 @@ Vital statistics: * Supported environments: Linux, macOS, or Windows running some kind of Python - 3.5-or-better (either CPython or PyPy3 is fine). \\*BSD and illumos likely + 3.8-or-better (either CPython or PyPy3 is fine). \\*BSD and illumos likely work too, but are not tested. * Install: ``python3 -m pip install -U trio`` (or on Windows, maybe @@ -73,27 +73,30 @@ version=__version__, description="A friendly Python library for async concurrency and I/O", long_description=LONG_DESC, + long_description_content_type="text/x-rst", author="Nathaniel J. Smith", author_email="njs@pobox.com", url="https://github.com/python-trio/trio", - license="MIT -or- Apache License 2.0", + license="MIT OR Apache-2.0", packages=find_packages(), install_requires=[ - "attrs >= 19.2.0", # for eq + # attrs 19.2.0 adds `eq` option to decorators + # attrs 20.1.0 adds @frozen + "attrs >= 20.1.0", "sortedcontainers", - "async_generator >= 1.9", "idna", "outcome", "sniffio", # cffi 1.12 adds from_buffer(require_writable=True) and ffi.release() # cffi 1.14 fixes memory leak inside ffi.getwinerror() - "cffi>=1.14; os_name == 'nt'", # "cffi is required on windows" - "contextvars>=2.1; python_version < '3.7'" + # cffi is required on Windows, except on PyPy where it is built-in + "cffi>=1.14; os_name == 'nt' and implementation_name != 'pypy'", + "exceptiongroup >= 1.0.0rc9; python_version < '3.11'", ], # This means, just install *everything* you see under trio/, even if it # doesn't look like a source file, so long as it appears in MANIFEST.in: include_package_data=True, - python_requires=">=3.5", + python_requires=">=3.8", keywords=["async", "io", "networking", "trio"], classifiers=[ "Development Status :: 3 - Alpha", @@ -107,8 +110,11 @@ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.5", - "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Topic :: System :: Networking", "Framework :: Trio", ], diff --git a/test-requirements.in b/test-requirements.in index 077eb70dcf..1911b1bf11 100644 --- a/test-requirements.in +++ b/test-requirements.in @@ -1,31 +1,34 @@ # For tests -pytest >= 5.0 # for faulthandler in core -pytest-cov >= 2.6.0 -ipython # for the IPython traceback integration tests -pyOpenSSL # for the ssl tests -trustme # for the ssl tests -pylint # for pylint finding all symbols tests -jedi # for jedi code completion tests +pytest >= 5.0 # for faulthandler in core +coverage >= 7.2.5 +async_generator >= 1.9 +pyright +ipython # for the IPython traceback integration tests +pyOpenSSL >= 22.0.0 # for the ssl + DTLS tests +trustme # for the ssl + DTLS tests +pylint # for pylint finding all symbols tests +jedi # for jedi code completion tests +cryptography>=41.0.0 # cryptography<41 segfaults on pypy3.10 # Tools -yapf ==0.29.0 # formatting +black; implementation_name == "cpython" +mypy; implementation_name == "cpython" +types-pyOpenSSL; implementation_name == "cpython" # and annotations flake8 +flake8-pyproject astor # code generation +pip-tools >= 6.13.0 +codespell # https://github.com/python-trio/trio/pull/654#issuecomment-420518745 -typed_ast; python_version < "3.8" and implementation_name == "cpython" +mypy-extensions; implementation_name == "cpython" +typing-extensions # Trio's own dependencies cffi; os_name == "nt" -contextvars; python_version < "3.7" -attrs >= 19.2.0 +attrs >= 20.1.0 sortedcontainers -async_generator >= 1.9 idna outcome sniffio - -# Required by contextvars, but harmless to install everywhere. -# dependabot drops the contextvars dependency because it runs -# on 3.7. -immutables >= 0.6 +exceptiongroup >= 1.0.0rc9; python_version < "3.11" diff --git a/test-requirements.txt b/test-requirements.txt index f83f84ea1e..7e0d86e62e 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,52 +1,180 @@ # -# This file is autogenerated by pip-compile -# To update, run: +# This file is autogenerated by pip-compile with Python 3.8 +# by the following command: # -# pip-compile --output-file test-requirements.txt test-requirements.in +# pip-compile test-requirements.in # -astor==0.8.1 # via -r test-requirements.in -astroid==2.3.3 # via pylint -async-generator==1.10 # via -r test-requirements.in -attrs==19.3.0 # via -r test-requirements.in, outcome, pytest -backcall==0.1.0 # via ipython -cffi==1.14.0 # via cryptography -coverage==5.0.3 # via pytest-cov -cryptography==2.8 # via pyopenssl, trustme -decorator==4.4.2 # via ipython, traitlets -entrypoints==0.3 # via flake8 -flake8==3.7.9 # via -r test-requirements.in -idna==2.9 # via -r test-requirements.in, trustme -immutables==0.11 # via -r test-requirements.in -ipython-genutils==0.2.0 # via traitlets -ipython==7.9.0 # via -r test-requirements.in -isort==4.3.21 # via pylint -jedi==0.16.0 # via -r test-requirements.in, ipython -lazy-object-proxy==1.4.3 # via astroid -mccabe==0.6.1 # via flake8, pylint -more-itertools==8.2.0 # via pytest -outcome==1.0.1 # via -r test-requirements.in -packaging==20.3 # via pytest -parso==0.6.2 # via jedi -pexpect==4.8.0 # via ipython -pickleshare==0.7.5 # via ipython -pluggy==0.13.1 # via pytest -prompt-toolkit==2.0.10 # via ipython -ptyprocess==0.6.0 # via pexpect -py==1.8.1 # via pytest -pycodestyle==2.5.0 # via flake8 -pycparser==2.20 # via cffi -pyflakes==2.1.1 # via flake8 -pygments==2.6.1 # via ipython -pylint==2.4.2 # via -r test-requirements.in -pyopenssl==19.1.0 # via -r test-requirements.in -pyparsing==2.4.6 # via packaging -pytest-cov==2.8.1 # via -r test-requirements.in -pytest==5.3.5 # via -r test-requirements.in, pytest-cov -six==1.14.0 # via astroid, cryptography, packaging, prompt-toolkit, pyopenssl, traitlets -sniffio==1.1.0 # via -r test-requirements.in -sortedcontainers==2.1.0 # via -r test-requirements.in -traitlets==4.3.3 # via ipython -trustme==0.6.0 # via -r test-requirements.in -wcwidth==0.1.8 # via prompt-toolkit, pytest -wrapt==1.11.2 # via astroid -yapf==0.29.0 # via -r test-requirements.in +astor==0.8.1 + # via -r test-requirements.in +astroid==2.15.6 + # via pylint +asttokens==2.2.1 + # via stack-data +async-generator==1.10 + # via -r test-requirements.in +attrs==23.1.0 + # via + # -r test-requirements.in + # outcome +backcall==0.2.0 + # via ipython +black==23.7.0 ; implementation_name == "cpython" + # via -r test-requirements.in +build==0.10.0 + # via pip-tools +cffi==1.15.1 + # via cryptography +click==8.1.5 + # via + # black + # pip-tools +codespell==2.2.5 + # via -r test-requirements.in +coverage==7.2.7 + # via -r test-requirements.in +cryptography==41.0.2 + # via + # -r test-requirements.in + # pyopenssl + # trustme + # types-pyopenssl +decorator==5.1.1 + # via ipython +dill==0.3.6 + # via pylint +exceptiongroup==1.1.2 ; python_version < "3.11" + # via + # -r test-requirements.in + # pytest +executing==1.2.0 + # via stack-data +flake8==6.0.0 + # via + # -r test-requirements.in + # flake8-pyproject +flake8-pyproject==1.2.3 + # via -r test-requirements.in +idna==3.4 + # via + # -r test-requirements.in + # trustme +iniconfig==2.0.0 + # via pytest +ipython==8.12.2 + # via -r test-requirements.in +isort==5.12.0 + # via pylint +jedi==0.18.2 + # via + # -r test-requirements.in + # ipython +lazy-object-proxy==1.9.0 + # via astroid +matplotlib-inline==0.1.6 + # via ipython +mccabe==0.7.0 + # via + # flake8 + # pylint +mypy==1.4.1 ; implementation_name == "cpython" + # via -r test-requirements.in +mypy-extensions==1.0.0 ; implementation_name == "cpython" + # via + # -r test-requirements.in + # black + # mypy +nodeenv==1.8.0 + # via pyright +outcome==1.2.0 + # via -r test-requirements.in +packaging==23.1 + # via + # black + # build + # pytest +parso==0.8.3 + # via jedi +pathspec==0.11.1 + # via black +pexpect==4.8.0 + # via ipython +pickleshare==0.7.5 + # via ipython +pip-tools==7.0.0 + # via -r test-requirements.in +platformdirs==3.9.1 + # via + # black + # pylint +pluggy==1.2.0 + # via pytest +prompt-toolkit==3.0.39 + # via ipython +ptyprocess==0.7.0 + # via pexpect +pure-eval==0.2.2 + # via stack-data +pycodestyle==2.10.0 + # via flake8 +pycparser==2.21 + # via cffi +pyflakes==3.0.1 + # via flake8 +pygments==2.15.1 + # via ipython +pylint==2.17.4 + # via -r test-requirements.in +pyopenssl==23.2.0 + # via -r test-requirements.in +pyproject-hooks==1.0.0 + # via build +pyright==1.1.317 + # via -r test-requirements.in +pytest==7.4.0 + # via -r test-requirements.in +six==1.16.0 + # via asttokens +sniffio==1.3.0 + # via -r test-requirements.in +sortedcontainers==2.4.0 + # via -r test-requirements.in +stack-data==0.6.2 + # via ipython +tomli==2.0.1 + # via + # black + # build + # flake8-pyproject + # mypy + # pip-tools + # pylint + # pyproject-hooks + # pytest +tomlkit==0.11.8 + # via pylint +traitlets==5.9.0 + # via + # ipython + # matplotlib-inline +trustme==1.1.0 + # via -r test-requirements.in +types-pyopenssl==23.2.0.1 ; implementation_name == "cpython" + # via -r test-requirements.in +typing-extensions==4.7.1 + # via + # -r test-requirements.in + # astroid + # black + # ipython + # mypy + # pylint +wcwidth==0.2.6 + # via prompt-toolkit +wheel==0.40.0 + # via pip-tools +wrapt==1.15.0 + # via astroid + +# The following packages are considered to be unsafe in a requirements file: +# pip +# setuptools diff --git a/trio/__init__.py b/trio/__init__.py index 663475c0b6..799e699404 100644 --- a/trio/__init__.py +++ b/trio/__init__.py @@ -12,146 +12,137 @@ # # This file pulls together the friendly public API, by re-exporting the more # innocuous bits of the _core API + the higher-level tools from trio/*.py. +# +# Uses `from x import y as y` for compatibility with `pyright --verifytypes` (#2625) -from ._version import __version__ +# must be imported early to avoid circular import +from ._core import TASK_STATUS_IGNORED as TASK_STATUS_IGNORED # isort: skip +# Submodules imported by default +from . import abc, from_thread, lowlevel, socket, to_thread +from ._channel import ( + MemoryReceiveChannel as MemoryReceiveChannel, + MemorySendChannel as MemorySendChannel, + open_memory_channel as open_memory_channel, +) from ._core import ( - TrioInternalError, RunFinishedError, WouldBlock, Cancelled, - BusyResourceError, ClosedResourceError, MultiError, run, open_nursery, - CancelScope, open_cancel_scope, current_effective_deadline, - TASK_STATUS_IGNORED, current_time, BrokenResourceError, EndOfChannel, - Nursery + BrokenResourceError as BrokenResourceError, + BusyResourceError as BusyResourceError, + Cancelled as Cancelled, + CancelScope as CancelScope, + ClosedResourceError as ClosedResourceError, + EndOfChannel as EndOfChannel, + Nursery as Nursery, + RunFinishedError as RunFinishedError, + TaskStatus as TaskStatus, + TrioInternalError as TrioInternalError, + WouldBlock as WouldBlock, + current_effective_deadline as current_effective_deadline, + current_time as current_time, + open_nursery as open_nursery, + run as run, ) - -from ._timeouts import ( - move_on_at, move_on_after, sleep_forever, sleep_until, sleep, fail_at, - fail_after, TooSlowError +from ._core._multierror import ( + MultiError as _MultiError, + NonBaseMultiError as _NonBaseMultiError, ) - -from ._sync import ( - Event, CapacityLimiter, Semaphore, Lock, StrictFIFOLock, Condition +from ._deprecate import TrioDeprecationWarning as TrioDeprecationWarning +from ._dtls import ( + DTLSChannel as DTLSChannel, + DTLSChannelStatistics as DTLSChannelStatistics, + DTLSEndpoint as DTLSEndpoint, ) - -from ._threads import BlockingTrioPortal as _BlockingTrioPortal - -from ._highlevel_generic import aclose_forcefully, StapledStream - -from ._channel import ( - open_memory_channel, MemorySendChannel, MemoryReceiveChannel +from ._file_io import open_file as open_file, wrap_file as wrap_file +from ._highlevel_generic import ( + StapledStream as StapledStream, + aclose_forcefully as aclose_forcefully, ) - -from ._signals import open_signal_receiver - -from ._highlevel_socket import SocketStream, SocketListener - -from ._file_io import open_file, wrap_file - -from ._path import Path - -from ._subprocess import Process, open_process, run_process - -from ._ssl import SSLStream, SSLListener, NeedHandshakeError - -from ._highlevel_serve_listeners import serve_listeners - -from ._highlevel_open_tcp_stream import open_tcp_stream - -from ._highlevel_open_tcp_listeners import open_tcp_listeners, serve_tcp - -from ._highlevel_open_unix_stream import open_unix_socket - +from ._highlevel_open_tcp_listeners import ( + open_tcp_listeners as open_tcp_listeners, + serve_tcp as serve_tcp, +) +from ._highlevel_open_tcp_stream import open_tcp_stream as open_tcp_stream from ._highlevel_open_unix_listeners import open_unix_listeners, serve_unix - +from ._highlevel_open_unix_stream import open_unix_socket as open_unix_socket +from ._highlevel_serve_listeners import serve_listeners as serve_listeners +from ._highlevel_socket import ( + SocketListener as SocketListener, + SocketStream as SocketStream, +) from ._highlevel_ssl_helpers import ( - open_ssl_over_tcp_stream, open_ssl_over_tcp_listeners, serve_ssl_over_tcp + open_ssl_over_tcp_listeners as open_ssl_over_tcp_listeners, + open_ssl_over_tcp_stream as open_ssl_over_tcp_stream, + serve_ssl_over_tcp as serve_ssl_over_tcp, +) +from ._path import Path as Path +from ._signals import open_signal_receiver as open_signal_receiver +from ._ssl import ( + NeedHandshakeError as NeedHandshakeError, + SSLListener as SSLListener, + SSLStream as SSLStream, +) +from ._subprocess import Process as Process, run_process as run_process +from ._sync import ( + CapacityLimiter as CapacityLimiter, + CapacityLimiterStatistics as CapacityLimiterStatistics, + Condition as Condition, + ConditionStatistics as ConditionStatistics, + Event as Event, + EventStatistics as EventStatistics, + Lock as Lock, + LockStatistics as LockStatistics, + Semaphore as Semaphore, + StrictFIFOLock as StrictFIFOLock, +) +from ._timeouts import ( + TooSlowError as TooSlowError, + fail_after as fail_after, + fail_at as fail_at, + move_on_after as move_on_after, + move_on_at as move_on_at, + sleep as sleep, + sleep_forever as sleep_forever, + sleep_until as sleep_until, ) -from ._deprecate import TrioDeprecationWarning +# pyright explicitly does not care about `__version__` +# see https://github.com/microsoft/pyright/blob/main/docs/typed-libraries.md#type-completeness +from ._version import __version__ -# Submodules imported by default -from . import hazmat -from . import socket -from . import abc -from . import from_thread -from . import to_thread # Not imported by default, but mentioned here so static analysis tools like # pylint will know that it exists. if False: from . import testing -from . import _deprecated_ssl_reexports -from . import _deprecated_subprocess_reexports +from . import _deprecate as _deprecate _deprecate.enable_attribute_deprecations(__name__) -__deprecated_attributes__ = { - "ssl": - _deprecate.DeprecatedAttribute( - _deprecated_ssl_reexports, - "0.11.0", - issue=852, - instead=( - "trio.SSLStream, trio.SSLListener, trio.NeedHandshakeError, " - "and the standard library 'ssl' module (minus SSLSocket and " - "wrap_socket())" - ), - ), - "subprocess": - _deprecate.DeprecatedAttribute( - _deprecated_subprocess_reexports, - "0.11.0", - issue=852, - instead=( - "trio.Process and the constants in the standard " - "library 'subprocess' module" - ), - ), - "run_sync_in_worker_thread": - _deprecate.DeprecatedAttribute( - to_thread.run_sync, - "0.12.0", - issue=810, - ), - "current_default_worker_thread_limiter": - _deprecate.DeprecatedAttribute( - to_thread.current_default_thread_limiter, - "0.12.0", - issue=810, - ), - "BlockingTrioPortal": - _deprecate.DeprecatedAttribute( - _BlockingTrioPortal, - "0.12.0", - issue=810, - instead=from_thread, - ), -} -_deprecate.enable_attribute_deprecations(hazmat.__name__) -hazmat.__deprecated_attributes__ = { - "wait_socket_readable": - _deprecate.DeprecatedAttribute( - hazmat.wait_readable, - "0.12.0", - issue=878, - ), - "wait_socket_writable": - _deprecate.DeprecatedAttribute( - hazmat.wait_writable, - "0.12.0", - issue=878, - ), - "notify_socket_close": - _deprecate.DeprecatedAttribute( - hazmat.notify_closing, - "0.12.0", - issue=878, +__deprecated_attributes__ = { + "open_process": _deprecate.DeprecatedAttribute( + value=lowlevel.open_process, + version="0.20.0", + issue=1104, + instead="trio.lowlevel.open_process", + ), + "MultiError": _deprecate.DeprecatedAttribute( + value=_MultiError, + version="0.22.0", + issue=2211, + instead=( + "BaseExceptionGroup (on Python 3.11 and later) or " + "exceptiongroup.BaseExceptionGroup (earlier versions)" ), - "notify_fd_close": - _deprecate.DeprecatedAttribute( - hazmat.notify_closing, - "0.12.0", - issue=878, + ), + "NonBaseMultiError": _deprecate.DeprecatedAttribute( + value=_NonBaseMultiError, + version="0.22.0", + issue=2211, + instead=( + "ExceptionGroup (on Python 3.11 and later) or " + "exceptiongroup.ExceptionGroup (earlier versions)" ), + ), } # Having the public path in .__module__ attributes is important for: @@ -161,21 +152,11 @@ # - pickle # - probably other stuff from ._util import fixup_module_metadata + fixup_module_metadata(__name__, globals()) -fixup_module_metadata(hazmat.__name__, hazmat.__dict__) +fixup_module_metadata(lowlevel.__name__, lowlevel.__dict__) fixup_module_metadata(socket.__name__, socket.__dict__) fixup_module_metadata(abc.__name__, abc.__dict__) fixup_module_metadata(from_thread.__name__, from_thread.__dict__) fixup_module_metadata(to_thread.__name__, to_thread.__dict__) -fixup_module_metadata(__name__ + ".ssl", _deprecated_ssl_reexports.__dict__) -fixup_module_metadata( - __name__ + ".subprocess", _deprecated_subprocess_reexports.__dict__ -) del fixup_module_metadata - -import sys -if sys.version_info < (3, 6): - _deprecate.warn_deprecated( - "Support for Python 3.5", "0.14", issue=75, instead="Python 3.6+" - ) -del sys diff --git a/trio/_abc.py b/trio/_abc.py index 88c2ff1f70..59454b794c 100644 --- a/trio/_abc.py +++ b/trio/_abc.py @@ -1,19 +1,30 @@ +from __future__ import annotations + +import socket from abc import ABCMeta, abstractmethod -from typing import Generic, TypeVar -from ._util import aiter_compat +from typing import TYPE_CHECKING, Generic, TypeVar + import trio +if TYPE_CHECKING: + from types import TracebackType + + from typing_extensions import Self + + # both of these introduce circular imports if outside a TYPE_CHECKING guard + from ._socket import _SocketType + from .lowlevel import Task + # We use ABCMeta instead of ABC, plus set __slots__=(), so as not to force a # __dict__ onto subclasses. class Clock(metaclass=ABCMeta): - """The interface for custom run loop clocks. + """The interface for custom run loop clocks.""" - """ __slots__ = () @abstractmethod - def start_clock(self): + def start_clock(self) -> None: """Do any setup this clock might need. Called at the beginning of the run. @@ -21,7 +32,7 @@ def start_clock(self): """ @abstractmethod - def current_time(self): + def current_time(self) -> float: """Return the current time, according to this clock. This is used to implement functions like :func:`trio.current_time` and @@ -33,7 +44,7 @@ def current_time(self): """ @abstractmethod - def deadline_to_sleep_time(self, deadline): + def deadline_to_sleep_time(self, deadline: float) -> float: """Compute the real time until the given deadline. This is called before we enter a system-specific wait function like @@ -64,62 +75,59 @@ class Instrument(metaclass=ABCMeta): of these methods are optional. This class serves mostly as documentation. """ - __slots__ = () - - def before_run(self): - """Called at the beginning of :func:`trio.run`. - """ + __slots__ = () - def after_run(self): - """Called just before :func:`trio.run` returns. + def before_run(self) -> None: + """Called at the beginning of :func:`trio.run`.""" - """ + def after_run(self) -> None: + """Called just before :func:`trio.run` returns.""" - def task_spawned(self, task): + def task_spawned(self, task: Task) -> None: """Called when the given task is created. Args: - task (trio.hazmat.Task): The new task. + task (trio.lowlevel.Task): The new task. """ - def task_scheduled(self, task): + def task_scheduled(self, task: Task) -> None: """Called when the given task becomes runnable. It may still be some time before it actually runs, if there are other runnable tasks ahead of it. Args: - task (trio.hazmat.Task): The task that became runnable. + task (trio.lowlevel.Task): The task that became runnable. """ - def before_task_step(self, task): + def before_task_step(self, task: Task) -> None: """Called immediately before we resume running the given task. Args: - task (trio.hazmat.Task): The task that is about to run. + task (trio.lowlevel.Task): The task that is about to run. """ - def after_task_step(self, task): + def after_task_step(self, task: Task) -> None: """Called when we return to the main run loop after a task has yielded. Args: - task (trio.hazmat.Task): The task that just ran. + task (trio.lowlevel.Task): The task that just ran. """ - def task_exited(self, task): + def task_exited(self, task: Task) -> None: """Called when the given task exits. Args: - task (trio.hazmat.Task): The finished task. + task (trio.lowlevel.Task): The finished task. """ - def before_io_wait(self, timeout): + def before_io_wait(self, timeout: float) -> None: """Called before blocking to wait for I/O readiness. Args: @@ -127,7 +135,7 @@ def before_io_wait(self, timeout): """ - def after_io_wait(self, timeout): + def after_io_wait(self, timeout: float) -> None: """Called after handling pending I/O. Args: @@ -145,12 +153,27 @@ class HostnameResolver(metaclass=ABCMeta): See :func:`trio.socket.set_custom_hostname_resolver`. """ + __slots__ = () @abstractmethod async def getaddrinfo( - self, host, port, family=0, type=0, proto=0, flags=0 - ): + self, + host: bytes | str | None, + port: bytes | str | int | None, + family: int = 0, + type: int = 0, + proto: int = 0, + flags: int = 0, + ) -> list[ + tuple[ + socket.AddressFamily, + socket.SocketKind, + int, + str, + tuple[str, int] | tuple[str, int, int, int], + ] + ]: """A custom implementation of :func:`~trio.socket.getaddrinfo`. Called by :func:`trio.socket.getaddrinfo`. @@ -167,7 +190,9 @@ async def getaddrinfo( """ @abstractmethod - async def getnameinfo(self, sockaddr, flags): + async def getnameinfo( + self, sockaddr: tuple[str, int] | tuple[str, int, int, int], flags: int + ) -> tuple[str, str]: """A custom implementation of :func:`~trio.socket.getnameinfo`. Called by :func:`trio.socket.getnameinfo`. @@ -182,8 +207,14 @@ class SocketFactory(metaclass=ABCMeta): See :func:`trio.socket.set_custom_socket_factory`. """ + @abstractmethod - def socket(self, family=None, type=None, proto=None): + def socket( + self, + family: socket.AddressFamily | int | None = None, + type: socket.SocketKind | int | None = None, + proto: int | None = None, + ) -> _SocketType: """Create and return a socket object. Your socket object must inherit from :class:`trio.socket.SocketType`, @@ -225,10 +256,11 @@ class AsyncResource(metaclass=ABCMeta): ``__aenter__`` and ``__aexit__`` should be adequate for all subclasses. """ + __slots__ = () @abstractmethod - async def aclose(self): + async def aclose(self) -> None: """Close this resource, possibly blocking. IMPORTANT: This method may block in order to perform a "graceful" @@ -256,10 +288,15 @@ async def aclose(self): """ - async def __aenter__(self): + async def __aenter__(self) -> Self: return self - async def __aexit__(self, *args): + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: await self.aclose() @@ -278,10 +315,11 @@ class SendStream(AsyncResource): :class:`SendChannel`. """ + __slots__ = () @abstractmethod - async def send_all(self, data): + async def send_all(self, data: bytes | bytearray | memoryview) -> None: """Sends the given data through the stream, blocking if necessary. Args: @@ -307,7 +345,7 @@ async def send_all(self, data): """ @abstractmethod - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: """Block until it's possible that :meth:`send_all` might not block. This method may return early: it's possible that after it returns, @@ -334,7 +372,7 @@ async def wait_send_all_might_not_block(self): This method is intended to aid in implementing protocols that want to delay choosing which data to send until the last moment. E.g., - suppose you're working on an implemention of a remote display server + suppose you're working on an implementation of a remote display server like `VNC `__, and the network connection is currently backed up so that if you call @@ -383,10 +421,11 @@ class ReceiveStream(AsyncResource): byte, and the loop automatically exits when reaching end-of-file. """ + __slots__ = () @abstractmethod - async def receive_some(self, max_bytes=None): + async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray: """Wait until there is data available on this stream, and then return some of it. @@ -414,11 +453,10 @@ async def receive_some(self, max_bytes=None): """ - @aiter_compat - def __aiter__(self): + def __aiter__(self) -> Self: return self - async def __anext__(self): + async def __anext__(self) -> bytes | bytearray: data = await self.receive_some() if not data: raise StopAsyncIteration @@ -435,6 +473,7 @@ class Stream(SendStream, ReceiveStream): step further and implement :class:`HalfCloseableStream`. """ + __slots__ = () @@ -443,10 +482,11 @@ class HalfCloseableStream(Stream): part of the stream without closing the receive part. """ + __slots__ = () @abstractmethod - async def send_eof(self): + async def send_eof(self) -> None: """Send an end-of-file indication on this stream, if possible. The difference between :meth:`send_eof` and @@ -521,10 +561,11 @@ class Listener(AsyncResource, Generic[T_resource]): or using an ``async with`` block. """ + __slots__ = () @abstractmethod - async def accept(self): + async def accept(self) -> AsyncResource: """Wait until an incoming connection arrives, and then return it. Returns: @@ -562,6 +603,7 @@ class SendChannel(AsyncResource, Generic[SendType]): `SendStream`. """ + __slots__ = () @abstractmethod @@ -606,6 +648,7 @@ class ReceiveChannel(AsyncResource, Generic[ReceiveType]): `ReceiveStream`. """ + __slots__ = () @abstractmethod @@ -629,8 +672,7 @@ async def receive(self) -> ReceiveType: """ - @aiter_compat - def __aiter__(self): + def __aiter__(self) -> Self: return self async def __anext__(self) -> ReceiveType: diff --git a/trio/_channel.py b/trio/_channel.py index c902c485fa..c8d27695b8 100644 --- a/trio/_channel.py +++ b/trio/_channel.py @@ -1,18 +1,27 @@ -from collections import deque, OrderedDict +from __future__ import annotations + +from collections import OrderedDict, deque from math import inf +from types import TracebackType +from typing import Tuple # only needed for typechecking on <3.9 +from typing import TYPE_CHECKING, Generic, TypeVar import attr from outcome import Error, Value -from .abc import SendChannel, ReceiveChannel, Channel -from ._util import generic_function, NoPublicConstructor - import trio -from ._core import enable_ki_protection + +from ._abc import ReceiveChannel, ReceiveType, SendChannel, SendType, T +from ._core import Abort, RaiseCancelT, Task, enable_ki_protection +from ._util import NoPublicConstructor, generic_function + +# Temporary TypeVar needed until mypy release supports Self as a type +SelfT = TypeVar("SelfT") -@generic_function -def open_memory_channel(max_buffer_size): +def _open_memory_channel( + max_buffer_size: int | float, +) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]: """Open a channel for passing objects between tasks within a process. Memory channels are lightweight, cheap to allocate, and entirely @@ -68,35 +77,57 @@ def open_memory_channel(max_buffer_size): raise TypeError("max_buffer_size must be an integer or math.inf") if max_buffer_size < 0: raise ValueError("max_buffer_size must be >= 0") - state = MemoryChannelState(max_buffer_size) + state: MemoryChannelState[T] = MemoryChannelState(max_buffer_size) return ( - MemorySendChannel._create(state), MemoryReceiveChannel._create(state) + MemorySendChannel[T]._create(state), + MemoryReceiveChannel[T]._create(state), ) +# This workaround requires python3.9+, once older python versions are not supported +# or there's a better way of achieving type-checking on a generic factory function, +# it could replace the normal function header +if TYPE_CHECKING: + # written as a class so you can say open_memory_channel[int](5) + # Need to use Tuple instead of tuple due to CI check running on 3.8 + class open_memory_channel(Tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]): + def __new__( # type: ignore[misc] # "must return a subtype" + cls, max_buffer_size: int | float + ) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]: + return _open_memory_channel(max_buffer_size) + + def __init__(self, max_buffer_size: int | float): + ... + +else: + # apply the generic_function decorator to make open_memory_channel indexable + # so it's valid to say e.g. ``open_memory_channel[bytes](5)`` at runtime + open_memory_channel = generic_function(_open_memory_channel) + + @attr.s(frozen=True, slots=True) class MemoryChannelStats: - current_buffer_used = attr.ib() - max_buffer_size = attr.ib() - open_send_channels = attr.ib() - open_receive_channels = attr.ib() - tasks_waiting_send = attr.ib() - tasks_waiting_receive = attr.ib() + current_buffer_used: int = attr.ib() + max_buffer_size: int | float = attr.ib() + open_send_channels: int = attr.ib() + open_receive_channels: int = attr.ib() + tasks_waiting_send: int = attr.ib() + tasks_waiting_receive: int = attr.ib() @attr.s(slots=True) -class MemoryChannelState: - max_buffer_size = attr.ib() - data = attr.ib(factory=deque) +class MemoryChannelState(Generic[T]): + max_buffer_size: int | float = attr.ib() + data: deque[T] = attr.ib(factory=deque) # Counts of open endpoints using this state - open_send_channels = attr.ib(default=0) - open_receive_channels = attr.ib(default=0) + open_send_channels: int = attr.ib(default=0) + open_receive_channels: int = attr.ib(default=0) # {task: value} - send_tasks = attr.ib(factory=OrderedDict) + send_tasks: OrderedDict[Task, T] = attr.ib(factory=OrderedDict) # {task: None} - receive_tasks = attr.ib(factory=OrderedDict) + receive_tasks: OrderedDict[Task, None] = attr.ib(factory=OrderedDict) - def statistics(self): + def statistics(self) -> MemoryChannelStats: return MemoryChannelStats( current_buffer_used=len(self.data), max_buffer_size=self.max_buffer_size, @@ -108,30 +139,28 @@ def statistics(self): @attr.s(eq=False, repr=False) -class MemorySendChannel(SendChannel, metaclass=NoPublicConstructor): - _state = attr.ib() - _closed = attr.ib(default=False) +class MemorySendChannel(SendChannel[SendType], metaclass=NoPublicConstructor): + _state: MemoryChannelState[SendType] = attr.ib() + _closed: bool = attr.ib(default=False) # This is just the tasks waiting on *this* object. As compared to # self._state.send_tasks, which includes tasks from this object and # all clones. - _tasks = attr.ib(factory=set) + _tasks: set[Task] = attr.ib(factory=set) - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: self._state.open_send_channels += 1 - def __repr__(self): - return ( - "".format( - id(self), id(self._state) - ) + def __repr__(self) -> str: + return "".format( + id(self), id(self._state) ) - def statistics(self): + def statistics(self) -> MemoryChannelStats: # XX should we also report statistics specific to this object? return self._state.statistics() @enable_ki_protection - def send_nowait(self, value): + def send_nowait(self, value: SendType) -> None: """Like `~trio.abc.SendChannel.send`, but if the channel's buffer is full, raises `WouldBlock` instead of blocking. @@ -144,42 +173,43 @@ def send_nowait(self, value): assert not self._state.data task, _ = self._state.receive_tasks.popitem(last=False) task.custom_sleep_data._tasks.remove(task) - trio.hazmat.reschedule(task, Value(value)) + trio.lowlevel.reschedule(task, Value(value)) elif len(self._state.data) < self._state.max_buffer_size: self._state.data.append(value) else: raise trio.WouldBlock @enable_ki_protection - async def send(self, value): + async def send(self, value: SendType) -> None: """See `SendChannel.send `. Memory channels allow multiple tasks to call `send` at the same time. """ - await trio.hazmat.checkpoint_if_cancelled() + await trio.lowlevel.checkpoint_if_cancelled() try: self.send_nowait(value) except trio.WouldBlock: pass else: - await trio.hazmat.cancel_shielded_checkpoint() + await trio.lowlevel.cancel_shielded_checkpoint() return - task = trio.hazmat.current_task() + task = trio.lowlevel.current_task() self._tasks.add(task) self._state.send_tasks[task] = value task.custom_sleep_data = self - def abort_fn(_): + def abort_fn(_: RaiseCancelT) -> Abort: self._tasks.remove(task) del self._state.send_tasks[task] - return trio.hazmat.Abort.SUCCEEDED + return trio.lowlevel.Abort.SUCCEEDED - await trio.hazmat.wait_task_rescheduled(abort_fn) + await trio.lowlevel.wait_task_rescheduled(abort_fn) + # Return type must be stringified or use a TypeVar @enable_ki_protection - def clone(self): + def clone(self) -> MemorySendChannel[SendType]: """Clone this send channel object. This returns a new `MemorySendChannel` object, which acts as a @@ -207,14 +237,35 @@ def clone(self): raise trio.ClosedResourceError return MemorySendChannel._create(self._state) + def __enter__(self: SelfT) -> SelfT: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + self.close() + @enable_ki_protection - async def aclose(self): + def close(self) -> None: + """Close this send channel object synchronously. + + All channel objects have an asynchronous `~.AsyncResource.aclose` method. + Memory channels can also be closed synchronously. This has the same + effect on the channel and other tasks using it, but `close` is not a + trio checkpoint. This simplifies cleaning up in cancelled tasks. + + Using ``with send_channel:`` will close the channel object on leaving + the with block. + + """ if self._closed: - await trio.hazmat.checkpoint() return self._closed = True for task in self._tasks: - trio.hazmat.reschedule(task, Error(trio.ClosedResourceError())) + trio.lowlevel.reschedule(task, Error(trio.ClosedResourceError())) del self._state.send_tasks[task] self._tasks.clear() self._state.open_send_channels -= 1 @@ -222,30 +273,34 @@ async def aclose(self): assert not self._state.send_tasks for task in self._state.receive_tasks: task.custom_sleep_data._tasks.remove(task) - trio.hazmat.reschedule(task, Error(trio.EndOfChannel())) + trio.lowlevel.reschedule(task, Error(trio.EndOfChannel())) self._state.receive_tasks.clear() - await trio.hazmat.checkpoint() + + @enable_ki_protection + async def aclose(self) -> None: + self.close() + await trio.lowlevel.checkpoint() @attr.s(eq=False, repr=False) -class MemoryReceiveChannel(ReceiveChannel, metaclass=NoPublicConstructor): - _state = attr.ib() - _closed = attr.ib(default=False) - _tasks = attr.ib(factory=set) +class MemoryReceiveChannel(ReceiveChannel[ReceiveType], metaclass=NoPublicConstructor): + _state: MemoryChannelState[ReceiveType] = attr.ib() + _closed: bool = attr.ib(default=False) + _tasks: set[trio._core._run.Task] = attr.ib(factory=set) - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: self._state.open_receive_channels += 1 - def statistics(self): + def statistics(self) -> MemoryChannelStats: return self._state.statistics() - def __repr__(self): + def __repr__(self) -> str: return "".format( id(self), id(self._state) ) @enable_ki_protection - def receive_nowait(self): + def receive_nowait(self) -> ReceiveType: """Like `~trio.abc.ReceiveChannel.receive`, but if there's nothing ready to receive, raises `WouldBlock` instead of blocking. @@ -255,7 +310,7 @@ def receive_nowait(self): if self._state.send_tasks: task, value = self._state.send_tasks.popitem(last=False) task.custom_sleep_data._tasks.remove(task) - trio.hazmat.reschedule(task) + trio.lowlevel.reschedule(task) self._state.data.append(value) # Fall through if self._state.data: @@ -265,7 +320,7 @@ def receive_nowait(self): raise trio.WouldBlock @enable_ki_protection - async def receive(self): + async def receive(self) -> ReceiveType: """See `ReceiveChannel.receive `. Memory channels allow multiple tasks to call `receive` at the same @@ -273,29 +328,31 @@ async def receive(self): will get the second item sent, and so on. """ - await trio.hazmat.checkpoint_if_cancelled() + await trio.lowlevel.checkpoint_if_cancelled() try: value = self.receive_nowait() except trio.WouldBlock: pass else: - await trio.hazmat.cancel_shielded_checkpoint() + await trio.lowlevel.cancel_shielded_checkpoint() return value - task = trio.hazmat.current_task() + task = trio.lowlevel.current_task() self._tasks.add(task) self._state.receive_tasks[task] = None task.custom_sleep_data = self - def abort_fn(_): + def abort_fn(_: RaiseCancelT) -> Abort: self._tasks.remove(task) del self._state.receive_tasks[task] - return trio.hazmat.Abort.SUCCEEDED + return trio.lowlevel.Abort.SUCCEEDED - return await trio.hazmat.wait_task_rescheduled(abort_fn) + # Not strictly guaranteed to return ReceiveType, but will do so unless + # you intentionally reschedule with a bad value. + return await trio.lowlevel.wait_task_rescheduled(abort_fn) # type: ignore[no-any-return] @enable_ki_protection - def clone(self): + def clone(self) -> MemoryReceiveChannel[ReceiveType]: """Clone this receive channel object. This returns a new `MemoryReceiveChannel` object, which acts as a @@ -326,14 +383,35 @@ def clone(self): raise trio.ClosedResourceError return MemoryReceiveChannel._create(self._state) + def __enter__(self: SelfT) -> SelfT: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + self.close() + @enable_ki_protection - async def aclose(self): + def close(self) -> None: + """Close this receive channel object synchronously. + + All channel objects have an asynchronous `~.AsyncResource.aclose` method. + Memory channels can also be closed synchronously. This has the same + effect on the channel and other tasks using it, but `close` is not a + trio checkpoint. This simplifies cleaning up in cancelled tasks. + + Using ``with receive_channel:`` will close the channel object on + leaving the with block. + + """ if self._closed: - await trio.hazmat.checkpoint() return self._closed = True for task in self._tasks: - trio.hazmat.reschedule(task, Error(trio.ClosedResourceError())) + trio.lowlevel.reschedule(task, Error(trio.ClosedResourceError())) del self._state.receive_tasks[task] self._tasks.clear() self._state.open_receive_channels -= 1 @@ -341,7 +419,11 @@ async def aclose(self): assert not self._state.receive_tasks for task in self._state.send_tasks: task.custom_sleep_data._tasks.remove(task) - trio.hazmat.reschedule(task, Error(trio.BrokenResourceError())) + trio.lowlevel.reschedule(task, Error(trio.BrokenResourceError())) self._state.send_tasks.clear() self._state.data.clear() - await trio.hazmat.checkpoint() + + @enable_ki_protection + async def aclose(self) -> None: + self.close() + await trio.lowlevel.checkpoint() diff --git a/trio/_core/__init__.py b/trio/_core/__init__.py index fda027e193..aa898fffe0 100644 --- a/trio/_core/__init__.py +++ b/trio/_core/__init__.py @@ -1,56 +1,81 @@ """ This namespace represents the core functionality that has to be built-in and deal with private internal data structures. Things in this namespace -are publicly available in either trio, trio.hazmat, or trio.testing. +are publicly available in either trio, trio.lowlevel, or trio.testing. """ -from ._exceptions import ( - TrioInternalError, RunFinishedError, WouldBlock, Cancelled, - BusyResourceError, ClosedResourceError, BrokenResourceError, EndOfChannel -) +import sys -from ._multierror import MultiError - -from ._ki import ( - enable_ki_protection, disable_ki_protection, currently_ki_protected +from ._entry_queue import TrioToken +from ._exceptions import ( + BrokenResourceError, + BusyResourceError, + Cancelled, + ClosedResourceError, + EndOfChannel, + RunFinishedError, + TrioInternalError, + WouldBlock, ) +from ._ki import currently_ki_protected, disable_ki_protection, enable_ki_protection +from ._local import RunVar +from ._mock_clock import MockClock +from ._parking_lot import ParkingLot, ParkingLotStatistics # Imports that always exist from ._run import ( - Task, CancelScope, run, open_nursery, open_cancel_scope, checkpoint, - current_task, current_effective_deadline, checkpoint_if_cancelled, - TASK_STATUS_IGNORED, current_statistics, current_trio_token, reschedule, - remove_instrument, add_instrument, current_clock, current_root_task, - spawn_system_task, current_time, wait_all_tasks_blocked, wait_readable, - wait_writable, notify_closing, Nursery + TASK_STATUS_IGNORED, + CancelScope, + Nursery, + Task, + TaskStatus, + add_instrument, + checkpoint, + checkpoint_if_cancelled, + current_clock, + current_effective_deadline, + current_root_task, + current_statistics, + current_task, + current_time, + current_trio_token, + notify_closing, + open_nursery, + remove_instrument, + reschedule, + run, + spawn_system_task, + start_guest_run, + wait_all_tasks_blocked, + wait_readable, + wait_writable, ) +from ._thread_cache import start_thread_soon # Has to come after _run to resolve a circular import from ._traps import ( - cancel_shielded_checkpoint, Abort, wait_task_rescheduled, - temporarily_detach_coroutine_object, permanently_detach_coroutine_object, - reattach_detached_coroutine_object + Abort, + RaiseCancelT, + cancel_shielded_checkpoint, + permanently_detach_coroutine_object, + reattach_detached_coroutine_object, + temporarily_detach_coroutine_object, + wait_task_rescheduled, ) - -from ._entry_queue import TrioToken - -from ._parking_lot import ParkingLot - from ._unbounded_queue import UnboundedQueue -from ._local import RunVar - -# Kqueue imports -try: - from ._run import (current_kqueue, monitor_kevent, wait_kevent) -except ImportError: - pass - # Windows imports -try: +if sys.platform == "win32": from ._run import ( - monitor_completion_key, current_iocp, register_with_iocp, - wait_overlapped, write_overlapped, readinto_overlapped + current_iocp, + monitor_completion_key, + readinto_overlapped, + register_with_iocp, + wait_overlapped, + write_overlapped, ) -except ImportError: - pass +# Kqueue imports +elif sys.platform != "linux" and sys.platform != "win32": + from ._run import current_kqueue, monitor_kevent, wait_kevent + +del sys # It would be better to import sys as _sys, but mypy does not understand it diff --git a/trio/_core/_asyncgens.py b/trio/_core/_asyncgens.py new file mode 100644 index 0000000000..5f02ebe76d --- /dev/null +++ b/trio/_core/_asyncgens.py @@ -0,0 +1,194 @@ +import logging +import sys +import warnings +import weakref + +import attr + +from .. import _core +from .._util import name_asyncgen +from . import _run + +# Used to log exceptions in async generator finalizers +ASYNCGEN_LOGGER = logging.getLogger("trio.async_generator_errors") + + +@attr.s(eq=False, slots=True) +class AsyncGenerators: + # Async generators are added to this set when first iterated. Any + # left after the main task exits will be closed before trio.run() + # returns. During most of the run, this is a WeakSet so GC works. + # During shutdown, when we're finalizing all the remaining + # asyncgens after the system nursery has been closed, it's a + # regular set so we don't have to deal with GC firing at + # unexpected times. + alive = attr.ib(factory=weakref.WeakSet) + + # This collects async generators that get garbage collected during + # the one-tick window between the system nursery closing and the + # init task starting end-of-run asyncgen finalization. + trailing_needs_finalize = attr.ib(factory=set) + + prev_hooks = attr.ib(init=False) + + def install_hooks(self, runner): + def firstiter(agen): + if hasattr(_run.GLOBAL_RUN_CONTEXT, "task"): + self.alive.add(agen) + else: + # An async generator first iterated outside of a Trio + # task doesn't belong to Trio. Probably we're in guest + # mode and the async generator belongs to our host. + # The locals dictionary is the only good place to + # remember this fact, at least until + # https://bugs.python.org/issue40916 is implemented. + agen.ag_frame.f_locals["@trio_foreign_asyncgen"] = True + if self.prev_hooks.firstiter is not None: + self.prev_hooks.firstiter(agen) + + def finalize_in_trio_context(agen, agen_name): + try: + runner.spawn_system_task( + self._finalize_one, + agen, + agen_name, + name=f"close asyncgen {agen_name} (abandoned)", + ) + except RuntimeError: + # There is a one-tick window where the system nursery + # is closed but the init task hasn't yet made + # self.asyncgens a strong set to disable GC. We seem to + # have hit it. + self.trailing_needs_finalize.add(agen) + + def finalizer(agen): + agen_name = name_asyncgen(agen) + try: + is_ours = not agen.ag_frame.f_locals.get("@trio_foreign_asyncgen") + except AttributeError: # pragma: no cover + is_ours = True + + if is_ours: + runner.entry_queue.run_sync_soon( + finalize_in_trio_context, agen, agen_name + ) + + # Do this last, because it might raise an exception + # depending on the user's warnings filter. (That + # exception will be printed to the terminal and + # ignored, since we're running in GC context.) + warnings.warn( + f"Async generator {agen_name!r} was garbage collected before it " + "had been exhausted. Surround its use in 'async with " + "aclosing(...):' to ensure that it gets cleaned up as soon as " + "you're done using it.", + ResourceWarning, + stacklevel=2, + source=agen, + ) + else: + # Not ours -> forward to the host loop's async generator finalizer + if self.prev_hooks.finalizer is not None: + self.prev_hooks.finalizer(agen) + else: + # Host has no finalizer. Reimplement the default + # Python behavior with no hooks installed: throw in + # GeneratorExit, step once, raise RuntimeError if + # it doesn't exit. + closer = agen.aclose() + try: + # If the next thing is a yield, this will raise RuntimeError + # which we allow to propagate + closer.send(None) + except StopIteration: + pass + else: + # If the next thing is an await, we get here. Give a nicer + # error than the default "async generator ignored GeneratorExit" + raise RuntimeError( + f"Non-Trio async generator {agen_name!r} awaited something " + "during finalization; install a finalization hook to " + "support this, or wrap it in 'async with aclosing(...):'" + ) + + self.prev_hooks = sys.get_asyncgen_hooks() + sys.set_asyncgen_hooks(firstiter=firstiter, finalizer=finalizer) + + async def finalize_remaining(self, runner): + # This is called from init after shutting down the system nursery. + # The only tasks running at this point are init and + # the run_sync_soon task, and since the system nursery is closed, + # there's no way for user code to spawn more. + assert _core.current_task() is runner.init_task + assert len(runner.tasks) == 2 + + # To make async generator finalization easier to reason + # about, we'll shut down asyncgen garbage collection by turning + # the alive WeakSet into a regular set. + self.alive = set(self.alive) + + # Process all pending run_sync_soon callbacks, in case one of + # them was an asyncgen finalizer that snuck in under the wire. + runner.entry_queue.run_sync_soon(runner.reschedule, runner.init_task) + await _core.wait_task_rescheduled( + lambda _: _core.Abort.FAILED # pragma: no cover + ) + self.alive.update(self.trailing_needs_finalize) + self.trailing_needs_finalize.clear() + + # None of the still-living tasks use async generators, so + # every async generator must be suspended at a yield point -- + # there's no one to be doing the iteration. That's good, + # because aclose() only works on an asyncgen that's suspended + # at a yield point. (If it's suspended at an event loop trap, + # because someone is in the middle of iterating it, then you + # get a RuntimeError on 3.8+, and a nasty surprise on earlier + # versions due to https://bugs.python.org/issue32526.) + # + # However, once we start aclose() of one async generator, it + # might start fetching the next value from another, thus + # preventing us from closing that other (at least until + # aclose() of the first one is complete). This constraint + # effectively requires us to finalize the remaining asyncgens + # in arbitrary order, rather than doing all of them at the + # same time. On 3.8+ we could defer any generator with + # ag_running=True to a later batch, but that only catches + # the case where our aclose() starts after the user's + # asend()/etc. If our aclose() starts first, then the + # user's asend()/etc will raise RuntimeError, since they're + # probably not checking ag_running. + # + # It might be possible to allow some parallelized cleanup if + # we can determine that a certain set of asyncgens have no + # interdependencies, using gc.get_referents() and such. + # But just doing one at a time will typically work well enough + # (since each aclose() executes in a cancelled scope) and + # is much easier to reason about. + + # It's possible that that cleanup code will itself create + # more async generators, so we iterate repeatedly until + # all are gone. + while self.alive: + batch = self.alive + self.alive = set() + for agen in batch: + await self._finalize_one(agen, name_asyncgen(agen)) + + def close(self): + sys.set_asyncgen_hooks(*self.prev_hooks) + + async def _finalize_one(self, agen, name): + try: + # This shield ensures that finalize_asyncgen never exits + # with an exception, not even a Cancelled. The inside + # is cancelled so there's no deadlock risk. + with _core.CancelScope(shield=True) as cancel_scope: + cancel_scope.cancel() + await agen.aclose() + except BaseException: + ASYNCGEN_LOGGER.exception( + "Exception ignored during finalization of async generator %r -- " + "surround your use of the generator in 'async with aclosing(...):' " + "to raise exceptions like this in the context where they're generated", + name, + ) diff --git a/trio/_core/_entry_queue.py b/trio/_core/_entry_queue.py index 97b1c56fa4..878506bb2b 100644 --- a/trio/_core/_entry_queue.py +++ b/trio/_core/_entry_queue.py @@ -1,13 +1,12 @@ -from collections import deque import threading +from collections import deque import attr from .. import _core +from .._util import NoPublicConstructor from ._wakeup_socketpair import WakeupSocketpair -__all__ = ["TrioToken"] - @attr.s(slots=True) class EntryQueue: @@ -16,8 +15,8 @@ class EntryQueue: # not signal-safe. deque is implemented in C, so each operation is atomic # WRT threads (and this is guaranteed in the docs), AND each operation is # atomic WRT signal delivery (signal handlers can run on either side, but - # not *during* a deque operation). dict makes similar guarantees - and on - # CPython 3.6 and PyPy, it's even ordered! + # not *during* a deque operation). dict makes similar guarantees - and + # it's even ordered! queue = attr.ib(factory=deque) idempotent_queue = attr.ib(factory=dict) @@ -57,7 +56,15 @@ def run_cb(job): async def kill_everything(exc): raise exc - _core.spawn_system_task(kill_everything, exc) + try: + _core.spawn_system_task(kill_everything, exc) + except RuntimeError: + # We're quite late in the shutdown process and the + # system nursery is already closed. + # TODO(2020-06): this is a gross hack and should + # be fixed soon when we address #1607. + _core.current_task().parent_nursery.start_soon(kill_everything, exc) + return True # This has to be carefully written to be safe in the face of new items @@ -103,10 +110,6 @@ def close(self): def size(self): return len(self.queue) + len(self.idempotent_queue) - def spawn(self): - name = "" - _core.spawn_system_task(self.task, name=name) - def run_sync_soon(self, sync_fn, *args, idempotent=False): with self.lock: if self.done: @@ -123,7 +126,8 @@ def run_sync_soon(self, sync_fn, *args, idempotent=False): self.wakeup.wakeup_thread_and_signal_safe() -class TrioToken: +@attr.s(eq=False, hash=False, slots=True) +class TrioToken(metaclass=NoPublicConstructor): """An opaque object representing a single call to :func:`trio.run`. It has no public constructor; instead, see :func:`current_trio_token`. @@ -142,10 +146,7 @@ class TrioToken: """ - __slots__ = ('_reentry_queue',) - - def __init__(self, reentry_queue): - self._reentry_queue = reentry_queue + _reentry_queue = attr.ib() def run_sync_soon(self, sync_fn, *args, idempotent=False): """Schedule a call to ``sync_fn(*args)`` to occur in the context of a @@ -160,12 +161,12 @@ def run_sync_soon(self, sync_fn, *args, idempotent=False): If you need this, you'll have to build your own. The call is effectively run as part of a system task (see - :func:`~trio.hazmat.spawn_system_task`). In particular this means + :func:`~trio.lowlevel.spawn_system_task`). In particular this means that: * :exc:`KeyboardInterrupt` protection is *enabled* by default; if you want ``sync_fn`` to be interruptible by control-C, then you - need to use :func:`~trio.hazmat.disable_ki_protection` + need to use :func:`~trio.lowlevel.disable_ki_protection` explicitly. * If ``sync_fn`` raises an exception, then it's converted into a @@ -178,9 +179,7 @@ def run_sync_soon(self, sync_fn, *args, idempotent=False): If ``idempotent=True``, then ``sync_fn`` and ``args`` must be hashable, and Trio will make a best-effort attempt to discard any call submission which is equal to an already-pending call. Trio - will make an attempt to process these in first-in first-out order, - but no guarantees. (Currently processing is FIFO on CPython 3.6 and - PyPy, but not CPython 3.5.) + will process these in first-in first-out order. Any ordering guarantees apply separately to ``idempotent=False`` and ``idempotent=True`` calls; there's no rule for how calls in the @@ -193,6 +192,4 @@ def run_sync_soon(self, sync_fn, *args, idempotent=False): exits.) """ - self._reentry_queue.run_sync_soon( - sync_fn, *args, idempotent=idempotent - ) + self._reentry_queue.run_sync_soon(sync_fn, *args, idempotent=idempotent) diff --git a/trio/_core/_exceptions.py b/trio/_core/_exceptions.py index 3cdaa2ca8e..bdc7b31c21 100644 --- a/trio/_core/_exceptions.py +++ b/trio/_core/_exceptions.py @@ -1,11 +1,9 @@ -import attr - from trio._util import NoPublicConstructor class TrioInternalError(Exception): """Raised by :func:`run` if we encounter a bug in Trio, or (possibly) a - misuse of one of the low-level :mod:`trio.hazmat` APIs. + misuse of one of the low-level :mod:`trio.lowlevel` APIs. This should never happen! If you get this error, please file a bug. @@ -25,9 +23,7 @@ class RunFinishedError(RuntimeError): class WouldBlock(Exception): - """Raised by ``X_nowait`` functions if ``X`` would block. - - """ + """Raised by ``X_nowait`` functions if ``X`` would block.""" class Cancelled(BaseException, metaclass=NoPublicConstructor): @@ -62,7 +58,8 @@ class Cancelled(BaseException, metaclass=NoPublicConstructor): everywhere. """ - def __str__(self): + + def __str__(self) -> str: return "Cancelled" diff --git a/trio/_core/_generated_instrumentation.py b/trio/_core/_generated_instrumentation.py new file mode 100644 index 0000000000..30c2f26b4e --- /dev/null +++ b/trio/_core/_generated_instrumentation.py @@ -0,0 +1,48 @@ +# *********************************************************** +# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** +# ************************************************************* +# isort: skip +from ._instrumentation import Instrument +from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED +from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT + +# fmt: off + + +def add_instrument(instrument: Instrument) ->None: + """Start instrumenting the current run loop with the given instrument. + + Args: + instrument (trio.abc.Instrument): The instrument to activate. + + If ``instrument`` is already active, does nothing. + + """ + locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True + try: + return GLOBAL_RUN_CONTEXT.runner.instruments.add_instrument(instrument) + except AttributeError: + raise RuntimeError("must be called from async context") + + +def remove_instrument(instrument: Instrument) ->None: + """Stop instrumenting the current run loop with the given instrument. + + Args: + instrument (trio.abc.Instrument): The instrument to de-activate. + + Raises: + KeyError: if the instrument is not currently active. This could + occur either because you never added it, or because you added it + and then it raised an unhandled exception and was automatically + deactivated. + + """ + locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True + try: + return GLOBAL_RUN_CONTEXT.runner.instruments.remove_instrument(instrument) + except AttributeError: + raise RuntimeError("must be called from async context") + + +# fmt: on diff --git a/trio/_core/_generated_io_epoll.py b/trio/_core/_generated_io_epoll.py index fe63a6ee0c..02fb3bc348 100644 --- a/trio/_core/_generated_io_epoll.py +++ b/trio/_core/_generated_io_epoll.py @@ -1,28 +1,36 @@ # *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND +# isort: skip +from ._instrumentation import Instrument from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED +from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT + +# fmt: off - async def wait_readable(fd): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd) except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + async def wait_writable(fd): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd) except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + def notify_closing(fd): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: - return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd) + return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd) except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + + +# fmt: on diff --git a/trio/_core/_generated_io_kqueue.py b/trio/_core/_generated_io_kqueue.py index 059a8a95d1..94e819769c 100644 --- a/trio/_core/_generated_io_kqueue.py +++ b/trio/_core/_generated_io_kqueue.py @@ -1,49 +1,60 @@ # *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND +# isort: skip +from ._instrumentation import Instrument from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED +from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT + +# fmt: off - def current_kqueue(): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: - return GLOBAL_RUN_CONTEXT.runner.io_manager.current_kqueue() + return GLOBAL_RUN_CONTEXT.runner.io_manager.current_kqueue() except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + def monitor_kevent(ident, filter): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: - return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_kevent(ident, filter) + return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_kevent(ident, filter) except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + async def wait_kevent(ident, filter, abort_func): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_kevent(ident, filter, abort_func) except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + async def wait_readable(fd): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd) except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + async def wait_writable(fd): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd) except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + def notify_closing(fd): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: - return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd) + return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd) except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + + +# fmt: on diff --git a/trio/_core/_generated_io_windows.py b/trio/_core/_generated_io_windows.py index 78dd30db19..26b4da697d 100644 --- a/trio/_core/_generated_io_windows.py +++ b/trio/_core/_generated_io_windows.py @@ -1,70 +1,84 @@ # *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND +# isort: skip +from ._instrumentation import Instrument from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED +from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT + +# fmt: off - async def wait_readable(sock): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(sock) except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + async def wait_writable(sock): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(sock) except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + def notify_closing(handle): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: - return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(handle) + return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(handle) except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + def register_with_iocp(handle): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: - return GLOBAL_RUN_CONTEXT.runner.io_manager.register_with_iocp(handle) + return GLOBAL_RUN_CONTEXT.runner.io_manager.register_with_iocp(handle) except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + async def wait_overlapped(handle, lpOverlapped): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_overlapped(handle, lpOverlapped) except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + async def write_overlapped(handle, data, file_offset=0): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.write_overlapped(handle, data, file_offset) except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + async def readinto_overlapped(handle, buffer, file_offset=0): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.io_manager.readinto_overlapped(handle, buffer, file_offset) except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + def current_iocp(): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: - return GLOBAL_RUN_CONTEXT.runner.io_manager.current_iocp() + return GLOBAL_RUN_CONTEXT.runner.io_manager.current_iocp() except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + def monitor_completion_key(): locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: - return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_completion_key() + return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_completion_key() except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + + +# fmt: on diff --git a/trio/_core/_generated_run.py b/trio/_core/_generated_run.py index 834346c0bf..d1e74a93f4 100644 --- a/trio/_core/_generated_run.py +++ b/trio/_core/_generated_run.py @@ -1,10 +1,13 @@ # *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND +# isort: skip +from ._instrumentation import Instrument from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED +from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT + +# fmt: off - def current_statistics(): """Returns an object containing run-loop-level debugging information. @@ -22,7 +25,7 @@ def current_statistics(): :data:`~math.inf` if there are no pending deadlines. * ``run_sync_soon_queue_size`` (int): The number of unprocessed callbacks queued via - :meth:`trio.hazmat.TrioToken.run_sync_soon`. + :meth:`trio.lowlevel.TrioToken.run_sync_soon`. * ``io_statistics`` (object): Some statistics from Trio's I/O backend. This always has an attribute ``backend`` which is a string naming which operating-system-specific I/O backend is in use; the @@ -31,9 +34,10 @@ def current_statistics(): """ locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: - return GLOBAL_RUN_CONTEXT.runner.current_statistics() + return GLOBAL_RUN_CONTEXT.runner.current_statistics() except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + def current_time(): """Returns the current time according to Trio's internal clock. @@ -47,19 +51,19 @@ def current_time(): """ locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: - return GLOBAL_RUN_CONTEXT.runner.current_time() + return GLOBAL_RUN_CONTEXT.runner.current_time() except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") -def current_clock(): - """Returns the current :class:`~trio.abc.Clock`. - """ +def current_clock(): + """Returns the current :class:`~trio.abc.Clock`.""" locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: - return GLOBAL_RUN_CONTEXT.runner.current_clock() + return GLOBAL_RUN_CONTEXT.runner.current_clock() except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + def current_root_task(): """Returns the current root :class:`Task`. @@ -69,9 +73,10 @@ def current_root_task(): """ locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: - return GLOBAL_RUN_CONTEXT.runner.current_root_task() + return GLOBAL_RUN_CONTEXT.runner.current_root_task() except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + def reschedule(task, next_send=_NO_SEND): """Reschedule the given task with the given @@ -85,7 +90,7 @@ def reschedule(task, next_send=_NO_SEND): to calling :func:`reschedule` once.) Args: - task (trio.hazmat.Task): the task to be rescheduled. Must be blocked + task (trio.lowlevel.Task): the task to be rescheduled. Must be blocked in a call to :func:`wait_task_rescheduled`. next_send (outcome.Outcome): the value (or error) to return (or raise) from :func:`wait_task_rescheduled`. @@ -93,11 +98,12 @@ def reschedule(task, next_send=_NO_SEND): """ locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: - return GLOBAL_RUN_CONTEXT.runner.reschedule(task, next_send) + return GLOBAL_RUN_CONTEXT.runner.reschedule(task, next_send) except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") -def spawn_system_task(async_fn, *args, name=None): + +def spawn_system_task(async_fn, *args, name=None, context=None): """Spawn a "system" task. System tasks have a few differences from regular tasks: @@ -121,6 +127,15 @@ def spawn_system_task(async_fn, *args, name=None): * System tasks do not inherit context variables from their creator. + Towards the end of a call to :meth:`trio.run`, after the main + task and all system tasks have exited, the system nursery + becomes closed. At this point, new calls to + :func:`spawn_system_task` will raise ``RuntimeError("Nursery + is closed to new arrivals")`` instead of creating a system + task. It's possible to encounter this state either in + a ``finally`` block in an async generator, or in a callback + passed to :meth:`TrioToken.run_sync_soon` at the right moment. + Args: async_fn: An async callable. args: Positional arguments for ``async_fn``. If you want to pass @@ -131,6 +146,10 @@ def spawn_system_task(async_fn, *args, name=None): case is if you're wrapping a function before spawning a new task, you might pass the original function as the ``name=`` to make debugging easier. + context: An optional ``contextvars.Context`` object with context variables + to use for this task. You would normally get a copy of the current + context with ``context = contextvars.copy_context()`` and then you would + pass that ``context`` object here. Returns: Task: the newly spawned task @@ -138,9 +157,10 @@ def spawn_system_task(async_fn, *args, name=None): """ locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: - return GLOBAL_RUN_CONTEXT.runner.spawn_system_task(async_fn, *args, name=name) + return GLOBAL_RUN_CONTEXT.runner.spawn_system_task(async_fn, *args, name=name, context=context) except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + def current_trio_token(): """Retrieve the :class:`TrioToken` for the current call to @@ -149,11 +169,12 @@ def current_trio_token(): """ locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: - return GLOBAL_RUN_CONTEXT.runner.current_trio_token() + return GLOBAL_RUN_CONTEXT.runner.current_trio_token() except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") + -async def wait_all_tasks_blocked(cushion=0.0, tiebreaker=0): +async def wait_all_tasks_blocked(cushion=0.0): """Block until there are no runnable tasks. This is useful in testing code when you want to give other tasks a @@ -171,9 +192,7 @@ async def wait_all_tasks_blocked(cushion=0.0, tiebreaker=0): then the one with the shortest ``cushion`` is the one woken (and this task becoming unblocked resets the timers for the remaining tasks). If there are multiple tasks that have exactly the same - ``cushion``, then the one with the lowest ``tiebreaker`` value is - woken first. And if there are multiple tasks with the same ``cushion`` - and the same ``tiebreaker``, then all are woken. + ``cushion``, then all are woken. You should also consider :class:`trio.testing.Sequencer`, which provides a more explicit way to control execution ordering within a @@ -196,18 +215,18 @@ async def test_lock_fairness(): nursery.start_soon(lock_taker, lock) # child hasn't run yet, we have the lock assert lock.locked() - assert lock._owner is trio.hazmat.current_task() + assert lock._owner is trio.lowlevel.current_task() await trio.testing.wait_all_tasks_blocked() # now the child has run and is blocked on lock.acquire(), we # still have the lock assert lock.locked() - assert lock._owner is trio.hazmat.current_task() + assert lock._owner is trio.lowlevel.current_task() lock.release() try: # The child has a prior claim, so we can't have it lock.acquire_nowait() except trio.WouldBlock: - assert lock._owner is not trio.hazmat.current_task() + assert lock._owner is not trio.lowlevel.current_task() print("PASS") else: print("FAIL") @@ -215,40 +234,9 @@ async def test_lock_fairness(): """ locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: - return await GLOBAL_RUN_CONTEXT.runner.wait_all_tasks_blocked(cushion, tiebreaker) - except AttributeError: - raise RuntimeError('must be called from async context') - -def add_instrument(instrument): - """Start instrumenting the current run loop with the given instrument. - - Args: - instrument (trio.abc.Instrument): The instrument to activate. - - If ``instrument`` is already active, does nothing. - - """ - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True - try: - return GLOBAL_RUN_CONTEXT.runner.add_instrument(instrument) + return await GLOBAL_RUN_CONTEXT.runner.wait_all_tasks_blocked(cushion) except AttributeError: - raise RuntimeError('must be called from async context') - -def remove_instrument(instrument): - """Stop instrumenting the current run loop with the given instrument. + raise RuntimeError("must be called from async context") - Args: - instrument (trio.abc.Instrument): The instrument to de-activate. - - Raises: - KeyError: if the instrument is not currently active. This could - occur either because you never added it, or because you added it - and then it raised an unhandled exception and was automatically - deactivated. - """ - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True - try: - return GLOBAL_RUN_CONTEXT.runner.remove_instrument(instrument) - except AttributeError: - raise RuntimeError('must be called from async context') +# fmt: on diff --git a/trio/_core/_instrumentation.py b/trio/_core/_instrumentation.py new file mode 100644 index 0000000000..a0757a5b83 --- /dev/null +++ b/trio/_core/_instrumentation.py @@ -0,0 +1,108 @@ +import logging +import types +from typing import Any, Callable, Dict, Sequence, TypeVar + +from .._abc import Instrument + +# Used to log exceptions in instruments +INSTRUMENT_LOGGER = logging.getLogger("trio.abc.Instrument") + + +F = TypeVar("F", bound=Callable[..., Any]) + + +# Decorator to mark methods public. This does nothing by itself, but +# trio/_tools/gen_exports.py looks for it. +def _public(fn: F) -> F: + return fn + + +class Instruments(Dict[str, Dict[Instrument, None]]): + """A collection of `trio.abc.Instrument` organized by hook. + + Instrumentation calls are rather expensive, and we don't want a + rarely-used instrument (like before_run()) to slow down hot + operations (like before_task_step()). Thus, we cache the set of + instruments to be called for each hook, and skip the instrumentation + call if there's nothing currently installed for that hook. + """ + + __slots__ = () + + def __init__(self, incoming: Sequence[Instrument]): + self["_all"] = {} + for instrument in incoming: + self.add_instrument(instrument) + + @_public + def add_instrument(self, instrument: Instrument) -> None: + """Start instrumenting the current run loop with the given instrument. + + Args: + instrument (trio.abc.Instrument): The instrument to activate. + + If ``instrument`` is already active, does nothing. + + """ + if instrument in self["_all"]: + return + self["_all"][instrument] = None + try: + for name in dir(instrument): + if name.startswith("_"): + continue + try: + prototype = getattr(Instrument, name) + except AttributeError: + continue + impl = getattr(instrument, name) + if isinstance(impl, types.MethodType) and impl.__func__ is prototype: + # Inherited unchanged from _abc.Instrument + continue + self.setdefault(name, {})[instrument] = None + except: + self.remove_instrument(instrument) + raise + + @_public + def remove_instrument(self, instrument: Instrument) -> None: + """Stop instrumenting the current run loop with the given instrument. + + Args: + instrument (trio.abc.Instrument): The instrument to de-activate. + + Raises: + KeyError: if the instrument is not currently active. This could + occur either because you never added it, or because you added it + and then it raised an unhandled exception and was automatically + deactivated. + + """ + # If instrument isn't present, the KeyError propagates out + self["_all"].pop(instrument) + for hookname, instruments in list(self.items()): + if instrument in instruments: + del instruments[instrument] + if not instruments: + del self[hookname] + + def call(self, hookname: str, *args: Any) -> None: + """Call hookname(*args) on each applicable instrument. + + You must first check whether there are any instruments installed for + that hook, e.g.:: + + if "before_task_step" in instruments: + instruments.call("before_task_step", task) + """ + for instrument in list(self[hookname]): + try: + getattr(instrument, hookname)(*args) + except: + self.remove_instrument(instrument) + INSTRUMENT_LOGGER.exception( + "Exception raised when calling %r on instrument %r. " + "Instrument has been disabled.", + hookname, + instrument, + ) diff --git a/trio/_core/_io_common.py b/trio/_core/_io_common.py index 9891849bc9..b141474fda 100644 --- a/trio/_core/_io_common.py +++ b/trio/_core/_io_common.py @@ -1,5 +1,7 @@ import copy + import outcome + from .. import _core diff --git a/trio/_core/_io_epoll.py b/trio/_core/_io_epoll.py index 5d73a58c84..376dd18a4e 100644 --- a/trio/_core/_io_epoll.py +++ b/trio/_core/_io_epoll.py @@ -1,10 +1,16 @@ import select -import attr +import sys from collections import defaultdict +from typing import TYPE_CHECKING, Dict + +import attr from .. import _core -from ._run import _public from ._io_common import wake_all +from ._run import _public +from ._wakeup_socketpair import WakeupSocketpair + +assert not TYPE_CHECKING or sys.platform == "linux" @attr.s(slots=True, eq=False, frozen=True) @@ -183,7 +189,15 @@ class EpollWaiters: class EpollIOManager: _epoll = attr.ib(factory=select.epoll) # {fd: EpollWaiters} - _registered = attr.ib(factory=lambda: defaultdict(EpollWaiters)) + _registered = attr.ib( + factory=lambda: defaultdict(EpollWaiters), type=Dict[int, EpollWaiters] + ) + _force_wakeup = attr.ib(factory=WakeupSocketpair) + _force_wakeup_fd = attr.ib(default=None) + + def __attrs_post_init__(self): + self._epoll.register(self._force_wakeup.wakeup_sock, select.EPOLLIN) + self._force_wakeup_fd = self._force_wakeup.wakeup_sock.fileno() def statistics(self): tasks_waiting_read = 0 @@ -200,13 +214,26 @@ def statistics(self): def close(self): self._epoll.close() + self._force_wakeup.close() - # Called internally by the task runner: - def handle_io(self, timeout): + def force_wakeup(self): + self._force_wakeup.wakeup_thread_and_signal_safe() + + # Return value must be False-y IFF the timeout expired, NOT if any I/O + # happened or force_wakeup was called. Otherwise it can be anything; gets + # passed straight through to process_events. + def get_events(self, timeout): # max_events must be > 0 or epoll gets cranky + # accessing self._registered from a thread looks dangerous, but it's + # OK because it doesn't matter if our value is a little bit off. max_events = max(1, len(self._registered)) - events = self._epoll.poll(timeout, max_events) + return self._epoll.poll(timeout, max_events) + + def process_events(self, events): for fd, flags in events: + if fd == self._force_wakeup_fd: + self._force_wakeup.drain() + continue waiters = self._registered[fd] # EPOLLONESHOT always clears the flags when an event is delivered waiters.current_flags = 0 @@ -235,9 +262,7 @@ def _update_registrations(self, fd): self._epoll.modify(fd, wanted_flags | select.EPOLLONESHOT) except OSError: # If that fails, it might be a new fd; try EPOLL_CTL_ADD - self._epoll.register( - fd, wanted_flags | select.EPOLLONESHOT - ) + self._epoll.register(fd, wanted_flags | select.EPOLLONESHOT) waiters.current_flags = wanted_flags except OSError as exc: # If everything fails, probably it's a bad fd, e.g. because @@ -284,7 +309,7 @@ def notify_closing(self, fd): fd = fd.fileno() wake_all( self._registered[fd], - _core.ClosedResourceError("another task closed this fd") + _core.ClosedResourceError("another task closed this fd"), ) del self._registered[fd] try: diff --git a/trio/_core/_io_kqueue.py b/trio/_core/_io_kqueue.py index e3989152ea..d1151843e8 100644 --- a/trio/_core/_io_kqueue.py +++ b/trio/_core/_io_kqueue.py @@ -1,11 +1,17 @@ +import errno import select - -import outcome +import sys from contextlib import contextmanager +from typing import TYPE_CHECKING + import attr +import outcome from .. import _core from ._run import _public +from ._wakeup_socketpair import WakeupSocketpair + +assert not TYPE_CHECKING or (sys.platform != "linux" and sys.platform != "win32") @attr.s(slots=True, eq=False, frozen=True) @@ -20,6 +26,15 @@ class KqueueIOManager: _kqueue = attr.ib(factory=select.kqueue) # {(ident, filter): Task or UnboundedQueue} _registered = attr.ib(factory=dict) + _force_wakeup = attr.ib(factory=WakeupSocketpair) + _force_wakeup_fd = attr.ib(default=None) + + def __attrs_post_init__(self): + force_wakeup_event = select.kevent( + self._force_wakeup.wakeup_sock, select.KQ_FILTER_READ, select.KQ_EV_ADD + ) + self._kqueue.control([force_wakeup_event], 0) + self._force_wakeup_fd = self._force_wakeup.wakeup_sock.fileno() def statistics(self): tasks_waiting = 0 @@ -29,15 +44,16 @@ def statistics(self): tasks_waiting += 1 else: monitors += 1 - return _KqueueStatistics( - tasks_waiting=tasks_waiting, - monitors=monitors, - ) + return _KqueueStatistics(tasks_waiting=tasks_waiting, monitors=monitors) def close(self): self._kqueue.close() + self._force_wakeup.close() + + def force_wakeup(self): + self._force_wakeup.wakeup_thread_and_signal_safe() - def handle_io(self, timeout): + def get_events(self, timeout): # max_events must be > 0 or kqueue gets cranky # and we generally want this to be strictly larger than the actual # number of events we get, so that we can tell that we've gotten @@ -52,8 +68,14 @@ def handle_io(self, timeout): else: timeout = 0 # and loop back to the start + return events + + def process_events(self, events): for event in events: key = (event.ident, event.filter) + if event.ident == self._force_wakeup_fd: + self._force_wakeup.drain() + continue receiver = self._registered[key] if event.flags & select.KQ_EV_ONESHOT: del self._registered[key] @@ -83,8 +105,7 @@ def monitor_kevent(self, ident, filter): key = (ident, filter) if key in self._registered: raise _core.BusyResourceError( - "attempt to register multiple listeners for same " - "ident/filter pair" + "attempt to register multiple listeners for same ident/filter pair" ) q = _core.UnboundedQueue() self._registered[key] = q @@ -98,8 +119,7 @@ async def wait_kevent(self, ident, filter, abort_func): key = (ident, filter) if key in self._registered: raise _core.BusyResourceError( - "attempt to register multiple listeners for same " - "ident/filter pair" + "attempt to register multiple listeners for same ident/filter pair" ) self._registered[key] = _core.current_task() @@ -122,16 +142,22 @@ def abort(_): event = select.kevent(fd, filter, select.KQ_EV_DELETE) try: self._kqueue.control([event], 0) - except FileNotFoundError: + except OSError as exc: # kqueue tracks individual fds (*not* the underlying file # object, see _io_epoll.py for a long discussion of why this # distinction matters), and automatically deregisters an event # if the fd is closed. So if kqueue.control says that it # doesn't know about this event, then probably it's because - # the fd was closed behind our backs. (Too bad it doesn't tell - # us that this happened... oh well, you can't have - # everything.) - pass + # the fd was closed behind our backs. (Too bad we can't ask it + # to wake us up when this happens, versus discovering it after + # the fact... oh well, you can't have everything.) + # + # FreeBSD reports this using EBADF. macOS uses ENOENT. + if exc.errno in (errno.EBADF, errno.ENOENT): # pragma: no branch + pass + else: # pragma: no cover + # As far as we know, this branch can't happen. + raise return _core.Abort.SUCCEEDED await self.wait_kevent(fd, filter, abort) diff --git a/trio/_core/_io_windows.py b/trio/_core/_io_windows.py index 7e452177d1..4084f72b6e 100644 --- a/trio/_core/_io_windows.py +++ b/trio/_core/_io_windows.py @@ -1,30 +1,34 @@ -import itertools -from contextlib import contextmanager import enum +import itertools import socket +import sys +from contextlib import contextmanager +from typing import TYPE_CHECKING import attr +from outcome import Value from .. import _core -from ._run import _public from ._io_common import wake_all - +from ._run import _public from ._windows_cffi import ( - ffi, - kernel32, - ntdll, - ws2_32, INVALID_HANDLE_VALUE, - raise_winerror, - _handle, - ErrorCodes, - FileFlags, AFDPollFlags, - WSAIoctls, CompletionModes, + ErrorCodes, + FileFlags, IoControlCodes, + WSAIoctls, + _handle, + ffi, + kernel32, + ntdll, + raise_winerror, + ws2_32, ) +assert not TYPE_CHECKING or sys.platform == "win32" + # There's a lot to be said about the overall design of a Windows event # loop. See # @@ -171,7 +175,8 @@ class CKeys(enum.IntEnum): AFD_POLL = 0 WAIT_OVERLAPPED = 1 LATE_CANCEL = 2 - USER_DEFINED = 3 # and above + FORCE_WAKEUP = 3 + USER_DEFINED = 4 # and above def _check(success): @@ -180,7 +185,7 @@ def _check(success): return success -def _get_base_socket(sock, *, which=WSAIoctls.SIO_BASE_HANDLE): +def _get_underlying_socket(sock, *, which=WSAIoctls.SIO_BASE_HANDLE): if hasattr(sock, "fileno"): sock = sock.fileno() base_ptr = ffi.new("HANDLE *") @@ -202,6 +207,53 @@ def _get_base_socket(sock, *, which=WSAIoctls.SIO_BASE_HANDLE): return base_ptr[0] +def _get_base_socket(sock): + # There is a development kit for LSPs called Komodia Redirector. + # It does some unusual (some might say evil) things like intercepting + # SIO_BASE_HANDLE (fails) and SIO_BSP_HANDLE_SELECT (returns the same + # socket) in a misguided attempt to prevent bypassing it. It's been used + # in malware including the infamous Lenovo Superfish incident from 2015, + # but unfortunately is also used in some legitimate products such as + # parental control tools and Astrill VPN. Komodia happens to not + # block SIO_BSP_HANDLE_POLL, so we'll try SIO_BASE_HANDLE and fall back + # to SIO_BSP_HANDLE_POLL if it doesn't work. + # References: + # - https://github.com/piscisaureus/wepoll/blob/0598a791bf9cbbf480793d778930fc635b044980/wepoll.c#L2223 + # - https://github.com/tokio-rs/mio/issues/1314 + + while True: + try: + # If this is not a Komodia-intercepted socket, we can just use + # SIO_BASE_HANDLE. + return _get_underlying_socket(sock) + except OSError as ex: + if ex.winerror == ErrorCodes.ERROR_NOT_SOCKET: + # SIO_BASE_HANDLE might fail even without LSP intervention, + # if we get something that's not a socket. + raise + if hasattr(sock, "fileno"): + sock = sock.fileno() + sock = _handle(sock) + next_sock = _get_underlying_socket( + sock, which=WSAIoctls.SIO_BSP_HANDLE_POLL + ) + if next_sock == sock: + # If BSP_HANDLE_POLL returns the same socket we already had, + # then there's no layering going on and we need to fail + # to prevent an infinite loop. + raise RuntimeError( + "Unexpected network configuration detected: " + "SIO_BASE_HANDLE failed and SIO_BSP_HANDLE_POLL didn't " + "return a different socket. Please file a bug at " + "https://github.com/python-trio/trio/issues/new, " + "and include the output of running: " + "netsh winsock show catalog" + ) + # Otherwise we've gotten at least one layer deeper, so + # loop back around to keep digging. + sock = next_sock + + def _afd_helper_handle(): # The "AFD" driver is exposed at the NT path "\Device\Afd". We're using # the Win32 CreateFile, though, so we have to pass a Win32 path. \\.\ is @@ -291,6 +343,24 @@ class AFDPollOp: lpOverlapped = attr.ib() poll_info = attr.ib() waiters = attr.ib() + afd_group = attr.ib() + + +# The Windows kernel has a weird issue when using AFD handles. If you have N +# instances of wait_readable/wait_writable registered with a single AFD handle, +# then cancelling any one of them takes something like O(N**2) time. So if we +# used just a single AFD handle, then cancellation would quickly become very +# expensive, e.g. a program with N active sockets would take something like +# O(N**3) time to unwind after control-C. The solution is to spread our sockets +# out over multiple AFD handles, so that N doesn't grow too large for any +# individual handle. +MAX_AFD_GROUP_SIZE = 500 # at 1000, the cubic scaling is just starting to bite + + +@attr.s(slots=True, eq=False) +class AFDGroup: + size = attr.ib() + handle = attr.ib() @attr.s(slots=True, eq=False, frozen=True) @@ -322,17 +392,14 @@ def __init__(self): # touches to safe values up front, before we do anything that can # fail. self._iocp = None - self._afd = None + self._all_afd_handles = [] self._iocp = _check( - kernel32.CreateIoCompletionPort( - INVALID_HANDLE_VALUE, ffi.NULL, 0, 0 - ) + kernel32.CreateIoCompletionPort(INVALID_HANDLE_VALUE, ffi.NULL, 0, 0) ) self._events = ffi.new("OVERLAPPED_ENTRY[]", MAX_EVENTS) - self._afd = _afd_helper_handle() - self._register_with_iocp(self._afd, CKeys.AFD_POLL) + self._vacant_afd_groups = set() # {lpOverlapped: AFDPollOp} self._afd_ops = {} # {socket handle: AFDWaiters} @@ -346,21 +413,41 @@ def __init__(self): self._completion_key_counter = itertools.count(CKeys.USER_DEFINED) with socket.socket() as s: - # LSPs can't override this. - base_handle = _get_base_socket(s, which=WSAIoctls.SIO_BASE_HANDLE) + # We assume we're not working with any LSP that changes + # how select() is supposed to work. Validate this by + # ensuring that the result of SIO_BSP_HANDLE_SELECT (the + # LSP-hookable mechanism for "what should I use for + # select()?") matches that of SIO_BASE_HANDLE ("what is + # the real non-hooked underlying socket here?"). + # + # This doesn't work for Komodia-based LSPs; see the comments + # in _get_base_socket() for details. But we have special + # logic for those, so we just skip this check if + # SIO_BASE_HANDLE fails. + # LSPs can in theory override this, but we believe that it never - # actually happens in the wild. - select_handle = _get_base_socket( + # actually happens in the wild (except Komodia) + select_handle = _get_underlying_socket( s, which=WSAIoctls.SIO_BSP_HANDLE_SELECT ) - if base_handle != select_handle: # pragma: no cover - raise RuntimeError( - "Unexpected network configuration detected. " - "Please file a bug at " - "https://github.com/python-trio/trio/issues/new, " - "and include the output of running: " - "netsh winsock show catalog" - ) + try: + # LSPs shouldn't override this... + base_handle = _get_underlying_socket(s, which=WSAIoctls.SIO_BASE_HANDLE) + except OSError: + # But Komodia-based LSPs do anyway, in a way that causes + # a failure with WSAEFAULT. We have special handling for + # them in _get_base_socket(). Make sure it works. + _get_base_socket(s) + else: + if base_handle != select_handle: + raise RuntimeError( + "Unexpected network configuration detected: " + "SIO_BASE_HANDLE and SIO_BSP_HANDLE_SELECT differ. " + "Please file a bug at " + "https://github.com/python-trio/trio/issues/new, " + "and include the output of running: " + "netsh winsock show catalog" + ) def close(self): try: @@ -369,10 +456,9 @@ def close(self): self._iocp = None _check(kernel32.CloseHandle(iocp)) finally: - if self._afd is not None: - afd = self._afd - self._afd = None - _check(kernel32.CloseHandle(afd)) + while self._all_afd_handles: + afd_handle = self._all_afd_handles.pop() + _check(kernel32.CloseHandle(afd_handle)) def __del__(self): self.close() @@ -392,7 +478,14 @@ def statistics(self): completion_key_monitors=len(self._completion_key_queues), ) - def handle_io(self, timeout): + def force_wakeup(self): + _check( + kernel32.PostQueuedCompletionStatus( + self._iocp, 0, CKeys.FORCE_WAKEUP, ffi.NULL + ) + ) + + def get_events(self, timeout): received = ffi.new("PULONG") milliseconds = round(1000 * timeout) if timeout > 0 and milliseconds == 0: @@ -400,15 +493,17 @@ def handle_io(self, timeout): try: _check( kernel32.GetQueuedCompletionStatusEx( - self._iocp, self._events, MAX_EVENTS, received, - milliseconds, 0 + self._iocp, self._events, MAX_EVENTS, received, milliseconds, 0 ) ) except OSError as exc: if exc.winerror != ErrorCodes.WAIT_TIMEOUT: # pragma: no cover raise - return - for i in range(received[0]): + return 0 + return received[0] + + def process_events(self, received): + for i in range(received): entry = self._events[i] if entry.lpCompletionKey == CKeys.AFD_POLL: lpo = entry.lpOverlapped @@ -435,7 +530,12 @@ def handle_io(self, timeout): elif entry.lpCompletionKey == CKeys.WAIT_OVERLAPPED: # Regular I/O event, dispatch on lpOverlapped waiter = self._overlapped_waiters.pop(entry.lpOverlapped) - _core.reschedule(waiter) + overlapped = entry.lpOverlapped + transferred = entry.dwNumberOfBytesTransferred + info = CompletionKeyEventInfo( + lpOverlapped=overlapped, dwNumberOfBytesTransferred=transferred + ) + _core.reschedule(waiter, Value(info)) elif entry.lpCompletionKey == CKeys.LATE_CANCEL: # Post made by a regular I/O event's abort_fn # after it failed to cancel the I/O. If we still @@ -470,24 +570,21 @@ def handle_io(self, timeout): # try changing this line to # _core.reschedule(waiter, outcome.Error(exc)) raise exc + elif entry.lpCompletionKey == CKeys.FORCE_WAKEUP: + pass else: # dispatch on lpCompletionKey queue = self._completion_key_queues[entry.lpCompletionKey] overlapped = int(ffi.cast("uintptr_t", entry.lpOverlapped)) transferred = entry.dwNumberOfBytesTransferred info = CompletionKeyEventInfo( - lpOverlapped=overlapped, - dwNumberOfBytesTransferred=transferred, + lpOverlapped=overlapped, dwNumberOfBytesTransferred=transferred ) queue.put_nowait(info) def _register_with_iocp(self, handle, completion_key): handle = _handle(handle) - _check( - kernel32.CreateIoCompletionPort( - handle, self._iocp, completion_key, 0 - ) - ) + _check(kernel32.CreateIoCompletionPort(handle, self._iocp, completion_key, 0)) # Supposedly this makes things slightly faster, by disabling the # ability to do WaitForSingleObject(handle). We would never want to do # that anyway, so might as well get the extra speed (if any). @@ -505,10 +602,11 @@ def _register_with_iocp(self, handle, completion_key): def _refresh_afd(self, base_handle): waiters = self._afd_waiters[base_handle] if waiters.current_op is not None: + afd_group = waiters.current_op.afd_group try: _check( kernel32.CancelIoEx( - self._afd, waiters.current_op.lpOverlapped + afd_group.handle, waiters.current_op.lpOverlapped ) ) except OSError as exc: @@ -517,6 +615,8 @@ def _refresh_afd(self, base_handle): # crash noisily. raise # pragma: no cover waiters.current_op = None + afd_group.size -= 1 + self._vacant_afd_groups.add(afd_group) flags = 0 if waiters.read_task is not None: @@ -527,6 +627,14 @@ def _refresh_afd(self, base_handle): if not flags: del self._afd_waiters[base_handle] else: + try: + afd_group = self._vacant_afd_groups.pop() + except KeyError: + afd_group = AFDGroup(0, _afd_helper_handle()) + self._register_with_iocp(afd_group.handle, CKeys.AFD_POLL) + self._all_afd_handles.append(afd_group.handle) + self._vacant_afd_groups.add(afd_group) + lpOverlapped = ffi.new("LPOVERLAPPED") poll_info = ffi.new("AFD_POLL_INFO *") @@ -540,7 +648,7 @@ def _refresh_afd(self, base_handle): try: _check( kernel32.DeviceIoControl( - self._afd, + afd_group.handle, IoControlCodes.IOCTL_AFD_POLL, poll_info, ffi.sizeof("AFD_POLL_INFO"), @@ -560,9 +668,12 @@ def _refresh_afd(self, base_handle): # Do this last, because it could raise. wake_all(waiters, exc) return - op = AFDPollOp(lpOverlapped, poll_info, waiters) + op = AFDPollOp(lpOverlapped, poll_info, waiters, afd_group) waiters.current_op = op self._afd_ops[lpOverlapped] = op + afd_group.size += 1 + if afd_group.size >= MAX_AFD_GROUP_SIZE: + self._vacant_afd_groups.remove(afd_group) async def _afd_poll(self, sock, mode): base_handle = _get_base_socket(sock) @@ -656,7 +767,7 @@ def abort(raise_cancel_): ) from exc return _core.Abort.FAILED - await _core.wait_task_rescheduled(abort) + info = await _core.wait_task_rescheduled(abort) if lpOverlapped.Internal != 0: # the lpOverlapped reports the error as an NT status code, # which we must convert back to a Win32 error code before @@ -669,11 +780,10 @@ def abort(raise_cancel_): # We didn't request this cancellation, so assume # it happened due to the underlying handle being # closed before the operation could complete. - raise _core.ClosedResourceError( - "another task closed this resource" - ) + raise _core.ClosedResourceError("another task closed this resource") else: raise_winerror(code) + return info async def _perform_overlapped(self, handle, submit_fn): # submit_fn(lpOverlapped) submits some I/O @@ -700,7 +810,7 @@ async def write_overlapped(self, handle, data, file_offset=0): def submit_write(lpOverlapped): # yes, these are the real documented names offset_fields = lpOverlapped.DUMMYUNIONNAME.DUMMYSTRUCTNAME - offset_fields.Offset = file_offset & 0xffffffff + offset_fields.Offset = file_offset & 0xFFFFFFFF offset_fields.OffsetHigh = file_offset >> 32 _check( kernel32.WriteFile( @@ -722,7 +832,7 @@ async def readinto_overlapped(self, handle, buffer, file_offset=0): def submit_read(lpOverlapped): offset_fields = lpOverlapped.DUMMYUNIONNAME.DUMMYSTRUCTNAME - offset_fields.Offset = file_offset & 0xffffffff + offset_fields.Offset = file_offset & 0xFFFFFFFF offset_fields.OffsetHigh = file_offset >> 32 _check( kernel32.ReadFile( diff --git a/trio/_core/_ki.py b/trio/_core/_ki.py index a3f64c5dca..cc05ef9177 100644 --- a/trio/_core/_ki.py +++ b/trio/_core/_ki.py @@ -1,22 +1,19 @@ +from __future__ import annotations + import inspect import signal import sys -from contextlib import contextmanager from functools import wraps +from typing import TYPE_CHECKING -import async_generator +import attr from .._util import is_main_thread -if False: - from typing import Any, TypeVar, Callable - F = TypeVar('F', bound=Callable[..., Any]) +if TYPE_CHECKING: + from typing import Any, Callable, TypeVar -__all__ = [ - "enable_ki_protection", - "disable_ki_protection", - "currently_ki_protected", -] + F = TypeVar("F", bound=Callable[..., Any]) # In ordinary single-threaded Python code, when you hit control-C, it raises # an exception and automatically does all the regular unwinding stuff. @@ -60,7 +57,7 @@ # # If this raises a KeyboardInterrupt, it might be because the coroutine got # interrupted and has unwound... or it might be the KeyboardInterrupt -# arrived just *after* 'send' returned, so the coroutine is still running +# arrived just *after* 'send' returned, so the coroutine is still running, # but we just lost the message it sent. (And worse, in our actual task # runner, the send is hidden inside a utility function etc.) # @@ -83,7 +80,7 @@ # We use this special string as a unique key into the frame locals dictionary. # The @ ensures it is not a valid identifier and can't clash with any possible # real local name. See: https://github.com/python-trio/trio/issues/469 -LOCALS_KEY_KI_PROTECTION_ENABLED = '@TRIO_KI_PROTECTION_ENABLED' +LOCALS_KEY_KI_PROTECTION_ENABLED = "@TRIO_KI_PROTECTION_ENABLED" # NB: according to the signal.signal docs, 'frame' can be None on entry to @@ -95,7 +92,7 @@ def ki_protection_enabled(frame): if frame.f_code.co_name == "__del__": return True frame = frame.f_back - return False + return True def currently_ki_protected(): @@ -114,6 +111,14 @@ def currently_ki_protected(): return ki_protection_enabled(sys._getframe()) +# This is to support the async_generator package necessary for aclosing on <3.10 +# functions decorated @async_generator are given this magic property that's a +# reference to the object itself +# see python-trio/async_generator/async_generator/_impl.py +def legacy_isasyncgenfunction(obj): + return getattr(obj, "_async_gen_function", None) == id(obj) + + def _ki_protection_decorator(enabled): def decorator(fn): # In some version of Python, isgeneratorfunction returns true for @@ -125,8 +130,7 @@ def decorator(fn): def wrapper(*args, **kwargs): # See the comment for regular generators below coro = fn(*args, **kwargs) - coro.cr_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED - ] = enabled + coro.cr_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled return coro return wrapper @@ -143,19 +147,17 @@ def wrapper(*args, **kwargs): # thrown into! See: # https://bugs.python.org/issue29590 gen = fn(*args, **kwargs) - gen.gi_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED - ] = enabled + gen.gi_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled return gen return wrapper - elif async_generator.isasyncgenfunction(fn): + elif inspect.isasyncgenfunction(fn) or legacy_isasyncgenfunction(fn): @wraps(fn) def wrapper(*args, **kwargs): # See the comment for regular generators above agen = fn(*args, **kwargs) - agen.ag_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED - ] = enabled + agen.ag_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled return agen return wrapper @@ -171,35 +173,38 @@ def wrapper(*args, **kwargs): return decorator -enable_ki_protection = _ki_protection_decorator(True) # type: Callable[[F], F] +enable_ki_protection: Callable[[F], F] = _ki_protection_decorator(True) enable_ki_protection.__name__ = "enable_ki_protection" -disable_ki_protection = _ki_protection_decorator( - False -) # type: Callable[[F], F] +disable_ki_protection: Callable[[F], F] = _ki_protection_decorator(False) disable_ki_protection.__name__ = "disable_ki_protection" -@contextmanager -def ki_manager(deliver_cb, restrict_keyboard_interrupt_to_checkpoints): - if ( - not is_main_thread() - or signal.getsignal(signal.SIGINT) != signal.default_int_handler - ): - yield - return - - def handler(signum, frame): - assert signum == signal.SIGINT - protection_enabled = ki_protection_enabled(frame) - if protection_enabled or restrict_keyboard_interrupt_to_checkpoints: - deliver_cb() - else: - raise KeyboardInterrupt - - signal.signal(signal.SIGINT, handler) - try: - yield - finally: - if signal.getsignal(signal.SIGINT) is handler: - signal.signal(signal.SIGINT, signal.default_int_handler) +@attr.s +class KIManager: + handler = attr.ib(default=None) + + def install(self, deliver_cb, restrict_keyboard_interrupt_to_checkpoints): + assert self.handler is None + if ( + not is_main_thread() + or signal.getsignal(signal.SIGINT) != signal.default_int_handler + ): + return + + def handler(signum, frame): + assert signum == signal.SIGINT + protection_enabled = ki_protection_enabled(frame) + if protection_enabled or restrict_keyboard_interrupt_to_checkpoints: + deliver_cb() + else: + raise KeyboardInterrupt + + self.handler = handler + signal.signal(signal.SIGINT, handler) + + def close(self): + if self.handler is not None: + if signal.getsignal(signal.SIGINT) is self.handler: + signal.signal(signal.SIGINT, signal.default_int_handler) + self.handler = None diff --git a/trio/_core/_local.py b/trio/_core/_local.py index 7ff3757356..7f2c632153 100644 --- a/trio/_core/_local.py +++ b/trio/_core/_local.py @@ -1,25 +1,34 @@ +from __future__ import annotations + +from typing import Generic, TypeVar, final + # Runvar implementations +import attr + +from .._util import Final, NoPublicConstructor from . import _run -__all__ = ["RunVar"] +T = TypeVar("T") -class _RunVarToken: - _no_value = object() +@final +class _NoValue(metaclass=Final): + ... - __slots__ = ("_var", "previous_value", "redeemed") - @classmethod - def empty(cls, var): - return cls(var, value=cls._no_value) +@attr.s(eq=False, hash=False, slots=False) +class RunVarToken(Generic[T], metaclass=NoPublicConstructor): + _var: RunVar[T] = attr.ib() + previous_value: T | type[_NoValue] = attr.ib(default=_NoValue) + redeemed: bool = attr.ib(default=False, init=False) - def __init__(self, var, value): - self._var = var - self.previous_value = value - self.redeemed = False + @classmethod + def _empty(cls, var: RunVar[T]) -> RunVarToken[T]: + return cls._create(var) -class RunVar: +@attr.s(eq=False, hash=False, slots=True) +class RunVar(Generic[T], metaclass=Final): """The run-local variant of a context variable. :class:`RunVar` objects are similar to context variable objects, @@ -28,31 +37,28 @@ class RunVar: """ - _NO_DEFAULT = object() - __slots__ = ("_name", "_default") - - def __init__(self, name, default=_NO_DEFAULT): - self._name = name - self._default = default + _name: str = attr.ib() + _default: T | type[_NoValue] = attr.ib(default=_NoValue) - def get(self, default=_NO_DEFAULT): + def get(self, default: T | type[_NoValue] = _NoValue) -> T: """Gets the value of this :class:`RunVar` for the current run call.""" try: - return _run.GLOBAL_RUN_CONTEXT.runner._locals[self] + # not typed yet + return _run.GLOBAL_RUN_CONTEXT.runner._locals[self] # type: ignore[return-value, index] except AttributeError: - raise RuntimeError("Cannot be used outside of a run context") \ - from None + raise RuntimeError("Cannot be used outside of a run context") from None except KeyError: # contextvars consistency - if default is not self._NO_DEFAULT: - return default + # `type: ignore` awaiting https://github.com/python/mypy/issues/15553 to be fixed & released + if default is not _NoValue: + return default # type: ignore[return-value] - if self._default is not self._NO_DEFAULT: - return self._default + if self._default is not _NoValue: + return self._default # type: ignore[return-value] raise LookupError(self) from None - def set(self, value): + def set(self, value: T) -> RunVarToken[T]: """Sets the value of this :class:`RunVar` for this current run call. @@ -60,16 +66,16 @@ def set(self, value): try: old_value = self.get() except LookupError: - token = _RunVarToken.empty(self) + token = RunVarToken._empty(self) else: - token = _RunVarToken(self, old_value) + token = RunVarToken[T]._create(self, old_value) # This can't fail, because if we weren't in Trio context then the # get() above would have failed. - _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = value + _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = value # type: ignore[assignment, index] return token - def reset(self, token): + def reset(self, token: RunVarToken[T]) -> None: """Resets the value of this :class:`RunVar` to what it was previously specified by the token. @@ -85,14 +91,14 @@ def reset(self, token): previous = token.previous_value try: - if previous is _RunVarToken._no_value: - _run.GLOBAL_RUN_CONTEXT.runner._locals.pop(self) + if previous is _NoValue: + _run.GLOBAL_RUN_CONTEXT.runner._locals.pop(self) # type: ignore[arg-type] else: - _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = previous + _run.GLOBAL_RUN_CONTEXT.runner._locals[self] = previous # type: ignore[index, assignment] except AttributeError: raise RuntimeError("Cannot be used outside of a run context") token.redeemed = True - def __repr__(self): - return ("".format(self._name)) + def __repr__(self) -> str: + return f"" diff --git a/trio/_core/_mock_clock.py b/trio/_core/_mock_clock.py new file mode 100644 index 0000000000..fe35298631 --- /dev/null +++ b/trio/_core/_mock_clock.py @@ -0,0 +1,165 @@ +import time +from math import inf + +from .. import _core +from .._abc import Clock +from .._util import Final +from ._run import GLOBAL_RUN_CONTEXT + +################################################################ +# The glorious MockClock +################################################################ + + +# Prior art: +# https://twistedmatrix.com/documents/current/api/twisted.internet.task.Clock.html +# https://github.com/ztellman/manifold/issues/57 +class MockClock(Clock, metaclass=Final): + """A user-controllable clock suitable for writing tests. + + Args: + rate (float): the initial :attr:`rate`. + autojump_threshold (float): the initial :attr:`autojump_threshold`. + + .. attribute:: rate + + How many seconds of clock time pass per second of real time. Default is + 0.0, i.e. the clock only advances through manuals calls to :meth:`jump` + or when the :attr:`autojump_threshold` is triggered. You can assign to + this attribute to change it. + + .. attribute:: autojump_threshold + + The clock keeps an eye on the run loop, and if at any point it detects + that all tasks have been blocked for this many real seconds (i.e., + according to the actual clock, not this clock), then the clock + automatically jumps ahead to the run loop's next scheduled + timeout. Default is :data:`math.inf`, i.e., to never autojump. You can + assign to this attribute to change it. + + Basically the idea is that if you have code or tests that use sleeps + and timeouts, you can use this to make it run much faster, totally + automatically. (At least, as long as those sleeps/timeouts are + happening inside Trio; if your test involves talking to external + service and waiting for it to timeout then obviously we can't help you + there.) + + You should set this to the smallest value that lets you reliably avoid + "false alarms" where some I/O is in flight (e.g. between two halves of + a socketpair) but the threshold gets triggered and time gets advanced + anyway. This will depend on the details of your tests and test + environment. If you aren't doing any I/O (like in our sleeping example + above) then just set it to zero, and the clock will jump whenever all + tasks are blocked. + + .. note:: If you use ``autojump_threshold`` and + `wait_all_tasks_blocked` at the same time, then you might wonder how + they interact, since they both cause things to happen after the run + loop goes idle for some time. The answer is: + `wait_all_tasks_blocked` takes priority. If there's a task blocked + in `wait_all_tasks_blocked`, then the autojump feature treats that + as active task and does *not* jump the clock. + + """ + + def __init__(self, rate: float = 0.0, autojump_threshold: float = inf): + # when the real clock said 'real_base', the virtual time was + # 'virtual_base', and since then it's advanced at 'rate' virtual + # seconds per real second. + self._real_base = 0.0 + self._virtual_base = 0.0 + self._rate = 0.0 + self._autojump_threshold = 0.0 + # kept as an attribute so that our tests can monkeypatch it + self._real_clock = time.perf_counter + + # use the property update logic to set initial values + self.rate = rate + self.autojump_threshold = autojump_threshold + + def __repr__(self) -> str: + return "".format( + self.current_time(), self._rate, id(self) + ) + + @property + def rate(self) -> float: + return self._rate + + @rate.setter + def rate(self, new_rate: float) -> None: + if new_rate < 0: + raise ValueError("rate must be >= 0") + else: + real = self._real_clock() + virtual = self._real_to_virtual(real) + self._virtual_base = virtual + self._real_base = real + self._rate = float(new_rate) + + @property + def autojump_threshold(self) -> float: + return self._autojump_threshold + + @autojump_threshold.setter + def autojump_threshold(self, new_autojump_threshold: float) -> None: + self._autojump_threshold = float(new_autojump_threshold) + self._try_resync_autojump_threshold() + + # runner.clock_autojump_threshold is an internal API that isn't easily + # usable by custom third-party Clock objects. If you need access to this + # functionality, let us know, and we'll figure out how to make a public + # API. Discussion: + # + # https://github.com/python-trio/trio/issues/1587 + def _try_resync_autojump_threshold(self) -> None: + try: + runner = GLOBAL_RUN_CONTEXT.runner + if runner.is_guest: + runner.force_guest_tick_asap() + except AttributeError: + pass + else: + runner.clock_autojump_threshold = self._autojump_threshold + + # Invoked by the run loop when runner.clock_autojump_threshold is + # exceeded. + def _autojump(self) -> None: + statistics = _core.current_statistics() + jump = statistics.seconds_to_next_deadline + if 0 < jump < inf: + self.jump(jump) + + def _real_to_virtual(self, real: float) -> float: + real_offset = real - self._real_base + virtual_offset = self._rate * real_offset + return self._virtual_base + virtual_offset + + def start_clock(self) -> None: + self._try_resync_autojump_threshold() + + def current_time(self) -> float: + return self._real_to_virtual(self._real_clock()) + + def deadline_to_sleep_time(self, deadline: float) -> float: + virtual_timeout = deadline - self.current_time() + if virtual_timeout <= 0: + return 0 + elif self._rate > 0: + return virtual_timeout / self._rate + else: + return 999999999 + + def jump(self, seconds) -> None: + """Manually advance the clock by the given number of seconds. + + Args: + seconds (float): the number of seconds to jump the clock forward. + + Raises: + ValueError: if you try to pass a negative value for ``seconds``. + + """ + if seconds < 0: + raise ValueError("time can't go backwards") + self._virtual_base += seconds diff --git a/trio/_core/_multierror.py b/trio/_core/_multierror.py index a29f19f60d..3c6ebb789f 100644 --- a/trio/_core/_multierror.py +++ b/trio/_core/_multierror.py @@ -1,19 +1,20 @@ +from __future__ import annotations + import sys -import traceback -import textwrap import warnings +from typing import TYPE_CHECKING import attr -__all__ = ["MultiError"] +from trio._deprecate import warn_deprecated -# python traceback.TracebackException < 3.6.4 does not support unhashable exceptions -# see https://github.com/python/cpython/pull/4014 for details -if sys.version_info < (3, 6, 4): - exc_key = lambda exc: exc +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup, ExceptionGroup, print_exception else: - exc_key = id + from traceback import print_exception +if TYPE_CHECKING: + from types import TracebackType ################################################################ # MultiError ################################################################ @@ -116,6 +117,9 @@ def push_tb_down(tb, exc, preserved): preserved = set() new_root_exc = filter_tree(root_exc, preserved) push_tb_down(None, root_exc, preserved) + # Delete the local functions to avoid a reference cycle (see + # test_simple_cancel_scope_usage_doesnt_create_cyclic_garbage) + del filter_tree, push_tb_down return new_root_exc @@ -131,10 +135,16 @@ class MultiErrorCatcher: def __enter__(self): pass - def __exit__(self, etype, exc, tb): - if exc is not None: - filtered_exc = MultiError.filter(self._handler, exc) - if filtered_exc is exc: + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> bool | None: + if exc_value is not None: + filtered_exc = _filter_impl(self._handler, exc_value) + + if filtered_exc is exc_value: # Let the interpreter re-raise it return False if filtered_exc is None: @@ -151,9 +161,13 @@ def __exit__(self, etype, exc, tb): _, value, _ = sys.exc_info() assert value is filtered_exc value.__context__ = old_context + # delete references from locals to avoid creating cycles + # see test_MultiError_catch_doesnt_create_cyclic_garbage + del _, filtered_exc, value + return False -class MultiError(BaseException): +class MultiError(BaseExceptionGroup): """An exception that contains other exceptions; also known as an "inception". @@ -175,23 +189,24 @@ class MultiError(BaseException): :exc:`BaseException`. """ - def __init__(self, exceptions): - # Avoid recursion when exceptions[0] returned by __new__() happens - # to be a MultiError and subsequently __init__() is called. - if hasattr(self, "exceptions"): - # __init__ was already called on this object - assert len(exceptions) == 1 and exceptions[0] is self + + def __init__(self, exceptions, *, _collapse=True): + self.collapse = _collapse + + # Avoid double initialization when _collapse is True and exceptions[0] returned + # by __new__() happens to be a MultiError and subsequently __init__() is called. + if _collapse and getattr(self, "exceptions", None) is not None: + # This exception was already initialized. return - self.exceptions = exceptions - def __new__(cls, exceptions): + super().__init__("multiple tasks failed", exceptions) + + def __new__(cls, exceptions, *, _collapse=True): exceptions = list(exceptions) for exc in exceptions: if not isinstance(exc, BaseException): - raise TypeError( - "Expected an exception object, not {!r}".format(exc) - ) - if len(exceptions) == 1: + raise TypeError(f"Expected an exception object, not {exc!r}") + if _collapse and len(exceptions) == 1: # If this lone object happens to itself be a MultiError, then # Python will implicitly call our __init__ on it again. See # special handling in __init__. @@ -203,13 +218,30 @@ def __new__(cls, exceptions): # In an earlier version of the code, we didn't define __init__ and # simply set the `exceptions` attribute directly on the new object. # However, linters expect attributes to be initialized in __init__. - return BaseException.__new__(cls, exceptions) + if all(isinstance(exc, Exception) for exc in exceptions): + cls = NonBaseMultiError + + return super().__new__(cls, "multiple tasks failed", exceptions) + + def __reduce__(self): + return ( + self.__new__, + (self.__class__, list(self.exceptions)), + {"collapse": self.collapse}, + ) def __str__(self): return ", ".join(repr(exc) for exc in self.exceptions) def __repr__(self): - return "".format(self) + return f"" + + def derive(self, __excs): + # We use _collapse=False here to get ExceptionGroup semantics, since derive() + # is part of the PEP 654 API + exc = MultiError(__excs, _collapse=False) + exc.collapse = self.collapse + return exc @classmethod def filter(cls, handler, root_exc): @@ -227,7 +259,12 @@ def filter(cls, handler, root_exc): ``handler`` returned None for all the inputs, returns None. """ - + warn_deprecated( + "MultiError.filter()", + "0.22.0", + instead="BaseExceptionGroup.split()", + issue=2211, + ) return _filter_impl(handler, root_exc) @classmethod @@ -239,12 +276,23 @@ def catch(cls, handler): handler: as for :meth:`filter` """ + warn_deprecated( + "MultiError.catch", + "0.22.0", + instead="except* or exceptiongroup.catch()", + issue=2211, + ) return MultiErrorCatcher(handler) +class NonBaseMultiError(MultiError, ExceptionGroup): + pass + + # Clean up exception printing: MultiError.__module__ = "trio" +NonBaseMultiError.__module__ = "trio" ################################################################ # concat_tb @@ -282,16 +330,19 @@ def controller(operation): # no missing test we could add, and no value in coverage nagging # us about adding one. if operation.opname in [ - "__getattribute__", "__getattr__" + "__getattribute__", + "__getattr__", ]: # pragma: no cover if operation.args[0] == "tb_next": return tb_next return operation.delegate() return tputil.make_proxy(controller, type(base_tb), base_tb) + else: # ctypes it is import ctypes + # How to handle refcounting? I don't want to use ctypes.py_object because # I don't understand or trust it, and I don't want to use # ctypes.pythonapi.Py_{Inc,Dec}Ref because we might clash with user code @@ -338,7 +389,12 @@ def copy_tb(base_tb, tb_next): c_new_tb.tb_lasti = base_tb.tb_lasti c_new_tb.tb_lineno = base_tb.tb_lineno - return new_tb + try: + return new_tb + finally: + # delete references from locals to avoid creating cycles + # see test_MultiError_catch_doesnt_create_cyclic_garbage + del new_tb, old_tb_frame def concat_tb(head, tail): @@ -356,118 +412,58 @@ def concat_tb(head, tail): return current_head -################################################################ -# MultiError traceback formatting -# -# What follows is terrible, terrible monkey patching of -# traceback.TracebackException to add support for handling -# MultiErrors -################################################################ - -traceback_exception_original_init = traceback.TracebackException.__init__ - - -def traceback_exception_init( - self, - exc_type, - exc_value, - exc_traceback, - *, - limit=None, - lookup_lines=True, - capture_locals=False, - _seen=None -): - if _seen is None: - _seen = set() - - # Capture the original exception and its cause and context as TracebackExceptions - traceback_exception_original_init( - self, - exc_type, - exc_value, - exc_traceback, - limit=limit, - lookup_lines=lookup_lines, - capture_locals=capture_locals, - _seen=_seen - ) - - # Capture each of the exceptions in the MultiError along with each of their causes and contexts - if isinstance(exc_value, MultiError): - embedded = [] - for exc in exc_value.exceptions: - if exc_key(exc) not in _seen: - embedded.append( - traceback.TracebackException.from_exception( - exc, - limit=limit, - lookup_lines=lookup_lines, - capture_locals=capture_locals, - # copy the set of _seen exceptions so that duplicates - # shared between sub-exceptions are not omitted - _seen=set(_seen) - ) - ) - self.embedded = embedded - else: - self.embedded = [] - - -traceback.TracebackException.__init__ = traceback_exception_init -traceback_exception_original_format = traceback.TracebackException.format - - -def traceback_exception_format(self, *, chain=True): - yield from traceback_exception_original_format(self, chain=chain) - - for i, exc in enumerate(self.embedded): - yield "\nDetails of embedded exception {}:\n\n".format(i + 1) - yield from ( - textwrap.indent(line, " " * 2) for line in exc.format(chain=chain) - ) - - -traceback.TracebackException.format = traceback_exception_format - - -def trio_excepthook(etype, value, tb): - for chunk in traceback.format_exception(etype, value, tb): - sys.stderr.write(chunk) - - -IPython_handler_installed = False -warning_given = False +# Remove when IPython gains support for exception groups +# (https://github.com/ipython/ipython/issues/13753) if "IPython" in sys.modules: import IPython + ip = IPython.get_ipython() if ip is not None: if ip.custom_exceptions != (): warnings.warn( "IPython detected, but you already have a custom exception " "handler installed. I'll skip installing Trio's custom " - "handler, but this means MultiErrors will not show full " + "handler, but this means exception groups will not show full " "tracebacks.", - category=RuntimeWarning + category=RuntimeWarning, ) - warning_given = True else: def trio_show_traceback(self, etype, value, tb, tb_offset=None): # XX it would be better to integrate with IPython's fancy # exception formatting stuff (and not ignore tb_offset) - trio_excepthook(etype, value, tb) + print_exception(value) - ip.set_custom_exc((MultiError,), trio_show_traceback) - IPython_handler_installed = True + ip.set_custom_exc((BaseExceptionGroup,), trio_show_traceback) -if sys.excepthook is sys.__excepthook__: - sys.excepthook = trio_excepthook -else: - if not IPython_handler_installed and not warning_given: - warnings.warn( - "You seem to already have a custom sys.excepthook handler " - "installed. I'll skip installing Trio's custom handler, but this " - "means MultiErrors will not show full tracebacks.", - category=RuntimeWarning - ) + +# Ubuntu's system Python has a sitecustomize.py file that import +# apport_python_hook and replaces sys.excepthook. +# +# The custom hook captures the error for crash reporting, and then calls +# sys.__excepthook__ to actually print the error. +# +# We don't mind it capturing the error for crash reporting, but we want to +# take over printing the error. So we monkeypatch the apport_python_hook +# module so that instead of calling sys.__excepthook__, it calls our custom +# hook. +# +# More details: https://github.com/python-trio/trio/issues/1065 +if ( + sys.version_info < (3, 11) + and getattr(sys.excepthook, "__name__", None) == "apport_excepthook" +): + from types import ModuleType + + import apport_python_hook + from exceptiongroup import format_exception + + assert sys.excepthook is apport_python_hook.apport_excepthook + + def replacement_excepthook(etype, value, tb): + sys.stderr.write("".join(format_exception(etype, value, tb))) + + fake_sys = ModuleType("trio_fake_sys") + fake_sys.__dict__.update(sys.__dict__) + fake_sys.__excepthook__ = replacement_excepthook # type: ignore + apport_python_hook.sys = fake_sys diff --git a/trio/_core/_parking_lot.py b/trio/_core/_parking_lot.py index d14e95e6ce..74708433da 100644 --- a/trio/_core/_parking_lot.py +++ b/trio/_core/_parking_lot.py @@ -69,25 +69,38 @@ # unpark is called. # # See: https://github.com/python-trio/trio/issues/53 +from __future__ import annotations -from itertools import count -import attr +import math from collections import OrderedDict +from collections.abc import Iterator +from typing import TYPE_CHECKING + +import attr from .. import _core +from .._util import Final + +if TYPE_CHECKING: + from ._run import Task -__all__ = ["ParkingLot"] -_counter = count() +@attr.s(frozen=True, slots=True) +class ParkingLotStatistics: + """An object containing debugging information for a ParkingLot. + Currently the following fields are defined: + + * ``tasks_waiting`` (int): The number of tasks blocked on this lot's + :meth:`trio.lowlevel.ParkingLot.park` method. + + """ -@attr.s(frozen=True) -class _ParkingLotStatistics: - tasks_waiting = attr.ib() + tasks_waiting: int = attr.ib() -@attr.s(eq=False, hash=False) -class ParkingLot: +@attr.s(eq=False, hash=False, slots=True) +class ParkingLot(metaclass=Final): """A fair wait queue with cancellation and requeueing. This class encapsulates the tricky parts of implementing a wait @@ -102,18 +115,14 @@ class ParkingLot: # {task: None}, we just want a deque where we can quickly delete random # items - _parked = attr.ib(factory=OrderedDict, init=False) + _parked: OrderedDict[Task, None] = attr.ib(factory=OrderedDict, init=False) - def __len__(self): - """Returns the number of parked tasks. - - """ + def __len__(self) -> int: + """Returns the number of parked tasks.""" return len(self._parked) - def __bool__(self): - """True if there are parked tasks, False otherwise. - - """ + def __bool__(self) -> bool: + """True if there are parked tasks, False otherwise.""" return bool(self._parked) # XX this currently returns None @@ -121,7 +130,7 @@ def __bool__(self): # line (for false wakeups), then we could have it return a ticket that # abstracts the "place in line" concept. @_core.enable_ki_protection - async def park(self): + async def park(self) -> None: """Park the current task until woken by a call to :meth:`unpark` or :meth:`unpark_all`. @@ -136,13 +145,20 @@ def abort_fn(_): await _core.wait_task_rescheduled(abort_fn) - def _pop_several(self, count): - for _ in range(min(count, len(self._parked))): + def _pop_several(self, count: int | float) -> Iterator[Task]: + if isinstance(count, float): + if math.isinf(count): + count = len(self._parked) + else: + raise ValueError("Cannot pop a non-integer number of tasks.") + else: + count = min(count, len(self._parked)) + for _ in range(count): task, _ = self._parked.popitem(last=False) yield task @_core.enable_ki_protection - def unpark(self, *, count=1): + def unpark(self, *, count: int | float = 1) -> list[Task]: """Unpark one or more tasks. This wakes up ``count`` tasks that are blocked in :meth:`park`. If @@ -150,7 +166,7 @@ def unpark(self, *, count=1): are available and then returns successfully. Args: - count (int): the number of tasks to unpark. + count (int | math.inf): the number of tasks to unpark. """ tasks = list(self._pop_several(count)) @@ -158,14 +174,12 @@ def unpark(self, *, count=1): _core.reschedule(task) return tasks - def unpark_all(self): - """Unpark all parked tasks. - - """ + def unpark_all(self) -> list[Task]: + """Unpark all parked tasks.""" return self.unpark(count=len(self)) @_core.enable_ki_protection - def repark(self, new_lot, *, count=1): + def repark(self, new_lot: ParkingLot, *, count: int | float = 1) -> None: """Move parked tasks from one :class:`ParkingLot` object to another. This dequeues ``count`` tasks from one lot, and requeues them on @@ -177,8 +191,8 @@ async def parker(lot): print("woken") async def main(): - lot1 = trio.hazmat.ParkingLot() - lot2 = trio.hazmat.ParkingLot() + lot1 = trio.lowlevel.ParkingLot() + lot2 = trio.lowlevel.ParkingLot() async with trio.open_nursery() as nursery: nursery.start_soon(parker, lot1) await trio.testing.wait_all_tasks_blocked() @@ -195,7 +209,7 @@ async def main(): Args: new_lot (ParkingLot): the parking lot to move tasks to. - count (int): the number of tasks to move. + count (int|math.inf): the number of tasks to move. """ if not isinstance(new_lot, ParkingLot): @@ -204,7 +218,7 @@ async def main(): new_lot._parked[task] = None task.custom_sleep_data = new_lot - def repark_all(self, new_lot): + def repark_all(self, new_lot: ParkingLot) -> None: """Move all parked tasks from one :class:`ParkingLot` object to another. @@ -213,7 +227,7 @@ def repark_all(self, new_lot): """ return self.repark(new_lot, count=len(self)) - def statistics(self): + def statistics(self) -> ParkingLotStatistics: """Return an object containing debugging information. Currently the following fields are defined: @@ -222,4 +236,4 @@ def statistics(self): :meth:`park` method. """ - return _ParkingLotStatistics(tasks_waiting=len(self._parked)) + return ParkingLotStatistics(tasks_waiting=len(self._parked)) diff --git a/trio/_core/_run.py b/trio/_core/_run.py index 9bb115dadc..ce8feb2827 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -1,49 +1,68 @@ +from __future__ import annotations + +import enum import functools +import gc import itertools -import logging -import os import random import select import sys import threading +import warnings from collections import deque -import collections.abc -from contextlib import contextmanager, closing - +from collections.abc import Callable, Coroutine, Iterator +from contextlib import AbstractAsyncContextManager, contextmanager from contextvars import copy_context +from heapq import heapify, heappop, heappush from math import inf from time import perf_counter - -from sniffio import current_async_library_cvar +from types import TracebackType +from typing import TYPE_CHECKING, Any, NoReturn, TypeVar import attr -from async_generator import isasyncgen +from outcome import Error, Outcome, Value, capture +from sniffio import current_async_library_cvar from sortedcontainers import SortedDict -from outcome import Error, Value, capture +from .. import _core +from .._util import Final, NoPublicConstructor, coroutine_or_error +from ._asyncgens import AsyncGenerators from ._entry_queue import EntryQueue, TrioToken -from ._exceptions import (TrioInternalError, RunFinishedError, Cancelled) -from ._ki import ( - LOCALS_KEY_KI_PROTECTION_ENABLED, ki_manager, enable_ki_protection -) -from ._multierror import MultiError +from ._exceptions import Cancelled, RunFinishedError, TrioInternalError +from ._instrumentation import Instruments +from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED, KIManager, enable_ki_protection +from ._multierror import MultiError, concat_tb +from ._thread_cache import start_thread_soon from ._traps import ( Abort, - wait_task_rescheduled, CancelShieldedCheckpoint, PermanentlyDetachCoroutineObject, WaitTaskRescheduled, + cancel_shielded_checkpoint, + wait_task_rescheduled, ) -from .. import _core -from .._deprecate import deprecated -from .._util import Final, NoPublicConstructor -_NO_SEND = object() +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup + +from types import FrameType + +if TYPE_CHECKING: + import contextvars + + # An unfortunate name collision here with trio._util.Final + from typing import Final as FinalT + +DEADLINE_HEAP_MIN_PRUNE_THRESHOLD: FinalT = 1000 + +_NO_SEND: FinalT = object() + +FnT = TypeVar("FnT", bound="Callable[..., Any]") # Decorator to mark methods public. This does nothing by itself, but # trio/_tools/gen_exports.py looks for it. -def _public(fn): +def _public(fn: FnT) -> FnT: return fn @@ -52,23 +71,31 @@ def _public(fn): # variable to True, and registers the Random instance _r for Hypothesis # to manage for each test case, which together should make Trio's task # scheduling loop deterministic. We have a test for that, of course. -_ALLOW_DETERMINISTIC_SCHEDULING = False +_ALLOW_DETERMINISTIC_SCHEDULING: FinalT = False _r = random.Random() -# Used to log exceptions in instruments -INSTRUMENT_LOGGER = logging.getLogger("trio.abc.Instrument") +def _count_context_run_tb_frames() -> int: + """Count implementation dependent traceback frames from Context.run() -# On 3.7+, Context.run() is implemented in C and doesn't show up in -# tracebacks. On 3.6 and earlier, we use the contextvars backport, which is -# currently implemented in Python and adds 1 frame to tracebacks. So this -# function is a super-overkill version of "0 if sys.version_info >= (3, 7) -# else 1". But if Context.run ever changes, we'll be ready! -# -# This can all be removed once we drop support for 3.6. -def _count_context_run_tb_frames(): - def function_with_unique_name_xyzzy(): - 1 / 0 + On CPython, Context.run() is implemented in C and doesn't show up in + tracebacks. On PyPy, it is implemented in Python and adds 1 frame to + tracebacks. + + Returns: + int: Traceback frame count + + """ + + def function_with_unique_name_xyzzy() -> NoReturn: + try: + 1 / 0 + except ZeroDivisionError: + raise + else: # pragma: no cover + raise TrioInternalError( + "A ZeroDivisionError should have been raised, but it wasn't." + ) ctx = copy_context() try: @@ -76,15 +103,20 @@ def function_with_unique_name_xyzzy(): except ZeroDivisionError as exc: tb = exc.__traceback__ # Skip the frame where we caught it - tb = tb.tb_next + tb = tb.tb_next # type: ignore[union-attr] count = 0 - while tb.tb_frame.f_code.co_name != "function_with_unique_name_xyzzy": - tb = tb.tb_next + while tb.tb_frame.f_code.co_name != "function_with_unique_name_xyzzy": # type: ignore[union-attr] + tb = tb.tb_next # type: ignore[union-attr] count += 1 return count + else: # pragma: no cover + raise TrioInternalError( + f"The purpose of {function_with_unique_name_xyzzy.__name__} is " + "to raise a ZeroDivisionError, but it didn't." + ) -CONTEXT_RUN_TB_FRAMES = _count_context_run_tb_frames() +CONTEXT_RUN_TB_FRAMES: FinalT = _count_context_run_tb_frames() @attr.s(frozen=True, slots=True) @@ -92,26 +124,125 @@ class SystemClock: # Add a large random offset to our clock to ensure that if people # accidentally call time.perf_counter() directly or start comparing clocks # between different runs, then they'll notice the bug quickly: - offset = attr.ib(factory=lambda: _r.uniform(10000, 200000)) + offset: float = attr.ib(factory=lambda: _r.uniform(10000, 200000)) - def start_clock(self): + def start_clock(self) -> None: pass # In cPython 3, on every platform except Windows, perf_counter is # exactly the same as time.monotonic; and on Windows, it uses # QueryPerformanceCounter instead of GetTickCount64. - def current_time(self): + def current_time(self) -> float: return self.offset + perf_counter() - def deadline_to_sleep_time(self, deadline): + def deadline_to_sleep_time(self, deadline: float) -> float: return deadline - self.current_time() +class IdlePrimedTypes(enum.Enum): + WAITING_FOR_IDLE = 1 + AUTOJUMP_CLOCK = 2 + + ################################################################ # CancelScope and friends ################################################################ +def collapse_exception_group(excgroup): + """Recursively collapse any single-exception groups into that single contained + exception. + + """ + exceptions = list(excgroup.exceptions) + modified = False + for i, exc in enumerate(exceptions): + if isinstance(exc, BaseExceptionGroup): + new_exc = collapse_exception_group(exc) + if new_exc is not exc: + modified = True + exceptions[i] = new_exc + + if len(exceptions) == 1 and isinstance(excgroup, MultiError) and excgroup.collapse: + exceptions[0].__traceback__ = concat_tb( + excgroup.__traceback__, exceptions[0].__traceback__ + ) + return exceptions[0] + elif modified: + return excgroup.derive(exceptions) + else: + return excgroup + + +@attr.s(eq=False, slots=True) +class Deadlines: + """A container of deadlined cancel scopes. + + Only contains scopes with non-infinite deadlines that are currently + attached to at least one task. + + """ + + # Heap of (deadline, id(CancelScope), CancelScope) + _heap = attr.ib(factory=list) + # Count of active deadlines (those that haven't been changed) + _active = attr.ib(default=0) + + def add(self, deadline, cancel_scope): + heappush(self._heap, (deadline, id(cancel_scope), cancel_scope)) + self._active += 1 + + def remove(self, deadline, cancel_scope): + self._active -= 1 + + def next_deadline(self): + while self._heap: + deadline, _, cancel_scope = self._heap[0] + if deadline == cancel_scope._registered_deadline: + return deadline + else: + # This entry is stale; discard it and try again + heappop(self._heap) + return inf + + def _prune(self): + # In principle, it's possible for a cancel scope to toggle back and + # forth repeatedly between the same two deadlines, and end up with + # lots of stale entries that *look* like they're still active, because + # their deadline is correct, but in fact are redundant. So when + # pruning we have to eliminate entries with the wrong deadline, *and* + # eliminate duplicates. + seen = set() + pruned_heap = [] + for deadline, tiebreaker, cancel_scope in self._heap: + if deadline == cancel_scope._registered_deadline: + if cancel_scope in seen: + continue + seen.add(cancel_scope) + pruned_heap.append((deadline, tiebreaker, cancel_scope)) + # See test_cancel_scope_deadline_duplicates for a test that exercises + # this assert: + assert len(pruned_heap) == self._active + heapify(pruned_heap) + self._heap = pruned_heap + + def expire(self, now): + did_something = False + while self._heap and self._heap[0][0] <= now: + deadline, _, cancel_scope = heappop(self._heap) + if deadline == cancel_scope._registered_deadline: + did_something = True + # This implicitly calls self.remove(), so we don't need to + # decrement _active here + cancel_scope.cancel() + # If we've accumulated too many stale entries, then prune the heap to + # keep it under control. (We only do this occasionally in a batch, to + # keep the amortized cost down) + if len(self._heap) > self._active * 2 + DEADLINE_HEAP_MIN_PRUNE_THRESHOLD: + self._prune() + return did_something + + @attr.s(eq=False, slots=True) class CancelStatus: """Tracks the cancellation status for a contiguous extent @@ -145,7 +276,7 @@ class CancelStatus: # Our associated cancel scope. Can be any object with attributes # `deadline`, `shield`, and `cancel_called`, but in current usage # is always a CancelScope object. Must not be None. - _scope = attr.ib() + _scope: CancelScope = attr.ib() # True iff the tasks in self._tasks should receive cancellations # when they checkpoint. Always True when scope.cancel_called is True; @@ -155,31 +286,31 @@ class CancelStatus: # effectively cancelled due to the cancel scope two levels out # becoming cancelled, but then the cancel scope one level out # becomes shielded so we're not effectively cancelled anymore. - effectively_cancelled = attr.ib(default=False) + effectively_cancelled: bool = attr.ib(default=False) # The CancelStatus whose cancellations can propagate to us; we # become effectively cancelled when they do, unless scope.shield # is True. May be None (for the outermost CancelStatus in a call # to trio.run(), briefly during TaskStatus.started(), or during # recovery from mis-nesting of cancel scopes). - _parent = attr.ib(default=None, repr=False) + _parent: CancelStatus | None = attr.ib(default=None, repr=False) # All of the CancelStatuses that have this CancelStatus as their parent. - _children = attr.ib(factory=set, init=False, repr=False) + _children: set[CancelStatus] = attr.ib(factory=set, init=False, repr=False) # Tasks whose cancellation state is currently tied directly to # the cancellation state of this CancelStatus object. Don't modify # this directly; instead, use Task._activate_cancel_status(). # Invariant: all(task._cancel_status is self for task in self._tasks) - _tasks = attr.ib(factory=set, init=False, repr=False) + _tasks: set[Task] = attr.ib(factory=set, init=False, repr=False) # Set to True on still-active cancel statuses that are children # of a cancel status that's been closed. This is used to permit # recovery from mis-nested cancel scopes (well, at least enough # recovery to show a useful traceback). - abandoned_by_misnesting = attr.ib(default=False, init=False, repr=False) + abandoned_by_misnesting: bool = attr.ib(default=False, init=False, repr=False) - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: if self._parent is not None: self._parent._children.add(self) self.recalculate() @@ -187,11 +318,11 @@ def __attrs_post_init__(self): # parent/children/tasks accessors are used by TaskStatus.started() @property - def parent(self): + def parent(self) -> CancelStatus | None: return self._parent @parent.setter - def parent(self, parent): + def parent(self, parent: CancelStatus) -> None: if self._parent is not None: self._parent._children.remove(self) self._parent = parent @@ -200,14 +331,14 @@ def parent(self, parent): self.recalculate() @property - def children(self): + def children(self) -> frozenset[CancelStatus]: return frozenset(self._children) @property - def tasks(self): + def tasks(self) -> frozenset[Task]: return frozenset(self._tasks) - def encloses(self, other): + def encloses(self, other: CancelStatus | None) -> bool: """Returns true if this cancel status is a direct or indirect parent of cancel status *other*, or if *other* is *self*. """ @@ -217,7 +348,7 @@ def encloses(self, other): other = other.parent return False - def close(self): + def close(self) -> None: self.parent = None # now we're not a child of self.parent anymore if self._tasks or self._children: # Cancel scopes weren't exited in opposite order of being @@ -248,7 +379,8 @@ def close(self): @property def parent_cancellation_is_visible_to_us(self): return ( - self._parent is not None and not self._scope.shield + self._parent is not None + and not self._scope.shield and self._parent.effectively_cancelled ) @@ -276,7 +408,7 @@ def _mark_abandoned(self): for child in self._children: child._mark_abandoned() - def effective_deadline(self): + def effective_deadline(self) -> float: if self.effectively_cancelled: return -inf if self._parent is None or self._scope.shield: @@ -348,15 +480,15 @@ class CancelScope(metaclass=Final): has been entered yet, and changes take immediate effect. """ - _cancel_status = attr.ib(default=None, init=False) - _has_been_entered = attr.ib(default=False, init=False) - _registered_deadline = attr.ib(default=inf, init=False) - _cancel_called = attr.ib(default=False, init=False) - cancelled_caught = attr.ib(default=False, init=False) + _cancel_status: CancelStatus | None = attr.ib(default=None, init=False) + _has_been_entered: bool = attr.ib(default=False, init=False) + _registered_deadline: float = attr.ib(default=inf, init=False) + _cancel_called: bool = attr.ib(default=False, init=False) + cancelled_caught: bool = attr.ib(default=False, init=False) # Constructor arguments: - _deadline = attr.ib(default=inf, kw_only=True) - _shield = attr.ib(default=False, kw_only=True) + _deadline: float = attr.ib(default=inf, kw_only=True) + _shield: bool = attr.ib(default=False, kw_only=True) @enable_ki_protection def __enter__(self): @@ -369,18 +501,10 @@ def __enter__(self): if current_time() >= self._deadline: self.cancel() with self._might_change_registered_deadline(): - self._cancel_status = CancelStatus( - scope=self, parent=task._cancel_status - ) + self._cancel_status = CancelStatus(scope=self, parent=task._cancel_status) task._activate_cancel_status(self._cancel_status) return self - def _exc_filter(self, exc): - if isinstance(exc, Cancelled): - self.cancelled_caught = True - return None - return exc - def _close(self, exc): if self._cancel_status is None: new_exc = RuntimeError( @@ -422,8 +546,10 @@ def _close(self, exc): new_exc = RuntimeError( "Cancel scope stack corrupted: attempted to exit {!r} " "in {!r} that's still within its child {!r}\n{}".format( - self, scope_task, scope_task._cancel_status._scope, - MISNESTING_ADVICE + self, + scope_task, + scope_task._cancel_status._scope, + MISNESTING_ADVICE, ) ) new_exc.__context__ = exc @@ -432,21 +558,40 @@ def _close(self, exc): else: scope_task._activate_cancel_status(self._cancel_status.parent) if ( - exc is not None and self._cancel_status.effectively_cancelled + exc is not None + and self._cancel_status.effectively_cancelled and not self._cancel_status.parent_cancellation_is_visible_to_us ): - exc = MultiError.filter(self._exc_filter, exc) + if isinstance(exc, Cancelled): + self.cancelled_caught = True + exc = None + elif isinstance(exc, BaseExceptionGroup): + matched, exc = exc.split(Cancelled) + if matched: + self.cancelled_caught = True + + if exc: + exc = collapse_exception_group(exc) + self._cancel_status.close() with self._might_change_registered_deadline(): self._cancel_status = None return exc - @enable_ki_protection - def __exit__(self, etype, exc, tb): + def __exit__( + self, + etype: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> bool: # NB: NurseryManager calls _close() directly rather than __exit__(), # so __exit__() must be just _close() plus this logic for adapting # the exception-filtering result to the context manager API. + # This inlines the enable_ki_protection decorator so we can fix + # f_locals *locally* below to avoid reference cycles + locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True + # Tracebacks show the 'raise' line below out of context, so let's give # this variable a name that makes sense out of context. remaining_error_after_cancel_scope = self._close(exc) @@ -464,8 +609,15 @@ def __exit__(self, etype, exc, tb): _, value, _ = sys.exc_info() assert value is remaining_error_after_cancel_scope value.__context__ = old_context - - def __repr__(self): + # delete references from locals to avoid creating cycles + # see test_cancel_scope_exit_doesnt_create_cyclic_garbage + del remaining_error_after_cancel_scope, value, _, exc + # deep magic to remove refs via f_locals + locals() + # TODO: check if PEP558 changes the need for this call + # https://github.com/python/cpython/pull/3640 + + def __repr__(self) -> str: if self._cancel_status is not None: binding = "active" elif self._has_been_entered: @@ -485,16 +637,14 @@ def __repr__(self): else: state = ", deadline is {:.2f} seconds {}".format( abs(self._deadline - now), - "from now" if self._deadline >= now else "ago" + "from now" if self._deadline >= now else "ago", ) - return "".format( - id(self), binding, state - ) + return f"" @contextmanager @enable_ki_protection - def _might_change_registered_deadline(self): + def _might_change_registered_deadline(self) -> Iterator[None]: try: yield finally: @@ -506,13 +656,19 @@ def _might_change_registered_deadline(self): if old != new: self._registered_deadline = new runner = GLOBAL_RUN_CONTEXT.runner + if runner.is_guest: + old_next_deadline = runner.deadlines.next_deadline() if old != inf: - del runner.deadlines[old, id(self)] + runner.deadlines.remove(old, self) if new != inf: - runner.deadlines[new, id(self)] = self + runner.deadlines.add(new, self) + if runner.is_guest: + new_next_deadline = runner.deadlines.next_deadline() + if old_next_deadline != new_next_deadline: + runner.force_guest_tick_asap() @property - def deadline(self): + def deadline(self) -> float: """Read-write, :class:`float`. An absolute time on the current run's clock at which this scope will automatically become cancelled. You can adjust the deadline by modifying this @@ -538,12 +694,12 @@ def deadline(self): return self._deadline @deadline.setter - def deadline(self, new_deadline): + def deadline(self, new_deadline: float) -> None: with self._might_change_registered_deadline(): self._deadline = float(new_deadline) @property - def shield(self): + def shield(self) -> bool: """Read-write, :class:`bool`, default :data:`False`. So long as this is set to :data:`True`, then the code inside this scope will not receive :exc:`~trio.Cancelled` exceptions from scopes @@ -568,7 +724,7 @@ def shield(self): @shield.setter @enable_ki_protection - def shield(self, new_value): + def shield(self, new_value: bool) -> None: if not isinstance(new_value, bool): raise TypeError("shield must be a bool") self._shield = new_value @@ -576,7 +732,7 @@ def shield(self, new_value): self._cancel_status.recalculate() @enable_ki_protection - def cancel(self): + def cancel(self) -> None: """Cancels this scope immediately. This method is idempotent, i.e., if the scope was already @@ -590,7 +746,7 @@ def cancel(self): self._cancel_status.recalculate() @property - def cancel_called(self): + def cancel_called(self) -> bool: """Readonly :class:`bool`. Records whether cancellation has been requested for this scope, either by an explicit call to :meth:`cancel` or by the deadline expiring. @@ -619,12 +775,6 @@ def cancel_called(self): return self._cancel_called -@deprecated("0.11.0", issue=607, instead="trio.CancelScope") -def open_cancel_scope(*, deadline=inf, shield=False): - """Returns a context manager which creates a new cancellation scope.""" - return CancelScope(deadline=deadline, shield=shield) - - ################################################################ # Nursery and friends ################################################################ @@ -633,20 +783,18 @@ def open_cancel_scope(*, deadline=inf, shield=False): # This code needs to be read alongside the code from Nursery.start to make # sense. @attr.s(eq=False, hash=False, repr=False) -class _TaskStatus: +class TaskStatus(metaclass=Final): _old_nursery = attr.ib() _new_nursery = attr.ib() _called_started = attr.ib(default=False) _value = attr.ib(default=None) def __repr__(self): - return "".format(id(self)) + return f"" def started(self, value=None): if self._called_started: - raise RuntimeError( - "called 'started' twice on the same task status" - ) + raise RuntimeError("called 'started' twice on the same task status") self._called_started = True self._value = value @@ -666,6 +814,7 @@ def started(self, value=None): self._old_nursery._children = set() for task in tasks: task._parent_nursery = self._new_nursery + task._eventual_parent_nursery = None self._new_nursery._children.add(task) # Move all children of the old nursery's cancel status object @@ -696,6 +845,7 @@ def started(self, value=None): self._old_nursery._check_nursery_closed() +@attr.s class NurseryManager: """Nursery context manager. @@ -705,15 +855,25 @@ class NurseryManager: and StopAsyncIteration. """ + + strict_exception_groups: bool = attr.ib(default=False) + @enable_ki_protection - async def __aenter__(self): + async def __aenter__(self) -> Nursery: self._scope = CancelScope() self._scope.__enter__() - self._nursery = Nursery._create(current_task(), self._scope) + self._nursery = Nursery._create( + current_task(), self._scope, self.strict_exception_groups + ) return self._nursery @enable_ki_protection - async def __aexit__(self, etype, exc, tb): + async def __aexit__( + self, + etype: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> bool: new_exc = await self._nursery._nested_child_finished(exc) # Tracebacks show the 'raise' line below out of context, so let's give # this variable a name that makes sense out of context. @@ -732,25 +892,46 @@ async def __aexit__(self, etype, exc, tb): _, value, _ = sys.exc_info() assert value is combined_error_from_nursery value.__context__ = old_context + # delete references from locals to avoid creating cycles + # see test_simple_cancel_scope_usage_doesnt_create_cyclic_garbage + del _, combined_error_from_nursery, value, new_exc - def __enter__(self): - raise RuntimeError( - "use 'async with open_nursery(...)', not 'with open_nursery(...)'" - ) + # make sure these raise errors in static analysis if called + if not TYPE_CHECKING: + + def __enter__(self) -> NoReturn: + raise RuntimeError( + "use 'async with open_nursery(...)', not 'with open_nursery(...)'" + ) - def __exit__(self): # pragma: no cover - assert False, """Never called, but should be defined""" + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> NoReturn: # pragma: no cover + raise AssertionError("Never called, but should be defined") -def open_nursery(): +def open_nursery( + strict_exception_groups: bool | None = None, +) -> AbstractAsyncContextManager[Nursery]: """Returns an async context manager which must be used to create a new `Nursery`. It does not block on entry; on exit it blocks until all child tasks have exited. + Args: + strict_exception_groups (bool): If true, even a single raised exception will be + wrapped in an exception group. This will eventually become the default + behavior. If not specified, uses the value passed to :func:`run`. + """ - return NurseryManager() + if strict_exception_groups is None: + strict_exception_groups = GLOBAL_RUN_CONTEXT.runner.strict_exception_groups + + return NurseryManager(strict_exception_groups=strict_exception_groups) class Nursery(metaclass=NoPublicConstructor): @@ -774,8 +955,15 @@ class Nursery(metaclass=NoPublicConstructor): other things, e.g. if you want to explicitly cancel all children in response to some external event. """ - def __init__(self, parent_task, cancel_scope): + + def __init__( + self, + parent_task: Task, + cancel_scope: CancelScope, + strict_exception_groups: bool, + ): self._parent_task = parent_task + self._strict_exception_groups = strict_exception_groups parent_task._child_nurseries.append(self) # the cancel status that children inherit - we take a snapshot, so it # won't be affected by any changes in the parent. @@ -784,8 +972,8 @@ def __init__(self, parent_task, cancel_scope): # children. self.cancel_scope = cancel_scope assert self.cancel_scope._cancel_status is self._cancel_status - self._children = set() - self._pending_excs = [] + self._children: set[Task] = set() + self._pending_excs: list[BaseException] = [] # The "nested child" is how this code refers to the contents of the # nursery's 'async with' block, which acts like a child Task in all # the ways we can make it. @@ -795,24 +983,22 @@ def __init__(self, parent_task, cancel_scope): self._closed = False @property - def child_tasks(self): - """(`frozenset`): Contains all the child :class:`~trio.hazmat.Task` + def child_tasks(self) -> frozenset[Task]: + """(`frozenset`): Contains all the child :class:`~trio.lowlevel.Task` objects which are still running.""" return frozenset(self._children) @property - def parent_task(self): - "(`~trio.hazmat.Task`): The Task that opened this nursery." + def parent_task(self) -> Task: + "(`~trio.lowlevel.Task`): The Task that opened this nursery." return self._parent_task - def _add_exc(self, exc): + def _add_exc(self, exc: BaseException) -> None: self._pending_excs.append(exc) self.cancel_scope.cancel() def _check_nursery_closed(self): - if not any( - [self._nested_child_running, self._children, self._pending_starts] - ): + if not any([self._nested_child_running, self._children, self._pending_starts]): self._closed = True if self._parent_waiting_in_aexit: self._parent_waiting_in_aexit = False @@ -820,12 +1006,13 @@ def _check_nursery_closed(self): def _child_finished(self, task, outcome): self._children.remove(task) - if type(outcome) is Error: + if isinstance(outcome, Error): self._add_exc(outcome.error) self._check_nursery_closed() async def _nested_child_finished(self, nested_child_exc): - """Returns MultiError instance if there are pending exceptions.""" + # Returns MultiError instance (or any exception if the nursery is in loose mode + # and there is just one contained exception) if there are pending exceptions if nested_child_exc is not None: self._add_exc(nested_child_exc) self._nested_child_running = False @@ -853,23 +1040,29 @@ def aborted(raise_cancel): popped = self._parent_task._child_nurseries.pop() assert popped is self if self._pending_excs: - return MultiError(self._pending_excs) + try: + return MultiError( + self._pending_excs, _collapse=not self._strict_exception_groups + ) + finally: + # avoid a garbage cycle + # (see test_nursery_cancel_doesnt_create_cyclic_garbage) + del self._pending_excs def start_soon(self, async_fn, *args, name=None): """Creates a child task, scheduling ``await async_fn(*args)``. - This and :meth:`start` are the two fundamental methods for + If you want to run a function and immediately wait for its result, + then you don't need a nursery; just use ``await async_fn(*args)``. + If you want to wait for the task to initialize itself before + continuing, see :meth:`start`, the other fundamental method for creating concurrent tasks in Trio. Note that this is *not* an async function and you don't use await when calling it. It sets up the new task, but then returns - immediately, *before* it has a chance to run. The new task won’t - actually get a chance to do anything until some later point when - you execute a checkpoint and the scheduler decides to run it. - If you want to run a function and immediately wait for its result, - then you don't need a nursery; just use ``await async_fn(*args)``. - If you want to wait for the task to initialize itself before - continuing, see :meth:`start()`. + immediately, *before* the new task has a chance to do anything. + New tasks may start running in any order, and at any checkpoint the + scheduler chooses - at latest when the nursery is waiting to exit. It's possible to pass a nursery object into another task, which allows that task to start new child tasks in the first task's @@ -891,9 +1084,6 @@ def start_soon(self, async_fn, *args, name=None): original function as the ``name=`` to make debugging easier. - Returns: - True if successful, False otherwise. - Raises: RuntimeError: If this nursery is no longer open (i.e. its ``async with`` block has @@ -902,7 +1092,7 @@ def start_soon(self, async_fn, *args, name=None): GLOBAL_RUN_CONTEXT.runner.spawn_impl(async_fn, args, self, name) async def start(self, async_fn, *args, name=None): - r""" Creates and initalizes a child task. + r"""Creates and initializes a child task. Like :meth:`start_soon`, but blocks until the new task has finished initializing itself, and optionally returns some @@ -914,10 +1104,10 @@ async def start(self, async_fn, *args, name=None): The conventional way to define ``async_fn`` is like:: - async def async_fn(arg1, arg2, \*, task_status=trio.TASK_STATUS_IGNORED): - ... + async def async_fn(arg1, arg2, *, task_status=trio.TASK_STATUS_IGNORED): + ... # Caller is blocked waiting for this code to run task_status.started() - ... + ... # This async code can be interleaved with the caller :attr:`trio.TASK_STATUS_IGNORED` is a special global object with a do-nothing ``started`` method. This way your function supports @@ -935,8 +1125,8 @@ async def async_fn(arg1, arg2, \*, task_status=trio.TASK_STATUS_IGNORED): :meth:`start` is cancelled, then the child task is also cancelled. - When the child calls ``task_status.started()``, it's moved from - out from underneath :meth:`start` and into the given nursery. + When the child calls ``task_status.started()``, it's moved out + from underneath :meth:`start` and into the given nursery. If the child task passes a value to ``task_status.started(value)``, then :meth:`start` returns this @@ -947,24 +1137,25 @@ async def async_fn(arg1, arg2, \*, task_status=trio.TASK_STATUS_IGNORED): try: self._pending_starts += 1 async with open_nursery() as old_nursery: - task_status = _TaskStatus(old_nursery, self) + task_status = TaskStatus(old_nursery, self) thunk = functools.partial(async_fn, task_status=task_status) - old_nursery.start_soon(thunk, *args, name=name) - # Wait for either _TaskStatus.started or an exception to + task = GLOBAL_RUN_CONTEXT.runner.spawn_impl( + thunk, args, old_nursery, name + ) + task._eventual_parent_nursery = self + # Wait for either TaskStatus.started or an exception to # cancel this nursery: # If we get here, then the child either got reparented or exited - # normally. The complicated logic is all in _TaskStatus.started(). + # normally. The complicated logic is all in TaskStatus.started(). # (Any exceptions propagate directly out of the above.) if not task_status._called_started: - raise RuntimeError( - "child exited without calling task_status.started()" - ) + raise RuntimeError("child exited without calling task_status.started()") return task_status._value finally: self._pending_starts -= 1 self._check_nursery_closed() - def __del__(self): + def __del__(self) -> None: assert not self._children @@ -973,15 +1164,14 @@ def __del__(self): ################################################################ -@attr.s(eq=False, hash=False, repr=False) -class Task: - _parent_nursery = attr.ib() - coro = attr.ib() +@attr.s(eq=False, hash=False, repr=False, slots=True) +class Task(metaclass=NoPublicConstructor): + _parent_nursery: Nursery | None = attr.ib() + coro: Coroutine[Any, Outcome[object], Any] = attr.ib() _runner = attr.ib() - name = attr.ib() - # PEP 567 contextvars context - context = attr.ib() - _counter = attr.ib(init=False, factory=itertools.count().__next__) + name: str = attr.ib() + context: contextvars.Context = attr.ib() + _counter: int = attr.ib(init=False, factory=itertools.count().__next__) # Invariant: # - for unscheduled tasks, _next_send_fn and _next_send are both None @@ -996,23 +1186,26 @@ class Task: # Tasks start out unscheduled. _next_send_fn = attr.ib(default=None) _next_send = attr.ib(default=None) - _abort_func = attr.ib(default=None) - custom_sleep_data = attr.ib(default=None) + _abort_func: Callable[[Callable[[], NoReturn]], Abort] | None = attr.ib( + default=None + ) + custom_sleep_data: Any = attr.ib(default=None) # For introspection and nursery.start() - _child_nurseries = attr.ib(factory=list) + _child_nurseries: list[Nursery] = attr.ib(factory=list) + _eventual_parent_nursery: Nursery | None = attr.ib(default=None) # these are counts of how many cancel/schedule points this task has # executed, for assert{_no,}_checkpoints # XX maybe these should be exposed as part of a statistics() method? - _cancel_points = attr.ib(default=0) - _schedule_points = attr.ib(default=0) + _cancel_points: int = attr.ib(default=0) + _schedule_points: int = attr.ib(default=0) - def __repr__(self): - return ("".format(self.name, id(self))) + def __repr__(self) -> str: + return f"" @property - def parent_nursery(self): + def parent_nursery(self) -> Nursery | None: """The nursery this task is inside (or None if this is the "init" task). @@ -1023,7 +1216,19 @@ def parent_nursery(self): return self._parent_nursery @property - def child_nurseries(self): + def eventual_parent_nursery(self) -> Nursery | None: + """The nursery this task will be inside after it calls + ``task_status.started()``. + + If this task has already called ``started()``, or if it was not + spawned using `nursery.start() `, then + its `eventual_parent_nursery` is ``None``. + + """ + return self._eventual_parent_nursery + + @property + def child_nurseries(self) -> list[Nursery]: """The nurseries this task contains. This is a list, with outer nurseries before inner nurseries. @@ -1031,15 +1236,63 @@ def child_nurseries(self): """ return list(self._child_nurseries) + def iter_await_frames(self) -> Iterator[tuple[FrameType, int]]: + """Iterates recursively over the coroutine-like objects this + task is waiting on, yielding the frame and line number at each + frame. + + This is similar to `traceback.walk_stack` in a synchronous + context. Note that `traceback.walk_stack` returns frames from + the bottom of the call stack to the top, while this function + starts from `Task.coro ` and works it + way down. + + Example usage: extracting a stack trace:: + + import traceback + + def print_stack_for_task(task): + ss = traceback.StackSummary.extract(task.iter_await_frames()) + print("".join(ss.format())) + + """ + # ignore static typing as we're doing lots of dynamic introspection + coro: Any = self.coro + while coro is not None: + if hasattr(coro, "cr_frame"): + # A real coroutine + yield coro.cr_frame, coro.cr_frame.f_lineno + coro = coro.cr_await + elif hasattr(coro, "gi_frame"): + # A generator decorated with @types.coroutine + yield coro.gi_frame, coro.gi_frame.f_lineno + coro = coro.gi_yieldfrom + elif coro.__class__.__name__ in [ + "async_generator_athrow", + "async_generator_asend", + ]: + # cannot extract the generator directly, see https://github.com/python/cpython/issues/76991 + # we can however use the gc to look through the object + for referent in gc.get_referents(coro): + if hasattr(referent, "ag_frame"): + yield referent.ag_frame, referent.ag_frame.f_lineno + coro = referent.ag_await + break + else: + # either cpython changed or we are running on an alternative python implementation + return + else: + return + ################ # Cancellation ################ # The CancelStatus object that is currently active for this task. # Don't change this directly; instead, use _activate_cancel_status(). - _cancel_status = attr.ib(default=None, repr=False) + _cancel_status: CancelStatus = attr.ib(default=None, repr=False) - def _activate_cancel_status(self, cancel_status): + def _activate_cancel_status(self, cancel_status: CancelStatus) -> None: if self._cancel_status is not None: self._cancel_status._tasks.remove(self) self._cancel_status = cancel_status @@ -1048,11 +1301,16 @@ def _activate_cancel_status(self, cancel_status): if self._cancel_status.effectively_cancelled: self._attempt_delivery_of_any_pending_cancel() - def _attempt_abort(self, raise_cancel): + def _attempt_abort(self, raise_cancel: Callable[[], NoReturn]) -> None: # Either the abort succeeds, in which case we will reschedule the # task, or else it fails, in which case it will worry about # rescheduling itself (hopefully eventually calling reraise to raise # the given exception, but not necessarily). + + # This is only called by the functions immediately below, which both check + # `self.abort_func is not None`. + assert self._abort_func is not None, "FATAL INTERNAL ERROR" + success = self._abort_func(raise_cancel) if type(success) is not Abort: raise TrioInternalError("abort function must return Abort enum") @@ -1062,7 +1320,7 @@ def _attempt_abort(self, raise_cancel): if success is Abort.SUCCEEDED: self._runner.reschedule(self, capture(raise_cancel)) - def _attempt_delivery_of_any_pending_cancel(self): + def _attempt_delivery_of_any_pending_cancel(self) -> None: if self._abort_func is None: return if not self._cancel_status.effectively_cancelled: @@ -1073,12 +1331,12 @@ def raise_cancel(): self._attempt_abort(raise_cancel) - def _attempt_delivery_of_pending_ki(self): + def _attempt_delivery_of_pending_ki(self) -> None: assert self._runner.ki_pending if self._abort_func is None: return - def raise_cancel(): + def raise_cancel() -> NoReturn: self._runner.ki_pending = False raise KeyboardInterrupt @@ -1089,7 +1347,13 @@ def raise_cancel(): # The central Runner object ################################################################ -GLOBAL_RUN_CONTEXT = threading.local() + +class RunContext(threading.local): + runner: Runner + task: Task + + +GLOBAL_RUN_CONTEXT: FinalT = RunContext() @attr.s(frozen=True) @@ -1101,22 +1365,85 @@ class _RunStatistics: run_sync_soon_queue_size = attr.ib() -@attr.s(eq=False, hash=False) +# This holds all the state that gets trampolined back and forth between +# callbacks when we're running in guest mode. +# +# It has to be a separate object from Runner, and Runner *cannot* hold +# references to it (directly or indirectly)! +# +# The idea is that we want a chance to detect if our host loop quits and stops +# driving us forward. We detect that by unrolled_run_gen being garbage +# collected, and hitting its 'except GeneratorExit:' block. So this only +# happens if unrolled_run_gen is GCed. +# +# The Runner state is referenced from the global GLOBAL_RUN_CONTEXT. The only +# way it gets *un*referenced is by unrolled_run_gen completing, e.g. by being +# GCed. But if Runner has a direct or indirect reference to it, and the host +# loop has abandoned it, then this will never happen! +# +# So this object can reference Runner, but Runner can't reference it. The only +# references to it are the "in flight" callback chain on the host loop / +# worker thread. +@attr.s(eq=False, hash=False, slots=True) +class GuestState: + runner = attr.ib() + run_sync_soon_threadsafe = attr.ib() + run_sync_soon_not_threadsafe = attr.ib() + done_callback = attr.ib() + unrolled_run_gen = attr.ib() + _value_factory: Callable[[], Value] = lambda: Value(None) + unrolled_run_next_send = attr.ib(factory=_value_factory, type=Outcome) + + def guest_tick(self): + try: + timeout = self.unrolled_run_next_send.send(self.unrolled_run_gen) + except StopIteration: + self.done_callback(self.runner.main_task_outcome) + return + except TrioInternalError as exc: + self.done_callback(Error(exc)) + return + + # Optimization: try to skip going into the thread if we can avoid it + events_outcome = capture(self.runner.io_manager.get_events, 0) + if timeout <= 0 or isinstance(events_outcome, Error) or events_outcome.value: + # No need to go into the thread + self.unrolled_run_next_send = events_outcome + self.runner.guest_tick_scheduled = True + self.run_sync_soon_not_threadsafe(self.guest_tick) + else: + # Need to go into the thread and call get_events() there + self.runner.guest_tick_scheduled = False + + def get_events(): + return self.runner.io_manager.get_events(timeout) + + def deliver(events_outcome): + def in_main_thread(): + self.unrolled_run_next_send = events_outcome + self.runner.guest_tick_scheduled = True + self.guest_tick() + + self.run_sync_soon_threadsafe(in_main_thread) + + start_thread_soon(get_events, deliver) + + +@attr.s(eq=False, hash=False, slots=True) class Runner: clock = attr.ib() - instruments = attr.ib() + instruments: Instruments = attr.ib() io_manager = attr.ib() + ki_manager = attr.ib() + strict_exception_groups = attr.ib() # Run-local values, see _local.py _locals = attr.ib(factory=dict) - runq = attr.ib(factory=deque) + runq: deque[Task] = attr.ib(factory=deque) tasks = attr.ib(factory=set) - # {(deadline, id(CancelScope)): CancelScope} - # only contains scopes with non-infinite deadlines that are currently - # attached to at least one task - deadlines = attr.ib(factory=SortedDict) + deadlines = attr.ib(factory=Deadlines) init_task = attr.ib(default=None) system_nursery = attr.ib(default=None) @@ -1126,12 +1453,29 @@ class Runner: entry_queue = attr.ib(factory=EntryQueue) trio_token = attr.ib(default=None) + asyncgens = attr.ib(factory=AsyncGenerators) + + # If everything goes idle for this long, we call clock._autojump() + clock_autojump_threshold = attr.ib(default=inf) + + # Guest mode stuff + is_guest = attr.ib(default=False) + guest_tick_scheduled = attr.ib(default=False) + + def force_guest_tick_asap(self): + if self.guest_tick_scheduled: + return + self.guest_tick_scheduled = True + self.io_manager.force_wakeup() def close(self): self.io_manager.close() self.entry_queue.close() - if self.instruments: - self.instrument("after_run") + self.asyncgens.close() + if "after_run" in self.instruments: + self.instruments.call("after_run") + # This is where KI protection gets disabled, so we do it last + self.ki_manager.close() @_public def current_statistics(self): @@ -1150,18 +1494,14 @@ def current_statistics(self): :data:`~math.inf` if there are no pending deadlines. * ``run_sync_soon_queue_size`` (int): The number of unprocessed callbacks queued via - :meth:`trio.hazmat.TrioToken.run_sync_soon`. + :meth:`trio.lowlevel.TrioToken.run_sync_soon`. * ``io_statistics`` (object): Some statistics from Trio's I/O backend. This always has an attribute ``backend`` which is a string naming which operating-system-specific I/O backend is in use; the other attributes vary between backends. """ - if self.deadlines: - next_deadline, _ = self.deadlines.keys()[0] - seconds_to_next_deadline = next_deadline - self.current_time() - else: - seconds_to_next_deadline = float("inf") + seconds_to_next_deadline = self.deadlines.next_deadline() - self.current_time() return _RunStatistics( tasks_living=len(self.tasks), tasks_runnable=len(self.runq), @@ -1185,9 +1525,7 @@ def current_time(self): @_public def current_clock(self): - """Returns the current :class:`~trio.abc.Clock`. - - """ + """Returns the current :class:`~trio.abc.Clock`.""" return self.clock @_public @@ -1216,7 +1554,7 @@ def reschedule(self, task, next_send=_NO_SEND): to calling :func:`reschedule` once.) Args: - task (trio.hazmat.Task): the task to be rescheduled. Must be blocked + task (trio.lowlevel.Task): the task to be rescheduled. Must be blocked in a call to :func:`wait_task_rescheduled`. next_send (outcome.Outcome): the value (or error) to return (or raise) from :func:`wait_task_rescheduled`. @@ -1231,12 +1569,15 @@ def reschedule(self, task, next_send=_NO_SEND): task._next_send = next_send task._abort_func = None task.custom_sleep_data = None + if not self.runq and self.is_guest: + self.force_guest_tick_asap() self.runq.append(task) - if self.instruments: - self.instrument("task_scheduled", task) - - def spawn_impl(self, async_fn, args, nursery, name, *, system_task=False): + if "task_scheduled" in self.instruments: + self.instruments.call("task_scheduled", task) + def spawn_impl( + self, async_fn, args, nursery, name, *, system_task=False, context=None + ): ###### # Make sure the nursery is in working order ###### @@ -1250,92 +1591,23 @@ def spawn_impl(self, async_fn, args, nursery, name, *, system_task=False): assert self.init_task is None ###### - # Call the function and get the coroutine object, while giving helpful - # errors for common mistakes. + # Propagate contextvars, and make sure that async_fn can use sniffio. ###### - - def _return_value_looks_like_wrong_library(value): - # Returned by legacy @asyncio.coroutine functions, which includes - # a surprising proportion of asyncio builtins. - if isinstance(value, collections.abc.Generator): - return True - # The protocol for detecting an asyncio Future-like object - if getattr(value, "_asyncio_future_blocking", None) is not None: - return True - # asyncio.Future doesn't have _asyncio_future_blocking until - # 3.5.3. We don't want to import asyncio, but this janky check - # should work well enough for our purposes. And it also catches - # tornado Futures and twisted Deferreds. By the time we're calling - # this function, we already know something has gone wrong, so a - # heuristic is pretty safe. - if value.__class__.__name__ in ("Future", "Deferred"): - return True - return False - - try: - coro = async_fn(*args) - except TypeError: - # Give good error for: nursery.start_soon(trio.sleep(1)) - if isinstance(async_fn, collections.abc.Coroutine): - raise TypeError( - "Trio was expecting an async function, but instead it got " - "a coroutine object {async_fn!r}\n" - "\n" - "Probably you did something like:\n" - "\n" - " trio.run({async_fn.__name__}(...)) # incorrect!\n" - " nursery.start_soon({async_fn.__name__}(...)) # incorrect!\n" - "\n" - "Instead, you want (notice the parentheses!):\n" - "\n" - " trio.run({async_fn.__name__}, ...) # correct!\n" - " nursery.start_soon({async_fn.__name__}, ...) # correct!" - .format(async_fn=async_fn) - ) from None - - # Give good error for: nursery.start_soon(future) - if _return_value_looks_like_wrong_library(async_fn): - raise TypeError( - "Trio was expecting an async function, but instead it got " - "{!r} – are you trying to use a library written for " - "asyncio/twisted/tornado or similar? That won't work " - "without some sort of compatibility shim." - .format(async_fn) - ) from None - - raise - - # We can't check iscoroutinefunction(async_fn), because that will fail - # for things like functools.partial objects wrapping an async - # function. So we have to just call it and then check whether the - # return value is a coroutine object. - if not isinstance(coro, collections.abc.Coroutine): - # Give good error for: nursery.start_soon(func_returning_future) - if _return_value_looks_like_wrong_library(coro): - raise TypeError( - "start_soon got unexpected {!r} – are you trying to use a " - "library written for asyncio/twisted/tornado or similar? " - "That won't work without some sort of compatibility shim." - .format(coro) - ) - - if isasyncgen(coro): - raise TypeError( - "start_soon expected an async function but got an async " - "generator {!r}".format(coro) - ) - - # Give good error for: nursery.start_soon(some_sync_fn) - raise TypeError( - "Trio expected an async function, but {!r} appears to be " - "synchronous".format( - getattr(async_fn, "__qualname__", async_fn) - ) - ) + if context is None: + if system_task: + context = self.system_context.copy() + else: + context = copy_context() + # start_soon() or spawn_system_task() might have been invoked + # from a different async library; make sure the new task + # understands it's Trio-flavored. + context.run(current_async_library_cvar.set, "trio") ###### - # Set up the Task object + # Call the function and get the coroutine object, while giving helpful + # errors for common mistakes. ###### + coro = context.run(coroutine_or_error, async_fn, *args) if name is None: name = async_fn @@ -1343,31 +1615,23 @@ def _return_value_looks_like_wrong_library(value): name = name.func if not isinstance(name, str): try: - name = "{}.{}".format(name.__module__, name.__qualname__) + name = f"{name.__module__}.{name.__qualname__}" except AttributeError: name = repr(name) - if system_task: - context = self.system_context.copy() - else: - context = copy_context() - if not hasattr(coro, "cr_frame"): # This async function is implemented in C or Cython async def python_wrapper(orig_coro): return await orig_coro coro = python_wrapper(coro) - coro.cr_frame.f_locals.setdefault( - LOCALS_KEY_KI_PROTECTION_ENABLED, system_task - ) + coro.cr_frame.f_locals.setdefault(LOCALS_KEY_KI_PROTECTION_ENABLED, system_task) - task = Task( - coro=coro, - parent_nursery=nursery, - runner=self, - name=name, - context=context, + ###### + # Set up the Task object + ###### + task = Task._create( + coro=coro, parent_nursery=nursery, runner=self, name=name, context=context ) self.tasks.add(task) @@ -1375,8 +1639,8 @@ async def python_wrapper(orig_coro): nursery._children.add(task) task._activate_cancel_status(nursery._cancel_status) - if self.instruments: - self.instrument("task_spawned", task) + if "task_spawned" in self.instruments: + self.instruments.call("task_spawned", task) # Special case: normally next_send should be an Outcome, but for the # very first send we have to send a literal unboxed None. self.reschedule(task, None) @@ -1408,11 +1672,7 @@ def task_exited(self, task, outcome): task._activate_cancel_status(None) self.tasks.remove(task) - if task is self.main_task: - self.main_task_outcome = outcome - self.system_nursery.cancel_scope.cancel() - self.system_nursery._child_finished(task, Value(None)) - elif task is self.init_task: + if task is self.init_task: # If the init task crashed, then something is very wrong and we # let the error propagate. (It'll eventually be wrapped in a # TrioInternalError.) @@ -1422,17 +1682,20 @@ def task_exited(self, task, outcome): if self.tasks: # pragma: no cover raise TrioInternalError else: + if task is self.main_task: + self.main_task_outcome = outcome + outcome = Value(None) task._parent_nursery._child_finished(task, outcome) - if self.instruments: - self.instrument("task_exited", task) + if "task_exited" in self.instruments: + self.instruments.call("task_exited", task) ################ # System tasks and init ################ @_public - def spawn_system_task(self, async_fn, *args, name=None): + def spawn_system_task(self, async_fn, *args, name=None, context=None): """Spawn a "system" task. System tasks have a few differences from regular tasks: @@ -1456,6 +1719,15 @@ def spawn_system_task(self, async_fn, *args, name=None): * System tasks do not inherit context variables from their creator. + Towards the end of a call to :meth:`trio.run`, after the main + task and all system tasks have exited, the system nursery + becomes closed. At this point, new calls to + :func:`spawn_system_task` will raise ``RuntimeError("Nursery + is closed to new arrivals")`` instead of creating a system + task. It's possible to encounter this state either in + a ``finally`` block in an async generator, or in a callback + passed to :meth:`TrioToken.run_sync_soon` at the right moment. + Args: async_fn: An async callable. args: Positional arguments for ``async_fn``. If you want to pass @@ -1466,26 +1738,56 @@ def spawn_system_task(self, async_fn, *args, name=None): case is if you're wrapping a function before spawning a new task, you might pass the original function as the ``name=`` to make debugging easier. + context: An optional ``contextvars.Context`` object with context variables + to use for this task. You would normally get a copy of the current + context with ``context = contextvars.copy_context()`` and then you would + pass that ``context`` object here. Returns: Task: the newly spawned task """ return self.spawn_impl( - async_fn, args, self.system_nursery, name, system_task=True + async_fn, + args, + self.system_nursery, + name, + system_task=True, + context=context, ) async def init(self, async_fn, args): - async with open_nursery() as system_nursery: - self.system_nursery = system_nursery - try: - self.main_task = self.spawn_impl( - async_fn, args, system_nursery, None - ) - except BaseException as exc: - self.main_task_outcome = Error(exc) - system_nursery.cancel_scope.cancel() - self.entry_queue.spawn() + # run_sync_soon task runs here: + async with open_nursery() as run_sync_soon_nursery: + # All other system tasks run here: + async with open_nursery() as self.system_nursery: + # Only the main task runs here: + async with open_nursery() as main_task_nursery: + try: + self.main_task = self.spawn_impl( + async_fn, args, main_task_nursery, None + ) + except BaseException as exc: + self.main_task_outcome = Error(exc) + return + self.spawn_impl( + self.entry_queue.task, + (), + run_sync_soon_nursery, + "", + system_task=True, + ) + + # Main task is done; start shutting down system tasks + self.system_nursery.cancel_scope.cancel() + + # System nursery is closed; finalize remaining async generators + await self.asyncgens.finalize_remaining(self) + + # There are no more asyncgens, which means no more user-provided + # code except possibly run_sync_soon callbacks. It's finally safe + # to stop the run_sync_soon task and exit run(). + run_sync_soon_nursery.cancel_scope.cancel() ################ # Outside context problems @@ -1498,7 +1800,7 @@ def current_trio_token(self): """ if self.trio_token is None: - self.trio_token = TrioToken(self.entry_queue) + self.trio_token = TrioToken._create(self.entry_queue) return self.trio_token ################ @@ -1541,7 +1843,7 @@ def _deliver_ki_cb(self): waiting_for_idle = attr.ib(factory=SortedDict) @_public - async def wait_all_tasks_blocked(self, cushion=0.0, tiebreaker=0): + async def wait_all_tasks_blocked(self, cushion=0.0): """Block until there are no runnable tasks. This is useful in testing code when you want to give other tasks a @@ -1559,9 +1861,7 @@ async def wait_all_tasks_blocked(self, cushion=0.0, tiebreaker=0): then the one with the shortest ``cushion`` is the one woken (and this task becoming unblocked resets the timers for the remaining tasks). If there are multiple tasks that have exactly the same - ``cushion``, then the one with the lowest ``tiebreaker`` value is - woken first. And if there are multiple tasks with the same ``cushion`` - and the same ``tiebreaker``, then all are woken. + ``cushion``, then all are woken. You should also consider :class:`trio.testing.Sequencer`, which provides a more explicit way to control execution ordering within a @@ -1584,25 +1884,25 @@ async def test_lock_fairness(): nursery.start_soon(lock_taker, lock) # child hasn't run yet, we have the lock assert lock.locked() - assert lock._owner is trio.hazmat.current_task() + assert lock._owner is trio.lowlevel.current_task() await trio.testing.wait_all_tasks_blocked() # now the child has run and is blocked on lock.acquire(), we # still have the lock assert lock.locked() - assert lock._owner is trio.hazmat.current_task() + assert lock._owner is trio.lowlevel.current_task() lock.release() try: # The child has a prior claim, so we can't have it lock.acquire_nowait() except trio.WouldBlock: - assert lock._owner is not trio.hazmat.current_task() + assert lock._owner is not trio.lowlevel.current_task() print("PASS") else: print("FAIL") """ task = current_task() - key = (cushion, tiebreaker, id(task)) + key = (cushion, id(task)) self.waiting_for_idle[key] = task def abort(_): @@ -1611,66 +1911,110 @@ def abort(_): await wait_task_rescheduled(abort) - ################ - # Instrumentation - ################ - - def instrument(self, method_name, *args): - if not self.instruments: - return - - for instrument in list(self.instruments): - try: - method = getattr(instrument, method_name) - except AttributeError: - continue - try: - method(*args) - except: - self.instruments.remove(instrument) - INSTRUMENT_LOGGER.exception( - "Exception raised when calling %r on instrument %r. " - "Instrument has been disabled.", method_name, instrument - ) - - @_public - def add_instrument(self, instrument): - """Start instrumenting the current run loop with the given instrument. - - Args: - instrument (trio.abc.Instrument): The instrument to activate. - - If ``instrument`` is already active, does nothing. - - """ - if instrument not in self.instruments: - self.instruments.append(instrument) - - @_public - def remove_instrument(self, instrument): - """Stop instrumenting the current run loop with the given instrument. - Args: - instrument (trio.abc.Instrument): The instrument to de-activate. +################################################################ +# run +################################################################ +# +# Trio's core task scheduler and coroutine runner is in 'unrolled_run'. It's +# called that because it has an unusual feature: it's actually a generator. +# Whenever it needs to fetch IO events from the OS, it yields, and waits for +# its caller to send the IO events back in. So the loop is "unrolled" into a +# sequence of generator send() calls. +# +# The reason for this unusual design is to support two different modes of +# operation, where the IO is handled differently. +# +# In normal mode using trio.run, the scheduler and IO run in the same thread: +# +# Main thread: +# +# +---------------------------+ +# | Run tasks | +# | (unrolled_run) | +# +---------------------------+ +# | Block waiting for I/O | +# | (io_manager.get_events) | +# +---------------------------+ +# | Run tasks | +# | (unrolled_run) | +# +---------------------------+ +# | Block waiting for I/O | +# | (io_manager.get_events) | +# +---------------------------+ +# : +# +# +# In guest mode using trio.lowlevel.start_guest_run, the scheduler runs on the +# main thread as a host loop callback, but blocking for IO gets pushed into a +# worker thread: +# +# Main thread executing host loop: Trio I/O thread: +# +# +---------------------------+ +# | Run Trio tasks | +# | (unrolled_run) | +# +---------------------------+ --------------+ +# v +# +---------------------------+ +----------------------------+ +# | Host loop does whatever | | Block waiting for Trio I/O | +# | it wants | | (io_manager.get_events) | +# +---------------------------+ +----------------------------+ +# | +# +---------------------------+ <-------------+ +# | Run Trio tasks | +# | (unrolled_run) | +# +---------------------------+ --------------+ +# v +# +---------------------------+ +----------------------------+ +# | Host loop does whatever | | Block waiting for Trio I/O | +# | it wants | | (io_manager.get_events) | +# +---------------------------+ +----------------------------+ +# : : +# +# Most of Trio's internals don't need to care about this difference. The main +# complication it creates is that in guest mode, we might need to wake up not +# just due to OS-reported IO events, but also because of code running on the +# host loop calling reschedule() or changing task deadlines. Search for +# 'is_guest' to see the special cases we need to handle this. + + +def setup_runner( + clock, + instruments, + restrict_keyboard_interrupt_to_checkpoints, + strict_exception_groups, +): + """Create a Runner object and install it as the GLOBAL_RUN_CONTEXT.""" + # It wouldn't be *hard* to support nested calls to run(), but I can't + # think of a single good reason for it, so let's be conservative for + # now: + if hasattr(GLOBAL_RUN_CONTEXT, "runner"): + raise RuntimeError("Attempted to call run() from inside a run()") - Raises: - KeyError: if the instrument is not currently active. This could - occur either because you never added it, or because you added it - and then it raised an unhandled exception and was automatically - deactivated. + if clock is None: + clock = SystemClock() + instruments = Instruments(instruments) + io_manager = TheIOManager() + system_context = copy_context() + ki_manager = KIManager() - """ - # We're moving 'instruments' to being a set, so raise KeyError like - # set.remove does. - try: - self.instruments.remove(instrument) - except ValueError as exc: - raise KeyError(*exc.args) + runner = Runner( + clock=clock, + instruments=instruments, + io_manager=io_manager, + system_context=system_context, + ki_manager=ki_manager, + strict_exception_groups=strict_exception_groups, + ) + runner.asyncgens.install_hooks(runner) + # This is where KI protection gets enabled, so we want to do it early - in + # particular before we start modifying global state like GLOBAL_RUN_CONTEXT + ki_manager.install(runner.deliver_ki, restrict_keyboard_interrupt_to_checkpoints) -################################################################ -# run -################################################################ + GLOBAL_RUN_CONTEXT.runner = runner + return runner def run( @@ -1678,7 +2022,8 @@ def run( *args, clock=None, instruments=(), - restrict_keyboard_interrupt_to_checkpoints=False + restrict_keyboard_interrupt_to_checkpoints: bool = False, + strict_exception_groups: bool = False, ): """Run a Trio-flavored async function, and return the result. @@ -1735,6 +2080,10 @@ def run( main thread (this is a Python limitation), or if you use :func:`open_signal_receiver` to catch SIGINT. + strict_exception_groups (bool): If true, nurseries will always wrap even a single + raised exception in an exception group. This can be overridden on the level of + individual nurseries. This will eventually become the default behavior. + Returns: Whatever ``async_fn`` returns. @@ -1750,226 +2099,361 @@ def run( __tracebackhide__ = True - # Do error-checking up front, before we enter the TrioInternalError - # try/catch - # - # It wouldn't be *hard* to support nested calls to run(), but I can't - # think of a single good reason for it, so let's be conservative for - # now: - if hasattr(GLOBAL_RUN_CONTEXT, "runner"): - raise RuntimeError("Attempted to call run() from inside a run()") - - if clock is None: - clock = SystemClock() - instruments = list(instruments) - io_manager = TheIOManager() - system_context = copy_context() - system_context.run(current_async_library_cvar.set, "trio") - runner = Runner( - clock=clock, - instruments=instruments, - io_manager=io_manager, - system_context=system_context, + runner = setup_runner( + clock, + instruments, + restrict_keyboard_interrupt_to_checkpoints, + strict_exception_groups, ) - GLOBAL_RUN_CONTEXT.runner = runner - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True - # KI handling goes outside the core try/except/finally to avoid a window - # where KeyboardInterrupt would be allowed and converted into an - # TrioInternalError: - try: - with ki_manager( - runner.deliver_ki, restrict_keyboard_interrupt_to_checkpoints - ): - try: - with closing(runner): - with runner.entry_queue.wakeup.wakeup_on_signals(): - # The main reason this is split off into its own - # function is just to get rid of this extra - # indentation. - run_impl(runner, async_fn, args) - except TrioInternalError: - raise - except BaseException as exc: - raise TrioInternalError( - "internal error in Trio - please file a bug!" - ) from exc - finally: - GLOBAL_RUN_CONTEXT.__dict__.clear() - # Inlined copy of runner.main_task_outcome.unwrap() to avoid - # cluttering every single Trio traceback with an extra frame. - if type(runner.main_task_outcome) is Value: - return runner.main_task_outcome.value - else: - raise runner.main_task_outcome.error - finally: - # To guarantee that we never swallow a KeyboardInterrupt, we have to - # check for pending ones once more after leaving the context manager: - if runner.ki_pending: - # Implicitly chains with any exception from outcome.unwrap(): - raise KeyboardInterrupt + gen = unrolled_run(runner, async_fn, args) + next_send = None + while True: + try: + timeout = gen.send(next_send) + except StopIteration: + break + next_send = runner.io_manager.get_events(timeout) + # Inlined copy of runner.main_task_outcome.unwrap() to avoid + # cluttering every single Trio traceback with an extra frame. + if isinstance(runner.main_task_outcome, Value): + return runner.main_task_outcome.value + else: + raise runner.main_task_outcome.error + + +def start_guest_run( + async_fn, + *args, + run_sync_soon_threadsafe, + done_callback, + run_sync_soon_not_threadsafe=None, + host_uses_signal_set_wakeup_fd: bool = False, + clock=None, + instruments=(), + restrict_keyboard_interrupt_to_checkpoints: bool = False, + strict_exception_groups: bool = False, +): + """Start a "guest" run of Trio on top of some other "host" event loop. + + Each host loop can only have one guest run at a time. + + You should always let the Trio run finish before stopping the host loop; + if not, it may leave Trio's internal data structures in an inconsistent + state. You might be able to get away with it if you immediately exit the + program, but it's safest not to go there in the first place. + + Generally, the best way to do this is wrap this in a function that starts + the host loop and then immediately starts the guest run, and then shuts + down the host when the guest run completes. + + Args: + + run_sync_soon_threadsafe: An arbitrary callable, which will be passed a + function as its sole argument:: + + def my_run_sync_soon_threadsafe(fn): + ... + + This callable should schedule ``fn()`` to be run by the host on its + next pass through its loop. **Must support being called from + arbitrary threads.** + + done_callback: An arbitrary callable:: + + def my_done_callback(run_outcome): + ... + + When the Trio run has finished, Trio will invoke this callback to let + you know. The argument is an `outcome.Outcome`, reporting what would + have been returned or raised by `trio.run`. This function can do + anything you want, but commonly you'll want it to shut down the + host loop, unwrap the outcome, etc. + + run_sync_soon_not_threadsafe: Like ``run_sync_soon_threadsafe``, but + will only be called from inside the host loop's main thread. + Optional, but if your host loop allows you to implement this more + efficiently than ``run_sync_soon_threadsafe`` then passing it will + make things a bit faster. + + host_uses_signal_set_wakeup_fd (bool): Pass `True` if your host loop + uses `signal.set_wakeup_fd`, and `False` otherwise. For more details, + see :ref:`guest-run-implementation`. + + For the meaning of other arguments, see `trio.run`. + + """ + runner = setup_runner( + clock, + instruments, + restrict_keyboard_interrupt_to_checkpoints, + strict_exception_groups, + ) + runner.is_guest = True + runner.guest_tick_scheduled = True + + if run_sync_soon_not_threadsafe is None: + run_sync_soon_not_threadsafe = run_sync_soon_threadsafe + + guest_state = GuestState( + runner=runner, + run_sync_soon_threadsafe=run_sync_soon_threadsafe, + run_sync_soon_not_threadsafe=run_sync_soon_not_threadsafe, + done_callback=done_callback, + unrolled_run_gen=unrolled_run( + runner, + async_fn, + args, + host_uses_signal_set_wakeup_fd=host_uses_signal_set_wakeup_fd, + ), + ) + run_sync_soon_not_threadsafe(guest_state.guest_tick) # 24 hours is arbitrary, but it avoids issues like people setting timeouts of # 10**20 and then getting integer overflows in the underlying system calls. -_MAX_TIMEOUT = 24 * 60 * 60 +_MAX_TIMEOUT: FinalT = 24 * 60 * 60 -def run_impl(runner, async_fn, args): +# Weird quirk: this is written as a generator in order to support "guest +# mode", where our core event loop gets unrolled into a series of callbacks on +# the host loop. If you're doing a regular trio.run then this gets run +# straight through. +def unrolled_run( + runner: Runner, + async_fn, + args, + host_uses_signal_set_wakeup_fd: bool = False, +): + locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True __tracebackhide__ = True - if runner.instruments: - runner.instrument("before_run") - runner.clock.start_clock() - runner.init_task = runner.spawn_impl( - runner.init, - (async_fn, args), - None, - "", - system_task=True, - ) + try: + if not host_uses_signal_set_wakeup_fd: + runner.entry_queue.wakeup.wakeup_on_signals() + + if "before_run" in runner.instruments: + runner.instruments.call("before_run") + runner.clock.start_clock() + runner.init_task = runner.spawn_impl( + runner.init, (async_fn, args), None, "", system_task=True + ) - # You know how people talk about "event loops"? This 'while' loop right - # here is our event loop: - while runner.tasks: - if runner.runq: - timeout = 0 - elif runner.deadlines: - deadline, _ = runner.deadlines.keys()[0] - timeout = runner.clock.deadline_to_sleep_time(deadline) - else: - timeout = _MAX_TIMEOUT - timeout = min(max(0, timeout), _MAX_TIMEOUT) - - idle_primed = False - if runner.waiting_for_idle: - cushion, tiebreaker, _ = runner.waiting_for_idle.keys()[0] - if cushion < timeout: - timeout = cushion - idle_primed = True - - if runner.instruments: - runner.instrument("before_io_wait", timeout) - - runner.io_manager.handle_io(timeout) - - if runner.instruments: - runner.instrument("after_io_wait", timeout) - - # Process cancellations due to deadline expiry - now = runner.clock.current_time() - while runner.deadlines: - (deadline, _), cancel_scope = runner.deadlines.peekitem(0) - if deadline <= now: - # This removes the given scope from runner.deadlines: - cancel_scope.cancel() - idle_primed = False + # You know how people talk about "event loops"? This 'while' loop right + # here is our event loop: + while runner.tasks: + if runner.runq: + timeout: float = 0 else: - break - - if not runner.runq and idle_primed: - while runner.waiting_for_idle: - key, task = runner.waiting_for_idle.peekitem(0) - if key[:2] == (cushion, tiebreaker): - del runner.waiting_for_idle[key] - runner.reschedule(task) + deadline = runner.deadlines.next_deadline() + timeout = runner.clock.deadline_to_sleep_time(deadline) + timeout = min(max(0, timeout), _MAX_TIMEOUT) + + idle_primed = None + if runner.waiting_for_idle: + cushion, _ = runner.waiting_for_idle.keys()[0] + if cushion < timeout: + timeout = cushion + idle_primed = IdlePrimedTypes.WAITING_FOR_IDLE + # We use 'elif' here because if there are tasks in + # wait_all_tasks_blocked, then those tasks will wake up without + # jumping the clock, so we don't need to autojump. + elif runner.clock_autojump_threshold < timeout: + timeout = runner.clock_autojump_threshold + idle_primed = IdlePrimedTypes.AUTOJUMP_CLOCK + + if "before_io_wait" in runner.instruments: + runner.instruments.call("before_io_wait", timeout) + + # Driver will call io_manager.get_events(timeout) and pass it back + # in through the yield + events = yield timeout + runner.io_manager.process_events(events) + + if "after_io_wait" in runner.instruments: + runner.instruments.call("after_io_wait", timeout) + + # Process cancellations due to deadline expiry + now = runner.clock.current_time() + if runner.deadlines.expire(now): + idle_primed = None + + # idle_primed != None means: if the IO wait hit the timeout, and + # still nothing is happening, then we should start waking up + # wait_all_tasks_blocked tasks or autojump the clock. But there + # are some subtleties in defining "nothing is happening". + # + # 'not runner.runq' means that no tasks are currently runnable. + # 'not events' means that the last IO wait call hit its full + # timeout. These are very similar, and if idle_primed != None and + # we're running in regular mode then they always go together. But, + # in *guest* mode, they can happen independently, even when + # idle_primed=True: + # + # - runner.runq=empty and events=True: the host loop adjusted a + # deadline and that forced an IO wakeup before the timeout expired, + # even though no actual tasks were scheduled. + # + # - runner.runq=nonempty and events=False: the IO wait hit its + # timeout, but then some code in the host thread rescheduled a task + # before we got here. + # + # So we need to check both. + if idle_primed is not None and not runner.runq and not events: + if idle_primed is IdlePrimedTypes.WAITING_FOR_IDLE: + while runner.waiting_for_idle: + key, task = runner.waiting_for_idle.peekitem(0) + if key[0] == cushion: + del runner.waiting_for_idle[key] + runner.reschedule(task) + else: + break else: - break - - # Process all runnable tasks, but only the ones that are already - # runnable now. Anything that becomes runnable during this cycle needs - # to wait until the next pass. This avoids various starvation issues - # by ensuring that there's never an unbounded delay between successive - # checks for I/O. - # - # Also, we randomize the order of each batch to avoid assumptions - # about scheduling order sneaking in. In the long run, I suspect we'll - # either (a) use strict FIFO ordering and document that for - # predictability/determinism, or (b) implement a more sophisticated - # scheduler (e.g. some variant of fair queueing), for better behavior - # under load. For now, this is the worst of both worlds - but it keeps - # our options open. (If we do decide to go all in on deterministic - # scheduling, then there are other things that will probably need to - # change too, like the deadlines tie-breaker and the non-deterministic - # ordering of task._notify_queues.) - batch = list(runner.runq) - if _ALLOW_DETERMINISTIC_SCHEDULING: - # We're running under Hypothesis, and pytest-trio has patched this - # in to make the scheduler deterministic and avoid flaky tests. - # It's not worth the (small) performance cost in normal operation, - # since we'll shuffle the list and _r is only seeded for tests. - batch.sort(key=lambda t: t._counter) - runner.runq.clear() - _r.shuffle(batch) - while batch: - task = batch.pop() - GLOBAL_RUN_CONTEXT.task = task - - if runner.instruments: - runner.instrument("before_task_step", task) - - next_send_fn = task._next_send_fn - next_send = task._next_send - task._next_send_fn = task._next_send = None - final_outcome = None - try: - # We used to unwrap the Outcome object here and send/throw its - # contents in directly, but it turns out that .throw() is - # buggy, at least on CPython 3.6 and earlier: - # https://bugs.python.org/issue29587 - # https://bugs.python.org/issue29590 - # So now we send in the Outcome object and unwrap it on the - # other side. - msg = task.context.run(next_send_fn, next_send) - except StopIteration as stop_iteration: - final_outcome = Value(stop_iteration.value) - except BaseException as task_exc: - # Store for later, removing uninteresting top frames: 1 frame - # we always remove, because it's this function catching it, - # and then in addition we remove however many more Context.run - # adds. - tb = task_exc.__traceback__.tb_next - for _ in range(CONTEXT_RUN_TB_FRAMES): - tb = tb.tb_next - final_outcome = Error(task_exc.with_traceback(tb)) - - if final_outcome is not None: - # We can't call this directly inside the except: blocks above, - # because then the exceptions end up attaching themselves to - # other exceptions as __context__ in unwanted ways. - runner.task_exited(task, final_outcome) + assert idle_primed is IdlePrimedTypes.AUTOJUMP_CLOCK + runner.clock._autojump() + + # Process all runnable tasks, but only the ones that are already + # runnable now. Anything that becomes runnable during this cycle + # needs to wait until the next pass. This avoids various + # starvation issues by ensuring that there's never an unbounded + # delay between successive checks for I/O. + # + # Also, we randomize the order of each batch to avoid assumptions + # about scheduling order sneaking in. In the long run, I suspect + # we'll either (a) use strict FIFO ordering and document that for + # predictability/determinism, or (b) implement a more + # sophisticated scheduler (e.g. some variant of fair queueing), + # for better behavior under load. For now, this is the worst of + # both worlds - but it keeps our options open. (If we do decide to + # go all in on deterministic scheduling, then there are other + # things that will probably need to change too, like the deadlines + # tie-breaker and the non-deterministic ordering of + # task._notify_queues.) + batch = list(runner.runq) + runner.runq.clear() + if _ALLOW_DETERMINISTIC_SCHEDULING: + # We're running under Hypothesis, and pytest-trio has patched + # this in to make the scheduler deterministic and avoid flaky + # tests. It's not worth the (small) performance cost in normal + # operation, since we'll shuffle the list and _r is only + # seeded for tests. + batch.sort(key=lambda t: t._counter) + _r.shuffle(batch) else: - task._schedule_points += 1 - if msg is CancelShieldedCheckpoint: - runner.reschedule(task) - elif type(msg) is WaitTaskRescheduled: - task._cancel_points += 1 - task._abort_func = msg.abort_func - # KI is "outside" all cancel scopes, so check for it - # before checking for regular cancellation: - if runner.ki_pending and task is runner.main_task: - task._attempt_delivery_of_pending_ki() - task._attempt_delivery_of_any_pending_cancel() - elif type(msg) is PermanentlyDetachCoroutineObject: - # Pretend the task just exited with the given outcome - runner.task_exited(task, msg.final_outcome) + # 50% chance of reversing the batch, this way each task + # can appear before/after any other task. + if _r.random() < 0.5: + batch.reverse() + while batch: + task = batch.pop() + GLOBAL_RUN_CONTEXT.task = task + + if "before_task_step" in runner.instruments: + runner.instruments.call("before_task_step", task) + + next_send_fn = task._next_send_fn + next_send = task._next_send + task._next_send_fn = task._next_send = None + final_outcome = None + try: + # We used to unwrap the Outcome object here and send/throw + # its contents in directly, but it turns out that .throw() + # is buggy, at least before CPython 3.9: + # https://bugs.python.org/issue29587 + # https://bugs.python.org/issue29590 + # So now we send in the Outcome object and unwrap it on the + # other side. + msg = task.context.run(next_send_fn, next_send) + except StopIteration as stop_iteration: + final_outcome = Value(stop_iteration.value) + except BaseException as task_exc: + # Store for later, removing uninteresting top frames: 1 + # frame we always remove, because it's this function + # catching it, and then in addition we remove however many + # more Context.run adds. + tb = task_exc.__traceback__ + for _ in range(1 + CONTEXT_RUN_TB_FRAMES): + if tb is None: + break + tb = tb.tb_next + final_outcome = Error(task_exc.with_traceback(tb)) + # Remove local refs so that e.g. cancelled coroutine locals + # are not kept alive by this frame until another exception + # comes along. + del tb + + if final_outcome is not None: + # We can't call this directly inside the except: blocks + # above, because then the exceptions end up attaching + # themselves to other exceptions as __context__ in + # unwanted ways. + runner.task_exited(task, final_outcome) + # final_outcome may contain a traceback ref. It's not as + # crucial compared to the above, but this will allow more + # prompt release of resources in coroutine locals. + final_outcome = None else: - exc = TypeError( - "trio.run received unrecognized yield message {!r}. " - "Are you trying to use a library written for some " - "other framework like asyncio? That won't work " - "without some kind of compatibility shim.".format(msg) - ) - # The foreign library probably doesn't adhere to our - # protocol of unwrapping whatever outcome gets sent in. - # Instead, we'll arrange to throw `exc` in directly, - # which works for at least asyncio and curio. - runner.reschedule(task, exc) - task._next_send_fn = task.coro.throw - - if runner.instruments: - runner.instrument("after_task_step", task) - del GLOBAL_RUN_CONTEXT.task + task._schedule_points += 1 + if msg is CancelShieldedCheckpoint: + runner.reschedule(task) + elif type(msg) is WaitTaskRescheduled: + task._cancel_points += 1 + task._abort_func = msg.abort_func + # KI is "outside" all cancel scopes, so check for it + # before checking for regular cancellation: + if runner.ki_pending and task is runner.main_task: + task._attempt_delivery_of_pending_ki() + task._attempt_delivery_of_any_pending_cancel() + elif type(msg) is PermanentlyDetachCoroutineObject: + # Pretend the task just exited with the given outcome + runner.task_exited(task, msg.final_outcome) + else: + exc = TypeError( + "trio.run received unrecognized yield message {!r}. " + "Are you trying to use a library written for some " + "other framework like asyncio? That won't work " + "without some kind of compatibility shim.".format(msg) + ) + # The foreign library probably doesn't adhere to our + # protocol of unwrapping whatever outcome gets sent in. + # Instead, we'll arrange to throw `exc` in directly, + # which works for at least asyncio and curio. + runner.reschedule(task, exc) + task._next_send_fn = task.coro.throw + # prevent long-lived reference + # TODO: develop test for this deletion + del msg + + if "after_task_step" in runner.instruments: + runner.instruments.call("after_task_step", task) + del GLOBAL_RUN_CONTEXT.task + # prevent long-lived references + # TODO: develop test for these deletions + del task, next_send, next_send_fn + + except GeneratorExit: + # The run-loop generator has been garbage collected without finishing + warnings.warn( + RuntimeWarning( + "Trio guest run got abandoned without properly finishing... " + "weird stuff might happen" + ) + ) + except TrioInternalError: + raise + except BaseException as exc: + raise TrioInternalError("internal error in Trio - please file a bug!") from exc + finally: + GLOBAL_RUN_CONTEXT.__dict__.clear() + runner.close() + # Have to do this after runner.close() has disabled KI protection, + # because otherwise there's a race where ki_pending could get set + # after we check it. + if runner.ki_pending: + ki = KeyboardInterrupt() + if isinstance(runner.main_task_outcome, Error): + ki.__context__ = runner.main_task_outcome.error + runner.main_task_outcome = Error(ki) ################################################################ @@ -1978,17 +2462,17 @@ def run_impl(runner, async_fn, args): class _TaskStatusIgnored: - def __repr__(self): + def __repr__(self) -> str: return "TASK_STATUS_IGNORED" - def started(self, value=None): + def started(self, value: object = None) -> None: pass -TASK_STATUS_IGNORED = _TaskStatusIgnored() +TASK_STATUS_IGNORED: FinalT = _TaskStatusIgnored() -def current_task(): +def current_task() -> Task: """Return the :class:`Task` object representing the current task. Returns: @@ -2002,7 +2486,7 @@ def current_task(): raise RuntimeError("must be called from async context") from None -def current_effective_deadline(): +def current_effective_deadline() -> float: """Returns the current effective deadline for the current task. This function examines all the cancellation scopes that are currently in @@ -2029,7 +2513,7 @@ def current_effective_deadline(): return current_task()._cancel_status.effective_deadline() -async def checkpoint(): +async def checkpoint() -> None: """A pure :ref:`checkpoint `. This checks for cancellation and allows other tasks to be scheduled, @@ -2043,18 +2527,27 @@ async def checkpoint(): :func:`checkpoint`.) """ - with CancelScope(deadline=-inf): - await _core.wait_task_rescheduled(lambda _: _core.Abort.SUCCEEDED) + # The scheduler is what checks timeouts and converts them into + # cancellations. So by doing the schedule point first, we ensure that the + # cancel point has the most up-to-date info. + await cancel_shielded_checkpoint() + task = current_task() + task._cancel_points += 1 + if task._cancel_status.effectively_cancelled or ( + task is task._runner.main_task and task._runner.ki_pending + ): + with CancelScope(deadline=-inf): + await _core.wait_task_rescheduled(lambda _: _core.Abort.SUCCEEDED) -async def checkpoint_if_cancelled(): +async def checkpoint_if_cancelled() -> None: """Issue a :ref:`checkpoint ` if the calling context has been cancelled. Equivalent to (but potentially more efficient than):: - if trio.current_deadline() == -inf: - await trio.hazmat.checkpoint() + if trio.current_effective_deadline() == -inf: + await trio.lowlevel.checkpoint() This is either a no-op, or else it allow other tasks to be scheduled and then raises :exc:`trio.Cancelled`. @@ -2063,25 +2556,25 @@ async def checkpoint_if_cancelled(): """ task = current_task() - if ( - task._cancel_status.effectively_cancelled or - (task is task._runner.main_task and task._runner.ki_pending) + if task._cancel_status.effectively_cancelled or ( + task is task._runner.main_task and task._runner.ki_pending ): await _core.checkpoint() assert False # pragma: no cover task._cancel_points += 1 -if os.name == "nt": - from ._io_windows import WindowsIOManager as TheIOManager +if sys.platform == "win32": from ._generated_io_windows import * -elif hasattr(select, "epoll"): - from ._io_epoll import EpollIOManager as TheIOManager + from ._io_windows import WindowsIOManager as TheIOManager +elif sys.platform == "linux" or (not TYPE_CHECKING and hasattr(select, "epoll")): from ._generated_io_epoll import * -elif hasattr(select, "kqueue"): - from ._io_kqueue import KqueueIOManager as TheIOManager + from ._io_epoll import EpollIOManager as TheIOManager +elif TYPE_CHECKING or hasattr(select, "kqueue"): from ._generated_io_kqueue import * + from ._io_kqueue import KqueueIOManager as TheIOManager else: # pragma: no cover raise NotImplementedError("unsupported platform") +from ._generated_instrumentation import * from ._generated_run import * diff --git a/trio/_core/tests/__init__.py b/trio/_core/_tests/__init__.py similarity index 100% rename from trio/_core/tests/__init__.py rename to trio/_core/_tests/__init__.py diff --git a/trio/_core/_tests/test_asyncgen.py b/trio/_core/_tests/test_asyncgen.py new file mode 100644 index 0000000000..f72d5c6859 --- /dev/null +++ b/trio/_core/_tests/test_asyncgen.py @@ -0,0 +1,322 @@ +import contextlib +import sys +import weakref +from math import inf + +import pytest + +from ... import _core +from .tutil import buggy_pypy_asyncgens, gc_collect_harder, restore_unraisablehook + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="no aclosing() in stdlib<3.10") +def test_asyncgen_basics(): + collected = [] + + async def example(cause): + try: + try: + yield 42 + except GeneratorExit: + pass + await _core.checkpoint() + except _core.Cancelled: + assert "exhausted" not in cause + task_name = _core.current_task().name + assert cause in task_name or task_name == "" + assert _core.current_effective_deadline() == -inf + with pytest.raises(_core.Cancelled): + await _core.checkpoint() + collected.append(cause) + else: + assert "async_main" in _core.current_task().name + assert "exhausted" in cause + assert _core.current_effective_deadline() == inf + await _core.checkpoint() + collected.append(cause) + + saved = [] + + async def async_main(): + # GC'ed before exhausted + with pytest.warns( + ResourceWarning, match="Async generator.*collected before.*exhausted" + ): + assert 42 == await example("abandoned").asend(None) + gc_collect_harder() + await _core.wait_all_tasks_blocked() + assert collected.pop() == "abandoned" + + # aclosing() ensures it's cleaned up at point of use + async with contextlib.aclosing(example("exhausted 1")) as aiter: + assert 42 == await aiter.asend(None) + assert collected.pop() == "exhausted 1" + + # Also fine if you exhaust it at point of use + async for val in example("exhausted 2"): + assert val == 42 + assert collected.pop() == "exhausted 2" + + gc_collect_harder() + + # No problems saving the geniter when using either of these patterns + async with contextlib.aclosing(example("exhausted 3")) as aiter: + saved.append(aiter) + assert 42 == await aiter.asend(None) + assert collected.pop() == "exhausted 3" + + # Also fine if you exhaust it at point of use + saved.append(example("exhausted 4")) + async for val in saved[-1]: + assert val == 42 + assert collected.pop() == "exhausted 4" + + # Leave one referenced-but-unexhausted and make sure it gets cleaned up + if buggy_pypy_asyncgens: + collected.append("outlived run") + else: + saved.append(example("outlived run")) + assert 42 == await saved[-1].asend(None) + assert collected == [] + + _core.run(async_main) + assert collected.pop() == "outlived run" + for agen in saved: + assert agen.ag_frame is None # all should now be exhausted + + +async def test_asyncgen_throws_during_finalization(caplog): + record = [] + + async def agen(): + try: + yield 1 + finally: + await _core.cancel_shielded_checkpoint() + record.append("crashing") + raise ValueError("oops") + + with restore_unraisablehook(): + await agen().asend(None) + gc_collect_harder() + await _core.wait_all_tasks_blocked() + assert record == ["crashing"] + exc_type, exc_value, exc_traceback = caplog.records[0].exc_info + assert exc_type is ValueError + assert str(exc_value) == "oops" + assert "during finalization of async generator" in caplog.records[0].message + + +@pytest.mark.skipif(buggy_pypy_asyncgens, reason="pypy 7.2.0 is buggy") +def test_firstiter_after_closing(): + saved = [] + record = [] + + async def funky_agen(): + try: + yield 1 + except GeneratorExit: + record.append("cleanup 1") + raise + try: + yield 2 + finally: + record.append("cleanup 2") + await funky_agen().asend(None) + + async def async_main(): + aiter = funky_agen() + saved.append(aiter) + assert 1 == await aiter.asend(None) + assert 2 == await aiter.asend(None) + + _core.run(async_main) + assert record == ["cleanup 2", "cleanup 1"] + + +@pytest.mark.skipif(buggy_pypy_asyncgens, reason="pypy 7.2.0 is buggy") +def test_interdependent_asyncgen_cleanup_order(): + saved = [] + record = [] + + async def innermost(): + try: + yield 1 + finally: + await _core.cancel_shielded_checkpoint() + record.append("innermost") + + async def agen(label, inner): + try: + yield await inner.asend(None) + finally: + # Either `inner` has already been cleaned up, or + # we're about to exhaust it. Either way, we wind + # up with `record` containing the labels in + # innermost-to-outermost order. + with pytest.raises(StopAsyncIteration): + await inner.asend(None) + record.append(label) + + async def async_main(): + # This makes a chain of 101 interdependent asyncgens: + # agen(99)'s cleanup will iterate agen(98)'s will iterate + # ... agen(0)'s will iterate innermost()'s + ag_chain = innermost() + for idx in range(100): + ag_chain = agen(idx, ag_chain) + saved.append(ag_chain) + assert 1 == await ag_chain.asend(None) + assert record == [] + + _core.run(async_main) + assert record == ["innermost"] + list(range(100)) + + +@restore_unraisablehook() +def test_last_minute_gc_edge_case(): + saved = [] + record = [] + needs_retry = True + + async def agen(): + try: + yield 1 + finally: + record.append("cleaned up") + + def collect_at_opportune_moment(token): + runner = _core._run.GLOBAL_RUN_CONTEXT.runner + if runner.system_nursery._closed and isinstance( + runner.asyncgens.alive, weakref.WeakSet + ): + saved.clear() + record.append("final collection") + gc_collect_harder() + record.append("done") + else: + try: + token.run_sync_soon(collect_at_opportune_moment, token) + except _core.RunFinishedError: # pragma: no cover + nonlocal needs_retry + needs_retry = True + + async def async_main(): + token = _core.current_trio_token() + token.run_sync_soon(collect_at_opportune_moment, token) + saved.append(agen()) + await saved[-1].asend(None) + + # Actually running into the edge case requires that the run_sync_soon task + # execute in between the system nursery's closure and the strong-ification + # of runner.asyncgens. There's about a 25% chance that it doesn't + # (if the run_sync_soon task runs before init on one tick and after init + # on the next tick); if we try enough times, we can make the chance of + # failure as small as we want. + for attempt in range(50): + needs_retry = False + del record[:] + del saved[:] + _core.run(async_main) + if needs_retry: # pragma: no cover + if not buggy_pypy_asyncgens: + assert record == ["cleaned up"] + else: + assert record == ["final collection", "done", "cleaned up"] + break + else: # pragma: no cover + pytest.fail( + "Didn't manage to hit the trailing_finalizer_asyncgens case " + f"despite trying {attempt} times" + ) + + +async def step_outside_async_context(aiter): + # abort_fns run outside of task context, at least if they're + # triggered by a deadline expiry rather than a direct + # cancellation. Thus, an asyncgen first iterated inside one + # will appear non-Trio, and since no other hooks were installed, + # will use the last-ditch fallback handling (that tries to mimic + # CPython's behavior with no hooks). + # + # NB: the strangeness with aiter being an attribute of abort_fn is + # to make it as easy as possible to ensure we don't hang onto a + # reference to aiter inside the guts of the run loop. + def abort_fn(_): + with pytest.raises(StopIteration, match="42"): + abort_fn.aiter.asend(None).send(None) + del abort_fn.aiter + return _core.Abort.SUCCEEDED + + abort_fn.aiter = aiter + + async with _core.open_nursery() as nursery: + nursery.start_soon(_core.wait_task_rescheduled, abort_fn) + await _core.wait_all_tasks_blocked() + nursery.cancel_scope.deadline = _core.current_time() + + +@pytest.mark.skipif(buggy_pypy_asyncgens, reason="pypy 7.2.0 is buggy") +async def test_fallback_when_no_hook_claims_it(capsys): + async def well_behaved(): + yield 42 + + async def yields_after_yield(): + with pytest.raises(GeneratorExit): + yield 42 + yield 100 + + async def awaits_after_yield(): + with pytest.raises(GeneratorExit): + yield 42 + await _core.cancel_shielded_checkpoint() + + with restore_unraisablehook(): + await step_outside_async_context(well_behaved()) + gc_collect_harder() + assert capsys.readouterr().err == "" + + await step_outside_async_context(yields_after_yield()) + gc_collect_harder() + assert "ignored GeneratorExit" in capsys.readouterr().err + + await step_outside_async_context(awaits_after_yield()) + gc_collect_harder() + assert "awaited something during finalization" in capsys.readouterr().err + + +@pytest.mark.skipif(buggy_pypy_asyncgens, reason="pypy 7.2.0 is buggy") +def test_delegation_to_existing_hooks(): + record = [] + + def my_firstiter(agen): + record.append("firstiter " + agen.ag_frame.f_locals["arg"]) + + def my_finalizer(agen): + record.append("finalizer " + agen.ag_frame.f_locals["arg"]) + + async def example(arg): + try: + yield 42 + finally: + with pytest.raises(_core.Cancelled): + await _core.checkpoint() + record.append("trio collected " + arg) + + async def async_main(): + await step_outside_async_context(example("theirs")) + assert 42 == await example("ours").asend(None) + gc_collect_harder() + assert record == ["firstiter theirs", "finalizer theirs"] + record[:] = [] + await _core.wait_all_tasks_blocked() + assert record == ["trio collected ours"] + + with restore_unraisablehook(): + old_hooks = sys.get_asyncgen_hooks() + sys.set_asyncgen_hooks(my_firstiter, my_finalizer) + try: + _core.run(async_main) + finally: + assert sys.get_asyncgen_hooks() == (my_firstiter, my_finalizer) + sys.set_asyncgen_hooks(*old_hooks) diff --git a/trio/_core/_tests/test_guest_mode.py b/trio/_core/_tests/test_guest_mode.py new file mode 100644 index 0000000000..7b004cf04d --- /dev/null +++ b/trio/_core/_tests/test_guest_mode.py @@ -0,0 +1,550 @@ +import asyncio +import contextvars +import queue +import signal +import socket +import sys +import threading +import time +import traceback +import warnings +from functools import partial +from math import inf + +import pytest + +import trio +import trio.testing + +from ..._util import signal_raise +from .tutil import buggy_pypy_asyncgens, gc_collect_harder, restore_unraisablehook + + +# The simplest possible "host" loop. +# Nice features: +# - we can run code "outside" of trio using the schedule function passed to +# our main +# - final result is returned +# - any unhandled exceptions cause an immediate crash +def trivial_guest_run(trio_fn, **start_guest_run_kwargs): + todo = queue.Queue() + + host_thread = threading.current_thread() + + def run_sync_soon_threadsafe(fn): + if host_thread is threading.current_thread(): # pragma: no cover + crash = partial( + pytest.fail, "run_sync_soon_threadsafe called from host thread" + ) + todo.put(("run", crash)) + todo.put(("run", fn)) + + def run_sync_soon_not_threadsafe(fn): + if host_thread is not threading.current_thread(): # pragma: no cover + crash = partial( + pytest.fail, "run_sync_soon_not_threadsafe called from worker thread" + ) + todo.put(("run", crash)) + todo.put(("run", fn)) + + def done_callback(outcome): + todo.put(("unwrap", outcome)) + + trio.lowlevel.start_guest_run( + trio_fn, + run_sync_soon_not_threadsafe, + run_sync_soon_threadsafe=run_sync_soon_threadsafe, + run_sync_soon_not_threadsafe=run_sync_soon_not_threadsafe, + done_callback=done_callback, + **start_guest_run_kwargs, + ) + + try: + while True: + op, obj = todo.get() + if op == "run": + obj() + elif op == "unwrap": + return obj.unwrap() + else: # pragma: no cover + assert False + finally: + # Make sure that exceptions raised here don't capture these, so that + # if an exception does cause us to abandon a run then the Trio state + # has a chance to be GC'ed and warn about it. + del todo, run_sync_soon_threadsafe, done_callback + + +def test_guest_trivial(): + async def trio_return(in_host): + await trio.sleep(0) + return "ok" + + assert trivial_guest_run(trio_return) == "ok" + + async def trio_fail(in_host): + raise KeyError("whoopsiedaisy") + + with pytest.raises(KeyError, match="whoopsiedaisy"): + trivial_guest_run(trio_fail) + + +def test_guest_can_do_io(): + async def trio_main(in_host): + record = [] + a, b = trio.socket.socketpair() + with a, b: + async with trio.open_nursery() as nursery: + + async def do_receive(): + record.append(await a.recv(1)) + + nursery.start_soon(do_receive) + await trio.testing.wait_all_tasks_blocked() + + await b.send(b"x") + + assert record == [b"x"] + + trivial_guest_run(trio_main) + + +def test_host_can_directly_wake_trio_task(): + async def trio_main(in_host): + ev = trio.Event() + in_host(ev.set) + await ev.wait() + return "ok" + + assert trivial_guest_run(trio_main) == "ok" + + +def test_host_altering_deadlines_wakes_trio_up(): + def set_deadline(cscope, new_deadline): + cscope.deadline = new_deadline + + async def trio_main(in_host): + with trio.CancelScope() as cscope: + in_host(lambda: set_deadline(cscope, -inf)) + await trio.sleep_forever() + assert cscope.cancelled_caught + + with trio.CancelScope() as cscope: + # also do a change that doesn't affect the next deadline, just to + # exercise that path + in_host(lambda: set_deadline(cscope, 1e6)) + in_host(lambda: set_deadline(cscope, -inf)) + await trio.sleep(999) + assert cscope.cancelled_caught + + return "ok" + + assert trivial_guest_run(trio_main) == "ok" + + +def test_warn_set_wakeup_fd_overwrite(): + assert signal.set_wakeup_fd(-1) == -1 + + async def trio_main(in_host): + return "ok" + + a, b = socket.socketpair() + with a, b: + a.setblocking(False) + + # Warn if there's already a wakeup fd + signal.set_wakeup_fd(a.fileno()) + try: + with pytest.warns(RuntimeWarning, match="signal handling code.*collided"): + assert trivial_guest_run(trio_main) == "ok" + finally: + assert signal.set_wakeup_fd(-1) == a.fileno() + + signal.set_wakeup_fd(a.fileno()) + try: + with pytest.warns(RuntimeWarning, match="signal handling code.*collided"): + assert ( + trivial_guest_run(trio_main, host_uses_signal_set_wakeup_fd=False) + == "ok" + ) + finally: + assert signal.set_wakeup_fd(-1) == a.fileno() + + # Don't warn if there isn't already a wakeup fd + with warnings.catch_warnings(): + warnings.simplefilter("error") + assert trivial_guest_run(trio_main) == "ok" + + with warnings.catch_warnings(): + warnings.simplefilter("error") + assert ( + trivial_guest_run(trio_main, host_uses_signal_set_wakeup_fd=True) + == "ok" + ) + + # If there's already a wakeup fd, but we've been told to trust it, + # then it's left alone and there's no warning + signal.set_wakeup_fd(a.fileno()) + try: + + async def trio_check_wakeup_fd_unaltered(in_host): + fd = signal.set_wakeup_fd(-1) + assert fd == a.fileno() + signal.set_wakeup_fd(fd) + return "ok" + + with warnings.catch_warnings(): + warnings.simplefilter("error") + assert ( + trivial_guest_run( + trio_check_wakeup_fd_unaltered, + host_uses_signal_set_wakeup_fd=True, + ) + == "ok" + ) + finally: + assert signal.set_wakeup_fd(-1) == a.fileno() + + +def test_host_wakeup_doesnt_trigger_wait_all_tasks_blocked(): + # This is designed to hit the branch in unrolled_run where: + # idle_primed=True + # runner.runq is empty + # events is Truth-y + # ...and confirm that in this case, wait_all_tasks_blocked does not get + # triggered. + def set_deadline(cscope, new_deadline): + print(f"setting deadline {new_deadline}") + cscope.deadline = new_deadline + + async def trio_main(in_host): + async def sit_in_wait_all_tasks_blocked(watb_cscope): + with watb_cscope: + # Overall point of this test is that this + # wait_all_tasks_blocked should *not* return normally, but + # only by cancellation. + await trio.testing.wait_all_tasks_blocked(cushion=9999) + assert False # pragma: no cover + assert watb_cscope.cancelled_caught + + async def get_woken_by_host_deadline(watb_cscope): + with trio.CancelScope() as cscope: + print("scheduling stuff to happen") + + # Altering the deadline from the host, to something in the + # future, will cause the run loop to wake up, but then + # discover that there is nothing to do and go back to sleep. + # This should *not* trigger wait_all_tasks_blocked. + # + # So the 'before_io_wait' here will wait until we're blocking + # with the wait_all_tasks_blocked primed, and then schedule a + # deadline change. The critical test is that this should *not* + # wake up 'sit_in_wait_all_tasks_blocked'. + # + # The after we've had a chance to wake up + # 'sit_in_wait_all_tasks_blocked', we want the test to + # actually end. So in after_io_wait we schedule a second host + # call to tear things down. + class InstrumentHelper: + def __init__(self): + self.primed = False + + def before_io_wait(self, timeout): + print(f"before_io_wait({timeout})") + if timeout == 9999: # pragma: no branch + assert not self.primed + in_host(lambda: set_deadline(cscope, 1e9)) + self.primed = True + + def after_io_wait(self, timeout): + if self.primed: # pragma: no branch + print("instrument triggered") + in_host(lambda: cscope.cancel()) + trio.lowlevel.remove_instrument(self) + + trio.lowlevel.add_instrument(InstrumentHelper()) + await trio.sleep_forever() + assert cscope.cancelled_caught + watb_cscope.cancel() + + async with trio.open_nursery() as nursery: + watb_cscope = trio.CancelScope() + nursery.start_soon(sit_in_wait_all_tasks_blocked, watb_cscope) + await trio.testing.wait_all_tasks_blocked() + nursery.start_soon(get_woken_by_host_deadline, watb_cscope) + + return "ok" + + assert trivial_guest_run(trio_main) == "ok" + + +@restore_unraisablehook() +def test_guest_warns_if_abandoned(): + # This warning is emitted from the garbage collector. So we have to make + # sure that our abandoned run is garbage. The easiest way to do this is to + # put it into a function, so that we're sure all the local state, + # traceback frames, etc. are garbage once it returns. + def do_abandoned_guest_run(): + async def abandoned_main(in_host): + in_host(lambda: 1 / 0) + while True: + await trio.sleep(0) + + with pytest.raises(ZeroDivisionError): + trivial_guest_run(abandoned_main) + + with pytest.warns(RuntimeWarning, match="Trio guest run got abandoned"): + do_abandoned_guest_run() + gc_collect_harder() + + # If you have problems some day figuring out what's holding onto a + # reference to the unrolled_run generator and making this test fail, + # then this might be useful to help track it down. (It assumes you + # also hack start_guest_run so that it does 'global W; W = + # weakref(unrolled_run_gen)'.) + # + # import gc + # print(trio._core._run.W) + # targets = [trio._core._run.W()] + # for i in range(15): + # new_targets = [] + # for target in targets: + # new_targets += gc.get_referrers(target) + # new_targets.remove(targets) + # print("#####################") + # print(f"depth {i}: {len(new_targets)}") + # print(new_targets) + # targets = new_targets + + with pytest.raises(RuntimeError): + trio.current_time() + + +def aiotrio_run(trio_fn, *, pass_not_threadsafe=True, **start_guest_run_kwargs): + loop = asyncio.new_event_loop() + + async def aio_main(): + trio_done_fut = loop.create_future() + + def trio_done_callback(main_outcome): + print(f"trio_fn finished: {main_outcome!r}") + trio_done_fut.set_result(main_outcome) + + if pass_not_threadsafe: + start_guest_run_kwargs["run_sync_soon_not_threadsafe"] = loop.call_soon + + trio.lowlevel.start_guest_run( + trio_fn, + run_sync_soon_threadsafe=loop.call_soon_threadsafe, + done_callback=trio_done_callback, + **start_guest_run_kwargs, + ) + + return (await trio_done_fut).unwrap() + + try: + return loop.run_until_complete(aio_main()) + finally: + loop.close() + + +def test_guest_mode_on_asyncio(): + async def trio_main(): + print("trio_main!") + + to_trio, from_aio = trio.open_memory_channel(float("inf")) + from_trio = asyncio.Queue() + + aio_task = asyncio.ensure_future(aio_pingpong(from_trio, to_trio)) + + # Make sure we have at least one tick where we don't need to go into + # the thread + await trio.sleep(0) + + from_trio.put_nowait(0) + + async for n in from_aio: + print(f"trio got: {n}") + from_trio.put_nowait(n + 1) + if n >= 10: + aio_task.cancel() + return "trio-main-done" + + async def aio_pingpong(from_trio, to_trio): + print("aio_pingpong!") + + try: + while True: + n = await from_trio.get() + print(f"aio got: {n}") + to_trio.send_nowait(n + 1) + except asyncio.CancelledError: + raise + except: # pragma: no cover + traceback.print_exc() + raise + + assert ( + aiotrio_run( + trio_main, + # Not all versions of asyncio we test on can actually be trusted, + # but this test doesn't care about signal handling, and it's + # easier to just avoid the warnings. + host_uses_signal_set_wakeup_fd=True, + ) + == "trio-main-done" + ) + + assert ( + aiotrio_run( + trio_main, + # Also check that passing only call_soon_threadsafe works, via the + # fallback path where we use it for everything. + pass_not_threadsafe=False, + host_uses_signal_set_wakeup_fd=True, + ) + == "trio-main-done" + ) + + +def test_guest_mode_internal_errors(monkeypatch, recwarn): + with monkeypatch.context() as m: + + async def crash_in_run_loop(in_host): + m.setattr("trio._core._run.GLOBAL_RUN_CONTEXT.runner.runq", "HI") + await trio.sleep(1) + + with pytest.raises(trio.TrioInternalError): + trivial_guest_run(crash_in_run_loop) + + with monkeypatch.context() as m: + + async def crash_in_io(in_host): + m.setattr("trio._core._run.TheIOManager.get_events", None) + await trio.sleep(0) + + with pytest.raises(trio.TrioInternalError): + trivial_guest_run(crash_in_io) + + with monkeypatch.context() as m: + + async def crash_in_worker_thread_io(in_host): + t = threading.current_thread() + old_get_events = trio._core._run.TheIOManager.get_events + + def bad_get_events(*args): + if threading.current_thread() is not t: + raise ValueError("oh no!") + else: + return old_get_events(*args) + + m.setattr("trio._core._run.TheIOManager.get_events", bad_get_events) + + await trio.sleep(1) + + with pytest.raises(trio.TrioInternalError): + trivial_guest_run(crash_in_worker_thread_io) + + gc_collect_harder() + + +def test_guest_mode_ki(): + assert signal.getsignal(signal.SIGINT) is signal.default_int_handler + + # Check SIGINT in Trio func and in host func + async def trio_main(in_host): + with pytest.raises(KeyboardInterrupt): + signal_raise(signal.SIGINT) + + # Host SIGINT should get injected into Trio + in_host(partial(signal_raise, signal.SIGINT)) + await trio.sleep(10) + + with pytest.raises(KeyboardInterrupt) as excinfo: + trivial_guest_run(trio_main) + assert excinfo.value.__context__ is None + # Signal handler should be restored properly on exit + assert signal.getsignal(signal.SIGINT) is signal.default_int_handler + + # Also check chaining in the case where KI is injected after main exits + final_exc = KeyError("whoa") + + async def trio_main_raising(in_host): + in_host(partial(signal_raise, signal.SIGINT)) + raise final_exc + + with pytest.raises(KeyboardInterrupt) as excinfo: + trivial_guest_run(trio_main_raising) + assert excinfo.value.__context__ is final_exc + + assert signal.getsignal(signal.SIGINT) is signal.default_int_handler + + +def test_guest_mode_autojump_clock_threshold_changing(): + # This is super obscure and probably no-one will ever notice, but + # technically mutating the MockClock.autojump_threshold from the host + # should wake up the guest, so let's test it. + + clock = trio.testing.MockClock() + + DURATION = 120 + + async def trio_main(in_host): + assert trio.current_time() == 0 + in_host(lambda: setattr(clock, "autojump_threshold", 0)) + await trio.sleep(DURATION) + assert trio.current_time() == DURATION + + start = time.monotonic() + trivial_guest_run(trio_main, clock=clock) + end = time.monotonic() + # Should be basically instantaneous, but we'll leave a generous buffer to + # account for any CI weirdness + assert end - start < DURATION / 2 + + +@pytest.mark.skipif(buggy_pypy_asyncgens, reason="PyPy 7.2 is buggy") +@pytest.mark.xfail( + sys.implementation.name == "pypy", + reason="async generator issue under investigation", +) +@restore_unraisablehook() +def test_guest_mode_asyncgens(): + import sniffio + + record = set() + + async def agen(label): + assert sniffio.current_async_library() == label + try: + yield 1 + finally: + library = sniffio.current_async_library() + try: + await sys.modules[library].sleep(0) + except trio.Cancelled: + pass + record.add((label, library)) + + async def iterate_in_aio(): + # "trio" gets inherited from our Trio caller if we don't set this + sniffio.current_async_library_cvar.set("asyncio") + await agen("asyncio").asend(None) + + async def trio_main(): + task = asyncio.ensure_future(iterate_in_aio()) + done_evt = trio.Event() + task.add_done_callback(lambda _: done_evt.set()) + with trio.fail_after(1): + await done_evt.wait() + + await agen("trio").asend(None) + + gc_collect_harder() + + # Ensure we don't pollute the thread-level context if run under + # an asyncio without contextvars support (3.6) + context = contextvars.copy_context() + context.run(aiotrio_run, trio_main, host_uses_signal_set_wakeup_fd=True) + + assert record == {("asyncio", "asyncio"), ("trio", "trio")} diff --git a/trio/_core/_tests/test_instrumentation.py b/trio/_core/_tests/test_instrumentation.py new file mode 100644 index 0000000000..498a3eb272 --- /dev/null +++ b/trio/_core/_tests/test_instrumentation.py @@ -0,0 +1,254 @@ +import attr +import pytest + +from ... import _abc, _core +from .tutil import check_sequence_matches + + +@attr.s(eq=False, hash=False) +class TaskRecorder: + record = attr.ib(factory=list) + + def before_run(self): + self.record.append(("before_run",)) + + def task_scheduled(self, task): + self.record.append(("schedule", task)) + + def before_task_step(self, task): + assert task is _core.current_task() + self.record.append(("before", task)) + + def after_task_step(self, task): + assert task is _core.current_task() + self.record.append(("after", task)) + + def after_run(self): + self.record.append(("after_run",)) + + def filter_tasks(self, tasks): + for item in self.record: + if item[0] in ("schedule", "before", "after") and item[1] in tasks: + yield item + if item[0] in ("before_run", "after_run"): + yield item + + +def test_instruments(recwarn): + r1 = TaskRecorder() + r2 = TaskRecorder() + r3 = TaskRecorder() + + task = None + + # We use a child task for this, because the main task does some extra + # bookkeeping stuff that can leak into the instrument results, and we + # don't want to deal with it. + async def task_fn(): + nonlocal task + task = _core.current_task() + + for _ in range(4): + await _core.checkpoint() + # replace r2 with r3, to test that we can manipulate them as we go + _core.remove_instrument(r2) + with pytest.raises(KeyError): + _core.remove_instrument(r2) + # add is idempotent + _core.add_instrument(r3) + _core.add_instrument(r3) + for _ in range(1): + await _core.checkpoint() + + async def main(): + async with _core.open_nursery() as nursery: + nursery.start_soon(task_fn) + + _core.run(main, instruments=[r1, r2]) + + # It sleeps 5 times, so it runs 6 times. Note that checkpoint() + # reschedules the task immediately upon yielding, before the + # after_task_step event fires. + expected = ( + [("before_run",), ("schedule", task)] + + [("before", task), ("schedule", task), ("after", task)] * 5 + + [("before", task), ("after", task), ("after_run",)] + ) + assert r1.record == r2.record + r3.record + assert list(r1.filter_tasks([task])) == expected + + +def test_instruments_interleave(): + tasks = {} + + async def two_step1(): + tasks["t1"] = _core.current_task() + await _core.checkpoint() + + async def two_step2(): + tasks["t2"] = _core.current_task() + await _core.checkpoint() + + async def main(): + async with _core.open_nursery() as nursery: + nursery.start_soon(two_step1) + nursery.start_soon(two_step2) + + r = TaskRecorder() + _core.run(main, instruments=[r]) + + expected = [ + ("before_run",), + ("schedule", tasks["t1"]), + ("schedule", tasks["t2"]), + { + ("before", tasks["t1"]), + ("schedule", tasks["t1"]), + ("after", tasks["t1"]), + ("before", tasks["t2"]), + ("schedule", tasks["t2"]), + ("after", tasks["t2"]), + }, + { + ("before", tasks["t1"]), + ("after", tasks["t1"]), + ("before", tasks["t2"]), + ("after", tasks["t2"]), + }, + ("after_run",), + ] + print(list(r.filter_tasks(tasks.values()))) + check_sequence_matches(list(r.filter_tasks(tasks.values())), expected) + + +def test_null_instrument(): + # undefined instrument methods are skipped + class NullInstrument: + def something_unrelated(self): + pass # pragma: no cover + + async def main(): + await _core.checkpoint() + + _core.run(main, instruments=[NullInstrument()]) + + +def test_instrument_before_after_run(): + record = [] + + class BeforeAfterRun: + def before_run(self): + record.append("before_run") + + def after_run(self): + record.append("after_run") + + async def main(): + pass + + _core.run(main, instruments=[BeforeAfterRun()]) + assert record == ["before_run", "after_run"] + + +def test_instrument_task_spawn_exit(): + record = [] + + class SpawnExitRecorder: + def task_spawned(self, task): + record.append(("spawned", task)) + + def task_exited(self, task): + record.append(("exited", task)) + + async def main(): + return _core.current_task() + + main_task = _core.run(main, instruments=[SpawnExitRecorder()]) + assert ("spawned", main_task) in record + assert ("exited", main_task) in record + + +# This test also tests having a crash before the initial task is even spawned, +# which is very difficult to handle. +def test_instruments_crash(caplog): + record = [] + + class BrokenInstrument: + def task_scheduled(self, task): + record.append("scheduled") + raise ValueError("oops") + + def close(self): + # Shouldn't be called -- tests that the instrument disabling logic + # works right. + record.append("closed") # pragma: no cover + + async def main(): + record.append("main ran") + return _core.current_task() + + r = TaskRecorder() + main_task = _core.run(main, instruments=[r, BrokenInstrument()]) + assert record == ["scheduled", "main ran"] + # the TaskRecorder kept going throughout, even though the BrokenInstrument + # was disabled + assert ("after", main_task) in r.record + assert ("after_run",) in r.record + # And we got a log message + exc_type, exc_value, exc_traceback = caplog.records[0].exc_info + assert exc_type is ValueError + assert str(exc_value) == "oops" + assert "Instrument has been disabled" in caplog.records[0].message + + +def test_instruments_monkeypatch(): + class NullInstrument(_abc.Instrument): + pass + + instrument = NullInstrument() + + async def main(): + record = [] + + # Changing the set of hooks implemented by an instrument after + # it's installed doesn't make them start being called right away + instrument.before_task_step = record.append + await _core.checkpoint() + await _core.checkpoint() + assert len(record) == 0 + + # But if we remove and re-add the instrument, the new hooks are + # picked up + _core.remove_instrument(instrument) + _core.add_instrument(instrument) + await _core.checkpoint() + await _core.checkpoint() + assert record.count(_core.current_task()) == 2 + + _core.remove_instrument(instrument) + await _core.checkpoint() + await _core.checkpoint() + assert record.count(_core.current_task()) == 2 + + _core.run(main, instruments=[instrument]) + + +def test_instrument_that_raises_on_getattr(): + class EvilInstrument: + def task_exited(self, task): + assert False # pragma: no cover + + @property + def after_run(self): + raise ValueError("oops") + + async def main(): + with pytest.raises(ValueError): + _core.add_instrument(EvilInstrument()) + + # Make sure the instrument is fully removed from the per-method lists + runner = _core.current_task()._runner + assert "after_run" not in runner.instruments + assert "task_exited" not in runner.instruments + + _core.run(main) diff --git a/trio/_core/tests/test_io.py b/trio/_core/_tests/test_io.py similarity index 91% rename from trio/_core/tests/test_io.py rename to trio/_core/_tests/test_io.py index 2adb3f9a65..21a954941c 100644 --- a/trio/_core/tests/test_io.py +++ b/trio/_core/_tests/test_io.py @@ -1,15 +1,14 @@ -import pytest - -import socket as stdlib_socket -import select import random -import errno +import socket as stdlib_socket from contextlib import suppress -from ... import _core -from ...testing import wait_all_tasks_blocked, Sequencer, assert_checkpoints +import pytest + import trio +from ... import _core +from ...testing import assert_checkpoints, wait_all_tasks_blocked + # Cross-platform tests for IO handling @@ -43,17 +42,19 @@ def using_fileno(fn): def fileno_wrapper(fileobj): return fn(fileobj.fileno()) - name = "<{} on fileno>".format(fn.__name__) + name = f"<{fn.__name__} on fileno>" fileno_wrapper.__name__ = fileno_wrapper.__qualname__ = name return fileno_wrapper -wait_readable_options = [trio.hazmat.wait_readable] -wait_writable_options = [trio.hazmat.wait_writable] -notify_closing_options = [trio.hazmat.notify_closing] +wait_readable_options = [trio.lowlevel.wait_readable] +wait_writable_options = [trio.lowlevel.wait_writable] +notify_closing_options = [trio.lowlevel.notify_closing] for options_list in [ - wait_readable_options, wait_writable_options, notify_closing_options + wait_readable_options, + wait_writable_options, + notify_closing_options, ]: options_list += [using_fileno(f) for f in options_list] @@ -196,9 +197,7 @@ async def writer(): @read_socket_test @write_socket_test -async def test_socket_simultaneous_read_write( - socketpair, wait_readable, wait_writable -): +async def test_socket_simultaneous_read_write(socketpair, wait_readable, wait_writable): record = [] async def r_task(sock): @@ -226,9 +225,7 @@ async def w_task(sock): @read_socket_test @write_socket_test -async def test_socket_actual_streaming( - socketpair, wait_readable, wait_writable -): +async def test_socket_actual_streaming(socketpair, wait_readable, wait_writable): a, b = socketpair # Use a small send buffer on one of the sockets to increase the chance of @@ -285,7 +282,7 @@ async def test_notify_closing_on_invalid_object(): got_oserror = False got_no_error = False try: - trio.hazmat.notify_closing(-1) + trio.lowlevel.notify_closing(-1) except OSError: got_oserror = True else: @@ -296,7 +293,7 @@ async def test_notify_closing_on_invalid_object(): async def test_wait_on_invalid_object(): # We definitely want to raise an error everywhere if you pass in an # invalid fd to wait_* - for wait in [trio.hazmat.wait_readable, trio.hazmat.wait_writable]: + for wait in [trio.lowlevel.wait_readable, trio.lowlevel.wait_writable]: with stdlib_socket.socket() as s: fileno = s.fileno() # We just closed the socket and don't do anything else in between, so @@ -377,7 +374,7 @@ async def allow_OSError(async_func, *args): with stdlib_socket.socket() as s: async with trio.open_nursery() as nursery: - nursery.start_soon(allow_OSError, trio.hazmat.wait_readable, s) + nursery.start_soon(allow_OSError, trio.lowlevel.wait_readable, s) await wait_all_tasks_blocked() s.close() await wait_all_tasks_blocked() @@ -389,7 +386,7 @@ async def allow_OSError(async_func, *args): # wait_readable pending until cancelled). with stdlib_socket.socket() as s, s.dup() as s2: # noqa: F841 async with trio.open_nursery() as nursery: - nursery.start_soon(allow_OSError, trio.hazmat.wait_readable, s) + nursery.start_soon(allow_OSError, trio.lowlevel.wait_readable, s) await wait_all_tasks_blocked() s.close() await wait_all_tasks_blocked() @@ -408,8 +405,8 @@ async def allow_OSError(async_func, *args): b.setblocking(False) fill_socket(a) async with trio.open_nursery() as nursery: - nursery.start_soon(allow_OSError, trio.hazmat.wait_readable, a) - nursery.start_soon(allow_OSError, trio.hazmat.wait_writable, a) + nursery.start_soon(allow_OSError, trio.lowlevel.wait_readable, a) + nursery.start_soon(allow_OSError, trio.lowlevel.wait_writable, a) await wait_all_tasks_blocked() a.close() nursery.cancel_scope.cancel() @@ -419,7 +416,7 @@ async def allow_OSError(async_func, *args): # handle_io context rather than abort context. a, b = stdlib_socket.socketpair() with a, b, a.dup() as a2: # noqa: F841 - print("a={}, b={}, a2={}".format(a.fileno(), b.fileno(), a2.fileno())) + print(f"a={a.fileno()}, b={b.fileno()}, a2={a2.fileno()}") a.setblocking(False) b.setblocking(False) fill_socket(a) @@ -432,12 +429,12 @@ async def allow_OSError(async_func, *args): # definitely arrive, and when it does then we can assume that whatever # notification was going to arrive for 'a' has also arrived. async def wait_readable_a2_then_set(): - await trio.hazmat.wait_readable(a2) + await trio.lowlevel.wait_readable(a2) e.set() async with trio.open_nursery() as nursery: - nursery.start_soon(allow_OSError, trio.hazmat.wait_readable, a) - nursery.start_soon(allow_OSError, trio.hazmat.wait_writable, a) + nursery.start_soon(allow_OSError, trio.lowlevel.wait_readable, a) + nursery.start_soon(allow_OSError, trio.lowlevel.wait_writable, a) nursery.start_soon(wait_readable_a2_then_set) await wait_all_tasks_blocked() a.close() diff --git a/trio/_core/tests/test_ki.py b/trio/_core/_tests/test_ki.py similarity index 66% rename from trio/_core/tests/test_ki.py rename to trio/_core/_tests/test_ki.py index e0d3b97c50..fdbada4624 100644 --- a/trio/_core/tests/test_ki.py +++ b/trio/_core/_tests/test_ki.py @@ -1,21 +1,20 @@ -import outcome -import pytest -import sys -import os +import contextlib +import inspect import signal import threading -import contextlib -import time -from async_generator import ( - async_generator, yield_, isasyncgenfunction, asynccontextmanager -) +import outcome +import pytest + +try: + from async_generator import async_generator, yield_ +except ImportError: # pragma: no cover + async_generator = yield_ = None from ... import _core -from ...testing import wait_all_tasks_blocked -from ..._util import signal_raise, is_main_thread from ..._timeouts import sleep -from .tutil import slow +from ..._util import signal_raise +from ...testing import wait_all_tasks_blocked def ki_self(): @@ -107,6 +106,7 @@ async def unprotected(): async def child(expected): import traceback + traceback.print_stack() assert _core.currently_ki_protected() == expected await _core.checkpoint() @@ -138,7 +138,8 @@ def protected_manager(): raise KeyError -async def test_agen_protection(): +@pytest.mark.skipif(async_generator is None, reason="async_generator not installed") +async def test_async_generator_agen_protection(): @_core.enable_ki_protection @async_generator async def agen_protected1(): @@ -176,31 +177,54 @@ async def agen_unprotected2(): finally: assert not _core.currently_ki_protected() - for agen_fn in [ - agen_protected1, - agen_protected2, - agen_unprotected1, - agen_unprotected2, - ]: - async for _ in agen_fn(): # noqa + await _check_agen(agen_protected1) + await _check_agen(agen_protected2) + await _check_agen(agen_unprotected1) + await _check_agen(agen_unprotected2) + + +async def test_native_agen_protection(): + # Native async generators + @_core.enable_ki_protection + async def agen_protected(): + assert _core.currently_ki_protected() + try: + yield + finally: + assert _core.currently_ki_protected() + + @_core.disable_ki_protection + async def agen_unprotected(): + assert not _core.currently_ki_protected() + try: + yield + finally: assert not _core.currently_ki_protected() - # asynccontextmanager insists that the function passed must itself be an - # async gen function, not a wrapper around one - if isasyncgenfunction(agen_fn): - async with asynccontextmanager(agen_fn)(): - assert not _core.currently_ki_protected() + await _check_agen(agen_protected) + await _check_agen(agen_unprotected) + + +async def _check_agen(agen_fn): + async for _ in agen_fn(): # noqa + assert not _core.currently_ki_protected() + + # asynccontextmanager insists that the function passed must itself be an + # async gen function, not a wrapper around one + if inspect.isasyncgenfunction(agen_fn): + async with contextlib.asynccontextmanager(agen_fn)(): + assert not _core.currently_ki_protected() - # Another case that's tricky due to: - # https://bugs.python.org/issue29590 - with pytest.raises(KeyError): - async with asynccontextmanager(agen_fn)(): - raise KeyError + # Another case that's tricky due to: + # https://bugs.python.org/issue29590 + with pytest.raises(KeyError): + async with contextlib.asynccontextmanager(agen_fn)(): + raise KeyError # Test the case where there's no magic local anywhere in the call stack -def test_ki_enabled_out_of_context(): - assert not _core.currently_ki_protected() +def test_ki_disabled_out_of_context(): + assert _core.currently_ki_protected() def test_ki_disabled_in_del(): @@ -211,8 +235,15 @@ def __del__(): assert _core.currently_ki_protected() assert nestedfunction() + @_core.disable_ki_protection + def outerfunction(): + assert not _core.currently_ki_protected() + assert not nestedfunction() + __del__() + __del__() - assert not nestedfunction() + outerfunction() + assert nestedfunction() def test_ki_protection_works(): @@ -240,9 +271,7 @@ async def raiser(name, record): # If we didn't raise (b/c protected), then we *should* get # cancelled at the next opportunity try: - await _core.wait_task_rescheduled( - lambda _: _core.Abort.SUCCEEDED - ) + await _core.wait_task_rescheduled(lambda _: _core.Abort.SUCCEEDED) except _core.Cancelled: record.add(name + " cancel ok") @@ -269,9 +298,7 @@ async def check_protected_kill(): async with _core.open_nursery() as nursery: nursery.start_soon(sleeper, "s1", record) nursery.start_soon(sleeper, "s2", record) - nursery.start_soon( - _core.enable_ki_protection(raiser), "r1", record - ) + nursery.start_soon(_core.enable_ki_protection(raiser), "r1", record) # __aexit__ blocks, and then receives the KI with pytest.raises(KeyboardInterrupt): @@ -468,131 +495,8 @@ def test_ki_with_broken_threads(): @_core.enable_ki_protection async def inner(): - assert signal.getsignal( - signal.SIGINT - ) != signal.default_int_handler + assert signal.getsignal(signal.SIGINT) != signal.default_int_handler _core.run(inner) finally: threading._active[thread.ident] = original - - -# For details on why this test is non-trivial, see: -# https://github.com/python-trio/trio/issues/42 -# https://github.com/python-trio/trio/issues/109 -@slow -def test_ki_wakes_us_up(): - assert is_main_thread() - - # This test is flaky due to a race condition on Windows; see: - # https://github.com/python-trio/trio/issues/119 - # https://bugs.python.org/issue30038 - # I think the only fix is to wait for fixed CPython to be released, so in - # the mean time, on affected versions we send two signals (equivalent to - # hitting control-C twice). This works because the problem is that the C - # level signal handler does - # - # write-to-fd -> set-flags - # - # and we need - # - # set-flags -> write-to-fd - # - # so running the C level signal handler twice does - # - # write-to-fd -> set-flags -> write-to-fd -> set-flags - # - # which contains the desired sequence. - # - # Affected version of CPython include: - # - all versions of 3.5 (fix will not be backported) - # - 3.6.1 and earlier - # It's fixed in 3.6.2 and 3.7+ - # - # PyPy was never affected. - # - # The problem technically can occur on Unix as well, if a signal is - # delivered to a non-main thread, though we haven't observed this in - # practice. - # - # There's also this theoretical problem, but hopefully it won't actually - # bite us in practice: - # https://bugs.python.org/issue31119 - # https://bitbucket.org/pypy/pypy/issues/2623 - import platform - buggy_wakeup_fd = ( - platform.python_implementation() == "CPython" and sys.version_info < - (3, 6, 2) - ) - - # lock is only needed to avoid an annoying race condition where the - # *second* ki_self() call arrives *after* the first one woke us up and its - # KeyboardInterrupt was caught, and then generates a second - # KeyboardInterrupt that aborts the test run. The kill_soon thread holds - # the lock while doing the calls to ki_self, which means that it holds it - # while the C-level signal handler is running. Then in the main thread, - # when we're woken up we know that ki_self() has been run at least once; - # if we then take the lock it guaranteeds that ki_self() has been run - # twice, so if a second KeyboardInterrupt is going to arrive it should - # arrive by the time we've acquired the lock. This lets us force it to - # happen inside the pytest.raises block. - # - # It will be very nice when the buggy_wakeup_fd bug is fixed. - lock = threading.Lock() - - def kill_soon(): - # We want the signal to be raised after the main thread has entered - # the IO manager blocking primitive. There really is no way to - # deterministically interlock with that, so we have to use sleep and - # hope it's long enough. - time.sleep(1.1) - with lock: - print("thread doing ki_self()") - ki_self() - if buggy_wakeup_fd: - print("buggy_wakeup_fd =", buggy_wakeup_fd) - ki_self() - - async def main(): - thread = threading.Thread(target=kill_soon) - print("Starting thread") - thread.start() - try: - with pytest.raises(KeyboardInterrupt): - # To limit the damage on CI if this does get broken (as - # compared to sleep_forever()) - print("Going to sleep") - try: - await sleep(20) - print("Woke without raising?!") # pragma: no cover - # The only purpose of this finally: block is to soak up the - # second KeyboardInterrupt that might arrive on - # buggy_wakeup_fd platforms. So it might get aborted at any - # moment randomly on some runs, so pragma: no cover avoids - # coverage flapping: - finally: # pragma: no cover - print("waiting for lock") - with lock: - print("got lock") - # And then we want to force a PyErr_CheckSignals. Which is - # not so easy on Windows. Weird kluge: builtin_repr calls - # PyObject_Repr, which does an unconditional - # PyErr_CheckSignals for some reason. - print(repr(None)) - # And finally, it's possible that the signal was delivered - # but at a moment when we had KI protection enabled, so we - # need to execute a checkpoint to ensure it's delivered - # before we exit main(). - await _core.checkpoint() - finally: - print("joining thread", sys.exc_info()) - thread.join() - - start = time.perf_counter() - try: - _core.run(main) - finally: - end = time.perf_counter() - print("duration", end - start) - print("sys.exc_info", sys.exc_info()) - assert 1.0 <= (end - start) < 2 diff --git a/trio/_core/tests/test_local.py b/trio/_core/_tests/test_local.py similarity index 97% rename from trio/_core/tests/test_local.py rename to trio/_core/_tests/test_local.py index 7f403168ea..619dcd20d4 100644 --- a/trio/_core/tests/test_local.py +++ b/trio/_core/_tests/test_local.py @@ -80,7 +80,7 @@ async def task2(tok): with pytest.raises(LookupError): t1.get() - t1.set("cod") + t1.set("haddock") async with _core.open_nursery() as n: token = t1.set("cod") @@ -92,7 +92,7 @@ async def task2(tok): n.start_soon(task2, token) await _core.wait_all_tasks_blocked() - assert t1.get() == "cod" + assert t1.get() == "haddock" _core.run(sync_check) diff --git a/trio/_core/_tests/test_mock_clock.py b/trio/_core/_tests/test_mock_clock.py new file mode 100644 index 0000000000..9c74df3334 --- /dev/null +++ b/trio/_core/_tests/test_mock_clock.py @@ -0,0 +1,171 @@ +import time +from math import inf + +import pytest + +from trio import sleep + +from ... import _core +from .. import wait_all_tasks_blocked +from .._mock_clock import MockClock +from .tutil import slow + + +def test_mock_clock(): + REAL_NOW = 123.0 + c = MockClock() + c._real_clock = lambda: REAL_NOW + repr(c) # smoke test + assert c.rate == 0 + assert c.current_time() == 0 + c.jump(1.2) + assert c.current_time() == 1.2 + with pytest.raises(ValueError): + c.jump(-1) + assert c.current_time() == 1.2 + assert c.deadline_to_sleep_time(1.1) == 0 + assert c.deadline_to_sleep_time(1.2) == 0 + assert c.deadline_to_sleep_time(1.3) > 999999 + + with pytest.raises(ValueError): + c.rate = -1 + assert c.rate == 0 + + c.rate = 2 + assert c.current_time() == 1.2 + REAL_NOW += 1 + assert c.current_time() == 3.2 + assert c.deadline_to_sleep_time(3.1) == 0 + assert c.deadline_to_sleep_time(3.2) == 0 + assert c.deadline_to_sleep_time(4.2) == 0.5 + + c.rate = 0.5 + assert c.current_time() == 3.2 + assert c.deadline_to_sleep_time(3.1) == 0 + assert c.deadline_to_sleep_time(3.2) == 0 + assert c.deadline_to_sleep_time(4.2) == 2.0 + + c.jump(0.8) + assert c.current_time() == 4.0 + REAL_NOW += 1 + assert c.current_time() == 4.5 + + c2 = MockClock(rate=3) + assert c2.rate == 3 + assert c2.current_time() < 10 + + +async def test_mock_clock_autojump(mock_clock): + assert mock_clock.autojump_threshold == inf + + mock_clock.autojump_threshold = 0 + assert mock_clock.autojump_threshold == 0 + + real_start = time.perf_counter() + + virtual_start = _core.current_time() + for i in range(10): + print(f"sleeping {10 * i} seconds") + await sleep(10 * i) + print("woke up!") + assert virtual_start + 10 * i == _core.current_time() + virtual_start = _core.current_time() + + real_duration = time.perf_counter() - real_start + print(f"Slept {10 * sum(range(10))} seconds in {real_duration} seconds") + assert real_duration < 1 + + mock_clock.autojump_threshold = 0.02 + t = _core.current_time() + # this should wake up before the autojump threshold triggers, so time + # shouldn't change + await wait_all_tasks_blocked() + assert t == _core.current_time() + # this should too + await wait_all_tasks_blocked(0.01) + assert t == _core.current_time() + + # set up a situation where the autojump task is blocked for a long long + # time, to make sure that cancel-and-adjust-threshold logic is working + mock_clock.autojump_threshold = 10000 + await wait_all_tasks_blocked() + mock_clock.autojump_threshold = 0 + # if the above line didn't take affect immediately, then this would be + # bad: + await sleep(100000) + + +async def test_mock_clock_autojump_interference(mock_clock): + mock_clock.autojump_threshold = 0.02 + + mock_clock2 = MockClock() + # messing with the autojump threshold of a clock that isn't actually + # installed in the run loop shouldn't do anything. + mock_clock2.autojump_threshold = 0.01 + + # if the autojump_threshold of 0.01 were in effect, then the next line + # would block forever, as the autojump task kept waking up to try to + # jump the clock. + await wait_all_tasks_blocked(0.015) + + # but the 0.02 limit does apply + await sleep(100000) + + +def test_mock_clock_autojump_preset(): + # Check that we can set the autojump_threshold before the clock is + # actually in use, and it gets picked up + mock_clock = MockClock(autojump_threshold=0.1) + mock_clock.autojump_threshold = 0.01 + real_start = time.perf_counter() + _core.run(sleep, 10000, clock=mock_clock) + assert time.perf_counter() - real_start < 1 + + +async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_0(mock_clock): + # Checks that autojump_threshold=0 doesn't interfere with + # calling wait_all_tasks_blocked with the default cushion=0. + + mock_clock.autojump_threshold = 0 + + record = [] + + async def sleeper(): + await sleep(100) + record.append("yawn") + + async def waiter(): + await wait_all_tasks_blocked() + record.append("waiter woke") + await sleep(1000) + record.append("waiter done") + + async with _core.open_nursery() as nursery: + nursery.start_soon(sleeper) + nursery.start_soon(waiter) + + assert record == ["waiter woke", "yawn", "waiter done"] + + +@slow +async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_nonzero(mock_clock): + # Checks that autojump_threshold=0 doesn't interfere with + # calling wait_all_tasks_blocked with a non-zero cushion. + + mock_clock.autojump_threshold = 0 + + record = [] + + async def sleeper(): + await sleep(100) + record.append("yawn") + + async def waiter(): + await wait_all_tasks_blocked(1) + record.append("waiter done") + + async with _core.open_nursery() as nursery: + nursery.start_soon(sleeper) + nursery.start_soon(waiter) + + assert record == ["waiter done", "yawn"] diff --git a/trio/_core/tests/test_multierror.py b/trio/_core/_tests/test_multierror.py similarity index 55% rename from trio/_core/tests/test_multierror.py rename to trio/_core/_tests/test_multierror.py index 83894f358d..7a8bd2f9a8 100644 --- a/trio/_core/tests/test_multierror.py +++ b/trio/_core/_tests/test_multierror.py @@ -1,17 +1,22 @@ -import logging -import pytest - -from traceback import extract_tb, print_exception, format_exception, _cause_message -import sys +import gc import os +import pickle import re -from pathlib import Path import subprocess +import sys +import warnings +from pathlib import Path +from traceback import extract_tb, print_exception -from .tutil import slow +import pytest -from .._multierror import MultiError, concat_tb +from ... import TrioDeprecationWarning from ..._core import open_nursery +from .._multierror import MultiError, NonBaseMultiError, concat_tb +from .tutil import slow + +if sys.version_info < (3, 11): + from exceptiongroup import ExceptionGroup class NotHashableException(Exception): @@ -66,12 +71,7 @@ def get_tb(raiser): return get_exc(raiser).__traceback__ -def einfo(exc): - return (type(exc), exc, exc.__traceback__) - - def test_concat_tb(): - tb1 = get_tb(raiser1) tb2 = get_tb(raiser2) @@ -102,7 +102,7 @@ def test_MultiError(): assert MultiError([exc1]) is exc1 m = MultiError([exc1, exc2]) - assert m.exceptions == [exc1, exc2] + assert m.exceptions == (exc1, exc2) assert "ValueError" in str(m) assert "ValueError" in repr(m) @@ -115,7 +115,7 @@ def test_MultiError(): def test_MultiErrorOfSingleMultiError(): # For MultiError([MultiError]), ensure there is no bad recursion by the # constructor where __init__ is called if __new__ returns a bare MultiError. - exceptions = [KeyError(), ValueError()] + exceptions = (KeyError(), ValueError()) a = MultiError(exceptions) b = MultiError([a]) assert b == a @@ -144,25 +144,10 @@ def handle_ValueError(exc): else: return exc - filtered_excs = MultiError.filter(handle_ValueError, excs) - assert isinstance(filtered_excs, NotHashableException) - + with pytest.warns(TrioDeprecationWarning): + filtered_excs = MultiError.filter(handle_ValueError, excs) -def test_traceback_recursion(): - exc1 = RuntimeError() - exc2 = KeyError() - exc3 = NotHashableException(42) - # Note how this creates a loop, where exc1 refers to exc1 - # This could trigger an infinite recursion; the 'seen' set is supposed to prevent - # this. - exc1.__cause__ = MultiError([exc1, exc2, exc3]) - # python traceback.TracebackException < 3.6.4 does not support unhashable exceptions - # and raises a TypeError exception - if sys.version_info < (3, 6, 4): - with pytest.raises(TypeError): - format_exception(*einfo(exc1)) - else: - format_exception(*einfo(exc1)) + assert isinstance(filtered_excs, NotHashableException) def make_tree(): @@ -206,7 +191,9 @@ def null_handler(exc): m = make_tree() assert_tree_eq(m, m) - assert MultiError.filter(null_handler, m) is m + with pytest.warns(TrioDeprecationWarning): + assert MultiError.filter(null_handler, m) is m + assert_tree_eq(m, make_tree()) # Make sure we don't pick up any detritus if run in a context where @@ -215,7 +202,8 @@ def null_handler(exc): try: raise ValueError except ValueError: - assert MultiError.filter(null_handler, m) is m + with pytest.warns(TrioDeprecationWarning): + assert MultiError.filter(null_handler, m) is m assert_tree_eq(m, make_tree()) def simple_filter(exc): @@ -225,7 +213,9 @@ def simple_filter(exc): return RuntimeError() return exc - new_m = MultiError.filter(simple_filter, make_tree()) + with pytest.warns(TrioDeprecationWarning): + new_m = MultiError.filter(simple_filter, make_tree()) + assert isinstance(new_m, MultiError) assert len(new_m.exceptions) == 2 # was: [[ValueError, KeyError], NameError] @@ -243,9 +233,9 @@ def simple_filter(exc): assert isinstance(orig.exceptions[0].exceptions[1], KeyError) # get original traceback summary orig_extracted = ( - extract_tb(orig.__traceback__) + - extract_tb(orig.exceptions[0].__traceback__) + - extract_tb(orig.exceptions[0].exceptions[1].__traceback__) + extract_tb(orig.__traceback__) + + extract_tb(orig.exceptions[0].__traceback__) + + extract_tb(orig.exceptions[0].exceptions[1].__traceback__) ) def p(exc): @@ -267,7 +257,8 @@ def filter_NameError(exc): return exc m = make_tree() - new_m = MultiError.filter(filter_NameError, m) + with pytest.warns(TrioDeprecationWarning): + new_m = MultiError.filter(filter_NameError, m) # with the NameError gone, the other branch gets promoted assert new_m is m.exceptions[0] @@ -275,7 +266,8 @@ def filter_NameError(exc): def filter_all(exc): return None - assert MultiError.filter(filter_all, make_tree()) is None + with pytest.warns(TrioDeprecationWarning): + assert MultiError.filter(filter_all, make_tree()) is None def test_MultiError_catch(): @@ -284,13 +276,13 @@ def test_MultiError_catch(): def noop(_): pass # pragma: no cover - with MultiError.catch(noop): + with pytest.warns(TrioDeprecationWarning), MultiError.catch(noop): pass # Simple pass-through of all exceptions m = make_tree() with pytest.raises(MultiError) as excinfo: - with MultiError.catch(lambda exc: exc): + with pytest.warns(TrioDeprecationWarning), MultiError.catch(lambda exc: exc): raise m assert excinfo.value is m # Should be unchanged, except that we added a traceback frame by raising @@ -302,7 +294,7 @@ def noop(_): assert_tree_eq(m, make_tree()) # Swallows everything - with MultiError.catch(lambda _: None): + with pytest.warns(TrioDeprecationWarning), MultiError.catch(lambda _: None): raise make_tree() def simple_filter(exc): @@ -313,7 +305,7 @@ def simple_filter(exc): return exc with pytest.raises(MultiError) as excinfo: - with MultiError.catch(simple_filter): + with pytest.warns(TrioDeprecationWarning), MultiError.catch(simple_filter): raise make_tree() new_m = excinfo.value assert isinstance(new_m, MultiError) @@ -331,7 +323,7 @@ def simple_filter(exc): v = ValueError() v.__cause__ = KeyError() with pytest.raises(ValueError) as excinfo: - with MultiError.catch(lambda exc: exc): + with pytest.warns(TrioDeprecationWarning), MultiError.catch(lambda exc: exc): raise v assert isinstance(excinfo.value.__cause__, KeyError) @@ -339,7 +331,7 @@ def simple_filter(exc): context = KeyError() v.__context__ = context with pytest.raises(ValueError) as excinfo: - with MultiError.catch(lambda exc: exc): + with pytest.warns(TrioDeprecationWarning), MultiError.catch(lambda exc: exc): raise v assert excinfo.value.__context__ is context assert not excinfo.value.__suppress_context__ @@ -358,12 +350,45 @@ def catch_RuntimeError(exc): else: return exc - with MultiError.catch(catch_RuntimeError): - raise MultiError([v, distractor]) + with pytest.warns(TrioDeprecationWarning): + with MultiError.catch(catch_RuntimeError): + raise MultiError([v, distractor]) assert excinfo.value.__context__ is context assert excinfo.value.__suppress_context__ == suppress_context +@pytest.mark.skipif( + sys.implementation.name != "cpython", reason="Only makes sense with refcounting GC" +) +def test_MultiError_catch_doesnt_create_cyclic_garbage(): + # https://github.com/python-trio/trio/pull/2063 + gc.collect() + old_flags = gc.get_debug() + + def make_multi(): + # make_tree creates cycles itself, so a simple + raise MultiError([get_exc(raiser1), get_exc(raiser2)]) + + def simple_filter(exc): + if isinstance(exc, ValueError): + return Exception() + if isinstance(exc, KeyError): + return RuntimeError() + assert False, "only ValueError and KeyError should exist" # pragma: no cover + + try: + gc.set_debug(gc.DEBUG_SAVEALL) + with pytest.raises(MultiError): + # covers MultiErrorCatcher.__exit__ and _multierror.copy_tb + with pytest.warns(TrioDeprecationWarning), MultiError.catch(simple_filter): + raise make_multi() + gc.collect() + assert not gc.garbage + finally: + gc.set_debug(old_flags) + gc.garbage.clear() + + def assert_match_in_seq(pattern_list, string): offset = 0 print("looking for pattern matches...") @@ -382,205 +407,30 @@ def test_assert_match_in_seq(): assert_match_in_seq(["a", "b"], "xx b xx a xx") -def test_format_exception(): - exc = get_exc(raiser1) - formatted = "".join(format_exception(*einfo(exc))) - assert "raiser1_string" in formatted - assert "in raiser1_3" in formatted - assert "raiser2_string" not in formatted - assert "in raiser2_2" not in formatted - assert "direct cause" not in formatted - assert "During handling" not in formatted - - exc = get_exc(raiser1) - exc.__cause__ = get_exc(raiser2) - formatted = "".join(format_exception(*einfo(exc))) - assert "raiser1_string" in formatted - assert "in raiser1_3" in formatted - assert "raiser2_string" in formatted - assert "in raiser2_2" in formatted - assert "direct cause" in formatted - assert "During handling" not in formatted - # ensure cause included - assert _cause_message in formatted - - exc = get_exc(raiser1) - exc.__context__ = get_exc(raiser2) - formatted = "".join(format_exception(*einfo(exc))) - assert "raiser1_string" in formatted - assert "in raiser1_3" in formatted - assert "raiser2_string" in formatted - assert "in raiser2_2" in formatted - assert "direct cause" not in formatted - assert "During handling" in formatted - - exc.__suppress_context__ = True - formatted = "".join(format_exception(*einfo(exc))) - assert "raiser1_string" in formatted - assert "in raiser1_3" in formatted - assert "raiser2_string" not in formatted - assert "in raiser2_2" not in formatted - assert "direct cause" not in formatted - assert "During handling" not in formatted - - # chain=False - exc = get_exc(raiser1) - exc.__context__ = get_exc(raiser2) - formatted = "".join(format_exception(*einfo(exc), chain=False)) - assert "raiser1_string" in formatted - assert "in raiser1_3" in formatted - assert "raiser2_string" not in formatted - assert "in raiser2_2" not in formatted - assert "direct cause" not in formatted - assert "During handling" not in formatted - - # limit - exc = get_exc(raiser1) - exc.__context__ = get_exc(raiser2) - # get_exc adds a frame that counts against the limit, so limit=2 means we - # get 1 deep into the raiser stack - formatted = "".join(format_exception(*einfo(exc), limit=2)) - print(formatted) - assert "raiser1_string" in formatted - assert "in raiser1" in formatted - assert "in raiser1_2" not in formatted - assert "raiser2_string" in formatted - assert "in raiser2" in formatted - assert "in raiser2_2" not in formatted - - exc = get_exc(raiser1) - exc.__context__ = get_exc(raiser2) - formatted = "".join(format_exception(*einfo(exc), limit=1)) - print(formatted) - assert "raiser1_string" in formatted - assert "in raiser1" not in formatted - assert "raiser2_string" in formatted - assert "in raiser2" not in formatted - - # handles loops - exc = get_exc(raiser1) - exc.__cause__ = exc - formatted = "".join(format_exception(*einfo(exc))) - assert "raiser1_string" in formatted - assert "in raiser1_3" in formatted - assert "raiser2_string" not in formatted - assert "in raiser2_2" not in formatted - # ensure duplicate exception is not included as cause - assert _cause_message not in formatted - - # MultiError - formatted = "".join(format_exception(*einfo(make_tree()))) - print(formatted) - - assert_match_in_seq( - [ - # Outer exception is MultiError - r"MultiError:", - # First embedded exception is the embedded MultiError - r"\nDetails of embedded exception 1", - # Which has a single stack frame from make_tree raising it - r"in make_tree", - # Then it has two embedded exceptions - r" Details of embedded exception 1", - r"in raiser1_2", - # for some reason ValueError has no quotes - r"ValueError: raiser1_string", - r" Details of embedded exception 2", - r"in raiser2_2", - # But KeyError does have quotes - r"KeyError: 'raiser2_string'", - # And finally the NameError, which is a sibling of the embedded - # MultiError - r"\nDetails of embedded exception 2:", - r"in raiser3", - r"NameError", - ], - formatted - ) - - # Prints duplicate exceptions in sub-exceptions - exc1 = get_exc(raiser1) - - def raise1_raiser1(): - try: - raise exc1 - except: - raise ValueError("foo") - - def raise2_raiser1(): - try: - raise exc1 - except: - raise KeyError("bar") +def test_base_multierror(): + """ + Test that MultiError() with at least one base exception will return a MultiError + object. + """ - exc2 = get_exc(raise1_raiser1) - exc3 = get_exc(raise2_raiser1) + exc = MultiError([ZeroDivisionError(), KeyboardInterrupt()]) + assert type(exc) is MultiError - try: - raise MultiError([exc2, exc3]) - except MultiError as e: - exc = e - formatted = "".join(format_exception(*einfo(exc))) - print(formatted) +def test_non_base_multierror(): + """ + Test that MultiError() without base exceptions will return a NonBaseMultiError + object. + """ - assert_match_in_seq( - [ - r"Traceback", - # Outer exception is MultiError - r"MultiError:", - # First embedded exception is the embedded ValueError with cause of raiser1 - r"\nDetails of embedded exception 1", - # Print details of exc1 - r" Traceback", - r"in get_exc", - r"in raiser1", - r"ValueError: raiser1_string", - # Print details of exc2 - r"\n During handling of the above exception, another exception occurred:", - r" Traceback", - r"in get_exc", - r"in raise1_raiser1", - r" ValueError: foo", - # Second embedded exception is the embedded KeyError with cause of raiser1 - r"\nDetails of embedded exception 2", - # Print details of exc1 again - r" Traceback", - r"in get_exc", - r"in raiser1", - r"ValueError: raiser1_string", - # Print details of exc3 - r"\n During handling of the above exception, another exception occurred:", - r" Traceback", - r"in get_exc", - r"in raise2_raiser1", - r" KeyError: 'bar'", - ], - formatted - ) - - -def test_logging(caplog): - exc1 = get_exc(raiser1) - exc2 = get_exc(raiser2) - - m = MultiError([exc1, exc2]) - - message = "test test test" - try: - raise m - except MultiError as exc: - logging.getLogger().exception(message) - # Join lines together - formatted = "".join( - format_exception(type(exc), exc, exc.__traceback__) - ) - assert message in caplog.text - assert formatted in caplog.text + exc = MultiError([ZeroDivisionError(), ValueError()]) + assert type(exc) is NonBaseMultiError + assert isinstance(exc, ExceptionGroup) def run_script(name, use_ipython=False): import trio + trio_path = Path(trio.__file__).parent.parent script_path = Path(__file__).parent / "test_multierror_scripts" / name @@ -596,7 +446,11 @@ def run_script(name, use_ipython=False): print("subprocess PYTHONPATH:", env.get("PYTHONPATH")) if use_ipython: - lines = [script_path.open().read(), "exit()"] + lines = [ + "import runpy", + f"runpy.run_path(r'{script_path}', run_name='trio.fake')", + "exit()", + ] cmd = [ sys.executable, @@ -605,7 +459,7 @@ def run_script(name, use_ipython=False): "IPython", # no startup files "--quick", - "--TerminalIPythonApp.code_to_run=" + '\n'.join(lines), + "--TerminalIPythonApp.code_to_run=" + "\n".join(lines), ] else: cmd = [sys.executable, "-u", str(script_path)] @@ -623,49 +477,19 @@ def check_simple_excepthook(completed): [ "in ", "MultiError", - "Details of embedded exception 1", + "--- 1 ---", "in exc1_fn", "ValueError", - "Details of embedded exception 2", + "--- 2 ---", "in exc2_fn", "KeyError", - ], completed.stdout.decode("utf-8") - ) - - -def test_simple_excepthook(): - completed = run_script("simple_excepthook.py") - check_simple_excepthook(completed) - - -def test_custom_excepthook(): - # Check that user-defined excepthooks aren't overridden - completed = run_script("custom_excepthook.py") - assert_match_in_seq( - [ - # The warning - "RuntimeWarning", - "already have a custom", - # The message printed by the custom hook, proving we didn't - # override it - "custom running!", - # The MultiError - "MultiError:", ], - completed.stdout.decode("utf-8") + completed.stdout.decode("utf-8"), ) -# This warning is triggered by ipython 7.5.0 on python 3.8 -import warnings -warnings.filterwarnings( - "ignore", - message=".*\"@coroutine\" decorator is deprecated", - category=DeprecationWarning, - module="IPython.*" -) try: - import IPython + import IPython # noqa: F401 except ImportError: # pragma: no cover have_ipython = False else: @@ -705,7 +529,59 @@ def test_ipython_custom_exc_handler(): "ValueError", "KeyError", ], - completed.stdout.decode("utf-8") + completed.stdout.decode("utf-8"), ) # Make sure our other warning doesn't show up assert "custom sys.excepthook" not in completed.stdout.decode("utf-8") + + +@slow +@pytest.mark.skipif( + not Path("/usr/lib/python3/dist-packages/apport_python_hook.py").exists(), + reason="need Ubuntu with python3-apport installed", +) +def test_apport_excepthook_monkeypatch_interaction(): + completed = run_script("apport_excepthook.py") + stdout = completed.stdout.decode("utf-8") + + # No warning + assert "custom sys.excepthook" not in stdout + + # Proper traceback + assert_match_in_seq( + ["--- 1 ---", "KeyError", "--- 2 ---", "ValueError"], + stdout, + ) + + +@pytest.mark.parametrize("protocol", range(0, pickle.HIGHEST_PROTOCOL + 1)) +def test_pickle_multierror(protocol) -> None: + # use trio.MultiError to make sure that pickle works through the deprecation layer + import trio + + my_except = ZeroDivisionError() + + try: + 1 / 0 + except ZeroDivisionError as e: + my_except = e + + # MultiError will collapse into different classes depending on the errors + for cls, errors in ( + (ZeroDivisionError, [my_except]), + (NonBaseMultiError, [my_except, ValueError()]), + (MultiError, [BaseException(), my_except]), + ): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", TrioDeprecationWarning) + me = trio.MultiError(errors) # type: ignore[attr-defined] + dump = pickle.dumps(me, protocol=protocol) + load = pickle.loads(dump) + assert repr(me) == repr(load) + assert me.__class__ == load.__class__ == cls + + assert me.__dict__.keys() == load.__dict__.keys() + for me_val, load_val in zip(me.__dict__.values(), load.__dict__.values()): + # tracebacks etc are not preserved through pickling for the default + # exceptions, so we only check that the repr stays the same + assert repr(me_val) == repr(load_val) diff --git a/trio/_core/tests/test_multierror_scripts/__init__.py b/trio/_core/_tests/test_multierror_scripts/__init__.py similarity index 100% rename from trio/_core/tests/test_multierror_scripts/__init__.py rename to trio/_core/_tests/test_multierror_scripts/__init__.py diff --git a/trio/_core/tests/test_multierror_scripts/_common.py b/trio/_core/_tests/test_multierror_scripts/_common.py similarity index 100% rename from trio/_core/tests/test_multierror_scripts/_common.py rename to trio/_core/_tests/test_multierror_scripts/_common.py diff --git a/trio/_core/_tests/test_multierror_scripts/apport_excepthook.py b/trio/_core/_tests/test_multierror_scripts/apport_excepthook.py new file mode 100644 index 0000000000..3e1d23ca8e --- /dev/null +++ b/trio/_core/_tests/test_multierror_scripts/apport_excepthook.py @@ -0,0 +1,15 @@ +# The apport_python_hook package is only installed as part of Ubuntu's system +# python, and not available in venvs. So before we can import it we have to +# make sure it's on sys.path. +import sys + +import _common + +sys.path.append("/usr/lib/python3/dist-packages") +import apport_python_hook + +apport_python_hook.install() + +import trio + +raise trio.MultiError([KeyError("key_error"), ValueError("value_error")]) diff --git a/trio/_core/tests/test_multierror_scripts/ipython_custom_exc.py b/trio/_core/_tests/test_multierror_scripts/ipython_custom_exc.py similarity index 99% rename from trio/_core/tests/test_multierror_scripts/ipython_custom_exc.py rename to trio/_core/_tests/test_multierror_scripts/ipython_custom_exc.py index 017c5ea059..80e42b6a2c 100644 --- a/trio/_core/tests/test_multierror_scripts/ipython_custom_exc.py +++ b/trio/_core/_tests/test_multierror_scripts/ipython_custom_exc.py @@ -1,10 +1,10 @@ -import _common - # Override the regular excepthook too -- it doesn't change anything either way # because ipython doesn't use it, but we want to make sure Trio doesn't warn # about it. import sys +import _common + def custom_excepthook(*args): print("custom running!") @@ -14,6 +14,7 @@ def custom_excepthook(*args): sys.excepthook = custom_excepthook import IPython + ip = IPython.get_ipython() diff --git a/trio/_core/tests/test_multierror_scripts/simple_excepthook.py b/trio/_core/_tests/test_multierror_scripts/simple_excepthook.py similarity index 100% rename from trio/_core/tests/test_multierror_scripts/simple_excepthook.py rename to trio/_core/_tests/test_multierror_scripts/simple_excepthook.py diff --git a/trio/_core/tests/test_multierror_scripts/simple_excepthook_IPython.py b/trio/_core/_tests/test_multierror_scripts/simple_excepthook_IPython.py similarity index 99% rename from trio/_core/tests/test_multierror_scripts/simple_excepthook_IPython.py rename to trio/_core/_tests/test_multierror_scripts/simple_excepthook_IPython.py index 6aa12493b0..51a88c96ce 100644 --- a/trio/_core/tests/test_multierror_scripts/simple_excepthook_IPython.py +++ b/trio/_core/_tests/test_multierror_scripts/simple_excepthook_IPython.py @@ -3,5 +3,4 @@ # To tickle the "is IPython loaded?" logic, make sure that Trio tolerates # IPython loaded but not actually in use import IPython - import simple_excepthook diff --git a/trio/_core/tests/test_parking_lot.py b/trio/_core/_tests/test_parking_lot.py similarity index 87% rename from trio/_core/tests/test_parking_lot.py rename to trio/_core/_tests/test_parking_lot.py index 95e4a96b50..3f03fdbade 100644 --- a/trio/_core/tests/test_parking_lot.py +++ b/trio/_core/_tests/test_parking_lot.py @@ -2,17 +2,17 @@ from ... import _core from ...testing import wait_all_tasks_blocked -from .tutil import check_sequence_matches from .._parking_lot import ParkingLot +from .tutil import check_sequence_matches async def test_parking_lot_basic(): record = [] async def waiter(i, lot): - record.append("sleep {}".format(i)) + record.append(f"sleep {i}") await lot.park() - record.append("wake {}".format(i)) + record.append(f"wake {i}") async with _core.open_nursery() as nursery: lot = ParkingLot() @@ -32,10 +32,7 @@ async def waiter(i, lot): assert len(record) == 6 check_sequence_matches( - record, [ - {"sleep 0", "sleep 1", "sleep 2"}, - {"wake 0", "wake 1", "wake 2"}, - ] + record, [{"sleep 0", "sleep 1", "sleep 2"}, {"wake 0", "wake 1", "wake 2"}] ) async with _core.open_nursery() as nursery: @@ -71,26 +68,24 @@ async def waiter(i, lot): lot.unpark(count=2) await wait_all_tasks_blocked() check_sequence_matches( - record, [ - "sleep 0", - "sleep 1", - "sleep 2", - {"wake 0", "wake 1"}, - ] + record, ["sleep 0", "sleep 1", "sleep 2", {"wake 0", "wake 1"}] ) lot.unpark_all() + with pytest.raises(ValueError): + lot.unpark(count=1.5) + async def cancellable_waiter(name, lot, scopes, record): with _core.CancelScope() as scope: scopes[name] = scope - record.append("sleep {}".format(name)) + record.append(f"sleep {name}") try: await lot.park() except _core.Cancelled: - record.append("cancelled {}".format(name)) + record.append(f"cancelled {name}") else: - record.append("wake {}".format(name)) + record.append(f"wake {name}") async def test_parking_lot_cancel(): @@ -115,13 +110,7 @@ async def test_parking_lot_cancel(): assert len(record) == 6 check_sequence_matches( - record, [ - "sleep 1", - "sleep 2", - "sleep 3", - "cancelled 2", - {"wake 1", "wake 3"}, - ] + record, ["sleep 1", "sleep 2", "sleep 3", "cancelled 2", {"wake 1", "wake 3"}] ) @@ -160,13 +149,22 @@ async def test_parking_lot_repark(): await wait_all_tasks_blocked() assert len(lot2) == 1 assert record == [ - "sleep 1", "sleep 2", "sleep 3", "wake 1", "cancelled 2" + "sleep 1", + "sleep 2", + "sleep 3", + "wake 1", + "cancelled 2", ] lot2.unpark_all() await wait_all_tasks_blocked() assert record == [ - "sleep 1", "sleep 2", "sleep 3", "wake 1", "cancelled 2", "wake 3" + "sleep 1", + "sleep 2", + "sleep 3", + "wake 1", + "cancelled 2", + "wake 3", ] diff --git a/trio/_core/tests/test_run.py b/trio/_core/_tests/test_run.py similarity index 82% rename from trio/_core/tests/test_run.py rename to trio/_core/_tests/test_run.py index fc03197133..81c3b73cc4 100644 --- a/trio/_core/tests/test_run.py +++ b/trio/_core/_tests/test_run.py @@ -1,32 +1,37 @@ import contextvars import functools -import platform +import gc import sys import threading import time import types -import warnings -from contextlib import contextmanager, ExitStack +import weakref +from contextlib import ExitStack from math import inf -from textwrap import dedent -import attr import outcome -import sniffio import pytest -from async_generator import async_generator +import sniffio -from .tutil import slow, check_sequence_matches, gc_collect_harder from ... import _core +from ..._core._multierror import MultiError, NonBaseMultiError from ..._threads import to_thread_run_sync -from ..._timeouts import sleep, fail_after -from ..._util import aiter_compat -from ...testing import ( - wait_all_tasks_blocked, - Sequencer, - assert_checkpoints, +from ..._timeouts import fail_after, sleep +from ...testing import Sequencer, assert_checkpoints, wait_all_tasks_blocked +from .._run import DEADLINE_HEAP_MIN_PRUNE_THRESHOLD +from .tutil import ( + buggy_pypy_asyncgens, + check_sequence_matches, + create_asyncio_future_in_new_loop, + gc_collect_harder, + ignore_coroutine_never_awaited_warnings, + restore_unraisablehook, + slow, ) +if sys.version_info < (3, 11): + from exceptiongroup import ExceptionGroup + # slightly different from _timeouts.sleep_forever because it returns the value # its rescheduled with, which is really only useful for tests of @@ -35,26 +40,6 @@ async def sleep_forever(): return await _core.wait_task_rescheduled(lambda _: _core.Abort.SUCCEEDED) -# Some of our tests need to leak coroutines, and thus trigger the -# "RuntimeWarning: coroutine '...' was never awaited" message. This context -# manager should be used anywhere this happens to hide those messages, because -# (a) when expected they're clutter, (b) on CPython 3.5.x where x < 3, this -# warning can trigger a segfault if we run with warnings turned into errors: -# https://bugs.python.org/issue27811 -@contextmanager -def ignore_coroutine_never_awaited_warnings(): - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", message="coroutine '.*' was never awaited" - ) - try: - yield - finally: - # Make sure to trigger any coroutine __del__ methods now, before - # we leave the context manager. - gc_collect_harder() - - def test_basic(): async def trivial(x): return x @@ -146,8 +131,7 @@ async def looper(whoami, record): nursery.start_soon(looper, "b", record) check_sequence_matches( - record, - [{("a", 0), ("b", 0)}, {("a", 1), ("b", 1)}, {("a", 2), ("b", 2)}] + record, [{("a", 0), ("b", 0)}, {("a", 1), ("b", 1)}, {("a", 2), ("b", 2)}] ) @@ -188,11 +172,13 @@ async def main(): nursery.start_soon(crasher) raise KeyError - with pytest.raises(_core.MultiError) as excinfo: + with pytest.raises(MultiError) as excinfo: _core.run(main) print(excinfo.value) - assert {type(exc) - for exc in excinfo.value.exceptions} == {ValueError, KeyError} + assert {type(exc) for exc in excinfo.value.exceptions} == { + ValueError, + KeyError, + } def test_two_child_crashes(): @@ -204,10 +190,12 @@ async def main(): nursery.start_soon(crasher, KeyError) nursery.start_soon(crasher, ValueError) - with pytest.raises(_core.MultiError) as excinfo: + with pytest.raises(MultiError) as excinfo: _core.run(main) - assert {type(exc) - for exc in excinfo.value.exceptions} == {ValueError, KeyError} + assert {type(exc) for exc in excinfo.value.exceptions} == { + ValueError, + KeyError, + } async def test_child_crash_wakes_parent(): @@ -285,7 +273,7 @@ async def child(): async def test_root_task(): root = _core.current_root_task() - assert root.parent_nursery is None + assert root.parent_nursery is root.eventual_parent_nursery is None def test_out_of_context(): @@ -347,202 +335,6 @@ async def child(): assert stats.seconds_to_next_deadline == inf -@attr.s(eq=False, hash=False) -class TaskRecorder: - record = attr.ib(factory=list) - - def before_run(self): - self.record.append(("before_run",)) - - def task_scheduled(self, task): - self.record.append(("schedule", task)) - - def before_task_step(self, task): - assert task is _core.current_task() - self.record.append(("before", task)) - - def after_task_step(self, task): - assert task is _core.current_task() - self.record.append(("after", task)) - - def after_run(self): - self.record.append(("after_run",)) - - def filter_tasks(self, tasks): - for item in self.record: - if item[0] in ("schedule", "before", "after") and item[1] in tasks: - yield item - if item[0] in ("before_run", "after_run"): - yield item - - -def test_instruments(recwarn): - r1 = TaskRecorder() - r2 = TaskRecorder() - r3 = TaskRecorder() - - task = None - - # We use a child task for this, because the main task does some extra - # bookkeeping stuff that can leak into the instrument results, and we - # don't want to deal with it. - async def task_fn(): - nonlocal task - task = _core.current_task() - - for _ in range(4): - await _core.checkpoint() - # replace r2 with r3, to test that we can manipulate them as we go - _core.remove_instrument(r2) - with pytest.raises(KeyError): - _core.remove_instrument(r2) - # add is idempotent - _core.add_instrument(r3) - _core.add_instrument(r3) - for _ in range(1): - await _core.checkpoint() - - async def main(): - async with _core.open_nursery() as nursery: - nursery.start_soon(task_fn) - - _core.run(main, instruments=[r1, r2]) - - # It sleeps 5 times, so it runs 6 times. Note that checkpoint() - # reschedules the task immediately upon yielding, before the - # after_task_step event fires. - expected = ( - [("before_run",), ("schedule", task)] + - [("before", task), ("schedule", task), ("after", task)] * 5 + - [("before", task), ("after", task), ("after_run",)] - ) - assert len(r1.record) > len(r2.record) > len(r3.record) - assert r1.record == r2.record + r3.record - assert list(r1.filter_tasks([task])) == expected - - -def test_instruments_interleave(): - tasks = {} - - async def two_step1(): - tasks["t1"] = _core.current_task() - await _core.checkpoint() - - async def two_step2(): - tasks["t2"] = _core.current_task() - await _core.checkpoint() - - async def main(): - async with _core.open_nursery() as nursery: - nursery.start_soon(two_step1) - nursery.start_soon(two_step2) - - r = TaskRecorder() - _core.run(main, instruments=[r]) - - expected = [ - ("before_run",), - ("schedule", tasks["t1"]), - ("schedule", tasks["t2"]), - { - ("before", tasks["t1"]), - ("schedule", tasks["t1"]), - ("after", tasks["t1"]), - ("before", tasks["t2"]), - ("schedule", tasks["t2"]), - ("after", tasks["t2"]) - }, - { - ("before", tasks["t1"]), - ("after", tasks["t1"]), - ("before", tasks["t2"]), - ("after", tasks["t2"]) - }, - ("after_run",), - ] # yapf: disable - print(list(r.filter_tasks(tasks.values()))) - check_sequence_matches(list(r.filter_tasks(tasks.values())), expected) - - -def test_null_instrument(): - # undefined instrument methods are skipped - class NullInstrument: - pass - - async def main(): - await _core.checkpoint() - - _core.run(main, instruments=[NullInstrument()]) - - -def test_instrument_before_after_run(): - record = [] - - class BeforeAfterRun: - def before_run(self): - record.append("before_run") - - def after_run(self): - record.append("after_run") - - async def main(): - pass - - _core.run(main, instruments=[BeforeAfterRun()]) - assert record == ["before_run", "after_run"] - - -def test_instrument_task_spawn_exit(): - record = [] - - class SpawnExitRecorder: - def task_spawned(self, task): - record.append(("spawned", task)) - - def task_exited(self, task): - record.append(("exited", task)) - - async def main(): - return _core.current_task() - - main_task = _core.run(main, instruments=[SpawnExitRecorder()]) - assert ("spawned", main_task) in record - assert ("exited", main_task) in record - - -# This test also tests having a crash before the initial task is even spawned, -# which is very difficult to handle. -def test_instruments_crash(caplog): - record = [] - - class BrokenInstrument: - def task_scheduled(self, task): - record.append("scheduled") - raise ValueError("oops") - - def close(self): - # Shouldn't be called -- tests that the instrument disabling logic - # works right. - record.append("closed") # pragma: no cover - - async def main(): - record.append("main ran") - return _core.current_task() - - r = TaskRecorder() - main_task = _core.run(main, instruments=[r, BrokenInstrument()]) - assert record == ["scheduled", "main ran"] - # the TaskRecorder kept going throughout, even though the BrokenInstrument - # was disabled - assert ("after", main_task) in r.record - assert ("after_run",) in r.record - # And we got a log message - exc_type, exc_value, exc_traceback = caplog.records[0].exc_info - assert exc_type is ValueError - assert str(exc_value) == "oops" - assert "Instrument has been disabled" in caplog.records[0].message - - async def test_cancel_scope_repr(mock_clock): scope = _core.CancelScope() assert "unbound" in repr(scope) @@ -635,7 +427,7 @@ async def crasher(): # And one that raises a different error nursery.start_soon(crasher) # t4 # and then our __aexit__ also receives an outer Cancelled - except _core.MultiError as multi_exc: + except MultiError as multi_exc: # Since the outer scope became cancelled before the # nursery block exited, all cancellations inside the # nursery block continue propagating to reach the @@ -650,7 +442,7 @@ async def crasher(): except AssertionError: # pragma: no cover raise except BaseException as exc: - # This is ouside the outer scope, so all the Cancelled + # This is outside the outer scope, so all the Cancelled # exceptions should have been absorbed, leaving just a regular # KeyError from crasher() assert type(exc) is KeyError @@ -805,13 +597,6 @@ async def test_basic_timeout(mock_clock): await _core.checkpoint() -@pytest.mark.filterwarnings( - "ignore:.*trio.open_cancel_scope:trio.TrioDeprecationWarning" -) -async def test_cancel_scope_deprecated(recwarn): - assert isinstance(_core.open_cancel_scope(), _core.CancelScope) - - async def test_cancel_scope_nesting(): # Nested scopes: if two triggering at once, the outer one wins with _core.CancelScope() as scope1: @@ -981,15 +766,13 @@ async def task2(): with pytest.raises(RuntimeError) as exc_info: await nursery_mgr.__aexit__(*sys.exc_info()) assert "which had already been exited" in str(exc_info.value) - assert type(exc_info.value.__context__) is _core.MultiError + assert type(exc_info.value.__context__) is NonBaseMultiError assert len(exc_info.value.__context__.exceptions) == 3 cancelled_in_context = False for exc in exc_info.value.__context__.exceptions: assert isinstance(exc, RuntimeError) assert "closed before the task exited" in str(exc) - cancelled_in_context |= isinstance( - exc.__context__, _core.Cancelled - ) + cancelled_in_context |= isinstance(exc.__context__, _core.Cancelled) assert cancelled_in_context # for the sleep_forever # Trying to exit a cancel scope from an unrelated task raises an error @@ -1058,6 +841,7 @@ async def stubborn_sleeper(): assert record == ["sleep", "woke", "cancelled"] +@restore_unraisablehook() def test_broken_abort(): async def main(): # These yields are here to work around an annoying warning -- we're @@ -1084,6 +868,7 @@ async def main(): gc_collect_harder() +@restore_unraisablehook() def test_error_in_run_loop(): # Blow stuff up real good to check we at least get a TrioInternalError async def main(): @@ -1142,7 +927,7 @@ async def main(): _core.run(main) me = excinfo.value.__cause__ - assert isinstance(me, _core.MultiError) + assert isinstance(me, MultiError) assert len(me.exceptions) == 2 for exc in me.exceptions: assert isinstance(exc, (KeyError, ValueError)) @@ -1255,13 +1040,18 @@ async def child2(): nursery.start_soon(child2) assert record == [ - "child1 raise", "child1 sleep", "child2 wake", "child2 sleep again", - "child1 re-raise", "child1 success", "child2 re-raise", - "child2 success" + "child1 raise", + "child1 sleep", + "child2 wake", + "child2 sleep again", + "child1 re-raise", + "child1 success", + "child2 re-raise", + "child2 success", ] -# At least as of CPython 3.6, using .throw() to raise an exception inside a +# Before CPython 3.9, using .throw() to raise an exception inside a # coroutine/generator causes the original exc_info state to be lost, so things # like re-raising and exception chaining are broken. # @@ -1316,7 +1106,7 @@ async def test_nursery_exception_chaining_doesnt_make_context_loops(): async def crasher(): raise KeyError - with pytest.raises(_core.MultiError) as excinfo: + with pytest.raises(MultiError) as excinfo: async with _core.open_nursery() as nursery: nursery.start_soon(crasher) raise ValueError @@ -1387,13 +1177,7 @@ def cb(x): for i in range(100): token.run_sync_soon(cb, i, idempotent=True) await wait_all_tasks_blocked() - if ( - sys.version_info < (3, 6) - and platform.python_implementation() == "CPython" - ): - # no order guarantees - record.sort() - # Otherwise, we guarantee FIFO + # We guarantee FIFO assert record == list(range(100)) @@ -1549,6 +1333,33 @@ def cb(i): assert counter[0] == COUNT +@pytest.mark.skipif(buggy_pypy_asyncgens, reason="PyPy 7.2 is buggy") +def test_TrioToken_run_sync_soon_late_crash(): + # Crash after system nursery is closed -- easiest way to do that is + # from an async generator finalizer. + record = [] + saved = [] + + async def agen(): + token = _core.current_trio_token() + try: + yield 1 + finally: + token.run_sync_soon(lambda: {}["nope"]) + token.run_sync_soon(lambda: record.append("2nd ran")) + + async def main(): + saved.append(agen()) + await saved[-1].asend(None) + record.append("main exiting") + + with pytest.raises(_core.TrioInternalError) as excinfo: + _core.run(main) + + assert type(excinfo.value.__cause__) is KeyError + assert record == ["main exiting", "2nd ran"] + + async def test_slow_abort_basic(): with _core.CancelScope() as scope: scope.cancel() @@ -1611,7 +1422,7 @@ async def test_task_tree_introspection(): tasks = {} nurseries = {} - async def parent(): + async def parent(task_status=_core.TASK_STATUS_IGNORED): tasks["parent"] = _core.current_task() assert tasks["parent"].child_nurseries == [] @@ -1624,7 +1435,7 @@ async def parent(): async with _core.open_nursery() as nursery: nurseries["parent"] = nursery - nursery.start_soon(child1) + await nursery.start(child1) # Upward links survive after tasks/nurseries exit assert nurseries["parent"].parent_task is tasks["parent"] @@ -1647,8 +1458,19 @@ async def child2(): assert nurseries["child1"].child_tasks == frozenset({tasks["child2"]}) assert tasks["child2"].child_nurseries == [] - async def child1(): - tasks["child1"] = _core.current_task() + async def child1(task_status=_core.TASK_STATUS_IGNORED): + me = tasks["child1"] = _core.current_task() + assert me.parent_nursery.parent_task is tasks["parent"] + assert me.parent_nursery is not nurseries["parent"] + assert me.eventual_parent_nursery is nurseries["parent"] + task_status.started() + assert me.parent_nursery is nurseries["parent"] + assert me.eventual_parent_nursery is None + + # Wait for the start() call to return and close its internal nursery, to + # ensure consistent results in child2: + await _core.wait_all_tasks_blocked() + async with _core.open_nursery() as nursery: nurseries["child1"] = nursery nursery.start_soon(child2) @@ -1656,6 +1478,11 @@ async def child1(): async with _core.open_nursery() as nursery: nursery.start_soon(parent) + # There are no pending starts, so no one should have a non-None + # eventual_parent_nursery + for task in tasks.values(): + assert task.eventual_parent_nursery is None + async def test_nursery_closure(): async def child1(nursery): @@ -1713,8 +1540,6 @@ async def test_current_effective_deadline(mock_clock): assert _core.current_effective_deadline() == inf -# @coroutine is deprecated since python 3.8, which is fine with us. -@pytest.mark.filterwarnings("ignore:.*@coroutine.*:DeprecationWarning") def test_nice_error_on_bad_calls_to_run_or_spawn(): def bad_call_run(*args): _core.run(*args) @@ -1726,66 +1551,26 @@ async def main(): _core.run(main) - class Deferred: - "Just kidding" - - with ignore_coroutine_never_awaited_warnings(): - for bad_call in bad_call_run, bad_call_spawn: - - async def f(): # pragma: no cover - pass - - with pytest.raises(TypeError) as excinfo: - bad_call(f()) - assert "expecting an async function" in str(excinfo.value) - - import asyncio - - @asyncio.coroutine - def generator_based_coro(): # pragma: no cover - yield from asyncio.sleep(1) - - with pytest.raises(TypeError) as excinfo: - bad_call(generator_based_coro()) - assert "asyncio" in str(excinfo.value) - - with pytest.raises(TypeError) as excinfo: - bad_call(asyncio.Future()) - assert "asyncio" in str(excinfo.value) - - with pytest.raises(TypeError) as excinfo: - bad_call(lambda: asyncio.Future()) - assert "asyncio" in str(excinfo.value) - - with pytest.raises(TypeError) as excinfo: - bad_call(Deferred()) - assert "twisted" in str(excinfo.value) - - with pytest.raises(TypeError) as excinfo: - bad_call(lambda: Deferred()) - assert "twisted" in str(excinfo.value) + for bad_call in bad_call_run, bad_call_spawn: - with pytest.raises(TypeError) as excinfo: - bad_call(len, [1, 2, 3]) - assert "appears to be synchronous" in str(excinfo.value) + async def f(): # pragma: no cover + pass - @async_generator - async def async_gen(arg): # pragma: no cover - pass + with pytest.raises(TypeError, match="expecting an async function"): + bad_call(f()) - with pytest.raises(TypeError) as excinfo: - bad_call(async_gen, 0) - msg = "expected an async function but got an async generator" - assert msg in str(excinfo.value) + async def async_gen(arg): # pragma: no cover + yield arg - # Make sure no references are kept around to keep anything alive - del excinfo + with pytest.raises( + TypeError, match="expected an async function but got an async generator" + ): + bad_call(async_gen, 0) def test_calling_asyncio_function_gives_nice_error(): async def child_xyzzy(): - import asyncio - await asyncio.Future() + await create_asyncio_future_in_new_loop() async def misguided(): await child_xyzzy() @@ -1804,9 +1589,8 @@ async def test_asyncio_function_inside_nursery_does_not_explode(): # Regression test for https://github.com/python-trio/trio/issues/552 with pytest.raises(TypeError) as excinfo: async with _core.open_nursery() as nursery: - import asyncio nursery.start_soon(sleep_forever) - await asyncio.Future() + await create_asyncio_future_in_new_loop() assert "asyncio" in str(excinfo.value) @@ -1824,14 +1608,14 @@ async def test_trivial_yields(): with _core.CancelScope() as cancel_scope: cancel_scope.cancel() - with pytest.raises(_core.MultiError) as excinfo: + with pytest.raises(MultiError) as excinfo: async with _core.open_nursery(): raise KeyError assert len(excinfo.value.exceptions) == 2 - assert {type(e) - for e in excinfo.value.exceptions} == { - KeyError, _core.Cancelled - } + assert {type(e) for e in excinfo.value.exceptions} == { + KeyError, + _core.Cancelled, + } async def test_nursery_start(autojump_clock): @@ -1843,9 +1627,7 @@ async def no_args(): # pragma: no cover with pytest.raises(TypeError): await nursery.start(no_args) - async def sleep_then_start( - seconds, *, task_status=_core.TASK_STATUS_IGNORED - ): + async def sleep_then_start(seconds, *, task_status=_core.TASK_STATUS_IGNORED): repr(task_status) # smoke test await sleep(seconds) task_status.started(seconds) @@ -1909,19 +1691,19 @@ async def just_started(task_status=_core.TASK_STATUS_IGNORED): # and if after the no-op started(), the child crashes, the error comes out # of start() - async def raise_keyerror_after_started( - task_status=_core.TASK_STATUS_IGNORED - ): + async def raise_keyerror_after_started(task_status=_core.TASK_STATUS_IGNORED): task_status.started() raise KeyError("whoopsiedaisy") async with _core.open_nursery() as nursery: with _core.CancelScope() as cs: cs.cancel() - with pytest.raises(_core.MultiError) as excinfo: + with pytest.raises(MultiError) as excinfo: await nursery.start(raise_keyerror_after_started) - assert {type(e) - for e in excinfo.value.exceptions} == {_core.Cancelled, KeyError} + assert {type(e) for e in excinfo.value.exceptions} == { + _core.Cancelled, + KeyError, + } # trying to start in a closed nursery raises an error immediately async with _core.open_nursery() as closed_nursery: @@ -2031,7 +1813,7 @@ async def fail(): async with _core.open_nursery() as nursery: nursery.start_soon(fail) raise StopIteration - except _core.MultiError as e: + except MultiError as e: assert tuple(map(type, e.exceptions)) == (StopIteration, ValueError) @@ -2056,32 +1838,17 @@ def __init__(self, *largs): async def _accumulate(self, f, items, i): items[i] = await f() - @aiter_compat def __aiter__(self): return self async def __anext__(self): nexts = self.nexts - items = [ - None, - ] * len(nexts) - got_stop = False - - def handle(exc): - nonlocal got_stop - if isinstance(exc, StopAsyncIteration): - got_stop = True - return None - else: # pragma: no cover - return exc - - with _core.MultiError.catch(handle): - async with _core.open_nursery() as nursery: - for i, f in enumerate(nexts): - nursery.start_soon(self._accumulate, f, items, i) + items = [None] * len(nexts) + + async with _core.open_nursery() as nursery: + for i, f in enumerate(nexts): + nursery.start_soon(self._accumulate, f, items, i) - if got_stop: - raise StopAsyncIteration return items result = [] @@ -2103,7 +1870,7 @@ async def my_child_task(): async with _core.open_nursery() as nursery: nursery.start_soon(my_child_task) nursery.start_soon(my_child_task) - except _core.MultiError as exc: + except MultiError as exc: first_exc = exc.exceptions[0] assert isinstance(first_exc, KeyError) # The top frame in the exception traceback should be inside the child @@ -2154,7 +1921,7 @@ async def t2(): def test_system_task_contexts(): - cvar = contextvars.ContextVar('qwilfish') + cvar = contextvars.ContextVar("qwilfish") cvar.set("water") async def system_task(): @@ -2174,11 +1941,7 @@ async def inner(): def test_Nursery_init(): - check_Nursery_error = pytest.raises( - TypeError, match='no public constructor available' - ) - - with check_Nursery_error: + with pytest.raises(TypeError): _core._run.Nursery(None, None) @@ -2189,23 +1952,17 @@ async def test_Nursery_private_init(): def test_Nursery_subclass(): - with pytest.raises( - TypeError, match='`Nursery` does not support subclassing' - ): + with pytest.raises(TypeError): class Subclass(_core._run.Nursery): pass def test_Cancelled_init(): - check_Cancelled_error = pytest.raises( - TypeError, match='no public constructor available' - ) - - with check_Cancelled_error: + with pytest.raises(TypeError): raise _core.Cancelled - with check_Cancelled_error: + with pytest.raises(TypeError): _core.Cancelled() # private constructor should not raise @@ -2214,22 +1971,18 @@ def test_Cancelled_init(): def test_Cancelled_str(): cancelled = _core.Cancelled._create() - assert str(cancelled) == 'Cancelled' + assert str(cancelled) == "Cancelled" def test_Cancelled_subclass(): - with pytest.raises( - TypeError, match='`Cancelled` does not support subclassing' - ): + with pytest.raises(TypeError): class Subclass(_core.Cancelled): pass def test_CancelScope_subclass(): - with pytest.raises( - TypeError, match='`CancelScope` does not support subclassing' - ): + with pytest.raises(TypeError): class Subclass(_core.CancelScope): pass @@ -2242,11 +1995,25 @@ def test_sniffio_integration(): async def check_inside_trio(): assert sniffio.current_async_library() == "trio" + def check_function_returning_coroutine(): + assert sniffio.current_async_library() == "trio" + return check_inside_trio() + _core.run(check_inside_trio) with pytest.raises(sniffio.AsyncLibraryNotFoundError): sniffio.current_async_library() + async def check_new_task_resets_sniffio_library(): + sniffio.current_async_library_cvar.set("nullio") + _core.spawn_system_task(check_inside_trio) + async with _core.open_nursery() as nursery: + nursery.start_soon(check_inside_trio) + nursery.start_soon(check_function_returning_coroutine) + assert sniffio.current_async_library() == "nullio" + + _core.run(check_new_task_resets_sniffio_library) + async def test_Task_custom_sleep_data(): task = _core.current_task() @@ -2276,9 +2043,7 @@ async def detachable_coroutine(task_outcome, yield_value): await async_yield(yield_value) async with _core.open_nursery() as nursery: - nursery.start_soon( - detachable_coroutine, outcome.Value(None), "I'm free!" - ) + nursery.start_soon(detachable_coroutine, outcome.Value(None), "I'm free!") # If we get here then Trio thinks the task has exited... but the coroutine # is still iterable @@ -2293,9 +2058,7 @@ async def detachable_coroutine(task_outcome, yield_value): pdco_outcome = None with pytest.raises(KeyError): async with _core.open_nursery() as nursery: - nursery.start_soon( - detachable_coroutine, outcome.Error(KeyError()), "uh oh" - ) + nursery.start_soon(detachable_coroutine, outcome.Error(KeyError()), "uh oh") throw_in = ValueError() assert task.coro.throw(throw_in) == "uh oh" assert pdco_outcome == outcome.Error(throw_in) @@ -2305,9 +2068,7 @@ async def detachable_coroutine(task_outcome, yield_value): async def bad_detach(): async with _core.open_nursery(): with pytest.raises(RuntimeError) as excinfo: - await _core.permanently_detach_coroutine_object( - outcome.Value(None) - ) + await _core.permanently_detach_coroutine_object(outcome.Value(None)) assert "open nurser" in str(excinfo.value) async with _core.open_nursery() as nursery: @@ -2338,9 +2099,7 @@ def abort_fn(_): # pragma: no cover await async_yield(2) with pytest.raises(RuntimeError) as excinfo: - await _core.reattach_detached_coroutine_object( - unrelated_task, None - ) + await _core.reattach_detached_coroutine_object(unrelated_task, None) assert "does not match" in str(excinfo.value) await _core.reattach_detached_coroutine_object(task, "byebye") @@ -2392,27 +2151,15 @@ def abort_fn(_): assert abort_fn_called +@restore_unraisablehook() def test_async_function_implemented_in_C(): # These used to crash because we'd try to mutate the coroutine object's # cr_frame, but C functions don't have Python frames. - ns = {"_core": _core} - try: - exec( - dedent( - """ - async def agen_fn(record): - assert not _core.currently_ki_protected() - record.append("the generator ran") - yield - """ - ), - ns, - ) - except SyntaxError: - pytest.skip("Requires Python 3.6+") - else: - agen_fn = ns["agen_fn"] + async def agen_fn(record): + assert not _core.currently_ki_protected() + record.append("the generator ran") + yield run_record = [] agen = agen_fn(run_record) @@ -2437,3 +2184,254 @@ async def test_very_deep_cancel_scope_nesting(): for _ in range(5000): exit_stack.enter_context(_core.CancelScope()) outermost_scope.cancel() + + +async def test_cancel_scope_deadline_duplicates(): + # This exercises an assert in Deadlines._prune, by intentionally creating + # duplicate entries in the deadline heap. + now = _core.current_time() + with _core.CancelScope() as cscope: + for _ in range(DEADLINE_HEAP_MIN_PRUNE_THRESHOLD * 2): + cscope.deadline = now + 9998 + cscope.deadline = now + 9999 + await sleep(0.01) + + +@pytest.mark.skipif( + sys.implementation.name != "cpython", reason="Only makes sense with refcounting GC" +) +async def test_simple_cancel_scope_usage_doesnt_create_cyclic_garbage(): + # https://github.com/python-trio/trio/issues/1770 + gc.collect() + + async def do_a_cancel(): + with _core.CancelScope() as cscope: + cscope.cancel() + await sleep_forever() + + async def crasher(): + raise ValueError + + old_flags = gc.get_debug() + try: + gc.collect() + gc.set_debug(gc.DEBUG_SAVEALL) + + # cover outcome.Error.unwrap + # (See https://github.com/python-trio/outcome/pull/29) + await do_a_cancel() + # cover outcome.Error.unwrap if unrolled_run hangs on to exception refs + # (See https://github.com/python-trio/trio/pull/1864) + await do_a_cancel() + + with pytest.raises(ValueError): + async with _core.open_nursery() as nursery: + # cover NurseryManager.__aexit__ + nursery.start_soon(crasher) + + gc.collect() + assert not gc.garbage + finally: + gc.set_debug(old_flags) + gc.garbage.clear() + + +@pytest.mark.skipif( + sys.implementation.name != "cpython", reason="Only makes sense with refcounting GC" +) +async def test_cancel_scope_exit_doesnt_create_cyclic_garbage(): + # https://github.com/python-trio/trio/pull/2063 + gc.collect() + + async def crasher(): + raise ValueError + + old_flags = gc.get_debug() + try: + with pytest.raises(ValueError), _core.CancelScope() as outer: + async with _core.open_nursery() as nursery: + gc.collect() + gc.set_debug(gc.DEBUG_SAVEALL) + # One child that gets cancelled by the outer scope + nursery.start_soon(sleep_forever) + outer.cancel() + # And one that raises a different error + nursery.start_soon(crasher) + # so that outer filters a Cancelled from the MultiError and + # covers CancelScope.__exit__ (and NurseryManager.__aexit__) + # (See https://github.com/python-trio/trio/pull/2063) + + gc.collect() + assert not gc.garbage + finally: + gc.set_debug(old_flags) + gc.garbage.clear() + + +@pytest.mark.xfail( + sys.version_info >= (3, 12), + reason="Waiting on https://github.com/python/cpython/issues/100964", +) +@pytest.mark.skipif( + sys.implementation.name != "cpython", reason="Only makes sense with refcounting GC" +) +async def test_nursery_cancel_doesnt_create_cyclic_garbage(): + # https://github.com/python-trio/trio/issues/1770#issuecomment-730229423 + def toggle_collected(): + nonlocal collected + collected = True + + collected = False + gc.collect() + old_flags = gc.get_debug() + try: + gc.set_debug(0) + gc.collect() + gc.set_debug(gc.DEBUG_SAVEALL) + + # cover Nursery._nested_child_finished + async with _core.open_nursery() as nursery: + nursery.cancel_scope.cancel() + + weakref.finalize(nursery, toggle_collected) + del nursery + # a checkpoint clears the nursery from the internals, apparently + # TODO: stop event loop from hanging on to the nursery at this point + await _core.checkpoint() + + assert collected + gc.collect() + assert not gc.garbage + finally: + gc.set_debug(old_flags) + gc.garbage.clear() + + +@pytest.mark.skipif( + sys.implementation.name != "cpython", reason="Only makes sense with refcounting GC" +) +async def test_locals_destroyed_promptly_on_cancel(): + destroyed = False + + def finalizer(): + nonlocal destroyed + destroyed = True + + class A: + pass + + async def task(): + a = A() + weakref.finalize(a, finalizer) + await _core.checkpoint() + + async with _core.open_nursery() as nursery: + nursery.start_soon(task) + nursery.cancel_scope.cancel() + assert destroyed + + +def test_run_strict_exception_groups(): + """ + Test that nurseries respect the global context setting of strict_exception_groups. + """ + + async def main(): + async with _core.open_nursery(): + raise Exception("foo") + + with pytest.raises(MultiError) as exc: + _core.run(main, strict_exception_groups=True) + + assert len(exc.value.exceptions) == 1 + assert type(exc.value.exceptions[0]) is Exception + assert exc.value.exceptions[0].args == ("foo",) + + +def test_run_strict_exception_groups_nursery_override(): + """ + Test that a nursery can override the global context setting of + strict_exception_groups. + """ + + async def main(): + async with _core.open_nursery(strict_exception_groups=False): + raise Exception("foo") + + with pytest.raises(Exception, match="foo"): + _core.run(main, strict_exception_groups=True) + + +async def test_nursery_strict_exception_groups(): + """Test that strict exception groups can be enabled on a per-nursery basis.""" + with pytest.raises(MultiError) as exc: + async with _core.open_nursery(strict_exception_groups=True): + raise Exception("foo") + + assert len(exc.value.exceptions) == 1 + assert type(exc.value.exceptions[0]) is Exception + assert exc.value.exceptions[0].args == ("foo",) + + +async def test_nursery_collapse_strict(): + """ + Test that a single exception from a nested nursery with strict semantics doesn't get + collapsed when CancelledErrors are stripped from it. + """ + + async def raise_error(): + raise RuntimeError("test error") + + with pytest.raises(MultiError) as exc: + async with _core.open_nursery() as nursery: + nursery.start_soon(sleep_forever) + nursery.start_soon(raise_error) + async with _core.open_nursery(strict_exception_groups=True) as nursery2: + nursery2.start_soon(sleep_forever) + nursery2.start_soon(raise_error) + nursery.cancel_scope.cancel() + + exceptions = exc.value.exceptions + assert len(exceptions) == 2 + assert isinstance(exceptions[0], RuntimeError) + assert isinstance(exceptions[1], MultiError) + assert len(exceptions[1].exceptions) == 1 + assert isinstance(exceptions[1].exceptions[0], RuntimeError) + + +async def test_nursery_collapse_loose(): + """ + Test that a single exception from a nested nursery with loose semantics gets + collapsed when CancelledErrors are stripped from it. + """ + + async def raise_error(): + raise RuntimeError("test error") + + with pytest.raises(MultiError) as exc: + async with _core.open_nursery() as nursery: + nursery.start_soon(sleep_forever) + nursery.start_soon(raise_error) + async with _core.open_nursery() as nursery2: + nursery2.start_soon(sleep_forever) + nursery2.start_soon(raise_error) + nursery.cancel_scope.cancel() + + exceptions = exc.value.exceptions + assert len(exceptions) == 2 + assert isinstance(exceptions[0], RuntimeError) + assert isinstance(exceptions[1], RuntimeError) + + +async def test_cancel_scope_no_cancellederror(): + """ + Test that when a cancel scope encounters an exception group that does NOT contain + a Cancelled exception, it will NOT set the ``cancelled_caught`` flag. + """ + + with pytest.raises(ExceptionGroup): + with _core.CancelScope() as scope: + scope.cancel() + raise ExceptionGroup("test", [RuntimeError(), RuntimeError()]) + + assert not scope.cancelled_caught diff --git a/trio/_core/_tests/test_thread_cache.py b/trio/_core/_tests/test_thread_cache.py new file mode 100644 index 0000000000..de78443f4e --- /dev/null +++ b/trio/_core/_tests/test_thread_cache.py @@ -0,0 +1,187 @@ +import threading +import time +from contextlib import contextmanager +from queue import Queue + +import pytest + +from .. import _thread_cache +from .._thread_cache import ThreadCache, start_thread_soon +from .tutil import gc_collect_harder, slow + + +def test_thread_cache_basics(): + q = Queue() + + def fn(): + raise RuntimeError("hi") + + def deliver(outcome): + q.put(outcome) + + start_thread_soon(fn, deliver) + + outcome = q.get() + with pytest.raises(RuntimeError, match="hi"): + outcome.unwrap() + + +def test_thread_cache_deref(): + res = [False] + + class del_me: + def __call__(self): + return 42 + + def __del__(self): + res[0] = True + + q = Queue() + + def deliver(outcome): + q.put(outcome) + + start_thread_soon(del_me(), deliver) + outcome = q.get() + assert outcome.unwrap() == 42 + + gc_collect_harder() + assert res[0] + + +@slow +def test_spawning_new_thread_from_deliver_reuses_starting_thread(): + # We know that no-one else is using the thread cache, so if we keep + # submitting new jobs the instant the previous one is finished, we should + # keep getting the same thread over and over. This tests both that the + # thread cache is LIFO, and that threads can be assigned new work *before* + # deliver exits. + + # Make sure there are a few threads running, so if we weren't LIFO then we + # could grab the wrong one. + q = Queue() + COUNT = 5 + for _ in range(COUNT): + start_thread_soon(lambda: time.sleep(1), lambda result: q.put(result)) + for _ in range(COUNT): + q.get().unwrap() + + seen_threads = set() + done = threading.Event() + + def deliver(n, _): + print(n) + seen_threads.add(threading.current_thread()) + if n == 0: + done.set() + else: + start_thread_soon(lambda: None, lambda _: deliver(n - 1, _)) + + start_thread_soon(lambda: None, lambda _: deliver(5, _)) + + done.wait() + + assert len(seen_threads) == 1 + + +@slow +def test_idle_threads_exit(monkeypatch): + # Temporarily set the idle timeout to something tiny, to speed up the + # test. (But non-zero, so that the worker loop will at least yield the + # CPU.) + monkeypatch.setattr(_thread_cache, "IDLE_TIMEOUT", 0.0001) + + q = Queue() + start_thread_soon(lambda: None, lambda _: q.put(threading.current_thread())) + seen_thread = q.get() + # Since the idle timeout is 0, after sleeping for 1 second, the thread + # should have exited + time.sleep(1) + assert not seen_thread.is_alive() + + +@contextmanager +def _join_started_threads(): + before = frozenset(threading.enumerate()) + try: + yield + finally: + for thread in threading.enumerate(): + if thread not in before: + thread.join(timeout=1.0) + assert not thread.is_alive() + + +def test_race_between_idle_exit_and_job_assignment(monkeypatch): + # This is a lock where the first few times you try to acquire it with a + # timeout, it waits until the lock is available and then pretends to time + # out. Using this in our thread cache implementation causes the following + # sequence: + # + # 1. start_thread_soon grabs the worker thread, assigns it a job, and + # releases its lock. + # 2. The worker thread wakes up (because the lock has been released), but + # the JankyLock lies to it and tells it that the lock timed out. So the + # worker thread tries to exit. + # 3. The worker thread checks for the race between exiting and being + # assigned a job, and discovers that it *is* in the process of being + # assigned a job, so it loops around and tries to acquire the lock + # again. + # 4. Eventually the JankyLock admits that the lock is available, and + # everything proceeds as normal. + + class JankyLock: + def __init__(self): + self._lock = threading.Lock() + self._counter = 3 + + def acquire(self, timeout=-1): + got_it = self._lock.acquire(timeout=timeout) + if timeout == -1: + return True + elif got_it: + if self._counter > 0: + self._counter -= 1 + self._lock.release() + return False + return True + else: + return False + + def release(self): + self._lock.release() + + monkeypatch.setattr(_thread_cache, "Lock", JankyLock) + + with _join_started_threads(): + tc = ThreadCache() + done = threading.Event() + tc.start_thread_soon(lambda: None, lambda _: done.set()) + done.wait() + # Let's kill the thread we started, so it doesn't hang around until the + # test suite finishes. Doesn't really do any harm, but it can be confusing + # to see it in debug output. + monkeypatch.setattr(_thread_cache, "IDLE_TIMEOUT", 0.0001) + tc.start_thread_soon(lambda: None, lambda _: None) + + +def test_raise_in_deliver(capfd): + seen_threads = set() + + def track_threads(): + seen_threads.add(threading.current_thread()) + + def deliver(_): + done.set() + raise RuntimeError("don't do this") + + done = threading.Event() + start_thread_soon(track_threads, deliver) + done.wait() + done = threading.Event() + start_thread_soon(track_threads, lambda _: done.set()) + done.wait() + assert len(seen_threads) == 1 + err = capfd.readouterr().err + assert "don't do this" in err + assert "delivering result" in err diff --git a/trio/_core/tests/test_tutil.py b/trio/_core/_tests/test_tutil.py similarity index 100% rename from trio/_core/tests/test_tutil.py rename to trio/_core/_tests/test_tutil.py diff --git a/trio/_core/tests/test_unbounded_queue.py b/trio/_core/_tests/test_unbounded_queue.py similarity index 100% rename from trio/_core/tests/test_unbounded_queue.py rename to trio/_core/_tests/test_unbounded_queue.py diff --git a/trio/_core/tests/test_windows.py b/trio/_core/_tests/test_windows.py similarity index 70% rename from trio/_core/tests/test_windows.py rename to trio/_core/_tests/test_windows.py index 2fb8a97092..0dac94543c 100644 --- a/trio/_core/tests/test_windows.py +++ b/trio/_core/_tests/test_windows.py @@ -4,25 +4,28 @@ import pytest -on_windows = (os.name == "nt") +on_windows = os.name == "nt" # Mark all the tests in this file as being windows-only pytestmark = pytest.mark.skipif(not on_windows, reason="windows only") -from .tutil import slow, gc_collect_harder -from ... import _core, sleep, move_on_after +from ... import _core, sleep from ...testing import wait_all_tasks_blocked +from .tutil import gc_collect_harder, restore_unraisablehook, slow + if on_windows: from .._windows_cffi import ( - ffi, kernel32, INVALID_HANDLE_VALUE, raise_winerror, FileFlags + INVALID_HANDLE_VALUE, + FileFlags, + ffi, + kernel32, + raise_winerror, ) # The undocumented API that this is testing should be changed to stop using # UnboundedQueue (or just removed until we have time to redo it), but until # then we filter out the warning. -@pytest.mark.filterwarnings( - "ignore:.*UnboundedQueue:trio.TrioDeprecationWarning" -) +@pytest.mark.filterwarnings("ignore:.*UnboundedQueue:trio.TrioDeprecationWarning") async def test_completion_key_listen(): async def post(key): iocp = ffi.cast("HANDLE", _core.current_iocp()) @@ -30,9 +33,7 @@ async def post(key): print("post", i) if i % 3 == 0: await _core.checkpoint() - success = kernel32.PostQueuedCompletionStatus( - iocp, i, key, ffi.NULL - ) + success = kernel32.PostQueuedCompletionStatus(iocp, i, key, ffi.NULL) assert success with _core.monitor_completion_key() as (key, queue): @@ -80,9 +81,7 @@ async def test_readinto_overlapped(): async def read_region(start, end): await _core.readinto_overlapped( - handle, - buffer_view[start:end], - start, + handle, buffer_view[start:end], start ) _core.register_with_iocp(handle) @@ -92,7 +91,7 @@ async def read_region(start, end): assert buffer == data - with pytest.raises(BufferError): + with pytest.raises((BufferError, TypeError)): await _core.readinto_overlapped(handle, b"immutable") finally: kernel32.CloseHandle(handle) @@ -100,8 +99,8 @@ async def read_region(start, end): @contextmanager def pipe_with_overlapped_read(): - from asyncio.windows_utils import pipe import msvcrt + from asyncio.windows_utils import pipe read_handle, write_handle = pipe(overlapped=(True, False)) try: @@ -112,6 +111,7 @@ def pipe_with_overlapped_read(): kernel32.CloseHandle(ffi.cast("HANDLE", write_handle)) +@restore_unraisablehook() def test_forgot_to_register_with_iocp(): with pipe_with_overlapped_read() as (write_fp, read_handle): with write_fp: @@ -124,10 +124,7 @@ async def main(): try: async with _core.open_nursery() as nursery: nursery.start_soon( - _core.readinto_overlapped, - read_handle, - target, - name="xyz" + _core.readinto_overlapped, read_handle, target, name="xyz" ) await wait_all_tasks_blocked() nursery.cancel_scope.cancel() @@ -175,3 +172,47 @@ async def test_too_late_to_cancel(): # fallback completion that was posted when CancelIoEx failed. assert await _core.readinto_overlapped(read_handle, target) == 6 assert target[:6] == b"test2\n" + + +def test_lsp_that_hooks_select_gives_good_error(monkeypatch): + from .. import _io_windows + from .._windows_cffi import WSAIoctls, _handle + + def patched_get_underlying(sock, *, which=WSAIoctls.SIO_BASE_HANDLE): + if hasattr(sock, "fileno"): # pragma: no branch + sock = sock.fileno() + if which == WSAIoctls.SIO_BSP_HANDLE_SELECT: + return _handle(sock + 1) + else: + return _handle(sock) + + monkeypatch.setattr(_io_windows, "_get_underlying_socket", patched_get_underlying) + with pytest.raises( + RuntimeError, match="SIO_BASE_HANDLE and SIO_BSP_HANDLE_SELECT differ" + ): + _core.run(sleep, 0) + + +def test_lsp_that_completely_hides_base_socket_gives_good_error(monkeypatch): + # This tests behavior with an LSP that fails SIO_BASE_HANDLE and returns + # self for SIO_BSP_HANDLE_SELECT (like Komodia), but also returns + # self for SIO_BSP_HANDLE_POLL. No known LSP does this, but we want to + # make sure we get an error rather than an infinite loop. + + from .. import _io_windows + from .._windows_cffi import WSAIoctls, _handle + + def patched_get_underlying(sock, *, which=WSAIoctls.SIO_BASE_HANDLE): + if hasattr(sock, "fileno"): # pragma: no branch + sock = sock.fileno() + if which == WSAIoctls.SIO_BASE_HANDLE: + raise OSError("nope") + else: + return _handle(sock) + + monkeypatch.setattr(_io_windows, "_get_underlying_socket", patched_get_underlying) + with pytest.raises( + RuntimeError, + match="SIO_BASE_HANDLE failed and SIO_BSP_HANDLE_POLL didn't return a diff", + ): + _core.run(sleep, 0) diff --git a/trio/_core/_tests/tutil.py b/trio/_core/_tests/tutil.py new file mode 100644 index 0000000000..b3aa73fb7d --- /dev/null +++ b/trio/_core/_tests/tutil.py @@ -0,0 +1,120 @@ +# Utilities for testing +import asyncio +import gc +import os +import socket as stdlib_socket +import sys +import warnings +from contextlib import closing, contextmanager +from typing import TYPE_CHECKING + +import pytest + +# See trio/_tests/conftest.py for the other half of this +from trio._tests.pytest_plugin import RUN_SLOW + +slow = pytest.mark.skipif(not RUN_SLOW, reason="use --run-slow to run slow tests") + +# PyPy 7.2 was released with a bug that just never called the async +# generator 'firstiter' hook at all. This impacts tests of end-of-run +# finalization (nothing gets added to runner.asyncgens) and tests of +# "foreign" async generator behavior (since the firstiter hook is what +# marks the asyncgen as foreign), but most tests of GC-mediated +# finalization still work. +buggy_pypy_asyncgens = ( + not TYPE_CHECKING + and sys.implementation.name == "pypy" + and sys.pypy_version_info < (7, 3) +) + +try: + s = stdlib_socket.socket(stdlib_socket.AF_INET6, stdlib_socket.SOCK_STREAM, 0) +except OSError: # pragma: no cover + # Some systems don't even support creating an IPv6 socket, let alone + # binding it. (ex: Linux with 'ipv6.disable=1' in the kernel command line) + # We don't have any of those in our CI, and there's nothing that gets + # tested _only_ if can_create_ipv6 = False, so we'll just no-cover this. + can_create_ipv6 = False + can_bind_ipv6 = False +else: + can_create_ipv6 = True + with s: + try: + s.bind(("::1", 0)) + except OSError: # pragma: no cover # since support for 3.7 was removed + can_bind_ipv6 = False + else: + can_bind_ipv6 = True + +creates_ipv6 = pytest.mark.skipif(not can_create_ipv6, reason="need IPv6") +binds_ipv6 = pytest.mark.skipif(not can_bind_ipv6, reason="need IPv6") + + +def gc_collect_harder(): + # In the test suite we sometimes want to call gc.collect() to make sure + # that any objects with noisy __del__ methods (e.g. unawaited coroutines) + # get collected before we continue, so their noise doesn't leak into + # unrelated tests. + # + # On PyPy, coroutine objects (for example) can survive at least 1 round of + # garbage collection, because executing their __del__ method to print the + # warning can cause them to be resurrected. So we call collect a few times + # to make sure. + for _ in range(5): + gc.collect() + + +# Some of our tests need to leak coroutines, and thus trigger the +# "RuntimeWarning: coroutine '...' was never awaited" message. This context +# manager should be used anywhere this happens to hide those messages, because +# when expected they're clutter. +@contextmanager +def ignore_coroutine_never_awaited_warnings(): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="coroutine '.*' was never awaited") + try: + yield + finally: + # Make sure to trigger any coroutine __del__ methods now, before + # we leave the context manager. + gc_collect_harder() + + +def _noop(*args, **kwargs): + pass + + +@contextmanager +def restore_unraisablehook(): + sys.unraisablehook, prev = sys.__unraisablehook__, sys.unraisablehook + try: + yield + finally: + sys.unraisablehook = prev + + +# template is like: +# [1, {2.1, 2.2}, 3] -> matches [1, 2.1, 2.2, 3] or [1, 2.2, 2.1, 3] +def check_sequence_matches(seq, template): + i = 0 + for pattern in template: + if not isinstance(pattern, set): + pattern = {pattern} + got = set(seq[i : i + len(pattern)]) + assert got == pattern + i += len(got) + + +# https://bugs.freebsd.org/bugzilla/show_bug.cgi?id=246350 +skip_if_fbsd_pipes_broken = pytest.mark.skipif( + sys.platform != "win32" # prevent mypy from complaining about missing uname + and hasattr(os, "uname") + and os.uname().sysname == "FreeBSD" + and os.uname().release[:4] < "12.2", + reason="hangs on FreeBSD 12.1 and earlier, due to FreeBSD bug #246350", +) + + +def create_asyncio_future_in_new_loop(): + with closing(asyncio.new_event_loop()) as loop: + return loop.create_future() diff --git a/trio/_core/_thread_cache.py b/trio/_core/_thread_cache.py new file mode 100644 index 0000000000..cc272fc92c --- /dev/null +++ b/trio/_core/_thread_cache.py @@ -0,0 +1,263 @@ +import ctypes +import ctypes.util +import sys +import traceback +from functools import partial +from itertools import count +from threading import Lock, Thread +from typing import Callable, Optional, Tuple + +import outcome + + +def _to_os_thread_name(name: str) -> bytes: + # ctypes handles the trailing \00 + return name.encode("ascii", errors="replace")[:15] + + +# used to construct the method used to set os thread name, or None, depending on platform. +# called once on import +def get_os_thread_name_func() -> Optional[Callable[[Optional[int], str], None]]: + def namefunc(setname: Callable[[int, bytes], int], ident: Optional[int], name: str): + # Thread.ident is None "if it has not been started". Unclear if that can happen + # with current usage. + if ident is not None: # pragma: no cover + setname(ident, _to_os_thread_name(name)) + + # namefunc on mac also takes an ident, even if pthread_setname_np doesn't/can't use it + # so the caller don't need to care about platform. + def darwin_namefunc( + setname: Callable[[bytes], int], ident: Optional[int], name: str + ): + # I don't know if Mac can rename threads that hasn't been started, but default + # to no to be on the safe side. + if ident is not None: # pragma: no cover + setname(_to_os_thread_name(name)) + + # find the pthread library + # this will fail on windows + libpthread_path = ctypes.util.find_library("pthread") + if not libpthread_path: + return None + + # Sometimes windows can find the path, but gives a permission error when + # accessing it. Catching a wider exception in case of more esoteric errors. + # https://github.com/python-trio/trio/issues/2688 + try: + libpthread = ctypes.CDLL(libpthread_path) + except Exception: # pragma: no cover + return None + + # get the setname method from it + # afaik this should never fail + pthread_setname_np = getattr(libpthread, "pthread_setname_np", None) + if pthread_setname_np is None: # pragma: no cover + return None + + # specify function prototype + pthread_setname_np.restype = ctypes.c_int + + # on mac OSX pthread_setname_np does not take a thread id, + # it only lets threads name themselves, which is not a problem for us. + # Just need to make sure to call it correctly + if sys.platform == "darwin": + pthread_setname_np.argtypes = [ctypes.c_char_p] + return partial(darwin_namefunc, pthread_setname_np) + + # otherwise assume linux parameter conventions. Should also work on *BSD + pthread_setname_np.argtypes = [ctypes.c_void_p, ctypes.c_char_p] + return partial(namefunc, pthread_setname_np) + + +# construct os thread name method +set_os_thread_name = get_os_thread_name_func() + +# The "thread cache" is a simple unbounded thread pool, i.e., it automatically +# spawns as many threads as needed to handle all the requests its given. Its +# only purpose is to cache worker threads so that they don't have to be +# started from scratch every time we want to delegate some work to a thread. +# It's expected that some higher-level code will track how many threads are in +# use to avoid overwhelming the system (e.g. the limiter= argument to +# trio.to_thread.run_sync). +# +# To maximize sharing, there's only one thread cache per process, even if you +# have multiple calls to trio.run. +# +# Guarantees: +# +# It's safe to call start_thread_soon simultaneously from +# multiple threads. +# +# Idle threads are chosen in LIFO order, i.e. we *don't* spread work evenly +# over all threads. Instead we try to let some threads do most of the work +# while others sit idle as much as possible. Compared to FIFO, this has better +# memory cache behavior, and it makes it easier to detect when we have too +# many threads, so idle ones can exit. +# +# This code assumes that 'dict' has the following properties: +# +# - __setitem__, __delitem__, and popitem are all thread-safe and atomic with +# respect to each other. This is guaranteed by the GIL. +# +# - popitem returns the most-recently-added item (i.e., __setitem__ + popitem +# give you a LIFO queue). This relies on dicts being insertion-ordered, like +# they are in py36+. + +# How long a thread will idle waiting for new work before gives up and exits. +# This value is pretty arbitrary; I don't think it matters too much. +IDLE_TIMEOUT = 10 # seconds + +name_counter = count() + + +class WorkerThread: + def __init__(self, thread_cache): + self._job: Optional[Tuple[Callable, Callable, str]] = None + self._thread_cache = thread_cache + # This Lock is used in an unconventional way. + # + # "Unlocked" means we have a pending job that's been assigned to us; + # "locked" means that we don't. + # + # Initially we have no job, so it starts out in locked state. + self._worker_lock = Lock() + self._worker_lock.acquire() + self._default_name = f"Trio thread {next(name_counter)}" + + self._thread = Thread(target=self._work, name=self._default_name, daemon=True) + + if set_os_thread_name: + set_os_thread_name(self._thread.ident, self._default_name) + self._thread.start() + + def _handle_job(self): + # Handle job in a separate method to ensure user-created + # objects are cleaned up in a consistent manner. + assert self._job is not None + fn, deliver, name = self._job + self._job = None + + # set name + if name is not None: + self._thread.name = name + if set_os_thread_name: + set_os_thread_name(self._thread.ident, name) + result = outcome.capture(fn) + + # reset name if it was changed + if name is not None: + self._thread.name = self._default_name + if set_os_thread_name: + set_os_thread_name(self._thread.ident, self._default_name) + + # Tell the cache that we're available to be assigned a new + # job. We do this *before* calling 'deliver', so that if + # 'deliver' triggers a new job, it can be assigned to us + # instead of spawning a new thread. + self._thread_cache._idle_workers[self] = None + try: + deliver(result) + except BaseException as e: + print("Exception while delivering result of thread", file=sys.stderr) + traceback.print_exception(type(e), e, e.__traceback__) + + def _work(self): + while True: + if self._worker_lock.acquire(timeout=IDLE_TIMEOUT): + # We got a job + self._handle_job() + else: + # Timeout acquiring lock, so we can probably exit. But, + # there's a race condition: we might be assigned a job *just* + # as we're about to exit. So we have to check. + try: + del self._thread_cache._idle_workers[self] + except KeyError: + # Someone else removed us from the idle worker queue, so + # they must be in the process of assigning us a job - loop + # around and wait for it. + continue + else: + # We successfully removed ourselves from the idle + # worker queue, so no more jobs are incoming; it's safe to + # exit. + return + + +class ThreadCache: + def __init__(self): + self._idle_workers = {} + + def start_thread_soon(self, fn, deliver, name: Optional[str] = None): + try: + worker, _ = self._idle_workers.popitem() + except KeyError: + worker = WorkerThread(self) + worker._job = (fn, deliver, name) + worker._worker_lock.release() + + +THREAD_CACHE = ThreadCache() + + +def start_thread_soon(fn, deliver, name: Optional[str] = None): + """Runs ``deliver(outcome.capture(fn))`` in a worker thread. + + Generally ``fn`` does some blocking work, and ``deliver`` delivers the + result back to whoever is interested. + + This is a low-level, no-frills interface, very similar to using + `threading.Thread` to spawn a thread directly. The main difference is + that this function tries to re-use threads when possible, so it can be + a bit faster than `threading.Thread`. + + Worker threads have the `~threading.Thread.daemon` flag set, which means + that if your main thread exits, worker threads will automatically be + killed. If you want to make sure that your ``fn`` runs to completion, then + you should make sure that the main thread remains alive until ``deliver`` + is called. + + It is safe to call this function simultaneously from multiple threads. + + Args: + + fn (sync function): Performs arbitrary blocking work. + + deliver (sync function): Takes the `outcome.Outcome` of ``fn``, and + delivers it. *Must not block.* + + Because worker threads are cached and reused for multiple calls, neither + function should mutate thread-level state, like `threading.local` objects + – or if they do, they should be careful to revert their changes before + returning. + + Note: + + The split between ``fn`` and ``deliver`` serves two purposes. First, + it's convenient, since most callers need something like this anyway. + + Second, it avoids a small race condition that could cause too many + threads to be spawned. Consider a program that wants to run several + jobs sequentially on a thread, so the main thread submits a job, waits + for it to finish, submits another job, etc. In theory, this program + should only need one worker thread. But what could happen is: + + 1. Worker thread: First job finishes, and calls ``deliver``. + + 2. Main thread: receives notification that the job finished, and calls + ``start_thread_soon``. + + 3. Main thread: sees that no worker threads are marked idle, so spawns + a second worker thread. + + 4. Original worker thread: marks itself as idle. + + To avoid this, threads mark themselves as idle *before* calling + ``deliver``. + + Is this potential extra thread a major problem? Maybe not, but it's + easy enough to avoid, and we figure that if the user is trying to + limit how many threads they're using then it's polite to respect that. + + """ + THREAD_CACHE.start_thread_soon(fn, deliver, name) diff --git a/trio/_core/_traps.py b/trio/_core/_traps.py index a24a2f1742..08a8ceac01 100644 --- a/trio/_core/_traps.py +++ b/trio/_core/_traps.py @@ -1,7 +1,8 @@ # These are the only functions that ever yield back to the task runner. -import types import enum +import types +from typing import Any, Callable, NoReturn import attr import outcome @@ -37,7 +38,7 @@ async def cancel_shielded_checkpoint(): Equivalent to (but potentially more efficient than):: with trio.CancelScope(shield=True): - await trio.hazmat.checkpoint() + await trio.lowlevel.checkpoint() """ return (await _async_yield(CancelShieldedCheckpoint)).unwrap() @@ -53,6 +54,7 @@ class Abort(enum.Enum): FAILED """ + SUCCEEDED = 1 FAILED = 2 @@ -63,11 +65,16 @@ class WaitTaskRescheduled: abort_func = attr.ib() -async def wait_task_rescheduled(abort_func): +RaiseCancelT = Callable[[], NoReturn] # TypeAlias + + +# Should always return the type a Task "expects", unless you willfully reschedule it +# with a bad value. +async def wait_task_rescheduled(abort_func: Callable[[RaiseCancelT], Abort]) -> Any: """Put the current task to sleep, with cancellation support. This is the lowest-level API for blocking in Trio. Every time a - :class:`~trio.hazmat.Task` blocks, it does so by calling this function + :class:`~trio.lowlevel.Task` blocks, it does so by calling this function (usually indirectly via some higher-level API). This is a tricky interface with no guard rails. If you can use @@ -96,7 +103,7 @@ async def wait_task_rescheduled(abort_func): def abort_func(raise_cancel): ... - return trio.hazmat.Abort.SUCCEEDED # or FAILED + return trio.lowlevel.Abort.SUCCEEDED # or FAILED It should attempt to clean up any state associated with this call, and in particular, arrange that :func:`reschedule` will *not* be called @@ -127,7 +134,7 @@ def abort_func(raise_cancel): # Catch the exception from raise_cancel and inject it into the task. # (This is what Trio does automatically for you if you return # Abort.SUCCEEDED.) - trio.hazmat.reschedule(task, outcome.capture(raise_cancel)) + trio.lowlevel.reschedule(task, outcome.capture(raise_cancel)) # Option 2: # wait to be woken by "someone", and then decide whether to raise @@ -137,7 +144,7 @@ def abort(inner_raise_cancel): nonlocal outer_raise_cancel outer_raise_cancel = inner_raise_cancel TRY_TO_CANCEL_OPERATION() - return trio.hazmat.Abort.FAILED + return trio.lowlevel.Abort.FAILED await wait_task_rescheduled(abort) if OPERATION_WAS_SUCCESSFULLY_CANCELLED: # raises the error diff --git a/trio/_core/_unbounded_queue.py b/trio/_core/_unbounded_queue.py index f5e2dda5c3..9c747749b4 100644 --- a/trio/_core/_unbounded_queue.py +++ b/trio/_core/_unbounded_queue.py @@ -1,10 +1,8 @@ import attr from .. import _core -from .._util import aiter_compat from .._deprecate import deprecated - -__all__ = ["UnboundedQueue"] +from .._util import Final @attr.s(frozen=True) @@ -13,7 +11,7 @@ class _UnboundedQueueStats: tasks_waiting = attr.ib() -class UnboundedQueue: +class UnboundedQueue(metaclass=Final): """An unbounded queue suitable for certain unusual forms of inter-task communication. @@ -42,11 +40,12 @@ class UnboundedQueue: ... """ + @deprecated( "0.9.0", issue=497, - thing="trio.hazmat.UnboundedQueue", - instead="trio.open_memory_channel(math.inf)" + thing="trio.lowlevel.UnboundedQueue", + instead="trio.open_memory_channel(math.inf)", ) def __init__(self): self._lot = _core.ParkingLot() @@ -55,12 +54,10 @@ def __init__(self): self._can_get = False def __repr__(self): - return "".format(len(self._data)) + return f"" def qsize(self): - """Returns the number of items currently in the queue. - - """ + """Returns the number of items currently in the queue.""" return len(self._data) def empty(self): @@ -142,11 +139,9 @@ def statistics(self): """ return _UnboundedQueueStats( - qsize=len(self._data), - tasks_waiting=self._lot.statistics().tasks_waiting + qsize=len(self._data), tasks_waiting=self._lot.statistics().tasks_waiting ) - @aiter_compat def __aiter__(self): return self diff --git a/trio/_core/_wakeup_socketpair.py b/trio/_core/_wakeup_socketpair.py index 0c37928a55..51a80ef024 100644 --- a/trio/_core/_wakeup_socketpair.py +++ b/trio/_core/_wakeup_socketpair.py @@ -1,16 +1,10 @@ -import socket -import sys -from contextlib import contextmanager import signal +import socket +import warnings from .. import _core from .._util import is_main_thread -if sys.version_info >= (3, 7): - HAVE_WARN_ON_FULL_BUFFER = True -else: - HAVE_WARN_ON_FULL_BUFFER = False - class WakeupSocketpair: def __init__(self): @@ -25,21 +19,15 @@ def __init__(self): # Windows 10: 525347 # Windows you're weird. (And on Windows setting SNDBUF to 0 makes send # blocking, even on non-blocking sockets, so don't do that.) - # - # But, if we're on an old Python and can't control the signal module's - # warn-on-full-buffer behavior, then we need to leave things alone, so - # the signal module won't spam the console with spurious warnings. - if HAVE_WARN_ON_FULL_BUFFER: - self.wakeup_sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1) - self.write_sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1) + self.wakeup_sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1) + self.write_sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1) # On Windows this is a TCP socket so this might matter. On other # platforms this fails b/c AF_UNIX sockets aren't actually TCP. try: - self.write_sock.setsockopt( - socket.IPPROTO_TCP, socket.TCP_NODELAY, 1 - ) + self.write_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) except OSError: pass + self.old_wakeup_fd = None def wakeup_thread_and_signal_safe(self): try: @@ -58,21 +46,26 @@ def drain(self): except BlockingIOError: pass - @contextmanager def wakeup_on_signals(self): + assert self.old_wakeup_fd is None if not is_main_thread(): - yield return fd = self.write_sock.fileno() - if HAVE_WARN_ON_FULL_BUFFER: - old_wakeup_fd = signal.set_wakeup_fd(fd, warn_on_full_buffer=False) - else: - old_wakeup_fd = signal.set_wakeup_fd(fd) - try: - yield - finally: - signal.set_wakeup_fd(old_wakeup_fd) + self.old_wakeup_fd = signal.set_wakeup_fd(fd, warn_on_full_buffer=False) + if self.old_wakeup_fd != -1: + warnings.warn( + RuntimeWarning( + "It looks like Trio's signal handling code might have " + "collided with another library you're using. If you're " + "running Trio in guest mode, then this might mean you " + "should set host_uses_signal_set_wakeup_fd=True. " + "Otherwise, file a bug on Trio and we'll help you figure " + "out what's going on." + ) + ) def close(self): self.wakeup_sock.close() self.write_sock.close() + if self.old_wakeup_fd is not None: + signal.set_wakeup_fd(self.old_wakeup_fd) diff --git a/trio/_core/_windows_cffi.py b/trio/_core/_windows_cffi.py index e2b95a9113..639e75b50e 100644 --- a/trio/_core/_windows_cffi.py +++ b/trio/_core/_windows_cffi.py @@ -1,10 +1,7 @@ -import cffi -import re import enum -try: - from enum import IntFlag -except ImportError: # python 3.5 - from enum import IntEnum as IntFlag +import re + +import cffi ################################################################ # Functions and types @@ -248,6 +245,7 @@ class ErrorCodes(enum.IntEnum): ERROR_INVALID_HANDLE = 6 ERROR_INVALID_PARMETER = 87 ERROR_NOT_FOUND = 1168 + ERROR_NOT_SOCKET = 10038 class FileFlags(enum.IntEnum): @@ -264,7 +262,7 @@ class FileFlags(enum.IntEnum): TRUNCATE_EXISTING = 5 -class AFDPollFlags(IntFlag): +class AFDPollFlags(enum.IntFlag): # These are drawn from a combination of: # https://github.com/piscisaureus/wepoll/blob/master/src/afd.h # https://github.com/reactos/reactos/blob/master/sdk/include/reactos/drivers/afd/shared.h @@ -287,9 +285,10 @@ class AFDPollFlags(IntFlag): class WSAIoctls(enum.IntEnum): SIO_BASE_HANDLE = 0x48000022 SIO_BSP_HANDLE_SELECT = 0x4800001C + SIO_BSP_HANDLE_POLL = 0x4800001D -class CompletionModes(IntFlag): +class CompletionModes(enum.IntFlag): FILE_SKIP_COMPLETION_PORT_ON_SUCCESS = 0x1 FILE_SKIP_SET_EVENT_ON_HANDLE = 0x2 diff --git a/trio/_core/tests/conftest.py b/trio/_core/tests/conftest.py deleted file mode 100644 index aca1f98a65..0000000000 --- a/trio/_core/tests/conftest.py +++ /dev/null @@ -1,25 +0,0 @@ -import pytest -import inspect - -# XX this should move into a global something -from ...testing import MockClock, trio_test - - -@pytest.fixture -def mock_clock(): - return MockClock() - - -@pytest.fixture -def autojump_clock(): - return MockClock(autojump_threshold=0) - - -# FIXME: split off into a package (or just make part of Trio's public -# interface?), with config file to enable? and I guess a mark option too; I -# guess it's useful with the class- and file-level marking machinery (where -# the raw @trio_test decorator isn't enough). -@pytest.hookimpl(tryfirst=True) -def pytest_pyfunc_call(pyfuncitem): - if inspect.iscoroutinefunction(pyfuncitem.obj): - pyfuncitem.obj = trio_test(pyfuncitem.obj) diff --git a/trio/_core/tests/test_multierror_scripts/custom_excepthook.py b/trio/_core/tests/test_multierror_scripts/custom_excepthook.py deleted file mode 100644 index 564c5833b2..0000000000 --- a/trio/_core/tests/test_multierror_scripts/custom_excepthook.py +++ /dev/null @@ -1,18 +0,0 @@ -import _common - -import sys - - -def custom_excepthook(*args): - print("custom running!") - return sys.__excepthook__(*args) - - -sys.excepthook = custom_excepthook - -# Should warn that we'll get kinda-broken tracebacks -import trio - -# The custom excepthook should run, because Trio was polite and didn't -# override it -raise trio.MultiError([ValueError(), KeyError()]) diff --git a/trio/_core/tests/test_util.py b/trio/_core/tests/test_util.py deleted file mode 100644 index 5871ed8eef..0000000000 --- a/trio/_core/tests/test_util.py +++ /dev/null @@ -1 +0,0 @@ -import pytest diff --git a/trio/_core/tests/tutil.py b/trio/_core/tests/tutil.py deleted file mode 100644 index d5d22d525e..0000000000 --- a/trio/_core/tests/tutil.py +++ /dev/null @@ -1,63 +0,0 @@ -# Utilities for testing -import socket as stdlib_socket - -import pytest - -import gc - -# See trio/tests/conftest.py for the other half of this -from trio.tests.conftest import RUN_SLOW -slow = pytest.mark.skipif( - not RUN_SLOW, - reason="use --run-slow to run slow tests", -) - -try: - s = stdlib_socket.socket( - stdlib_socket.AF_INET6, stdlib_socket.SOCK_STREAM, 0 - ) -except OSError: # pragma: no cover - # Some systems don't even support creating an IPv6 socket, let alone - # binding it. (ex: Linux with 'ipv6.disable=1' in the kernel command line) - # We don't have any of those in our CI, and there's nothing that gets - # tested _only_ if can_create_ipv6 = False, so we'll just no-cover this. - can_create_ipv6 = False - can_bind_ipv6 = False -else: - can_create_ipv6 = True - with s: - try: - s.bind(('::1', 0)) - except OSError: - can_bind_ipv6 = False - else: - can_bind_ipv6 = True - -creates_ipv6 = pytest.mark.skipif(not can_create_ipv6, reason="need IPv6") -binds_ipv6 = pytest.mark.skipif(not can_bind_ipv6, reason="need IPv6") - - -def gc_collect_harder(): - # In the test suite we sometimes want to call gc.collect() to make sure - # that any objects with noisy __del__ methods (e.g. unawaited coroutines) - # get collected before we continue, so their noise doesn't leak into - # unrelated tests. - # - # On PyPy, coroutine objects (for example) can survive at least 1 round of - # garbage collection, because executing their __del__ method to print the - # warning can cause them to be resurrected. So we call collect a few times - # to make sure. - for _ in range(4): - gc.collect() - - -# template is like: -# [1, {2.1, 2.2}, 3] -> matches [1, 2.1, 2.2, 3] or [1, 2.2, 2.1, 3] -def check_sequence_matches(seq, template): - i = 0 - for pattern in template: - if not isinstance(pattern, set): - pattern = {pattern} - got = set(seq[i:i + len(pattern)]) - assert got == pattern - i += len(got) diff --git a/trio/_deprecate.py b/trio/_deprecate.py index b1362cbc38..fe00192583 100644 --- a/trio/_deprecate.py +++ b/trio/_deprecate.py @@ -1,7 +1,7 @@ import sys +import warnings from functools import wraps from types import ModuleType -import warnings import attr @@ -30,24 +30,24 @@ class TrioDeprecationWarning(FutureWarning): def _url_for_issue(issue): - return "https://github.com/python-trio/trio/issues/{}".format(issue) + return f"https://github.com/python-trio/trio/issues/{issue}" def _stringify(thing): if hasattr(thing, "__module__") and hasattr(thing, "__qualname__"): - return "{}.{}".format(thing.__module__, thing.__qualname__) + return f"{thing.__module__}.{thing.__qualname__}" return str(thing) def warn_deprecated(thing, version, *, issue, instead, stacklevel=2): stacklevel += 1 - msg = "{} is deprecated since Trio {}".format(_stringify(thing), version) + msg = f"{_stringify(thing)} is deprecated since Trio {version}" if instead is None: msg += " with no replacement" else: - msg += "; use {} instead".format(_stringify(instead)) + msg += f"; use {_stringify(instead)} instead" if issue is not None: - msg += " ({})".format(_url_for_issue(issue)) + msg += f" ({_url_for_issue(issue)})" warnings.warn(TrioDeprecationWarning(msg), stacklevel=stacklevel) @@ -72,9 +72,9 @@ def wrapper(*args, **kwargs): doc = wrapper.__doc__ doc = doc.rstrip() doc += "\n\n" - doc += ".. deprecated:: {}\n".format(version) + doc += f".. deprecated:: {version}\n" if instead is not None: - doc += " Use {} instead.\n".format(_stringify(instead)) + doc += f" Use {_stringify(instead)} instead.\n" if issue is not None: doc += " For details, see `issue #{} <{}>`__.\n".format( issue, _url_for_issue(issue) @@ -116,10 +116,8 @@ def __getattr__(self, name): instead = info.instead if instead is DeprecatedAttribute._not_set: instead = info.value - thing = "{}.{}".format(self.__name__, name) - warn_deprecated( - thing, info.version, issue=info.issue, instead=instead - ) + thing = f"{self.__name__}.{name}" + warn_deprecated(thing, info.version, issue=info.issue, instead=instead) return info.value msg = "module '{}' has no attribute '{}'" diff --git a/trio/_deprecated_ssl_reexports.py b/trio/_deprecated_ssl_reexports.py deleted file mode 100644 index f86077e3fe..0000000000 --- a/trio/_deprecated_ssl_reexports.py +++ /dev/null @@ -1,100 +0,0 @@ -# This is a public namespace, so we don't want to expose any non-underscored -# attributes that aren't actually part of our public API. But it's very -# annoying to carefully always use underscored names for module-level -# temporaries, imports, etc. when implementing the module. So we put the -# implementation in an underscored module, and then re-export the public parts -# here. - -# Trio-specific symbols: -from ._ssl import SSLStream, SSLListener, NeedHandshakeError - -# Symbols re-exported from the stdlib ssl module: - -# Always available -from ssl import ( - cert_time_to_seconds, CertificateError, create_default_context, - DER_cert_to_PEM_cert, get_default_verify_paths, match_hostname, - PEM_cert_to_DER_cert, Purpose, SSLEOFError, SSLError, SSLSyscallError, - SSLZeroReturnError -) - -# Added in python 3.6 -try: - from ssl import AlertDescription, SSLErrorNumber, SSLSession, VerifyFlags, VerifyMode, Options # noqa -except ImportError: - pass - -# Added in python 3.7 -try: - from ssl import SSLCertVerificationError, TLSVersion # noqa -except ImportError: - pass - -# Windows-only -try: - from ssl import enum_certificates, enum_crls # noqa -except ImportError: - pass - -# Fake import to enable static analysis tools to catch the names -# (Real import is below) -try: - from ssl import ( - AF_INET, ALERT_DESCRIPTION_ACCESS_DENIED, - ALERT_DESCRIPTION_BAD_CERTIFICATE_HASH_VALUE, - ALERT_DESCRIPTION_BAD_CERTIFICATE_STATUS_RESPONSE, - ALERT_DESCRIPTION_BAD_CERTIFICATE, ALERT_DESCRIPTION_BAD_RECORD_MAC, - ALERT_DESCRIPTION_CERTIFICATE_EXPIRED, - ALERT_DESCRIPTION_CERTIFICATE_REVOKED, - ALERT_DESCRIPTION_CERTIFICATE_UNKNOWN, - ALERT_DESCRIPTION_CERTIFICATE_UNOBTAINABLE, - ALERT_DESCRIPTION_CLOSE_NOTIFY, ALERT_DESCRIPTION_DECODE_ERROR, - ALERT_DESCRIPTION_DECOMPRESSION_FAILURE, - ALERT_DESCRIPTION_DECRYPT_ERROR, ALERT_DESCRIPTION_HANDSHAKE_FAILURE, - ALERT_DESCRIPTION_ILLEGAL_PARAMETER, - ALERT_DESCRIPTION_INSUFFICIENT_SECURITY, - ALERT_DESCRIPTION_INTERNAL_ERROR, ALERT_DESCRIPTION_NO_RENEGOTIATION, - ALERT_DESCRIPTION_PROTOCOL_VERSION, ALERT_DESCRIPTION_RECORD_OVERFLOW, - ALERT_DESCRIPTION_UNEXPECTED_MESSAGE, ALERT_DESCRIPTION_UNKNOWN_CA, - ALERT_DESCRIPTION_UNKNOWN_PSK_IDENTITY, - ALERT_DESCRIPTION_UNRECOGNIZED_NAME, - ALERT_DESCRIPTION_UNSUPPORTED_CERTIFICATE, - ALERT_DESCRIPTION_UNSUPPORTED_EXTENSION, - ALERT_DESCRIPTION_USER_CANCELLED, CERT_NONE, CERT_OPTIONAL, - CERT_REQUIRED, CHANNEL_BINDING_TYPES, HAS_ALPN, HAS_ECDH, - HAS_NEVER_CHECK_COMMON_NAME, HAS_NPN, HAS_SNI, OP_ALL, - OP_ALLOW_UNSAFE_LEGACY_RENEGOTIATION, OP_COOKIE_EXCHANGE, - OP_DONT_INSERT_EMPTY_FRAGMENTS, OP_EPHEMERAL_RSA, - OP_LEGACY_SERVER_CONNECT, OP_MICROSOFT_BIG_SSLV3_BUFFER, - OP_MICROSOFT_SESS_ID_BUG, OP_MSIE_SSLV2_RSA_PADDING, - OP_NETSCAPE_CA_DN_BUG, OP_NETSCAPE_CHALLENGE_BUG, - OP_NETSCAPE_DEMO_CIPHER_CHANGE_BUG, - OP_NETSCAPE_REUSE_CIPHER_CHANGE_BUG, OP_NO_QUERY_MTU, OP_PKCS1_CHECK_1, - OP_PKCS1_CHECK_2, OP_SSLEAY_080_CLIENT_DH_BUG, - OP_SSLREF2_REUSE_CERT_TYPE_BUG, OP_TLS_BLOCK_PADDING_BUG, - OP_TLS_D5_BUG, OP_TLS_ROLLBACK_BUG, SSL_ERROR_NONE, - SSL_ERROR_NO_SOCKET, OP_CIPHER_SERVER_PREFERENCE, OP_NO_COMPRESSION, - OP_NO_RENEGOTIATION, OP_NO_TICKET, OP_SINGLE_DH_USE, - OP_SINGLE_ECDH_USE, OPENSSL_VERSION_INFO, OPENSSL_VERSION_NUMBER, - OPENSSL_VERSION, PEM_FOOTER, PEM_HEADER, PROTOCOL_TLS_CLIENT, - PROTOCOL_TLS_SERVER, PROTOCOL_TLS, SO_TYPE, SOCK_STREAM, SOL_SOCKET, - SSL_ERROR_EOF, SSL_ERROR_INVALID_ERROR_CODE, SSL_ERROR_SSL, - SSL_ERROR_SYSCALL, SSL_ERROR_WANT_CONNECT, SSL_ERROR_WANT_READ, - SSL_ERROR_WANT_WRITE, SSL_ERROR_WANT_X509_LOOKUP, - SSL_ERROR_ZERO_RETURN, VERIFY_CRL_CHECK_CHAIN, VERIFY_CRL_CHECK_LEAF, - VERIFY_DEFAULT, VERIFY_X509_STRICT, VERIFY_X509_TRUSTED_FIRST, - OP_ENABLE_MIDDLEBOX_COMPAT - ) -except ImportError: - pass - -# Dynamically re-export whatever constants this particular Python happens to -# have: -import ssl as _stdlib_ssl -globals().update( - { - _name: getattr(_stdlib_ssl, _name) - for _name in _stdlib_ssl.__dict__.keys() - if _name.isupper() and not _name.startswith('_') - } -) diff --git a/trio/_deprecated_subprocess_reexports.py b/trio/_deprecated_subprocess_reexports.py deleted file mode 100644 index b91e28784a..0000000000 --- a/trio/_deprecated_subprocess_reexports.py +++ /dev/null @@ -1,28 +0,0 @@ -from ._subprocess import Process - -# Reexport constants and exceptions from the stdlib subprocess module -from subprocess import ( - PIPE, STDOUT, DEVNULL, CalledProcessError, SubprocessError, TimeoutExpired, - CompletedProcess -) - -# Windows only -try: - from subprocess import ( - STARTUPINFO, STD_INPUT_HANDLE, STD_OUTPUT_HANDLE, STD_ERROR_HANDLE, - SW_HIDE, STARTF_USESTDHANDLES, STARTF_USESHOWWINDOW, - CREATE_NEW_CONSOLE, CREATE_NEW_PROCESS_GROUP - ) -except ImportError: - pass - -# Windows 3.7+ only -try: - from subprocess import ( - ABOVE_NORMAL_PRIORITY_CLASS, BELOW_NORMAL_PRIORITY_CLASS, - HIGH_PRIORITY_CLASS, IDLE_PRIORITY_CLASS, NORMAL_PRIORITY_CLASS, - REALTIME_PRIORITY_CLASS, CREATE_NO_WINDOW, DETACHED_PROCESS, - CREATE_DEFAULT_ERROR_MODE, CREATE_BREAKAWAY_FROM_JOB - ) -except ImportError: - pass diff --git a/trio/_dtls.py b/trio/_dtls.py new file mode 100644 index 0000000000..8675cb75b6 --- /dev/null +++ b/trio/_dtls.py @@ -0,0 +1,1353 @@ +# Implementation of DTLS 1.2, using pyopenssl +# https://datatracker.ietf.org/doc/html/rfc6347 +# +# OpenSSL's APIs for DTLS are extremely awkward and limited, which forces us to jump +# through a *lot* of hoops and implement important chunks of the protocol ourselves. +# Hopefully they fix this before implementing DTLS 1.3, because it's a very different +# protocol, and it's probably impossible to pull tricks like we do here. + +from __future__ import annotations + +import enum +import errno +import hmac +import os +import struct +import warnings +import weakref +from itertools import count +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Generic, + Iterable, + Iterator, + TypeVar, + Union, +) +from weakref import ReferenceType, WeakValueDictionary + +import attr +from OpenSSL import SSL + +import trio + +from ._util import Final, NoPublicConstructor + +if TYPE_CHECKING: + from types import TracebackType + + from OpenSSL.SSL import Context + from typing_extensions import Self, TypeAlias + + from ._core._run import TaskStatus + from ._socket import Address, _SocketType + +MAX_UDP_PACKET_SIZE = 65527 + + +def packet_header_overhead(sock: _SocketType) -> int: + if sock.family == trio.socket.AF_INET: + return 28 + else: + return 48 + + +def worst_case_mtu(sock: _SocketType) -> int: + if sock.family == trio.socket.AF_INET: + return 576 - packet_header_overhead(sock) + else: + return 1280 - packet_header_overhead(sock) + + +def best_guess_mtu(sock: _SocketType) -> int: + return 1500 - packet_header_overhead(sock) + + +# There are a bunch of different RFCs that define these codes, so for a +# comprehensive collection look here: +# https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml +class ContentType(enum.IntEnum): + change_cipher_spec = 20 + alert = 21 + handshake = 22 + application_data = 23 + heartbeat = 24 + + +class HandshakeType(enum.IntEnum): + hello_request = 0 + client_hello = 1 + server_hello = 2 + hello_verify_request = 3 + new_session_ticket = 4 + end_of_early_data = 4 + encrypted_extensions = 8 + certificate = 11 + server_key_exchange = 12 + certificate_request = 13 + server_hello_done = 14 + certificate_verify = 15 + client_key_exchange = 16 + finished = 20 + certificate_url = 21 + certificate_status = 22 + supplemental_data = 23 + key_update = 24 + compressed_certificate = 25 + ekt_key = 26 + message_hash = 254 + + +class ProtocolVersion: + DTLS10 = bytes([254, 255]) + DTLS12 = bytes([254, 253]) + + +EPOCH_MASK = 0xFFFF << (6 * 8) + + +# Conventions: +# - All functions that handle network data end in _untrusted. +# - All functions end in _untrusted MUST make sure that bad data from the +# network cannot *only* cause BadPacket to be raised. No IndexError or +# struct.error or whatever. +class BadPacket(Exception): + pass + + +# This checks that the DTLS 'epoch' field is 0, which is true iff we're in the +# initial handshake. It doesn't check the ContentType, because not all +# handshake messages have ContentType==handshake -- for example, +# ChangeCipherSpec is used during the handshake but has its own ContentType. +# +# Cannot fail. +def part_of_handshake_untrusted(packet: bytes) -> bool: + # If the packet is too short, then slicing will successfully return a + # short string, which will necessarily fail to match. + return packet[3:5] == b"\x00\x00" + + +# Cannot fail +def is_client_hello_untrusted(packet: bytes) -> bool: + try: + return ( + packet[0] == ContentType.handshake + and packet[13] == HandshakeType.client_hello + ) + except IndexError: + # Invalid DTLS record + return False + + +# DTLS records are: +# - 1 byte content type +# - 2 bytes version +# - 8 bytes epoch+seqno +# Technically this is 2 bytes epoch then 6 bytes seqno, but we treat it as +# a single 8-byte integer, where epoch changes are represented as jumping +# forward by 2**(6*8). +# - 2 bytes payload length (unsigned big-endian) +# - payload +RECORD_HEADER = struct.Struct("!B2sQH") + + +def to_hex(data: bytes) -> str: # pragma: no cover + return data.hex() + + +@attr.frozen +class Record: + content_type: int + version: bytes = attr.ib(repr=to_hex) + epoch_seqno: int + payload: bytes = attr.ib(repr=to_hex) + + +def records_untrusted(packet: bytes) -> Iterator[Record]: + i = 0 + while i < len(packet): + try: + ct, version, epoch_seqno, payload_len = RECORD_HEADER.unpack_from(packet, i) + # Marked as no-cover because at time of writing, this code is unreachable + # (records_untrusted only gets called on packets that are either trusted or that + # have passed is_client_hello_untrusted, which filters out short packets) + except struct.error as exc: # pragma: no cover + raise BadPacket("invalid record header") from exc + i += RECORD_HEADER.size + payload = packet[i : i + payload_len] + if len(payload) != payload_len: + raise BadPacket("short record") + i += payload_len + yield Record(ct, version, epoch_seqno, payload) + + +def encode_record(record: Record) -> bytes: + header = RECORD_HEADER.pack( + record.content_type, + record.version, + record.epoch_seqno, + len(record.payload), + ) + return header + record.payload + + +# Handshake messages are: +# - 1 byte message type +# - 3 bytes total message length +# - 2 bytes message sequence number +# - 3 bytes fragment offset +# - 3 bytes fragment length +HANDSHAKE_MESSAGE_HEADER = struct.Struct("!B3sH3s3s") + + +@attr.frozen +class HandshakeFragment: + msg_type: int + msg_len: int + msg_seq: int + frag_offset: int + frag_len: int + frag: bytes = attr.ib(repr=to_hex) + + +def decode_handshake_fragment_untrusted(payload: bytes) -> HandshakeFragment: + # Raises BadPacket if decoding fails + try: + ( + msg_type, + msg_len_bytes, + msg_seq, + frag_offset_bytes, + frag_len_bytes, + ) = HANDSHAKE_MESSAGE_HEADER.unpack_from(payload) + except struct.error as exc: + raise BadPacket("bad handshake message header") from exc + # 'struct' doesn't have built-in support for 24-bit integers, so we + # have to do it by hand. These can't fail. + msg_len = int.from_bytes(msg_len_bytes, "big") + frag_offset = int.from_bytes(frag_offset_bytes, "big") + frag_len = int.from_bytes(frag_len_bytes, "big") + frag = payload[HANDSHAKE_MESSAGE_HEADER.size :] + if len(frag) != frag_len: + raise BadPacket("handshake fragment length doesn't match record length") + return HandshakeFragment( + msg_type, + msg_len, + msg_seq, + frag_offset, + frag_len, + frag, + ) + + +def encode_handshake_fragment(hsf: HandshakeFragment) -> bytes: + hs_header = HANDSHAKE_MESSAGE_HEADER.pack( + hsf.msg_type, + hsf.msg_len.to_bytes(3, "big"), + hsf.msg_seq, + hsf.frag_offset.to_bytes(3, "big"), + hsf.frag_len.to_bytes(3, "big"), + ) + return hs_header + hsf.frag + + +def decode_client_hello_untrusted(packet: bytes) -> tuple[int, bytes, bytes]: + # Raises BadPacket if parsing fails + # Returns (record epoch_seqno, cookie from the packet, data that should be + # hashed into cookie) + try: + # ClientHello has to be the first record in the packet + record = next(records_untrusted(packet)) + # no-cover because at time of writing, this is unreachable: + # decode_client_hello_untrusted is only called on packets that have passed + # is_client_hello_untrusted, which confirms the content type. + if record.content_type != ContentType.handshake: # pragma: no cover + raise BadPacket("not a handshake record") + fragment = decode_handshake_fragment_untrusted(record.payload) + if fragment.msg_type != HandshakeType.client_hello: + raise BadPacket("not a ClientHello") + # ClientHello can't be fragmented, because reassembly requires holding + # per-connection state, and we refuse to allocate per-connection state + # until after we get a valid ClientHello. + if fragment.frag_offset != 0: + raise BadPacket("fragmented ClientHello") + if fragment.frag_len != fragment.msg_len: + raise BadPacket("fragmented ClientHello") + + # As per RFC 6347: + # + # When responding to a HelloVerifyRequest, the client MUST use the + # same parameter values (version, random, session_id, cipher_suites, + # compression_method) as it did in the original ClientHello. The + # server SHOULD use those values to generate its cookie and verify that + # they are correct upon cookie receipt. + # + # However, the record-layer framing can and will change (e.g. the + # second ClientHello will have a new record-layer sequence number). So + # we need to pull out the handshake message alone, discarding the + # record-layer stuff, and then we're going to hash all of it *except* + # the cookie. + + body = fragment.frag + # ClientHello is: + # + # - 2 bytes client_version + # - 32 bytes random + # - 1 byte session_id length + # - session_id + # - 1 byte cookie length + # - cookie + # - everything else + # + # So to find the cookie, so we need to figure out how long the + # session_id is and skip past it. + session_id_len = body[2 + 32] + cookie_len_offset = 2 + 32 + 1 + session_id_len + cookie_len = body[cookie_len_offset] + + cookie_start = cookie_len_offset + 1 + cookie_end = cookie_start + cookie_len + + before_cookie = body[:cookie_len_offset] + cookie = body[cookie_start:cookie_end] + after_cookie = body[cookie_end:] + + if len(cookie) != cookie_len: + raise BadPacket("short cookie") + return (record.epoch_seqno, cookie, before_cookie + after_cookie) + + except (struct.error, IndexError) as exc: + raise BadPacket("bad ClientHello") from exc + + +@attr.frozen +class HandshakeMessage: + record_version: bytes = attr.ib(repr=to_hex) + msg_type: HandshakeType + msg_seq: int + body: bytearray = attr.ib(repr=to_hex) + + +# ChangeCipherSpec is part of the handshake, but it's not a "handshake +# message" and can't be fragmented the same way. Sigh. +@attr.frozen +class PseudoHandshakeMessage: + record_version: bytes = attr.ib(repr=to_hex) + content_type: int + payload: bytes = attr.ib(repr=to_hex) + + +# The final record in a handshake is Finished, which is encrypted, can't be fragmented +# (at least by us), and keeps its record number (because it's in a new epoch). So we +# just pass it through unchanged. (Fortunately, the payload is only a single hash value, +# so the largest it will ever be is 64 bytes for a 512-bit hash. Which is small enough +# that it never requires fragmenting to fit into a UDP packet. +@attr.frozen +class OpaqueHandshakeMessage: + record: Record + + +_AnyHandshakeMessage: TypeAlias = Union[ + HandshakeMessage, PseudoHandshakeMessage, OpaqueHandshakeMessage +] + + +# This takes a raw outgoing handshake volley that openssl generated, and +# reconstructs the handshake messages inside it, so that we can repack them +# into records while retransmitting. So the data ought to be well-behaved -- +# it's not coming from the network. +def decode_volley_trusted( + volley: bytes, +) -> list[_AnyHandshakeMessage]: + messages: list[_AnyHandshakeMessage] = [] + messages_by_seq = {} + for record in records_untrusted(volley): + # ChangeCipherSpec isn't a handshake message, so it can't be fragmented. + # Handshake messages with epoch > 0 are encrypted, so we can't fragment them + # either. Fortunately, ChangeCipherSpec has a 1 byte payload, and the only + # encrypted handshake message is Finished, whose payload is a single hash value + # -- so 32 bytes for SHA-256, 64 for SHA-512, etc. Neither is going to be so + # large that it has to be fragmented to fit into a single packet. + if record.epoch_seqno & EPOCH_MASK: + messages.append(OpaqueHandshakeMessage(record)) + elif record.content_type in (ContentType.change_cipher_spec, ContentType.alert): + messages.append( + PseudoHandshakeMessage( + record.version, record.content_type, record.payload + ) + ) + else: + assert record.content_type == ContentType.handshake + fragment = decode_handshake_fragment_untrusted(record.payload) + msg_type = HandshakeType(fragment.msg_type) + if fragment.msg_seq not in messages_by_seq: + msg = HandshakeMessage( + record.version, + msg_type, + fragment.msg_seq, + bytearray(fragment.msg_len), + ) + messages.append(msg) + messages_by_seq[fragment.msg_seq] = msg + else: + msg = messages_by_seq[fragment.msg_seq] + assert msg.msg_type == fragment.msg_type + assert msg.msg_seq == fragment.msg_seq + assert len(msg.body) == fragment.msg_len + + msg.body[ + fragment.frag_offset : fragment.frag_offset + fragment.frag_len + ] = fragment.frag + + return messages + + +class RecordEncoder: + def __init__(self) -> None: + self._record_seq = count() + + def set_first_record_number(self, n: int) -> None: + self._record_seq = count(n) + + def encode_volley( + self, + messages: Iterable[_AnyHandshakeMessage], + mtu: int, + ) -> list[bytearray]: + packets = [] + packet = bytearray() + for message in messages: + if isinstance(message, OpaqueHandshakeMessage): + encoded = encode_record(message.record) + if mtu - len(packet) - len(encoded) <= 0: + packets.append(packet) + packet = bytearray() + packet += encoded + assert len(packet) <= mtu + elif isinstance(message, PseudoHandshakeMessage): + space = mtu - len(packet) - RECORD_HEADER.size - len(message.payload) + if space <= 0: + packets.append(packet) + packet = bytearray() + packet += RECORD_HEADER.pack( + message.content_type, + message.record_version, + next(self._record_seq), + len(message.payload), + ) + packet += message.payload + assert len(packet) <= mtu + else: + msg_len_bytes = len(message.body).to_bytes(3, "big") + frag_offset = 0 + frags_encoded = 0 + # If message.body is empty, then we still want to encode it in one + # fragment, not zero. + while frag_offset < len(message.body) or not frags_encoded: + space = ( + mtu + - len(packet) + - RECORD_HEADER.size + - HANDSHAKE_MESSAGE_HEADER.size + ) + if space <= 0: + packets.append(packet) + packet = bytearray() + continue + frag = message.body[frag_offset : frag_offset + space] + frag_offset_bytes = frag_offset.to_bytes(3, "big") + frag_len_bytes = len(frag).to_bytes(3, "big") + frag_offset += len(frag) + + packet += RECORD_HEADER.pack( + ContentType.handshake, + message.record_version, + next(self._record_seq), + HANDSHAKE_MESSAGE_HEADER.size + len(frag), + ) + + packet += HANDSHAKE_MESSAGE_HEADER.pack( + message.msg_type, + msg_len_bytes, + message.msg_seq, + frag_offset_bytes, + frag_len_bytes, + ) + + packet += frag + + frags_encoded += 1 + assert len(packet) <= mtu + + if packet: + packets.append(packet) + + return packets + + +# This bit requires implementing a bona fide cryptographic protocol, so even though it's +# a simple one let's take a moment to discuss the design. +# +# Our goal is to force new incoming handshakes that claim to be coming from a +# given ip:port to prove that they can also receive packets sent to that +# ip:port. (There's nothing in UDP to stop someone from forging the return +# address, and it's often used for stuff like DoS reflection attacks, where +# an attacker tries to trick us into sending data at some innocent victim.) +# For more details, see: +# +# https://datatracker.ietf.org/doc/html/rfc6347#section-4.2.1 +# +# To do this, when we receive an initial ClientHello, we calculate a magic +# cookie, and send it back as a HelloVerifyRequest. Then the client sends us a +# second ClientHello, this time with the magic cookie included, and after we +# check that this cookie is valid we go ahead and start the handshake proper. +# +# So the magic cookie needs the following properties: +# - No-one can forge it without knowing our secret key +# - It ensures that the ip, port, and ClientHello contents from the response +# match those in the challenge +# - It expires after a short-ish period (so that if an attacker manages to steal one, it +# won't be useful for long) +# - It doesn't require storing any peer-specific state on our side +# +# To do that, we take the ip/port/ClientHello data and compute an HMAC of them, using a +# secret key we generate on startup. We also include: +# +# - The current time (using Trio's clock), rounded to the nearest 30 seconds +# - A random salt +# +# Then the cookie is the salt and the HMAC digest concatenated together. +# +# When verifying a cookie, we use the salt + new ip/port/ClientHello data to recompute +# the HMAC digest, for both the current time and the current time minus 30 seconds, and +# if either of them match, we consider the cookie good. +# +# Including the rounded-off time like this means that each cookie is good for at least +# 30 seconds, and possibly as much as 60 seconds. +# +# The salt is probably not necessary -- I'm pretty sure that all it does is make it hard +# for an attacker to figure out when our clock ticks over a 30 second boundary. Which is +# probably pretty harmless? But it's easier to add the salt than to convince myself that +# it's *completely* harmless, so, salt it is. + +COOKIE_REFRESH_INTERVAL = 30 # seconds +KEY_BYTES = 32 +COOKIE_HASH = "sha256" +SALT_BYTES = 8 +# 32 bytes was the maximum cookie length in DTLS 1.0. DTLS 1.2 raised it to 255. I doubt +# there are any DTLS 1.0 implementations still in the wild, but really 32 bytes is +# plenty, and it also gets rid of a confusing warning in Wireshark output. +# +# We truncate the cookie to 32 bytes, of which 8 bytes is salt, so that leaves 24 bytes +# of truncated HMAC = 192 bit security, which is still massive overkill. (TCP uses 32 +# *bits* for this.) HMAC truncation is explicitly noted as safe in RFC 2104: +# https://datatracker.ietf.org/doc/html/rfc2104#section-5 +COOKIE_LENGTH = 32 + + +def _current_cookie_tick() -> int: + return int(trio.current_time() / COOKIE_REFRESH_INTERVAL) + + +# Simple deterministic and invertible serializer -- i.e., a useful tool for converting +# structured data into something we can cryptographically sign. +def _signable(*fields: bytes) -> bytes: + out = [] + for field in fields: + out.append(struct.pack("!Q", len(field))) + out.append(field) + return b"".join(out) + + +def _make_cookie( + key: bytes, salt: bytes, tick: int, address: Address, client_hello_bits: bytes +) -> bytes: + assert len(salt) == SALT_BYTES + assert len(key) == KEY_BYTES + + signable_data = _signable( + salt, + struct.pack("!Q", tick), + # address is a mix of strings and ints, and variable length, so pack + # it into a single nested field + _signable(*(str(part).encode() for part in address)), + client_hello_bits, + ) + + return (salt + hmac.digest(key, signable_data, COOKIE_HASH))[:COOKIE_LENGTH] + + +def valid_cookie( + key: bytes, cookie: bytes, address: Address, client_hello_bits: bytes +) -> bool: + if len(cookie) > SALT_BYTES: + salt = cookie[:SALT_BYTES] + + tick = _current_cookie_tick() + + cur_cookie = _make_cookie(key, salt, tick, address, client_hello_bits) + old_cookie = _make_cookie( + key, salt, max(tick - 1, 0), address, client_hello_bits + ) + + # I doubt using a short-circuiting 'or' here would leak any meaningful + # information, but why risk it when '|' is just as easy. + return hmac.compare_digest(cookie, cur_cookie) | hmac.compare_digest( + cookie, old_cookie + ) + else: + return False + + +def challenge_for( + key: bytes, address: Address, epoch_seqno: int, client_hello_bits: bytes +) -> bytes: + salt = os.urandom(SALT_BYTES) + tick = _current_cookie_tick() + cookie = _make_cookie(key, salt, tick, address, client_hello_bits) + + # HelloVerifyRequest body is: + # - 2 bytes version + # - length-prefixed cookie + # + # The DTLS 1.2 spec says that for this message specifically we should use + # the DTLS 1.0 version. + # + # (It also says the opposite of that, but that part is a mistake: + # https://www.rfc-editor.org/errata/eid4103 + # ). + # + # And I guess we use this for both the message-level and record-level + # ProtocolVersions, since we haven't negotiated anything else yet? + body = ProtocolVersion.DTLS10 + bytes([len(cookie)]) + cookie + + # RFC says have to copy the client's record number + # Errata says it should be handshake message number + # Openssl copies back record sequence number, and always sets message seq + # number 0. So I guess we'll follow openssl. + hs = HandshakeFragment( + msg_type=HandshakeType.hello_verify_request, + msg_len=len(body), + msg_seq=0, + frag_offset=0, + frag_len=len(body), + frag=body, + ) + payload = encode_handshake_fragment(hs) + + packet = encode_record( + Record(ContentType.handshake, ProtocolVersion.DTLS10, epoch_seqno, payload) + ) + return packet + + +_T = TypeVar("_T") + + +class _Queue(Generic[_T]): + def __init__(self, incoming_packets_buffer: int | float): + self.s, self.r = trio.open_memory_channel[_T](incoming_packets_buffer) + + +def _read_loop(read_fn: Callable[[int], bytes]) -> bytes: + chunks = [] + while True: + try: + chunk = read_fn(2**14) # max TLS record size + except SSL.WantReadError: + break + chunks.append(chunk) + return b"".join(chunks) + + +async def handle_client_hello_untrusted( + endpoint: DTLSEndpoint, address: Address, packet: bytes +) -> None: + if endpoint._listening_context is None: + return + + try: + epoch_seqno, cookie, bits = decode_client_hello_untrusted(packet) + except BadPacket: + return + + if endpoint._listening_key is None: + endpoint._listening_key = os.urandom(KEY_BYTES) + + if not valid_cookie(endpoint._listening_key, cookie, address, bits): + challenge_packet = challenge_for( + endpoint._listening_key, address, epoch_seqno, bits + ) + try: + async with endpoint._send_lock: + await endpoint.socket.sendto(challenge_packet, address) + except (OSError, trio.ClosedResourceError): + pass + else: + # We got a real, valid ClientHello! + stream = DTLSChannel._create(endpoint, address, endpoint._listening_context) + # Our HelloRetryRequest had some sequence number. We need our future sequence + # numbers to be larger than it, so our peer knows that our future records aren't + # stale/duplicates. But, we don't know what this sequence number was. What we do + # know is: + # - the HelloRetryRequest seqno was copied it from the initial ClientHello + # - the new ClientHello has a higher seqno than the initial ClientHello + # So, if we copy the new ClientHello's seqno into our first real handshake + # record and increment from there, that should work. + stream._record_encoder.set_first_record_number(epoch_seqno) + # Process the ClientHello + try: + stream._ssl.bio_write(packet) + stream._ssl.DTLSv1_listen() + except SSL.Error: + # ...OpenSSL didn't like it, so I guess we didn't have a valid ClientHello + # after all. + return + + # Some old versions of OpenSSL have a bug with memory BIOs, where DTLSv1_listen + # consumes the ClientHello out of the BIO, but then do_handshake expects the + # ClientHello to still be in there (but not the one that ships with Ubuntu + # 20.04). In particular, this is known to affect the OpenSSL v1.1.1 that ships + # with Ubuntu 18.04. To work around this, we deliver a second copy of the + # ClientHello after DTLSv1_listen has completed. This is safe to do + # unconditionally, because on newer versions of OpenSSL, the second ClientHello + # is treated as a duplicate packet, which is a normal thing that can happen over + # UDP. For more details, see: + # + # https://github.com/pyca/pyopenssl/blob/e84e7b57d1838de70ab7a27089fbee78ce0d2106/tests/test_ssl.py#L4226-L4293 + # + # This was fixed in v1.1.1a, and all later versions. So maybe in 2024 or so we + # can delete this. The fix landed in OpenSSL master as 079ef6bd534d2, and then + # was backported to the 1.1.1 branch as d1bfd8076e28. + stream._ssl.bio_write(packet) + + # Check if we have an existing association + old_stream = endpoint._streams.get(address) + if old_stream is not None: + if old_stream._client_hello == (cookie, bits): + # ...This was just a duplicate of the last ClientHello, so never mind. + return + else: + # Ok, this *really is* a new handshake; the old stream should go away. + old_stream._set_replaced() + stream._client_hello = (cookie, bits) + endpoint._streams[address] = stream + endpoint._incoming_connections_q.s.send_nowait(stream) + + +async def dtls_receive_loop( + endpoint_ref: ReferenceType[DTLSEndpoint], sock: _SocketType +) -> None: + try: + while True: + try: + packet, address = await sock.recvfrom(MAX_UDP_PACKET_SIZE) + except OSError as exc: + if exc.errno == errno.ECONNRESET: + # Windows only: "On a UDP-datagram socket [ECONNRESET] + # indicates a previous send operation resulted in an ICMP Port + # Unreachable message" -- https://docs.microsoft.com/en-us/windows/win32/api/winsock/nf-winsock-recvfrom + # + # This is totally useless -- there's nothing we can do with this + # information. So we just ignore it and retry the recv. + continue + else: + raise + endpoint = endpoint_ref() + try: + if endpoint is None: + return + if is_client_hello_untrusted(packet): + await handle_client_hello_untrusted(endpoint, address, packet) + elif address in endpoint._streams: + stream = endpoint._streams[address] + if stream._did_handshake and part_of_handshake_untrusted(packet): + # The peer just sent us more handshake messages, that aren't a + # ClientHello, and we thought the handshake was done. Some of + # the packets that we sent to finish the handshake must have + # gotten lost. So re-send them. We do this directly here instead + # of just putting it into the queue and letting the receiver do + # it, because there's no guarantee that anyone is reading from + # the queue, because we think the handshake is done! + await stream._resend_final_volley() + else: + try: + # mypy for some reason cannot determine type of _q + stream._q.s.send_nowait(packet) # type:ignore[has-type] + except trio.WouldBlock: + stream._packets_dropped_in_trio += 1 + else: + # Drop packet + pass + finally: + del endpoint + except trio.ClosedResourceError: + # socket was closed + return + except OSError as exc: + if exc.errno in (errno.EBADF, errno.ENOTSOCK): + # socket was closed + return + else: # pragma: no cover + # ??? shouldn't happen + raise + + +@attr.frozen +class DTLSChannelStatistics: + """Currently this has only one attribute: + + - ``incoming_packets_dropped_in_trio`` (``int``): Gives a count of the number of + incoming packets from this peer that Trio successfully received from the + network, but then got dropped because the internal channel buffer was full. If + this is non-zero, then you might want to call ``receive`` more often, or use a + larger ``incoming_packets_buffer``, or just not worry about it because your + UDP-based protocol should be able to handle the occasional lost packet, right? + + """ + + incoming_packets_dropped_in_trio: int + + +class DTLSChannel(trio.abc.Channel[bytes], metaclass=NoPublicConstructor): + """A DTLS connection. + + This class has no public constructor – you get instances by calling + `DTLSEndpoint.serve` or `~DTLSEndpoint.connect`. + + .. attribute:: endpoint + + The `DTLSEndpoint` that this connection is using. + + .. attribute:: peer_address + + The IP/port of the remote peer that this connection is associated with. + + """ + + def __init__(self, endpoint: DTLSEndpoint, peer_address: Address, ctx: Context): + self.endpoint = endpoint + self.peer_address = peer_address + self._packets_dropped_in_trio = 0 + self._client_hello = None + self._did_handshake = False + # These are mandatory for all DTLS connections. OP_NO_QUERY_MTU is required to + # stop openssl from trying to query the memory BIO's MTU and then breaking, and + # OP_NO_RENEGOTIATION disables renegotiation, which is too complex for us to + # support and isn't useful anyway -- especially for DTLS where it's equivalent + # to just performing a new handshake. + ctx.set_options( + ( + SSL.OP_NO_QUERY_MTU + | SSL.OP_NO_RENEGOTIATION # type: ignore[attr-defined] + ) + ) + self._ssl = SSL.Connection(ctx) + self._handshake_mtu = 0 + # This calls self._ssl.set_ciphertext_mtu, which is important, because if you + # don't call it then openssl doesn't work. + self.set_ciphertext_mtu(best_guess_mtu(self.endpoint.socket)) + self._replaced = False + self._closed = False + self._q = _Queue[bytes](endpoint.incoming_packets_buffer) + self._handshake_lock = trio.Lock() + self._record_encoder: RecordEncoder = RecordEncoder() + + self._final_volley: list[_AnyHandshakeMessage] = [] + + def _set_replaced(self) -> None: + self._replaced = True + # Any packets we already received could maybe possibly still be processed, but + # there are no more coming. So we close this on the sender side. + self._q.s.close() + + def _check_replaced(self) -> None: + if self._replaced: + raise trio.BrokenResourceError( + "peer tore down this connection to start a new one" + ) + + # XX on systems where we can (maybe just Linux?) take advantage of the kernel's PMTU + # estimate + + # XX should we send close-notify when closing? It seems particularly pointless for + # DTLS where packets are all independent and can be lost anyway. We do at least need + # to handle receiving it properly though, which might be easier if we send it... + + def close(self) -> None: + """Close this connection. + + `DTLSChannel`\\s don't actually own any OS-level resources – the + socket is owned by the `DTLSEndpoint`, not the individual connections. So + you don't really *have* to call this. But it will interrupt any other tasks + calling `receive` with a `ClosedResourceError`, and cause future attempts to use + this connection to fail. + + You can also use this object as a synchronous or asynchronous context manager. + + """ + if self._closed: + return + self._closed = True + if self.endpoint._streams.get(self.peer_address) is self: + del self.endpoint._streams[self.peer_address] + # Will wake any tasks waiting on self._q.get with a + # ClosedResourceError + self._q.r.close() + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + return self.close() + + async def aclose(self) -> None: + """Close this connection, but asynchronously. + + This is included to satisfy the `trio.abc.Channel` contract. It's + identical to `close`, but async. + + """ + self.close() + await trio.lowlevel.checkpoint() + + async def _send_volley(self, volley_messages: list[_AnyHandshakeMessage]) -> None: + packets = self._record_encoder.encode_volley( + volley_messages, self._handshake_mtu + ) + for packet in packets: + async with self.endpoint._send_lock: + await self.endpoint.socket.sendto(packet, self.peer_address) + + async def _resend_final_volley(self) -> None: + await self._send_volley(self._final_volley) + + async def do_handshake(self, *, initial_retransmit_timeout: float = 1.0) -> None: + """Perform the handshake. + + Calling this is optional – if you don't, then it will be automatically called + the first time you call `send` or `receive`. But calling it explicitly can be + useful in case you want to control the retransmit timeout, use a cancel scope to + place an overall timeout on the handshake, or catch errors from the handshake + specifically. + + It's safe to call this multiple times, or call it simultaneously from multiple + tasks – the first call will perform the handshake, and the rest will be no-ops. + + Args: + + initial_retransmit_timeout (float): Since UDP is an unreliable protocol, it's + possible that some of the packets we send during the handshake will get + lost. To handle this, DTLS uses a timer to automatically retransmit + handshake packets that don't receive a response. This lets you set the + timeout we use to detect packet loss. Ideally, it should be set to ~1.5 + times the round-trip time to your peer, but 1 second is a reasonable + default. There's `some useful guidance here + `__. + + This is the *initial* timeout, because if packets keep being lost then Trio + will automatically back off to longer values, to avoid overloading the + network. + + """ + async with self._handshake_lock: + if self._did_handshake: + return + + timeout = initial_retransmit_timeout + volley_messages: list[_AnyHandshakeMessage] = [] + volley_failed_sends = 0 + + def read_volley() -> list[_AnyHandshakeMessage]: + volley_bytes = _read_loop(self._ssl.bio_read) + new_volley_messages = decode_volley_trusted(volley_bytes) + if ( + new_volley_messages + and volley_messages + and isinstance(new_volley_messages[0], HandshakeMessage) + and isinstance(volley_messages[0], HandshakeMessage) + and new_volley_messages[0].msg_seq == volley_messages[0].msg_seq + ): + # openssl decided to retransmit; discard because we handle + # retransmits ourselves + return [] + else: + return new_volley_messages + + # If we're a client, we send the initial volley. If we're a server, then + # the initial ClientHello has already been inserted into self._ssl's + # read BIO. So either way, we start by generating a new volley. + try: + self._ssl.do_handshake() + except SSL.WantReadError: + pass + volley_messages = read_volley() + # If we don't have messages to send in our initial volley, then something + # has gone very wrong. (I'm not sure this can actually happen without an + # error from OpenSSL, but we check just in case.) + if not volley_messages: # pragma: no cover + raise SSL.Error("something wrong with peer's ClientHello") + + while True: + # -- at this point, we need to either send or re-send a volley -- + assert volley_messages + self._check_replaced() + await self._send_volley(volley_messages) + # -- then this is where we wait for a reply -- + self.endpoint._ensure_receive_loop() + with trio.move_on_after(timeout) as cscope: + async for packet in self._q.r: + self._ssl.bio_write(packet) + try: + self._ssl.do_handshake() + # We ignore generic SSL.Error here, because you can get those + # from random invalid packets + except (SSL.WantReadError, SSL.Error): + pass + else: + # No exception -> the handshake is done, and we can + # switch into data transfer mode. + self._did_handshake = True + # Might be empty, but that's ok -- we'll just send no + # packets. + self._final_volley = read_volley() + await self._send_volley(self._final_volley) + return + maybe_volley = read_volley() + if maybe_volley: + if ( + isinstance(maybe_volley[0], PseudoHandshakeMessage) + and maybe_volley[0].content_type == ContentType.alert + ): + # we're sending an alert (e.g. due to a corrupted + # packet). We want to send it once, but don't save it to + # retransmit -- keep the last volley as the current + # volley. + await self._send_volley(maybe_volley) + else: + # We managed to get all of the peer's volley and + # generate a new one ourselves! break out of the 'for' + # loop and restart the timer. + volley_messages = maybe_volley + # "Implementations SHOULD retain the current timer value + # until a transmission without loss occurs, at which + # time the value may be reset to the initial value." + if volley_failed_sends == 0: + timeout = initial_retransmit_timeout + volley_failed_sends = 0 + break + else: + assert self._replaced + self._check_replaced() + if cscope.cancelled_caught: + # Timeout expired. Double timeout for backoff, with a limit of 60 + # seconds (this matches what openssl does, and also the + # recommendation in draft-ietf-tls-dtls13). + timeout = min(2 * timeout, 60.0) + volley_failed_sends += 1 + if volley_failed_sends == 2: + # We tried sending this twice and they both failed. Maybe our + # PMTU estimate is wrong? Let's try dropping it to the minimum + # and hope that helps. + self._handshake_mtu = min( + self._handshake_mtu, worst_case_mtu(self.endpoint.socket) + ) + + async def send(self, data: bytes) -> None: + """Send a packet of data, securely.""" + + if self._closed: + raise trio.ClosedResourceError + if not data: + raise ValueError("openssl doesn't support sending empty DTLS packets") + if not self._did_handshake: + await self.do_handshake() + self._check_replaced() + self._ssl.write(data) + async with self.endpoint._send_lock: + await self.endpoint.socket.sendto( + _read_loop(self._ssl.bio_read), self.peer_address + ) + + async def receive(self) -> bytes: + """Fetch the next packet of data from this connection's peer, waiting if + necessary. + + This is safe to call from multiple tasks simultaneously, in case you have some + reason to do that. And more importantly, it's cancellation-safe, meaning that + cancelling a call to `receive` will never cause a packet to be lost or corrupt + the underlying connection. + + """ + if not self._did_handshake: + await self.do_handshake() + # If the packet isn't really valid, then openssl can decode it to the empty + # string (e.g. b/c it's a late-arriving handshake packet, or a duplicate copy of + # a data packet). Skip over these instead of returning them. + while True: + try: + packet = await self._q.r.receive() + except trio.EndOfChannel: + assert self._replaced + self._check_replaced() + self._ssl.bio_write(packet) + cleartext = _read_loop(self._ssl.read) + if cleartext: + return cleartext + + def set_ciphertext_mtu(self, new_mtu: int) -> None: + """Tells Trio the `largest amount of data that can be sent in a single packet to + this peer `__. + + Trio doesn't actually enforce this limit – if you pass a huge packet to `send`, + then we'll dutifully encrypt it and attempt to send it. But calling this method + does have two useful effects: + + - If called before the handshake is performed, then Trio will automatically + fragment handshake messages to fit within the given MTU. It also might + fragment them even smaller, if it detects signs of packet loss, so setting + this should never be necessary to make a successful connection. But, the + packet loss detection only happens after multiple timeouts have expired, so if + you have reason to believe that a smaller MTU is required, then you can set + this to skip those timeouts and establish the connection more quickly. + + - It changes the value returned from `get_cleartext_mtu`. So if you have some + kind of estimate of the network-level MTU, then you can use this to figure out + how much overhead DTLS will need for hashes/padding/etc., and how much space + you have left for your application data. + + The MTU here is measuring the largest UDP *payload* you think can be sent, the + amount of encrypted data that can be handed to the operating system in a single + call to `send`. It should *not* include IP/UDP headers. Note that OS estimates + of the MTU often are link-layer MTUs, so you have to subtract off 28 bytes on + IPv4 and 48 bytes on IPv6 to get the ciphertext MTU. + + By default, Trio assumes an MTU of 1472 bytes on IPv4, and 1452 bytes on IPv6, + which correspond to the common Ethernet MTU of 1500 bytes after accounting for + IP/UDP overhead. + + """ + self._handshake_mtu = new_mtu + self._ssl.set_ciphertext_mtu(new_mtu) + + def get_cleartext_mtu(self) -> int: + """Returns the largest number of bytes that you can pass in a single call to + `send` while still fitting within the network-level MTU. + + See `set_ciphertext_mtu` for more details. + + """ + if not self._did_handshake: + raise trio.NeedHandshakeError + return self._ssl.get_cleartext_mtu() # type: ignore[no-any-return] + + def statistics(self) -> DTLSChannelStatistics: + """Returns a `DTLSChannelStatistics` object with statistics about this connection.""" + return DTLSChannelStatistics(self._packets_dropped_in_trio) + + +class DTLSEndpoint(metaclass=Final): + """A DTLS endpoint. + + A single UDP socket can handle arbitrarily many DTLS connections simultaneously, + acting as a client or server as needed. A `DTLSEndpoint` object holds a UDP socket + and manages these connections, which are represented as `DTLSChannel` objects. + + Args: + socket: (trio.socket.SocketType): A ``SOCK_DGRAM`` socket. If you want to accept + incoming connections in server mode, then you should probably bind the socket to + some known port. + incoming_packets_buffer (int): Each `DTLSChannel` using this socket has its own + buffer that holds incoming packets until you call `~DTLSChannel.receive` to read + them. This lets you adjust the size of this buffer. `~DTLSChannel.statistics` + lets you check if the buffer has overflowed. + + .. attribute:: socket + incoming_packets_buffer + + Both constructor arguments are also exposed as attributes, in case you need to + access them later. + + """ + + def __init__(self, socket: _SocketType, *, incoming_packets_buffer: int = 10): + # We do this lazily on first construction, so only people who actually use DTLS + # have to install PyOpenSSL. + global SSL + from OpenSSL import SSL + + # for __del__, in case the next line raises + self._initialized: bool = False + if socket.type != trio.socket.SOCK_DGRAM: + raise ValueError("DTLS requires a SOCK_DGRAM socket") + self._initialized = True + self.socket: _SocketType = socket + + self.incoming_packets_buffer = incoming_packets_buffer + self._token = trio.lowlevel.current_trio_token() + # We don't need to track handshaking vs non-handshake connections + # separately. We only keep one connection per remote address; as soon + # as a peer provides a valid cookie, we can immediately tear down the + # old connection. + # {remote address: DTLSChannel} + self._streams: WeakValueDictionary[Address, DTLSChannel] = WeakValueDictionary() + self._listening_context: Context | None = None + self._listening_key: bytes | None = None + self._incoming_connections_q = _Queue[DTLSChannel](float("inf")) + self._send_lock = trio.Lock() + self._closed = False + self._receive_loop_spawned = False + + def _ensure_receive_loop(self) -> None: + # We have to spawn this lazily, because on Windows it will immediately error out + # if the socket isn't already bound -- which for clients might not happen until + # after we send our first packet. + if not self._receive_loop_spawned: + trio.lowlevel.spawn_system_task( + dtls_receive_loop, weakref.ref(self), self.socket + ) + self._receive_loop_spawned = True + + def __del__(self) -> None: + # Do nothing if this object was never fully constructed + if not self._initialized: + return + # Close the socket in Trio context (if our Trio context still exists), so that + # the background task gets notified about the closure and can exit. + if not self._closed: + try: + self._token.run_sync_soon(self.close) + except RuntimeError: + pass + # Do this last, because it might raise an exception + warnings.warn( + f"unclosed DTLS endpoint {self!r}", ResourceWarning, source=self + ) + + def close(self) -> None: + """Close this socket, and all associated DTLS connections. + + This object can also be used as a context manager. + + """ + self._closed = True + self.socket.close() + for stream in list(self._streams.values()): + stream.close() + self._incoming_connections_q.s.close() + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + return self.close() + + def _check_closed(self) -> None: + if self._closed: + raise trio.ClosedResourceError + + # async_fn cannot be typed with ParamSpec, since we don't accept + # kwargs. Can be typed with TypeVarTuple once it's fully supported + # in mypy. + async def serve( + self, + ssl_context: Context, + async_fn: Callable[..., Awaitable[object]], + *args: Any, + task_status: TaskStatus = trio.TASK_STATUS_IGNORED, # type: ignore[has-type] + ) -> None: + """Listen for incoming connections, and spawn a handler for each using an + internal nursery. + + Similar to `~trio.serve_tcp`, this function never returns until cancelled, or + the `DTLSEndpoint` is closed and all handlers have exited. + + Usage commonly looks like:: + + async def handler(dtls_channel): + ... + + async with trio.open_nursery() as nursery: + await nursery.start(dtls_endpoint.serve, ssl_context, handler) + # ... do other things here ... + + The ``dtls_channel`` passed into the handler function has already performed the + "cookie exchange" part of the DTLS handshake, so the peer address is + trustworthy. But the actual cryptographic handshake doesn't happen until you + start using it, giving you a chance for any last minute configuration, and the + option to catch and handle handshake errors. + + Args: + ssl_context (OpenSSL.SSL.Context): The PyOpenSSL context object to use for + incoming connections. + async_fn: The handler function that will be invoked for each incoming + connection. + + """ + self._check_closed() + if self._listening_context is not None: + raise trio.BusyResourceError("another task is already listening") + try: + self.socket.getsockname() + except OSError: + raise RuntimeError("DTLS socket must be bound before it can serve") + self._ensure_receive_loop() + # We do cookie verification ourselves, so tell OpenSSL not to worry about it. + # (See also _inject_client_hello_untrusted.) + ssl_context.set_cookie_verify_callback(lambda *_: True) + try: + self._listening_context = ssl_context + task_status.started() + + async def handler_wrapper(stream: DTLSChannel) -> None: + with stream: + await async_fn(stream, *args) + + async with trio.open_nursery() as nursery: + async for stream in self._incoming_connections_q.r: # pragma: no branch + nursery.start_soon(handler_wrapper, stream) + finally: + self._listening_context = None + + def connect(self, address: tuple[str, int], ssl_context: Context) -> DTLSChannel: + """Initiate an outgoing DTLS connection. + + Notice that this is a synchronous method. That's because it doesn't actually + initiate any I/O – it just sets up a `DTLSChannel` object. The actual handshake + doesn't occur until you start using the `DTLSChannel`. This gives you a chance + to do further configuration first, like setting MTU etc. + + Args: + address: The address to connect to. Usually a (host, port) tuple, like + ``("127.0.0.1", 12345)``. + ssl_context (OpenSSL.SSL.Context): The PyOpenSSL context object to use for + this connection. + + Returns: + DTLSChannel + + """ + # it would be nice if we could detect when 'address' is our own endpoint (a + # loopback connection), because that can't work + # but I don't see how to do it reliably + self._check_closed() + channel = DTLSChannel._create(self, address, ssl_context) + channel._ssl.set_connect_state() + old_channel = self._streams.get(address) + if old_channel is not None: + old_channel._set_replaced() + self._streams[address] = channel + return channel diff --git a/trio/_file_io.py b/trio/_file_io.py index 32468af46b..6b79ae25b5 100644 --- a/trio/_file_io.py +++ b/trio/_file_io.py @@ -1,112 +1,274 @@ -from functools import partial -import io +from __future__ import annotations -from .abc import AsyncResource -from ._util import aiter_compat, async_wraps, fspath +import io +from functools import partial +from typing import ( + IO, + TYPE_CHECKING, + Any, + AnyStr, + BinaryIO, + Callable, + Generic, + Iterable, + TypeVar, + Union, + overload, +) import trio -__all__ = ['open_file', 'wrap_file'] +from ._util import async_wraps +from .abc import AsyncResource + +if TYPE_CHECKING: + from _typeshed import ( + OpenBinaryMode, + OpenBinaryModeReading, + OpenBinaryModeUpdating, + OpenBinaryModeWriting, + OpenTextMode, + StrOrBytesPath, + ) + from typing_extensions import Literal # This list is also in the docs, make sure to keep them in sync -_FILE_SYNC_ATTRS = { - 'closed', - 'encoding', - 'errors', - 'fileno', - 'isatty', - 'newlines', - 'readable', - 'seekable', - 'writable', +_FILE_SYNC_ATTRS: set[str] = { + "closed", + "encoding", + "errors", + "fileno", + "isatty", + "newlines", + "readable", + "seekable", + "writable", # not defined in *IOBase: - 'buffer', - 'raw', - 'line_buffering', - 'closefd', - 'name', - 'mode', - 'getvalue', - 'getbuffer', + "buffer", + "raw", + "line_buffering", + "closefd", + "name", + "mode", + "getvalue", + "getbuffer", } # This list is also in the docs, make sure to keep them in sync -_FILE_ASYNC_METHODS = { - 'flush', - 'read', - 'read1', - 'readall', - 'readinto', - 'readline', - 'readlines', - 'seek', - 'tell', - 'truncate', - 'write', - 'writelines', +_FILE_ASYNC_METHODS: set[str] = { + "flush", + "read", + "read1", + "readall", + "readinto", + "readline", + "readlines", + "seek", + "tell", + "truncate", + "write", + "writelines", # not defined in *IOBase: - 'readinto1', - 'peek', + "readinto1", + "peek", } -class AsyncIOWrapper(AsyncResource): +FileT = TypeVar("FileT") +FileT_co = TypeVar("FileT_co", covariant=True) +T = TypeVar("T") +T_co = TypeVar("T_co", covariant=True) +T_contra = TypeVar("T_contra", contravariant=True) +AnyStr_co = TypeVar("AnyStr_co", str, bytes, covariant=True) +AnyStr_contra = TypeVar("AnyStr_contra", str, bytes, contravariant=True) + +# This is a little complicated. IO objects have a lot of methods, and which are available on +# different types varies wildly. We want to match the interface of whatever file we're wrapping. +# This pile of protocols each has one sync method/property, meaning they're going to be compatible +# with a file class that supports that method/property. The ones parameterized with AnyStr take +# either str or bytes depending. + +# The wrapper is then a generic class, where the typevar is set to the type of the sync file we're +# wrapping. For generics, adding a type to self has a special meaning - properties/methods can be +# conditional - it's only valid to call them if the object you're accessing them on is compatible +# with that type hint. By using the protocols, the type checker will be checking to see if the +# wrapped type has that method, and only allow the methods that do to be called. We can then alter +# the signature however it needs to match runtime behaviour. +# More info: https://mypy.readthedocs.io/en/stable/more_types.html#advanced-uses-of-self-types +if TYPE_CHECKING: + from typing_extensions import Buffer, Protocol + + # fmt: off + + class _HasClosed(Protocol): + @property + def closed(self) -> bool: ... + + class _HasEncoding(Protocol): + @property + def encoding(self) -> str: ... + + class _HasErrors(Protocol): + @property + def errors(self) -> str | None: ... + + class _HasFileNo(Protocol): + def fileno(self) -> int: ... + + class _HasIsATTY(Protocol): + def isatty(self) -> bool: ... + + class _HasNewlines(Protocol[T_co]): + # Type varies here - documented to be None, tuple of strings, strings. Typeshed uses Any. + @property + def newlines(self) -> T_co: ... + + class _HasReadable(Protocol): + def readable(self) -> bool: ... + + class _HasSeekable(Protocol): + def seekable(self) -> bool: ... + + class _HasWritable(Protocol): + def writable(self) -> bool: ... + + class _HasBuffer(Protocol): + @property + def buffer(self) -> BinaryIO: ... + + class _HasRaw(Protocol): + @property + def raw(self) -> io.RawIOBase: ... + + class _HasLineBuffering(Protocol): + @property + def line_buffering(self) -> bool: ... + + class _HasCloseFD(Protocol): + @property + def closefd(self) -> bool: ... + + class _HasName(Protocol): + @property + def name(self) -> str: ... + + class _HasMode(Protocol): + @property + def mode(self) -> str: ... + + class _CanGetValue(Protocol[AnyStr_co]): + def getvalue(self) -> AnyStr_co: ... + + class _CanGetBuffer(Protocol): + def getbuffer(self) -> memoryview: ... + + class _CanFlush(Protocol): + def flush(self) -> None: ... + + class _CanRead(Protocol[AnyStr_co]): + def read(self, size: int | None = ..., /) -> AnyStr_co: ... + + class _CanRead1(Protocol): + def read1(self, size: int | None = ..., /) -> bytes: ... + + class _CanReadAll(Protocol[AnyStr_co]): + def readall(self) -> AnyStr_co: ... + + class _CanReadInto(Protocol): + def readinto(self, buf: Buffer, /) -> int | None: ... + + class _CanReadInto1(Protocol): + def readinto1(self, buffer: Buffer, /) -> int: ... + + class _CanReadLine(Protocol[AnyStr_co]): + def readline(self, size: int = ..., /) -> AnyStr_co: ... + + class _CanReadLines(Protocol[AnyStr]): + def readlines(self, hint: int = ...) -> list[AnyStr]: ... + + class _CanSeek(Protocol): + def seek(self, target: int, whence: int = 0, /) -> int: ... + + class _CanTell(Protocol): + def tell(self) -> int: ... + + class _CanTruncate(Protocol): + def truncate(self, size: int | None = ..., /) -> int: ... + + class _CanWrite(Protocol[AnyStr_contra]): + def write(self, data: AnyStr_contra, /) -> int: ... + + class _CanWriteLines(Protocol[T_contra]): + # The lines parameter varies for bytes/str, so use a typevar to make the async match. + def writelines(self, lines: Iterable[T_contra], /) -> None: ... + + class _CanPeek(Protocol[AnyStr_co]): + def peek(self, size: int = 0, /) -> AnyStr_co: ... + + class _CanDetach(Protocol[T_co]): + # The T typevar will be the unbuffered/binary file this file wraps. + def detach(self) -> T_co: ... + + class _CanClose(Protocol): + def close(self) -> None: ... + + +# FileT needs to be covariant for the protocol trick to work - the real IO types are effectively a +# subtype of the protocols. +class AsyncIOWrapper(AsyncResource, Generic[FileT_co]): """A generic :class:`~io.IOBase` wrapper that implements the :term:`asynchronous file object` interface. Wrapped methods that could block are executed in :meth:`trio.to_thread.run_sync`. - All properties and methods defined in in :mod:`~io` are exposed by this + All properties and methods defined in :mod:`~io` are exposed by this wrapper, if they exist in the wrapped file object. - """ - def __init__(self, file): + + def __init__(self, file: FileT_co) -> None: self._wrapped = file @property - def wrapped(self): - """object: A reference to the wrapped file object - - """ + def wrapped(self) -> FileT_co: + """object: A reference to the wrapped file object""" return self._wrapped - def __getattr__(self, name): - if name in _FILE_SYNC_ATTRS: - return getattr(self._wrapped, name) - if name in _FILE_ASYNC_METHODS: - meth = getattr(self._wrapped, name) + if not TYPE_CHECKING: - @async_wraps(self.__class__, self._wrapped.__class__, name) - async def wrapper(*args, **kwargs): - func = partial(meth, *args, **kwargs) - return await trio.to_thread.run_sync(func) + def __getattr__(self, name: str) -> object: + if name in _FILE_SYNC_ATTRS: + return getattr(self._wrapped, name) + if name in _FILE_ASYNC_METHODS: + meth = getattr(self._wrapped, name) - # cache the generated method - setattr(self, name, wrapper) - return wrapper + @async_wraps(self.__class__, self._wrapped.__class__, name) + async def wrapper(*args, **kwargs): + func = partial(meth, *args, **kwargs) + return await trio.to_thread.run_sync(func) - raise AttributeError(name) + # cache the generated method + setattr(self, name, wrapper) + return wrapper - def __dir__(self): + raise AttributeError(name) + + def __dir__(self) -> Iterable[str]: attrs = set(super().__dir__()) attrs.update(a for a in _FILE_SYNC_ATTRS if hasattr(self.wrapped, a)) - attrs.update( - a for a in _FILE_ASYNC_METHODS if hasattr(self.wrapped, a) - ) + attrs.update(a for a in _FILE_ASYNC_METHODS if hasattr(self.wrapped, a)) return attrs - @aiter_compat - def __aiter__(self): + def __aiter__(self) -> AsyncIOWrapper[FileT_co]: return self - async def __anext__(self): + async def __anext__(self: AsyncIOWrapper[_CanReadLine[AnyStr]]) -> AnyStr: line = await self.readline() if line: return line else: raise StopAsyncIteration - async def detach(self): + async def detach(self: AsyncIOWrapper[_CanDetach[T]]) -> AsyncIOWrapper[T]: """Like :meth:`io.BufferedIOBase.detach`, but async. This also re-wraps the result in a new :term:`asynchronous file object` @@ -117,7 +279,7 @@ async def detach(self): raw = await trio.to_thread.run_sync(self._wrapped.detach) return wrap_file(raw) - async def aclose(self): + async def aclose(self: AsyncIOWrapper[_CanClose]) -> None: """Like :meth:`io.IOBase.close`, but async. This is also shielded from cancellation; if a cancellation scope is @@ -129,20 +291,169 @@ async def aclose(self): with trio.CancelScope(shield=True): await trio.to_thread.run_sync(self._wrapped.close) - await trio.hazmat.checkpoint_if_cancelled() + await trio.lowlevel.checkpoint_if_cancelled() + + if TYPE_CHECKING: + # fmt: off + # Based on typing.IO and io stubs. + @property + def closed(self: AsyncIOWrapper[_HasClosed]) -> bool: ... + @property + def encoding(self: AsyncIOWrapper[_HasEncoding]) -> str: ... + @property + def errors(self: AsyncIOWrapper[_HasErrors]) -> str | None: ... + @property + def newlines(self: AsyncIOWrapper[_HasNewlines[T]]) -> T: ... + @property + def buffer(self: AsyncIOWrapper[_HasBuffer]) -> BinaryIO: ... + @property + def raw(self: AsyncIOWrapper[_HasRaw]) -> io.RawIOBase: ... + @property + def line_buffering(self: AsyncIOWrapper[_HasLineBuffering]) -> int: ... + @property + def closefd(self: AsyncIOWrapper[_HasCloseFD]) -> bool: ... + @property + def name(self: AsyncIOWrapper[_HasName]) -> str: ... + @property + def mode(self: AsyncIOWrapper[_HasMode]) -> str: ... + + def fileno(self: AsyncIOWrapper[_HasFileNo]) -> int: ... + def isatty(self: AsyncIOWrapper[_HasIsATTY]) -> bool: ... + def readable(self: AsyncIOWrapper[_HasReadable]) -> bool: ... + def seekable(self: AsyncIOWrapper[_HasSeekable]) -> bool: ... + def writable(self: AsyncIOWrapper[_HasWritable]) -> bool: ... + def getvalue(self: AsyncIOWrapper[_CanGetValue[AnyStr]]) -> AnyStr: ... + def getbuffer(self: AsyncIOWrapper[_CanGetBuffer]) -> memoryview: ... + async def flush(self: AsyncIOWrapper[_CanFlush]) -> None: ... + async def read(self: AsyncIOWrapper[_CanRead[AnyStr]], size: int | None = -1, /) -> AnyStr: ... + async def read1(self: AsyncIOWrapper[_CanRead1], size: int | None = -1, /) -> bytes: ... + async def readall(self: AsyncIOWrapper[_CanReadAll[AnyStr]]) -> AnyStr: ... + async def readinto(self: AsyncIOWrapper[_CanReadInto], buf: Buffer, /) -> int | None: ... + async def readline(self: AsyncIOWrapper[_CanReadLine[AnyStr]], size: int = -1, /) -> AnyStr: ... + async def readlines(self: AsyncIOWrapper[_CanReadLines[AnyStr]]) -> list[AnyStr]: ... + async def seek(self: AsyncIOWrapper[_CanSeek], target: int, whence: int = 0, /) -> int: ... + async def tell(self: AsyncIOWrapper[_CanTell]) -> int: ... + async def truncate(self: AsyncIOWrapper[_CanTruncate], size: int | None = None, /) -> int: ... + async def write(self: AsyncIOWrapper[_CanWrite[AnyStr]], data: AnyStr, /) -> int: ... + async def writelines(self: AsyncIOWrapper[_CanWriteLines[T]], lines: Iterable[T], /) -> None: ... + async def readinto1(self: AsyncIOWrapper[_CanReadInto1], buffer: Buffer, /) -> int: ... + async def peek(self: AsyncIOWrapper[_CanPeek[AnyStr]], size: int = 0, /) -> AnyStr: ... + + +# Type hints are copied from builtin open. +_OpenFile = Union["StrOrBytesPath", int] +_Opener = Callable[[str, int], int] + + +@overload +async def open_file( + file: _OpenFile, + mode: OpenTextMode = "r", + buffering: int = -1, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, + closefd: bool = True, + opener: _Opener | None = None, +) -> AsyncIOWrapper[io.TextIOWrapper]: + ... + + +@overload +async def open_file( + file: _OpenFile, + mode: OpenBinaryMode, + buffering: Literal[0], + encoding: None = None, + errors: None = None, + newline: None = None, + closefd: bool = True, + opener: _Opener | None = None, +) -> AsyncIOWrapper[io.FileIO]: + ... + + +@overload +async def open_file( + file: _OpenFile, + mode: OpenBinaryModeUpdating, + buffering: Literal[-1, 1] = -1, + encoding: None = None, + errors: None = None, + newline: None = None, + closefd: bool = True, + opener: _Opener | None = None, +) -> AsyncIOWrapper[io.BufferedRandom]: + ... + + +@overload +async def open_file( + file: _OpenFile, + mode: OpenBinaryModeWriting, + buffering: Literal[-1, 1] = -1, + encoding: None = None, + errors: None = None, + newline: None = None, + closefd: bool = True, + opener: _Opener | None = None, +) -> AsyncIOWrapper[io.BufferedWriter]: + ... + + +@overload +async def open_file( + file: _OpenFile, + mode: OpenBinaryModeReading, + buffering: Literal[-1, 1] = -1, + encoding: None = None, + errors: None = None, + newline: None = None, + closefd: bool = True, + opener: _Opener | None = None, +) -> AsyncIOWrapper[io.BufferedReader]: + ... + + +@overload +async def open_file( + file: _OpenFile, + mode: OpenBinaryMode, + buffering: int, + encoding: None = None, + errors: None = None, + newline: None = None, + closefd: bool = True, + opener: _Opener | None = None, +) -> AsyncIOWrapper[BinaryIO]: + ... + + +@overload +async def open_file( + file: _OpenFile, + mode: str, + buffering: int = -1, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, + closefd: bool = True, + opener: _Opener | None = None, +) -> AsyncIOWrapper[IO[Any]]: + ... async def open_file( - file, - mode='r', - buffering=-1, - encoding=None, - errors=None, - newline=None, - closefd=True, - opener=None -): - """Asynchronous version of :func:`io.open`. + file: _OpenFile, + mode: str = "r", + buffering: int = -1, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, + closefd: bool = True, + opener: _Opener | None = None, +) -> AsyncIOWrapper[Any]: + """Asynchronous version of :func:`open`. Returns: An :term:`asynchronous file object` @@ -159,20 +470,15 @@ async def open_file( :func:`trio.Path.open` """ - # python3.5 compat - if isinstance(file, trio.Path): - file = fspath(file) - _file = wrap_file( await trio.to_thread.run_sync( - io.open, file, mode, buffering, encoding, errors, newline, closefd, - opener + io.open, file, mode, buffering, encoding, errors, newline, closefd, opener ) ) return _file -def wrap_file(file): +def wrap_file(file: FileT) -> AsyncIOWrapper[FileT]: """This wraps any file object in a wrapper that provides an asynchronous file object interface. @@ -189,13 +495,14 @@ def wrap_file(file): assert await async_file.read() == 'asdf' """ - def has(attr): + + def has(attr: str) -> bool: return hasattr(file, attr) and callable(getattr(file, attr)) - if not (has('close') and (has('read') or has('write'))): + if not (has("close") and (has("read") or has("write"))): raise TypeError( - '{} does not implement required duck-file methods: ' - 'close and (read or write)'.format(file) + "{} does not implement required duck-file methods: " + "close and (read or write)".format(file) ) return AsyncIOWrapper(file) diff --git a/trio/_highlevel_generic.py b/trio/_highlevel_generic.py index 79a82d36d8..e1ac378c6a 100644 --- a/trio/_highlevel_generic.py +++ b/trio/_highlevel_generic.py @@ -1,10 +1,19 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + import attr import trio +from trio._util import Final + +if TYPE_CHECKING: + from .abc import SendStream, ReceiveStream, AsyncResource + from .abc import HalfCloseableStream -async def aclose_forcefully(resource): +async def aclose_forcefully(resource: AsyncResource) -> None: """Close an async resource or async generator immediately, without blocking to do any graceful cleanup. @@ -35,7 +44,7 @@ async def aclose_forcefully(resource): @attr.s(eq=False, hash=False) -class StapledStream(HalfCloseableStream): +class StapledStream(HalfCloseableStream, metaclass=Final): """This class `staples `__ together two unidirectional streams to make single bidirectional stream. @@ -69,22 +78,19 @@ class StapledStream(HalfCloseableStream): is delegated to this object. """ - send_stream = attr.ib() - receive_stream = attr.ib() - async def send_all(self, data): - """Calls ``self.send_stream.send_all``. + send_stream: SendStream = attr.ib() + receive_stream: ReceiveStream = attr.ib() - """ + async def send_all(self, data: bytes | bytearray | memoryview) -> None: + """Calls ``self.send_stream.send_all``.""" return await self.send_stream.send_all(data) - async def wait_send_all_might_not_block(self): - """Calls ``self.send_stream.wait_send_all_might_not_block``. - - """ + async def wait_send_all_might_not_block(self) -> None: + """Calls ``self.send_stream.wait_send_all_might_not_block``.""" return await self.send_stream.wait_send_all_might_not_block() - async def send_eof(self): + async def send_eof(self) -> None: """Shuts down the send side of the stream. If ``self.send_stream.send_eof`` exists, then calls it. Otherwise, @@ -92,20 +98,19 @@ async def send_eof(self): """ if hasattr(self.send_stream, "send_eof"): - return await self.send_stream.send_eof() + # send_stream.send_eof() is not defined in Trio, this should maybe be + # redesigned so it's possible to type it. + return await self.send_stream.send_eof() # type: ignore[no-any-return] else: return await self.send_stream.aclose() - async def receive_some(self, max_bytes=None): - """Calls ``self.receive_stream.receive_some``. - - """ + # we intentionally accept more types from the caller than we support returning + async def receive_some(self, max_bytes: int | None = None) -> bytes: + """Calls ``self.receive_stream.receive_some``.""" return await self.receive_stream.receive_some(max_bytes) - async def aclose(self): - """Calls ``aclose`` on both underlying streams. - - """ + async def aclose(self) -> None: + """Calls ``aclose`` on both underlying streams.""" try: await self.send_stream.aclose() finally: diff --git a/trio/_highlevel_open_tcp_listeners.py b/trio/_highlevel_open_tcp_listeners.py index 2625238803..6211917254 100644 --- a/trio/_highlevel_open_tcp_listeners.py +++ b/trio/_highlevel_open_tcp_listeners.py @@ -3,9 +3,11 @@ from math import inf import trio + from . import socket as tsocket -__all__ = ["open_tcp_listeners", "serve_tcp"] +if sys.version_info < (3, 11): + from exceptiongroup import ExceptionGroup # Default backlog size: @@ -41,7 +43,7 @@ def _compute_backlog(backlog): # Many systems (Linux, BSDs, ...) store the backlog in a uint16 and are # missing overflow protection, so we apply our own overflow protection. # https://github.com/golang/go/issues/5030 - return min(backlog, 0xffff) + return min(backlog, 0xFFFF) async def open_tcp_listeners(port, *, host=None, backlog=None): @@ -89,15 +91,12 @@ async def open_tcp_listeners(port, *, host=None, backlog=None): # doesn't: # http://klickverbot.at/blog/2012/01/getaddrinfo-edge-case-behavior-on-windows-linux-and-osx/ if not isinstance(port, int): - raise TypeError("port must be an int not {!r}".format(port)) + raise TypeError(f"port must be an int not {port!r}") backlog = _compute_backlog(backlog) addresses = await tsocket.getaddrinfo( - host, - port, - type=tsocket.SOCK_STREAM, - flags=tsocket.AI_PASSIVE, + host, port, type=tsocket.SOCK_STREAM, flags=tsocket.AI_PASSIVE ) listeners = [] @@ -121,14 +120,10 @@ async def open_tcp_listeners(port, *, host=None, backlog=None): try: # See https://github.com/python-trio/trio/issues/39 if sys.platform != "win32": - sock.setsockopt( - tsocket.SOL_SOCKET, tsocket.SO_REUSEADDR, 1 - ) + sock.setsockopt(tsocket.SOL_SOCKET, tsocket.SO_REUSEADDR, 1) if family == tsocket.AF_INET6: - sock.setsockopt( - tsocket.IPPROTO_IPV6, tsocket.IPV6_V6ONLY, 1 - ) + sock.setsockopt(tsocket.IPPROTO_IPV6, tsocket.IPV6_V6ONLY, 1) await sock.bind(sockaddr) sock.listen(backlog) @@ -143,11 +138,13 @@ async def open_tcp_listeners(port, *, host=None, backlog=None): raise if unsupported_address_families and not listeners: - raise OSError( - errno.EAFNOSUPPORT, + msg = ( "This system doesn't support any of the kinds of " "socket that that address could use" - ) from trio.MultiError(unsupported_address_families) + ) + raise OSError(errno.EAFNOSUPPORT, msg) from ExceptionGroup( + msg, unsupported_address_families + ) return listeners @@ -159,7 +156,7 @@ async def serve_tcp( host=None, backlog=None, handler_nursery=None, - task_status=trio.TASK_STATUS_IGNORED + task_status=trio.TASK_STATUS_IGNORED, ): """Listen for incoming TCP connections, and for each one start a task running ``handler(stream)``. @@ -226,8 +223,5 @@ async def serve_tcp( """ listeners = await trio.open_tcp_listeners(port, host=host, backlog=backlog) await trio.serve_listeners( - handler, - listeners, - handler_nursery=handler_nursery, - task_status=task_status + handler, listeners, handler_nursery=handler_nursery, task_status=task_status ) diff --git a/trio/_highlevel_open_tcp_stream.py b/trio/_highlevel_open_tcp_stream.py index 71418b12d3..a2477104d9 100644 --- a/trio/_highlevel_open_tcp_stream.py +++ b/trio/_highlevel_open_tcp_stream.py @@ -1,9 +1,13 @@ +import sys from contextlib import contextmanager import trio -from trio.socket import getaddrinfo, SOCK_STREAM, socket +from trio._core._multierror import MultiError +from trio.socket import SOCK_STREAM, getaddrinfo, socket + +if sys.version_info < (3, 11): + from exceptiongroup import ExceptionGroup -__all__ = ["open_tcp_stream"] # Implementation of RFC 6555 "Happy eyeballs" # https://tools.ietf.org/html/rfc6555 @@ -116,8 +120,10 @@ def close_all(): sock.close() except BaseException as exc: errs.append(exc) - if errs: - raise trio.MultiError(errs) + if len(errs) == 1: + raise errs[0] + elif errs: + raise MultiError(errs) def reorder_for_rfc_6555_section_5_4(targets): @@ -141,9 +147,9 @@ def reorder_for_rfc_6555_section_5_4(targets): def format_host_port(host, port): host = host.decode("ascii") if isinstance(host, bytes) else host if ":" in host: - return "[{}]:{}".format(host, port) + return f"[{host}]:{port}" else: - return "{}:{}".format(host, port) + return f"{host}:{port}" # Twisted's HostnameEndpoint has a good set of configurables: @@ -167,11 +173,7 @@ def format_host_port(host, port): # AF_INET6: "..."} # this might be simpler after async def open_tcp_stream( - host, - port, - *, - # No trailing comma b/c bpo-9232 (fixed in py36) - happy_eyeballs_delay=DEFAULT_DELAY + host, port, *, happy_eyeballs_delay=DEFAULT_DELAY, local_address=None ): """Connect to the given host and port over TCP. @@ -207,13 +209,30 @@ async def open_tcp_stream( Args: host (str or bytes): The host to connect to. Can be an IPv4 address, IPv6 address, or a hostname. + port (int): The port to connect to. + happy_eyeballs_delay (float): How many seconds to wait for each connection attempt to succeed or fail before getting impatient and starting another one in parallel. Set to `math.inf` if you want to limit to only one connection attempt at a time (like :func:`socket.create_connection`). Default: 0.25 (250 ms). + local_address (None or str): The local IP address or hostname to use as + the source for outgoing connections. If ``None``, we let the OS pick + the source IP. + + This is useful in some exotic networking configurations where your + host has multiple IP addresses, and you want to force the use of a + specific one. + + Note that if you pass an IPv4 ``local_address``, then you won't be + able to connect to IPv6 hosts, and vice-versa. If you want to take + advantage of this to force the use of IPv4 or IPv6 without + specifying an exact source address, you can use the IPv4 wildcard + address ``local_address="0.0.0.0"``, or the IPv6 wildcard address + ``local_address="::"``. + Returns: SocketStream: a :class:`~trio.abc.Stream` connected to the given server. @@ -224,6 +243,7 @@ async def open_tcp_stream( open_ssl_over_tcp_stream """ + # To keep our public API surface smaller, rule out some cases that # getaddrinfo will accept in some circumstances, but that act weird or # have non-portable behavior or are just plain not useful. @@ -231,7 +251,7 @@ async def open_tcp_stream( if host is None: raise ValueError("host cannot be None") if not isinstance(port, int): - raise TypeError("port must be int, not {!r}".format(port)) + raise TypeError(f"port must be int, not {port!r}") if happy_eyeballs_delay is None: happy_eyeballs_delay = DEFAULT_DELAY @@ -270,6 +290,53 @@ async def attempt_connect(socket_args, sockaddr, attempt_failed): sock = socket(*socket_args) open_sockets.add(sock) + if local_address is not None: + # TCP connections are identified by a 4-tuple: + # + # (local IP, local port, remote IP, remote port) + # + # So if a single local IP wants to make multiple connections + # to the same (remote IP, remote port) pair, then those + # connections have to use different local ports, or else TCP + # won't be able to tell them apart. OTOH, if you have multiple + # connections to different remote IP/ports, then those + # connections can share a local port. + # + # Normally, when you call bind(), the kernel will immediately + # assign a specific local port to your socket. At this point + # the kernel doesn't know which (remote IP, remote port) + # you're going to use, so it has to pick a local port that + # *no* other connection is using. That's the only way to + # guarantee that this local port will be usable later when we + # call connect(). (Alternatively, you can set SO_REUSEADDR to + # allow multiple nascent connections to share the same port, + # but then connect() might fail with EADDRNOTAVAIL if we get + # unlucky and our TCP 4-tuple ends up colliding with another + # unrelated connection.) + # + # So calling bind() before connect() works, but it disables + # sharing of local ports. This is inefficient: it makes you + # more likely to run out of local ports. + # + # But on some versions of Linux, we can re-enable sharing of + # local ports by setting a special flag. This flag tells + # bind() to only bind the IP, and not the port. That way, + # connect() is allowed to pick the the port, and it can do a + # better job of it because it knows the remote IP/port. + try: + sock.setsockopt( + trio.socket.IPPROTO_IP, trio.socket.IP_BIND_ADDRESS_NO_PORT, 1 + ) + except (OSError, AttributeError): + pass + try: + await sock.bind((local_address, 0)) + except OSError: + raise OSError( + f"local_address={local_address!r} is incompatible " + f"with remote address {sockaddr}" + ) + await sock.connect(sockaddr) # Success! Save the winning socket and cancel all outstanding @@ -305,7 +372,7 @@ async def attempt_connect(socket_args, sockaddr, attempt_failed): msg = "all attempts to connect to {} failed".format( format_host_port(host, port) ) - raise OSError(msg) from trio.MultiError(oserrors) + raise OSError(msg) from ExceptionGroup(msg, oserrors) else: stream = trio.SocketStream(winning_socket) open_sockets.remove(winning_socket) diff --git a/trio/_highlevel_open_unix_listeners.py b/trio/_highlevel_open_unix_listeners.py index 3c0bb55850..35362be33e 100644 --- a/trio/_highlevel_open_unix_listeners.py +++ b/trio/_highlevel_open_unix_listeners.py @@ -7,6 +7,14 @@ from math import inf import trio +from trio import SocketListener + +try: + from trio.socket import AF_UNIX + + HAS_UNIX = True +except ImportError: + HAS_UNIX = False __all__ = ["open_unix_listeners", "serve_unix"] @@ -44,10 +52,10 @@ def _compute_backlog(backlog): # Many systems (Linux, BSDs, ...) store the backlog in a uint16 and are # missing overflow protection, so we apply our own overflow protection. # https://github.com/golang/go/issues/5030 - return min(backlog, 0xffff) + return min(backlog, 0xFFFF) -class UnixSocketListener(trio.SocketListener): +class UnixSocketListener(SocketListener): @staticmethod def _inode(filename): """Return a (dev, inode) tuple uniquely identifying a file.""" @@ -89,6 +97,9 @@ def __init__(self, sock, path, inode): @staticmethod def _create(path, mode, backlog): + if not HAS_UNIX: + raise RuntimeError("Unix sockets are not supported on this platform") + # Sanitise and pre-verify socket path path = os.path.abspath(path) folder = os.path.dirname(path) @@ -101,7 +112,7 @@ def _create(path, mode, backlog): pass # Create new socket with a random temporary name tmp_path = f"{path}.{secrets.token_urlsafe()}" - sock = socket.socket(socket.AF_UNIX) + sock = socket.socket(AF_UNIX) try: # Critical section begins (filename races) sock.bind(tmp_path) @@ -179,7 +190,12 @@ async def open_unix_listeners(path, *, mode=0o666, backlog=None): Returns: list of :class:`SocketListener` + Raises: + RuntimeError: If AF_UNIX sockets are not supported. """ + if not HAS_UNIX: + raise RuntimeError("Unix sockets are not supported on this platform") + return [await UnixSocketListener.create(path, mode=mode, backlog=backlog)] @@ -189,7 +205,7 @@ async def serve_unix( *, backlog=None, handler_nursery=None, - task_status=trio.TASK_STATUS_IGNORED + task_status=trio.TASK_STATUS_IGNORED, ): """Listen for incoming UNIX connections, and for each one start a task running ``handler(stream)``. @@ -225,11 +241,13 @@ async def serve_unix( Returns: This function only returns when cancelled. + Raises: + RuntimeError: If AF_UNIX sockets are not supported. """ - listeners = await trio.open_unix_listeners(path, backlog=backlog) + if not HAS_UNIX: + raise RuntimeError("Unix sockets are not supported on this platform") + + listeners = await open_unix_listeners(path, backlog=backlog) await trio.serve_listeners( - handler, - listeners, - handler_nursery=handler_nursery, - task_status=task_status + handler, listeners, handler_nursery=handler_nursery, task_status=task_status ) diff --git a/trio/_highlevel_open_unix_stream.py b/trio/_highlevel_open_unix_stream.py index 59141ebc38..c2c3a3ca7c 100644 --- a/trio/_highlevel_open_unix_stream.py +++ b/trio/_highlevel_open_unix_stream.py @@ -1,16 +1,16 @@ +import os from contextlib import contextmanager import trio -from trio.socket import socket, SOCK_STREAM +from trio.socket import SOCK_STREAM, socket try: from trio.socket import AF_UNIX + has_unix = True except ImportError: has_unix = False -__all__ = ["open_unix_socket"] - @contextmanager def close_on_error(obj): @@ -21,7 +21,7 @@ def close_on_error(obj): raise -async def open_unix_socket(filename,): +async def open_unix_socket(filename): """Opens a connection to the specified `Unix domain socket `__. @@ -44,6 +44,6 @@ async def open_unix_socket(filename,): # possible location to connect to sock = socket(AF_UNIX, SOCK_STREAM) with close_on_error(sock): - await sock.connect(trio._util.fspath(filename)) + await sock.connect(os.fspath(filename)) return trio.SocketStream(sock) diff --git a/trio/_highlevel_serve_listeners.py b/trio/_highlevel_serve_listeners.py index 958624a6f4..0585fa516f 100644 --- a/trio/_highlevel_serve_listeners.py +++ b/trio/_highlevel_serve_listeners.py @@ -4,8 +4,6 @@ import trio -__all__ = ["serve_listeners"] - # Errors that accept(2) can return, and which indicate that the system is # overloaded ACCEPT_CAPACITY_ERRNOS = { @@ -41,7 +39,7 @@ async def _serve_one_listener(listener, handler_nursery, handler): errno.errorcode[exc.errno], os.strerror(exc.errno), SLEEP_TIME, - exc_info=True + exc_info=True, ) await trio.sleep(SLEEP_TIME) else: @@ -51,11 +49,7 @@ async def _serve_one_listener(listener, handler_nursery, handler): async def serve_listeners( - handler, - listeners, - *, - handler_nursery=None, - task_status=trio.TASK_STATUS_IGNORED + handler, listeners, *, handler_nursery=None, task_status=trio.TASK_STATUS_IGNORED ): r"""Listen for incoming connections on ``listeners``, and for each one start a task running ``handler(stream)``. @@ -120,9 +114,7 @@ async def serve_listeners( if handler_nursery is None: handler_nursery = nursery for listener in listeners: - nursery.start_soon( - _serve_one_listener, listener, handler_nursery, handler - ) + nursery.start_soon(_serve_one_listener, listener, handler_nursery, handler) # The listeners are already queueing connections when we're called, # but we wait until the end to call started() just in case we get an # error or whatever. diff --git a/trio/_highlevel_socket.py b/trio/_highlevel_socket.py index e816ea2dea..ce96153805 100644 --- a/trio/_highlevel_socket.py +++ b/trio/_highlevel_socket.py @@ -1,14 +1,18 @@ # "High-level" networking interface +from __future__ import annotations import errno from contextlib import contextmanager +from typing import TYPE_CHECKING import trio + from . import socket as tsocket -from ._util import ConflictDetector +from ._util import ConflictDetector, Final from .abc import HalfCloseableStream, Listener -__all__ = ["SocketStream", "SocketListener"] +if TYPE_CHECKING: + from ._socket import _SocketType as SocketType # XX TODO: this number was picked arbitrarily. We should do experiments to # tune it. (Or make it dynamic -- one idea is to start small and increase it @@ -30,16 +34,12 @@ def _translate_socket_errors_to_stream_errors(): yield except OSError as exc: if exc.errno in _closed_stream_errnos: - raise trio.ClosedResourceError( - "this socket was already closed" - ) from None + raise trio.ClosedResourceError("this socket was already closed") from None else: - raise trio.BrokenResourceError( - "socket connection broken: {}".format(exc) - ) from exc + raise trio.BrokenResourceError(f"socket connection broken: {exc}") from exc -class SocketStream(HalfCloseableStream): +class SocketStream(HalfCloseableStream, metaclass=Final): """An implementation of the :class:`trio.abc.HalfCloseableStream` interface based on a raw network socket. @@ -62,7 +62,8 @@ class SocketStream(HalfCloseableStream): The Trio socket object that this stream wraps. """ - def __init__(self, socket): + + def __init__(self, socket: SocketType): if not isinstance(socket, tsocket.SocketType): raise TypeError("SocketStream requires a Trio socket object") if socket.type != tsocket.SOCK_STREAM: @@ -92,9 +93,7 @@ def __init__(self, socket): # http://devstreaming.apple.com/videos/wwdc/2015/719ui2k57m/719/719_your_app_and_next_generation_networks.pdf?dl=1 # ). The theory is that you want it to be bandwidth * # rescheduling interval. - self.setsockopt( - tsocket.IPPROTO_TCP, tsocket.TCP_NOTSENT_LOWAT, 2**14 - ) + self.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NOTSENT_LOWAT, 2**14) except OSError: pass @@ -106,10 +105,8 @@ async def send_all(self, data): with memoryview(data) as data: if not data: if self.socket.fileno() == -1: - raise trio.ClosedResourceError( - "socket was already closed" - ) - await trio.hazmat.checkpoint() + raise trio.ClosedResourceError("socket was already closed") + await trio.lowlevel.checkpoint() return total_sent = 0 while total_sent < len(data): @@ -117,16 +114,16 @@ async def send_all(self, data): sent = await self.socket.send(remaining) total_sent += sent - async def wait_send_all_might_not_block(self): + async def wait_send_all_might_not_block(self) -> None: with self._send_conflict_detector: if self.socket.fileno() == -1: raise trio.ClosedResourceError with _translate_socket_errors_to_stream_errors(): await self.socket.wait_writable() - async def send_eof(self): + async def send_eof(self) -> None: with self._send_conflict_detector: - await trio.hazmat.checkpoint() + await trio.lowlevel.checkpoint() # On macOS, calling shutdown a second time raises ENOTCONN, but # send_eof needs to be idempotent. if self.socket.did_shutdown_SHUT_WR: @@ -134,7 +131,7 @@ async def send_eof(self): with _translate_socket_errors_to_stream_errors(): self.socket.shutdown(tsocket.SHUT_WR) - async def receive_some(self, max_bytes=None): + async def receive_some(self, max_bytes: int | None = None) -> bytes: if max_bytes is None: max_bytes = DEFAULT_RECEIVE_SIZE if max_bytes < 1: @@ -142,9 +139,9 @@ async def receive_some(self, max_bytes=None): with _translate_socket_errors_to_stream_errors(): return await self.socket.recv(max_bytes) - async def aclose(self): + async def aclose(self) -> None: self.socket.close() - await trio.hazmat.checkpoint() + await trio.lowlevel.checkpoint() # __aenter__, __aexit__ inherited from HalfCloseableStream are OK @@ -322,7 +319,7 @@ def getsockopt(self, level, option, buffersize=0): pass -class SocketListener(Listener[SocketStream]): +class SocketListener(Listener[SocketStream], metaclass=Final): """A :class:`~trio.abc.Listener` that uses a listening socket to accept incoming connections as :class:`SocketStream` objects. @@ -338,15 +335,14 @@ class SocketListener(Listener[SocketStream]): The Trio socket object that this stream wraps. """ - def __init__(self, socket): + + def __init__(self, socket: SocketType): if not isinstance(socket, tsocket.SocketType): raise TypeError("SocketListener requires a Trio socket object") if socket.type != tsocket.SOCK_STREAM: raise ValueError("SocketListener requires a SOCK_STREAM socket") try: - listening = socket.getsockopt( - tsocket.SOL_SOCKET, tsocket.SO_ACCEPTCONN - ) + listening = socket.getsockopt(tsocket.SOL_SOCKET, tsocket.SO_ACCEPTCONN) except OSError: # SO_ACCEPTCONN fails on macOS; we just have to trust the user. pass @@ -356,7 +352,7 @@ def __init__(self, socket): self.socket = socket - async def accept(self): + async def accept(self) -> SocketStream: """Accept an incoming connection. Returns: @@ -384,9 +380,7 @@ async def accept(self): else: return SocketStream(sock) - async def aclose(self): - """Close this listener and its underlying socket. - - """ + async def aclose(self) -> None: + """Close this listener and its underlying socket.""" self.socket.close() - await trio.hazmat.checkpoint() + await trio.lowlevel.checkpoint() diff --git a/trio/_highlevel_ssl_helpers.py b/trio/_highlevel_ssl_helpers.py index 9b68e942f4..ad77a302f0 100644 --- a/trio/_highlevel_ssl_helpers.py +++ b/trio/_highlevel_ssl_helpers.py @@ -1,12 +1,8 @@ -import trio import ssl -from ._highlevel_open_tcp_stream import DEFAULT_DELAY +import trio -__all__ = [ - "open_ssl_over_tcp_stream", "open_ssl_over_tcp_listeners", - "serve_ssl_over_tcp" -] +from ._highlevel_open_tcp_stream import DEFAULT_DELAY # It might have been nice to take a ssl_protocols= argument here to set up @@ -14,12 +10,8 @@ # if it's one we created, but not OK if it's one that was passed in... and # the one major protocol using NPN/ALPN is HTTP/2, which mandates that you use # a specially configured SSLContext anyway! I also thought maybe we could copy -# the given SSLContext and then mutate the copy, but it's no good: -# copy.copy(SSLContext) seems to succeed, but the state is not transferred! -# For example, with CPython 3.5, we have: -# ctx = ssl.create_default_context() -# assert ctx.check_hostname == True -# assert copy.copy(ctx).check_hostname == False +# the given SSLContext and then mutate the copy, but it's no good as SSLContext +# objects can't be copied: https://bugs.python.org/issue33023. # So... let's punt on that for now. Hopefully we'll be getting a new Python # TLS API soon and can revisit this then. async def open_ssl_over_tcp_stream( @@ -28,8 +20,7 @@ async def open_ssl_over_tcp_stream( *, https_compatible=False, ssl_context=None, - # No trailing comma b/c bpo-9232 (fixed in py36) - happy_eyeballs_delay=DEFAULT_DELAY + happy_eyeballs_delay=DEFAULT_DELAY, ): """Make a TLS-encrypted Connection to the given host and port over TCP. @@ -58,17 +49,16 @@ async def open_ssl_over_tcp_stream( """ tcp_stream = await trio.open_tcp_stream( - host, - port, - happy_eyeballs_delay=happy_eyeballs_delay, + host, port, happy_eyeballs_delay=happy_eyeballs_delay ) if ssl_context is None: ssl_context = ssl.create_default_context() + + if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"): + ssl_context.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF + return trio.SSLStream( - tcp_stream, - ssl_context, - server_hostname=host, - https_compatible=https_compatible, + tcp_stream, ssl_context, server_hostname=host, https_compatible=https_compatible ) @@ -87,15 +77,10 @@ async def open_ssl_over_tcp_listeners( backlog (int or None): See :func:`open_tcp_listeners` for details. """ - tcp_listeners = await trio.open_tcp_listeners( - port, host=host, backlog=backlog - ) + tcp_listeners = await trio.open_tcp_listeners(port, host=host, backlog=backlog) ssl_listeners = [ - trio.SSLListener( - tcp_listener, - ssl_context, - https_compatible=https_compatible, - ) for tcp_listener in tcp_listeners + trio.SSLListener(tcp_listener, ssl_context, https_compatible=https_compatible) + for tcp_listener in tcp_listeners ] return ssl_listeners @@ -109,7 +94,7 @@ async def serve_ssl_over_tcp( https_compatible=False, backlog=None, handler_nursery=None, - task_status=trio.TASK_STATUS_IGNORED + task_status=trio.TASK_STATUS_IGNORED, ): """Listen for incoming TCP connections, and for each one start a task running ``handler(stream)``. @@ -163,11 +148,8 @@ async def serve_ssl_over_tcp( ssl_context, host=host, https_compatible=https_compatible, - backlog=backlog + backlog=backlog, ) await trio.serve_listeners( - handler, - listeners, - handler_nursery=handler_nursery, - task_status=task_status + handler, listeners, handler_nursery=handler_nursery, task_status=task_status ) diff --git a/trio/_path.py b/trio/_path.py index bbadf9d874..b7e6b16e4a 100644 --- a/trio/_path.py +++ b/trio/_path.py @@ -1,61 +1,96 @@ -from functools import wraps, partial +from __future__ import annotations + +import inspect import os -import types import pathlib +import sys +import types +from collections.abc import Awaitable, Callable, Iterable +from functools import partial, wraps +from io import BufferedRandom, BufferedReader, BufferedWriter, FileIO, TextIOWrapper +from typing import ( + IO, + TYPE_CHECKING, + Any, + BinaryIO, + ClassVar, + TypeVar, + Union, + cast, + overload, +) import trio -from trio._util import async_wraps, fspath +from trio._file_io import AsyncIOWrapper as _AsyncIOWrapper +from trio._util import Final, async_wraps -__all__ = ['Path'] +if TYPE_CHECKING: + from _typeshed import ( + OpenBinaryMode, + OpenBinaryModeReading, + OpenBinaryModeUpdating, + OpenBinaryModeWriting, + OpenTextMode, + ) + from typing_extensions import Concatenate, Literal, ParamSpec, TypeAlias + P = ParamSpec("P") -# python3.5 compat: __fspath__ does not exist in 3.5, so unwrap any trio.Path -# being passed to any wrapped method -def unwrap_paths(args): - new_args = [] - for arg in args: - if isinstance(arg, Path): - arg = arg._wrapped - new_args.append(arg) - return new_args +T = TypeVar("T") +StrPath: TypeAlias = Union[str, "os.PathLike[str]"] # Only subscriptable in 3.9+ # re-wrap return value from methods that return new instances of pathlib.Path -def rewrap_path(value): +def rewrap_path(value: T) -> T | Path: if isinstance(value, pathlib.Path): - value = Path(value) - return value + return Path(value) + else: + return value -def _forward_factory(cls, attr_name, attr): +def _forward_factory( + cls: AsyncAutoWrapperType, + attr_name: str, + attr: Callable[Concatenate[pathlib.Path, P], T], +) -> Callable[Concatenate[Path, P], T | Path]: @wraps(attr) - def wrapper(self, *args, **kwargs): - args = unwrap_paths(args) + def wrapper(self: Path, *args: P.args, **kwargs: P.kwargs) -> T | Path: attr = getattr(self._wrapped, attr_name) value = attr(*args, **kwargs) return rewrap_path(value) + # Assigning this makes inspect and therefore Sphinx show the original parameters. + # It's not defined on functions normally though, this is a custom attribute. + assert isinstance(wrapper, types.FunctionType) + wrapper.__signature__ = inspect.signature(attr) + return wrapper -def _forward_magic(cls, attr): +def _forward_magic( + cls: AsyncAutoWrapperType, attr: Callable[..., T] +) -> Callable[..., Path | T]: sentinel = object() @wraps(attr) - def wrapper(self, other=sentinel): + def wrapper(self: Path, other: object = sentinel) -> Path | T: if other is sentinel: return attr(self._wrapped) if isinstance(other, cls): - other = other._wrapped + other = cast(Path, other)._wrapped value = attr(self._wrapped, other) return rewrap_path(value) + assert isinstance(wrapper, types.FunctionType) + wrapper.__signature__ = inspect.signature(attr) return wrapper -def iter_wrapper_factory(cls, meth_name): +def iter_wrapper_factory( + cls: AsyncAutoWrapperType, meth_name: str +) -> Callable[Concatenate[Path, P], Awaitable[Iterable[Path]]]: @async_wraps(cls, cls._wraps, meth_name) - async def wrapper(self, *args, **kwargs): + async def wrapper(self: Path, *args: P.args, **kwargs: P.kwargs) -> Iterable[Path]: meth = getattr(self._wrapped, meth_name) func = partial(meth, *args, **kwargs) # Make sure that the full iteration is performed in the thread @@ -66,10 +101,11 @@ async def wrapper(self, *args, **kwargs): return wrapper -def thread_wrapper_factory(cls, meth_name): +def thread_wrapper_factory( + cls: AsyncAutoWrapperType, meth_name: str +) -> Callable[Concatenate[Path, P], Awaitable[Path]]: @async_wraps(cls, cls._wraps, meth_name) - async def wrapper(self, *args, **kwargs): - args = unwrap_paths(args) + async def wrapper(self: Path, *args: P.args, **kwargs: P.kwargs) -> Path: meth = getattr(self._wrapped, meth_name) func = partial(meth, *args, **kwargs) value = await trio.to_thread.run_sync(func) @@ -78,21 +114,31 @@ async def wrapper(self, *args, **kwargs): return wrapper -def classmethod_wrapper_factory(cls, meth_name): - @classmethod +def classmethod_wrapper_factory( + cls: AsyncAutoWrapperType, meth_name: str +) -> classmethod: @async_wraps(cls, cls._wraps, meth_name) - async def wrapper(cls, *args, **kwargs): - args = unwrap_paths(args) + async def wrapper(cls: type[Path], *args: Any, **kwargs: Any) -> Path: meth = getattr(cls._wraps, meth_name) func = partial(meth, *args, **kwargs) value = await trio.to_thread.run_sync(func) return rewrap_path(value) - return wrapper + assert isinstance(wrapper, types.FunctionType) + wrapper.__signature__ = inspect.signature(getattr(cls._wraps, meth_name)) + return classmethod(wrapper) + +class AsyncAutoWrapperType(Final): + _forwards: type + _wraps: type + _forward_magic: list[str] + _wrap_iter: list[str] + _forward: list[str] -class AsyncAutoWrapperType(type): - def __init__(cls, name, bases, attrs): + def __init__( + cls, name: str, bases: tuple[type, ...], attrs: dict[str, object] + ) -> None: super().__init__(name, bases, attrs) cls._forward = [] @@ -101,10 +147,10 @@ def __init__(cls, name, bases, attrs): type(cls).generate_magic(cls, attrs) type(cls).generate_iter(cls, attrs) - def generate_forwards(cls, attrs): + def generate_forwards(cls, attrs: dict[str, object]) -> None: # forward functions of _forwards for attr_name, attr in cls._forwards.__dict__.items(): - if attr_name.startswith('_') or attr_name in attrs: + if attr_name.startswith("_") or attr_name in attrs: continue if isinstance(attr, property): @@ -115,33 +161,39 @@ def generate_forwards(cls, attrs): else: raise TypeError(attr_name, type(attr)) - def generate_wraps(cls, attrs): + def generate_wraps(cls, attrs: dict[str, object]) -> None: # generate wrappers for functions of _wraps + wrapper: classmethod | Callable for attr_name, attr in cls._wraps.__dict__.items(): # .z. exclude cls._wrap_iter - if attr_name.startswith('_') or attr_name in attrs: + if attr_name.startswith("_") or attr_name in attrs: continue if isinstance(attr, classmethod): wrapper = classmethod_wrapper_factory(cls, attr_name) setattr(cls, attr_name, wrapper) elif isinstance(attr, types.FunctionType): wrapper = thread_wrapper_factory(cls, attr_name) + assert isinstance(wrapper, types.FunctionType) + wrapper.__signature__ = inspect.signature(attr) setattr(cls, attr_name, wrapper) else: raise TypeError(attr_name, type(attr)) - def generate_magic(cls, attrs): + def generate_magic(cls, attrs: dict[str, object]) -> None: # generate wrappers for magic for attr_name in cls._forward_magic: attr = getattr(cls._forwards, attr_name) wrapper = _forward_magic(cls, attr) setattr(cls, attr_name, wrapper) - def generate_iter(cls, attrs): + def generate_iter(cls, attrs: dict[str, object]) -> None: # generate wrappers for methods that return iterators + wrapper: Callable for attr_name, attr in cls._wraps.__dict__.items(): if attr_name in cls._wrap_iter: wrapper = iter_wrapper_factory(cls, attr_name) + assert isinstance(wrapper, types.FunctionType) + wrapper.__signature__ = inspect.signature(attr) setattr(cls, attr_name, wrapper) @@ -151,44 +203,125 @@ class Path(metaclass=AsyncAutoWrapperType): """ - _wraps = pathlib.Path - _forwards = pathlib.PurePath - _forward_magic = [ - '__str__', - '__bytes__', - '__truediv__', - '__rtruediv__', - '__eq__', - '__lt__', - '__le__', - '__gt__', - '__ge__', - '__hash__', + _forward: ClassVar[list[str]] + _wraps: ClassVar[type] = pathlib.Path + _forwards: ClassVar[type] = pathlib.PurePath + _forward_magic: ClassVar[list[str]] = [ + "__str__", + "__bytes__", + "__truediv__", + "__rtruediv__", + "__eq__", + "__lt__", + "__le__", + "__gt__", + "__ge__", + "__hash__", ] - _wrap_iter = ['glob', 'rglob', 'iterdir'] - - def __init__(self, *args): - args = unwrap_paths(args) + _wrap_iter: ClassVar[list[str]] = ["glob", "rglob", "iterdir"] + def __init__(self, *args: StrPath) -> None: self._wrapped = pathlib.Path(*args) - def __getattr__(self, name): - if name in self._forward: - value = getattr(self._wrapped, name) - return rewrap_path(value) - raise AttributeError(name) - - def __dir__(self): - return super().__dir__() + self._forward - - def __repr__(self): - return 'trio.Path({})'.format(repr(str(self))) - - def __fspath__(self): - return fspath(self._wrapped) - - @wraps(pathlib.Path.open) - async def open(self, *args, **kwargs): + # type checkers allow accessing any attributes on class instances with `__getattr__` + # so we hide it behind a type guard forcing it to rely on the hardcoded attribute + # list below. + if not TYPE_CHECKING: + + def __getattr__(self, name): + if name in self._forward: + value = getattr(self._wrapped, name) + return rewrap_path(value) + raise AttributeError(name) + + def __dir__(self) -> list[str]: + return [*super().__dir__(), *self._forward] + + def __repr__(self) -> str: + return f"trio.Path({repr(str(self))})" + + def __fspath__(self) -> str: + return os.fspath(self._wrapped) + + @overload + def open( + self, + mode: OpenTextMode = "r", + buffering: int = -1, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, + ) -> _AsyncIOWrapper[TextIOWrapper]: + ... + + @overload + def open( + self, + mode: OpenBinaryMode, + buffering: Literal[0], + encoding: None = None, + errors: None = None, + newline: None = None, + ) -> _AsyncIOWrapper[FileIO]: + ... + + @overload + def open( + self, + mode: OpenBinaryModeUpdating, + buffering: Literal[-1, 1] = -1, + encoding: None = None, + errors: None = None, + newline: None = None, + ) -> _AsyncIOWrapper[BufferedRandom]: + ... + + @overload + def open( + self, + mode: OpenBinaryModeWriting, + buffering: Literal[-1, 1] = -1, + encoding: None = None, + errors: None = None, + newline: None = None, + ) -> _AsyncIOWrapper[BufferedWriter]: + ... + + @overload + def open( + self, + mode: OpenBinaryModeReading, + buffering: Literal[-1, 1] = -1, + encoding: None = None, + errors: None = None, + newline: None = None, + ) -> _AsyncIOWrapper[BufferedReader]: + ... + + @overload + def open( + self, + mode: OpenBinaryMode, + buffering: int = -1, + encoding: None = None, + errors: None = None, + newline: None = None, + ) -> _AsyncIOWrapper[BinaryIO]: + ... + + @overload + def open( + self, + mode: str, + buffering: int = -1, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, + ) -> _AsyncIOWrapper[IO[Any]]: + ... + + @wraps(pathlib.Path.open) # type: ignore[misc] # Overload return mismatch. + async def open(self, *args: Any, **kwargs: Any) -> _AsyncIOWrapper[IO[Any]]: """Open the file pointed to by the path, like the :func:`trio.open_file` function does. @@ -198,6 +331,104 @@ async def open(self, *args, **kwargs): value = await trio.to_thread.run_sync(func) return trio.wrap_file(value) + if TYPE_CHECKING: + # the dunders listed in _forward_magic that aren't seen otherwise + # fmt: off + def __bytes__(self) -> bytes: ... + def __truediv__(self, other: StrPath) -> Path: ... + def __rtruediv__(self, other: StrPath) -> Path: ... + + # wrapped methods handled by __getattr__ + async def absolute(self) -> Path: ... + async def as_posix(self) -> str: ... + async def as_uri(self) -> str: ... + + if sys.version_info >= (3, 10): + async def stat(self, *, follow_symlinks: bool = True) -> os.stat_result: ... + async def chmod(self, mode: int, *, follow_symlinks: bool = True) -> None: ... + else: + async def stat(self) -> os.stat_result: ... + async def chmod(self, mode: int) -> None: ... + + @classmethod + async def cwd(self) -> Path: ... + + async def exists(self) -> bool: ... + async def expanduser(self) -> Path: ... + async def glob(self, pattern: str) -> Iterable[Path]: ... + async def home(self) -> Path: ... + async def is_absolute(self) -> bool: ... + async def is_block_device(self) -> bool: ... + async def is_char_device(self) -> bool: ... + async def is_dir(self) -> bool: ... + async def is_fifo(self) -> bool: ... + async def is_file(self) -> bool: ... + async def is_reserved(self) -> bool: ... + async def is_socket(self) -> bool: ... + async def is_symlink(self) -> bool: ... + async def iterdir(self) -> Iterable[Path]: ... + async def joinpath(self, *other: StrPath) -> Path: ... + async def lchmod(self, mode: int) -> None: ... + async def lstat(self) -> os.stat_result: ... + async def match(self, path_pattern: str) -> bool: ... + async def mkdir(self, mode: int = 0o777, parents: bool = False, exist_ok: bool = False) -> None: ... + async def read_bytes(self) -> bytes: ... + async def read_text(self, encoding: str | None = None, errors: str | None = None) -> str: ... + async def relative_to(self, *other: StrPath) -> Path: ... + + if sys.version_info >= (3, 8): + def rename(self, target: str | pathlib.PurePath) -> Path: ... + def replace(self, target: str | pathlib.PurePath) -> Path: ... + else: + def rename(self, target: str | pathlib.PurePath) -> None: ... + def replace(self, target: str | pathlib.PurePath) -> None: ... + + async def resolve(self, strict: bool = False) -> Path: ... + async def rglob(self, pattern: str) -> Iterable[Path]: ... + async def rmdir(self) -> None: ... + async def samefile(self, other_path: str | bytes | int | Path) -> bool: ... + async def symlink_to(self, target: str | Path, target_is_directory: bool = False) -> None: ... + async def touch(self, mode: int = 0o666, exist_ok: bool = True) -> None: ... + if sys.version_info >= (3, 8): + def unlink(self, missing_ok: bool = False) -> None: ... + else: + def unlink(self) -> None: ... + async def with_name(self, name: str) -> Path: ... + async def with_suffix(self, suffix: str) -> Path: ... + async def write_bytes(self, data: bytes) -> int: ... + + if sys.version_info >= (3, 10): + async def write_text( + self, data: str, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, + ) -> int: ... + else: + async def write_text( + self, data: str, + encoding: str | None = None, + errors: str | None = None, + ) -> int: ... + + if sys.platform != "win32": + async def owner(self) -> str: ... + async def group(self) -> str: ... + async def is_mount(self) -> bool: ... + + if sys.version_info >= (3, 9): + async def is_relative_to(self, *other: StrPath) -> bool: ... + async def with_stem(self, stem: str) -> Path: ... + async def readlink(self) -> Path: ... + if sys.version_info >= (3, 10): + async def hardlink_to(self, target: str | pathlib.Path) -> None: ... + if sys.version_info < (3, 12): + async def link_to(self, target: StrPath | bytes) -> None: ... + if sys.version_info >= (3, 12): + async def is_junction(self) -> bool: ... + walk: Any # TODO + async def with_segments(self, *pathsegments: StrPath) -> Path: ... + Path.iterdir.__doc__ = """ Like :meth:`pathlib.Path.iterdir`, but async. @@ -219,6 +450,6 @@ async def open(self, *args, **kwargs): # sense than inventing our own special docstring for this. del Path.absolute.__doc__ -# python3.5 compat -if hasattr(os, 'PathLike'): - os.PathLike.register(Path) +# TODO: This is likely not supported by all the static tools out there, see discussion in +# https://github.com/python-trio/trio/pull/2631#discussion_r1185612528 +os.PathLike.register(Path) diff --git a/trio/_signals.py b/trio/_signals.py index 2ebb4a0a5a..fe2bde946e 100644 --- a/trio/_signals.py +++ b/trio/_signals.py @@ -1,13 +1,10 @@ import signal -from contextlib import contextmanager from collections import OrderedDict +from contextlib import contextmanager import trio -from ._util import ( - signal_raise, aiter_compat, is_main_thread, ConflictDetector -) -__all__ = ["open_signal_receiver"] +from ._util import ConflictDetector, is_main_thread, signal_raise # Discussion of signal handling strategies: # @@ -61,7 +58,7 @@ class SignalReceiver: def __init__(self): # {signal num: None} self._pending = OrderedDict() - self._lot = trio.hazmat.ParkingLot() + self._lot = trio.lowlevel.ParkingLot() self._conflict_detector = ConflictDetector( "only one task can iterate on a signal receiver at a time" ) @@ -96,7 +93,6 @@ def deliver_next(): def _pending_signal_count(self): return len(self._pending) - @aiter_compat def __aiter__(self): return self @@ -110,7 +106,7 @@ async def __anext__(self): if not self._pending: await self._lot.park() else: - await trio.hazmat.checkpoint() + await trio.lowlevel.checkpoint() signum, _ = self._pending.popitem(last=False) return signum @@ -134,6 +130,8 @@ def open_signal_receiver(*signals): signals: the signals to listen for. Raises: + TypeError: if no signals were provided. + RuntimeError: if you try to use this anywhere except Python's main thread. (This is a Python limitation.) @@ -149,12 +147,15 @@ def open_signal_receiver(*signals): reload_configuration() """ + if not signals: + raise TypeError("No signals were provided") + if not is_main_thread(): raise RuntimeError( "Sorry, open_signal_receiver is only possible when running in " "Python interpreter's main thread" ) - token = trio.hazmat.current_trio_token() + token = trio.lowlevel.current_trio_token() queue = SignalReceiver() def handler(signum, _): diff --git a/trio/_socket.py b/trio/_socket.py index 18403962d2..b0ec1d480d 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -1,15 +1,49 @@ -import os as _os -import sys as _sys +from __future__ import annotations + +import os import select import socket as _stdlib_socket +import sys from functools import wraps as _wraps +from operator import index +from socket import AddressFamily, SocketKind +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + NoReturn, + SupportsIndex, + Tuple, + TypeVar, + Union, + overload, +) import idna as _idna import trio -from ._util import fspath + from . import _core +if TYPE_CHECKING: + from collections.abc import Iterable + from types import TracebackType + + from typing_extensions import Buffer, Concatenate, ParamSpec, Self, TypeAlias + + from ._abc import HostnameResolver, SocketFactory + + P = ParamSpec("P") + + +T = TypeVar("T") + +# must use old-style typing because it's evaluated at runtime +Address: TypeAlias = Union[ + str, bytes, Tuple[str, int], Tuple[str, int, int], Tuple[str, int, int, int] +] + # Usage: # @@ -20,50 +54,47 @@ # return await do_it_properly_with_a_check_point() # class _try_sync: - def __init__(self, blocking_exc_override=None): + def __init__( + self, blocking_exc_override: Callable[[BaseException], bool] | None = None + ): self._blocking_exc_override = blocking_exc_override - def _is_blocking_io_error(self, exc): + def _is_blocking_io_error(self, exc: BaseException) -> bool: if self._blocking_exc_override is None: return isinstance(exc, BlockingIOError) else: return self._blocking_exc_override(exc) - async def __aenter__(self): - await trio.hazmat.checkpoint_if_cancelled() + async def __aenter__(self) -> None: + await trio.lowlevel.checkpoint_if_cancelled() - async def __aexit__(self, etype, value, tb): - if value is not None and self._is_blocking_io_error(value): + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> bool: + if exc_value is not None and self._is_blocking_io_error(exc_value): # Discard the exception and fall through to the code below the # block return True else: - await trio.hazmat.cancel_shielded_checkpoint() + await trio.lowlevel.cancel_shielded_checkpoint() # Let the return or exception propagate return False -################################################################ -# CONSTANTS -################################################################ - -try: - from socket import IPPROTO_IPV6 -except ImportError: - # As of at least 3.6, python on Windows is missing IPPROTO_IPV6 - # https://bugs.python.org/issue29515 - if _sys.platform == "win32": - IPPROTO_IPV6 = 41 - ################################################################ # Overrides ################################################################ -_resolver = _core.RunVar("hostname_resolver") -_socket_factory = _core.RunVar("socket_factory") +_resolver: _core.RunVar[HostnameResolver | None] = _core.RunVar("hostname_resolver") +_socket_factory: _core.RunVar[SocketFactory | None] = _core.RunVar("socket_factory") -def set_custom_hostname_resolver(hostname_resolver): +def set_custom_hostname_resolver( + hostname_resolver: HostnameResolver | None, +) -> HostnameResolver | None: """Set a custom hostname resolver. By default, Trio's :func:`getaddrinfo` and :func:`getnameinfo` functions @@ -95,7 +126,9 @@ def set_custom_hostname_resolver(hostname_resolver): return old -def set_custom_socket_factory(socket_factory): +def set_custom_socket_factory( + socket_factory: SocketFactory | None, +) -> SocketFactory | None: """Set a custom socket object factory. This function allows you to replace Trio's normal socket class with a @@ -129,7 +162,23 @@ def set_custom_socket_factory(socket_factory): _NUMERIC_ONLY = _stdlib_socket.AI_NUMERICHOST | _stdlib_socket.AI_NUMERICSERV -async def getaddrinfo(host, port, family=0, type=0, proto=0, flags=0): +# It would be possible to @overload the return value depending on Literal[AddressFamily.INET/6], but should probably be added in typeshed first +async def getaddrinfo( + host: bytes | str | None, + port: bytes | str | int | None, + family: int = 0, + type: int = 0, + proto: int = 0, + flags: int = 0, +) -> list[ + tuple[ + AddressFamily, + SocketKind, + int, + str, + tuple[str, int] | tuple[str, int, int, int], + ] +]: """Look up a numeric address given a name. Arguments and return values are identical to :func:`socket.getaddrinfo`, @@ -150,9 +199,11 @@ async def getaddrinfo(host, port, family=0, type=0, proto=0, flags=0): # skip the whole thread thing, which seems worthwhile. So we try first # with the _NUMERIC_ONLY flags set, and then only spawn a thread if that # fails with EAI_NONAME: - def numeric_only_failure(exc): - return isinstance(exc, _stdlib_socket.gaierror) and \ - exc.errno == _stdlib_socket.EAI_NONAME + def numeric_only_failure(exc: BaseException) -> bool: + return ( + isinstance(exc, _stdlib_socket.gaierror) + and exc.errno == _stdlib_socket.EAI_NONAME + ) async with _try_sync(numeric_only_failure): return _stdlib_socket.getaddrinfo( @@ -186,11 +237,13 @@ def numeric_only_failure(exc): type, proto, flags, - cancellable=True + cancellable=True, ) -async def getnameinfo(sockaddr, flags): +async def getnameinfo( + sockaddr: tuple[str, int] | tuple[str, int, int, int], flags: int +) -> tuple[str, str]: """Look up a name given a numeric address. Arguments and return values are identical to :func:`socket.getnameinfo`, @@ -209,7 +262,7 @@ async def getnameinfo(sockaddr, flags): ) -async def getprotobyname(name): +async def getprotobyname(name: str) -> int: """Look up a protocol number by name. (Rarely used.) Like :func:`socket.getprotobyname`, but async. @@ -228,8 +281,8 @@ async def getprotobyname(name): ################################################################ -def from_stdlib_socket(sock): - """Convert a standard library :func:`socket.socket` object into a Trio +def from_stdlib_socket(sock: _stdlib_socket.socket) -> _SocketType: + """Convert a standard library :class:`socket.socket` object into a Trio socket object. """ @@ -237,39 +290,58 @@ def from_stdlib_socket(sock): @_wraps(_stdlib_socket.fromfd, assigned=(), updated=()) -def fromfd(fd, family, type, proto=0): - """Like :func:`socket.fromfd`, but returns a Trio socket object. - - """ - family, type, proto = _sniff_sockopts_for_fileno(family, type, proto, fd) +def fromfd( + fd: SupportsIndex, + family: AddressFamily | int = _stdlib_socket.AF_INET, + type: SocketKind | int = _stdlib_socket.SOCK_STREAM, + proto: int = 0, +) -> _SocketType: + """Like :func:`socket.fromfd`, but returns a Trio socket object.""" + family, type, proto = _sniff_sockopts_for_fileno(family, type, proto, index(fd)) return from_stdlib_socket(_stdlib_socket.fromfd(fd, family, type, proto)) -if hasattr(_stdlib_socket, "fromshare"): +if sys.platform == "win32" or ( + not TYPE_CHECKING and hasattr(_stdlib_socket, "fromshare") +): @_wraps(_stdlib_socket.fromshare, assigned=(), updated=()) - def fromshare(*args, **kwargs): - return from_stdlib_socket(_stdlib_socket.fromshare(*args, **kwargs)) + def fromshare(info: bytes) -> _SocketType: + return from_stdlib_socket(_stdlib_socket.fromshare(info)) + + +if sys.platform == "win32": + FamilyT: TypeAlias = int + TypeT: TypeAlias = int + FamilyDefault = _stdlib_socket.AF_INET +else: + FamilyDefault = None + FamilyT: TypeAlias = Union[int, AddressFamily, None] + TypeT: TypeAlias = Union[_stdlib_socket.socket, int] @_wraps(_stdlib_socket.socketpair, assigned=(), updated=()) -def socketpair(*args, **kwargs): +def socketpair( + family: FamilyT = FamilyDefault, + type: TypeT = SocketKind.SOCK_STREAM, + proto: int = 0, +) -> tuple[_SocketType, _SocketType]: """Like :func:`socket.socketpair`, but returns a pair of Trio socket objects. """ - left, right = _stdlib_socket.socketpair(*args, **kwargs) + left, right = _stdlib_socket.socketpair(family, type, proto) return (from_stdlib_socket(left), from_stdlib_socket(right)) @_wraps(_stdlib_socket.socket, assigned=(), updated=()) def socket( - family=_stdlib_socket.AF_INET, - type=_stdlib_socket.SOCK_STREAM, - proto=0, - fileno=None -): - """Create a new Trio socket, like :func:`socket.socket`. + family: AddressFamily | int = _stdlib_socket.AF_INET, + type: SocketKind | int = _stdlib_socket.SOCK_STREAM, + proto: int = 0, + fileno: int | None = None, +) -> _SocketType: + """Create a new Trio socket, like :class:`socket.socket`. This function's behavior can be customized using :func:`set_custom_socket_factory`. @@ -280,29 +352,30 @@ def socket( if sf is not None: return sf.socket(family, type, proto) else: - family, type, proto = _sniff_sockopts_for_fileno( - family, type, proto, fileno - ) + family, type, proto = _sniff_sockopts_for_fileno(family, type, proto, fileno) stdlib_socket = _stdlib_socket.socket(family, type, proto, fileno) return from_stdlib_socket(stdlib_socket) -def _sniff_sockopts_for_fileno(family, type, proto, fileno): - """Correct SOCKOPTS for given fileno, falling back to provided values. - - """ +def _sniff_sockopts_for_fileno( + family: AddressFamily | int, + type: SocketKind | int, + proto: int, + fileno: int | None, +) -> tuple[AddressFamily | int, SocketKind | int, int]: + """Correct SOCKOPTS for given fileno, falling back to provided values.""" # Wrap the raw fileno into a Python socket object # This object might have the wrong metadata, but it lets us easily call getsockopt # and then we'll throw it away and construct a new one with the correct metadata. - if not _sys.platform == "linux": + if sys.platform != "linux": return family, type, proto - try: - from socket import SO_DOMAIN, SO_PROTOCOL - except ImportError: - # Only available on 3.6 and above: - SO_PROTOCOL = 38 - SO_DOMAIN = 39 - from socket import SOL_SOCKET, SO_TYPE + from socket import ( # type: ignore[attr-defined] + SO_DOMAIN, + SO_PROTOCOL, + SO_TYPE, + SOL_SOCKET, + ) + sockobj = _stdlib_socket.socket(family, type, proto, fileno=fileno) try: family = sockobj.getsockopt(SOL_SOCKET, SO_DOMAIN) @@ -331,37 +404,127 @@ def _sniff_sockopts_for_fileno(family, type, proto, fileno): ) -# Note that this is *not* in __all__. -# -# This function will modify the given socket to match the behavior in python -# 3.7. This will become unecessary and can be removed when support for versions -# older than 3.7 is dropped. -def real_socket_type(type_num): - return type_num & _SOCK_TYPE_MASK - - -def _make_simple_sock_method_wrapper(methname, wait_fn, maybe_avail=False): - fn = getattr(_stdlib_socket.socket, methname) - +def _make_simple_sock_method_wrapper( + fn: Callable[Concatenate[_stdlib_socket.socket, P], T], + wait_fn: Callable[[_stdlib_socket.socket], Awaitable[None]], + maybe_avail: bool = False, +) -> Callable[Concatenate[_SocketType, P], Awaitable[T]]: @_wraps(fn, assigned=("__name__",), updated=()) - async def wrapper(self, *args, **kwargs): - return await self._nonblocking_helper(fn, args, kwargs, wait_fn) + async def wrapper(self: _SocketType, *args: P.args, **kwargs: P.kwargs) -> T: + return await self._nonblocking_helper(wait_fn, fn, *args, **kwargs) - wrapper.__doc__ = ( - """Like :meth:`socket.socket.{}`, but async. + wrapper.__doc__ = f"""Like :meth:`socket.socket.{fn.__name__}`, but async. - """.format(methname) - ) + """ if maybe_avail: wrapper.__doc__ += ( - "Only available on platforms where :meth:`socket.socket.{}` " - "is available.".format(methname) + f"Only available on platforms where :meth:`socket.socket.{fn.__name__}` is " + "available." ) return wrapper +# Helpers to work with the (hostname, port) language that Python uses for socket +# addresses everywhere. Split out into a standalone function so it can be reused by +# FakeNet. + + +# Take an address in Python's representation, and returns a new address in +# the same representation, but with names resolved to numbers, +# etc. +# +# local=True means that the address is being used with bind() or similar +# local=False means that the address is being used with connect() or sendto() or +# similar. +# + + +# Using a TypeVar to indicate we return the same type of address appears to give errors +# when passed a union of address types. +# @overload likely works, but is extremely verbose. +# NOTE: this function does not always checkpoint +async def _resolve_address_nocp( + type: int, + family: AddressFamily, + proto: int, + *, + ipv6_v6only: bool | int, + address: Address, + local: bool, +) -> Address: + # Do some pre-checking (or exit early for non-IP sockets) + if family == _stdlib_socket.AF_INET: + if not isinstance(address, tuple) or not len(address) == 2: + raise ValueError("address should be a (host, port) tuple") + elif family == _stdlib_socket.AF_INET6: + if not isinstance(address, tuple) or not 2 <= len(address) <= 4: + raise ValueError( + "address should be a (host, port, [flowinfo, [scopeid]]) tuple" + ) + elif family == getattr(_stdlib_socket, "AF_UNIX"): + # unwrap path-likes + assert isinstance(address, (str, bytes)) + return os.fspath(address) + else: + return address + + # -- From here on we know we have IPv4 or IPV6 -- + host: str | None + host, port, *_ = address + # Fast path for the simple case: already-resolved IP address, + # already-resolved port. This is particularly important for UDP, since + # every sendto call goes through here. + if isinstance(port, int): + try: + _stdlib_socket.inet_pton(family, address[0]) + except (OSError, TypeError): + pass + else: + return address + # Special cases to match the stdlib, see gh-277 + if host == "": + host = None + if host == "": + host = "255.255.255.255" + flags = 0 + if local: + flags |= _stdlib_socket.AI_PASSIVE + # Since we always pass in an explicit family here, AI_ADDRCONFIG + # doesn't add any value -- if we have no ipv6 connectivity and are + # working with an ipv6 socket, then things will break soon enough! And + # if we do enable it, then it makes it impossible to even run tests + # for ipv6 address resolution on travis-ci, which as of 2017-03-07 has + # no ipv6. + # flags |= AI_ADDRCONFIG + if family == _stdlib_socket.AF_INET6 and not ipv6_v6only: + flags |= _stdlib_socket.AI_V4MAPPED + gai_res = await getaddrinfo(host, port, family, type, proto, flags) + # AFAICT from the spec it's not possible for getaddrinfo to return an + # empty list. + assert len(gai_res) >= 1 + # Address is the last item in the first entry + (*_, normed), *_ = gai_res + # The above ignored any flowid and scopeid in the passed-in address, + # so restore them if present: + if family == _stdlib_socket.AF_INET6: + list_normed = list(normed) + assert len(normed) == 4 + # typechecking certainly doesn't like this logic, but given just how broad + # Address is, it's quite cumbersome to write the below without type: ignore + if len(address) >= 3: + list_normed[2] = address[2] # type: ignore + if len(address) >= 4: + list_normed[3] = address[3] # type: ignore + return tuple(list_normed) # type: ignore + return normed + + +# TODO: stopping users from initializing this type should be done in a different way, +# so SocketType can be used as a type. Note that this is *far* from trivial without +# breaking subclasses of SocketType. Can maybe add abstract methods to SocketType, +# or rename _SocketType. class SocketType: - def __init__(self): + def __init__(self) -> NoReturn: raise TypeError( "SocketType is an abstract class; use trio.socket.socket if you " "want to construct a socket object" @@ -369,14 +532,12 @@ def __init__(self): class _SocketType(SocketType): - def __init__(self, sock): + def __init__(self, sock: _stdlib_socket.socket): if type(sock) is not _stdlib_socket.socket: # For example, ssl.SSLSocket subclasses socket.socket, but we # certainly don't want to blindly wrap one of those. raise TypeError( - "expected object of type 'socket.socket', not '{}".format( - type(sock).__name__ - ) + f"expected object of type 'socket.socket', not '{type(sock).__name__}'" ) self._sock = sock self._sock.setblocking(False) @@ -386,81 +547,126 @@ def __init__(self, sock): # Simple + portable methods and attributes ################################################################ - # NB this doesn't work because for loops don't create a scope - # for _name in [ - # ]: - # _meth = getattr(_stdlib_socket.socket, _name) - # @_wraps(_meth, assigned=("__name__", "__doc__"), updated=()) - # def _wrapped(self, *args, **kwargs): - # return getattr(self._sock, _meth)(*args, **kwargs) - # locals()[_meth] = _wrapped - # del _name, _meth, _wrapped - - _forward = { - "detach", - "get_inheritable", - "set_inheritable", - "fileno", - "getpeername", - "getsockname", - "getsockopt", - "setsockopt", - "listen", - "share", - } - - def __getattr__(self, name): - if name in self._forward: - return getattr(self._sock, name) - raise AttributeError(name) - - def __dir__(self): - return super().__dir__() + list(self._forward) - - def __enter__(self): + # forwarded methods + def detach(self) -> int: + return self._sock.detach() + + def fileno(self) -> int: + return self._sock.fileno() + + def getpeername(self) -> Any: + return self._sock.getpeername() + + def getsockname(self) -> Any: + return self._sock.getsockname() + + @overload + def getsockopt(self, /, level: int, optname: int) -> int: + ... + + @overload + def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes: + ... + + def getsockopt( + self, /, level: int, optname: int, buflen: int | None = None + ) -> int | bytes: + if buflen is None: + return self._sock.getsockopt(level, optname) + return self._sock.getsockopt(level, optname, buflen) + + @overload + def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None: + ... + + @overload + def setsockopt(self, /, level: int, optname: int, value: None, optlen: int) -> None: + ... + + def setsockopt( + self, + /, + level: int, + optname: int, + value: int | Buffer | None, + optlen: int | None = None, + ) -> None: + if optlen is None: + if value is None: + raise TypeError( + "invalid value for argument 'value', must not be None when specifying optlen" + ) + return self._sock.setsockopt(level, optname, value) + if value is not None: + raise TypeError( + "invalid value for argument 'value': {value!r}, must be None when specifying optlen" + ) + + # Note: PyPy may crash here due to setsockopt only supporting + # four parameters. + return self._sock.setsockopt(level, optname, value, optlen) + + def listen(self, /, backlog: int = min(_stdlib_socket.SOMAXCONN, 128)) -> None: + return self._sock.listen(backlog) + + def get_inheritable(self) -> bool: + return self._sock.get_inheritable() + + def set_inheritable(self, inheritable: bool) -> None: + return self._sock.set_inheritable(inheritable) + + if sys.platform == "win32" or ( + not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "share") + ): + + def share(self, /, process_id: int) -> bytes: + return self._sock.share(process_id) + + def __enter__(self) -> Self: return self - def __exit__(self, *exc_info): - return self._sock.__exit__(*exc_info) + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + return self._sock.__exit__(exc_type, exc_value, traceback) @property - def family(self): + def family(self) -> AddressFamily: return self._sock.family @property - def type(self): - # Modify the socket type do match what is done on python 3.7. When - # support for versions older than 3.7 is dropped, this can be updated - # to just return self._sock.type - return real_socket_type(self._sock.type) + def type(self) -> SocketKind: + return self._sock.type @property - def proto(self): + def proto(self) -> int: return self._sock.proto @property - def did_shutdown_SHUT_WR(self): + def did_shutdown_SHUT_WR(self) -> bool: return self._did_shutdown_SHUT_WR - def __repr__(self): + def __repr__(self) -> str: return repr(self._sock).replace("socket.socket", "trio.socket.socket") - def dup(self): - """Same as :meth:`socket.socket.dup`. - - """ + def dup(self) -> _SocketType: + """Same as :meth:`socket.socket.dup`.""" return _SocketType(self._sock.dup()) - def close(self): + def close(self) -> None: if self._sock.fileno() != -1: - trio.hazmat.notify_closing(self._sock) + trio.lowlevel.notify_closing(self._sock) self._sock.close() - async def bind(self, address): - address = await self._resolve_local_address(address) + async def bind(self, address: Address) -> None: + address = await self._resolve_address_nocp(address, local=True) if ( hasattr(_stdlib_socket, "AF_UNIX") - and self.family == _stdlib_socket.AF_UNIX and address[0] + and self.family == _stdlib_socket.AF_UNIX + and address[0] ): # Use a thread for the filesystem traversal (unless it's an # abstract domain socket) @@ -470,100 +676,62 @@ async def bind(self, address): # complete asynchronously, like connect. But in practice AFAICT # there aren't yet any real systems that do this, so we'll worry # about it when it happens. + await trio.lowlevel.checkpoint() return self._sock.bind(address) - def shutdown(self, flag): + def shutdown(self, flag: int) -> None: # no need to worry about return value b/c always returns None: self._sock.shutdown(flag) # only do this if the call succeeded: if flag in [_stdlib_socket.SHUT_WR, _stdlib_socket.SHUT_RDWR]: self._did_shutdown_SHUT_WR = True - def is_readable(self): + def is_readable(self) -> bool: # use select.select on Windows, and select.poll everywhere else - if _sys.platform == "win32": + if sys.platform == "win32": rready, _, _ = select.select([self._sock], [], [], 0) return bool(rready) p = select.poll() p.register(self._sock, select.POLLIN) return bool(p.poll(0)) - async def wait_writable(self): + async def wait_writable(self) -> None: await _core.wait_writable(self._sock) - ################################################################ - # Address handling - ################################################################ - - # Take an address in Python's representation, and returns a new address in - # the same representation, but with names resolved to numbers, - # etc. - async def _resolve_address(self, address, flags): - # Do some pre-checking (or exit early for non-IP sockets) - if self._sock.family == _stdlib_socket.AF_INET: - if not isinstance(address, tuple) or not len(address) == 2: - raise ValueError("address should be a (host, port) tuple") - elif self._sock.family == _stdlib_socket.AF_INET6: - if not isinstance(address, tuple) or not 2 <= len(address) <= 4: - raise ValueError( - "address should be a (host, port, [flowinfo, [scopeid]]) " - "tuple" - ) - elif self._sock.family == _stdlib_socket.AF_UNIX: - await trio.hazmat.checkpoint() - # unwrap path-likes - return fspath(address) - + async def _resolve_address_nocp( + self, + address: Address, + *, + local: bool, + ) -> Address: + if self.family == _stdlib_socket.AF_INET6: + ipv6_v6only = self._sock.getsockopt( + _stdlib_socket.IPPROTO_IPV6, _stdlib_socket.IPV6_V6ONLY + ) else: - await trio.hazmat.checkpoint() - return address - host, port, *_ = address - # Special cases to match the stdlib, see gh-277 - if host == "": - host = None - if host == "": - host = "255.255.255.255" - # Since we always pass in an explicit family here, AI_ADDRCONFIG - # doesn't add any value -- if we have no ipv6 connectivity and are - # working with an ipv6 socket, then things will break soon enough! And - # if we do enable it, then it makes it impossible to even run tests - # for ipv6 address resolution on travis-ci, which as of 2017-03-07 has - # no ipv6. - # flags |= AI_ADDRCONFIG - if self._sock.family == _stdlib_socket.AF_INET6: - if not self._sock.getsockopt( - IPPROTO_IPV6, _stdlib_socket.IPV6_V6ONLY - ): - flags |= _stdlib_socket.AI_V4MAPPED - gai_res = await getaddrinfo( - host, port, self._sock.family, self.type, self._sock.proto, flags + ipv6_v6only = False + return await _resolve_address_nocp( + self.type, + self.family, + self.proto, + ipv6_v6only=ipv6_v6only, + address=address, + local=local, ) - # AFAICT from the spec it's not possible for getaddrinfo to return an - # empty list. - assert len(gai_res) >= 1 - # Address is the last item in the first entry - (*_, normed), *_ = gai_res - # The above ignored any flowid and scopeid in the passed-in address, - # so restore them if present: - if self._sock.family == _stdlib_socket.AF_INET6: - normed = list(normed) - assert len(normed) == 4 - if len(address) >= 3: - normed[2] = address[2] - if len(address) >= 4: - normed[3] = address[3] - normed = tuple(normed) - return normed - - # Returns something appropriate to pass to bind() - async def _resolve_local_address(self, address): - return await self._resolve_address(address, _stdlib_socket.AI_PASSIVE) - - # Returns something appropriate to pass to connect()/sendto()/sendmsg() - async def _resolve_remote_address(self, address): - return await self._resolve_address(address, 0) - - async def _nonblocking_helper(self, fn, args, kwargs, wait_fn): + + # args and kwargs must be starred, otherwise pyright complains: + # '"args" member of ParamSpec is valid only when used with *args parameter' + # '"kwargs" member of ParamSpec is valid only when used with **kwargs parameter' + # wait_fn and fn must also be first in the signature + # 'Keyword parameter cannot appear in signature after ParamSpec args parameter' + + async def _nonblocking_helper( + self, + wait_fn: Callable[[_stdlib_socket.socket], Awaitable[None]], + fn: Callable[Concatenate[_stdlib_socket.socket, P], T], + *args: P.args, + **kwargs: P.kwargs, + ) -> T: # We have to reconcile two conflicting goals: # - We want to make it look like we always blocked in doing these # operations. The obvious way is to always do an IO wait before @@ -599,12 +767,12 @@ async def _nonblocking_helper(self, fn, args, kwargs, wait_fn): # accept ################################################################ - _accept = _make_simple_sock_method_wrapper("accept", _core.wait_readable) - - async def accept(self): - """Like :meth:`socket.socket.accept`, but async. + _accept = _make_simple_sock_method_wrapper( + _stdlib_socket.socket.accept, _core.wait_readable + ) - """ + async def accept(self) -> tuple[_SocketType, object]: + """Like :meth:`socket.socket.accept`, but async.""" sock, addr = await self._accept() return from_stdlib_socket(sock), addr @@ -612,13 +780,13 @@ async def accept(self): # connect ################################################################ - async def connect(self, address): + async def connect(self, address: Address) -> None: # nonblocking connect is weird -- you call it to start things # off, then the socket becomes writable as a completion # notification. This means it isn't really cancellable... we close the # socket if cancelled, to avoid confusion. try: - address = await self._resolve_remote_address(address) + address = await self._resolve_address_nocp(address, local=False) async with _try_sync(): # An interesting puzzle: can a non-blocking connect() return EINTR # (= raise InterruptedError)? PEP 475 specifically left this as @@ -678,40 +846,73 @@ async def connect(self, address): self._sock.close() raise # Okay, the connect finished, but it might have failed: - err = self._sock.getsockopt( - _stdlib_socket.SOL_SOCKET, _stdlib_socket.SO_ERROR - ) + err = self._sock.getsockopt(_stdlib_socket.SOL_SOCKET, _stdlib_socket.SO_ERROR) if err != 0: - raise OSError(err, "Error in connect: " + _os.strerror(err)) + raise OSError(err, f"Error connecting to {address!r}: {os.strerror(err)}") ################################################################ # recv ################################################################ - recv = _make_simple_sock_method_wrapper("recv", _core.wait_readable) + # Not possible to typecheck with a Callable (due to DefaultArg), nor with a + # callback Protocol (https://github.com/python/typing/discussions/1040) + # but this seems to work. If not explicitly defined then pyright --verifytypes will + # complain about AmbiguousType + if TYPE_CHECKING: + + def recv(__self, __buflen: int, __flags: int = 0) -> Awaitable[bytes]: + ... + + # _make_simple_sock_method_wrapper is typed, so this checks that the above is correct + # this requires that we refrain from using `/` to specify pos-only + # args, or mypy thinks the signature differs from typeshed. + recv = _make_simple_sock_method_wrapper( # noqa: F811 + _stdlib_socket.socket.recv, _core.wait_readable + ) ################################################################ # recv_into ################################################################ - recv_into = _make_simple_sock_method_wrapper( - "recv_into", _core.wait_readable + if TYPE_CHECKING: + + def recv_into( + __self, buffer: Buffer, nbytes: int = 0, flags: int = 0 + ) -> Awaitable[int]: + ... + + recv_into = _make_simple_sock_method_wrapper( # noqa: F811 + _stdlib_socket.socket.recv_into, _core.wait_readable ) ################################################################ # recvfrom ################################################################ - recvfrom = _make_simple_sock_method_wrapper( - "recvfrom", _core.wait_readable + if TYPE_CHECKING: + # return type of socket.socket.recvfrom in typeshed is tuple[bytes, Any] + def recvfrom( + __self, __bufsize: int, __flags: int = 0 + ) -> Awaitable[tuple[bytes, Address]]: + ... + + recvfrom = _make_simple_sock_method_wrapper( # noqa: F811 + _stdlib_socket.socket.recvfrom, _core.wait_readable ) ################################################################ # recvfrom_into ################################################################ - recvfrom_into = _make_simple_sock_method_wrapper( - "recvfrom_into", _core.wait_readable + if TYPE_CHECKING: + # return type of socket.socket.recvfrom_into in typeshed is tuple[bytes, Any] + def recvfrom_into( + __self, buffer: Buffer, nbytes: int = 0, flags: int = 0 + ) -> Awaitable[tuple[int, Address]]: + ... + + recvfrom_into = _make_simple_sock_method_wrapper( # noqa: F811 + _stdlib_socket.socket.recvfrom_into, _core.wait_readable ) ################################################################ @@ -719,8 +920,15 @@ async def connect(self, address): ################################################################ if hasattr(_stdlib_socket.socket, "recvmsg"): - recvmsg = _make_simple_sock_method_wrapper( - "recvmsg", _core.wait_readable, maybe_avail=True + if TYPE_CHECKING: + + def recvmsg( + __self, __bufsize: int, __ancbufsize: int = 0, __flags: int = 0 + ) -> Awaitable[tuple[bytes, list[tuple[int, int, bytes]], int, Any]]: + ... + + recvmsg = _make_simple_sock_method_wrapper( # noqa: F811 + _stdlib_socket.socket.recvmsg, _core.wait_readable, maybe_avail=True ) ################################################################ @@ -728,54 +936,91 @@ async def connect(self, address): ################################################################ if hasattr(_stdlib_socket.socket, "recvmsg_into"): - recvmsg_into = _make_simple_sock_method_wrapper( - "recvmsg_into", _core.wait_readable, maybe_avail=True + if TYPE_CHECKING: + + def recvmsg_into( + __self, + __buffers: Iterable[Buffer], + __ancbufsize: int = 0, + __flags: int = 0, + ) -> Awaitable[tuple[int, list[tuple[int, int, bytes]], int, Any]]: + ... + + recvmsg_into = _make_simple_sock_method_wrapper( # noqa: F811 + _stdlib_socket.socket.recvmsg_into, _core.wait_readable, maybe_avail=True ) ################################################################ # send ################################################################ - send = _make_simple_sock_method_wrapper("send", _core.wait_writable) + if TYPE_CHECKING: + + def send(__self, __bytes: Buffer, __flags: int = 0) -> Awaitable[int]: + ... + + send = _make_simple_sock_method_wrapper( # noqa: F811 + _stdlib_socket.socket.send, _core.wait_writable + ) ################################################################ # sendto ################################################################ - @_wraps(_stdlib_socket.socket.sendto, assigned=(), updated=()) - async def sendto(self, *args): - """Similar to :meth:`socket.socket.sendto`, but async. - - """ + @overload + async def sendto( + self, __data: Buffer, __address: tuple[Any, ...] | str | Buffer + ) -> int: + ... + + @overload + async def sendto( + self, __data: Buffer, __flags: int, __address: tuple[Any, ...] | str | Buffer + ) -> int: + ... + + @_wraps(_stdlib_socket.socket.sendto, assigned=(), updated=()) # type: ignore[misc] + async def sendto(self, *args: Any) -> int: + """Similar to :meth:`socket.socket.sendto`, but async.""" # args is: data[, flags], address) # and kwargs are not accepted - args = list(args) - args[-1] = await self._resolve_remote_address(args[-1]) + args_list = list(args) + args_list[-1] = await self._resolve_address_nocp(args[-1], local=False) return await self._nonblocking_helper( - _stdlib_socket.socket.sendto, args, {}, _core.wait_writable + _core.wait_writable, _stdlib_socket.socket.sendto, *args_list ) ################################################################ # sendmsg ################################################################ - if hasattr(_stdlib_socket.socket, "sendmsg"): + if sys.platform != "win32" or ( + not TYPE_CHECKING and hasattr(_stdlib_socket.socket, "sendmsg") + ): @_wraps(_stdlib_socket.socket.sendmsg, assigned=(), updated=()) - async def sendmsg(self, *args): + async def sendmsg( + self, + __buffers: Iterable[Buffer], + __ancdata: Iterable[tuple[int, int, Buffer]] = (), + __flags: int = 0, + __address: Address | None = None, + ) -> int: """Similar to :meth:`socket.socket.sendmsg`, but async. Only available on platforms where :meth:`socket.socket.sendmsg` is available. """ - # args is: buffers[, ancdata[, flags[, address]]] - # and kwargs are not accepted - if len(args) == 4 and args[-1] is not None: - args = list(args) - args[-1] = await self._resolve_remote_address(args[-1]) + if __address is not None: + __address = await self._resolve_address_nocp(__address, local=False) return await self._nonblocking_helper( - _stdlib_socket.socket.sendmsg, args, {}, _core.wait_writable + _core.wait_writable, + _stdlib_socket.socket.sendmsg, + __buffers, + __ancdata, + __flags, + __address, ) ################################################################ diff --git a/trio/_ssl.py b/trio/_ssl.py index 21093b54dc..bd8b3b06b6 100644 --- a/trio/_ssl.py +++ b/trio/_ssl.py @@ -69,7 +69,7 @@ # able to use this to figure out the key. Is this a real practical problem? I # have no idea, I'm not a cryptographer. In any case, some people worry that # it's a problem, so their TLS libraries are designed to automatically trigger -# a renegotation every once in a while on some sort of timer. +# a renegotiation every once in a while on some sort of timer. # # The end result is that you might be going along, minding your own business, # and then *bam*! a wild renegotiation appears! And you just have to cope. @@ -155,11 +155,10 @@ import trio -from .abc import Stream, Listener -from ._highlevel_generic import aclose_forcefully from . import _sync -from ._util import ConflictDetector -from ._deprecate import warn_deprecated +from ._highlevel_generic import aclose_forcefully +from ._util import ConflictDetector, Final +from .abc import Listener, Stream ################################################################ # SSLStream @@ -191,6 +190,16 @@ STARTING_RECEIVE_SIZE = 16384 +def _is_eof(exc): + # There appears to be a bug on Python 3.10, where SSLErrors + # aren't properly translated into SSLEOFErrors. + # This stringly-typed error check is borrowed from the AnyIO + # project. + return isinstance(exc, _stdlib_ssl.SSLEOFError) or ( + hasattr(exc, "strerror") and "UNEXPECTED_EOF_WHILE_READING" in exc.strerror + ) + + class NeedHandshakeError(Exception): """Some :class:`SSLStream` methods can't return any meaningful data until after the handshake. If you call them before the handshake, they raise @@ -224,7 +233,7 @@ def done(self): _State = _Enum("_State", ["OK", "BROKEN", "CLOSED"]) -class SSLStream(Stream): +class SSLStream(Stream, metaclass=Final): r"""Encrypted communication using SSL/TLS. :class:`SSLStream` wraps an arbitrary :class:`~trio.abc.Stream`, and @@ -328,14 +337,9 @@ def __init__( server_hostname=None, server_side=False, https_compatible=False, - max_refill_bytes="unused and deprecated" ): self.transport_stream = transport_stream self._state = _State.OK - if max_refill_bytes != "unused and deprecated": - warn_deprecated( - "max_refill_bytes=...", "0.12.0", issue=959, instead=None - ) self._https_compatible = https_compatible self._outgoing = _stdlib_ssl.MemoryBIO() self._delayed_outgoing = None @@ -344,7 +348,7 @@ def __init__( self._incoming, self._outgoing, server_side=server_side, - server_hostname=server_hostname + server_hostname=server_hostname, ) # Tracks whether we've already done the initial handshake self._handshook = _Once(self._do_handshake) @@ -398,9 +402,7 @@ def __init__( def __getattr__(self, name): if name in self._forwarded: if name in self._after_handshake and not self._handshook.done: - raise NeedHandshakeError( - "call do_handshake() before calling {!r}".format(name) - ) + raise NeedHandshakeError(f"call do_handshake() before calling {name!r}") return getattr(self._ssl_object, name) else: @@ -429,10 +431,8 @@ def _check_status(self): # comments, though, just make sure to think carefully if you ever have to # touch it. The big comment at the top of this file will help explain # too. - async def _retry( - self, fn, *args, ignore_want_read=False, is_handshake=False - ): - await trio.hazmat.checkpoint_if_cancelled() + async def _retry(self, fn, *args, ignore_want_read=False, is_handshake=False): + await trio.lowlevel.checkpoint_if_cancelled() yielded = False finished = False while not finished: @@ -495,7 +495,9 @@ async def _retry( # # https://github.com/python-trio/trio/issues/819#issuecomment-517529763 if ( - is_handshake and not want_read and self._ssl_object.server_side + is_handshake + and not want_read + and self._ssl_object.server_side and self._ssl_object.version() == "TLSv1.3" ): assert self._delayed_outgoing is None @@ -544,7 +546,7 @@ async def _retry( # We could do something tricky to keep track of whether a # receive_some happens while we're sending, but the case where # we have to do both is very unusual (only during a - # renegotation), so it's better to keep things simple. So we + # renegotiation), so it's better to keep things simple. So we # do just one potentially-blocking operation, then check again # for fresh information. # @@ -598,7 +600,7 @@ async def _retry( self._incoming.write(data) self._inner_recv_count += 1 if not yielded: - await trio.hazmat.cancel_shielded_checkpoint() + await trio.lowlevel.cancel_shielded_checkpoint() return ret async def _do_handshake(self): @@ -615,7 +617,7 @@ async def do_handshake(self): certificates, select cryptographic keys, and so forth, before any actual data can be sent or received. You don't have to call this method; if you don't, then :class:`SSLStream` will automatically - peform the handshake as needed, the first time you try to send or + perform the handshake as needed, the first time you try to send or receive data. But if you want to trigger it manually – for example, because you want to look at the peer's certificate before you start talking to them – then you can call this method. @@ -664,13 +666,11 @@ async def receive_some(self, max_bytes=None): # For some reason, EOF before handshake sometimes raises # SSLSyscallError instead of SSLEOFError (e.g. on my linux # laptop, but not on appveyor). Thanks openssl. - if ( - self._https_compatible and isinstance( - exc.__cause__, - (_stdlib_ssl.SSLEOFError, _stdlib_ssl.SSLSyscallError) - ) + if self._https_compatible and ( + isinstance(exc.__cause__, _stdlib_ssl.SSLSyscallError) + or _is_eof(exc.__cause__) ): - await trio.hazmat.checkpoint() + await trio.lowlevel.checkpoint() return b"" else: raise @@ -678,9 +678,7 @@ async def receive_some(self, max_bytes=None): # If we somehow have more data already in our pending buffer # than the estimate receive size, bump up our size a bit for # this read only. - max_bytes = max( - self._estimated_receive_size, self._incoming.pending - ) + max_bytes = max(self._estimated_receive_size, self._incoming.pending) else: max_bytes = _operator.index(max_bytes) if max_bytes < 1: @@ -693,11 +691,9 @@ async def receive_some(self, max_bytes=None): # BROKEN. But that's actually fine, because after getting an # EOF on TLS then the only thing you can do is close the # stream, and closing doesn't care about the state. - if ( - self._https_compatible - and isinstance(exc.__cause__, _stdlib_ssl.SSLEOFError) - ): - await trio.hazmat.checkpoint() + + if self._https_compatible and _is_eof(exc.__cause__): + await trio.lowlevel.checkpoint() return b"" else: raise @@ -719,7 +715,7 @@ async def send_all(self, data): # SSLObject interprets write(b"") as an EOF for some reason, which # is not what we want. if not data: - await trio.hazmat.checkpoint() + await trio.lowlevel.checkpoint() return await self._retry(self._ssl_object.write, data) @@ -740,8 +736,7 @@ async def unwrap(self): ``transport_stream.receive_some(...)``. """ - with self._outer_recv_conflict_detector, \ - self._outer_send_conflict_detector: + with self._outer_recv_conflict_detector, self._outer_send_conflict_detector: self._check_status() await self._handshook.ensure(checkpoint=False) await self._retry(self._ssl_object.unwrap) @@ -763,7 +758,7 @@ async def aclose(self): """ if self._state is _State.CLOSED: - await trio.hazmat.checkpoint() + await trio.lowlevel.checkpoint() return if self._state is _State.BROKEN or self._https_compatible: self._state = _State.CLOSED @@ -824,9 +819,7 @@ async def aclose(self): # going to be able to do a clean shutdown. If that happens, we'll # just do an unclean shutdown. try: - await self._retry( - self._ssl_object.unwrap, ignore_want_read=True - ) + await self._retry(self._ssl_object.unwrap, ignore_want_read=True) except (trio.BrokenResourceError, trio.BusyResourceError): pass except: @@ -840,9 +833,7 @@ async def aclose(self): self._state = _State.CLOSED async def wait_send_all_might_not_block(self): - """See :meth:`trio.abc.SendStream.wait_send_all_might_not_block`. - - """ + """See :meth:`trio.abc.SendStream.wait_send_all_might_not_block`.""" # This method's implementation is deceptively simple. # # First, we take the outer send lock, because of Trio's standard @@ -882,7 +873,7 @@ async def wait_send_all_might_not_block(self): await self.transport_stream.wait_send_all_might_not_block() -class SSLListener(Listener[SSLStream]): +class SSLListener(Listener[SSLStream], metaclass=Final): """A :class:`~trio.abc.Listener` for SSL/TLS-encrypted servers. :class:`SSLListener` wraps around another Listener, and converts @@ -903,18 +894,14 @@ class SSLListener(Listener[SSLStream]): passed to ``__init__``. """ + def __init__( self, transport_listener, ssl_context, *, https_compatible=False, - max_refill_bytes="unused and deprecated" ): - if max_refill_bytes != "unused and deprecated": - warn_deprecated( - "max_refill_bytes=...", "0.12.0", issue=959, instead=None - ) self.transport_listener = transport_listener self._ssl_context = ssl_context self._https_compatible = https_compatible @@ -934,7 +921,5 @@ async def accept(self): ) async def aclose(self): - """Close the transport listener. - - """ + """Close the transport listener.""" await self.transport_listener.aclose() diff --git a/trio/_subprocess.py b/trio/_subprocess.py index 809af3f28a..1f8d0a8253 100644 --- a/trio/_subprocess.py +++ b/trio/_subprocess.py @@ -1,36 +1,87 @@ import os import subprocess -from typing import Optional +import sys +import warnings +from contextlib import ExitStack +from functools import partial +from typing import TYPE_CHECKING, Optional -from ._abc import AsyncResource, SendStream, ReceiveStream +import trio + +from ._abc import AsyncResource, ReceiveStream, SendStream +from ._core import ClosedResourceError +from ._deprecate import deprecated from ._highlevel_generic import StapledStream -from ._sync import Lock from ._subprocess_platform import ( - wait_child_exiting, create_pipe_to_child_stdin, - create_pipe_from_child_output + create_pipe_from_child_output, + create_pipe_to_child_stdin, + wait_child_exiting, ) -import trio +from ._sync import Lock +from ._util import NoPublicConstructor + +# Linux-specific, but has complex lifetime management stuff so we hard-code it +# here instead of hiding it behind the _subprocess_platform abstraction +can_try_pidfd_open: bool +if TYPE_CHECKING: + + def pidfd_open(fd: int, flags: int) -> int: + ... + + from ._subprocess_platform import ClosableReceiveStream, ClosableSendStream + +else: + can_try_pidfd_open = True + try: + from os import pidfd_open + except ImportError: + if sys.platform == "linux": + import ctypes + + _cdll_for_pidfd_open = ctypes.CDLL(None, use_errno=True) + _cdll_for_pidfd_open.syscall.restype = ctypes.c_long + # pid and flags are actually int-sized, but the syscall() function + # always takes longs. (Except on x32 where long is 32-bits and syscall + # takes 64-bit arguments. But in the unlikely case that anyone is + # using x32, this will still work, b/c we only need to pass in 32 bits + # of data, and the C ABI doesn't distinguish between passing 32-bit vs + # 64-bit integers; our 32-bit values will get loaded into 64-bit + # registers where syscall() will find them.) + _cdll_for_pidfd_open.syscall.argtypes = [ + ctypes.c_long, # syscall number + ctypes.c_long, # pid + ctypes.c_long, # flags + ] + __NR_pidfd_open = 434 + + def pidfd_open(fd: int, flags: int) -> int: + result = _cdll_for_pidfd_open.syscall(__NR_pidfd_open, fd, flags) + if result < 0: + err = ctypes.get_errno() + raise OSError(err, os.strerror(err)) + return result + + else: + can_try_pidfd_open = False -class Process(AsyncResource): +class Process(AsyncResource, metaclass=NoPublicConstructor): r"""A child process. Like :class:`subprocess.Popen`, but async. - This class has no public constructor. To create a child process, use - `open_process`:: + This class has no public constructor. The most common way to get a + `Process` object is to combine `Nursery.start` with `run_process`:: - process = await trio.open_process(...) + process_object = await nursery.start(run_process, ...) - `Process` implements the `~trio.abc.AsyncResource` interface. In order to - make sure your process doesn't end up getting abandoned by mistake or - after an exception, you can use ``async with``:: + This way, `run_process` supervises the process and makes sure that it is + cleaned up properly, while optionally checking the return value, feeding + it input, and so on. - async with await trio.open_process(...) as process: - ... + If you need more control – for example, because you want to spawn a child + process that outlives your program – then another option is to use + `trio.lowlevel.open_process`:: - "Closing" a :class:`Process` will close any pipes to the child and wait - for it to exit; if cancelled, the child will be forcibly killed and we - will ensure it has finished exiting before allowing the cancellation to - propagate. + process_object = await trio.lowlevel.open_process(...) Attributes: args (str or list): The ``command`` passed at construction time, @@ -66,111 +117,51 @@ class Process(AsyncResource): # arbitrarily many threads if wait() keeps getting cancelled. _wait_for_exit_data = None - # After the deprecation period: - # - delete __init__ and _create - # - add metaclass=NoPublicConstructor - # - rename _init to __init__ - # - move most of the code into open_process() - # - put the subprocess.Popen(...) call into a thread - def __init__(self, *args, **kwargs): - trio._deprecate.warn_deprecated( - "directly constructing Process objects", - "0.12.0", - issue=1109, - instead="trio.open_process" - ) - self._init(*args, **kwargs) - - @classmethod - def _create(cls, *args, **kwargs): - self = cls.__new__(cls) - self._init(*args, **kwargs) - return self + def __init__(self, popen, stdin, stdout, stderr): + self._proc = popen + self.stdin: Optional[SendStream] = stdin + self.stdout: Optional[ReceiveStream] = stdout + self.stderr: Optional[ReceiveStream] = stderr - def _init( - self, command, *, stdin=None, stdout=None, stderr=None, **options - ): - for key in ( - 'universal_newlines', 'text', 'encoding', 'errors', 'bufsize' - ): - if options.get(key): - raise TypeError( - "trio.Process only supports communicating over " - "unbuffered byte streams; the '{}' option is not supported" - .format(key) - ) - - self.stdin = None # type: Optional[SendStream] - self.stdout = None # type: Optional[ReceiveStream] - self.stderr = None # type: Optional[ReceiveStream] - self.stdio = None # type: Optional[StapledStream] - - if os.name == "posix": - if isinstance(command, str) and not options.get("shell"): - raise TypeError( - "command must be a sequence (not a string) if shell=False " - "on UNIX systems" - ) - if not isinstance(command, str) and options.get("shell"): - raise TypeError( - "command must be a string (not a sequence) if shell=True " - "on UNIX systems" - ) + self.stdio: Optional[StapledStream] = None + if self.stdin is not None and self.stdout is not None: + self.stdio = StapledStream(self.stdin, self.stdout) self._wait_lock = Lock() - if stdin == subprocess.PIPE: - self.stdin, stdin = create_pipe_to_child_stdin() - if stdout == subprocess.PIPE: - self.stdout, stdout = create_pipe_from_child_output() - if stderr == subprocess.STDOUT: - # If we created a pipe for stdout, pass the same pipe for - # stderr. If stdout was some non-pipe thing (DEVNULL or a - # given FD), pass the same thing. If stdout was passed as - # None, keep stderr as STDOUT to allow subprocess to dup - # our stdout. Regardless of which of these is applicable, - # don't create a new Trio stream for stderr -- if stdout - # is piped, stderr will be intermixed on the stdout stream. - if stdout is not None: - stderr = stdout - elif stderr == subprocess.PIPE: - self.stderr, stderr = create_pipe_from_child_output() - - try: - self._proc = subprocess.Popen( - command, stdin=stdin, stdout=stdout, stderr=stderr, **options - ) - finally: - # Close the parent's handle for each child side of a pipe; - # we want the child to have the only copy, so that when - # it exits we can read EOF on our side. - if self.stdin is not None: - os.close(stdin) - if self.stdout is not None: - os.close(stdout) - if self.stderr is not None: - os.close(stderr) - - if self.stdin is not None and self.stdout is not None: - self.stdio = StapledStream(self.stdin, self.stdout) + self._pidfd = None + if can_try_pidfd_open: + try: + fd = pidfd_open(self._proc.pid, 0) + except OSError: + # Well, we tried, but it didn't work (probably because we're + # running on an older kernel, or in an older sandbox, that + # hasn't been updated to support pidfd_open). We'll fall back + # on waitid instead. + pass + else: + # It worked! Wrap the raw fd up in a Python file object to + # make sure it'll get closed. + self._pidfd = open(fd) self.args = self._proc.args self.pid = self._proc.pid def __repr__(self): - if self.returncode is None: - status = "running with PID {}".format(self.pid) + returncode = self.returncode + if returncode is None: + status = f"running with PID {self.pid}" else: - if self.returncode < 0: - status = "exited with signal {}".format(-self.returncode) + if returncode < 0: + status = f"exited with signal {-returncode}" else: - status = "exited with status {}".format(self.returncode) - return "".format(self.args, status) + status = f"exited with status {returncode}" + return f"" @property def returncode(self): - """The exit status of the process (an integer), or ``None`` if it is - not yet known to have exited. + """The exit status of the process (an integer), or ``None`` if it's + still running. By convention, a return code of zero indicates success. On UNIX, negative values indicate termination due to a signal, @@ -178,11 +169,29 @@ def returncode(self): Windows, a process that exits due to a call to :meth:`Process.terminate` will have an exit status of 1. - Accessing this attribute does not check for termination; - use :meth:`poll` or :meth:`wait` for that. + Unlike the standard library `subprocess.Popen.returncode`, you don't + have to call `poll` or `wait` to update this attribute; it's + automatically updated as needed, and will always give you the latest + information. + """ - return self._proc.returncode + result = self._proc.poll() + if result is not None: + self._close_pidfd() + return result + + @deprecated( + "0.20.0", + thing="using trio.Process as an async context manager", + issue=1104, + instead="run_process or nursery.start(run_process, ...)", + ) + async def __aenter__(self): + return self + @deprecated( + "0.20.0", issue=1104, instead="run_process or nursery.start(run_process, ...)" + ) async def aclose(self): """Close any pipes we have to the process (both input and output) and wait for it to exit. @@ -200,34 +209,56 @@ async def aclose(self): try: await self.wait() finally: - if self.returncode is None: + if self._proc.returncode is None: self.kill() with trio.CancelScope(shield=True): await self.wait() + def _close_pidfd(self): + if self._pidfd is not None: + trio.lowlevel.notify_closing(self._pidfd.fileno()) + self._pidfd.close() + self._pidfd = None + async def wait(self): """Block until the process exits. Returns: The exit status of the process; see :attr:`returncode`. """ - if self.poll() is None: - async with self._wait_lock: - if self.poll() is None: + async with self._wait_lock: + if self.poll() is None: + if self._pidfd is not None: + try: + await trio.lowlevel.wait_readable(self._pidfd) + except ClosedResourceError: + # something else (probably a call to poll) already closed the + # pidfd + pass + else: await wait_child_exiting(self) - self._proc.wait() - else: - await trio.hazmat.checkpoint() - return self.returncode + # We have to use .wait() here, not .poll(), because on macOS + # (and maybe other systems, who knows), there's a race + # condition inside the kernel that creates a tiny window where + # kqueue reports that the process has exited, but + # waitpid(WNOHANG) can't yet reap it. So this .wait() may + # actually block for a tiny fraction of a second. + self._proc.wait() + self._close_pidfd() + assert self._proc.returncode is not None + return self._proc.returncode def poll(self): - """Check if the process has exited yet. + """Returns the exit status of the process (an integer), or ``None`` if + it's still running. + + Note that on Trio (unlike the standard library `subprocess.Popen`), + ``process.poll()`` and ``process.returncode`` always give the same + result. See `returncode` for more details. This method is only + included to make it easier to port code from `subprocess`. - Returns: - The exit status of the process, or ``None`` if it is still - running; see :attr:`returncode`. """ - return self._proc.poll() + return self.returncode def send_signal(self, sig): """Send signal ``sig`` to the process. @@ -268,17 +299,20 @@ async def open_process( ) -> Process: r"""Execute a child program in a new process. - After construction, you can interact with the child process by writing - data to its `~Process.stdin` stream (a `~trio.abc.SendStream`), reading - data from its `~Process.stdout` and/or `~Process.stderr` streams (both - `~trio.abc.ReceiveStream`\s), sending it signals using - `~Process.terminate`, `~Process.kill`, or `~Process.send_signal`, and - waiting for it to exit using `~Process.wait`. See `Process` for details. + After construction, you can interact with the child process by writing data to its + `~trio.Process.stdin` stream (a `~trio.abc.SendStream`), reading data from its + `~trio.Process.stdout` and/or `~trio.Process.stderr` streams (both + `~trio.abc.ReceiveStream`\s), sending it signals using `~trio.Process.terminate`, + `~trio.Process.kill`, or `~trio.Process.send_signal`, and waiting for it to exit + using `~trio.Process.wait`. See `trio.Process` for details. + + Each standard stream is only available if you specify that a pipe should be created + for it. For example, if you pass ``stdin=subprocess.PIPE``, you can write to the + `~trio.Process.stdin` stream, else `~trio.Process.stdin` will be ``None``. - Each standard stream is only available if you specify that a pipe should - be created for it. For example, if you pass ``stdin=subprocess.PIPE``, you - can write to the `~Process.stdin` stream, else `~Process.stdin` will be - ``None``. + Unlike `trio.run_process`, this function doesn't do any kind of automatic + management of the child process. It's up to you to implement whatever semantics you + want. Args: command (list or str): The command to run. Typically this is a @@ -306,19 +340,103 @@ async def open_process( are also accepted. Returns: - A new `Process` object. + A new `trio.Process` object. Raises: OSError: if the process spawning fails, for example because the specified command could not be found. """ - # XX FIXME: move the process creation into a thread as soon as we're done - # deprecating Process(...) - await trio.hazmat.checkpoint() - return Process._create( - command, stdin=stdin, stdout=stdout, stderr=stderr, **options - ) + for key in ("universal_newlines", "text", "encoding", "errors", "bufsize"): + if options.get(key): + raise TypeError( + "trio.Process only supports communicating over " + "unbuffered byte streams; the '{}' option is not supported".format(key) + ) + + if os.name == "posix": + if isinstance(command, str) and not options.get("shell"): + raise TypeError( + "command must be a sequence (not a string) if shell=False " + "on UNIX systems" + ) + if not isinstance(command, str) and options.get("shell"): + raise TypeError( + "command must be a string (not a sequence) if shell=True " + "on UNIX systems" + ) + + trio_stdin: Optional[ClosableSendStream] = None + trio_stdout: Optional[ClosableReceiveStream] = None + trio_stderr: Optional[ClosableReceiveStream] = None + # Close the parent's handle for each child side of a pipe; we want the child to + # have the only copy, so that when it exits we can read EOF on our side. The + # trio ends of pipes will be transferred to the Process object, which will be + # responsible for their lifetime. If process spawning fails, though, we still + # want to close them before letting the failure bubble out + with ExitStack() as always_cleanup, ExitStack() as cleanup_on_fail: + if stdin == subprocess.PIPE: + trio_stdin, stdin = create_pipe_to_child_stdin() + always_cleanup.callback(os.close, stdin) + cleanup_on_fail.callback(trio_stdin.close) + if stdout == subprocess.PIPE: + trio_stdout, stdout = create_pipe_from_child_output() + always_cleanup.callback(os.close, stdout) + cleanup_on_fail.callback(trio_stdout.close) + if stderr == subprocess.STDOUT: + # If we created a pipe for stdout, pass the same pipe for + # stderr. If stdout was some non-pipe thing (DEVNULL or a + # given FD), pass the same thing. If stdout was passed as + # None, keep stderr as STDOUT to allow subprocess to dup + # our stdout. Regardless of which of these is applicable, + # don't create a new Trio stream for stderr -- if stdout + # is piped, stderr will be intermixed on the stdout stream. + if stdout is not None: + stderr = stdout + elif stderr == subprocess.PIPE: + trio_stderr, stderr = create_pipe_from_child_output() + always_cleanup.callback(os.close, stderr) + cleanup_on_fail.callback(trio_stderr.close) + + popen = await trio.to_thread.run_sync( + partial( + subprocess.Popen, + command, + stdin=stdin, + stdout=stdout, + stderr=stderr, + **options, + ) + ) + # We did not fail, so dismiss the stack for the trio ends + cleanup_on_fail.pop_all() + + return Process._create(popen, trio_stdin, trio_stdout, trio_stderr) + + +async def _windows_deliver_cancel(p): + try: + p.terminate() + except OSError as exc: + warnings.warn(RuntimeWarning(f"TerminateProcess on {p!r} failed with: {exc!r}")) + + +async def _posix_deliver_cancel(p): + try: + p.terminate() + await trio.sleep(5) + warnings.warn( + RuntimeWarning( + f"process {p!r} ignored SIGTERM for 5 seconds. " + "(Maybe you should pass a custom deliver_cancel?) " + "Trying SIGKILL." + ) + ) + p.kill() + except OSError as exc: + warnings.warn( + RuntimeWarning(f"tried to kill process {p!r}, but failed with: {exc!r}") + ) async def run_process( @@ -328,45 +446,73 @@ async def run_process( capture_stdout=False, capture_stderr=False, check=True, - **options + deliver_cancel=None, + task_status=trio.TASK_STATUS_IGNORED, + **options, ): - """Run ``command`` in a subprocess, wait for it to complete, and - return a :class:`subprocess.CompletedProcess` instance describing - the results. - - If cancelled, :func:`run_process` terminates the subprocess and - waits for it to exit before propagating the cancellation, like - :meth:`Process.aclose`. - - **Input:** The subprocess's standard input stream is set up to - receive the bytes provided as ``stdin``. Once the given input has - been fully delivered, or if none is provided, the subprocess will - receive end-of-file when reading from its standard input. - Alternatively, if you want the subprocess to read its - standard input from the same place as the parent Trio process, you - can pass ``stdin=None``. + """Run ``command`` in a subprocess and wait for it to complete. + + This function can be called in two different ways. + + One option is a direct call, like:: + + completed_process_info = await trio.run_process(...) + + In this case, it returns a :class:`subprocess.CompletedProcess` instance + describing the results. Use this if you want to treat a process like a + function call. + + The other option is to run it as a task using `Nursery.start` – the enhanced version + of `~Nursery.start_soon` that lets a task pass back a value during startup:: + + process = await nursery.start(trio.run_process, ...) + + In this case, `~Nursery.start` returns a `Process` object that you can use + to interact with the process while it's running. Use this if you want to + treat a process like a background task. + + Either way, `run_process` makes sure that the process has exited before + returning, handles cancellation, optionally checks for errors, and + provides some convenient shorthands for dealing with the child's + input/output. + + **Input:** `run_process` supports all the same ``stdin=`` arguments as + `subprocess.Popen`. In addition, if you simply want to pass in some fixed + data, you can pass a plain `bytes` object, and `run_process` will take + care of setting up a pipe, feeding in the data you gave, and then sending + end-of-file. The default is ``b""``, which means that the child will receive + an empty stdin. If you want the child to instead read from the parent's + stdin, use ``stdin=None``. **Output:** By default, any output produced by the subprocess is passed through to the standard output and error streams of the - parent Trio process. If you would like to capture this output and - do something with it, you can pass ``capture_stdout=True`` to - capture the subprocess's standard output, and/or - ``capture_stderr=True`` to capture its standard error. Captured - data is provided as the + parent Trio process. + + When calling `run_process` directly, you can capture the subprocess's output by + passing ``capture_stdout=True`` to capture the subprocess's standard output, and/or + ``capture_stderr=True`` to capture its standard error. Captured data is collected up + by Trio into an in-memory buffer, and then provided as the :attr:`~subprocess.CompletedProcess.stdout` and/or - :attr:`~subprocess.CompletedProcess.stderr` attributes of the - returned :class:`~subprocess.CompletedProcess` object. The value - for any stream that was not captured will be ``None``. - + :attr:`~subprocess.CompletedProcess.stderr` attributes of the returned + :class:`~subprocess.CompletedProcess` object. The value for any stream that was not + captured will be ``None``. + If you want to capture both stdout and stderr while keeping them separate, pass ``capture_stdout=True, capture_stderr=True``. - + If you want to capture both stdout and stderr but mixed together in the order they were printed, use: ``capture_stdout=True, stderr=subprocess.STDOUT``. This directs the child's stderr into its stdout, so the combined output will be available in the `~subprocess.CompletedProcess.stdout` attribute. + If you're using ``await nursery.start(trio.run_process, ...)`` and want to capture + the subprocess's output for further processing, then use ``stdout=subprocess.PIPE`` + and then make sure to read the data out of the `Process.stdout` stream. If you want + to capture stderr separately, use ``stderr=subprocess.PIPE``. If you want to capture + both, but mixed together in the correct order, use ``stdout=subprocess.PIPE, + stderr=subprocess.STDOUT``. + **Error checking:** If the subprocess exits with a nonzero status code, indicating failure, :func:`run_process` raises a :exc:`subprocess.CalledProcessError` exception rather than @@ -374,8 +520,28 @@ async def run_process( the :attr:`~subprocess.CalledProcessError.stdout` and :attr:`~subprocess.CalledProcessError.stderr` attributes of that exception. To disable this behavior, so that :func:`run_process` - returns normally even if the subprocess exits abnormally, pass - ``check=False``. + returns normally even if the subprocess exits abnormally, pass ``check=False``. + + Note that this can make the ``capture_stdout`` and ``capture_stderr`` + arguments useful even when starting `run_process` as a task: if you only + care about the output if the process fails, then you can enable capturing + and then read the output off of the `~subprocess.CalledProcessError`. + + **Cancellation:** If cancelled, `run_process` sends a termination + request to the subprocess, then waits for it to fully exit. The + ``deliver_cancel`` argument lets you control how the process is terminated. + + .. note:: `run_process` is intentionally similar to the standard library + `subprocess.run`, but some of the defaults are different. Specifically, we + default to: + + - ``check=True``, because `"errors should never pass silently / unless + explicitly silenced" `__. + + - ``stdin=b""``, because it produces less-confusing results if a subprocess + unexpectedly tries to read from stdin. + + To get the `subprocess.run` semantics, use ``check=False, stdin=None``. Args: command (list or str): The command to run. Typically this is a @@ -385,26 +551,64 @@ async def run_process( ``**options``, or on Windows, ``command`` may alternatively be a string, which will be parsed following platform-dependent :ref:`quoting rules `. - stdin (:obj:`bytes`, file descriptor, or None): The bytes to provide to - the subprocess on its standard input stream, or ``None`` if the - subprocess's standard input should come from the same place as - the parent Trio process's standard input. As is the case with - the :mod:`subprocess` module, you can also pass a - file descriptor or an object with a ``fileno()`` method, - in which case the subprocess's standard input will come from - that file. + + stdin (:obj:`bytes`, subprocess.PIPE, file descriptor, or None): The + bytes to provide to the subprocess on its standard input stream, or + ``None`` if the subprocess's standard input should come from the + same place as the parent Trio process's standard input. As is the + case with the :mod:`subprocess` module, you can also pass a file + descriptor or an object with a ``fileno()`` method, in which case + the subprocess's standard input will come from that file. + + When starting `run_process` as a background task, you can also use + ``stdin=subprocess.PIPE``, in which case `Process.stdin` will be a + `~trio.abc.SendStream` that you can use to send data to the child. + capture_stdout (bool): If true, capture the bytes that the subprocess writes to its standard output stream and return them in the - :attr:`~subprocess.CompletedProcess.stdout` attribute - of the returned :class:`~subprocess.CompletedProcess` object. + `~subprocess.CompletedProcess.stdout` attribute of the returned + `subprocess.CompletedProcess` or `subprocess.CalledProcessError`. + capture_stderr (bool): If true, capture the bytes that the subprocess writes to its standard error stream and return them in the - :attr:`~subprocess.CompletedProcess.stderr` attribute - of the returned :class:`~subprocess.CompletedProcess` object. + `~subprocess.CompletedProcess.stderr` attribute of the returned + `~subprocess.CompletedProcess` or `subprocess.CalledProcessError`. + check (bool): If false, don't validate that the subprocess exits successfully. You should be sure to check the ``returncode`` attribute of the returned object if you pass ``check=False``, so that errors don't pass silently. + + deliver_cancel (async function or None): If `run_process` is cancelled, + then it needs to kill the child process. There are multiple ways to + do this, so we let you customize it. + + If you pass None (the default), then the behavior depends on the + platform: + + - On Windows, Trio calls ``TerminateProcess``, which should kill the + process immediately. + + - On Unix-likes, the default behavior is to send a ``SIGTERM``, wait + 5 seconds, and send a ``SIGKILL``. + + Alternatively, you can customize this behavior by passing in an + arbitrary async function, which will be called with the `Process` + object as an argument. For example, the default Unix behavior could + be implemented like this:: + + async def my_deliver_cancel(process): + process.send_signal(signal.SIGTERM) + await trio.sleep(5) + process.send_signal(signal.SIGKILL) + + When the process actually exits, the ``deliver_cancel`` function + will automatically be cancelled – so if the process exits after + ``SIGTERM``, then we'll never reach the ``SIGKILL``. + + In any case, `run_process` will always wait for the child process to + exit before raising `Cancelled`. + **options: :func:`run_process` also accepts any :ref:`general subprocess options ` and passes them on to the :class:`~trio.Process` constructor. This includes the @@ -413,8 +617,11 @@ async def run_process( ``stdout=subprocess.DEVNULL``, or file descriptors. Returns: - A :class:`subprocess.CompletedProcess` instance describing the - return code and outputs. + + When called normally – a `subprocess.CompletedProcess` instance + describing the return code and outputs. + + When called via `Nursery.start` – a `trio.Process` instance. Raises: UnicodeError: if ``stdin`` is specified as a Unicode string, rather @@ -437,12 +644,23 @@ async def run_process( if isinstance(stdin, str): raise UnicodeError("process stdin must be bytes, not str") - if stdin == subprocess.PIPE: - raise ValueError( - "stdin=subprocess.PIPE doesn't make sense since the pipe " - "is internal to run_process(); pass the actual data you " - "want to send over that pipe instead" - ) + if task_status is trio.TASK_STATUS_IGNORED: + if stdin is subprocess.PIPE: + raise ValueError( + "stdout=subprocess.PIPE is only valid with nursery.start, " + "since that's the only way to access the pipe; use nursery.start " + "or pass the data you want to write directly" + ) + if options.get("stdout") is subprocess.PIPE: + raise ValueError( + "stdout=subprocess.PIPE is only valid with nursery.start, " + "since that's the only way to access the pipe" + ) + if options.get("stderr") is subprocess.PIPE: + raise ValueError( + "stderr=subprocess.PIPE is only valid with nursery.start, " + "since that's the only way to access the pipe" + ) if isinstance(stdin, (bytes, bytearray, memoryview)): input = stdin options["stdin"] = subprocess.PIPE @@ -462,40 +680,63 @@ async def run_process( raise ValueError("can't specify both stderr and capture_stderr") options["stderr"] = subprocess.PIPE + if deliver_cancel is None: + if os.name == "nt": + deliver_cancel = _windows_deliver_cancel + else: + assert os.name == "posix" + deliver_cancel = _posix_deliver_cancel + stdout_chunks = [] stderr_chunks = [] - async with await open_process(command, **options) as proc: + async def feed_input(stream): + async with stream: + try: + await stream.send_all(input) + except trio.BrokenResourceError: + pass - async def feed_input(): - async with proc.stdin: - try: - await proc.stdin.send_all(input) - except trio.BrokenResourceError: - pass + async def read_output(stream, chunks): + async with stream: + async for chunk in stream: + chunks.append(chunk) - async def read_output(stream, chunks): - async with stream: - async for chunk in stream: - chunks.append(chunk) - - async with trio.open_nursery() as nursery: - if proc.stdin is not None: - nursery.start_soon(feed_input) - if proc.stdout is not None: + async with trio.open_nursery() as nursery: + proc = await open_process(command, **options) + try: + if input is not None: + nursery.start_soon(feed_input, proc.stdin) + proc.stdin = None + proc.stdio = None + if capture_stdout: nursery.start_soon(read_output, proc.stdout, stdout_chunks) - if proc.stderr is not None: + proc.stdout = None + proc.stdio = None + if capture_stderr: nursery.start_soon(read_output, proc.stderr, stderr_chunks) + proc.stderr = None + task_status.started(proc) await proc.wait() + except BaseException: + with trio.CancelScope(shield=True): + killer_cscope = trio.CancelScope(shield=True) - stdout = b"".join(stdout_chunks) if proc.stdout is not None else None - stderr = b"".join(stderr_chunks) if proc.stderr is not None else None + async def killer(): + with killer_cscope: + await deliver_cancel(proc) + + nursery.start_soon(killer) + await proc.wait() + killer_cscope.cancel() + raise + + stdout = b"".join(stdout_chunks) if capture_stdout else None + stderr = b"".join(stderr_chunks) if capture_stderr else None if proc.returncode and check: raise subprocess.CalledProcessError( proc.returncode, proc.args, output=stdout, stderr=stderr ) else: - return subprocess.CompletedProcess( - proc.args, proc.returncode, stdout, stderr - ) + return subprocess.CompletedProcess(proc.args, proc.returncode, stdout, stderr) diff --git a/trio/_subprocess_platform/__init__.py b/trio/_subprocess_platform/__init__.py index b1db8499c6..b6767af8f5 100644 --- a/trio/_subprocess_platform/__init__.py +++ b/trio/_subprocess_platform/__init__.py @@ -1,10 +1,27 @@ # Platform-specific subprocess bits'n'pieces. import os -from typing import Tuple +import sys +from typing import TYPE_CHECKING, Optional, Tuple + +import trio from .. import _core, _subprocess -from .._abc import SendStream, ReceiveStream +from .._abc import ReceiveStream, SendStream + +_wait_child_exiting_error: Optional[ImportError] = None +_create_child_pipe_error: Optional[ImportError] = None + + +if TYPE_CHECKING: + # internal types for the pipe representations used in type checking only + class ClosableSendStream(SendStream): + def close(self) -> None: + ... + + class ClosableReceiveStream(ReceiveStream): + def close(self) -> None: + ... # Fallback versions of the functions provided -- implementations @@ -21,10 +38,10 @@ async def wait_child_exiting(process: "_subprocess.Process") -> None: consumed by this call, since :class:`~subprocess.Popen` wants to be able to do that itself. """ - raise NotImplementedError from wait_child_exiting._error # pragma: no cover + raise NotImplementedError from _wait_child_exiting_error # pragma: no cover -def create_pipe_to_child_stdin() -> Tuple[SendStream, int]: +def create_pipe_to_child_stdin() -> Tuple["ClosableSendStream", int]: """Create a new pipe suitable for sending data from this process to the standard input of a child we're about to spawn. @@ -34,12 +51,10 @@ def create_pipe_to_child_stdin() -> Tuple[SendStream, int]: something suitable for passing as the ``stdin`` argument of :class:`subprocess.Popen`. """ - raise NotImplementedError from ( # pragma: no cover - create_pipe_to_child_stdin._error - ) + raise NotImplementedError from _create_child_pipe_error # pragma: no cover -def create_pipe_from_child_output() -> Tuple[ReceiveStream, int]: +def create_pipe_from_child_output() -> Tuple["ClosableReceiveStream", int]: """Create a new pipe suitable for receiving data into this process from the standard output or error stream of a child we're about to spawn. @@ -50,35 +65,37 @@ def create_pipe_from_child_output() -> Tuple[ReceiveStream, int]: something suitable for passing as the ``stdin`` argument of :class:`subprocess.Popen`. """ - raise NotImplementedError from ( # pragma: no cover - create_pipe_to_child_stdin._error - ) + raise NotImplementedError from _create_child_pipe_error # pragma: no cover try: - if os.name == "nt": + if sys.platform == "win32": from .windows import wait_child_exiting # noqa: F811 - elif hasattr(_core, "wait_kevent"): + elif sys.platform != "linux" and (TYPE_CHECKING or hasattr(_core, "wait_kevent")): from .kqueue import wait_child_exiting # noqa: F811 else: - from .waitid import wait_child_exiting # noqa: F811 + # noqa'd as it's an exported symbol + from .waitid import wait_child_exiting # noqa: F811, F401 except ImportError as ex: # pragma: no cover - wait_child_exiting._error = ex + _wait_child_exiting_error = ex try: - if os.name == "posix": - from ..hazmat import FdStream + if TYPE_CHECKING: + # Not worth type checking these definitions + pass + + elif os.name == "posix": def create_pipe_to_child_stdin(): # noqa: F811 rfd, wfd = os.pipe() - return FdStream(wfd), rfd + return trio.lowlevel.FdStream(wfd), rfd def create_pipe_from_child_output(): # noqa: F811 rfd, wfd = os.pipe() - return FdStream(rfd), wfd + return trio.lowlevel.FdStream(rfd), wfd elif os.name == "nt": - from .._windows_pipes import PipeSendStream, PipeReceiveStream + import msvcrt # This isn't exported or documented, but it's also not # underscore-prefixed, and seems kosher to use. The asyncio docs @@ -87,7 +104,8 @@ def create_pipe_from_child_output(): # noqa: F811 # when asyncio.windows_utils.socketpair was removed in 3.7, the # removal was mentioned in the release notes. from asyncio.windows_utils import pipe as windows_pipe - import msvcrt + + from .._windows_pipes import PipeReceiveStream, PipeSendStream def create_pipe_to_child_stdin(): # noqa: F811 # for stdin, we want the write end (our end) to use overlapped I/O @@ -103,5 +121,4 @@ def create_pipe_from_child_output(): # noqa: F811 raise ImportError("pipes not implemented on this platform") except ImportError as ex: # pragma: no cover - create_pipe_to_child_stdin._error = ex - create_pipe_from_child_output._error = ex + _create_child_pipe_error = ex diff --git a/trio/_subprocess_platform/kqueue.py b/trio/_subprocess_platform/kqueue.py index 837b556fed..9839fd046b 100644 --- a/trio/_subprocess_platform/kqueue.py +++ b/trio/_subprocess_platform/kqueue.py @@ -1,6 +1,11 @@ import select +import sys +from typing import TYPE_CHECKING + from .. import _core, _subprocess +assert (sys.platform != "win32" and sys.platform != "linux") or not TYPE_CHECKING + async def wait_child_exiting(process: "_subprocess.Process") -> None: kqueue = _core.current_kqueue() @@ -13,16 +18,11 @@ async def wait_child_exiting(process: "_subprocess.Process") -> None: KQ_NOTE_EXIT = 0x80000000 make_event = lambda flags: select.kevent( - process.pid, - filter=select.KQ_FILTER_PROC, - flags=flags, - fflags=KQ_NOTE_EXIT + process.pid, filter=select.KQ_FILTER_PROC, flags=flags, fflags=KQ_NOTE_EXIT ) try: - kqueue.control( - [make_event(select.KQ_EV_ADD | select.KQ_EV_ONESHOT)], 0 - ) + kqueue.control([make_event(select.KQ_EV_ADD | select.KQ_EV_ONESHOT)], 0) except ProcessLookupError: # pragma: no cover # This can supposedly happen if the process is in the process # of exiting, and it can even be the case that kqueue says the diff --git a/trio/_subprocess_platform/waitid.py b/trio/_subprocess_platform/waitid.py index 030c546f88..ad69017219 100644 --- a/trio/_subprocess_platform/waitid.py +++ b/trio/_subprocess_platform/waitid.py @@ -17,6 +17,7 @@ def sync_wait_reapable(pid): # pypy doesn't define os.waitid so we need to pull it out ourselves # using cffi: https://bitbucket.org/pypy/pypy/issues/2922/ import cffi + waitid_ffi = cffi.FFI() # Believe it or not, siginfo_t starts with fields in the @@ -43,7 +44,7 @@ def sync_wait_reapable(pid): def sync_wait_reapable(pid): P_PID = 1 WEXITED = 0x00000004 - if sys.platform == 'darwin': # pragma: no cover + if sys.platform == "darwin": # pragma: no cover # waitid() is not exposed on Python on Darwin but does # work through CFFI; note that we typically won't get # here since Darwin also defines kqueue @@ -75,10 +76,7 @@ async def _waitid_system_task(pid: int, event: Event) -> None: try: await to_thread_run_sync( - sync_wait_reapable, - pid, - cancellable=True, - limiter=waitid_limiter, + sync_wait_reapable, pid, cancellable=True, limiter=waitid_limiter ) except OSError: # If waitid fails, waitpid will fail too, so it still makes @@ -103,6 +101,7 @@ async def wait_child_exiting(process: "_subprocess.Process") -> None: # process. if process._wait_for_exit_data is None: - process._wait_for_exit_data = event = Event() + process._wait_for_exit_data = event = Event() # type: ignore _core.spawn_system_task(_waitid_system_task, process.pid, event) + assert isinstance(process._wait_for_exit_data, Event) await process._wait_for_exit_data.wait() diff --git a/trio/_sync.py b/trio/_sync.py index f34acd0858..bd2122858e 100644 --- a/trio/_sync.py +++ b/trio/_sync.py @@ -1,26 +1,39 @@ +from __future__ import annotations + import math +from typing import TYPE_CHECKING import attr -import outcome import trio -from ._util import aiter_compat -from ._core import enable_ki_protection, ParkingLot -from ._deprecate import deprecated +from . import _core +from ._core import Abort, ParkingLot, RaiseCancelT, enable_ki_protection +from ._util import Final + +if TYPE_CHECKING: + from types import TracebackType + + from ._core import Task + from ._core._parking_lot import ParkingLotStatistics + + +@attr.s(frozen=True, slots=True) +class EventStatistics: + """An object containing debugging information. + + Currently the following fields are defined: + + * ``tasks_waiting``: The number of tasks blocked on this event's + :meth:`trio.Event.wait` method. + + """ -__all__ = [ - "Event", - "CapacityLimiter", - "Semaphore", - "Lock", - "StrictFIFOLock", - "Condition", -] + tasks_waiting: int = attr.ib() -@attr.s(repr=False, eq=False, hash=False) -class Event: +@attr.s(repr=False, eq=False, hash=False, slots=True) +class Event(metaclass=Final): """A waitable boolean value useful for inter-task synchronization, inspired by :class:`threading.Event`. @@ -35,54 +48,52 @@ class Event: lost wakeups: it doesn't matter whether :meth:`set` gets called just before or after :meth:`wait`. If you want a lower-level wakeup primitive that doesn't have this protection, consider :class:`Condition` - or :class:`trio.hazmat.ParkingLot`. + or :class:`trio.lowlevel.ParkingLot`. .. note:: Unlike `threading.Event`, `trio.Event` has no `~threading.Event.clear` method. In Trio, once an `Event` has happened, it cannot un-happen. If you need to represent a series of events, consider creating a new `Event` object for each one (they're cheap!), or other synchronization methods like :ref:`channels ` or - `trio.hazmat.ParkingLot`. + `trio.lowlevel.ParkingLot`. """ - _lot = attr.ib(factory=ParkingLot, init=False) - _flag = attr.ib(default=False, init=False) + _tasks: set[Task] = attr.ib(factory=set, init=False) + _flag: bool = attr.ib(default=False, init=False) - def is_set(self): - """Return the current value of the internal flag. - - """ + def is_set(self) -> bool: + """Return the current value of the internal flag.""" return self._flag @enable_ki_protection - def set(self): - """Set the internal flag value to True, and wake any waiting tasks. - - """ - self._flag = True - self._lot.unpark_all() - - @deprecated( - "0.12.0", - issue=637, - instead="multiple Event objects or other synchronization primitives" - ) - def clear(self): - self._flag = False - - async def wait(self): + def set(self) -> None: + """Set the internal flag value to True, and wake any waiting tasks.""" + if not self._flag: + self._flag = True + for task in self._tasks: + _core.reschedule(task) + self._tasks.clear() + + async def wait(self) -> None: """Block until the internal flag value becomes True. If it's already True, then this method returns immediately. """ if self._flag: - await trio.hazmat.checkpoint() + await trio.lowlevel.checkpoint() else: - await self._lot.park() + task = _core.current_task() + self._tasks.add(task) - def statistics(self): + def abort_fn(_: RaiseCancelT) -> Abort: + self._tasks.remove(task) + return _core.Abort.SUCCEEDED + + await _core.wait_task_rescheduled(abort_fn) + + def statistics(self) -> EventStatistics: """Return an object containing debugging information. Currently the following fields are defined: @@ -91,36 +102,55 @@ def statistics(self): :meth:`wait` method. """ - return self._lot.statistics() + return EventStatistics(tasks_waiting=len(self._tasks)) -def async_cm(cls): +# TODO: type this with a Protocol to get rid of type: ignore, see +# https://github.com/python-trio/trio/pull/2682#discussion_r1259097422 +class AsyncContextManagerMixin: @enable_ki_protection - async def __aenter__(self): - await self.acquire() - - __aenter__.__qualname__ = cls.__qualname__ + ".__aenter__" - cls.__aenter__ = __aenter__ + async def __aenter__(self) -> None: + await self.acquire() # type: ignore[attr-defined] @enable_ki_protection - async def __aexit__(self, *args): - self.release() - - __aexit__.__qualname__ = cls.__qualname__ + ".__aexit__" - cls.__aexit__ = __aexit__ - return cls + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + self.release() # type: ignore[attr-defined] + + +@attr.s(frozen=True, slots=True) +class CapacityLimiterStatistics: + """An object containing debugging information. + + Currently the following fields are defined: + + * ``borrowed_tokens``: The number of tokens currently borrowed from + the sack. + * ``total_tokens``: The total number of tokens in the sack. Usually + this will be larger than ``borrowed_tokens``, but it's possibly for + it to be smaller if :attr:`trio.CapacityLimiter.total_tokens` was recently decreased. + * ``borrowers``: A list of all tasks or other entities that currently + hold a token. + * ``tasks_waiting``: The number of tasks blocked on this + :class:`CapacityLimiter`\'s :meth:`trio.CapacityLimiter.acquire` or + :meth:`trio.CapacityLimiter.acquire_on_behalf_of` methods. + """ -@attr.s(frozen=True) -class _CapacityLimiterStatistics: - borrowed_tokens = attr.ib() - total_tokens = attr.ib() - borrowers = attr.ib() - tasks_waiting = attr.ib() + borrowed_tokens: int = attr.ib() + total_tokens: int | float = attr.ib() + borrowers: list[Task | object] = attr.ib() + tasks_waiting: int = attr.ib() -@async_cm -class CapacityLimiter: +# Can be a generic type with a default of Task if/when PEP 696 is released +# and implemented in type checkers. Making it fully generic would currently +# introduce a lot of unnecessary hassle. +class CapacityLimiter(AsyncContextManagerMixin, metaclass=Final): """An object for controlling access to a resource with limited capacity. Sometimes you need to put a limit on how many tasks can do something at @@ -173,25 +203,24 @@ class CapacityLimiter: just borrowed and then put back. """ - def __init__(self, total_tokens): + + # total_tokens would ideally be int|Literal[math.inf] - but that's not valid typing + def __init__(self, total_tokens: int | float): self._lot = ParkingLot() - self._borrowers = set() + self._borrowers: set[Task | object] = set() # Maps tasks attempting to acquire -> borrower, to handle on-behalf-of - self._pending_borrowers = {} + self._pending_borrowers: dict[Task, Task | object] = {} # invoke the property setter for validation - self.total_tokens = total_tokens + self.total_tokens: int | float = total_tokens assert self._total_tokens == total_tokens - def __repr__(self): - return ( - "".format( - id(self), len(self._borrowers), self._total_tokens, - len(self._lot) - ) + def __repr__(self) -> str: + return "".format( + id(self), len(self._borrowers), self._total_tokens, len(self._lot) ) @property - def total_tokens(self): + def total_tokens(self) -> int | float: """The total capacity available. You can change :attr:`total_tokens` by assigning to this attribute. If @@ -206,37 +235,31 @@ def total_tokens(self): return self._total_tokens @total_tokens.setter - def total_tokens(self, new_total_tokens): - if not isinstance( - new_total_tokens, int - ) and new_total_tokens != math.inf: + def total_tokens(self, new_total_tokens: int | float) -> None: + if not isinstance(new_total_tokens, int) and new_total_tokens != math.inf: raise TypeError("total_tokens must be an int or math.inf") if new_total_tokens < 1: raise ValueError("total_tokens must be >= 1") self._total_tokens = new_total_tokens self._wake_waiters() - def _wake_waiters(self): + def _wake_waiters(self) -> None: available = self._total_tokens - len(self._borrowers) for woken in self._lot.unpark(count=available): self._borrowers.add(self._pending_borrowers.pop(woken)) @property - def borrowed_tokens(self): - """The amount of capacity that's currently in use. - - """ + def borrowed_tokens(self) -> int: + """The amount of capacity that's currently in use.""" return len(self._borrowers) @property - def available_tokens(self): - """The amount of capacity that's available to use. - - """ + def available_tokens(self) -> int | float: + """The amount of capacity that's available to use.""" return self.total_tokens - self.borrowed_tokens @enable_ki_protection - def acquire_nowait(self): + def acquire_nowait(self) -> None: """Borrow a token from the sack, without blocking. Raises: @@ -245,15 +268,15 @@ def acquire_nowait(self): tokens. """ - self.acquire_on_behalf_of_nowait(trio.hazmat.current_task()) + self.acquire_on_behalf_of_nowait(trio.lowlevel.current_task()) @enable_ki_protection - def acquire_on_behalf_of_nowait(self, borrower): + def acquire_on_behalf_of_nowait(self, borrower: Task | object) -> None: """Borrow a token from the sack on behalf of ``borrower``, without blocking. Args: - borrower: A :class:`trio.hazmat.Task` or arbitrary opaque object + borrower: A :class:`trio.lowlevel.Task` or arbitrary opaque object used to record who is borrowing this token. This is used by :func:`trio.to_thread.run_sync` to allow threads to "hold tokens", with the intention in the future of using it to `allow @@ -268,8 +291,7 @@ def acquire_on_behalf_of_nowait(self, borrower): """ if borrower in self._borrowers: raise RuntimeError( - "this borrower is already holding one of this " - "CapacityLimiter's tokens" + "this borrower is already holding one of this CapacityLimiter's tokens" ) if len(self._borrowers) < self._total_tokens and not self._lot: self._borrowers.add(borrower) @@ -277,7 +299,7 @@ def acquire_on_behalf_of_nowait(self, borrower): raise trio.WouldBlock @enable_ki_protection - async def acquire(self): + async def acquire(self) -> None: """Borrow a token from the sack, blocking if necessary. Raises: @@ -285,15 +307,15 @@ async def acquire(self): tokens. """ - await self.acquire_on_behalf_of(trio.hazmat.current_task()) + await self.acquire_on_behalf_of(trio.lowlevel.current_task()) @enable_ki_protection - async def acquire_on_behalf_of(self, borrower): + async def acquire_on_behalf_of(self, borrower: Task | object) -> None: """Borrow a token from the sack on behalf of ``borrower``, blocking if necessary. Args: - borrower: A :class:`trio.hazmat.Task` or arbitrary opaque object + borrower: A :class:`trio.lowlevel.Task` or arbitrary opaque object used to record who is borrowing this token; see :meth:`acquire_on_behalf_of_nowait` for details. @@ -302,11 +324,11 @@ async def acquire_on_behalf_of(self, borrower): tokens. """ - await trio.hazmat.checkpoint_if_cancelled() + await trio.lowlevel.checkpoint_if_cancelled() try: self.acquire_on_behalf_of_nowait(borrower) except trio.WouldBlock: - task = trio.hazmat.current_task() + task = trio.lowlevel.current_task() self._pending_borrowers[task] = borrower try: await self._lot.park() @@ -314,10 +336,10 @@ async def acquire_on_behalf_of(self, borrower): self._pending_borrowers.pop(task) raise else: - await trio.hazmat.cancel_shielded_checkpoint() + await trio.lowlevel.cancel_shielded_checkpoint() @enable_ki_protection - def release(self): + def release(self) -> None: """Put a token back into the sack. Raises: @@ -325,10 +347,10 @@ def release(self): sack's tokens. """ - self.release_on_behalf_of(trio.hazmat.current_task()) + self.release_on_behalf_of(trio.lowlevel.current_task()) @enable_ki_protection - def release_on_behalf_of(self, borrower): + def release_on_behalf_of(self, borrower: Task | object) -> None: """Put a token back into the sack on behalf of ``borrower``. Raises: @@ -338,13 +360,12 @@ def release_on_behalf_of(self, borrower): """ if borrower not in self._borrowers: raise RuntimeError( - "this borrower isn't holding any of this CapacityLimiter's " - "tokens" + "this borrower isn't holding any of this CapacityLimiter's tokens" ) self._borrowers.remove(borrower) self._wake_waiters() - def statistics(self): + def statistics(self) -> CapacityLimiterStatistics: """Return an object containing debugging information. Currently the following fields are defined: @@ -361,7 +382,7 @@ def statistics(self): :meth:`acquire_on_behalf_of` methods. """ - return _CapacityLimiterStatistics( + return CapacityLimiterStatistics( borrowed_tokens=len(self._borrowers), total_tokens=self._total_tokens, # Use a list instead of a frozenset just in case we start to allow @@ -371,8 +392,7 @@ def statistics(self): ) -@async_cm -class Semaphore: +class Semaphore(AsyncContextManagerMixin, metaclass=Final): """A `semaphore `__. A semaphore holds an integer value, which can be incremented by @@ -398,7 +418,8 @@ class Semaphore: ``max_value``. """ - def __init__(self, initial_value, *, max_value=None): + + def __init__(self, initial_value: int, *, max_value: int | None = None): if not isinstance(initial_value, int): raise TypeError("initial_value must be an int") if initial_value < 0: @@ -412,37 +433,31 @@ def __init__(self, initial_value, *, max_value=None): # Invariants: # bool(self._lot) implies self._value == 0 # (or equivalently: self._value > 0 implies not self._lot) - self._lot = trio.hazmat.ParkingLot() + self._lot = trio.lowlevel.ParkingLot() self._value = initial_value self._max_value = max_value - def __repr__(self): + def __repr__(self) -> str: if self._max_value is None: max_value_str = "" else: - max_value_str = ", max_value={}".format(self._max_value) - return ( - "".format( - self._value, max_value_str, id(self) - ) + max_value_str = f", max_value={self._max_value}" + return "".format( + self._value, max_value_str, id(self) ) @property - def value(self): - """The current value of the semaphore. - - """ + def value(self) -> int: + """The current value of the semaphore.""" return self._value @property - def max_value(self): - """The maximum allowed value. May be None to indicate no limit. - - """ + def max_value(self) -> int | None: + """The maximum allowed value. May be None to indicate no limit.""" return self._max_value @enable_ki_protection - def acquire_nowait(self): + def acquire_nowait(self) -> None: """Attempt to decrement the semaphore value, without blocking. Raises: @@ -456,21 +471,21 @@ def acquire_nowait(self): raise trio.WouldBlock @enable_ki_protection - async def acquire(self): + async def acquire(self) -> None: """Decrement the semaphore value, blocking if necessary to avoid letting it drop below zero. """ - await trio.hazmat.checkpoint_if_cancelled() + await trio.lowlevel.checkpoint_if_cancelled() try: self.acquire_nowait() except trio.WouldBlock: await self._lot.park() else: - await trio.hazmat.cancel_shielded_checkpoint() + await trio.lowlevel.cancel_shielded_checkpoint() @enable_ki_protection - def release(self): + def release(self) -> None: """Increment the semaphore value, possibly waking a task blocked in :meth:`acquire`. @@ -487,7 +502,7 @@ def release(self): raise ValueError("semaphore released too many times") self._value += 1 - def statistics(self): + def statistics(self) -> ParkingLotStatistics: """Return an object containing debugging information. Currently the following fields are defined: @@ -499,45 +514,42 @@ def statistics(self): return self._lot.statistics() -@attr.s(frozen=True) -class _LockStatistics: - locked = attr.ib() - owner = attr.ib() - tasks_waiting = attr.ib() +@attr.s(frozen=True, slots=True) +class LockStatistics: + """An object containing debugging information for a Lock. + Currently the following fields are defined: -@async_cm -@attr.s(eq=False, hash=False, repr=False) -class Lock: - """A classic `mutex - `__. + * ``locked`` (boolean): indicating whether the lock is held. + * ``owner``: the :class:`trio.lowlevel.Task` currently holding the lock, + or None if the lock is not held. + * ``tasks_waiting`` (int): The number of tasks blocked on this lock's + :meth:`trio.Lock.acquire` method. - This is a non-reentrant, single-owner lock. Unlike - :class:`threading.Lock`, only the owner of the lock is allowed to release - it. + """ - A :class:`Lock` object can be used as an async context manager; it - blocks on entry but not on exit. + locked: bool = attr.ib() + owner: Task | None = attr.ib() + tasks_waiting: int = attr.ib() - """ - _lot = attr.ib(factory=ParkingLot, init=False) - _owner = attr.ib(default=None, init=False) +@attr.s(eq=False, hash=False, repr=False) +class _LockImpl(AsyncContextManagerMixin): + _lot: ParkingLot = attr.ib(factory=ParkingLot, init=False) + _owner: Task | None = attr.ib(default=None, init=False) - def __repr__(self): + def __repr__(self) -> str: if self.locked(): s1 = "locked" - s2 = " with {} waiters".format(len(self._lot)) + s2 = f" with {len(self._lot)} waiters" else: s1 = "unlocked" s2 = "" - return ( - "<{} {} object at {:#x}{}>".format( - s1, self.__class__.__name__, id(self), s2 - ) + return "<{} {} object at {:#x}{}>".format( + s1, self.__class__.__name__, id(self), s2 ) - def locked(self): + def locked(self) -> bool: """Check whether the lock is currently held. Returns: @@ -547,7 +559,7 @@ def locked(self): return self._owner is not None @enable_ki_protection - def acquire_nowait(self): + def acquire_nowait(self) -> None: """Attempt to acquire the lock, without blocking. Raises: @@ -555,7 +567,7 @@ def acquire_nowait(self): """ - task = trio.hazmat.current_task() + task = trio.lowlevel.current_task() if self._owner is task: raise RuntimeError("attempt to re-acquire an already held Lock") elif self._owner is None and not self._lot: @@ -565,11 +577,9 @@ def acquire_nowait(self): raise trio.WouldBlock @enable_ki_protection - async def acquire(self): - """Acquire the lock, blocking if necessary. - - """ - await trio.hazmat.checkpoint_if_cancelled() + async def acquire(self) -> None: + """Acquire the lock, blocking if necessary.""" + await trio.lowlevel.checkpoint_if_cancelled() try: self.acquire_nowait() except trio.WouldBlock: @@ -578,17 +588,17 @@ async def acquire(self): # lock as well. await self._lot.park() else: - await trio.hazmat.cancel_shielded_checkpoint() + await trio.lowlevel.cancel_shielded_checkpoint() @enable_ki_protection - def release(self): + def release(self) -> None: """Release the lock. Raises: RuntimeError: if the calling task does not hold the lock. """ - task = trio.hazmat.current_task() + task = trio.lowlevel.current_task() if task is not self._owner: raise RuntimeError("can't release a Lock you don't own") if self._lot: @@ -596,26 +606,38 @@ def release(self): else: self._owner = None - def statistics(self): + def statistics(self) -> LockStatistics: """Return an object containing debugging information. Currently the following fields are defined: * ``locked``: boolean indicating whether the lock is held. - * ``owner``: the :class:`trio.hazmat.Task` currently holding the lock, + * ``owner``: the :class:`trio.lowlevel.Task` currently holding the lock, or None if the lock is not held. * ``tasks_waiting``: The number of tasks blocked on this lock's :meth:`acquire` method. """ - return _LockStatistics( - locked=self.locked(), - owner=self._owner, - tasks_waiting=len(self._lot), + return LockStatistics( + locked=self.locked(), owner=self._owner, tasks_waiting=len(self._lot) ) -class StrictFIFOLock(Lock): +class Lock(_LockImpl, metaclass=Final): + """A classic `mutex + `__. + + This is a non-reentrant, single-owner lock. Unlike + :class:`threading.Lock`, only the owner of the lock is allowed to release + it. + + A :class:`Lock` object can be used as an async context manager; it + blocks on entry but not on exit. + + """ + + +class StrictFIFOLock(_LockImpl, metaclass=Final): r"""A variant of :class:`Lock` where tasks are guaranteed to acquire the lock in strict first-come-first-served order. @@ -667,7 +689,7 @@ class StrictFIFOLock(Lock): :class:`StrictFIFOLock` guarantees that each task will send its data in the same order that the state machine generated it. - Currently, :class:`StrictFIFOLock` is simply an alias for :class:`Lock`, + Currently, :class:`StrictFIFOLock` is identical to :class:`Lock`, but (a) this may not always be true in the future, especially if Trio ever implements `more sophisticated scheduling policies `__, and (b) the above code @@ -678,14 +700,23 @@ class StrictFIFOLock(Lock): """ -@attr.s(frozen=True) -class _ConditionStatistics: - tasks_waiting = attr.ib() - lock_statistics = attr.ib() +@attr.s(frozen=True, slots=True) +class ConditionStatistics: + r"""An object containing debugging information for a Condition. + Currently the following fields are defined: -@async_cm -class Condition: + * ``tasks_waiting`` (int): The number of tasks blocked on this condition's + :meth:`trio.Condition.wait` method. + * ``lock_statistics``: The result of calling the underlying + :class:`Lock`\s :meth:`~Lock.statistics` method. + + """ + tasks_waiting: int = attr.ib() + lock_statistics: LockStatistics = attr.ib() + + +class Condition(AsyncContextManagerMixin, metaclass=Final): """A classic `condition variable `__, similar to :class:`threading.Condition`. @@ -699,15 +730,16 @@ class Condition: and used. """ - def __init__(self, lock=None): + + def __init__(self, lock: Lock | None = None): if lock is None: lock = Lock() if not type(lock) is Lock: raise TypeError("lock must be a trio.Lock") self._lock = lock - self._lot = trio.hazmat.ParkingLot() + self._lot = trio.lowlevel.ParkingLot() - def locked(self): + def locked(self) -> bool: """Check whether the underlying lock is currently held. Returns: @@ -716,7 +748,7 @@ def locked(self): """ return self._lock.locked() - def acquire_nowait(self): + def acquire_nowait(self) -> None: """Attempt to acquire the underlying lock, without blocking. Raises: @@ -725,20 +757,16 @@ def acquire_nowait(self): """ return self._lock.acquire_nowait() - async def acquire(self): - """Acquire the underlying lock, blocking if necessary. - - """ + async def acquire(self) -> None: + """Acquire the underlying lock, blocking if necessary.""" await self._lock.acquire() - def release(self): - """Release the underlying lock. - - """ + def release(self) -> None: + """Release the underlying lock.""" self._lock.release() @enable_ki_protection - async def wait(self): + async def wait(self) -> None: """Wait for another task to call :meth:`notify` or :meth:`notify_all`. @@ -761,7 +789,7 @@ async def wait(self): RuntimeError: if the calling task does not hold the lock. """ - if trio.hazmat.current_task() is not self._lock._owner: + if trio.lowlevel.current_task() is not self._lock._owner: raise RuntimeError("must hold the lock to wait") self.release() # NOTE: we go to sleep on self._lot, but we'll wake up on @@ -773,7 +801,7 @@ async def wait(self): await self.acquire() raise - def notify(self, n=1): + def notify(self, n: int = 1) -> None: """Wake one or more tasks that are blocked in :meth:`wait`. Args: @@ -783,22 +811,22 @@ def notify(self, n=1): RuntimeError: if the calling task does not hold the lock. """ - if trio.hazmat.current_task() is not self._lock._owner: + if trio.lowlevel.current_task() is not self._lock._owner: raise RuntimeError("must hold the lock to notify") self._lot.repark(self._lock._lot, count=n) - def notify_all(self): + def notify_all(self) -> None: """Wake all tasks that are currently blocked in :meth:`wait`. Raises: RuntimeError: if the calling task does not hold the lock. """ - if trio.hazmat.current_task() is not self._lock._owner: + if trio.lowlevel.current_task() is not self._lock._owner: raise RuntimeError("must hold the lock to notify") self._lot.repark_all(self._lock._lot) - def statistics(self): + def statistics(self) -> ConditionStatistics: r"""Return an object containing debugging information. Currently the following fields are defined: @@ -809,7 +837,6 @@ def statistics(self): :class:`Lock`\s :meth:`~Lock.statistics` method. """ - return _ConditionStatistics( - tasks_waiting=len(self._lot), - lock_statistics=self._lock.statistics(), + return ConditionStatistics( + tasks_waiting=len(self._lot), lock_statistics=self._lock.statistics() ) diff --git a/trio/tests/__init__.py b/trio/_tests/__init__.py similarity index 100% rename from trio/tests/__init__.py rename to trio/_tests/__init__.py diff --git a/trio/_tests/astrill-codesigning-cert.cer b/trio/_tests/astrill-codesigning-cert.cer new file mode 100644 index 0000000000..58cc0c05fa Binary files /dev/null and b/trio/_tests/astrill-codesigning-cert.cer differ diff --git a/trio/_tests/check_type_completeness.py b/trio/_tests/check_type_completeness.py new file mode 100755 index 0000000000..7a65a4249e --- /dev/null +++ b/trio/_tests/check_type_completeness.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python3 +# this file is not run as part of the tests, instead it's run standalone from check.sh +import argparse +import json +import subprocess +import sys +from pathlib import Path + +# the result file is not marked in MANIFEST.in so it's not included in the package +RESULT_FILE = Path(__file__).parent / "verify_types.json" +failed = False + + +# TODO: consider checking manually without `--ignoreexternal`, and/or +# removing it from the below call later on. +def run_pyright(): + return subprocess.run( + [ + "pyright", + # Specify a platform and version to keep imported modules consistent. + "--pythonplatform=Linux", + "--pythonversion=3.8", + "--verifytypes=trio", + "--outputjson", + "--ignoreexternal", + ], + capture_output=True, + ) + + +def check_less_than(key, current_dict, last_dict, /, invert=False): + global failed + current = current_dict[key] + last = last_dict[key] + assert isinstance(current, (float, int)) + assert isinstance(last, (float, int)) + if current == last: + return + if (current > last) ^ invert: + failed = True + print("ERROR: ", end="") + if isinstance(current, float): + strcurrent = f"{current:.4}" + strlast = f"{last:.4}" + else: + strcurrent = str(current) + strlast = str(last) + print( + f"{key} has gone {'down' if current int: + print("*" * 20, "\nChecking type completeness hasn't gone down...") + + res = run_pyright() + current_result = json.loads(res.stdout) + py_typed_file: Path | None = None + + # check if py.typed file was missing + if ( + current_result["generalDiagnostics"] + and current_result["generalDiagnostics"][0]["message"] + == "No py.typed file found" + ): + print("creating py.typed") + py_typed_file = ( + Path(current_result["typeCompleteness"]["packageRootDirectory"]) + / "py.typed" + ) + py_typed_file.write_text("") + + res = run_pyright() + current_result = json.loads(res.stdout) + + if res.stderr: + print(res.stderr) + + if args.full_diagnostics_file is not None: + with open(args.full_diagnostics_file, "w") as file: + json.dump( + [ + sym + for sym in current_result["typeCompleteness"]["symbols"] + if sym["diagnostics"] + ], + file, + sort_keys=True, + indent=2, + ) + + last_result = json.loads(RESULT_FILE.read_text()) + + for key in "errorCount", "warningCount", "informationCount": + check_zero(key, current_result["summary"]) + + for key, invert in ( + ("missingFunctionDocStringCount", False), + ("missingClassDocStringCount", False), + ("missingDefaultParamCount", False), + ("completenessScore", True), + ): + check_less_than( + key, + current_result["typeCompleteness"], + last_result["typeCompleteness"], + invert=invert, + ) + + for key, invert in ( + ("withUnknownType", False), + ("withAmbiguousType", False), + ("withKnownType", True), + ): + check_less_than( + key, + current_result["typeCompleteness"]["exportedSymbolCounts"], + last_result["typeCompleteness"]["exportedSymbolCounts"], + invert=invert, + ) + + assert ( + res.returncode != 0 + ), "Fully type complete! Delete this script and instead directly run `pyright --verifytypes=trio` (consider `--ignoreexternal`) in CI and checking exit code." + + if args.overwrite_file: + print("Overwriting file") + + # don't care about differences in time taken + del current_result["time"] + del current_result["summary"]["timeInSec"] + + # don't fail on version diff so pyright updates can be automerged + del current_result["version"] + + for key in ( + # don't save path (because that varies between machines) + "moduleRootDirectory", + "packageRootDirectory", + "pyTypedPath", + ): + del current_result["typeCompleteness"][key] + + # prune the symbols to only be the name of the symbols with + # errors, instead of saving a huge file. + new_symbols = [] + for symbol in current_result["typeCompleteness"]["symbols"]: + if symbol["diagnostics"]: + new_symbols.append(symbol["name"]) + continue + + # Ensure order of arrays does not affect result. + new_symbols.sort() + current_result["generalDiagnostics"].sort() + current_result["typeCompleteness"]["modules"].sort( + key=lambda module: module.get("name", "") + ) + + current_result["typeCompleteness"]["symbols"] = new_symbols + + with open(RESULT_FILE, "w") as file: + json.dump(current_result, file, sort_keys=True, indent=2) + # add newline at end of file so it's easier to manually modify + file.write("\n") + + if py_typed_file is not None: + print("deleting py.typed") + py_typed_file.unlink() + + print("*" * 20) + + return int(failed) + + +parser = argparse.ArgumentParser() +parser.add_argument("--overwrite-file", action="store_true", default=False) +parser.add_argument("--full-diagnostics-file", type=Path, default=None) +args = parser.parse_args() + +assert __name__ == "__main__", "This script should be run standalone" +sys.exit(main(args)) diff --git a/trio/tests/module_with_deprecations.py b/trio/_tests/module_with_deprecations.py similarity index 59% rename from trio/tests/module_with_deprecations.py rename to trio/_tests/module_with_deprecations.py index d194b6a5bd..73184d11e8 100644 --- a/trio/tests/module_with_deprecations.py +++ b/trio/_tests/module_with_deprecations.py @@ -8,22 +8,14 @@ # attributes in between calling enable_attribute_deprecations and defining # __deprecated_attributes__: import sys + this_mod = sys.modules[__name__] assert this_mod.regular == "hi" assert not hasattr(this_mod, "dep1") __deprecated_attributes__ = { - "dep1": - _deprecate.DeprecatedAttribute( - "value1", - "1.1", - issue=1, - ), - "dep2": - _deprecate.DeprecatedAttribute( - "value2", - "1.2", - issue=1, - instead="instead-string", - ), + "dep1": _deprecate.DeprecatedAttribute("value1", "1.1", issue=1), + "dep2": _deprecate.DeprecatedAttribute( + "value2", "1.2", issue=1, instead="instead-string" + ), } diff --git a/trio/tests/conftest.py b/trio/_tests/pytest_plugin.py similarity index 76% rename from trio/tests/conftest.py rename to trio/_tests/pytest_plugin.py index 772486e1eb..c6d73e25ea 100644 --- a/trio/tests/conftest.py +++ b/trio/_tests/pytest_plugin.py @@ -1,13 +1,8 @@ -# XX this does not belong here -- b/c it's here, these things only apply to -# the tests in trio/_core/tests, not in trio/tests. For now there's some -# copy-paste... -# -# this stuff should become a proper pytest plugin +import inspect import pytest -import inspect -from ..testing import trio_test, MockClock +from ..testing import MockClock, trio_test RUN_SLOW = True diff --git a/trio/tests/test_abc.py b/trio/_tests/test_abc.py similarity index 96% rename from trio/tests/test_abc.py rename to trio/_tests/test_abc.py index c445c97103..2b0b7088b0 100644 --- a/trio/tests/test_abc.py +++ b/trio/_tests/test_abc.py @@ -1,8 +1,6 @@ -import pytest - import attr +import pytest -from ..testing import assert_checkpoints from .. import abc as tabc diff --git a/trio/tests/test_channel.py b/trio/_tests/test_channel.py similarity index 84% rename from trio/tests/test_channel.py rename to trio/_tests/test_channel.py index b43466dd7d..4478c523f5 100644 --- a/trio/tests/test_channel.py +++ b/trio/_tests/test_channel.py @@ -1,8 +1,9 @@ import pytest -from ..testing import wait_all_tasks_blocked, assert_checkpoints import trio -from trio import open_memory_channel, EndOfChannel +from trio import EndOfChannel, open_memory_channel + +from ..testing import assert_checkpoints, wait_all_tasks_blocked async def test_channel(): @@ -157,6 +158,61 @@ async def receive_block(r): await r.receive() +async def test_close_sync(): + async def send_block(s, expect): + with pytest.raises(expect): + await s.send(None) + + # closing send -> other send gets ClosedResourceError + s, r = open_memory_channel(0) + async with trio.open_nursery() as nursery: + nursery.start_soon(send_block, s, trio.ClosedResourceError) + await wait_all_tasks_blocked() + s.close() + + # and it's persistent + with pytest.raises(trio.ClosedResourceError): + s.send_nowait(None) + with pytest.raises(trio.ClosedResourceError): + await s.send(None) + + # and receive gets EndOfChannel + with pytest.raises(EndOfChannel): + r.receive_nowait() + with pytest.raises(EndOfChannel): + await r.receive() + + # closing receive -> send gets BrokenResourceError + s, r = open_memory_channel(0) + async with trio.open_nursery() as nursery: + nursery.start_soon(send_block, s, trio.BrokenResourceError) + await wait_all_tasks_blocked() + r.close() + + # and it's persistent + with pytest.raises(trio.BrokenResourceError): + s.send_nowait(None) + with pytest.raises(trio.BrokenResourceError): + await s.send(None) + + # closing receive -> other receive gets ClosedResourceError + async def receive_block(r): + with pytest.raises(trio.ClosedResourceError): + await r.receive() + + s, r = open_memory_channel(0) + async with trio.open_nursery() as nursery: + nursery.start_soon(receive_block, r) + await wait_all_tasks_blocked() + r.close() + + # and it's persistent + with pytest.raises(trio.ClosedResourceError): + r.receive_nowait() + with pytest.raises(trio.ClosedResourceError): + await r.receive() + + async def test_receive_channel_clone_and_close(): s, r = open_memory_channel(10) @@ -165,7 +221,7 @@ async def test_receive_channel_clone_and_close(): s.send_nowait(None) await r.aclose() - async with r2: + with r2: pass with pytest.raises(trio.ClosedResourceError): @@ -230,7 +286,7 @@ async def test_inf_capacity(): s, r = open_memory_channel(float("inf")) # It's accepted, and we can send all day without blocking - async with s: + with s: for i in range(10): s.send_nowait(i) @@ -291,7 +347,6 @@ async def test_statistics(): async def test_channel_fairness(): - # We can remove an item we just sent, and send an item back in after, if # no-one else is waiting. s, r = open_memory_channel(1) diff --git a/trio/_tests/test_contextvars.py b/trio/_tests/test_contextvars.py new file mode 100644 index 0000000000..63853f5171 --- /dev/null +++ b/trio/_tests/test_contextvars.py @@ -0,0 +1,52 @@ +import contextvars + +from .. import _core + +trio_testing_contextvar = contextvars.ContextVar("trio_testing_contextvar") + + +async def test_contextvars_default(): + trio_testing_contextvar.set("main") + record = [] + + async def child(): + value = trio_testing_contextvar.get() + record.append(value) + + async with _core.open_nursery() as nursery: + nursery.start_soon(child) + assert record == ["main"] + + +async def test_contextvars_set(): + trio_testing_contextvar.set("main") + record = [] + + async def child(): + trio_testing_contextvar.set("child") + value = trio_testing_contextvar.get() + record.append(value) + + async with _core.open_nursery() as nursery: + nursery.start_soon(child) + value = trio_testing_contextvar.get() + assert record == ["child"] + assert value == "main" + + +async def test_contextvars_copy(): + trio_testing_contextvar.set("main") + context = contextvars.copy_context() + trio_testing_contextvar.set("second_main") + record = [] + + async def child(): + value = trio_testing_contextvar.get() + record.append(value) + + async with _core.open_nursery() as nursery: + context.run(nursery.start_soon, child) + nursery.start_soon(child) + value = trio_testing_contextvar.get() + assert set(record) == {"main", "second_main"} + assert value == "second_main" diff --git a/trio/tests/test_deprecate.py b/trio/_tests/test_deprecate.py similarity index 77% rename from trio/tests/test_deprecate.py rename to trio/_tests/test_deprecate.py index 6ecd00003e..33c05ffd25 100644 --- a/trio/tests/test_deprecate.py +++ b/trio/_tests/test_deprecate.py @@ -1,18 +1,23 @@ -import pytest - import inspect import warnings +import pytest + from .._deprecate import ( - TrioDeprecationWarning, warn_deprecated, deprecated, deprecated_alias + TrioDeprecationWarning, + deprecated, + deprecated_alias, + warn_deprecated, ) - from . import module_with_deprecations @pytest.fixture def recwarn_always(recwarn): warnings.simplefilter("always") + # ResourceWarnings about unclosed sockets can occur nondeterministically + # (during GC) which throws off the tests in this file + warnings.simplefilter("ignore", ResourceWarning) return recwarn @@ -25,8 +30,8 @@ def test_warn_deprecated(recwarn_always): def deprecated_thing(): warn_deprecated("ice", "1.2", issue=1, instead="water") - filename, lineno = _here() # https://github.com/google/yapf/issues/447 deprecated_thing() + filename, lineno = _here() assert len(recwarn_always) == 1 got = recwarn_always.pop(TrioDeprecationWarning) assert "ice is deprecated" in got.message.args[0] @@ -34,7 +39,7 @@ def deprecated_thing(): assert "water instead" in got.message.args[0] assert "/issues/1" in got.message.args[0] assert got.filename == filename - assert got.lineno == lineno + 1 + assert got.lineno == lineno - 1 def test_warn_deprecated_no_instead_or_issue(recwarn_always): @@ -54,7 +59,7 @@ def nested1(): def nested2(): warn_deprecated("x", "1.3", issue=7, instead="y", stacklevel=3) - filename, lineno = _here() # https://github.com/google/yapf/issues/447 + filename, lineno = _here() nested1() got = recwarn_always.pop(TrioDeprecationWarning) assert got.filename == filename @@ -154,67 +159,71 @@ def test_deprecated_alias_method(recwarn_always): @deprecated("2.1", issue=1, instead="hi") def docstring_test1(): # pragma: no cover - """Hello! - - """ + """Hello!""" @deprecated("2.1", issue=None, instead="hi") def docstring_test2(): # pragma: no cover - """Hello! - - """ + """Hello!""" @deprecated("2.1", issue=1, instead=None) def docstring_test3(): # pragma: no cover - """Hello! - - """ + """Hello!""" @deprecated("2.1", issue=None, instead=None) def docstring_test4(): # pragma: no cover - """Hello! - - """ + """Hello!""" def test_deprecated_docstring_munging(): - assert docstring_test1.__doc__ == """Hello! + assert ( + docstring_test1.__doc__ + == """Hello! .. deprecated:: 2.1 Use hi instead. For details, see `issue #1 `__. """ + ) - assert docstring_test2.__doc__ == """Hello! + assert ( + docstring_test2.__doc__ + == """Hello! .. deprecated:: 2.1 Use hi instead. """ + ) - assert docstring_test3.__doc__ == """Hello! + assert ( + docstring_test3.__doc__ + == """Hello! .. deprecated:: 2.1 For details, see `issue #1 `__. """ + ) - assert docstring_test4.__doc__ == """Hello! + assert ( + docstring_test4.__doc__ + == """Hello! .. deprecated:: 2.1 """ + ) def test_module_with_deprecations(recwarn_always): assert module_with_deprecations.regular == "hi" assert len(recwarn_always) == 0 - filename, lineno = _here() # https://github.com/google/yapf/issues/447 + filename, lineno = _here() assert module_with_deprecations.dep1 == "value1" got = recwarn_always.pop(TrioDeprecationWarning) assert got.filename == filename @@ -231,3 +240,32 @@ def test_module_with_deprecations(recwarn_always): with pytest.raises(AttributeError): module_with_deprecations.asdf + + +def test_tests_is_deprecated1() -> None: + with pytest.warns(TrioDeprecationWarning): + from trio import tests # warning on import + + # warning on access of any member + with pytest.warns(TrioDeprecationWarning): + assert tests.test_abc # type: ignore[attr-defined] + + +def test_tests_is_deprecated2() -> None: + # warning on direct import of test since that accesses `__spec__` + with pytest.warns(TrioDeprecationWarning): + import trio.tests + + with pytest.warns(TrioDeprecationWarning): + assert trio.tests.test_deprecate # type: ignore[attr-defined] + + +def test_tests_is_deprecated3() -> None: + import trio + + # no warning on accessing the submodule + assert trio.tests + + # only when accessing a submodule member + with pytest.warns(TrioDeprecationWarning): + assert trio.tests.test_abc # type: ignore[attr-defined] diff --git a/trio/_tests/test_dtls.py b/trio/_tests/test_dtls.py new file mode 100644 index 0000000000..b8c32c6d5f --- /dev/null +++ b/trio/_tests/test_dtls.py @@ -0,0 +1,870 @@ +import random +from contextlib import asynccontextmanager +from itertools import count + +import attr +import pytest +import trustme +from OpenSSL import SSL + +import trio +import trio.testing +from trio import DTLSEndpoint +from trio.testing._fake_net import FakeNet + +from .._core._tests.tutil import binds_ipv6, gc_collect_harder, slow + +ca = trustme.CA() +server_cert = ca.issue_cert("example.com") + +server_ctx = SSL.Context(SSL.DTLS_METHOD) +server_cert.configure_cert(server_ctx) + +client_ctx = SSL.Context(SSL.DTLS_METHOD) +ca.configure_trust(client_ctx) + + +parametrize_ipv6 = pytest.mark.parametrize( + "ipv6", [False, pytest.param(True, marks=binds_ipv6)], ids=["ipv4", "ipv6"] +) + + +def endpoint(**kwargs): + ipv6 = kwargs.pop("ipv6", False) + if ipv6: + family = trio.socket.AF_INET6 + else: + family = trio.socket.AF_INET + sock = trio.socket.socket(type=trio.socket.SOCK_DGRAM, family=family) + return DTLSEndpoint(sock, **kwargs) + + +@asynccontextmanager +async def dtls_echo_server(*, autocancel=True, mtu=None, ipv6=False): + with endpoint(ipv6=ipv6) as server: + if ipv6: + localhost = "::1" + else: + localhost = "127.0.0.1" + await server.socket.bind((localhost, 0)) + async with trio.open_nursery() as nursery: + + async def echo_handler(dtls_channel): + print( + "echo handler started: " + f"server {dtls_channel.endpoint.socket.getsockname()} " + f"client {dtls_channel.peer_address}" + ) + if mtu is not None: + dtls_channel.set_ciphertext_mtu(mtu) + try: + print("server starting do_handshake") + await dtls_channel.do_handshake() + print("server finished do_handshake") + async for packet in dtls_channel: + print(f"echoing {packet} -> {dtls_channel.peer_address}") + await dtls_channel.send(packet) + except trio.BrokenResourceError: # pragma: no cover + print("echo handler channel broken") + + await nursery.start(server.serve, server_ctx, echo_handler) + + yield server, server.socket.getsockname() + + if autocancel: + nursery.cancel_scope.cancel() + + +@parametrize_ipv6 +async def test_smoke(ipv6): + async with dtls_echo_server(ipv6=ipv6) as (server_endpoint, address): + with endpoint(ipv6=ipv6) as client_endpoint: + client_channel = client_endpoint.connect(address, client_ctx) + with pytest.raises(trio.NeedHandshakeError): + client_channel.get_cleartext_mtu() + + await client_channel.do_handshake() + await client_channel.send(b"hello") + assert await client_channel.receive() == b"hello" + await client_channel.send(b"goodbye") + assert await client_channel.receive() == b"goodbye" + + with pytest.raises(ValueError): + await client_channel.send(b"") + + client_channel.set_ciphertext_mtu(1234) + cleartext_mtu_1234 = client_channel.get_cleartext_mtu() + client_channel.set_ciphertext_mtu(4321) + assert client_channel.get_cleartext_mtu() > cleartext_mtu_1234 + client_channel.set_ciphertext_mtu(1234) + assert client_channel.get_cleartext_mtu() == cleartext_mtu_1234 + + +@slow +async def test_handshake_over_terrible_network(autojump_clock): + HANDSHAKES = 100 + r = random.Random(0) + fn = FakeNet() + fn.enable() + # avoid spurious timeouts on slow machines + autojump_clock.autojump_threshold = 0.001 + + async with dtls_echo_server() as (_, address): + async with trio.open_nursery() as nursery: + + async def route_packet(packet): + while True: + op = r.choices( + ["deliver", "drop", "dupe", "delay"], + weights=[0.7, 0.1, 0.1, 0.1], + )[0] + print(f"{packet.source} -> {packet.destination}: {op}") + if op == "drop": + return + elif op == "dupe": + fn.send_packet(packet) + elif op == "delay": + await trio.sleep(r.random() * 3) + # I wanted to test random packet corruption too, but it turns out + # openssl has a bug in the following scenario: + # + # - client sends ClientHello + # - server sends HelloVerifyRequest with cookie -- but cookie is + # invalid b/c either the ClientHello or HelloVerifyRequest was + # corrupted + # - client re-sends ClientHello with invalid cookie + # - server replies with new HelloVerifyRequest and correct cookie + # + # At this point, the client *should* switch to the new, valid + # cookie. But OpenSSL doesn't; it stubbornly insists on re-sending + # the original, invalid cookie over and over. In theory we could + # work around this by detecting cookie changes and starting over + # with a whole new SSL object, but (a) it doesn't seem worth it, (b) + # when I tried then I ran into another issue where OpenSSL got stuck + # in an infinite loop sending alerts over and over, which I didn't + # dig into because see (a). + # + # elif op == "distort": + # payload = bytearray(packet.payload) + # payload[r.randrange(len(payload))] ^= 1 << r.randrange(8) + # packet = attr.evolve(packet, payload=payload) + else: + assert op == "deliver" + print( + f"{packet.source} -> {packet.destination}: delivered" + f" {packet.payload.hex()}" + ) + fn.deliver_packet(packet) + break + + def route_packet_wrapper(packet): + try: + nursery.start_soon(route_packet, packet) + except RuntimeError: # pragma: no cover + # We're exiting the nursery, so any remaining packets can just get + # dropped + pass + + fn.route_packet = route_packet_wrapper + + for i in range(HANDSHAKES): + print("#" * 80) + print("#" * 80) + print("#" * 80) + with endpoint() as client_endpoint: + client = client_endpoint.connect(address, client_ctx) + print("client starting do_handshake") + await client.do_handshake() + print("client finished do_handshake") + msg = str(i).encode() + # Make multiple attempts to send data, because the network might + # drop it + while True: + with trio.move_on_after(10) as cscope: + await client.send(msg) + assert await client.receive() == msg + if not cscope.cancelled_caught: + break + + +async def test_implicit_handshake(): + async with dtls_echo_server() as (_, address): + with endpoint() as client_endpoint: + client = client_endpoint.connect(address, client_ctx) + + # Implicit handshake + await client.send(b"xyz") + assert await client.receive() == b"xyz" + + +async def test_full_duplex(): + # Tests simultaneous send/receive, and also multiple methods implicitly invoking + # do_handshake simultaneously. + with endpoint() as server_endpoint, endpoint() as client_endpoint: + await server_endpoint.socket.bind(("127.0.0.1", 0)) + async with trio.open_nursery() as server_nursery: + + async def handler(channel): + async with trio.open_nursery() as nursery: + nursery.start_soon(channel.send, b"from server") + nursery.start_soon(channel.receive) + + await server_nursery.start(server_endpoint.serve, server_ctx, handler) + + client = client_endpoint.connect( + server_endpoint.socket.getsockname(), client_ctx + ) + async with trio.open_nursery() as nursery: + nursery.start_soon(client.send, b"from client") + nursery.start_soon(client.receive) + + server_nursery.cancel_scope.cancel() + + +async def test_channel_closing(): + async with dtls_echo_server() as (_, address): + with endpoint() as client_endpoint: + client = client_endpoint.connect(address, client_ctx) + await client.do_handshake() + client.close() + + with pytest.raises(trio.ClosedResourceError): + await client.send(b"abc") + with pytest.raises(trio.ClosedResourceError): + await client.receive() + + # close is idempotent + client.close() + # can also aclose + await client.aclose() + + +async def test_serve_exits_cleanly_on_close(): + async with dtls_echo_server(autocancel=False) as (server_endpoint, address): + server_endpoint.close() + # Testing that the nursery exits even without being cancelled + # close is idempotent + server_endpoint.close() + + +async def test_client_multiplex(): + async with dtls_echo_server() as (_, address1), dtls_echo_server() as (_, address2): + with endpoint() as client_endpoint: + client1 = client_endpoint.connect(address1, client_ctx) + client2 = client_endpoint.connect(address2, client_ctx) + + await client1.send(b"abc") + await client2.send(b"xyz") + assert await client2.receive() == b"xyz" + assert await client1.receive() == b"abc" + + client_endpoint.close() + + with pytest.raises(trio.ClosedResourceError): + await client1.send("xxx") + with pytest.raises(trio.ClosedResourceError): + await client2.receive() + with pytest.raises(trio.ClosedResourceError): + client_endpoint.connect(address1, client_ctx) + + async with trio.open_nursery() as nursery: + with pytest.raises(trio.ClosedResourceError): + + async def null_handler(_): # pragma: no cover + pass + + await nursery.start(client_endpoint.serve, server_ctx, null_handler) + + +async def test_dtls_over_dgram_only(): + with trio.socket.socket() as s: + with pytest.raises(ValueError): + DTLSEndpoint(s) + + +async def test_double_serve(): + async def null_handler(_): # pragma: no cover + pass + + with endpoint() as server_endpoint: + await server_endpoint.socket.bind(("127.0.0.1", 0)) + async with trio.open_nursery() as nursery: + await nursery.start(server_endpoint.serve, server_ctx, null_handler) + with pytest.raises(trio.BusyResourceError): + await nursery.start(server_endpoint.serve, server_ctx, null_handler) + + nursery.cancel_scope.cancel() + + async with trio.open_nursery() as nursery: + await nursery.start(server_endpoint.serve, server_ctx, null_handler) + nursery.cancel_scope.cancel() + + +async def test_connect_to_non_server(autojump_clock): + fn = FakeNet() + fn.enable() + with endpoint() as client1, endpoint() as client2: + await client1.socket.bind(("127.0.0.1", 0)) + # This should just time out + with trio.move_on_after(100) as cscope: + channel = client2.connect(client1.socket.getsockname(), client_ctx) + await channel.do_handshake() + assert cscope.cancelled_caught + + +async def test_incoming_buffer_overflow(autojump_clock): + fn = FakeNet() + fn.enable() + for buffer_size in [10, 20]: + async with dtls_echo_server() as (_, address): + with endpoint(incoming_packets_buffer=buffer_size) as client_endpoint: + assert client_endpoint.incoming_packets_buffer == buffer_size + client = client_endpoint.connect(address, client_ctx) + for i in range(buffer_size + 15): + await client.send(str(i).encode()) + await trio.sleep(1) + stats = client.statistics() + assert stats.incoming_packets_dropped_in_trio == 15 + for i in range(buffer_size): + assert await client.receive() == str(i).encode() + await client.send(b"buffer clear now") + assert await client.receive() == b"buffer clear now" + + +async def test_server_socket_doesnt_crash_on_garbage(autojump_clock): + fn = FakeNet() + fn.enable() + + from trio._dtls import ( + ContentType, + HandshakeFragment, + HandshakeType, + ProtocolVersion, + Record, + encode_handshake_fragment, + encode_record, + ) + + client_hello = encode_record( + Record( + content_type=ContentType.handshake, + version=ProtocolVersion.DTLS10, + epoch_seqno=0, + payload=encode_handshake_fragment( + HandshakeFragment( + msg_type=HandshakeType.client_hello, + msg_len=10, + msg_seq=0, + frag_offset=0, + frag_len=10, + frag=bytes(10), + ) + ), + ) + ) + + client_hello_extended = client_hello + b"\x00" + client_hello_short = client_hello[:-1] + # cuts off in middle of handshake message header + client_hello_really_short = client_hello[:14] + client_hello_corrupt_record_len = bytearray(client_hello) + client_hello_corrupt_record_len[11] = 0xFF + + client_hello_fragmented = encode_record( + Record( + content_type=ContentType.handshake, + version=ProtocolVersion.DTLS10, + epoch_seqno=0, + payload=encode_handshake_fragment( + HandshakeFragment( + msg_type=HandshakeType.client_hello, + msg_len=20, + msg_seq=0, + frag_offset=0, + frag_len=10, + frag=bytes(10), + ) + ), + ) + ) + + client_hello_trailing_data_in_record = encode_record( + Record( + content_type=ContentType.handshake, + version=ProtocolVersion.DTLS10, + epoch_seqno=0, + payload=encode_handshake_fragment( + HandshakeFragment( + msg_type=HandshakeType.client_hello, + msg_len=20, + msg_seq=0, + frag_offset=0, + frag_len=10, + frag=bytes(10), + ) + ) + + b"\x00", + ) + ) + + handshake_empty = encode_record( + Record( + content_type=ContentType.handshake, + version=ProtocolVersion.DTLS10, + epoch_seqno=0, + payload=b"", + ) + ) + + client_hello_truncated_in_cookie = encode_record( + Record( + content_type=ContentType.handshake, + version=ProtocolVersion.DTLS10, + epoch_seqno=0, + payload=bytes(2 + 32 + 1) + b"\xff", + ) + ) + + async with dtls_echo_server() as (_, address): + with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as sock: + for bad_packet in [ + b"", + b"xyz", + client_hello_extended, + client_hello_short, + client_hello_really_short, + client_hello_corrupt_record_len, + client_hello_fragmented, + client_hello_trailing_data_in_record, + handshake_empty, + client_hello_truncated_in_cookie, + ]: + await sock.sendto(bad_packet, address) + await trio.sleep(1) + + +async def test_invalid_cookie_rejected(autojump_clock): + fn = FakeNet() + fn.enable() + + from trio._dtls import BadPacket, decode_client_hello_untrusted + + with trio.CancelScope() as cscope: + # the first 11 bytes of ClientHello aren't protected by the cookie, so only test + # corrupting bytes after that. + offset_to_corrupt = count(11) + + def route_packet(packet): + try: + _, cookie, _ = decode_client_hello_untrusted(packet.payload) + except BadPacket: + pass + else: + if len(cookie) != 0: + # this is a challenge response packet + # let's corrupt the next offset so the handshake should fail + payload = bytearray(packet.payload) + offset = next(offset_to_corrupt) + if offset >= len(payload): + # We've tried all offsets. Clamp offset to the end of the + # payload, and terminate the test. + offset = len(payload) - 1 + cscope.cancel() + payload[offset] ^= 0x01 + packet = attr.evolve(packet, payload=payload) + + fn.deliver_packet(packet) + + fn.route_packet = route_packet + + async with dtls_echo_server() as (_, address): + while True: + with endpoint() as client: + channel = client.connect(address, client_ctx) + await channel.do_handshake() + assert cscope.cancelled_caught + + +async def test_client_cancels_handshake_and_starts_new_one(autojump_clock): + # if a client disappears during the handshake, and then starts a new handshake from + # scratch, then the first handler's channel should fail, and a new handler get + # started + fn = FakeNet() + fn.enable() + + with endpoint() as server, endpoint() as client: + await server.socket.bind(("127.0.0.1", 0)) + async with trio.open_nursery() as nursery: + first_time = True + + async def handler(channel): + nonlocal first_time + if first_time: + first_time = False + print("handler: first time, cancelling connect") + connect_cscope.cancel() + await trio.sleep(0.5) + print("handler: handshake should fail now") + with pytest.raises(trio.BrokenResourceError): + await channel.do_handshake() + else: + print("handler: not first time, sending hello") + await channel.send(b"hello") + + await nursery.start(server.serve, server_ctx, handler) + + print("client: starting first connect") + with trio.CancelScope() as connect_cscope: + channel = client.connect(server.socket.getsockname(), client_ctx) + await channel.do_handshake() + assert connect_cscope.cancelled_caught + + print("client: starting second connect") + channel = client.connect(server.socket.getsockname(), client_ctx) + assert await channel.receive() == b"hello" + + # Give handlers a chance to finish + await trio.sleep(10) + nursery.cancel_scope.cancel() + + +async def test_swap_client_server(): + with endpoint() as a, endpoint() as b: + await a.socket.bind(("127.0.0.1", 0)) + await b.socket.bind(("127.0.0.1", 0)) + + async def echo_handler(channel): + async for packet in channel: + await channel.send(packet) + + async def crashing_echo_handler(channel): + with pytest.raises(trio.BrokenResourceError): + await echo_handler(channel) + + async with trio.open_nursery() as nursery: + await nursery.start(a.serve, server_ctx, crashing_echo_handler) + await nursery.start(b.serve, server_ctx, echo_handler) + + b_to_a = b.connect(a.socket.getsockname(), client_ctx) + await b_to_a.send(b"b as client") + assert await b_to_a.receive() == b"b as client" + + a_to_b = a.connect(b.socket.getsockname(), client_ctx) + await a_to_b.do_handshake() + with pytest.raises(trio.BrokenResourceError): + await b_to_a.send(b"association broken") + await a_to_b.send(b"a as client") + assert await a_to_b.receive() == b"a as client" + + nursery.cancel_scope.cancel() + + +@slow +async def test_openssl_retransmit_doesnt_break_stuff(): + # can't use autojump_clock here, because the point of the test is to wait for + # openssl's built-in retransmit timer to expire, which is hard-coded to use + # wall-clock time. + fn = FakeNet() + fn.enable() + + blackholed = True + + def route_packet(packet): + if blackholed: + print("dropped packet", packet) + return + print("delivered packet", packet) + # packets.append( + # scapy.all.IP( + # src=packet.source.ip.compressed, dst=packet.destination.ip.compressed + # ) + # / scapy.all.UDP(sport=packet.source.port, dport=packet.destination.port) + # / packet.payload + # ) + fn.deliver_packet(packet) + + fn.route_packet = route_packet + + async with dtls_echo_server() as (server_endpoint, address): + with endpoint() as client_endpoint: + async with trio.open_nursery() as nursery: + + async def connecter(): + client = client_endpoint.connect(address, client_ctx) + await client.do_handshake(initial_retransmit_timeout=1.5) + await client.send(b"hi") + assert await client.receive() == b"hi" + + nursery.start_soon(connecter) + + # openssl's default timeout is 1 second, so this ensures that it thinks + # the timeout has expired + await trio.sleep(1.1) + # disable blackholing and send a garbage packet to wake up openssl so it + # notices the timeout has expired + blackholed = False + await server_endpoint.socket.sendto( + b"xxx", client_endpoint.socket.getsockname() + ) + # now the client task should finish connecting and exit cleanly + + # scapy.all.wrpcap("/tmp/trace.pcap", packets) + + +async def test_initial_retransmit_timeout_configuration(autojump_clock): + fn = FakeNet() + fn.enable() + + blackholed = True + + def route_packet(packet): + nonlocal blackholed + if blackholed: + blackholed = False + else: + fn.deliver_packet(packet) + + fn.route_packet = route_packet + + async with dtls_echo_server() as (_, address): + for t in [1, 2, 4]: + with endpoint() as client: + before = trio.current_time() + blackholed = True + channel = client.connect(address, client_ctx) + await channel.do_handshake(initial_retransmit_timeout=t) + after = trio.current_time() + assert after - before == t + + +async def test_explicit_tiny_mtu_is_respected(): + # ClientHello is ~240 bytes, and it can't be fragmented, so our mtu has to + # be larger than that. (300 is still smaller than any real network though.) + MTU = 300 + + fn = FakeNet() + fn.enable() + + def route_packet(packet): + print(f"delivering {packet}") + print(f"payload size: {len(packet.payload)}") + assert len(packet.payload) <= MTU + fn.deliver_packet(packet) + + fn.route_packet = route_packet + + async with dtls_echo_server(mtu=MTU) as (server, address): + with endpoint() as client: + channel = client.connect(address, client_ctx) + channel.set_ciphertext_mtu(MTU) + await channel.do_handshake() + await channel.send(b"hi") + assert await channel.receive() == b"hi" + + +@parametrize_ipv6 +async def test_handshake_handles_minimum_network_mtu(ipv6, autojump_clock): + # Fake network that has the minimum allowable MTU for whatever protocol we're using. + fn = FakeNet() + fn.enable() + + if ipv6: + mtu = 1280 - 48 + else: + mtu = 576 - 28 + + def route_packet(packet): + if len(packet.payload) > mtu: + print(f"dropping {packet}") + else: + print(f"delivering {packet}") + fn.deliver_packet(packet) + + fn.route_packet = route_packet + + # See if we can successfully do a handshake -- some of the volleys will get dropped, + # and the retransmit logic should detect this and back off the MTU to something + # smaller until it succeeds. + async with dtls_echo_server(ipv6=ipv6) as (_, address): + with endpoint(ipv6=ipv6) as client_endpoint: + client = client_endpoint.connect(address, client_ctx) + # the handshake mtu backoff shouldn't affect the return value from + # get_cleartext_mtu, b/c that's under the user's control via + # set_ciphertext_mtu + client.set_ciphertext_mtu(9999) + await client.send(b"xyz") + assert await client.receive() == b"xyz" + assert client.get_cleartext_mtu() > 9000 # as vegeta said + + +@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning") +async def test_system_task_cleaned_up_on_gc(): + before_tasks = trio.lowlevel.current_statistics().tasks_living + + # We put this into a sub-function so that everything automatically becomes garbage + # when the frame exits. For some reason just doing 'del e' wasn't enough on pypy + # with coverage enabled -- I think we were hitting this bug: + # https://foss.heptapod.net/pypy/pypy/-/issues/3656 + async def start_and_forget_endpoint(): + e = endpoint() + + # This connection/handshake attempt can't succeed. The only purpose is to force + # the endpoint to set up a receive loop. + with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as s: + await s.bind(("127.0.0.1", 0)) + c = e.connect(s.getsockname(), client_ctx) + async with trio.open_nursery() as nursery: + nursery.start_soon(c.do_handshake) + await trio.testing.wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + + during_tasks = trio.lowlevel.current_statistics().tasks_living + return during_tasks + + with pytest.warns(ResourceWarning): + during_tasks = await start_and_forget_endpoint() + await trio.testing.wait_all_tasks_blocked() + gc_collect_harder() + + await trio.testing.wait_all_tasks_blocked() + + after_tasks = trio.lowlevel.current_statistics().tasks_living + assert before_tasks < during_tasks + assert before_tasks == after_tasks + + +@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning") +async def test_gc_before_system_task_starts(): + e = endpoint() + + with pytest.warns(ResourceWarning): + del e + gc_collect_harder() + + await trio.testing.wait_all_tasks_blocked() + + +@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning") +async def test_gc_as_packet_received(): + fn = FakeNet() + fn.enable() + + e = endpoint() + await e.socket.bind(("127.0.0.1", 0)) + e._ensure_receive_loop() + + await trio.testing.wait_all_tasks_blocked() + + with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as s: + await s.sendto(b"xxx", e.socket.getsockname()) + # At this point, the endpoint's receive loop has been marked runnable because it + # just received a packet; closing the endpoint socket won't interrupt that. But by + # the time it wakes up to process the packet, the endpoint will be gone. + with pytest.warns(ResourceWarning): + del e + gc_collect_harder() + + +@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning") +def test_gc_after_trio_exits(): + async def main(): + # We use fakenet just to make sure no real sockets can leak out of the test + # case - on pypy somehow the socket was outliving the gc_collect_harder call + # below. Since the test is just making sure DTLSEndpoint.__del__ doesn't explode + # when called after trio exits, it doesn't need a real socket. + fn = FakeNet() + fn.enable() + return endpoint() + + e = trio.run(main) + with pytest.warns(ResourceWarning): + del e + gc_collect_harder() + + +async def test_already_closed_socket_doesnt_crash(): + with endpoint() as e: + # We close the socket before checkpointing, so the socket will already be closed + # when the system task starts up + e.socket.close() + # Now give it a chance to start up, and hopefully not crash + await trio.testing.wait_all_tasks_blocked() + + +async def test_socket_closed_while_processing_clienthello(autojump_clock): + fn = FakeNet() + fn.enable() + + # Check what happens if the socket is discovered to be closed when sending a + # HelloVerifyRequest, since that has its own sending logic + async with dtls_echo_server() as (server, address): + + def route_packet(packet): + fn.deliver_packet(packet) + server.socket.close() + + fn.route_packet = route_packet + + with endpoint() as client_endpoint: + with trio.move_on_after(10): + client = client_endpoint.connect(address, client_ctx) + await client.do_handshake() + + +async def test_association_replaced_while_handshake_running(autojump_clock): + fn = FakeNet() + fn.enable() + + def route_packet(packet): + pass + + fn.route_packet = route_packet + + async with dtls_echo_server() as (_, address): + with endpoint() as client_endpoint: + c1 = client_endpoint.connect(address, client_ctx) + async with trio.open_nursery() as nursery: + + async def doomed_handshake(): + with pytest.raises(trio.BrokenResourceError): + await c1.do_handshake() + + nursery.start_soon(doomed_handshake) + + await trio.sleep(10) + + client_endpoint.connect(address, client_ctx) + + +async def test_association_replaced_before_handshake_starts(): + fn = FakeNet() + fn.enable() + + # This test shouldn't send any packets + def route_packet(packet): # pragma: no cover + assert False + + fn.route_packet = route_packet + + async with dtls_echo_server() as (_, address): + with endpoint() as client_endpoint: + c1 = client_endpoint.connect(address, client_ctx) + client_endpoint.connect(address, client_ctx) + with pytest.raises(trio.BrokenResourceError): + await c1.do_handshake() + + +async def test_send_to_closed_local_port(): + # On Windows, sending a UDP packet to a closed local port can cause a weird + # ECONNRESET error later, inside the receive task. Make sure we're handling it + # properly. + async with dtls_echo_server() as (_, address): + with endpoint() as client_endpoint: + async with trio.open_nursery() as nursery: + for i in range(1, 10): + channel = client_endpoint.connect(("127.0.0.1", i), client_ctx) + nursery.start_soon(channel.do_handshake) + channel = client_endpoint.connect(address, client_ctx) + await channel.send(b"xxx") + assert await channel.receive() == b"xxx" + nursery.cancel_scope.cancel() diff --git a/trio/_tests/test_exports.py b/trio/_tests/test_exports.py new file mode 100644 index 0000000000..b5d0a44088 --- /dev/null +++ b/trio/_tests/test_exports.py @@ -0,0 +1,499 @@ +import enum +import functools +import importlib +import inspect +import json +import socket as stdlib_socket +import sys +from pathlib import Path +from types import ModuleType + +import attrs +import pytest + +import trio +import trio.testing + +from .. import _core, _util +from .._core._tests.tutil import slow +from .pytest_plugin import RUN_SLOW + +mypy_cache_updated = False + + +def _ensure_mypy_cache_updated(): + # This pollutes the `empty` dir. Should this be changed? + from mypy.api import run + + global mypy_cache_updated + if not mypy_cache_updated: + # mypy cache was *probably* already updated by the other tests, + # but `pytest -k ...` might run just this test on its own + result = run( + [ + "--config-file=", + "--cache-dir=./.mypy_cache", + "--no-error-summary", + "-c", + "import trio", + ] + ) + assert not result[1] # stderr + assert not result[0] # stdout + mypy_cache_updated = True + + +def test_core_is_properly_reexported(): + # Each export from _core should be re-exported by exactly one of these + # three modules: + sources = [trio, trio.lowlevel, trio.testing] + for symbol in dir(_core): + if symbol.startswith("_"): + continue + found = 0 + for source in sources: + if symbol in dir(source) and getattr(source, symbol) is getattr( + _core, symbol + ): + found += 1 + print(symbol, found) + assert found == 1 + + +def public_modules(module): + yield module + for name, class_ in module.__dict__.items(): + if name.startswith("_"): # pragma: no cover + continue + if not isinstance(class_, ModuleType): + continue + if not class_.__name__.startswith(module.__name__): # pragma: no cover + continue + if class_ is module: # pragma: no cover + continue + yield from public_modules(class_) + + +PUBLIC_MODULES = list(public_modules(trio)) +PUBLIC_MODULE_NAMES = [m.__name__ for m in PUBLIC_MODULES] + + +# It doesn't make sense for downstream redistributors to run this test, since +# they might be using a newer version of Python with additional symbols which +# won't be reflected in trio.socket, and this shouldn't cause downstream test +# runs to start failing. +@pytest.mark.redistributors_should_skip +# Static analysis tools often have trouble with alpha releases, where Python's +# internals are in flux, grammar may not have settled down, etc. +@pytest.mark.skipif( + sys.version_info.releaselevel == "alpha", + reason="skip static introspection tools on Python dev/alpha releases", +) +@pytest.mark.parametrize("modname", PUBLIC_MODULE_NAMES) +@pytest.mark.parametrize("tool", ["pylint", "jedi", "mypy", "pyright_verifytypes"]) +@pytest.mark.filterwarnings( + # https://github.com/pypa/setuptools/issues/3274 + "ignore:module 'sre_constants' is deprecated:DeprecationWarning", +) +def test_static_tool_sees_all_symbols(tool, modname, tmpdir): + module = importlib.import_module(modname) + + def no_underscores(symbols): + return {symbol for symbol in symbols if not symbol.startswith("_")} + + runtime_names = no_underscores(dir(module)) + + # ignore deprecated module `tests` being invisible + if modname == "trio": + runtime_names.discard("tests") + + if tool in ("mypy", "pyright_verifytypes"): + # create py.typed file + py_typed_path = Path(trio.__file__).parent / "py.typed" + py_typed_exists = py_typed_path.exists() + if not py_typed_exists: # pragma: no branch + py_typed_path.write_text("") + + if tool == "pylint": + from pylint.lint import PyLinter + + linter = PyLinter() + ast = linter.get_ast(module.__file__, modname) + static_names = no_underscores(ast) + elif tool == "jedi": + import jedi + + # Simulate typing "import trio; trio." + script = jedi.Script(f"import {modname}; {modname}.") + completions = script.complete() + static_names = no_underscores(c.name for c in completions) + elif tool == "mypy": + if not RUN_SLOW: # pragma: no cover + pytest.skip("use --run-slow to check against mypy") + if sys.implementation.name != "cpython": + pytest.skip("mypy not installed in tests on pypy") + + cache = Path.cwd() / ".mypy_cache" + + _ensure_mypy_cache_updated() + + trio_cache = next(cache.glob("*/trio")) + _, modname = (modname + ".").split(".", 1) + modname = modname[:-1] + mod_cache = trio_cache / modname if modname else trio_cache + if mod_cache.is_dir(): + mod_cache = mod_cache / "__init__.data.json" + else: + mod_cache = trio_cache / (modname + ".data.json") + + assert mod_cache.exists() and mod_cache.is_file() + with mod_cache.open() as cache_file: + cache_json = json.loads(cache_file.read()) + static_names = no_underscores( + key + for key, value in cache_json["names"].items() + if not key.startswith(".") and value["kind"] == "Gdef" + ) + elif tool == "pyright_verifytypes": + if not RUN_SLOW: # pragma: no cover + pytest.skip("use --run-slow to check against mypy") + import subprocess + + res = subprocess.run( + ["pyright", f"--verifytypes={modname}", "--outputjson"], + capture_output=True, + ) + current_result = json.loads(res.stdout) + + static_names = { + x["name"][len(modname) + 1 :] + for x in current_result["typeCompleteness"]["symbols"] + if x["name"].startswith(modname) + } + + # pyright ignores the symbol defined behind `if False` + if modname == "trio": + static_names.add("testing") + + # these are hidden behind `if sys.platform != "win32" or not TYPE_CHECKING` + # so presumably pyright is parsing that if statement, in which case we don't + # care about them being missing. + if modname == "trio.socket" and sys.platform == "win32": + ignored_missing_names = {"if_indextoname", "if_nameindex", "if_nametoindex"} + assert static_names.isdisjoint(ignored_missing_names) + static_names.update(ignored_missing_names) + + else: # pragma: no cover + assert False + + # remove py.typed file + if tool in ("mypy", "pyright_verifytypes") and not py_typed_exists: + py_typed_path.unlink() + + # mypy handles errors with an `assert` in its branch + if tool == "mypy": + return + + # It's expected that the static set will contain more names than the + # runtime set: + # - static tools are sometimes sloppy and include deleted names + # - some symbols are platform-specific at runtime, but always show up in + # static analysis (e.g. in trio.socket or trio.lowlevel) + # So we check that the runtime names are a subset of the static names. + missing_names = runtime_names - static_names + + # ignore warnings about deprecated module tests + missing_names -= {"tests"} + + if missing_names: # pragma: no cover + print(f"{tool} can't see the following names in {modname}:") + print() + for name in sorted(missing_names): + print(f" {name}") + assert False + + +# this could be sped up by only invoking mypy once per module, or even once for all +# modules, instead of once per class. +@slow +# see comment on test_static_tool_sees_all_symbols +@pytest.mark.redistributors_should_skip +# Static analysis tools often have trouble with alpha releases, where Python's +# internals are in flux, grammar may not have settled down, etc. +@pytest.mark.skipif( + sys.version_info.releaselevel == "alpha", + reason="skip static introspection tools on Python dev/alpha releases", +) +@pytest.mark.parametrize("module_name", PUBLIC_MODULE_NAMES) +@pytest.mark.parametrize("tool", ["jedi", "mypy"]) +def test_static_tool_sees_class_members(tool, module_name, tmpdir) -> None: + module = PUBLIC_MODULES[PUBLIC_MODULE_NAMES.index(module_name)] + + # ignore hidden, but not dunder, symbols + def no_hidden(symbols): + return { + symbol + for symbol in symbols + if (not symbol.startswith("_")) or symbol.startswith("__") + } + + py_typed_path = Path(trio.__file__).parent / "py.typed" + py_typed_exists = py_typed_path.exists() + + if tool == "mypy": + if sys.implementation.name != "cpython": + pytest.skip("mypy not installed in tests on pypy") + # create py.typed file + # remove this logic when trio is marked with py.typed proper + if not py_typed_exists: # pragma: no branch + py_typed_path.write_text("") + + cache = Path.cwd() / ".mypy_cache" + + _ensure_mypy_cache_updated() + + trio_cache = next(cache.glob("*/trio")) + modname = module_name + _, modname = (modname + ".").split(".", 1) + modname = modname[:-1] + mod_cache = trio_cache / modname if modname else trio_cache + if mod_cache.is_dir(): + mod_cache = mod_cache / "__init__.data.json" + else: + mod_cache = trio_cache / (modname + ".data.json") + + assert mod_cache.exists() and mod_cache.is_file() + with mod_cache.open() as cache_file: + cache_json = json.loads(cache_file.read()) + + # skip a bunch of file-system activity (probably can un-memoize?) + @functools.lru_cache + def lookup_symbol(symbol): + topname, *modname, name = symbol.split(".") + version = next(cache.glob("3.*/")) + mod_cache = version / topname + if not mod_cache.is_dir(): + mod_cache = version / (topname + ".data.json") + + if modname: + for piece in modname[:-1]: + mod_cache /= piece + next_cache = mod_cache / modname[-1] + if next_cache.is_dir(): + mod_cache = next_cache / "__init__.data.json" + else: + mod_cache = mod_cache / (modname[-1] + ".data.json") + + with mod_cache.open() as f: + return json.loads(f.read())["names"][name] + + errors: dict[str, object] = {} + for class_name, class_ in module.__dict__.items(): + if not isinstance(class_, type): + continue + if module_name == "trio.socket" and class_name in dir(stdlib_socket): + continue + # Deprecated classes are exported with a leading underscore + # We don't care about errors in _MultiError as that's on its way out anyway + if class_name.startswith("_"): # pragma: no cover + continue + + # dir() and inspect.getmembers doesn't display properties from the metaclass + # also ignore some dunder methods that tend to differ but are of no consequence + ignore_names = set(dir(type(class_))) | { + "__annotations__", + "__attrs_attrs__", + "__attrs_own_setattr__", + "__class_getitem__", + "__getstate__", + "__match_args__", + "__order__", + "__orig_bases__", + "__parameters__", + "__setstate__", + "__slots__", + "__weakref__", + } + + # pypy seems to have some additional dunders that differ + if sys.implementation.name == "pypy": + ignore_names |= { + "__basicsize__", + "__dictoffset__", + "__itemsize__", + "__sizeof__", + "__weakrefoffset__", + "__unicode__", + } + + # inspect.getmembers sees `name` and `value` in Enums, otherwise + # it behaves the same way as `dir` + # runtime_names = no_underscores(dir(class_)) + runtime_names = ( + no_hidden(x[0] for x in inspect.getmembers(class_)) - ignore_names + ) + + if tool == "jedi": + import jedi + + script = jedi.Script( + f"from {module_name} import {class_name}; {class_name}." + ) + completions = script.complete() + static_names = no_hidden(c.name for c in completions) - ignore_names + + elif tool == "mypy": + # load the cached type information + cached_type_info = cache_json["names"][class_name] + if "node" not in cached_type_info: + cached_type_info = lookup_symbol(cached_type_info["cross_ref"]) + + assert "node" in cached_type_info + node = cached_type_info["node"] + static_names = no_hidden(k for k in node["names"] if not k.startswith(".")) + for symbol in node["mro"][1:]: + node = lookup_symbol(symbol)["node"] + static_names |= no_hidden( + k for k in node["names"] if not k.startswith(".") + ) + static_names -= ignore_names + + else: # pragma: no cover + assert False, "unknown tool" + + missing = runtime_names - static_names + extra = static_names - runtime_names + + # using .remove() instead of .delete() to get an error in case they start not + # being missing + + if ( + tool == "jedi" + and BaseException in class_.__mro__ + and sys.version_info >= (3, 11) + ): + missing.remove("add_note") + + if ( + tool == "mypy" + and BaseException in class_.__mro__ + and sys.version_info >= (3, 11) + ): + extra.remove("__notes__") + + if tool == "mypy" and attrs.has(class_): + # e.g. __trio__core__run_CancelScope_AttrsAttributes__ + before = len(extra) + extra = {e for e in extra if not e.endswith("AttrsAttributes__")} + assert len(extra) == before - 1 + + # TODO: this *should* be visible via `dir`!! + if tool == "mypy" and class_ == trio.Nursery: + extra.remove("cancel_scope") + + # TODO: I'm not so sure about these, but should still be looked at. + EXTRAS = { + trio.DTLSChannel: {"peer_address", "endpoint"}, + trio.DTLSEndpoint: {"socket", "incoming_packets_buffer"}, + trio.Process: {"args", "pid", "stderr", "stdin", "stdio", "stdout"}, + trio.SSLListener: {"transport_listener"}, + trio.SSLStream: {"transport_stream"}, + trio.SocketListener: {"socket"}, + trio.SocketStream: {"socket"}, + trio.testing.MemoryReceiveStream: {"close_hook", "receive_some_hook"}, + trio.testing.MemorySendStream: { + "close_hook", + "send_all_hook", + "wait_send_all_might_not_block_hook", + }, + } + if tool == "mypy" and class_ in EXTRAS: + before = len(extra) + extra -= EXTRAS[class_] + assert len(extra) == before - len(EXTRAS[class_]) + + # probably an issue with mypy.... + if tool == "mypy" and class_ == trio.Path and sys.platform == "win32": + before = len(missing) + missing -= {"owner", "group", "is_mount"} + assert len(missing) == before - 3 + + # TODO: why is this? Is it a problem? + # see https://github.com/python-trio/trio/pull/2631#discussion_r1185615916 + if class_ == trio.StapledStream: + extra.remove("receive_stream") + extra.remove("send_stream") + + # I have not researched why these are missing, should maybe create an issue + # upstream with jedi + if tool == "jedi" and sys.version_info >= (3, 11): + if class_ in ( + trio.DTLSChannel, + trio.MemoryReceiveChannel, + trio.MemorySendChannel, + trio.SSLListener, + trio.SocketListener, + ): + missing.remove("__aenter__") + missing.remove("__aexit__") + if class_ in (trio.DTLSChannel, trio.MemoryReceiveChannel): + missing.remove("__aiter__") + missing.remove("__anext__") + + # intentionally hidden behind type guard + if class_ == trio.Path: + missing.remove("__getattr__") + + if missing or extra: # pragma: no cover + errors[f"{module_name}.{class_name}"] = { + "missing": missing, + "extra": extra, + } + + # clean up created py.typed file + if tool == "mypy" and not py_typed_exists: + py_typed_path.unlink() + + # `assert not errors` will not print the full content of errors, even with + # `--verbose`, so we manually print it + if errors: # pragma: no cover + from pprint import pprint + + print(f"\n{tool} can't see the following symbols in {module_name}:") + pprint(errors) + assert not errors + + +def test_classes_are_final(): + for module in PUBLIC_MODULES: + for name, class_ in module.__dict__.items(): + if not isinstance(class_, type): + continue + # Deprecated classes are exported with a leading underscore + if name.startswith("_"): # pragma: no cover + continue + + # Abstract classes can be subclassed, because that's the whole + # point of ABCs + if inspect.isabstract(class_): + continue + # Exceptions are allowed to be subclassed, because exception + # subclassing isn't used to inherit behavior. + if issubclass(class_, BaseException): + continue + # These are classes that are conceptually abstract, but + # inspect.isabstract returns False for boring reasons. + if class_ in {trio.abc.Instrument, trio.socket.SocketType}: + continue + # Enums have their own metaclass, so we can't use our metaclasses. + # And I don't think there's a lot of risk from people subclassing + # enums... + if issubclass(class_, enum.Enum): + continue + # ... insert other special cases here ... + + # don't care about the *Statistics classes + if name.endswith("Statistics"): + continue + + assert isinstance(class_, _util.Final) diff --git a/trio/_tests/test_fakenet.py b/trio/_tests/test_fakenet.py new file mode 100644 index 0000000000..bc691c9db5 --- /dev/null +++ b/trio/_tests/test_fakenet.py @@ -0,0 +1,44 @@ +import pytest + +import trio +from trio.testing._fake_net import FakeNet + + +def fn(): + fn = FakeNet() + fn.enable() + return fn + + +async def test_basic_udp(): + fn() + s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + + await s1.bind(("127.0.0.1", 0)) + ip, port = s1.getsockname() + assert ip == "127.0.0.1" + assert port != 0 + await s2.sendto(b"xyz", s1.getsockname()) + data, addr = await s1.recvfrom(10) + assert data == b"xyz" + assert addr == s2.getsockname() + await s1.sendto(b"abc", s2.getsockname()) + data, addr = await s2.recvfrom(10) + assert data == b"abc" + assert addr == s1.getsockname() + + +async def test_msg_trunc(): + fn() + s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + await s1.bind(("127.0.0.1", 0)) + await s2.sendto(b"xyz", s1.getsockname()) + data, addr = await s1.recvfrom(10) + + +async def test_basic_tcp(): + fn() + with pytest.raises(NotImplementedError): + trio.socket.socket() diff --git a/trio/tests/test_file_io.py b/trio/_tests/test_file_io.py similarity index 63% rename from trio/tests/test_file_io.py rename to trio/_tests/test_file_io.py index 8f96e75aac..bae426cf48 100644 --- a/trio/tests/test_file_io.py +++ b/trio/_tests/test_file_io.py @@ -1,18 +1,21 @@ +import importlib import io - -import pytest +import os +import re +from typing import List, Tuple from unittest import mock from unittest.mock import sentinel +import pytest + import trio -from trio import _core -from trio._util import fspath -from trio._file_io import AsyncIOWrapper, _FILE_SYNC_ATTRS, _FILE_ASYNC_METHODS +from trio import _core, _file_io +from trio._file_io import _FILE_ASYNC_METHODS, _FILE_SYNC_ATTRS, AsyncIOWrapper @pytest.fixture def path(tmpdir): - return fspath(tmpdir.join('test')) + return os.fspath(tmpdir.join("test")) @pytest.fixture @@ -27,7 +30,7 @@ def async_file(wrapped): def test_wrap_invalid(): with pytest.raises(TypeError): - trio.wrap_file(str()) + trio.wrap_file("") def test_wrap_non_iobase(): @@ -58,9 +61,7 @@ def test_dir_matches_wrapped(async_file, wrapped): attrs = _FILE_SYNC_ATTRS.union(_FILE_ASYNC_METHODS) # all supported attrs in wrapped should be available in async_file - assert all( - attr in dir(async_file) for attr in attrs if attr in dir(wrapped) - ) + assert all(attr in dir(async_file) for attr in attrs if attr in dir(wrapped)) # all supported attrs not in wrapped should not be available in async_file assert not any( attr in dir(async_file) for attr in attrs if attr not in dir(wrapped) @@ -74,10 +75,50 @@ def unsupported_attr(self): # pragma: no cover async_file = trio.wrap_file(FakeFile()) - assert hasattr(async_file.wrapped, 'unsupported_attr') + assert hasattr(async_file.wrapped, "unsupported_attr") with pytest.raises(AttributeError): - getattr(async_file, 'unsupported_attr') + getattr(async_file, "unsupported_attr") + + +def test_type_stubs_match_lists() -> None: + """Check the manual stubs match the list of wrapped methods.""" + # Fetch the module's source code. + assert _file_io.__spec__ is not None + loader = _file_io.__spec__.loader + assert isinstance(loader, importlib.abc.SourceLoader) + source = io.StringIO(loader.get_source("trio._file_io")) + + # Find the class, then find the TYPE_CHECKING block. + for line in source: + if "class AsyncIOWrapper" in line: + break + else: # pragma: no cover - should always find this + pytest.fail("No class definition line?") + + for line in source: + if "if TYPE_CHECKING" in line: + break + else: # pragma: no cover - should always find this + pytest.fail("No TYPE CHECKING line?") + + # Now we should be at the type checking block. + found: List[Tuple[str, str]] = [] + for line in source: # pragma: no branch - expected to break early + if line.strip() and not line.startswith(" " * 8): + break # Dedented out of the if TYPE_CHECKING block. + match = re.match(r"\s*(async )?def ([a-zA-Z0-9_]+)\(", line) + if match is not None: + kind = "async" if match.group(1) is not None else "sync" + found.append((match.group(2), kind)) + + # Compare two lists so that we can easily see duplicates, and see what is different overall. + expected = [(fname, "async") for fname in _FILE_ASYNC_METHODS] + expected += [(fname, "sync") for fname in _FILE_SYNC_ATTRS] + # Ignore order, error if duplicates are present. + found.sort() + expected.sort() + assert found == expected def test_sync_attrs_forwarded(async_file, wrapped): @@ -110,10 +151,10 @@ def test_async_methods_generated_once(async_file): def test_async_methods_signature(async_file): # use read as a representative of all async methods - assert async_file.read.__name__ == 'read' - assert async_file.read.__qualname__ == 'AsyncIOWrapper.read' + assert async_file.read.__name__ == "read" + assert async_file.read.__qualname__ == "AsyncIOWrapper.read" - assert 'io.StringIO.read' in async_file.read.__doc__ + assert "io.StringIO.read" in async_file.read.__doc__ async def test_async_methods_wrap(async_file, wrapped): @@ -147,7 +188,7 @@ async def test_async_methods_match_wrapper(async_file, wrapped): async def test_open(path): - f = await trio.open_file(path, 'w') + f = await trio.open_file(path, "w") assert isinstance(f, AsyncIOWrapper) @@ -155,7 +196,7 @@ async def test_open(path): async def test_open_context_manager(path): - async with await trio.open_file(path, 'w') as f: + async with await trio.open_file(path, "w") as f: assert isinstance(f, AsyncIOWrapper) assert not f.closed @@ -163,7 +204,7 @@ async def test_open_context_manager(path): async def test_async_iter(): - async_file = trio.wrap_file(io.StringIO('test\nfoo\nbar')) + async_file = trio.wrap_file(io.StringIO("test\nfoo\nbar")) expected = list(async_file.wrapped) result = [] async_file.wrapped.seek(0) @@ -176,11 +217,11 @@ async def test_async_iter(): async def test_aclose_cancelled(path): with _core.CancelScope() as cscope: - f = await trio.open_file(path, 'w') + f = await trio.open_file(path, "w") cscope.cancel() with pytest.raises(_core.Cancelled): - await f.write('a') + await f.write("a") with pytest.raises(_core.Cancelled): await f.aclose() diff --git a/trio/tests/test_highlevel_generic.py b/trio/_tests/test_highlevel_generic.py similarity index 98% rename from trio/tests/test_highlevel_generic.py rename to trio/_tests/test_highlevel_generic.py index df2b2cecf7..38bcedee25 100644 --- a/trio/tests/test_highlevel_generic.py +++ b/trio/_tests/test_highlevel_generic.py @@ -1,9 +1,8 @@ -import pytest - import attr +import pytest -from ..abc import SendStream, ReceiveStream from .._highlevel_generic import StapledStream +from ..abc import ReceiveStream, SendStream @attr.s diff --git a/trio/tests/test_highlevel_open_tcp_listeners.py b/trio/_tests/test_highlevel_open_tcp_listeners.py similarity index 68% rename from trio/tests/test_highlevel_open_tcp_listeners.py rename to trio/_tests/test_highlevel_open_tcp_listeners.py index 2bee358ee5..e58cbd13cc 100644 --- a/trio/tests/test_highlevel_open_tcp_listeners.py +++ b/trio/_tests/test_highlevel_open_tcp_listeners.py @@ -1,17 +1,19 @@ -import pytest - -import socket as stdlib_socket import errno +import socket as stdlib_socket +import sys import attr +import pytest import trio -from trio import ( - open_tcp_listeners, serve_tcp, SocketListener, open_tcp_stream -) +from trio import SocketListener, open_tcp_listeners, open_tcp_stream, serve_tcp from trio.testing import open_stream_to_socket_listener + from .. import socket as tsocket -from .._core.tests.tutil import slow, creates_ipv6, binds_ipv6 +from .._core._tests.tutil import binds_ipv6 + +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup async def test_open_tcp_listeners_basic(): @@ -54,67 +56,15 @@ async def test_open_tcp_listeners_specific_port_specific_host(): assert listener.socket.getsockname() == (host, port) -# Warning: this sleeps, and needs to use a real sleep -- MockClock won't -# work. -# -# Also, this measurement technique often works, but not always: sometimes SYN -# cookies get triggered, and then the backlog measured this way effectively -# becomes infinite. (In particular, this has been observed happening on -# Travis-CI.) To avoid this blowing up and eating all FDs / ephemeral ports, -# we put an upper limit on the number of connections we attempt, and if we hit -# it then we return the magic string "lots". Then -# test_open_tcp_listeners_backlog uses a special path to handle this, treating -# it as a success -- but at least we'll see in coverage if none of our test -# runs are actually running the test properly. -async def measure_backlog(listener, limit): - client_streams = [] - try: - while True: - # Generally the response to the listen buffer being full is that - # the SYN gets dropped, and the client retries after 1 second. So - # we assume that any connect() call to localhost that takes >0.5 - # seconds indicates a dropped SYN. - with trio.move_on_after(0.5) as cancel_scope: - client_stream = await open_stream_to_socket_listener(listener) - client_streams.append(client_stream) - if cancel_scope.cancelled_caught: - break - if len(client_streams) >= limit: # pragma: no cover - return "lots" - finally: - # The need for "no cover" here is subtle: see - # https://github.com/python-trio/trio/issues/522 - for client_stream in client_streams: # pragma: no cover - await client_stream.aclose() - - return len(client_streams) - - -@slow -async def test_open_tcp_listeners_backlog(): - # Operating systems don't necessarily use the exact backlog you pass - async def check_backlog(nominal, required_min, required_max): - listeners = await open_tcp_listeners(0, backlog=nominal) - actual = await measure_backlog(listeners[0], required_max + 10) - for listener in listeners: - await listener.aclose() - print("nominal", nominal, "actual", actual) - if actual == "lots": # pragma: no cover - return - assert required_min <= actual <= required_max - - await check_backlog(nominal=1, required_min=1, required_max=10) - await check_backlog(nominal=11, required_min=11, required_max=20) - - @binds_ipv6 async def test_open_tcp_listeners_ipv6_v6only(): # Check IPV6_V6ONLY is working properly (ipv6_listener,) = await open_tcp_listeners(0, host="::1") - _, port, *_ = ipv6_listener.socket.getsockname() + async with ipv6_listener: + _, port, *_ = ipv6_listener.socket.getsockname() - with pytest.raises(OSError): - await open_tcp_stream("127.0.0.1", port) + with pytest.raises(OSError): + await open_tcp_stream("127.0.0.1", port) async def test_open_tcp_listeners_rebind(): @@ -123,10 +73,10 @@ async def test_open_tcp_listeners_rebind(): # Plain old rebinding while it's still there should fail, even if we have # SO_REUSEADDR set - probe = stdlib_socket.socket() - probe.setsockopt(stdlib_socket.SOL_SOCKET, stdlib_socket.SO_REUSEADDR, 1) - with pytest.raises(OSError): - probe.bind(sockaddr1) + with stdlib_socket.socket() as probe: + probe.setsockopt(stdlib_socket.SOL_SOCKET, stdlib_socket.SO_REUSEADDR, 1) + with pytest.raises(OSError): + probe.bind(sockaddr1) # Now use the first listener to set up some connections in various states, # and make sure that they don't create any obstacle to rebinding a second @@ -170,6 +120,7 @@ class FakeSocket(tsocket.SocketType): closed = attr.ib(default=False) poison_listen = attr.ib(default=False) + backlog = attr.ib(default=None) def getsockopt(self, level, option): if (level, option) == (tsocket.SOL_SOCKET, tsocket.SO_ACCEPTCONN): @@ -183,6 +134,9 @@ async def bind(self, sockaddr): pass def listen(self, backlog): + assert self.backlog is None + assert backlog is not None + self.backlog = backlog if self.poison_listen: raise FakeOSError("whoops") @@ -264,28 +218,18 @@ async def handler(stream): @pytest.mark.parametrize( - "try_families", [ - {tsocket.AF_INET}, - {tsocket.AF_INET6}, - {tsocket.AF_INET, tsocket.AF_INET6}, - ] + "try_families", + [{tsocket.AF_INET}, {tsocket.AF_INET6}, {tsocket.AF_INET, tsocket.AF_INET6}], ) @pytest.mark.parametrize( - "fail_families", [ - {tsocket.AF_INET}, - {tsocket.AF_INET6}, - {tsocket.AF_INET, tsocket.AF_INET6}, - ] + "fail_families", + [{tsocket.AF_INET}, {tsocket.AF_INET6}, {tsocket.AF_INET, tsocket.AF_INET6}], ) async def test_open_tcp_listeners_some_address_families_unavailable( try_families, fail_families ): fsf = FakeSocketFactory( - 10, - raise_on_family={ - family: errno.EAFNOSUPPORT - for family in fail_families - } + 10, raise_on_family={family: errno.EAFNOSUPPORT for family in fail_families} ) tsocket.set_custom_socket_factory(fsf) tsocket.set_custom_hostname_resolver( @@ -299,7 +243,7 @@ async def test_open_tcp_listeners_some_address_families_unavailable( await open_tcp_listeners(80, host="example.org") assert "This system doesn't support" in str(exc_info.value) - if isinstance(exc_info.value.__cause__, trio.MultiError): + if isinstance(exc_info.value.__cause__, BaseExceptionGroup): for subexc in exc_info.value.__cause__.exceptions: assert "nope" in str(subexc) else: @@ -318,13 +262,11 @@ async def test_open_tcp_listeners_socket_fails_not_afnosupport(): raise_on_family={ tsocket.AF_INET: errno.EAFNOSUPPORT, tsocket.AF_INET6: errno.EINVAL, - } + }, ) tsocket.set_custom_socket_factory(fsf) tsocket.set_custom_hostname_resolver( - FakeHostnameResolver( - [(tsocket.AF_INET, "foo"), (tsocket.AF_INET6, "bar")] - ) + FakeHostnameResolver([(tsocket.AF_INET, "foo"), (tsocket.AF_INET6, "bar")]) ) with pytest.raises(OSError) as exc_info: @@ -332,3 +274,26 @@ async def test_open_tcp_listeners_socket_fails_not_afnosupport(): assert exc_info.value.errno == errno.EINVAL assert exc_info.value.__cause__ is None assert "nope" in str(exc_info.value) + + +# We used to have an elaborate test that opened a real TCP listening socket +# and then tried to measure its backlog by making connections to it. And most +# of the time, it worked. But no matter what we tried, it was always fragile, +# because it had to do things like use timeouts to guess when the listening +# queue was full, sometimes the CI hosts go into SYN-cookie mode (where there +# effectively is no backlog), sometimes the host might not be enough resources +# to give us the full requested backlog... it was a mess. So now we just check +# that the backlog argument is passed through correctly. +async def test_open_tcp_listeners_backlog(): + fsf = FakeSocketFactory(99) + tsocket.set_custom_socket_factory(fsf) + for given, expected in [ + (None, 0xFFFF), + (99999999, 0xFFFF), + (10, 10), + (1, 1), + ]: + listeners = await open_tcp_listeners(0, backlog=given) + assert listeners + for listener in listeners: + assert listener.socket.backlog == expected diff --git a/trio/tests/test_highlevel_open_tcp_stream.py b/trio/_tests/test_highlevel_open_tcp_stream.py similarity index 83% rename from trio/tests/test_highlevel_open_tcp_stream.py rename to trio/_tests/test_highlevel_open_tcp_stream.py index 0c28450e5e..24f82bddd5 100644 --- a/trio/tests/test_highlevel_open_tcp_stream.py +++ b/trio/_tests/test_highlevel_open_tcp_stream.py @@ -1,15 +1,20 @@ -import pytest +import socket +import sys import attr +import pytest import trio -from trio.socket import AF_INET, AF_INET6, SOCK_STREAM, IPPROTO_TCP from trio._highlevel_open_tcp_stream import ( - reorder_for_rfc_6555_section_5_4, close_all, - open_tcp_stream, format_host_port, + open_tcp_stream, + reorder_for_rfc_6555_section_5_4, ) +from trio.socket import AF_INET, AF_INET6, IPPROTO_TCP, SOCK_STREAM + +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup def test_close_all(): @@ -46,14 +51,18 @@ def close(self): def test_reorder_for_rfc_6555_section_5_4(): def fake4(i): return ( - AF_INET, SOCK_STREAM, IPPROTO_TCP, "", ("10.0.0.{}".format(i), 80) + AF_INET, + SOCK_STREAM, + IPPROTO_TCP, + "", + (f"10.0.0.{i}", 80), ) def fake6(i): - return (AF_INET6, SOCK_STREAM, IPPROTO_TCP, "", ("::{}".format(i), 80)) + return (AF_INET6, SOCK_STREAM, IPPROTO_TCP, "", (f"::{i}", 80)) for fake in fake4, fake6: - # No effect on homogenous lists + # No effect on homogeneous lists targets = [fake(0), fake(1), fake(2)] reorder_for_rfc_6555_section_5_4(targets) assert targets == [fake(0), fake(1), fake(2)] @@ -108,6 +117,62 @@ async def test_open_tcp_stream_input_validation(): await open_tcp_stream("127.0.0.1", b"80") +def can_bind_127_0_0_2(): + with socket.socket() as s: + try: + s.bind(("127.0.0.2", 0)) + except OSError: + return False + return s.getsockname()[0] == "127.0.0.2" + + +async def test_local_address_real(): + with trio.socket.socket() as listener: + await listener.bind(("127.0.0.1", 0)) + listener.listen() + + # It's hard to test local_address properly, because you need multiple + # local addresses that you can bind to. Fortunately, on most Linux + # systems, you can bind to any 127.*.*.* address, and they all go + # through the loopback interface. So we can use a non-standard + # loopback address. On other systems, the only address we know for + # certain we have is 127.0.0.1, so we can't really test local_address= + # properly -- passing local_address=127.0.0.1 is indistinguishable + # from not passing local_address= at all. But, we can still do a smoke + # test to make sure the local_address= code doesn't crash. + if can_bind_127_0_0_2(): + local_address = "127.0.0.2" + else: + local_address = "127.0.0.1" + + async with await open_tcp_stream( + *listener.getsockname(), local_address=local_address + ) as client_stream: + assert client_stream.socket.getsockname()[0] == local_address + if hasattr(trio.socket, "IP_BIND_ADDRESS_NO_PORT"): + assert client_stream.socket.getsockopt( + trio.socket.IPPROTO_IP, trio.socket.IP_BIND_ADDRESS_NO_PORT + ) + + server_sock, remote_addr = await listener.accept() + await client_stream.aclose() + server_sock.close() + assert remote_addr[0] == local_address + + # Trying to connect to an ipv4 address with the ipv6 wildcard + # local_address should fail + with pytest.raises(OSError): + await open_tcp_stream(*listener.getsockname(), local_address="::") + + # But the ipv4 wildcard address should work + async with await open_tcp_stream( + *listener.getsockname(), local_address="0.0.0.0" + ) as client_stream: + server_sock, remote_addr = await listener.accept() + server_sock.close() + assert remote_addr == client_stream.socket.getsockname() + + # Now, thorough tests using fake sockets @@ -225,7 +290,7 @@ async def run_scenario( # If this is True, we require there to be an exception, and return # (exception, scenario object) expect_error=(), - **kwargs + **kwargs, ): supported_families = set() if ipv4_supported: @@ -278,8 +343,7 @@ async def test_one_host_slow_fail(autojump_clock): async def test_one_host_failed_after_connect(autojump_clock): exc, scenario = await run_scenario( - 83, [("1.2.3.4", 1, "postconnect_fail")], - expect_error=KeyboardInterrupt + 83, [("1.2.3.4", 1, "postconnect_fail")], expect_error=KeyboardInterrupt ) assert isinstance(exc, KeyboardInterrupt) @@ -375,7 +439,7 @@ async def test_all_fail(autojump_clock): expect_error=OSError, ) assert isinstance(exc, OSError) - assert isinstance(exc.__cause__, trio.MultiError) + assert isinstance(exc.__cause__, BaseExceptionGroup) assert len(exc.__cause__.exceptions) == 4 assert trio.current_time() == (0.1 + 0.2 + 10) assert scenario.connect_times == { @@ -495,7 +559,7 @@ async def test_cancel(autojump_clock): ("3.3.3.3", 10, "success"), ("4.4.4.4", 10, "success"), ], - expect_error=trio.MultiError, + expect_error=BaseExceptionGroup, ) # What comes out should be 1 or more Cancelled errors that all belong # to this cancel_scope; this is the easiest way to check that diff --git a/trio/tests/test_highlevel_open_unix_stream.py b/trio/_tests/test_highlevel_open_unix_stream.py similarity index 91% rename from trio/tests/test_highlevel_open_unix_stream.py rename to trio/_tests/test_highlevel_open_unix_stream.py index 872a43dd6d..64a15f9e9d 100644 --- a/trio/tests/test_highlevel_open_unix_stream.py +++ b/trio/_tests/test_highlevel_open_unix_stream.py @@ -4,10 +4,8 @@ import pytest -from trio import open_unix_socket, Path -from trio._highlevel_open_unix_stream import ( - close_on_error, -) +from trio import Path, open_unix_socket +from trio._highlevel_open_unix_stream import close_on_error if not hasattr(socket, "AF_UNIX"): pytestmark = pytest.mark.skip("Needs unix socket support") @@ -30,7 +28,7 @@ def close(self): assert c.closed -@pytest.mark.parametrize('filename', [4, 4.5]) +@pytest.mark.parametrize("filename", [4, 4.5]) async def test_open_with_bad_filename_type(filename): with pytest.raises(TypeError): await open_unix_socket(filename) diff --git a/trio/tests/test_highlevel_serve_listeners.py b/trio/_tests/test_highlevel_serve_listeners.py similarity index 95% rename from trio/tests/test_highlevel_serve_listeners.py rename to trio/_tests/test_highlevel_serve_listeners.py index ead37cac87..4385263899 100644 --- a/trio/tests/test_highlevel_serve_listeners.py +++ b/trio/_tests/test_highlevel_serve_listeners.py @@ -1,9 +1,8 @@ -import pytest - -from functools import partial import errno +from functools import partial import attr +import pytest import trio from trio.testing import memory_stream_pair, wait_all_tasks_blocked @@ -23,7 +22,7 @@ async def connect(self): return client async def accept(self): - await trio.hazmat.checkpoint() + await trio.lowlevel.checkpoint() assert not self.closed if self.accept_hook is not None: await self.accept_hook() @@ -33,7 +32,7 @@ async def accept(self): async def aclose(self): self.closed = True - await trio.hazmat.checkpoint() + await trio.lowlevel.checkpoint() async def test_serve_listeners_basic(): @@ -136,8 +135,9 @@ async def connection_watcher(*, task_status=trio.TASK_STATUS_IGNORED): await nursery.start( partial( trio.serve_listeners, - handler, [listener], - handler_nursery=handler_nursery + handler, + [listener], + handler_nursery=handler_nursery, ) ) for _ in range(10): diff --git a/trio/tests/test_highlevel_socket.py b/trio/_tests/test_highlevel_socket.py similarity index 96% rename from trio/tests/test_highlevel_socket.py rename to trio/_tests/test_highlevel_socket.py index dc19219e3e..14143affe2 100644 --- a/trio/tests/test_highlevel_socket.py +++ b/trio/_tests/test_highlevel_socket.py @@ -1,15 +1,16 @@ -import pytest - -import sys -import socket as stdlib_socket import errno +import socket as stdlib_socket +import sys + +import pytest -from .. import _core +from .. import _core, socket as tsocket +from .._highlevel_socket import * from ..testing import ( - check_half_closeable_stream, wait_all_tasks_blocked, assert_checkpoints + assert_checkpoints, + check_half_closeable_stream, + wait_all_tasks_blocked, ) -from .._highlevel_socket import * -from .. import socket as tsocket async def test_SocketStream_basics(): @@ -257,8 +258,9 @@ async def accept(self): async def test_socket_stream_works_when_peer_has_already_closed(): sock_a, sock_b = tsocket.socketpair() - await sock_b.send(b"x") - sock_b.close() - stream = SocketStream(sock_a) - assert await stream.receive_some(1) == b"x" - assert await stream.receive_some(1) == b"" + with sock_a, sock_b: + await sock_b.send(b"x") + sock_b.close() + stream = SocketStream(sock_a) + assert await stream.receive_some(1) == b"x" + assert await stream.receive_some(1) == b"" diff --git a/trio/_tests/test_highlevel_ssl_helpers.py b/trio/_tests/test_highlevel_ssl_helpers.py new file mode 100644 index 0000000000..f6eda0b578 --- /dev/null +++ b/trio/_tests/test_highlevel_ssl_helpers.py @@ -0,0 +1,114 @@ +from functools import partial + +import attr +import pytest + +import trio +import trio.testing +from trio.socket import AF_INET, IPPROTO_TCP, SOCK_STREAM + +from .._highlevel_ssl_helpers import ( + open_ssl_over_tcp_listeners, + open_ssl_over_tcp_stream, + serve_ssl_over_tcp, +) + +# noqa is needed because flake8 doesn't understand how pytest fixtures work. +from .test_ssl import SERVER_CTX, client_ctx # noqa: F401 + + +async def echo_handler(stream): + async with stream: + try: + while True: + data = await stream.receive_some(10000) + if not data: + break + await stream.send_all(data) + except trio.BrokenResourceError: + pass + + +# Resolver that always returns the given sockaddr, no matter what host/port +# you ask for. +@attr.s +class FakeHostnameResolver(trio.abc.HostnameResolver): + sockaddr = attr.ib() + + async def getaddrinfo(self, *args): + return [(AF_INET, SOCK_STREAM, IPPROTO_TCP, "", self.sockaddr)] + + async def getnameinfo(self, *args): # pragma: no cover + raise NotImplementedError + + +# This uses serve_ssl_over_tcp, which uses open_ssl_over_tcp_listeners... +# noqa is needed because flake8 doesn't understand how pytest fixtures work. +async def test_open_ssl_over_tcp_stream_and_everything_else(client_ctx): # noqa: F811 + async with trio.open_nursery() as nursery: + (listener,) = await nursery.start( + partial(serve_ssl_over_tcp, echo_handler, 0, SERVER_CTX, host="127.0.0.1") + ) + async with listener: + sockaddr = listener.transport_listener.socket.getsockname() + hostname_resolver = FakeHostnameResolver(sockaddr) + trio.socket.set_custom_hostname_resolver(hostname_resolver) + + # We don't have the right trust set up + # (checks that ssl_context=None is doing some validation) + stream = await open_ssl_over_tcp_stream("trio-test-1.example.org", 80) + async with stream: + with pytest.raises(trio.BrokenResourceError): + await stream.do_handshake() + + # We have the trust but not the hostname + # (checks custom ssl_context + hostname checking) + stream = await open_ssl_over_tcp_stream( + "xyzzy.example.org", 80, ssl_context=client_ctx + ) + async with stream: + with pytest.raises(trio.BrokenResourceError): + await stream.do_handshake() + + # This one should work! + stream = await open_ssl_over_tcp_stream( + "trio-test-1.example.org", 80, ssl_context=client_ctx + ) + async with stream: + assert isinstance(stream, trio.SSLStream) + assert stream.server_hostname == "trio-test-1.example.org" + await stream.send_all(b"x") + assert await stream.receive_some(1) == b"x" + + # Check https_compatible settings are being passed through + assert not stream._https_compatible + stream = await open_ssl_over_tcp_stream( + "trio-test-1.example.org", + 80, + ssl_context=client_ctx, + https_compatible=True, + # also, smoke test happy_eyeballs_delay + happy_eyeballs_delay=1, + ) + async with stream: + assert stream._https_compatible + + # Stop the echo server + nursery.cancel_scope.cancel() + + +async def test_open_ssl_over_tcp_listeners(): + (listener,) = await open_ssl_over_tcp_listeners(0, SERVER_CTX, host="127.0.0.1") + async with listener: + assert isinstance(listener, trio.SSLListener) + tl = listener.transport_listener + assert isinstance(tl, trio.SocketListener) + assert tl.socket.getsockname()[0] == "127.0.0.1" + + assert not listener._https_compatible + + (listener,) = await open_ssl_over_tcp_listeners( + 0, SERVER_CTX, host="127.0.0.1", https_compatible=True + ) + async with listener: + assert listener._https_compatible diff --git a/trio/tests/test_path.py b/trio/_tests/test_path.py similarity index 61% rename from trio/tests/test_path.py rename to trio/_tests/test_path.py index 67d7c2957e..bfef1aaf2c 100644 --- a/trio/tests/test_path.py +++ b/trio/_tests/test_path.py @@ -4,14 +4,13 @@ import pytest import trio -from trio._path import AsyncAutoWrapperType as Type -from trio._util import fspath from trio._file_io import AsyncIOWrapper +from trio._path import AsyncAutoWrapperType as Type @pytest.fixture def path(tmpdir): - p = str(tmpdir.join('test')) + p = str(tmpdir.join("test")) return trio.Path(p) @@ -22,32 +21,33 @@ def method_pair(path, method_name): async def test_open_is_async_context_manager(path): - async with await path.open('w') as f: + async with await path.open("w") as f: assert isinstance(f, AsyncIOWrapper) assert f.closed async def test_magic(): - path = trio.Path('test') + path = trio.Path("test") - assert str(path) == 'test' - assert bytes(path) == b'test' + assert str(path) == "test" + assert bytes(path) == b"test" cls_pairs = [ - (trio.Path, pathlib.Path), (pathlib.Path, trio.Path), - (trio.Path, trio.Path) + (trio.Path, pathlib.Path), + (pathlib.Path, trio.Path), + (trio.Path, trio.Path), ] -@pytest.mark.parametrize('cls_a,cls_b', cls_pairs) +@pytest.mark.parametrize("cls_a,cls_b", cls_pairs) async def test_cmp_magic(cls_a, cls_b): - a, b = cls_a(''), cls_b('') + a, b = cls_a(""), cls_b("") assert a == b assert not a != b - a, b = cls_a('a'), cls_b('b') + a, b = cls_a("a"), cls_b("b") assert a < b assert b > a @@ -57,28 +57,30 @@ async def test_cmp_magic(cls_a, cls_b): assert not b == None # noqa -# upstream python3.5 bug: we should also test (pathlib.Path, trio.Path), but +# upstream python3.8 bug: we should also test (pathlib.Path, trio.Path), but # __*div__ does not properly raise NotImplementedError like the other comparison # magic, so trio.Path's implementation does not get dispatched cls_pairs = [ - (trio.Path, pathlib.Path), (trio.Path, trio.Path), (trio.Path, str), - (str, trio.Path) + (trio.Path, pathlib.Path), + (trio.Path, trio.Path), + (trio.Path, str), + (str, trio.Path), ] -@pytest.mark.parametrize('cls_a,cls_b', cls_pairs) +@pytest.mark.parametrize("cls_a,cls_b", cls_pairs) async def test_div_magic(cls_a, cls_b): - a, b = cls_a('a'), cls_b('b') + a, b = cls_a("a"), cls_b("b") result = a / b assert isinstance(result, trio.Path) - assert str(result) == os.path.join('a', 'b') + assert str(result) == os.path.join("a", "b") @pytest.mark.parametrize( - 'cls_a,cls_b', [(trio.Path, pathlib.Path), (trio.Path, trio.Path)] + "cls_a,cls_b", [(trio.Path, pathlib.Path), (trio.Path, trio.Path)] ) -@pytest.mark.parametrize('path', ["foo", "foo/bar/baz", "./foo"]) +@pytest.mark.parametrize("path", ["foo", "foo/bar/baz", "./foo"]) async def test_hash_magic(cls_a, cls_b, path): a, b = cls_a(path), cls_b(path) assert hash(a) == hash(b) @@ -87,23 +89,22 @@ async def test_hash_magic(cls_a, cls_b, path): async def test_forwarded_properties(path): # use `name` as a representative of forwarded properties - assert 'name' in dir(path) - assert path.name == 'test' + assert "name" in dir(path) + assert path.name == "test" async def test_async_method_signature(path): # use `resolve` as a representative of wrapped methods - assert path.resolve.__name__ == 'resolve' - assert path.resolve.__qualname__ == 'Path.resolve' + assert path.resolve.__name__ == "resolve" + assert path.resolve.__qualname__ == "Path.resolve" - assert 'pathlib.Path.resolve' in path.resolve.__doc__ + assert "pathlib.Path.resolve" in path.resolve.__doc__ -@pytest.mark.parametrize('method_name', ['is_dir', 'is_file']) +@pytest.mark.parametrize("method_name", ["is_dir", "is_file"]) async def test_compare_async_stat_methods(method_name): - - method, async_method = method_pair('.', method_name) + method, async_method = method_pair(".", method_name) result = method() async_result = await async_method() @@ -113,13 +114,12 @@ async def test_compare_async_stat_methods(method_name): async def test_invalid_name_not_wrapped(path): with pytest.raises(AttributeError): - getattr(path, 'invalid_fake_attr') + getattr(path, "invalid_fake_attr") -@pytest.mark.parametrize('method_name', ['absolute', 'resolve']) +@pytest.mark.parametrize("method_name", ["absolute", "resolve"]) async def test_async_methods_rewrap(method_name): - - method, async_method = method_pair('.', method_name) + method, async_method = method_pair(".", method_name) result = method() async_result = await async_method() @@ -129,13 +129,13 @@ async def test_async_methods_rewrap(method_name): async def test_forward_methods_rewrap(path, tmpdir): - with_name = path.with_name('foo') - with_suffix = path.with_suffix('.py') + with_name = path.with_name("foo") + with_suffix = path.with_suffix(".py") assert isinstance(with_name, trio.Path) - assert with_name == tmpdir.join('foo') + assert with_name == tmpdir.join("foo") assert isinstance(with_suffix, trio.Path) - assert with_suffix == tmpdir.join('test.py') + assert with_suffix == tmpdir.join("test.py") async def test_forward_properties_rewrap(path): @@ -145,18 +145,18 @@ async def test_forward_properties_rewrap(path): async def test_forward_methods_without_rewrap(path, tmpdir): path = await path.parent.resolve() - assert path.as_uri().startswith('file:///') + assert path.as_uri().startswith("file:///") async def test_repr(): - path = trio.Path('.') + path = trio.Path(".") assert repr(path) == "trio.Path('.')" class MockWrapped: - unsupported = 'unsupported' - _private = 'private' + unsupported = "unsupported" + _private = "private" class MockWrapper: @@ -175,18 +175,18 @@ async def test_type_wraps_unsupported(): async def test_type_forwards_private(): - Type.generate_forwards(MockWrapper, {'unsupported': None}) + Type.generate_forwards(MockWrapper, {"unsupported": None}) - assert not hasattr(MockWrapper, '_private') + assert not hasattr(MockWrapper, "_private") async def test_type_wraps_private(): - Type.generate_wraps(MockWrapper, {'unsupported': None}) + Type.generate_wraps(MockWrapper, {"unsupported": None}) - assert not hasattr(MockWrapper, '_private') + assert not hasattr(MockWrapper, "_private") -@pytest.mark.parametrize('meth', [trio.Path.__init__, trio.Path.joinpath]) +@pytest.mark.parametrize("meth", [trio.Path.__init__, trio.Path.joinpath]) async def test_path_wraps_path(path, meth): wrapped = await path.absolute() result = meth(path, wrapped) @@ -202,22 +202,22 @@ async def test_path_nonpath(): async def test_open_file_can_open_path(path): - async with await trio.open_file(path, 'w') as f: - assert f.name == fspath(path) + async with await trio.open_file(path, "w") as f: + assert f.name == os.fspath(path) async def test_globmethods(path): # Populate a directory tree await path.mkdir() - await (path / 'foo').mkdir() - await (path / 'foo' / '_bar.txt').write_bytes(b'') - await (path / 'bar.txt').write_bytes(b'') - await (path / 'bar.dat').write_bytes(b'') + await (path / "foo").mkdir() + await (path / "foo" / "_bar.txt").write_bytes(b"") + await (path / "bar.txt").write_bytes(b"") + await (path / "bar.dat").write_bytes(b"") # Path.glob for _pattern, _results in { - '*.txt': {'bar.txt'}, - '**/*.txt': {'_bar.txt', 'bar.txt'}, + "*.txt": {"bar.txt"}, + "**/*.txt": {"_bar.txt", "bar.txt"}, }.items(): entries = set() for entry in await path.glob(_pattern): @@ -228,32 +228,32 @@ async def test_globmethods(path): # Path.rglob entries = set() - for entry in await path.rglob('*.txt'): + for entry in await path.rglob("*.txt"): assert isinstance(entry, trio.Path) entries.add(entry.name) - assert entries == {'_bar.txt', 'bar.txt'} + assert entries == {"_bar.txt", "bar.txt"} async def test_iterdir(path): # Populate a directory await path.mkdir() - await (path / 'foo').mkdir() - await (path / 'bar.txt').write_bytes(b'') + await (path / "foo").mkdir() + await (path / "bar.txt").write_bytes(b"") entries = set() for entry in await path.iterdir(): assert isinstance(entry, trio.Path) entries.add(entry.name) - assert entries == {'bar.txt', 'foo'} + assert entries == {"bar.txt", "foo"} async def test_classmethods(): assert isinstance(await trio.Path.home(), trio.Path) # pathlib.Path has only two classmethods - assert str(await trio.Path.home()) == os.path.expanduser('~') + assert str(await trio.Path.home()) == os.path.expanduser("~") assert str(await trio.Path.cwd()) == os.getcwd() # Wrapped method has docstring diff --git a/trio/tests/test_scheduler_determinism.py b/trio/_tests/test_scheduler_determinism.py similarity index 89% rename from trio/tests/test_scheduler_determinism.py rename to trio/_tests/test_scheduler_determinism.py index ba5f469396..e2d3167e45 100644 --- a/trio/tests/test_scheduler_determinism.py +++ b/trio/_tests/test_scheduler_determinism.py @@ -6,7 +6,7 @@ async def scheduler_trace(): trace = [] async def tracer(name): - for i in range(10): + for i in range(50): trace.append((name, i)) await trio.sleep(0) @@ -26,9 +26,7 @@ def test_the_trio_scheduler_is_not_deterministic(): def test_the_trio_scheduler_is_deterministic_if_seeded(monkeypatch): - monkeypatch.setattr( - trio._core._run, "_ALLOW_DETERMINISTIC_SCHEDULING", True - ) + monkeypatch.setattr(trio._core._run, "_ALLOW_DETERMINISTIC_SCHEDULING", True) traces = [] for _ in range(10): state = trio._core._run._r.getstate() diff --git a/trio/tests/test_signals.py b/trio/_tests/test_signals.py similarity index 95% rename from trio/tests/test_signals.py rename to trio/_tests/test_signals.py index 7ae930403c..313cce259f 100644 --- a/trio/tests/test_signals.py +++ b/trio/_tests/test_signals.py @@ -3,9 +3,10 @@ import pytest import trio + from .. import _core +from .._signals import _signal_handler, open_signal_receiver from .._util import signal_raise -from .._signals import open_signal_receiver, _signal_handler async def test_open_signal_receiver(): @@ -41,6 +42,12 @@ async def test_open_signal_receiver_restore_handler_after_one_bad_signal(): assert signal.getsignal(signal.SIGILL) is orig +async def test_open_signal_receiver_empty_fail(): + with pytest.raises(TypeError, match="No signals were provided"): + with open_signal_receiver(): + pass + + async def test_open_signal_receiver_restore_handler_after_duplicate_signal(): orig = signal.getsignal(signal.SIGILL) with open_signal_receiver(signal.SIGILL, signal.SIGILL): @@ -102,6 +109,7 @@ async def test_open_signal_receiver_no_starvation(): # open_signal_receiver block might cause the signal to be # redelivered and give us a core dump instead of a traceback... import traceback + traceback.print_exc() @@ -158,9 +166,7 @@ def raise_handler(signum, _): with _signal_handler({signal.SIGILL, signal.SIGFPE}, raise_handler): with pytest.raises(RuntimeError) as excinfo: - with open_signal_receiver( - signal.SIGILL, signal.SIGFPE - ) as receiver: + with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver: signal_raise(signal.SIGILL) signal_raise(signal.SIGFPE) await wait_run_sync_soon_idempotent_queue_barrier() diff --git a/trio/tests/test_socket.py b/trio/_tests/test_socket.py similarity index 75% rename from trio/tests/test_socket.py rename to trio/_tests/test_socket.py index c03c8bb8ff..e9baff436a 100644 --- a/trio/tests/test_socket.py +++ b/trio/_tests/test_socket.py @@ -1,15 +1,15 @@ -import pytest -import attr - +import errno +import inspect import os import socket as stdlib_socket -import inspect +import sys import tempfile -import sys as _sys -from .._core.tests.tutil import creates_ipv6, binds_ipv6 -from .. import _core -from .. import _socket as _tsocket -from .. import socket as tsocket + +import attr +import pytest + +from .. import _core, socket as tsocket +from .._core._tests.tutil import binds_ipv6, creates_ipv6 from .._socket import _NUMERIC_ONLY, _try_sync from ..testing import assert_checkpoints, wait_all_tasks_blocked @@ -44,9 +44,7 @@ def getaddrinfo(self, *args, **kwargs): elif bound[-1] & stdlib_socket.AI_NUMERICHOST: return self._orig_getaddrinfo(*args, **kwargs) else: - raise RuntimeError( - "gai called with unexpected arguments {}".format(bound) - ) + raise RuntimeError(f"gai called with unexpected arguments {bound}") @pytest.fixture @@ -101,37 +99,51 @@ def test_socket_has_some_reexports(): async def test_getaddrinfo(monkeygai): def check(got, expected): # win32 returns 0 for the proto field - def without_proto(gai_tup): - return gai_tup[:2] + (0,) + gai_tup[3:] + # musl and glibc have inconsistent handling of the canonical name + # field (https://github.com/python-trio/trio/issues/1499) + # Neither field gets used much and there isn't much opportunity for us + # to mess them up, so we don't bother checking them here + def interesting_fields(gai_tup): + # (family, type, proto, canonname, sockaddr) + family, type, proto, canonname, sockaddr = gai_tup + return (family, type, sockaddr) - expected2 = [without_proto(gt) for gt in expected] - assert got == expected or got == expected2 + def filtered(gai_list): + return [interesting_fields(gai_tup) for gai_tup in gai_list] + + assert filtered(got) == filtered(expected) # Simple non-blocking non-error cases, ipv4 and ipv6: with assert_checkpoints(): - res = await tsocket.getaddrinfo( - "127.0.0.1", "12345", type=tsocket.SOCK_STREAM - ) - - check(res, [ - (tsocket.AF_INET, # 127.0.0.1 is ipv4 - tsocket.SOCK_STREAM, - tsocket.IPPROTO_TCP, - "", - ("127.0.0.1", 12345)), - ]) # yapf: disable + res = await tsocket.getaddrinfo("127.0.0.1", "12345", type=tsocket.SOCK_STREAM) + + check( + res, + [ + ( + tsocket.AF_INET, # 127.0.0.1 is ipv4 + tsocket.SOCK_STREAM, + tsocket.IPPROTO_TCP, + "", + ("127.0.0.1", 12345), + ), + ], + ) with assert_checkpoints(): - res = await tsocket.getaddrinfo( - "::1", "12345", type=tsocket.SOCK_DGRAM - ) - check(res, [ - (tsocket.AF_INET6, - tsocket.SOCK_DGRAM, - tsocket.IPPROTO_UDP, - "", - ("::1", 12345, 0, 0)), - ]) # yapf: disable + res = await tsocket.getaddrinfo("::1", "12345", type=tsocket.SOCK_DGRAM) + check( + res, + [ + ( + tsocket.AF_INET6, + tsocket.SOCK_DGRAM, + tsocket.IPPROTO_UDP, + "", + ("::1", 12345, 0, 0), + ), + ], + ) monkeygai.set("x", b"host", "port", family=0, type=0, proto=0, flags=0) with assert_checkpoints(): @@ -143,8 +155,10 @@ def without_proto(gai_tup): with assert_checkpoints(): with pytest.raises(tsocket.gaierror) as excinfo: await tsocket.getaddrinfo("::1", "12345", type=-1) - # Linux, Windows + # Linux + glibc, Windows expected_errnos = {tsocket.EAI_SOCKTYPE} + # Linux + musl + expected_errnos.add(tsocket.EAI_SERVICE) # macOS if hasattr(tsocket, "EAI_BADHINTS"): expected_errnos.add(tsocket.EAI_BADHINTS) @@ -210,9 +224,9 @@ async def test_from_stdlib_socket(): class MySocket(stdlib_socket.socket): pass - mysock = MySocket() - with pytest.raises(TypeError): - tsocket.from_stdlib_socket(mysock) + with MySocket() as mysock: + with pytest.raises(TypeError): + tsocket.from_stdlib_socket(mysock) async def test_from_fd(): @@ -263,9 +277,10 @@ async def test_socket_v6(): assert s.family == tsocket.AF_INET6 -@pytest.mark.skipif(not _sys.platform == "linux", reason="linux only") +@pytest.mark.skipif(not sys.platform == "linux", reason="linux only") async def test_sniff_sockopts(): from socket import AF_INET, AF_INET6, SOCK_DGRAM, SOCK_STREAM + # generate the combinations of families/types we're testing: sockets = [] for family in [AF_INET, AF_INET6]: @@ -277,12 +292,15 @@ async def test_sniff_sockopts(): # check family / type for correctness: assert tsocket_socket.family == socket.family assert tsocket_socket.type == socket.type + tsocket_socket.detach() # fromfd constructor tsocket_from_fd = tsocket.fromfd(socket.fileno(), AF_INET, SOCK_STREAM) # check family / type for correctness: assert tsocket_from_fd.family == socket.family assert tsocket_from_fd.type == socket.type + tsocket_from_fd.close() + socket.close() @@ -336,12 +354,32 @@ async def test_SocketType_basics(): # type family proto stdlib_sock = stdlib_socket.socket() sock = tsocket.from_stdlib_socket(stdlib_sock) - assert sock.type == _tsocket.real_socket_type(stdlib_sock.type) + assert sock.type == stdlib_sock.type assert sock.family == stdlib_sock.family assert sock.proto == stdlib_sock.proto sock.close() +async def test_SocketType_setsockopt(): + sock = tsocket.socket() + with sock as _: + # specifying optlen. Not supported on pypy, and I couldn't find + # valid calls on darwin or win32. + if hasattr(tsocket, "SO_BINDTODEVICE"): + sock.setsockopt(tsocket.SOL_SOCKET, tsocket.SO_BINDTODEVICE, None, 0) + + # specifying value + sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, False) + + # specifying both + with pytest.raises(TypeError, match="invalid value for argument 'value'"): + sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, False, 5) # type: ignore[call-overload] + + # specifying neither + with pytest.raises(TypeError, match="invalid value for argument 'value'"): + sock.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, None) # type: ignore[call-overload] + + async def test_SocketType_dup(): a, b = tsocket.socketpair() with a, b: @@ -382,10 +420,11 @@ async def test_SocketType_shutdown(): @pytest.mark.parametrize( - "address, socket_type", [ - ('127.0.0.1', tsocket.AF_INET), - pytest.param('::1', tsocket.AF_INET6, marks=binds_ipv6) - ] + "address, socket_type", + [ + ("127.0.0.1", tsocket.AF_INET), + pytest.param("::1", tsocket.AF_INET6, marks=binds_ipv6), + ], ) async def test_SocketType_simple_server(address, socket_type): # listen, bind, accept, connect, getpeername, getsockname @@ -433,12 +472,12 @@ class Addresses: localhost = attr.ib() arbitrary = attr.ib() broadcast = attr.ib() - extra = attr.ib() # Direct thorough tests of the implicit resolver helpers @pytest.mark.parametrize( - "socket_type, addrs", [ + "socket_type, addrs", + [ ( tsocket.AF_INET, Addresses( @@ -446,7 +485,6 @@ class Addresses: localhost="127.0.0.1", arbitrary="1.2.3.4", broadcast="255.255.255.255", - extra=(), ), ), pytest.param( @@ -456,83 +494,104 @@ class Addresses: localhost="::1", arbitrary="1::2", broadcast="::ffff:255.255.255.255", - extra=(0, 0), ), marks=creates_ipv6, ), - ] + ], ) async def test_SocketType_resolve(socket_type, addrs): - v6 = (socket_type == tsocket.AF_INET6) - - # For some reason the stdlib special-cases "" to pass NULL to getaddrinfo - # They also error out on None, but whatever, None is much more consistent, - # so we accept it too. - for null in [None, ""]: - sock = tsocket.socket(family=socket_type) - got = await sock._resolve_local_address((null, 80)) - assert got == (addrs.bind_all, 80, *addrs.extra) - got = await sock._resolve_remote_address((null, 80)) - assert got == (addrs.localhost, 80, *addrs.extra) - - # AI_PASSIVE only affects the wildcard address, so for everything else - # _resolve_local_address and _resolve_remote_address should work the same: - for resolver in ["_resolve_local_address", "_resolve_remote_address"]: - - async def res(*args): - return await getattr(sock, resolver)(*args) - - # yapf: disable - assert await res((addrs.arbitrary, - "http")) == (addrs.arbitrary, 80, *addrs.extra) - if v6: - assert await res(("1::2", 80, 1)) == ("1::2", 80, 1, 0) - assert await res(("1::2", 80, 1, 2)) == ("1::2", 80, 1, 2) - - # V4 mapped addresses resolved if V6ONLY is False - sock.setsockopt(tsocket.IPPROTO_IPV6, tsocket.IPV6_V6ONLY, False) - assert await res(("1.2.3.4", - "http")) == ("::ffff:1.2.3.4", 80, 0, 0) - - # Check the special case, because why not - assert await res(("", - 123)) == (addrs.broadcast, 123, *addrs.extra) - # yapf: enable - - # But not if it's true (at least on systems where getaddrinfo works - # correctly) - if v6 and not gai_without_v4mapped_is_buggy(): - sock.setsockopt(tsocket.IPPROTO_IPV6, tsocket.IPV6_V6ONLY, True) - with pytest.raises(tsocket.gaierror) as excinfo: - await res(("1.2.3.4", 80)) - # Windows, macOS - expected_errnos = {tsocket.EAI_NONAME} - # Linux - if hasattr(tsocket, "EAI_ADDRFAMILY"): - expected_errnos.add(tsocket.EAI_ADDRFAMILY) - assert excinfo.value.errno in expected_errnos - - # A family where we know nothing about the addresses, so should just - # pass them through. This should work on Linux, which is enough to - # smoke test the basic functionality... - try: - netlink_sock = tsocket.socket( - family=tsocket.AF_NETLINK, type=tsocket.SOCK_DGRAM - ) - except (AttributeError, OSError): - pass - else: - assert await getattr(netlink_sock, resolver)("asdf") == "asdf" + v6 = socket_type == tsocket.AF_INET6 - with pytest.raises(ValueError): - await res("1.2.3.4") - with pytest.raises(ValueError): - await res(("1.2.3.4",)) - with pytest.raises(ValueError): + def pad(addr): + if v6: + while len(addr) < 4: + addr += (0,) + return addr + + def assert_eq(actual, expected): + assert pad(expected) == pad(actual) + + with tsocket.socket(family=socket_type) as sock: + # For some reason the stdlib special-cases "" to pass NULL to + # getaddrinfo. They also error out on None, but whatever, None is much + # more consistent, so we accept it too. + for null in [None, ""]: + got = await sock._resolve_address_nocp((null, 80), local=True) + assert_eq(got, (addrs.bind_all, 80)) + got = await sock._resolve_address_nocp((null, 80), local=False) + assert_eq(got, (addrs.localhost, 80)) + + # AI_PASSIVE only affects the wildcard address, so for everything else + # local=True/local=False should work the same: + for local in [False, True]: + + async def res(*args): + return await sock._resolve_address_nocp(*args, local=local) + + assert_eq(await res((addrs.arbitrary, "http")), (addrs.arbitrary, 80)) if v6: - await res(("1.2.3.4", 80, 0, 0, 0)) + # Check handling of different length ipv6 address tuples + assert_eq(await res(("1::2", 80)), ("1::2", 80, 0, 0)) + assert_eq(await res(("1::2", 80, 0)), ("1::2", 80, 0, 0)) + assert_eq(await res(("1::2", 80, 0, 0)), ("1::2", 80, 0, 0)) + # Non-zero flowinfo/scopeid get passed through + assert_eq(await res(("1::2", 80, 1)), ("1::2", 80, 1, 0)) + assert_eq(await res(("1::2", 80, 1, 2)), ("1::2", 80, 1, 2)) + + # And again with a string port, as a trick to avoid the + # already-resolved address fastpath and make sure we call + # getaddrinfo + assert_eq(await res(("1::2", "80")), ("1::2", 80, 0, 0)) + assert_eq(await res(("1::2", "80", 0)), ("1::2", 80, 0, 0)) + assert_eq(await res(("1::2", "80", 0, 0)), ("1::2", 80, 0, 0)) + assert_eq(await res(("1::2", "80", 1)), ("1::2", 80, 1, 0)) + assert_eq(await res(("1::2", "80", 1, 2)), ("1::2", 80, 1, 2)) + + # V4 mapped addresses resolved if V6ONLY is False + sock.setsockopt(tsocket.IPPROTO_IPV6, tsocket.IPV6_V6ONLY, False) + assert_eq(await res(("1.2.3.4", "http")), ("::ffff:1.2.3.4", 80)) + + # Check the special case, because why not + assert_eq(await res(("", 123)), (addrs.broadcast, 123)) + + # But not if it's true (at least on systems where getaddrinfo works + # correctly) + if v6 and not gai_without_v4mapped_is_buggy(): + sock.setsockopt(tsocket.IPPROTO_IPV6, tsocket.IPV6_V6ONLY, True) + with pytest.raises(tsocket.gaierror) as excinfo: + await res(("1.2.3.4", 80)) + # Windows, macOS + expected_errnos = {tsocket.EAI_NONAME} + # Linux + if hasattr(tsocket, "EAI_ADDRFAMILY"): + expected_errnos.add(tsocket.EAI_ADDRFAMILY) + assert excinfo.value.errno in expected_errnos + + # A family where we know nothing about the addresses, so should just + # pass them through. This should work on Linux, which is enough to + # smoke test the basic functionality... + try: + netlink_sock = tsocket.socket( + family=tsocket.AF_NETLINK, type=tsocket.SOCK_DGRAM + ) + except (AttributeError, OSError): + pass else: - await res(("1.2.3.4", 80, 0, 0)) + assert ( + await netlink_sock._resolve_address_nocp("asdf", local=local) + == "asdf" + ) + netlink_sock.close() + + with pytest.raises(ValueError): + await res("1.2.3.4") + with pytest.raises(ValueError): + await res(("1.2.3.4",)) + with pytest.raises(ValueError): + if v6: + await res(("1.2.3.4", 80, 0, 0, 0)) + else: + await res(("1.2.3.4", 80, 0, 0)) async def test_SocketType_unresolved_names(): @@ -567,7 +626,7 @@ async def test_SocketType_non_blocking_paths(): with assert_checkpoints(): with pytest.raises(_core.Cancelled): await ta.recv(10) - # immedate success (also checks that the previous attempt didn't + # immediate success (also checks that the previous attempt didn't # actually read anything) with assert_checkpoints(): await ta.recv(10) == b"1" @@ -682,19 +741,29 @@ def connect(self, *args, **kwargs): await sock.connect(("127.0.0.1", 2)) -async def test_resolve_remote_address_exception_closes_socket(): +# Fix issue #1810 +async def test_address_in_socket_error(): + address = "127.0.0.1" + with tsocket.socket() as sock: + try: + await sock.connect((address, 2)) + except OSError as e: + assert any(address in str(arg) for arg in e.args) + + +async def test_resolve_address_exception_in_connect_closes_socket(): # Here we are testing issue 247, any cancellation will leave the socket closed with _core.CancelScope() as cancel_scope: with tsocket.socket() as sock: - async def _resolve_remote_address(self, *args, **kwargs): + async def _resolve_address_nocp(self, *args, **kwargs): cancel_scope.cancel() await _core.checkpoint() - sock._resolve_remote_address = _resolve_remote_address + sock._resolve_address_nocp = _resolve_address_nocp with assert_checkpoints(): with pytest.raises(_core.Cancelled): - await sock.connect('') + await sock.connect("") assert sock.fileno() == -1 @@ -840,9 +909,11 @@ async def getnameinfo(self, sockaddr, flags): (0, 0, tsocket.IPPROTO_TCP, 0), (0, 0, 0, tsocket.AI_CANONNAME), ]: - assert ( - await tsocket.getaddrinfo("localhost", "foo", *vals) == - ("custom_gai", b"localhost", "foo", *vals) + assert await tsocket.getaddrinfo("localhost", "foo", *vals) == ( + "custom_gai", + b"localhost", + "foo", + *vals, ) # IDNA encoding is handled before calling the special object @@ -850,7 +921,7 @@ async def getnameinfo(self, sockaddr, flags): expected = ("custom_gai", b"xn--f-1gaa", "foo", 0, 0, 0, 0) assert got == expected - assert (await tsocket.getnameinfo("a", 0) == ("custom_gni", "a", 0)) + assert await tsocket.getnameinfo("a", 0) == ("custom_gni", "a", 0) # We can set it back to None assert tsocket.set_custom_hostname_resolver(None) is cr @@ -893,9 +964,7 @@ async def test_SocketType_is_abstract(): tsocket.SocketType() -@pytest.mark.skipif( - not hasattr(tsocket, "AF_UNIX"), reason="no unix domain sockets" -) +@pytest.mark.skipif(not hasattr(tsocket, "AF_UNIX"), reason="no unix domain sockets") async def test_unix_domain_socket(): # Bind has a special branch to use a thread, since it has to do filesystem # traversal. Maybe connect should too? Not sure. @@ -907,13 +976,14 @@ async def check_AF_UNIX(path): with tsocket.socket(family=tsocket.AF_UNIX) as csock: await csock.connect(path) ssock, _ = await lsock.accept() - await csock.send(b"x") - assert await ssock.recv(1) == b"x" + with ssock: + await csock.send(b"x") + assert await ssock.recv(1) == b"x" # Can't use tmpdir fixture, because we can exceed the maximum AF_UNIX path # length on macOS. with tempfile.TemporaryDirectory() as tmpdir: - path = "{}/sock".format(tmpdir) + path = f"{tmpdir}/sock" await check_AF_UNIX(path) try: @@ -952,3 +1022,24 @@ async def receiver(): nursery.start_soon(receiver) await wait_all_tasks_blocked() a.close() + + +async def test_many_sockets(): + total = 5000 # Must be more than MAX_AFD_GROUP_SIZE + sockets = [] + for x in range(total // 2): + try: + a, b = stdlib_socket.socketpair() + except OSError as e: # pragma: no cover + assert e.errno in (errno.EMFILE, errno.ENFILE) + break + sockets += [a, b] + async with _core.open_nursery() as nursery: + for s in sockets: + nursery.start_soon(_core.wait_readable, s) + await _core.wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + for sock in sockets: + sock.close() + if x != total // 2 - 1: # pragma: no cover + print(f"Unable to open more than {(x-1)*2} sockets.") diff --git a/trio/tests/test_ssl.py b/trio/_tests/test_ssl.py similarity index 89% rename from trio/tests/test_ssl.py rename to trio/_tests/test_ssl.py index 0ac217bb4c..f91cea8549 100644 --- a/trio/tests/test_ssl.py +++ b/trio/_tests/test_ssl.py @@ -1,33 +1,33 @@ -import pytest +from __future__ import annotations -import threading +import os import socket as stdlib_socket import ssl -from contextlib import contextmanager +import sys +import threading +from contextlib import asynccontextmanager, contextmanager from functools import partial -from OpenSSL import SSL +import pytest import trustme -from async_generator import async_generator, yield_, asynccontextmanager +from OpenSSL import SSL import trio -from .. import _core -from .._highlevel_socket import SocketStream, SocketListener + +from .. import _core, socket as tsocket +from .._core import BrokenResourceError, ClosedResourceError +from .._core._tests.tutil import slow from .._highlevel_generic import aclose_forcefully -from .._core import ClosedResourceError, BrokenResourceError from .._highlevel_open_tcp_stream import open_tcp_stream -from .. import socket as tsocket -from .._ssl import SSLStream, SSLListener, NeedHandshakeError +from .._highlevel_socket import SocketListener, SocketStream +from .._ssl import NeedHandshakeError, SSLListener, SSLStream, _is_eof from .._util import ConflictDetector - -from .._core.tests.tutil import slow - from ..testing import ( - assert_checkpoints, Sequencer, - memory_stream_pair, - lockstep_stream_pair, + assert_checkpoints, check_two_way_stream, + lockstep_stream_pair, + memory_stream_pair, ) # We have two different kinds of echo server fixtures we use for testing. The @@ -52,32 +52,30 @@ TRIO_TEST_1_CERT = TRIO_TEST_CA.issue_server_cert("trio-test-1.example.org") SERVER_CTX = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) +if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"): + SERVER_CTX.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF + TRIO_TEST_1_CERT.configure_cert(SERVER_CTX) + # TLS 1.3 has a lot of changes from previous versions. So we want to run tests # with both TLS 1.3, and TLS 1.2. -if hasattr(ssl, "OP_NO_TLSv1_3"): - # "tls13" means that we're willing to negotiate TLS 1.3. Usually that's - # what will happen, but the renegotiation tests explicitly force a - # downgrade on the server side. "tls12" means we refuse to negotiate TLS - # 1.3, so we'll almost certainly use TLS 1.2. - client_ctx_params = ["tls13", "tls12"] -else: - # We can't control whether we use TLS 1.3, so we just have to accept - # whatever openssl wants to use. This might be TLS 1.2 (if openssl is - # old), or it might be TLS 1.3 (if openssl is new, but our python version - # is too old to expose the configuration knobs). - client_ctx_params = ["default"] - - -@pytest.fixture(scope="module", params=client_ctx_params) +# "tls13" means that we're willing to negotiate TLS 1.3. Usually that's +# what will happen, but the renegotiation tests explicitly force a +# downgrade on the server side. "tls12" means we refuse to negotiate TLS +# 1.3, so we'll almost certainly use TLS 1.2. +@pytest.fixture(scope="module", params=["tls13", "tls12"]) def client_ctx(request): ctx = ssl.create_default_context() + + if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"): + ctx.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF + TRIO_TEST_CA.configure_trust(ctx) if request.param in ["default", "tls13"]: return ctx elif request.param == "tls12": - ctx.options |= ssl.OP_NO_TLSv1_3 + ctx.maximum_version = ssl.TLSVersion.TLSv1_2 return ctx else: # pragma: no cover assert False @@ -89,24 +87,21 @@ def ssl_echo_serve_sync(sock, *, expect_fail=False): wrapped = SERVER_CTX.wrap_socket( sock, server_side=True, suppress_ragged_eofs=False ) - wrapped.do_handshake() - while True: - data = wrapped.recv(4096) - if not data: - # other side has initiated a graceful shutdown; we try to - # respond in kind but it's legal for them to have already gone - # away. - exceptions = (BrokenPipeError, ssl.SSLZeroReturnError) - # Under unclear conditions, CPython sometimes raises - # SSLWantWriteError here. This is a bug (bpo-32219), but it's - # not our bug, so ignore it. - exceptions += (ssl.SSLWantWriteError,) - try: - wrapped.unwrap() - except exceptions: - pass - return - wrapped.sendall(data) + with wrapped: + wrapped.do_handshake() + while True: + data = wrapped.recv(4096) + if not data: + # other side has initiated a graceful shutdown; we try to + # respond in kind but it's legal for them to have already + # gone away. + exceptions = (BrokenPipeError, ssl.SSLZeroReturnError) + try: + wrapped.unwrap() + except exceptions: + pass + return + wrapped.sendall(data) # This is an obscure workaround for an openssl bug. In server mode, in # some versions, openssl sends some extra data at the end of do_handshake # that it shouldn't send. Normally this is harmless, but, if the other @@ -132,13 +127,14 @@ def ssl_echo_serve_sync(sock, *, expect_fail=False): else: if expect_fail: # pragma: no cover raise RuntimeError("failed to fail?") + finally: + sock.close() # Fixture that gives a raw socket connected to a trio-test-1 echo server # (running in a thread). Useful for testing making connections with different # SSLContexts. @asynccontextmanager -@async_generator async def ssl_echo_server_raw(**kwargs): a, b = stdlib_socket.socketpair() async with trio.open_nursery() as nursery: @@ -147,24 +143,18 @@ async def ssl_echo_server_raw(**kwargs): # nursery context manager to exit too. with a, b: nursery.start_soon( - trio.to_thread.run_sync, - partial(ssl_echo_serve_sync, b, **kwargs) + trio.to_thread.run_sync, partial(ssl_echo_serve_sync, b, **kwargs) ) - await yield_(SocketStream(tsocket.from_stdlib_socket(a))) + yield SocketStream(tsocket.from_stdlib_socket(a)) # Fixture that gives a properly set up SSLStream connected to a trio-test-1 # echo server (running in a thread) @asynccontextmanager -@async_generator async def ssl_echo_server(client_ctx, **kwargs): async with ssl_echo_server_raw(**kwargs) as sock: - await yield_( - SSLStream( - sock, client_ctx, server_hostname="trio-test-1.example.org" - ) - ) + yield SSLStream(sock, client_ctx, server_hostname="trio-test-1.example.org") # The weird in-memory server ... thing. @@ -175,13 +165,12 @@ def __init__(self, sleeper=None): ctx = SSL.Context(SSL.SSLv23_METHOD) # TLS 1.3 removes renegotiation support. Which is great for them, but # we still have to support versions before that, and that means we - # need to test renegotation support, which means we need to force this + # need to test renegotiation support, which means we need to force this # to use a lower version where this test server can trigger # renegotiations. Of course TLS 1.3 support isn't released yet, but # I'm told that this will work once it is. (And once it is we can - # remove the pragma: no cover too.) Alternatively, once we drop - # support for CPython 3.5 on macOS, then we could switch to using - # TLSv1_2_METHOD. + # remove the pragma: no cover too.) Alternatively, we could switch to + # using TLSv1_2_METHOD. # # Discussion: https://github.com/pyca/pyopenssl/issues/624 @@ -195,6 +184,7 @@ def __init__(self, sleeper=None): # Fortunately pyopenssl uses cryptography under the hood, so we can be # confident that they're using the same version of openssl from cryptography.hazmat.bindings.openssl.binding import Binding + b = Binding() if hasattr(b.lib, "SSL_OP_NO_TLSv1_3"): ctx.set_options(b.lib.SSL_OP_NO_TLSv1_3) @@ -235,7 +225,7 @@ def renegotiate_pending(self): return self._conn.renegotiate_pending() def renegotiate(self): - # Returns false if a renegotation is already in progress, meaning + # Returns false if a renegotiation is already in progress, meaning # nothing happens. assert self._conn.renegotiate() @@ -327,7 +317,7 @@ async def receive_some(self, nbytes=None): async def test_PyOpenSSLEchoStream_gives_resource_busy_errors(): # Make sure that PyOpenSSLEchoStream complains if two tasks call send_all # at the same time, or ditto for receive_some. The tricky cases where SSLStream - # might accidentally do this are during renegotation, which we test using + # might accidentally do this are during renegotiation, which we test using # PyOpenSSLEchoStream, so this makes sure that if we do have a bug then # PyOpenSSLEchoStream will notice and complain. @@ -363,9 +353,7 @@ async def test_PyOpenSSLEchoStream_gives_resource_busy_errors(): @contextmanager def virtual_ssl_echo_server(client_ctx, **kwargs): fakesock = PyOpenSSLEchoStream(**kwargs) - yield SSLStream( - fakesock, client_ctx, server_hostname="trio-test-1.example.org" - ) + yield SSLStream(fakesock, client_ctx, server_hostname="trio-test-1.example.org") def ssl_wrap_pair( @@ -374,13 +362,13 @@ def ssl_wrap_pair( server_transport, *, client_kwargs={}, - server_kwargs={} + server_kwargs={}, ): client_ssl = SSLStream( client_transport, client_ctx, server_hostname="trio-test-1.example.org", - **client_kwargs + **client_kwargs, ) server_ssl = SSLStream( server_transport, SERVER_CTX, server_side=True, **server_kwargs @@ -390,16 +378,12 @@ def ssl_wrap_pair( def ssl_memory_stream_pair(client_ctx, **kwargs): client_transport, server_transport = memory_stream_pair() - return ssl_wrap_pair( - client_ctx, client_transport, server_transport, **kwargs - ) + return ssl_wrap_pair(client_ctx, client_transport, server_transport, **kwargs) def ssl_lockstep_stream_pair(client_ctx, **kwargs): client_transport, server_transport = lockstep_stream_pair() - return ssl_wrap_pair( - client_ctx, client_transport, server_transport, **kwargs - ) + return ssl_wrap_pair(client_ctx, client_transport, server_transport, **kwargs) # Simple smoke test for handshake/send/receive/shutdown talking to a @@ -416,9 +400,7 @@ async def test_ssl_client_basics(client_ctx): # Didn't configure the CA file, should fail async with ssl_echo_server_raw(expect_fail=True) as sock: bad_client_ctx = ssl.create_default_context() - s = SSLStream( - sock, bad_client_ctx, server_hostname="trio-test-1.example.org" - ) + s = SSLStream(sock, bad_client_ctx, server_hostname="trio-test-1.example.org") assert not s.server_side with pytest.raises(BrokenResourceError) as excinfo: await s.send_all(b"x") @@ -426,9 +408,7 @@ async def test_ssl_client_basics(client_ctx): # Trusted CA, but wrong host name async with ssl_echo_server_raw(expect_fail=True) as sock: - s = SSLStream( - sock, client_ctx, server_hostname="trio-test-2.example.org" - ) + s = SSLStream(sock, client_ctx, server_hostname="trio-test-2.example.org") assert not s.server_side with pytest.raises(BrokenResourceError) as excinfo: await s.send_all(b"x") @@ -445,13 +425,13 @@ async def test_ssl_server_basics(client_ctx): assert server_transport.server_side def client(): - client_sock = client_ctx.wrap_socket( + with client_ctx.wrap_socket( a, server_hostname="trio-test-1.example.org" - ) - client_sock.sendall(b"x") - assert client_sock.recv(1) == b"y" - client_sock.sendall(b"z") - client_sock.unwrap() + ) as client_sock: + client_sock.sendall(b"x") + assert client_sock.recv(1) == b"y" + client_sock.sendall(b"z") + client_sock.unwrap() t = threading.Thread(target=client) t.start() @@ -469,9 +449,7 @@ async def test_attributes(client_ctx): async with ssl_echo_server_raw(expect_fail=True) as sock: good_ctx = client_ctx bad_ctx = ssl.create_default_context() - s = SSLStream( - sock, good_ctx, server_hostname="trio-test-1.example.org" - ) + s = SSLStream(sock, good_ctx, server_hostname="trio-test-1.example.org") assert s.transport_stream is sock @@ -598,6 +576,7 @@ async def test_renegotiation_randomized(mock_clock, client_ctx): mock_clock.autojump_threshold = 0 import random + r = random.Random(0) async def sleeper(_): @@ -634,8 +613,8 @@ async def expect(expected): await clear() for i in range(100): - b1 = bytes([i % 0xff]) - b2 = bytes([(2 * i) % 0xff]) + b1 = bytes([i % 0xFF]) + b2 = bytes([(2 * i) % 0xFF]) s.transport_stream.renegotiate() async with _core.open_nursery() as nursery: nursery.start_soon(send, b1) @@ -646,8 +625,8 @@ async def expect(expected): await clear() for i in range(100): - b1 = bytes([i % 0xff]) - b2 = bytes([(2 * i) % 0xff]) + b1 = bytes([i % 0xFF]) + b2 = bytes([(2 * i) % 0xFF]) await send(b1) s.transport_stream.renegotiate() await expect(b1) @@ -673,9 +652,7 @@ async def sleep_then_wait_writable(): await trio.sleep(1000) await s.wait_send_all_might_not_block() - with virtual_ssl_echo_server( - client_ctx, sleeper=sleeper_with_slow_send_all - ) as s: + with virtual_ssl_echo_server(client_ctx, sleeper=sleeper_with_slow_send_all) as s: await send(b"x") s.transport_stream.renegotiate() async with _core.open_nursery() as nursery: @@ -764,6 +741,10 @@ async def wait_send_all_might_not_block(self): assert record == ["ok"] +@pytest.mark.skipif( + os.name == "nt" and sys.version_info >= (3, 10), + reason="frequently fails on Windows + Python 3.10", +) async def test_checkpoints(client_ctx): async with ssl_echo_server(client_ctx) as s: with assert_checkpoints(): @@ -1028,7 +1009,7 @@ async def test_ssl_bad_shutdown_but_its_ok(client_ctx): client, server = ssl_memory_stream_pair( client_ctx, server_kwargs={"https_compatible": True}, - client_kwargs={"https_compatible": True} + client_kwargs={"https_compatible": True}, ) async with _core.open_nursery() as nursery: @@ -1052,9 +1033,7 @@ async def test_ssl_handshake_failure_during_aclose(): async with ssl_echo_server_raw(expect_fail=True) as sock: # Don't configure trust correctly client_ctx = ssl.create_default_context() - s = SSLStream( - sock, client_ctx, server_hostname="trio-test-1.example.org" - ) + s = SSLStream(sock, client_ctx, server_hostname="trio-test-1.example.org") # It's a little unclear here whether aclose should swallow the error # or let it escape. We *do* swallow the error if it arrives when we're # sending close_notify, because both sides closing the connection @@ -1094,7 +1073,7 @@ async def test_ssl_https_compatibility_disagreement(client_ctx): client, server = ssl_memory_stream_pair( client_ctx, server_kwargs={"https_compatible": False}, - client_kwargs={"https_compatible": True} + client_kwargs={"https_compatible": True}, ) async with _core.open_nursery() as nursery: @@ -1106,7 +1085,8 @@ async def test_ssl_https_compatibility_disagreement(client_ctx): async def receive_and_expect_error(): with pytest.raises(BrokenResourceError) as excinfo: await server.receive_some(10) - assert isinstance(excinfo.value.__cause__, ssl.SSLEOFError) + + assert _is_eof(excinfo.value.__cause__) async with _core.open_nursery() as nursery: nursery.start_soon(client.aclose) @@ -1117,7 +1097,7 @@ async def test_https_mode_eof_before_handshake(client_ctx): client, server = ssl_memory_stream_pair( client_ctx, server_kwargs={"https_compatible": True}, - client_kwargs={"https_compatible": True} + client_kwargs={"https_compatible": True}, ) async def server_expect_clean_eof(): @@ -1179,7 +1159,7 @@ async def test_selected_alpn_protocol_before_handshake(client_ctx): async def test_selected_alpn_protocol_when_not_set(client_ctx): - # ALPN protocol still returns None when it's not ser, + # ALPN protocol still returns None when it's not set, # instead of raising an exception client, server = ssl_memory_stream_pair(client_ctx) @@ -1190,8 +1170,7 @@ async def test_selected_alpn_protocol_when_not_set(client_ctx): assert client.selected_alpn_protocol() is None assert server.selected_alpn_protocol() is None - assert client.selected_alpn_protocol() == \ - server.selected_alpn_protocol() + assert client.selected_alpn_protocol() == server.selected_alpn_protocol() async def test_selected_npn_protocol_before_handshake(client_ctx): @@ -1204,8 +1183,12 @@ async def test_selected_npn_protocol_before_handshake(client_ctx): server.selected_npn_protocol() +@pytest.mark.filterwarnings( + r"ignore: ssl module. NPN is deprecated, use ALPN instead:UserWarning", + r"ignore:ssl NPN is deprecated, use ALPN instead:DeprecationWarning", +) async def test_selected_npn_protocol_when_not_set(client_ctx): - # NPN protocol still returns None when it's not ser, + # NPN protocol still returns None when it's not set, # instead of raising an exception client, server = ssl_memory_stream_pair(client_ctx) @@ -1216,8 +1199,7 @@ async def test_selected_npn_protocol_when_not_set(client_ctx): assert client.selected_npn_protocol() is None assert server.selected_npn_protocol() is None - assert client.selected_npn_protocol() == \ - server.selected_npn_protocol() + assert client.selected_npn_protocol() == server.selected_npn_protocol() async def test_get_channel_binding_before_handshake(client_ctx): @@ -1240,8 +1222,7 @@ async def test_get_channel_binding_after_handshake(client_ctx): assert client.get_channel_binding() is not None assert server.get_channel_binding() is not None - assert client.get_channel_binding() == \ - server.get_channel_binding() + assert client.get_channel_binding() == server.get_channel_binding() async def test_getpeercert(client_ctx): @@ -1254,10 +1235,7 @@ async def test_getpeercert(client_ctx): assert server.getpeercert() is None print(client.getpeercert()) - assert ( - ("DNS", - "trio-test-1.example.org") in client.getpeercert()["subjectAltName"] - ) + assert ("DNS", "trio-test-1.example.org") in client.getpeercert()["subjectAltName"] async def test_SSLListener(client_ctx): @@ -1270,9 +1248,7 @@ async def setup(**kwargs): transport_client = await open_tcp_stream(*listen_sock.getsockname()) ssl_client = SSLStream( - transport_client, - client_ctx, - server_hostname="trio-test-1.example.org" + transport_client, client_ctx, server_hostname="trio-test-1.example.org" ) return listen_sock, ssl_listener, ssl_client @@ -1305,15 +1281,3 @@ async def setup(**kwargs): await aclose_forcefully(ssl_listener) await aclose_forcefully(ssl_client) await aclose_forcefully(ssl_server) - - -async def test_deprecated_max_refill_bytes(client_ctx): - stream1, stream2 = memory_stream_pair() - with pytest.warns(trio.TrioDeprecationWarning): - SSLStream(stream1, client_ctx, max_refill_bytes=100) - with pytest.warns(trio.TrioDeprecationWarning): - # passing None is wrong here, but I'm too lazy to make a fake Listener - # and we get away with it for now. And this test will be deleted in a - # release or two anyway, so hopefully we'll keep getting away with it - # for long enough. - SSLListener(None, client_ctx, max_refill_bytes=100) diff --git a/trio/tests/test_subprocess.py b/trio/_tests/test_subprocess.py similarity index 62% rename from trio/tests/test_subprocess.py rename to trio/_tests/test_subprocess.py index 7fe64564c7..4dfaef4c7f 100644 --- a/trio/tests/test_subprocess.py +++ b/trio/_tests/test_subprocess.py @@ -1,16 +1,28 @@ import os +import random import signal import subprocess import sys +from contextlib import asynccontextmanager +from functools import partial +from pathlib import Path as SyncPath + import pytest -import random from .. import ( - _core, move_on_after, fail_after, sleep, sleep_forever, Process, - open_process, run_process, TrioDeprecationWarning + ClosedResourceError, + Event, + Process, + _core, + fail_after, + move_on_after, + run_process, + sleep, + sleep_forever, ) -from .._core.tests.tutil import slow -from ..testing import wait_all_tasks_blocked +from .._core._tests.tutil import skip_if_fbsd_pipes_broken, slow +from ..lowlevel import open_process +from ..testing import assert_no_checkpoints, wait_all_tasks_blocked posix = os.name == "posix" if posix: @@ -29,7 +41,11 @@ def python(code): EXIT_TRUE = python("sys.exit(0)") EXIT_FALSE = python("sys.exit(1)") CAT = python("sys.stdout.buffer.write(sys.stdin.buffer.read())") -SLEEP = lambda seconds: python("import time; time.sleep({})".format(seconds)) + +if posix: + SLEEP = lambda seconds: ["/bin/sleep", str(seconds)] +else: + SLEEP = lambda seconds: python(f"import time; time.sleep({seconds})") def got_signal(proc, sig): @@ -39,35 +55,65 @@ def got_signal(proc, sig): return proc.returncode != 0 -async def test_basic(): - repr_template = "".format(EXIT_TRUE) - async with await open_process(EXIT_TRUE) as proc: - assert isinstance(proc, Process) - assert proc.returncode is None - assert repr(proc) == repr_template.format( - "running with PID {}".format(proc.pid) - ) +@asynccontextmanager +async def open_process_then_kill(*args, **kwargs): + proc = await open_process(*args, **kwargs) + try: + yield proc + finally: + proc.kill() + await proc.wait() + + +@asynccontextmanager +async def run_process_in_nursery(*args, **kwargs): + async with _core.open_nursery() as nursery: + kwargs.setdefault("check", False) + proc = await nursery.start(partial(run_process, *args, **kwargs)) + yield proc + nursery.cancel_scope.cancel() + + +background_process_param = pytest.mark.parametrize( + "background_process", + [open_process_then_kill, run_process_in_nursery], + ids=["open_process", "run_process in nursery"], +) + + +@background_process_param +async def test_basic(background_process): + async with background_process(EXIT_TRUE) as proc: + await proc.wait() + assert isinstance(proc, Process) + assert proc._pidfd is None assert proc.returncode == 0 - assert repr(proc) == repr_template.format("exited with status 0") + assert repr(proc) == f"" - async with await open_process(EXIT_FALSE) as proc: - pass + async with background_process(EXIT_FALSE) as proc: + await proc.wait() assert proc.returncode == 1 assert repr(proc) == "".format( EXIT_FALSE, "exited with status 1" ) -# Delete this test when we remove direct Process construction -async def test_deprecated_Process_init(): - with pytest.warns(TrioDeprecationWarning): - async with Process(EXIT_TRUE) as proc: - assert isinstance(proc, Process) - assert proc.returncode == 0 +@background_process_param +async def test_auto_update_returncode(background_process): + async with background_process(SLEEP(9999)) as p: + assert p.returncode is None + assert "running" in repr(p) + p.kill() + p._proc.wait() + assert p.returncode is not None + assert "exited" in repr(p) + assert p._pidfd is None + assert p.returncode is not None -async def test_multi_wait(): - async with await open_process(SLEEP(10)) as proc: +@background_process_param +async def test_multi_wait(background_process): + async with background_process(SLEEP(10)) as proc: # Check that wait (including multi-wait) tolerates being cancelled async with _core.open_nursery() as nursery: nursery.start_soon(proc.wait) @@ -85,7 +131,21 @@ async def test_multi_wait(): proc.kill() -async def test_kill_when_context_cancelled(): +# Test for deprecated 'async with process:' semantics +async def test_async_with_basics_deprecated(recwarn): + async with await open_process( + CAT, stdin=subprocess.PIPE, stdout=subprocess.PIPE + ) as proc: + pass + assert proc.returncode is not None + with pytest.raises(ClosedResourceError): + await proc.stdin.send_all(b"x") + with pytest.raises(ClosedResourceError): + await proc.stdout.receive_some() + + +# Test for deprecated 'async with process:' semantics +async def test_kill_when_context_cancelled(recwarn): with move_on_after(100) as scope: async with await open_process(SLEEP(10)) as proc: assert proc.poll() is None @@ -105,8 +165,9 @@ async def test_kill_when_context_cancelled(): ) -async def test_pipes(): - async with await open_process( +@background_process_param +async def test_pipes(background_process): + async with background_process( COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR, stdin=subprocess.PIPE, stdout=subprocess.PIPE, @@ -125,8 +186,8 @@ async def check_output(stream, expected): assert seen == expected async with _core.open_nursery() as nursery: - # fail quickly if something is broken - nursery.cancel_scope.deadline = _core.current_time() + 3.0 + # fail eventually if something is broken + nursery.cancel_scope.deadline = _core.current_time() + 30.0 nursery.start_soon(feed_input) nursery.start_soon(check_output, proc.stdout, msg) nursery.start_soon(check_output, proc.stderr, msg[::-1]) @@ -135,7 +196,8 @@ async def check_output(stream, expected): assert 0 == await proc.wait() -async def test_interactive(): +@background_process_param +async def test_interactive(background_process): # Test some back-and-forth with a subprocess. This one works like so: # in: 32\n # out: 0000...0000\n (32 zeroes) @@ -147,7 +209,7 @@ async def test_interactive(): # out: EOF # err: EOF - async with await open_process( + async with background_process( python( "idx = 0\n" "while True:\n" @@ -162,7 +224,6 @@ async def test_interactive(): stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) as proc: - newline = b"\n" if posix else b"\r\n" async def expect(idx, request): @@ -171,17 +232,13 @@ async def expect(idx, request): async def drain_one(stream, count, digit): while count > 0: result = await stream.receive_some(count) - assert result == ( - "{}".format(digit).encode("utf-8") * len(result) - ) + assert result == (f"{digit}".encode() * len(result)) count -= len(result) assert count == 0 assert await stream.receive_some(len(newline)) == newline nursery.start_soon(drain_one, proc.stdout, request, idx * 2) - nursery.start_soon( - drain_one, proc.stderr, request * 2, idx * 2 + 1 - ) + nursery.start_soon(drain_one, proc.stderr, request * 2, idx * 2 + 1) with fail_after(5): await proc.stdin.send_all(b"12") @@ -202,6 +259,8 @@ async def drain_one(stream, count, digit): await proc.stdin.aclose() assert await proc.stdout.receive_some(1) == b"" assert await proc.stderr.receive_some(1) == b"" + await proc.wait() + assert proc.returncode == 0 @@ -238,6 +297,10 @@ async def test_run(): await run_process(CAT, stdin="oh no, it's text") with pytest.raises(ValueError): await run_process(CAT, stdin=subprocess.PIPE) + with pytest.raises(ValueError): + await run_process(CAT, stdout=subprocess.PIPE) + with pytest.raises(ValueError): + await run_process(CAT, stderr=subprocess.PIPE) with pytest.raises(ValueError): await run_process(CAT, capture_stdout=True, stdout=subprocess.DEVNULL) with pytest.raises(ValueError): @@ -262,17 +325,18 @@ async def test_run_check(): assert result.returncode == 1 +@skip_if_fbsd_pipes_broken async def test_run_with_broken_pipe(): result = await run_process( - [sys.executable, "-c", "import sys; sys.stdin.close()"], - stdin=b"x" * 131072, + [sys.executable, "-c", "import sys; sys.stdin.close()"], stdin=b"x" * 131072 ) assert result.returncode == 0 assert result.stdout is result.stderr is None -async def test_stderr_stdout(): - async with await open_process( +@background_process_param +async def test_stderr_stdout(background_process): + async with background_process( COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR, stdin=subprocess.PIPE, stdout=subprocess.PIPE, @@ -305,19 +369,20 @@ async def test_stderr_stdout(): # this one hits the branch where stderr=STDOUT but stdout # is not redirected - async with await open_process( + async with background_process( CAT, stdin=subprocess.PIPE, stderr=subprocess.STDOUT ) as proc: assert proc.stdout is None assert proc.stderr is None await proc.stdin.aclose() + await proc.wait() assert proc.returncode == 0 if posix: try: r, w = os.pipe() - async with await open_process( + async with background_process( COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR, stdin=subprocess.PIPE, stdout=w, @@ -349,11 +414,13 @@ async def test_errors(): await open_process("ls", shell=False) -async def test_signals(): +@background_process_param +async def test_signals(background_process): async def test_one_signal(send_it, signum): with move_on_after(1.0) as scope: - async with await open_process(SLEEP(3600)) as proc: + async with background_process(SLEEP(3600)) as proc: send_it(proc) + await proc.wait() assert not scope.cancelled_caught if posix: assert proc.returncode == -signum @@ -374,13 +441,14 @@ async def test_one_signal(send_it, signum): @pytest.mark.skipif(not posix, reason="POSIX specific") -async def test_wait_reapable_fails(): +@background_process_param +async def test_wait_reapable_fails(background_process): old_sigchld = signal.signal(signal.SIGCHLD, signal.SIG_IGN) try: # With SIGCHLD disabled, the wait() syscall will wait for the # process to exit but then fail with ECHILD. Make sure we # support this case as the stdlib subprocess module does. - async with await open_process(SLEEP(3600)) as proc: + async with background_process(SLEEP(3600)) as proc: async with _core.open_nursery() as nursery: nursery.start_soon(proc.wait) await wait_all_tasks_blocked() @@ -397,6 +465,7 @@ def test_waitid_eintr(): # This only matters on PyPy (where we're coding EINTR handling # ourselves) but the test works on all waitid platforms. from .._subprocess_platform import wait_child_exiting + if not wait_child_exiting.__module__.endswith("waitid"): pytest.skip("waitid only") from .._subprocess_platform.waitid import sync_wait_reapable @@ -421,3 +490,110 @@ def on_alarm(sig, frame): sleeper.kill() sleeper.wait() signal.signal(signal.SIGALRM, old_sigalrm) + + +async def test_custom_deliver_cancel(): + custom_deliver_cancel_called = False + + async def custom_deliver_cancel(proc): + nonlocal custom_deliver_cancel_called + custom_deliver_cancel_called = True + proc.terminate() + # Make sure this does get cancelled when the process exits, and that + # the process really exited. + try: + await sleep_forever() + finally: + assert proc.returncode is not None + + async with _core.open_nursery() as nursery: + nursery.start_soon( + partial(run_process, SLEEP(9999), deliver_cancel=custom_deliver_cancel) + ) + await wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + + assert custom_deliver_cancel_called + + +async def test_warn_on_failed_cancel_terminate(monkeypatch): + original_terminate = Process.terminate + + def broken_terminate(self): + original_terminate(self) + raise OSError("whoops") + + monkeypatch.setattr(Process, "terminate", broken_terminate) + + with pytest.warns(RuntimeWarning, match=".*whoops.*"): + async with _core.open_nursery() as nursery: + nursery.start_soon(run_process, SLEEP(9999)) + await wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + + +@pytest.mark.skipif(not posix, reason="posix only") +async def test_warn_on_cancel_SIGKILL_escalation(autojump_clock, monkeypatch): + monkeypatch.setattr(Process, "terminate", lambda *args: None) + + with pytest.warns(RuntimeWarning, match=".*ignored SIGTERM.*"): + async with _core.open_nursery() as nursery: + nursery.start_soon(run_process, SLEEP(9999)) + await wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + + +# the background_process_param exercises a lot of run_process cases, but it uses +# check=False, so lets have a test that uses check=True as well +async def test_run_process_background_fail(): + with pytest.raises(subprocess.CalledProcessError): + async with _core.open_nursery() as nursery: + proc = await nursery.start(run_process, EXIT_FALSE) + assert proc.returncode == 1 + + +@pytest.mark.skipif( + not SyncPath("/dev/fd").exists(), + reason="requires a way to iterate through open files", +) +async def test_for_leaking_fds(): + starting_fds = set(SyncPath("/dev/fd").iterdir()) + await run_process(EXIT_TRUE) + assert set(SyncPath("/dev/fd").iterdir()) == starting_fds + + with pytest.raises(subprocess.CalledProcessError): + await run_process(EXIT_FALSE) + assert set(SyncPath("/dev/fd").iterdir()) == starting_fds + + with pytest.raises(PermissionError): + await run_process(["/dev/fd/0"]) + assert set(SyncPath("/dev/fd").iterdir()) == starting_fds + + +# regression test for #2209 +async def test_subprocess_pidfd_unnotified(): + noticed_exit = None + + async def wait_and_tell(proc) -> None: + nonlocal noticed_exit + noticed_exit = Event() + await proc.wait() + noticed_exit.set() + + proc = await open_process(SLEEP(9999)) + async with _core.open_nursery() as nursery: + nursery.start_soon(wait_and_tell, proc) + await wait_all_tasks_blocked() + assert isinstance(noticed_exit, Event) + proc.terminate() + # without giving trio a chance to do so, + with assert_no_checkpoints(): + # wait until the process has actually exited; + proc._proc.wait() + # force a call to poll (that closes the pidfd on linux) + proc.poll() + with move_on_after(5): + # Some platforms use threads to wait for exit, so it might take a bit + # for everything to notice + await noticed_exit.wait() + assert noticed_exit.is_set(), "child task wasn't woken after poll, DEADLOCK" diff --git a/trio/tests/test_sync.py b/trio/_tests/test_sync.py similarity index 96% rename from trio/tests/test_sync.py rename to trio/_tests/test_sync.py index ab5b26d06e..7de42b86f9 100644 --- a/trio/tests/test_sync.py +++ b/trio/_tests/test_sync.py @@ -1,13 +1,11 @@ -import pytest - import weakref -from ..testing import wait_all_tasks_blocked, assert_checkpoints +import pytest from .. import _core -from .. import _timeouts -from .._timeouts import sleep_forever, move_on_after from .._sync import * +from .._timeouts import sleep_forever +from ..testing import assert_checkpoints, wait_all_tasks_blocked async def test_Event(): @@ -40,16 +38,6 @@ async def child(): assert record == ["sleeping", "sleeping", "woken", "woken"] -# When we remove clear() then this test can be removed too -def test_Event_clear(recwarn): - e = Event() - assert not e.is_set() - e.set() - assert e.is_set() - e.clear() - assert not e.is_set() - - async def test_CapacityLimiter(): with pytest.raises(TypeError): CapacityLimiter(1.0) @@ -121,6 +109,7 @@ async def test_CapacityLimiter(): async def test_CapacityLimiter_inf(): from math import inf + c = CapacityLimiter(inf) repr(c) # smoke test assert c.total_tokens == inf @@ -250,9 +239,7 @@ async def test_Semaphore_bounded(): assert bs.value == 1 -@pytest.mark.parametrize( - "lockcls", [Lock, StrictFIFOLock], ids=lambda fn: fn.__name__ -) +@pytest.mark.parametrize("lockcls", [Lock, StrictFIFOLock], ids=lambda fn: fn.__name__) async def test_Lock_and_StrictFIFOLock(lockcls): l = lockcls() # noqa assert not l.locked() @@ -412,15 +399,14 @@ async def waiter(i): assert c.locked() -from .._sync import async_cm from .._channel import open_memory_channel +from .._sync import AsyncContextManagerMixin # Three ways of implementing a Lock in terms of a channel. Used to let us put # the channel through the generic lock tests. -@async_cm -class ChannelLock1: +class ChannelLock1(AsyncContextManagerMixin): def __init__(self, capacity): self.s, self.r = open_memory_channel(capacity) for _ in range(capacity - 1): @@ -436,8 +422,7 @@ def release(self): self.r.receive_nowait() -@async_cm -class ChannelLock2: +class ChannelLock2(AsyncContextManagerMixin): def __init__(self): self.s, self.r = open_memory_channel(10) self.s.send_nowait(None) @@ -452,8 +437,7 @@ def release(self): self.s.send_nowait(None) -@async_cm -class ChannelLock3: +class ChannelLock3(AsyncContextManagerMixin): def __init__(self): self.s, self.r = open_memory_channel(0) # self.acquired is true when one task acquires the lock and @@ -558,7 +542,7 @@ async def loopy(name, lock_like): # The first three could be in any order due to scheduling randomness, # but after that they should repeat in the same order for i in range(LOOPS): - assert record[3 * i:3 * (i + 1)] == initial_order + assert record[3 * i : 3 * (i + 1)] == initial_order @generic_lock_test diff --git a/trio/tests/test_testing.py b/trio/_tests/test_testing.py similarity index 74% rename from trio/tests/test_testing.py rename to trio/_tests/test_testing.py index e73624a67c..3b5a57d3ec 100644 --- a/trio/tests/test_testing.py +++ b/trio/_tests/test_testing.py @@ -1,20 +1,16 @@ # XX this should get broken up, like testing.py did -import time -from math import inf import tempfile import pytest -from .._core.tests.tutil import can_bind_ipv6 -from .. import sleep -from .. import _core +from .. import _core, sleep, socket as tsocket +from .._core._tests.tutil import can_bind_ipv6 from .._highlevel_generic import aclose_forcefully +from .._highlevel_socket import SocketListener from ..testing import * from ..testing._check_streams import _assert_raises from ..testing._memory_streams import _UnboundedByteQueue -from .. import socket as tsocket -from .._highlevel_socket import SocketListener async def test_wait_all_tasks_blocked(): @@ -105,32 +101,6 @@ async def wait_big_cushion(): ] -async def test_wait_all_tasks_blocked_with_tiebreaker(): - record = [] - - async def do_wait(cushion, tiebreaker): - await wait_all_tasks_blocked(cushion=cushion, tiebreaker=tiebreaker) - record.append((cushion, tiebreaker)) - - async with _core.open_nursery() as nursery: - nursery.start_soon(do_wait, 0, 0) - nursery.start_soon(do_wait, 0, -1) - nursery.start_soon(do_wait, 0, 1) - nursery.start_soon(do_wait, 0, -1) - nursery.start_soon(do_wait, 0.0001, 10) - nursery.start_soon(do_wait, 0.0001, -10) - - assert record == sorted(record) - assert record == [ - (0, -1), - (0, -1), - (0, 0), - (0, 1), - (0.0001, -10), - (0.0001, 10), - ] - - ################################################################ @@ -146,7 +116,8 @@ async def test_assert_checkpoints(recwarn): # if you have a schedule point but not a cancel point, or vice-versa, then # that's not a checkpoint. for partial_yield in [ - _core.checkpoint_if_cancelled, _core.cancel_shielded_checkpoint + _core.checkpoint_if_cancelled, + _core.cancel_shielded_checkpoint, ]: print(partial_yield) with pytest.raises(AssertionError): @@ -171,7 +142,8 @@ async def test_assert_no_checkpoints(recwarn): # if you have a schedule point but not a cancel point, or vice-versa, then # that doesn't make *either* version of assert_{no_,}yields happy. for partial_yield in [ - _core.checkpoint_if_cancelled, _core.cancel_shielded_checkpoint + _core.checkpoint_if_cancelled, + _core.cancel_shielded_checkpoint, ]: print(partial_yield) with pytest.raises(AssertionError): @@ -215,9 +187,7 @@ async def f2(seq): nursery.start_soon(f2, seq) async with seq(5): await wait_all_tasks_blocked() - assert record == [ - ("f2", 0), ("f1", 1), ("f2", 2), ("f1", 3), ("f1", 4) - ] + assert record == [("f2", 0), ("f1", 1), ("f2", 2), ("f1", 3), ("f1", 4)] seq = Sequencer() # Catches us if we try to re-use a sequence point: @@ -241,7 +211,7 @@ async def child(i): async with seq(i): pass # pragma: no cover except RuntimeError: - record.append("seq({}) RuntimeError".format(i)) + record.append(f"seq({i}) RuntimeError") async with _core.open_nursery() as nursery: nursery.start_soon(child, 1) @@ -260,160 +230,6 @@ async def child(i): ################################################################ -def test_mock_clock(): - REAL_NOW = 123.0 - c = MockClock() - c._real_clock = lambda: REAL_NOW - repr(c) # smoke test - assert c.rate == 0 - assert c.current_time() == 0 - c.jump(1.2) - assert c.current_time() == 1.2 - with pytest.raises(ValueError): - c.jump(-1) - assert c.current_time() == 1.2 - assert c.deadline_to_sleep_time(1.1) == 0 - assert c.deadline_to_sleep_time(1.2) == 0 - assert c.deadline_to_sleep_time(1.3) > 999999 - - with pytest.raises(ValueError): - c.rate = -1 - assert c.rate == 0 - - c.rate = 2 - assert c.current_time() == 1.2 - REAL_NOW += 1 - assert c.current_time() == 3.2 - assert c.deadline_to_sleep_time(3.1) == 0 - assert c.deadline_to_sleep_time(3.2) == 0 - assert c.deadline_to_sleep_time(4.2) == 0.5 - - c.rate = 0.5 - assert c.current_time() == 3.2 - assert c.deadline_to_sleep_time(3.1) == 0 - assert c.deadline_to_sleep_time(3.2) == 0 - assert c.deadline_to_sleep_time(4.2) == 2.0 - - c.jump(0.8) - assert c.current_time() == 4.0 - REAL_NOW += 1 - assert c.current_time() == 4.5 - - c2 = MockClock(rate=3) - assert c2.rate == 3 - assert c2.current_time() < 10 - - -async def test_mock_clock_autojump(mock_clock): - assert mock_clock.autojump_threshold == inf - - mock_clock.autojump_threshold = 0 - assert mock_clock.autojump_threshold == 0 - - real_start = time.perf_counter() - - virtual_start = _core.current_time() - for i in range(10): - print("sleeping {} seconds".format(10 * i)) - await sleep(10 * i) - print("woke up!") - assert virtual_start + 10 * i == _core.current_time() - virtual_start = _core.current_time() - - real_duration = time.perf_counter() - real_start - print( - "Slept {} seconds in {} seconds".format( - 10 * sum(range(10)), real_duration - ) - ) - assert real_duration < 1 - - mock_clock.autojump_threshold = 0.02 - t = _core.current_time() - # this should wake up before the autojump threshold triggers, so time - # shouldn't change - await wait_all_tasks_blocked() - assert t == _core.current_time() - # this should too - await wait_all_tasks_blocked(0.01) - assert t == _core.current_time() - - # This should wake up at the same time as the autojump_threshold, and - # confuse things. There is no deadline, so it shouldn't actually jump - # the clock. But does it handle the situation gracefully? - await wait_all_tasks_blocked(cushion=0.02, tiebreaker=float("inf")) - # And again with threshold=0, because that has some special - # busy-wait-avoidance logic: - mock_clock.autojump_threshold = 0 - await wait_all_tasks_blocked(tiebreaker=float("inf")) - - # set up a situation where the autojump task is blocked for a long long - # time, to make sure that cancel-and-adjust-threshold logic is working - mock_clock.autojump_threshold = 10000 - await wait_all_tasks_blocked() - mock_clock.autojump_threshold = 0 - # if the above line didn't take affect immediately, then this would be - # bad: - await sleep(100000) - - -async def test_mock_clock_autojump_interference(mock_clock): - mock_clock.autojump_threshold = 0.02 - - mock_clock2 = MockClock() - # messing with the autojump threshold of a clock that isn't actually - # installed in the run loop shouldn't do anything. - mock_clock2.autojump_threshold = 0.01 - - # if the autojump_threshold of 0.01 were in effect, then the next line - # would block forever, as the autojump task kept waking up to try to - # jump the clock. - await wait_all_tasks_blocked(0.015) - - # but the 0.02 limit does apply - await sleep(100000) - - -def test_mock_clock_autojump_preset(): - # Check that we can set the autojump_threshold before the clock is - # actually in use, and it gets picked up - mock_clock = MockClock(autojump_threshold=0.1) - mock_clock.autojump_threshold = 0.01 - real_start = time.perf_counter() - _core.run(sleep, 10000, clock=mock_clock) - assert time.perf_counter() - real_start < 1 - - -async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked(mock_clock): - # Checks that autojump_threshold=0 doesn't interfere with - # calling wait_all_tasks_blocked with the default cushion=0 and arbitrary - # tiebreakers. - - mock_clock.autojump_threshold = 0 - - record = [] - - async def sleeper(): - await sleep(100) - record.append("yawn") - - async def waiter(): - for i in range(10): - await wait_all_tasks_blocked(tiebreaker=i) - record.append(i) - await sleep(1000) - record.append("waiter done") - - async with _core.open_nursery() as nursery: - nursery.start_soon(sleeper) - nursery.start_soon(waiter) - - assert record == list(range(10)) + ["yawn", "waiter done"] - - -################################################################ - - async def test__assert_raises(): with pytest.raises(AssertionError): with _assert_raises(RuntimeError): @@ -571,9 +387,7 @@ def close_hook(): record.append("close_hook") mss2 = MemorySendStream( - send_all_hook, - wait_send_all_might_not_block_hook, - close_hook, + send_all_hook, wait_send_all_might_not_block_hook, close_hook ) assert mss2.send_all_hook is send_all_hook @@ -835,7 +649,7 @@ async def check(listener): # can't use pytest's tmpdir; if we try then macOS says "OSError: # AF_UNIX path too long" with tempfile.TemporaryDirectory() as tmpdir: - path = "{}/sock".format(tmpdir) + path = f"{tmpdir}/sock" await sock.bind(path) sock.listen(10) await check(SocketListener(sock)) diff --git a/trio/tests/test_threads.py b/trio/_tests/test_threads.py similarity index 55% rename from trio/tests/test_threads.py rename to trio/_tests/test_threads.py index 29d44adc4a..21eb7b12e8 100644 --- a/trio/tests/test_threads.py +++ b/trio/_tests/test_threads.py @@ -1,19 +1,28 @@ -import threading +import contextvars import queue as stdlib_queue +import re +import sys +import threading import time +import weakref +from functools import partial +from typing import Callable, Optional import pytest +from sniffio import current_async_library_cvar -from .. import _core -from .. import Event, CapacityLimiter, sleep -from ..testing import wait_all_tasks_blocked +from trio._core import TrioToken, current_trio_token + +from .. import CapacityLimiter, Event, _core, sleep +from .._core._tests.test_ki import ki_self +from .._core._tests.tutil import buggy_pypy_asyncgens from .._threads import ( - to_thread_run_sync, current_default_thread_limiter, from_thread_run, - from_thread_run_sync, BlockingTrioPortal + current_default_thread_limiter, + from_thread_run, + from_thread_run_sync, + to_thread_run_sync, ) - -from .._core.tests.test_ki import ki_self -from .._core.tests.tutil import slow +from ..testing import wait_all_tasks_blocked async def test_do_in_trio_thread(): @@ -36,9 +45,7 @@ def threadfn(): while child_thread.is_alive(): print("yawn") await sleep(0.01) - assert record == [ - ("start", child_thread), ("f", trio_thread), expected - ] + assert record == [("start", child_thread), ("f", trio_thread), expected] token = _core.current_trio_token() @@ -54,9 +61,7 @@ def f(record): record.append(("f", threading.current_thread())) raise ValueError - await check_case( - from_thread_run_sync, f, ("error", ValueError), trio_token=token - ) + await check_case(from_thread_run_sync, f, ("error", ValueError), trio_token=token) async def f(record): assert not _core.currently_ki_protected() @@ -102,6 +107,7 @@ def trio_thread_fn(): ki_self() finally: import sys + print("finally", sys.exc_info()) async def trio_thread_afn(): @@ -161,6 +167,117 @@ async def main(): assert record == ["sleeping", "cancelled"] +async def test_named_thread(): + ending = " from trio._tests.test_threads.test_named_thread" + + def inner(name="inner" + ending) -> threading.Thread: + assert threading.current_thread().name == name + return threading.current_thread() + + def f(name: str) -> Callable[[None], threading.Thread]: + return partial(inner, name) + + # test defaults + await to_thread_run_sync(inner) + await to_thread_run_sync(inner, thread_name=None) + + # functools.partial doesn't have __name__, so defaults to None + await to_thread_run_sync(f("None" + ending)) + + # test that you can set a custom name, and that it's reset afterwards + async def test_thread_name(name: str): + thread = await to_thread_run_sync(f(name), thread_name=name) + assert re.match("Trio thread [0-9]*", thread.name) + + await test_thread_name("") + await test_thread_name("fobiedoo") + await test_thread_name("name_longer_than_15_characters") + + await test_thread_name("💙") + + +def _get_thread_name(ident: Optional[int] = None) -> Optional[str]: + import ctypes + import ctypes.util + + libpthread_path = ctypes.util.find_library("pthread") + if not libpthread_path: + print(f"no pthread on {sys.platform})") + return None + libpthread = ctypes.CDLL(libpthread_path) + + pthread_getname_np = getattr(libpthread, "pthread_getname_np", None) + + # this should never fail on any platforms afaik + assert pthread_getname_np + + # thankfully getname signature doesn't differ between platforms + pthread_getname_np.argtypes = [ + ctypes.c_void_p, + ctypes.c_char_p, + ctypes.c_size_t, + ] + pthread_getname_np.restype = ctypes.c_int + + name_buffer = ctypes.create_string_buffer(b"", size=16) + if ident is None: + ident = threading.get_ident() + assert pthread_getname_np(ident, name_buffer, 16) == 0 + try: + return name_buffer.value.decode() + except UnicodeDecodeError as e: # pragma: no cover + # used for debugging when testing via CI + pytest.fail(f"value: {name_buffer.value!r}, exception: {e}") + + +# test os thread naming +# this depends on pthread being available, which is the case on 99.9% of linux machines +# and most mac machines. So unless the platform is linux it will just skip +# in case it fails to fetch the os thread name. +async def test_named_thread_os(): + def inner(name) -> threading.Thread: + os_thread_name = _get_thread_name() + if os_thread_name is None and sys.platform != "linux": + pytest.skip(f"no pthread OS support on {sys.platform}") + else: + assert os_thread_name == name[:15] + + return threading.current_thread() + + def f(name: str) -> Callable[[None], threading.Thread]: + return partial(inner, name) + + # test defaults + default = "None from trio._tests.test_threads.test_named_thread" + await to_thread_run_sync(f(default)) + await to_thread_run_sync(f(default), thread_name=None) + + # test that you can set a custom name, and that it's reset afterwards + async def test_thread_name(name: str, expected: Optional[str] = None): + if expected is None: + expected = name + thread = await to_thread_run_sync(f(expected), thread_name=name) + + os_thread_name = _get_thread_name(thread.ident) + assert os_thread_name is not None, "should skip earlier if this is the case" + assert re.match("Trio thread [0-9]*", os_thread_name) + + await test_thread_name("") + await test_thread_name("fobiedoo") + await test_thread_name("name_longer_than_15_characters") + + await test_thread_name("💙", expected="?") + + +async def test_has_pthread_setname_np(): + from trio._core._thread_cache import get_os_thread_name_func + + k = get_os_thread_name_func() + if k is None: + assert sys.platform != "linux" + pytest.skip(f"no pthread_setname_np on {sys.platform}") + + async def test_run_in_worker_thread(): trio_thread = threading.current_thread() @@ -240,7 +357,9 @@ async def child(q, cancellable): # Make sure that if trio.run exits, and then the thread finishes, then that's # handled gracefully. (Requires that the thread result machinery be prepared # for call_soon to raise RunFinishedError.) -def test_run_in_worker_thread_abandoned(capfd): +def test_run_in_worker_thread_abandoned(capfd, monkeypatch): + monkeypatch.setattr(_core._thread_cache, "IDLE_TIMEOUT", 0.01) + q1 = stdlib_queue.Queue() q2 = stdlib_queue.Queue() @@ -269,7 +388,8 @@ async def child(): # Make sure we don't have a "Exception in thread ..." dump to the console: out, err = capfd.readouterr() - assert not out and not err + assert "Exception in thread" not in out + assert "Exception in thread" not in err @pytest.mark.parametrize("MAX", [3, 5, 10]) @@ -296,8 +416,8 @@ async def test_run_in_worker_thread_limiter(MAX, cancel, use_default_limiter): try: # We used to use regular variables and 'nonlocal' here, but it turns # out that it's not safe to assign to closed-over variables that are - # visible in multiple threads, at least as of CPython 3.6 and PyPy - # 5.8: + # visible in multiple threads, at least as of CPython 3.10 and PyPy + # 7.3: # # https://bugs.python.org/issue30744 # https://bitbucket.org/pypy/pypy/issues/2591/ @@ -333,15 +453,9 @@ def thread_fn(cancel_scope): async def run_thread(event): with _core.CancelScope() as cancel_scope: await to_thread_run_sync( - thread_fn, - cancel_scope, - limiter=limiter_arg, - cancellable=cancel + thread_fn, cancel_scope, limiter=limiter_arg, cancellable=cancel ) - print( - "run_thread finished, cancelled:", - cancel_scope.cancelled_caught - ) + print("run_thread finished, cancelled:", cancel_scope.cancelled_caught) event.set() async with _core.open_nursery() as nursery: @@ -433,10 +547,10 @@ def release_on_behalf_of(self, borrower): async def test_run_in_worker_thread_fail_to_spawn(monkeypatch): # Test the unlikely but possible case where trying to spawn a thread fails - def bad_start(self): + def bad_start(self, *args): raise RuntimeError("the engines canna take it captain") - monkeypatch.setattr(threading.Thread, "start", bad_start) + monkeypatch.setattr(_core._thread_cache.ThreadCache, "start_thread_soon", bad_start) limiter = current_default_thread_limiter() assert limiter.borrowed_tokens == 0 @@ -461,6 +575,62 @@ def thread_fn(): assert callee_token == caller_token +async def test_trio_to_thread_run_sync_expected_error(): + # Test correct error when passed async function + async def async_fn(): # pragma: no cover + pass + + with pytest.raises(TypeError, match="expected a sync function"): + await to_thread_run_sync(async_fn) + + +trio_test_contextvar = contextvars.ContextVar("trio_test_contextvar") + + +async def test_trio_to_thread_run_sync_contextvars(): + trio_thread = threading.current_thread() + trio_test_contextvar.set("main") + + def f(): + value = trio_test_contextvar.get() + sniffio_cvar_value = current_async_library_cvar.get() + return (value, sniffio_cvar_value, threading.current_thread()) + + value, sniffio_cvar_value, child_thread = await to_thread_run_sync(f) + assert value == "main" + assert sniffio_cvar_value == None + assert child_thread != trio_thread + + def g(): + parent_value = trio_test_contextvar.get() + trio_test_contextvar.set("worker") + inner_value = trio_test_contextvar.get() + sniffio_cvar_value = current_async_library_cvar.get() + return ( + parent_value, + inner_value, + sniffio_cvar_value, + threading.current_thread(), + ) + + ( + parent_value, + inner_value, + sniffio_cvar_value, + child_thread, + ) = await to_thread_run_sync(g) + current_value = trio_test_contextvar.get() + sniffio_outer_value = current_async_library_cvar.get() + assert parent_value == "main" + assert inner_value == "worker" + assert current_value == "main", ( + "The contextvar value set on the worker would not propagate back to the main" + " thread" + ) + assert sniffio_cvar_value is None + assert sniffio_outer_value == "trio" + + async def test_trio_from_thread_run_sync(): # Test that to_thread_run_sync correctly "hands off" the trio token to # trio.from_thread.run_sync() @@ -471,6 +641,16 @@ def thread_fn(): trio_time = await to_thread_run_sync(thread_fn) assert isinstance(trio_time, float) + # Test correct error when passed async function + async def async_fn(): # pragma: no cover + pass + + def thread_fn(): + from_thread_run_sync(async_fn) + + with pytest.raises(TypeError, match="expected a sync function"): + await to_thread_run_sync(thread_fn) + async def test_trio_from_thread_run(): # Test that to_thread_run_sync correctly "hands off" the trio token to @@ -488,6 +668,13 @@ def thread_fn(): await to_thread_run_sync(thread_fn) assert record == ["in thread", "back in trio"] + # Test correct error when passed sync function + def sync_fn(): # pragma: no cover + pass + + with pytest.raises(TypeError, match="appears to be synchronous"): + await to_thread_run_sync(from_thread_run, sync_fn) + async def test_trio_from_thread_token(): # Test that to_thread_run_sync and spawned trio.from_thread.run_sync() @@ -505,9 +692,7 @@ async def test_trio_from_thread_token_kwarg(): # Test that to_thread_run_sync and spawned trio.from_thread.run_sync() can # use an explicitly defined token def thread_fn(token): - callee_token = from_thread_run_sync( - _core.current_trio_token, trio_token=token - ) + callee_token = from_thread_run_sync(_core.current_trio_token, trio_token=token) return callee_token caller_token = _core.current_trio_token() @@ -523,46 +708,159 @@ async def test_from_thread_no_token(): from_thread_run_sync(_core.current_time) -def test_run_fn_as_system_task_catched_badly_typed_token(): - with pytest.raises(RuntimeError): - from_thread_run_sync( - _core.current_time, trio_token="Not TrioTokentype" +async def test_trio_from_thread_run_sync_contextvars(): + trio_test_contextvar.set("main") + + def thread_fn(): + thread_parent_value = trio_test_contextvar.get() + trio_test_contextvar.set("worker") + thread_current_value = trio_test_contextvar.get() + sniffio_cvar_thread_pre_value = current_async_library_cvar.get() + + def back_in_main(): + back_parent_value = trio_test_contextvar.get() + trio_test_contextvar.set("back_in_main") + back_current_value = trio_test_contextvar.get() + sniffio_cvar_back_value = current_async_library_cvar.get() + return back_parent_value, back_current_value, sniffio_cvar_back_value + + ( + back_parent_value, + back_current_value, + sniffio_cvar_back_value, + ) = from_thread_run_sync(back_in_main) + thread_after_value = trio_test_contextvar.get() + sniffio_cvar_thread_after_value = current_async_library_cvar.get() + return ( + thread_parent_value, + thread_current_value, + thread_after_value, + sniffio_cvar_thread_pre_value, + sniffio_cvar_thread_after_value, + back_parent_value, + back_current_value, + sniffio_cvar_back_value, ) + ( + thread_parent_value, + thread_current_value, + thread_after_value, + sniffio_cvar_thread_pre_value, + sniffio_cvar_thread_after_value, + back_parent_value, + back_current_value, + sniffio_cvar_back_value, + ) = await to_thread_run_sync(thread_fn) + current_value = trio_test_contextvar.get() + sniffio_cvar_out_value = current_async_library_cvar.get() + assert current_value == thread_parent_value == "main" + assert thread_current_value == back_parent_value == thread_after_value == "worker" + assert back_current_value == "back_in_main" + assert sniffio_cvar_out_value == sniffio_cvar_back_value == "trio" + assert sniffio_cvar_thread_pre_value == sniffio_cvar_thread_after_value == None + + +async def test_trio_from_thread_run_contextvars(): + trio_test_contextvar.set("main") + + def thread_fn(): + thread_parent_value = trio_test_contextvar.get() + trio_test_contextvar.set("worker") + thread_current_value = trio_test_contextvar.get() + sniffio_cvar_thread_pre_value = current_async_library_cvar.get() + + async def async_back_in_main(): + back_parent_value = trio_test_contextvar.get() + trio_test_contextvar.set("back_in_main") + back_current_value = trio_test_contextvar.get() + sniffio_cvar_back_value = current_async_library_cvar.get() + return back_parent_value, back_current_value, sniffio_cvar_back_value + + ( + back_parent_value, + back_current_value, + sniffio_cvar_back_value, + ) = from_thread_run(async_back_in_main) + thread_after_value = trio_test_contextvar.get() + sniffio_cvar_thread_after_value = current_async_library_cvar.get() + return ( + thread_parent_value, + thread_current_value, + thread_after_value, + sniffio_cvar_thread_pre_value, + sniffio_cvar_thread_after_value, + back_parent_value, + back_current_value, + sniffio_cvar_back_value, + ) -async def test_do_in_trio_thread_from_trio_thread_legacy(): - # This check specifically confirms that a RuntimeError will be raised if - # the old BlockingTrIoPortal API calls into a trio loop while already - # running inside of one. - portal = BlockingTrioPortal() + ( + thread_parent_value, + thread_current_value, + thread_after_value, + sniffio_cvar_thread_pre_value, + sniffio_cvar_thread_after_value, + back_parent_value, + back_current_value, + sniffio_cvar_back_value, + ) = await to_thread_run_sync(thread_fn) + current_value = trio_test_contextvar.get() + assert current_value == thread_parent_value == "main" + assert thread_current_value == back_parent_value == thread_after_value == "worker" + assert back_current_value == "back_in_main" + assert sniffio_cvar_thread_pre_value == sniffio_cvar_thread_after_value == None + assert sniffio_cvar_back_value == "trio" + +def test_run_fn_as_system_task_catched_badly_typed_token(): with pytest.raises(RuntimeError): - portal.run_sync(lambda: None) # pragma: no branch + from_thread_run_sync(_core.current_time, trio_token="Not TrioTokentype") - async def foo(): # pragma: no cover - pass +async def test_from_thread_inside_trio_thread(): + def not_called(): # pragma: no cover + assert False + + trio_token = _core.current_trio_token() with pytest.raises(RuntimeError): - portal.run(foo) + from_thread_run_sync(not_called, trio_token=trio_token) -async def test_BlockingTrioPortal_with_explicit_TrioToken(): - # This tests the deprecated BlockingTrioPortal with a token passed in to - # confirm that both methods of making a portal are supported by - # trio.from_thread - token = _core.current_trio_token() +@pytest.mark.skipif(buggy_pypy_asyncgens, reason="pypy 7.2.0 is buggy") +def test_from_thread_run_during_shutdown(): + save = [] + record = [] + + async def agen(): + try: + yield + finally: + with pytest.raises(_core.RunFinishedError), _core.CancelScope(shield=True): + await to_thread_run_sync(from_thread_run, sleep, 0) + record.append("ok") + + async def main(): + save.append(agen()) + await save[-1].asend(None) + + _core.run(main) + assert record == ["ok"] + - def worker_thread(token): - with pytest.raises(RuntimeError): - BlockingTrioPortal() - portal = BlockingTrioPortal(token) - return portal.run_sync(threading.current_thread) +async def test_trio_token_weak_referenceable(): + token = current_trio_token() + assert isinstance(token, TrioToken) + weak_reference = weakref.ref(token) + assert token is weak_reference() - t = await to_thread_run_sync(worker_thread, token) - assert t == threading.current_thread() +async def test_unsafe_cancellable_kwarg(): + # This is a stand in for a numpy ndarray or other objects + # that (maybe surprisingly) lack a notion of truthiness + class BadBool: + def __bool__(self): + raise NotImplementedError -def test_BlockingTrioPortal_deprecated_export(recwarn): - import trio - btp = trio.BlockingTrioPortal - assert btp is BlockingTrioPortal + with pytest.raises(NotImplementedError): + await to_thread_run_sync(int, cancellable=BadBool()) diff --git a/trio/tests/test_timeouts.py b/trio/_tests/test_timeouts.py similarity index 81% rename from trio/tests/test_timeouts.py rename to trio/_tests/test_timeouts.py index 382c015b1d..9507d88a78 100644 --- a/trio/tests/test_timeouts.py +++ b/trio/_tests/test_timeouts.py @@ -1,11 +1,12 @@ +import time + import outcome import pytest -import time -from .._core.tests.tutil import slow from .. import _core -from ..testing import assert_checkpoints +from .._core._tests.tutil import slow from .._timeouts import * +from ..testing import assert_checkpoints async def check_takes_about(f, expected_dur): @@ -53,9 +54,6 @@ async def sleep_2(): await check_takes_about(sleep_2, TARGET) - with pytest.raises(ValueError): - await sleep(-1) - with assert_checkpoints(): await sleep(0) # This also serves as a test of the trivial move_on_at @@ -66,10 +64,6 @@ async def sleep_2(): @slow async def test_move_on_after(): - with pytest.raises(ValueError): - with move_on_after(-1): - pass # pragma: no cover - async def sleep_3(): with move_on_after(TARGET): await sleep(100) @@ -99,6 +93,29 @@ async def sleep_5(): with fail_after(100): await sleep(0) - with pytest.raises(ValueError): - with fail_after(-1): - pass # pragma: no cover + +async def test_timeouts_raise_value_error(): + # deadlines are allowed to be negative, but not delays. + # neither delays nor deadlines are allowed to be NaN + + nan = float("nan") + + for fun, val in ( + (sleep, -1), + (sleep, nan), + (sleep_until, nan), + ): + with pytest.raises(ValueError): + await fun(val) + + for cm, val in ( + (fail_after, -1), + (fail_after, nan), + (fail_at, nan), + (move_on_after, -1), + (move_on_after, nan), + (move_on_at, nan), + ): + with pytest.raises(ValueError): + with cm(val): + pass # pragma: no cover diff --git a/trio/_tests/test_tracing.py b/trio/_tests/test_tracing.py new file mode 100644 index 0000000000..07d1ff7609 --- /dev/null +++ b/trio/_tests/test_tracing.py @@ -0,0 +1,59 @@ +import trio + + +async def coro1(event: trio.Event): + event.set() + await trio.sleep_forever() + + +async def coro2(event: trio.Event): + await coro1(event) + + +async def coro3(event: trio.Event): + await coro2(event) + + +async def coro2_async_gen(event: trio.Event): + yield await trio.lowlevel.checkpoint() + yield await coro1(event) + yield await trio.lowlevel.checkpoint() + + +async def coro3_async_gen(event: trio.Event): + async for x in coro2_async_gen(event): + pass + + +async def test_task_iter_await_frames(): + async with trio.open_nursery() as nursery: + event = trio.Event() + nursery.start_soon(coro3, event) + await event.wait() + + (task,) = nursery.child_tasks + + assert [frame.f_code.co_name for frame, _ in task.iter_await_frames()][:3] == [ + "coro3", + "coro2", + "coro1", + ] + + nursery.cancel_scope.cancel() + + +async def test_task_iter_await_frames_async_gen(): + async with trio.open_nursery() as nursery: + event = trio.Event() + nursery.start_soon(coro3_async_gen, event) + await event.wait() + + (task,) = nursery.child_tasks + + assert [frame.f_code.co_name for frame, _ in task.iter_await_frames()][:3] == [ + "coro3_async_gen", + "coro2_async_gen", + "coro1", + ] + + nursery.cancel_scope.cancel() diff --git a/trio/tests/test_unix_pipes.py b/trio/_tests/test_unix_pipes.py similarity index 82% rename from trio/tests/test_unix_pipes.py rename to trio/_tests/test_unix_pipes.py index 9349139e23..acee75aafb 100644 --- a/trio/tests/test_unix_pipes.py +++ b/trio/_tests/test_unix_pipes.py @@ -1,13 +1,13 @@ import errno -import select import os -import tempfile +import select +import sys import pytest -from .._core.tests.tutil import gc_collect_harder -from .. import _core, move_on_after -from ..testing import wait_all_tasks_blocked, check_one_way_stream +from .. import _core +from .._core._tests.tutil import gc_collect_harder, skip_if_fbsd_pipes_broken +from ..testing import check_one_way_stream, wait_all_tasks_blocked posix = os.name == "posix" pytestmark = pytest.mark.skipif(not posix, reason="posix only") @@ -194,9 +194,7 @@ async def patched_wait_readable(*args, **kwargs): await orig_wait_readable(*args, **kwargs) await r.aclose() - monkeypatch.setattr( - _core._run.TheIOManager, "wait_readable", patched_wait_readable - ) + monkeypatch.setattr(_core._run.TheIOManager, "wait_readable", patched_wait_readable) s, r = await make_pipe() async with s, r: async with _core.open_nursery() as nursery: @@ -224,18 +222,35 @@ async def patched_wait_writable(*args, **kwargs): await orig_wait_writable(*args, **kwargs) await s.aclose() - monkeypatch.setattr( - _core._run.TheIOManager, "wait_writable", patched_wait_writable - ) + monkeypatch.setattr(_core._run.TheIOManager, "wait_writable", patched_wait_writable) s, r = await make_clogged_pipe() async with s, r: async with _core.open_nursery() as nursery: nursery.start_soon(expect_closedresourceerror) await wait_all_tasks_blocked() - # Trigger everything by waking up the sender - await r.receive_some(10000) - - + # Trigger everything by waking up the sender. On ppc64el, PIPE_BUF + # is 8192 but make_clogged_pipe() ends up writing a total of + # 1048576 bytes before the pipe is full, and then a subsequent + # receive_some(10000) isn't sufficient for orig_wait_writable() to + # return for our subsequent aclose() call. It's necessary to empty + # the pipe further before this happens. So we loop here until the + # pipe is empty to make sure that the sender wakes up even in this + # case. Otherwise patched_wait_writable() never gets to the + # aclose(), so expect_closedresourceerror() never returns, the + # nursery never finishes all tasks and this test hangs. + received_data = await r.receive_some(10000) + while received_data: + received_data = await r.receive_some(10000) + + +# On FreeBSD, directories are readable, and we haven't found any other trick +# for making an unreadable fd, so there's no way to run this test. Fortunately +# the logic this is testing doesn't depend on the platform, so testing on +# other platforms is probably good enough. +@pytest.mark.skipif( + sys.platform.startswith("freebsd"), + reason="no way to make read() return a bizarro error on FreeBSD", +) async def test_bizarro_OSError_from_receive(): # Make sure that if the read syscall returns some bizarro error, then we # get a BrokenResourceError. This is incredibly unlikely; there's almost @@ -255,5 +270,6 @@ async def test_bizarro_OSError_from_receive(): os.close(dir_fd) +@skip_if_fbsd_pipes_broken async def test_pipe_fully(): await check_one_way_stream(make_pipe, make_clogged_pipe) diff --git a/trio/_tests/test_util.py b/trio/_tests/test_util.py new file mode 100644 index 0000000000..a4df6d35b4 --- /dev/null +++ b/trio/_tests/test_util.py @@ -0,0 +1,194 @@ +import signal +import sys + +import pytest + +import trio + +from .. import _core +from .._core._tests.tutil import ( + create_asyncio_future_in_new_loop, + ignore_coroutine_never_awaited_warnings, +) +from .._util import ( + ConflictDetector, + Final, + NoPublicConstructor, + coroutine_or_error, + generic_function, + is_main_thread, + signal_raise, +) +from ..testing import wait_all_tasks_blocked + + +def test_signal_raise(): + record = [] + + def handler(signum, _): + record.append(signum) + + old = signal.signal(signal.SIGFPE, handler) + try: + signal_raise(signal.SIGFPE) + finally: + signal.signal(signal.SIGFPE, old) + assert record == [signal.SIGFPE] + + +async def test_ConflictDetector(): + ul1 = ConflictDetector("ul1") + ul2 = ConflictDetector("ul2") + + with ul1: + with ul2: + print("ok") + + with pytest.raises(_core.BusyResourceError) as excinfo: + with ul1: + with ul1: + pass # pragma: no cover + assert "ul1" in str(excinfo.value) + + async def wait_with_ul1(): + with ul1: + await wait_all_tasks_blocked() + + with pytest.raises(_core.BusyResourceError) as excinfo: + async with _core.open_nursery() as nursery: + nursery.start_soon(wait_with_ul1) + nursery.start_soon(wait_with_ul1) + assert "ul1" in str(excinfo.value) + + +def test_module_metadata_is_fixed_up(): + import trio + import trio.testing + + assert trio.Cancelled.__module__ == "trio" + assert trio.open_nursery.__module__ == "trio" + assert trio.abc.Stream.__module__ == "trio.abc" + assert trio.lowlevel.wait_task_rescheduled.__module__ == "trio.lowlevel" + assert trio.testing.trio_test.__module__ == "trio.testing" + + # Also check methods + assert trio.lowlevel.ParkingLot.__init__.__module__ == "trio.lowlevel" + assert trio.abc.Stream.send_all.__module__ == "trio.abc" + + # And names + assert trio.Cancelled.__name__ == "Cancelled" + assert trio.Cancelled.__qualname__ == "Cancelled" + assert trio.abc.SendStream.send_all.__name__ == "send_all" + assert trio.abc.SendStream.send_all.__qualname__ == "SendStream.send_all" + assert trio.to_thread.__name__ == "trio.to_thread" + assert trio.to_thread.run_sync.__name__ == "run_sync" + assert trio.to_thread.run_sync.__qualname__ == "run_sync" + + +async def test_is_main_thread(): + assert is_main_thread() + + def not_main_thread(): + assert not is_main_thread() + + await trio.to_thread.run_sync(not_main_thread) + + +# @coroutine is deprecated since python 3.8, which is fine with us. +@pytest.mark.filterwarnings("ignore:.*@coroutine.*:DeprecationWarning") +def test_coroutine_or_error(): + class Deferred: + "Just kidding" + + with ignore_coroutine_never_awaited_warnings(): + + async def f(): # pragma: no cover + pass + + with pytest.raises(TypeError) as excinfo: + coroutine_or_error(f()) + assert "expecting an async function" in str(excinfo.value) + + import asyncio + + if sys.version_info < (3, 11): + + @asyncio.coroutine + def generator_based_coro(): # pragma: no cover + yield from asyncio.sleep(1) + + with pytest.raises(TypeError) as excinfo: + coroutine_or_error(generator_based_coro()) + assert "asyncio" in str(excinfo.value) + + with pytest.raises(TypeError) as excinfo: + coroutine_or_error(create_asyncio_future_in_new_loop()) + assert "asyncio" in str(excinfo.value) + + with pytest.raises(TypeError) as excinfo: + coroutine_or_error(create_asyncio_future_in_new_loop) + assert "asyncio" in str(excinfo.value) + + with pytest.raises(TypeError) as excinfo: + coroutine_or_error(Deferred()) + assert "twisted" in str(excinfo.value) + + with pytest.raises(TypeError) as excinfo: + coroutine_or_error(lambda: Deferred()) + assert "twisted" in str(excinfo.value) + + with pytest.raises(TypeError) as excinfo: + coroutine_or_error(len, [[1, 2, 3]]) + + assert "appears to be synchronous" in str(excinfo.value) + + async def async_gen(arg): # pragma: no cover + yield + + with pytest.raises(TypeError) as excinfo: + coroutine_or_error(async_gen, [0]) + msg = "expected an async function but got an async generator" + assert msg in str(excinfo.value) + + # Make sure no references are kept around to keep anything alive + del excinfo + + +def test_generic_function(): + @generic_function + def test_func(arg): + """Look, a docstring!""" + return arg + + assert test_func is test_func[int] is test_func[int, str] + assert test_func(42) == test_func[int](42) == 42 + assert test_func.__doc__ == "Look, a docstring!" + assert test_func.__qualname__ == "test_generic_function..test_func" + assert test_func.__name__ == "test_func" + assert test_func.__module__ == __name__ + + +def test_final_metaclass(): + class FinalClass(metaclass=Final): + pass + + with pytest.raises(TypeError): + + class SubClass(FinalClass): + pass + + +def test_no_public_constructor_metaclass(): + class SpecialClass(metaclass=NoPublicConstructor): + pass + + with pytest.raises(TypeError): + SpecialClass() + + with pytest.raises(TypeError): + + class SubClass(SpecialClass): + pass + + # Private constructor should not raise + assert isinstance(SpecialClass._create(), SpecialClass) diff --git a/trio/tests/test_wait_for_object.py b/trio/_tests/test_wait_for_object.py similarity index 85% rename from trio/tests/test_wait_for_object.py rename to trio/_tests/test_wait_for_object.py index 3c3830ea39..ea16684289 100644 --- a/trio/tests/test_wait_for_object.py +++ b/trio/_tests/test_wait_for_object.py @@ -2,17 +2,18 @@ import pytest -on_windows = (os.name == "nt") +on_windows = os.name == "nt" # Mark all the tests in this file as being windows-only pytestmark = pytest.mark.skipif(not on_windows, reason="windows only") -from .._core.tests.tutil import slow import trio -from .. import _core -from .. import _timeouts + +from .. import _core, _timeouts +from .._core._tests.tutil import slow + if on_windows: from .._core._windows_cffi import ffi, kernel32 - from .._wait_for_object import WaitForSingleObject, WaitForMultipleObjects_sync + from .._wait_for_object import WaitForMultipleObjects_sync, WaitForSingleObject async def test_WaitForMultipleObjects_sync(): @@ -29,7 +30,7 @@ async def test_WaitForMultipleObjects_sync(): kernel32.SetEvent(handle1) WaitForMultipleObjects_sync(handle1) kernel32.CloseHandle(handle1) - print('test_WaitForMultipleObjects_sync one OK') + print("test_WaitForMultipleObjects_sync one OK") # Two handles, signal first handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) @@ -38,7 +39,7 @@ async def test_WaitForMultipleObjects_sync(): WaitForMultipleObjects_sync(handle1, handle2) kernel32.CloseHandle(handle1) kernel32.CloseHandle(handle2) - print('test_WaitForMultipleObjects_sync set first OK') + print("test_WaitForMultipleObjects_sync set first OK") # Two handles, signal second handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) @@ -47,7 +48,7 @@ async def test_WaitForMultipleObjects_sync(): WaitForMultipleObjects_sync(handle1, handle2) kernel32.CloseHandle(handle1) kernel32.CloseHandle(handle2) - print('test_WaitForMultipleObjects_sync set second OK') + print("test_WaitForMultipleObjects_sync set second OK") # Two handles, close first handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) @@ -56,7 +57,7 @@ async def test_WaitForMultipleObjects_sync(): with pytest.raises(OSError): WaitForMultipleObjects_sync(handle1, handle2) kernel32.CloseHandle(handle2) - print('test_WaitForMultipleObjects_sync close first OK') + print("test_WaitForMultipleObjects_sync close first OK") # Two handles, close second handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) @@ -65,7 +66,7 @@ async def test_WaitForMultipleObjects_sync(): with pytest.raises(OSError): WaitForMultipleObjects_sync(handle1, handle2) kernel32.CloseHandle(handle1) - print('test_WaitForMultipleObjects_sync close second OK') + print("test_WaitForMultipleObjects_sync close second OK") @slow @@ -89,7 +90,7 @@ async def test_WaitForMultipleObjects_sync_slow(): t1 = _core.current_time() assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT kernel32.CloseHandle(handle1) - print('test_WaitForMultipleObjects_sync_slow one OK') + print("test_WaitForMultipleObjects_sync_slow one OK") # Two handles, signal first handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) @@ -97,8 +98,7 @@ async def test_WaitForMultipleObjects_sync_slow(): t0 = _core.current_time() async with _core.open_nursery() as nursery: nursery.start_soon( - trio.to_thread.run_sync, WaitForMultipleObjects_sync, handle1, - handle2 + trio.to_thread.run_sync, WaitForMultipleObjects_sync, handle1, handle2 ) await _timeouts.sleep(TIMEOUT) kernel32.SetEvent(handle1) @@ -106,7 +106,7 @@ async def test_WaitForMultipleObjects_sync_slow(): assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT kernel32.CloseHandle(handle1) kernel32.CloseHandle(handle2) - print('test_WaitForMultipleObjects_sync_slow thread-set first OK') + print("test_WaitForMultipleObjects_sync_slow thread-set first OK") # Two handles, signal second handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) @@ -114,8 +114,7 @@ async def test_WaitForMultipleObjects_sync_slow(): t0 = _core.current_time() async with _core.open_nursery() as nursery: nursery.start_soon( - trio.to_thread.run_sync, WaitForMultipleObjects_sync, handle1, - handle2 + trio.to_thread.run_sync, WaitForMultipleObjects_sync, handle1, handle2 ) await _timeouts.sleep(TIMEOUT) kernel32.SetEvent(handle2) @@ -123,7 +122,7 @@ async def test_WaitForMultipleObjects_sync_slow(): assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT kernel32.CloseHandle(handle1) kernel32.CloseHandle(handle2) - print('test_WaitForMultipleObjects_sync_slow thread-set second OK') + print("test_WaitForMultipleObjects_sync_slow thread-set second OK") async def test_WaitForSingleObject(): @@ -135,7 +134,7 @@ async def test_WaitForSingleObject(): kernel32.SetEvent(handle) await WaitForSingleObject(handle) # should return at once kernel32.CloseHandle(handle) - print('test_WaitForSingleObject already set OK') + print("test_WaitForSingleObject already set OK") # Test already set, as int handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) @@ -143,21 +142,21 @@ async def test_WaitForSingleObject(): kernel32.SetEvent(handle) await WaitForSingleObject(handle_int) # should return at once kernel32.CloseHandle(handle) - print('test_WaitForSingleObject already set OK') + print("test_WaitForSingleObject already set OK") # Test already closed handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) kernel32.CloseHandle(handle) with pytest.raises(OSError): await WaitForSingleObject(handle) # should return at once - print('test_WaitForSingleObject already closed OK') + print("test_WaitForSingleObject already closed OK") # Not a handle with pytest.raises(TypeError): await WaitForSingleObject("not a handle") # Wrong type # with pytest.raises(OSError): # await WaitForSingleObject(99) # If you're unlucky, it actually IS a handle :( - print('test_WaitForSingleObject not a handle OK') + print("test_WaitForSingleObject not a handle OK") @slow @@ -185,7 +184,7 @@ async def signal_soon_async(handle): kernel32.CloseHandle(handle) t1 = _core.current_time() assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT - print('test_WaitForSingleObject_slow set from task OK') + print("test_WaitForSingleObject_slow set from task OK") # Test handle is SET after TIMEOUT in separate coroutine, as int @@ -200,7 +199,7 @@ async def signal_soon_async(handle): kernel32.CloseHandle(handle) t1 = _core.current_time() assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT - print('test_WaitForSingleObject_slow set from task as int OK') + print("test_WaitForSingleObject_slow set from task as int OK") # Test handle is CLOSED after 1 sec - NOPE see comment above @@ -215,4 +214,4 @@ async def signal_soon_async(handle): kernel32.CloseHandle(handle) t1 = _core.current_time() assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT - print('test_WaitForSingleObject_slow cancellation OK') + print("test_WaitForSingleObject_slow cancellation OK") diff --git a/trio/tests/test_windows_pipes.py b/trio/_tests/test_windows_pipes.py similarity index 85% rename from trio/tests/test_windows_pipes.py rename to trio/_tests/test_windows_pipes.py index 864aaf768e..5c4bae7d25 100644 --- a/trio/tests/test_windows_pipes.py +++ b/trio/_tests/test_windows_pipes.py @@ -1,22 +1,24 @@ -import errno -import select +import sys +from typing import Any, Tuple -import os import pytest -from .._core.tests.tutil import gc_collect_harder -from .. import _core, move_on_after -from ..testing import wait_all_tasks_blocked, check_one_way_stream +from .. import _core +from ..testing import check_one_way_stream, wait_all_tasks_blocked -windows = os.name == "nt" -pytestmark = pytest.mark.skipif(not windows, reason="windows only") -if windows: - from .._windows_pipes import PipeSendStream, PipeReceiveStream - from .._core._windows_cffi import _handle, kernel32 +if sys.platform == "win32": from asyncio.windows_utils import pipe + from .._core._windows_cffi import _handle, kernel32 + from .._windows_pipes import PipeReceiveStream, PipeSendStream +else: + pytestmark = pytest.mark.skip(reason="windows only") + pipe: Any = None + PipeSendStream: Any = None + PipeReceiveStream: Any = None + -async def make_pipe() -> "Tuple[PipeSendStream, PipeReceiveStream]": +async def make_pipe() -> Tuple[PipeSendStream, PipeReceiveStream]: """Makes a new pair of pipes.""" (r, w) = pipe() return PipeSendStream(w), PipeReceiveStream(r) diff --git a/trio/tests/tools/__init__.py b/trio/_tests/tools/__init__.py similarity index 100% rename from trio/tests/tools/__init__.py rename to trio/_tests/tools/__init__.py diff --git a/trio/tests/tools/test_gen_exports.py b/trio/_tests/tools/test_gen_exports.py similarity index 86% rename from trio/tests/tools/test_gen_exports.py rename to trio/_tests/tools/test_gen_exports.py index 6c1fb0d668..9436105fa4 100644 --- a/trio/tests/tools/test_gen_exports.py +++ b/trio/_tests/tools/test_gen_exports.py @@ -1,13 +1,8 @@ import ast -import astor + import pytest -import os -import sys -from shutil import copyfile -from trio._tools.gen_exports import ( - get_public_methods, create_passthrough_args, process -) +from trio._tools.gen_exports import create_passthrough_args, get_public_methods, process SOURCE = '''from _run import _public @@ -43,11 +38,11 @@ def test_create_pass_through_args(): ("def f(one, *args)", "(one, *args)"), ( "def f(one, *args, kw1, kw2=None, **kwargs)", - "(one, *args, kw1=kw1, kw2=kw2, **kwargs)" + "(one, *args, kw1=kw1, kw2=kw2, **kwargs)", ), ] - for (funcdef, expected) in testcases: + for funcdef, expected in testcases: func_node = ast.parse(funcdef + ":\n pass").body[0] assert isinstance(func_node, ast.FunctionDef) assert create_passthrough_args(func_node) == expected diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json new file mode 100644 index 0000000000..60132e07fd --- /dev/null +++ b/trio/_tests/verify_types.json @@ -0,0 +1,167 @@ +{ + "generalDiagnostics": [], + "summary": { + "errorCount": 0, + "filesAnalyzed": 8, + "informationCount": 0, + "warningCount": 0 + }, + "typeCompleteness": { + "completenessScore": 0.9072, + "exportedSymbolCounts": { + "withAmbiguousType": 1, + "withKnownType": 567, + "withUnknownType": 57 + }, + "ignoreUnknownTypesFromImports": true, + "missingClassDocStringCount": 1, + "missingDefaultParamCount": 0, + "missingFunctionDocStringCount": 4, + "moduleName": "trio", + "modules": [ + { + "name": "trio" + }, + { + "name": "trio.abc" + }, + { + "name": "trio.from_thread" + }, + { + "name": "trio.lowlevel" + }, + { + "name": "trio.socket" + }, + { + "name": "trio.testing" + }, + { + "name": "trio.tests" + }, + { + "name": "trio.to_thread" + } + ], + "otherSymbolCounts": { + "withAmbiguousType": 3, + "withKnownType": 574, + "withUnknownType": 76 + }, + "packageName": "trio", + "symbols": [ + "trio.__deprecated_attributes__", + "trio._core._entry_queue.TrioToken.run_sync_soon", + "trio._core._mock_clock.MockClock.jump", + "trio._core._run.Nursery.start", + "trio._core._run.Nursery.start_soon", + "trio._core._run.TaskStatus.__repr__", + "trio._core._run.TaskStatus.started", + "trio._core._unbounded_queue.UnboundedQueue.__aiter__", + "trio._core._unbounded_queue.UnboundedQueue.__anext__", + "trio._core._unbounded_queue.UnboundedQueue.__repr__", + "trio._core._unbounded_queue.UnboundedQueue.empty", + "trio._core._unbounded_queue.UnboundedQueue.get_batch", + "trio._core._unbounded_queue.UnboundedQueue.get_batch_nowait", + "trio._core._unbounded_queue.UnboundedQueue.qsize", + "trio._core._unbounded_queue.UnboundedQueue.statistics", + "trio._dtls.DTLSChannel.__init__", + "trio._dtls.DTLSEndpoint.serve", + "trio._highlevel_socket.SocketStream.getsockopt", + "trio._highlevel_socket.SocketStream.send_all", + "trio._highlevel_socket.SocketStream.setsockopt", + "trio._ssl.SSLListener.__init__", + "trio._ssl.SSLListener.accept", + "trio._ssl.SSLListener.aclose", + "trio._ssl.SSLStream.__dir__", + "trio._ssl.SSLStream.__getattr__", + "trio._ssl.SSLStream.__init__", + "trio._ssl.SSLStream.__setattr__", + "trio._ssl.SSLStream.aclose", + "trio._ssl.SSLStream.do_handshake", + "trio._ssl.SSLStream.receive_some", + "trio._ssl.SSLStream.send_all", + "trio._ssl.SSLStream.transport_stream", + "trio._ssl.SSLStream.unwrap", + "trio._ssl.SSLStream.wait_send_all_might_not_block", + "trio._subprocess.Process.__aenter__", + "trio._subprocess.Process.__init__", + "trio._subprocess.Process.__repr__", + "trio._subprocess.Process.aclose", + "trio._subprocess.Process.args", + "trio._subprocess.Process.encoding", + "trio._subprocess.Process.errors", + "trio._subprocess.Process.kill", + "trio._subprocess.Process.pid", + "trio._subprocess.Process.poll", + "trio._subprocess.Process.returncode", + "trio._subprocess.Process.send_signal", + "trio._subprocess.Process.terminate", + "trio._subprocess.Process.wait", + "trio.current_time", + "trio.from_thread.run", + "trio.from_thread.run_sync", + "trio.lowlevel.cancel_shielded_checkpoint", + "trio.lowlevel.current_clock", + "trio.lowlevel.current_root_task", + "trio.lowlevel.current_statistics", + "trio.lowlevel.current_trio_token", + "trio.lowlevel.currently_ki_protected", + "trio.lowlevel.notify_closing", + "trio.lowlevel.open_process", + "trio.lowlevel.permanently_detach_coroutine_object", + "trio.lowlevel.reattach_detached_coroutine_object", + "trio.lowlevel.reschedule", + "trio.lowlevel.spawn_system_task", + "trio.lowlevel.start_guest_run", + "trio.lowlevel.start_thread_soon", + "trio.lowlevel.temporarily_detach_coroutine_object", + "trio.lowlevel.wait_readable", + "trio.lowlevel.wait_writable", + "trio.open_ssl_over_tcp_listeners", + "trio.open_ssl_over_tcp_stream", + "trio.open_tcp_listeners", + "trio.open_tcp_stream", + "trio.open_unix_socket", + "trio.run", + "trio.run_process", + "trio.serve_listeners", + "trio.serve_ssl_over_tcp", + "trio.serve_tcp", + "trio.testing._memory_streams.MemoryReceiveStream.__init__", + "trio.testing._memory_streams.MemoryReceiveStream.aclose", + "trio.testing._memory_streams.MemoryReceiveStream.close", + "trio.testing._memory_streams.MemoryReceiveStream.close_hook", + "trio.testing._memory_streams.MemoryReceiveStream.put_data", + "trio.testing._memory_streams.MemoryReceiveStream.put_eof", + "trio.testing._memory_streams.MemoryReceiveStream.receive_some", + "trio.testing._memory_streams.MemoryReceiveStream.receive_some_hook", + "trio.testing._memory_streams.MemorySendStream.__init__", + "trio.testing._memory_streams.MemorySendStream.aclose", + "trio.testing._memory_streams.MemorySendStream.close", + "trio.testing._memory_streams.MemorySendStream.close_hook", + "trio.testing._memory_streams.MemorySendStream.get_data", + "trio.testing._memory_streams.MemorySendStream.get_data_nowait", + "trio.testing._memory_streams.MemorySendStream.send_all", + "trio.testing._memory_streams.MemorySendStream.send_all_hook", + "trio.testing._memory_streams.MemorySendStream.wait_send_all_might_not_block", + "trio.testing._memory_streams.MemorySendStream.wait_send_all_might_not_block_hook", + "trio.testing.assert_checkpoints", + "trio.testing.assert_no_checkpoints", + "trio.testing.check_half_closeable_stream", + "trio.testing.check_one_way_stream", + "trio.testing.check_two_way_stream", + "trio.testing.lockstep_stream_one_way_pair", + "trio.testing.lockstep_stream_pair", + "trio.testing.memory_stream_one_way_pair", + "trio.testing.memory_stream_pair", + "trio.testing.memory_stream_pump", + "trio.testing.open_stream_to_socket_listener", + "trio.testing.trio_test", + "trio.testing.wait_all_tasks_blocked", + "trio.tests.TestsDeprecationWrapper", + "trio.to_thread.current_default_thread_limiter" + ] + } +} diff --git a/trio/_threads.py b/trio/_threads.py index e5ffb74b7c..3fbab05750 100644 --- a/trio/_threads.py +++ b/trio/_threads.py @@ -1,133 +1,36 @@ -import threading +from __future__ import annotations + +import contextvars +import functools +import inspect import queue as stdlib_queue +import threading from itertools import count +from typing import Any, Callable, Optional, TypeVar import attr import outcome +from sniffio import current_async_library_cvar import trio - +from trio._core._traps import RaiseCancelT + +from ._core import ( + RunVar, + TrioToken, + disable_ki_protection, + enable_ki_protection, + start_thread_soon, +) from ._sync import CapacityLimiter -from ._core import enable_ki_protection, disable_ki_protection, RunVar, TrioToken +from ._util import coroutine_or_error + +T = TypeVar("T") # Global due to Threading API, thread local storage for trio token TOKEN_LOCAL = threading.local() - -class BlockingTrioPortal: - def __init__(self, trio_token=None): - if trio_token is None: - trio_token = trio.hazmat.current_trio_token() - self._trio_token = trio_token - - def run(self, afn, *args): - return from_thread_run(afn, *args, trio_token=self._trio_token) - - def run_sync(self, fn, *args): - return from_thread_run_sync(fn, *args, trio_token=self._trio_token) - - -################################################################ - -# XX at some point it probably makes sense to implement some sort of thread -# pool? Or at least that's what everyone says. -# -# There are two arguments for thread pools: -# - speed (re-using threads instead of starting new ones) -# - throttling (if you have 1000 tasks, queue them up instead of spawning 1000 -# threads and running out of memory) -# -# Regarding speed, it's not clear how much of an advantage this is. Some -# numbers on my Linux laptop: -# -# Spawning and then joining a thread: -# -# In [25]: %timeit t = threading.Thread(target=lambda: None); t.start(); t.join() -# 10000 loops, best of 3: 44 µs per loop -# -# Using a thread pool: -# -# In [26]: tpp = concurrent.futures.ThreadPoolExecutor() -# In [27]: %timeit tpp.submit(lambda: None).result() -# -# In [28]: %timeit tpp.submit(lambda: None).result() -# 10000 loops, best of 3: 40.8 µs per loop -# -# What's a fast getaddrinfo look like? -# -# # with hot DNS cache: -# In [23]: %timeit socket.getaddrinfo("google.com", "80") -# 10 loops, best of 3: 50.9 ms per loop -# -# In [29]: %timeit socket.getaddrinfo("127.0.0.1", "80") -# 100000 loops, best of 3: 9.73 µs per loop -# -# -# So... maybe we can beat concurrent.futures with a super-efficient thread -# pool or something, but there really is not a lot of headroom here. -# -# Of course other systems might be different... here's CPython 3.6 in a -# Virtualbox VM running Windows 10 on that same Linux laptop: -# -# In [13]: %timeit t = threading.Thread(target=lambda: None); t.start(); t.join() -# 10000 loops, best of 3: 127 µs per loop -# -# In [18]: %timeit tpp.submit(lambda: None).result() -# 10000 loops, best of 3: 31.9 µs per loop -# -# So on Windows there *might* be an advantage? You've gotta be doing a lot of -# connections, with very fast DNS indeed, for that 100 us to matter. But maybe -# someone is. -# -# -# Regarding throttling: this is very much a trade-off. On the one hand, you -# don't want to overwhelm the machine, obviously. On the other hand, queueing -# up work on a central thread-pool creates a central coordination point which -# can potentially create deadlocks and all kinds of fun things. This is very -# context dependent. For getaddrinfo, whatever, they'll make progress and -# complete (we hope), and you want to throttle them to some reasonable -# amount. For calling waitpid() (because just say no to SIGCHLD), then you -# really want one thread-per-waitpid(), because for all you know the user has -# written some ridiculous thing like: -# -# for p in processes: -# await spawn(p.wait) -# # Deadlock here if there are enough processes: -# await some_other_subprocess.wait() -# for p in processes: -# p.terminate() -# -# This goes doubly for the sort of wacky thread usage we see in curio.abide -# (though, I'm not sure if that's actually useful in practice in our context, -# run_in_trio_thread seems like it might be a nicer synchronization primitive -# for most uses than trying to make threading.Lock awaitable). -# -# See also this very relevant discussion: -# -# https://twistedmatrix.com/trac/ticket/5298 -# -# "Interacting with the products at Rackspace which use Twisted, I've seen -# problems caused by thread-pool maximum sizes with some annoying -# regularity. The basic problem is this: if you have a hard limit on the -# number of threads, *it is not possible to write a correct program which may -# require starting a new thread to un-block a blocked pool thread*" - glyph -# -# For now, if we want to throttle getaddrinfo I think the simplest thing is -# for the socket code to have a semaphore for getaddrinfo calls. -# -# Regarding the memory overhead of threads, in theory one should be able to -# reduce this a *lot* for a thread that's just calling getaddrinfo or -# (especially) waitpid. Windows and pthreads both offer the ability to set -# thread stack size on a thread-by-thread basis. Unfortunately as of 3.6 -# CPython doesn't expose this in a useful way (all you can do is set it -# globally for the whole process, so it's - ironically - not thread safe). -# -# (It's also unclear how much stack size actually matters; on a 64-bit Linux -# server with overcommit -- i.e., the most common configuration -- then AFAICT -# really the only real limit is on stack size actually *used*; how much you -# *allocate* should be pretty much irrelevant.) - -_limiter_local = RunVar("limiter") +_limiter_local: RunVar[CapacityLimiter] = RunVar("limiter") # I pulled this number out of the air; it isn't based on anything. Probably we # should make some kind of measurements to pick a good value. DEFAULT_LIMIT = 40 @@ -160,7 +63,13 @@ class ThreadPlaceholder: @enable_ki_protection -async def to_thread_run_sync(sync_fn, *args, cancellable=False, limiter=None): +async def to_thread_run_sync( + sync_fn: Callable[..., T], + *args: Any, + thread_name: Optional[str] = None, + cancellable: bool = False, + limiter: CapacityLimiter | None = None, +) -> T: """Convert a blocking operation into an async operation using a thread. These two lines are equivalent:: @@ -182,6 +91,12 @@ async def to_thread_run_sync(sync_fn, *args, cancellable=False, limiter=None): arguments, use :func:`functools.partial`. cancellable (bool): Whether to allow cancellation of this operation. See discussion below. + thread_name (str): Optional string to set the name of the thread. + Will always set `threading.Thread.name`, but only set the os name + if pthread.h is available (i.e. most POSIX installations). + pthread names are limited to 15 characters, and can be read from + ``/proc//task//comm`` or with ``ps -eT``, among others. + Defaults to ``{sync_fn.__name__|None} from {trio.lowlevel.current_task().name}``. limiter (None, or CapacityLimiter-like object): An object used to limit the number of simultaneous threads. Most commonly this will be a `~trio.CapacityLimiter`, but it could be @@ -238,16 +153,16 @@ async def to_thread_run_sync(sync_fn, *args, cancellable=False, limiter=None): Exception: Whatever ``sync_fn(*args)`` raises. """ - await trio.hazmat.checkpoint_if_cancelled() - token = trio.hazmat.current_trio_token() + await trio.lowlevel.checkpoint_if_cancelled() + cancellable = bool(cancellable) # raise early if cancellable.__bool__ raises if limiter is None: limiter = current_default_thread_limiter() # Holds a reference to the task that's blocked in this function waiting # for the result – or None if this function was cancelled and we should # discard the result. - task_register = [trio.hazmat.current_task()] - name = "trio-worker-{}".format(next(_thread_counter)) + task_register: list[trio.lowlevel.Task | None] = [trio.lowlevel.current_task()] + name = f"trio.to_thread.run_sync-{next(_thread_counter)}" placeholder = ThreadPlaceholder(name) # This function gets scheduled into the Trio run loop to deliver the @@ -265,52 +180,64 @@ def do_release_then_return_result(): result = outcome.capture(do_release_then_return_result) if task_register[0] is not None: - trio.hazmat.reschedule(task_register[0], result) - - # This is the function that runs in the worker thread to do the actual - # work and then schedule the call to report_back_in_trio_thread_fn - # Since this is spawned in a new thread, the trio token needs to be passed - # explicitly to it so it can inject it into thread local storage - def worker_thread_fn(trio_token): - TOKEN_LOCAL.token = trio_token + trio.lowlevel.reschedule(task_register[0], result) + + current_trio_token = trio.lowlevel.current_trio_token() + + if thread_name is None: + thread_name = f"{getattr(sync_fn, '__name__', None)} from {trio.lowlevel.current_task().name}" + + def worker_fn(): + current_async_library_cvar.set(None) + TOKEN_LOCAL.token = current_trio_token try: - result = outcome.capture(sync_fn, *args) - try: - token.run_sync_soon(report_back_in_trio_thread_fn, result) - except trio.RunFinishedError: - # The entire run finished, so our particular task is certainly - # long gone -- it must have cancelled. - pass + ret = sync_fn(*args) + + if inspect.iscoroutine(ret): + # Manually close coroutine to avoid RuntimeWarnings + ret.close() + raise TypeError( + "Trio expected a sync function, but {!r} appears to be " + "asynchronous".format(getattr(sync_fn, "__qualname__", sync_fn)) + ) + + return ret finally: del TOKEN_LOCAL.token + context = contextvars.copy_context() + contextvars_aware_worker_fn = functools.partial(context.run, worker_fn) + + def deliver_worker_fn_result(result): + try: + current_trio_token.run_sync_soon(report_back_in_trio_thread_fn, result) + except trio.RunFinishedError: + # The entire run finished, so the task we're trying to contact is + # certainly long gone -- it must have been cancelled and abandoned + # us. + pass + await limiter.acquire_on_behalf_of(placeholder) try: - # daemon=True because it might get left behind if we cancel, and in - # this case shouldn't block process exit. - current_trio_token = trio.hazmat.current_trio_token() - thread = threading.Thread( - target=worker_thread_fn, - args=(current_trio_token,), - name=name, - daemon=True + start_thread_soon( + contextvars_aware_worker_fn, deliver_worker_fn_result, thread_name ) - thread.start() except: limiter.release_on_behalf_of(placeholder) raise - def abort(_): + def abort(_: RaiseCancelT) -> trio.lowlevel.Abort: if cancellable: task_register[0] = None - return trio.hazmat.Abort.SUCCEEDED + return trio.lowlevel.Abort.SUCCEEDED else: - return trio.hazmat.Abort.FAILED + return trio.lowlevel.Abort.FAILED - return await trio.hazmat.wait_task_rescheduled(abort) + # wait_task_rescheduled return value cannot be typed + return await trio.lowlevel.wait_task_rescheduled(abort) # type: ignore[no-any-return] -def _run_fn_as_system_task(cb, fn, *args, trio_token=None): +def _run_fn_as_system_task(cb, fn, *args, context, trio_token=None): """Helper function for from_thread.run and from_thread.run_sync. Since this internally uses TrioToken.run_sync_soon, all warnings about @@ -328,21 +255,16 @@ def _run_fn_as_system_task(cb, fn, *args, trio_token=None): "this thread wasn't created by Trio, pass kwarg trio_token=..." ) - # TODO: This is only necessary for compatibility with BlockingTrioPortal. - # once that is deprecated, this check should no longer be necessary because - # thread local storage (or the absence of) is sufficient to check if trio - # is running in a thread or not. + # Avoid deadlock by making sure we're not called from Trio thread try: - trio.hazmat.current_task() + trio.lowlevel.current_task() except RuntimeError: pass else: - raise RuntimeError( - "this is a blocking function; call it from a thread" - ) + raise RuntimeError("this is a blocking function; call it from a thread") - q = stdlib_queue.Queue() - trio_token.run_sync_soon(cb, q, fn, args) + q = stdlib_queue.SimpleQueue() + trio_token.run_sync_soon(context.run, cb, q, fn, args) return q.get().unwrap() @@ -358,7 +280,8 @@ def from_thread_run(afn, *args, trio_token=None): Raises: RunFinishedError: if the corresponding call to :func:`trio.run` has - already completed. + already completed, or if the run has started its final cleanup phase + and can no longer spawn new system tasks. Cancelled: if the corresponding call to :func:`trio.run` completes while ``afn(*args)`` is running, then ``afn`` is likely to raise :exc:`trio.Cancelled`, and this will propagate out into @@ -366,6 +289,7 @@ def from_thread_run(afn, *args, trio_token=None): which would otherwise cause a deadlock. AttributeError: if no ``trio_token`` was provided, and we can't infer one from context. + TypeError: if ``afn`` is not an asynchronous function. **Locating a Trio Token**: There are two ways to specify which `trio.run` loop to reenter: @@ -373,22 +297,39 @@ def from_thread_run(afn, *args, trio_token=None): - Spawn this thread from `trio.to_thread.run_sync`. Trio will automatically capture the relevant Trio token and use it when you want to re-enter Trio. - - Pass a keyword argument, ``trio_token`` specifiying a specific + - Pass a keyword argument, ``trio_token`` specifying a specific `trio.run` loop to re-enter. This is useful in case you have a "foreign" thread, spawned using some other framework, and still want to enter Trio. """ + def callback(q, afn, args): @disable_ki_protection async def unprotected_afn(): - return await afn(*args) + coro = coroutine_or_error(afn, *args) + return await coro async def await_in_trio_thread_task(): q.put_nowait(await outcome.acapture(unprotected_afn)) - trio.hazmat.spawn_system_task(await_in_trio_thread_task, name=afn) + context = contextvars.copy_context() + try: + trio.lowlevel.spawn_system_task( + await_in_trio_thread_task, name=afn, context=context + ) + except RuntimeError: # system nursery is closed + q.put_nowait( + outcome.Error(trio.RunFinishedError("system nursery is closed")) + ) - return _run_fn_as_system_task(callback, afn, *args, trio_token=trio_token) + context = contextvars.copy_context() + return _run_fn_as_system_task( + callback, + afn, + *args, + context=context, + trio_token=trio_token, + ) def from_thread_run_sync(fn, *args, trio_token=None): @@ -404,13 +345,11 @@ def from_thread_run_sync(fn, *args, trio_token=None): Raises: RunFinishedError: if the corresponding call to `trio.run` has already completed. - Cancelled: if the corresponding call to `trio.run` completes - while ``afn(*args)`` is running, then ``afn`` is likely to raise - :exc:`trio.Cancelled`, and this will propagate out into RuntimeError: if you try calling this from inside the Trio thread, which would otherwise cause a deadlock. AttributeError: if no ``trio_token`` was provided, and we can't infer one from context. + TypeError: if ``fn`` is an async function. **Locating a Trio Token**: There are two ways to specify which `trio.run` loop to reenter: @@ -418,17 +357,38 @@ def from_thread_run_sync(fn, *args, trio_token=None): - Spawn this thread from `trio.to_thread.run_sync`. Trio will automatically capture the relevant Trio token and use it when you want to re-enter Trio. - - Pass a keyword argument, ``trio_token`` specifiying a specific + - Pass a keyword argument, ``trio_token`` specifying a specific `trio.run` loop to re-enter. This is useful in case you have a "foreign" thread, spawned using some other framework, and still want to enter Trio. """ + def callback(q, fn, args): + current_async_library_cvar.set("trio") + @disable_ki_protection def unprotected_fn(): - return fn(*args) + ret = fn(*args) + + if inspect.iscoroutine(ret): + # Manually close coroutine to avoid RuntimeWarnings + ret.close() + raise TypeError( + "Trio expected a sync function, but {!r} appears to be " + "asynchronous".format(getattr(fn, "__qualname__", fn)) + ) + + return ret res = outcome.capture(unprotected_fn) q.put_nowait(res) - return _run_fn_as_system_task(callback, fn, *args, trio_token=trio_token) + context = contextvars.copy_context() + + return _run_fn_as_system_task( + callback, + fn, + *args, + context=context, + trio_token=trio_token, + ) diff --git a/trio/_timeouts.py b/trio/_timeouts.py index 25cef08d0d..1d03b2f2e3 100644 --- a/trio/_timeouts.py +++ b/trio/_timeouts.py @@ -1,31 +1,29 @@ -from contextlib import contextmanager +from __future__ import annotations -import trio +import math +from contextlib import AbstractContextManager, contextmanager +from typing import TYPE_CHECKING -__all__ = [ - "move_on_at", - "move_on_after", - "sleep_forever", - "sleep_until", - "sleep", - "fail_at", - "fail_after", - "TooSlowError", -] +import trio -def move_on_at(deadline): +def move_on_at(deadline: float) -> trio.CancelScope: """Use as a context manager to create a cancel scope with the given absolute deadline. Args: deadline (float): The deadline. + Raises: + ValueError: if deadline is NaN. + """ + if math.isnan(deadline): + raise ValueError("deadline must not be NaN") return trio.CancelScope(deadline=deadline) -def move_on_after(seconds): +def move_on_after(seconds: float) -> trio.CancelScope: """Use as a context manager to create a cancel scope whose deadline is set to now + *seconds*. @@ -33,43 +31,44 @@ def move_on_after(seconds): seconds (float): The timeout. Raises: - ValueError: if timeout is less than zero. + ValueError: if timeout is less than zero or NaN. """ - if seconds < 0: raise ValueError("timeout must be non-negative") return move_on_at(trio.current_time() + seconds) -async def sleep_forever(): +async def sleep_forever() -> None: """Pause execution of the current task forever (or until cancelled). Equivalent to calling ``await sleep(math.inf)``. """ - await trio.hazmat.wait_task_rescheduled( - lambda _: trio.hazmat.Abort.SUCCEEDED - ) + await trio.lowlevel.wait_task_rescheduled(lambda _: trio.lowlevel.Abort.SUCCEEDED) -async def sleep_until(deadline): +async def sleep_until(deadline: float) -> None: """Pause execution of the current task until the given time. The difference between :func:`sleep` and :func:`sleep_until` is that the - former takes a relative time and the latter takes an absolute time. + former takes a relative time and the latter takes an absolute time + according to Trio's internal clock (as returned by :func:`current_time`). Args: deadline (float): The time at which we should wake up again. May be in the past, in which case this function executes a checkpoint but does not block. + Raises: + ValueError: if deadline is NaN. + """ with move_on_at(deadline): await sleep_forever() -async def sleep(seconds): +async def sleep(seconds: float) -> None: """Pause execution of the current task for the given number of seconds. Args: @@ -77,13 +76,13 @@ async def sleep(seconds): insert a checkpoint without actually blocking. Raises: - ValueError: if *seconds* is negative. + ValueError: if *seconds* is negative or NaN. """ if seconds < 0: raise ValueError("duration must be non-negative") if seconds == 0: - await trio.hazmat.checkpoint() + await trio.lowlevel.checkpoint() else: await sleep_until(trio.current_time() + seconds) @@ -95,8 +94,9 @@ class TooSlowError(Exception): """ -@contextmanager -def fail_at(deadline): +# workaround for PyCharm not being able to infer return type from @contextmanager +# see https://youtrack.jetbrains.com/issue/PY-36444/PyCharm-doesnt-infer-types-when-using-contextlib.contextmanager-decorator +def fail_at(deadline: float) -> AbstractContextManager[trio.CancelScope]: # type: ignore[misc] """Creates a cancel scope with the given deadline, and raises an error if it is actually cancelled. @@ -108,19 +108,26 @@ def fail_at(deadline): :func:`fail_at`, then it's caught and :exc:`TooSlowError` is raised in its place. + Args: + deadline (float): The deadline. + Raises: TooSlowError: if a :exc:`Cancelled` exception is raised in this scope and caught by the context manager. + ValueError: if deadline is NaN. """ - with move_on_at(deadline) as scope: yield scope if scope.cancelled_caught: raise TooSlowError -def fail_after(seconds): +if not TYPE_CHECKING: + fail_at = contextmanager(fail_at) + + +def fail_after(seconds: float) -> AbstractContextManager[trio.CancelScope]: """Creates a cancel scope with the given timeout, and raises an error if it is actually cancelled. @@ -131,10 +138,13 @@ def fail_after(seconds): it's caught and discarded. When it reaches :func:`fail_after`, then it's caught and :exc:`TooSlowError` is raised in its place. + Args: + seconds (float): The timeout. + Raises: TooSlowError: if a :exc:`Cancelled` exception is raised in this scope and caught by the context manager. - ValueError: if *seconds* is less than zero. + ValueError: if *seconds* is less than zero or NaN. """ if seconds < 0: diff --git a/trio/_tools/gen_exports.py b/trio/_tools/gen_exports.py index cc0738177a..a5d8529b53 100755 --- a/trio/_tools/gen_exports.py +++ b/trio/_tools/gen_exports.py @@ -1,35 +1,38 @@ #! /usr/bin/env python3 -# -*- coding: utf-8 -`- """ Code generation script for class methods to be exported as public API """ import argparse import ast -import astor import os -from pathlib import Path import sys -import yapf.yapflib.yapf_api as formatter - +from pathlib import Path from textwrap import indent -PREFIX = '_generated' +import astor + +PREFIX = "_generated" HEADER = """# *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* -from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND +# isort: skip +from ._instrumentation import Instrument from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED +from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT + +# fmt: off +""" - +FOOTER = """# fmt: on """ TEMPLATE = """locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: - return {} GLOBAL_RUN_CONTEXT.{}.{} + return{}GLOBAL_RUN_CONTEXT.{}.{} except AttributeError: - raise RuntimeError('must be called from async context') + raise RuntimeError("must be called from async context") """ @@ -37,17 +40,13 @@ def is_function(node): """Check if the AST node is either a function or an async function """ - if ( - isinstance(node, ast.FunctionDef) - or isinstance(node, ast.AsyncFunctionDef) - ): + if isinstance(node, ast.FunctionDef) or isinstance(node, ast.AsyncFunctionDef): return True return False def is_public(node): - """Check if the AST node has a _public decorator - """ + """Check if the AST node has a _public decorator""" if not is_function(node): return False for decorator in node.decorator_list: @@ -57,7 +56,7 @@ def is_public(node): def get_public_methods(tree): - """ Return a list of methods marked as public. + """Return a list of methods marked as public. The function walks the given tree and extracts all objects that are functions which are marked public. @@ -113,28 +112,29 @@ def gen_public_wrappers_source(source_path: Path, lookup_path: str) -> str: del method.body[1:] # Create the function definition including the body - func = astor.to_source(method, indent_with=' ' * 4) + func = astor.to_source(method, indent_with=" " * 4) # Create export function body template = TEMPLATE.format( - 'await' if isinstance(method, ast.AsyncFunctionDef) else '', + " await " if isinstance(method, ast.AsyncFunctionDef) else " ", lookup_path, method.name + new_args, ) # Assemble function definition arguments and body - snippet = func + indent(template, ' ' * 4) + snippet = func + indent(template, " " * 4) # Append the snippet to the corresponding module generated.append(snippet) - return "\n".join(generated) + generated.append(FOOTER) + return "\n\n".join(generated) def matches_disk_files(new_files): for new_path, new_source in new_files.items(): if not os.path.exists(new_path): return False - with open(new_path, "r", encoding="utf-8") as old_file: + with open(new_path, encoding="utf-8") as old_file: old_source = old_file.read() if old_source != new_source: return False @@ -166,22 +166,20 @@ def process(sources_and_lookups, *, do_test): # doesn't collect coverage. def main(): # pragma: no cover parser = argparse.ArgumentParser( - description='Generate python code for public api wrappers' + description="Generate python code for public api wrappers" ) parser.add_argument( - '--test', - '-t', - action='store_true', - help='test if code is still up to date' + "--test", "-t", action="store_true", help="test if code is still up to date" ) parsed_args = parser.parse_args() source_root = Path.cwd() # Double-check we found the right directory assert (source_root / "LICENSE").exists() - core = source_root / 'trio/_core' + core = source_root / "trio/_core" to_wrap = [ (core / "_run.py", "runner"), + (core / "_instrumentation.py", "runner.instruments"), (core / "_io_windows.py", "runner.io_manager"), (core / "_io_epoll.py", "runner.io_manager"), (core / "_io_kqueue.py", "runner.io_manager"), @@ -190,5 +188,5 @@ def main(): # pragma: no cover process(to_wrap, do_test=parsed_args.test) -if __name__ == '__main__': # pragma: no cover +if __name__ == "__main__": # pragma: no cover main() diff --git a/trio/_unix_pipes.py b/trio/_unix_pipes.py index fb6515d8df..716550790e 100644 --- a/trio/_unix_pipes.py +++ b/trio/_unix_pipes.py @@ -1,19 +1,25 @@ -import os +from __future__ import annotations + import errno +import os +from typing import TYPE_CHECKING + +import trio from ._abc import Stream -from ._util import ConflictDetector +from ._util import ConflictDetector, Final -import trio +if TYPE_CHECKING: + from typing import Final as FinalType if os.name != "posix": - # We raise an error here rather than gating the import in hazmat.py + # We raise an error here rather than gating the import in lowlevel.py # in order to keep jedi static analysis happy. raise ImportError # XX TODO: is this a good number? who knows... it does match the default Linux # pipe capacity though. -DEFAULT_RECEIVE_SIZE = 65536 +DEFAULT_RECEIVE_SIZE: FinalType = 65536 class _FdHolder: @@ -34,7 +40,9 @@ class _FdHolder: # impossible to make this mistake – we'll just get an EBADF. # # (This trick was copied from the stdlib socket module.) - def __init__(self, fd: int): + fd: int + + def __init__(self, fd: int) -> None: # make sure self.fd is always initialized to *something*, because even # if we error out here then __del__ will run and access it. self.fd = -1 @@ -46,10 +54,10 @@ def __init__(self, fd: int): os.set_blocking(fd, False) @property - def closed(self): + def closed(self) -> bool: return self.fd == -1 - def _raw_close(self): + def _raw_close(self) -> None: # This doesn't assume it's in a Trio context, so it can be called from # __del__. You should never call it from Trio context, because it # skips calling notify_fd_close. But from __del__, skipping that is @@ -64,17 +72,16 @@ def _raw_close(self): os.set_blocking(fd, self._original_is_blocking) os.close(fd) - def __del__(self): + def __del__(self) -> None: self._raw_close() - async def aclose(self): + def close(self) -> None: if not self.closed: - trio.hazmat.notify_closing(self.fd) + trio.lowlevel.notify_closing(self.fd) self._raw_close() - await trio.hazmat.checkpoint() -class FdStream(Stream): +class FdStream(Stream, metaclass=Final): """ Represents a stream given the file descriptor to a pipe, TTY, etc. @@ -91,13 +98,14 @@ class FdStream(Stream): or processes are using file descriptors that are related through `os.dup` or inheritance across `os.fork` to the one that Trio is using, they are unlikely to be prepared to have non-blocking I/O semantics suddenly - thrust upon them. For example, you can use ``FdStream(os.dup(0))`` to - obtain a stream for reading from standard input, but it is only safe to - do so with heavy caveats: your stdin must not be shared by any other - processes and you must not make any calls to synchronous methods of - `sys.stdin` until the stream returned by `FdStream` is closed. See - `issue #174 `__ for a - discussion of the challenges involved in relaxing this restriction. + thrust upon them. For example, you can use + ``FdStream(os.dup(sys.stdin.fileno()))`` to obtain a stream for reading + from standard input, but it is only safe to do so with heavy caveats: your + stdin must not be shared by any other processes, and you must not make any + calls to synchronous methods of `sys.stdin` until the stream returned by + `FdStream` is closed. See `issue #174 + `__ for a discussion of the + challenges involved in relaxing this restriction. Args: fd (int): The fd to be wrapped. @@ -105,7 +113,8 @@ class FdStream(Stream): Returns: A new `FdStream` object. """ - def __init__(self, fd: int): + + def __init__(self, fd: int) -> None: self._fd_holder = _FdHolder(fd) self._send_conflict_detector = ConflictDetector( "another task is using this stream for send" @@ -114,13 +123,13 @@ def __init__(self, fd: int): "another task is using this stream for receive" ) - async def send_all(self, data: bytes): + async def send_all(self, data: bytes) -> None: with self._send_conflict_detector: # have to check up front, because send_all(b"") on a closed pipe # should raise if self._fd_holder.closed: raise trio.ClosedResourceError("file was already closed") - await trio.hazmat.checkpoint() + await trio.lowlevel.checkpoint() length = len(data) # adapted from the SocketStream code with memoryview(data) as view: @@ -130,7 +139,7 @@ async def send_all(self, data: bytes): try: sent += os.write(self._fd_holder.fd, remaining) except BlockingIOError: - await trio.hazmat.wait_writable(self._fd_holder.fd) + await trio.lowlevel.wait_writable(self._fd_holder.fd) except OSError as e: if e.errno == errno.EBADF: raise trio.ClosedResourceError( @@ -144,13 +153,13 @@ async def wait_send_all_might_not_block(self) -> None: if self._fd_holder.closed: raise trio.ClosedResourceError("file was already closed") try: - await trio.hazmat.wait_writable(self._fd_holder.fd) + await trio.lowlevel.wait_writable(self._fd_holder.fd) except BrokenPipeError as e: # kqueue: raises EPIPE on wait_writable instead # of sending, which is annoying raise trio.BrokenResourceError from e - async def receive_some(self, max_bytes=None) -> bytes: + async def receive_some(self, max_bytes: int | None = None) -> bytes: with self._receive_conflict_detector: if max_bytes is None: max_bytes = DEFAULT_RECEIVE_SIZE @@ -160,12 +169,12 @@ async def receive_some(self, max_bytes=None) -> bytes: if max_bytes < 1: raise ValueError("max_bytes must be integer >= 1") - await trio.hazmat.checkpoint() + await trio.lowlevel.checkpoint() while True: try: data = os.read(self._fd_holder.fd, max_bytes) except BlockingIOError: - await trio.hazmat.wait_readable(self._fd_holder.fd) + await trio.lowlevel.wait_readable(self._fd_holder.fd) except OSError as e: if e.errno == errno.EBADF: raise trio.ClosedResourceError( @@ -178,8 +187,12 @@ async def receive_some(self, max_bytes=None) -> bytes: return data - async def aclose(self): - await self._fd_holder.aclose() + def close(self) -> None: + self._fd_holder.close() + + async def aclose(self) -> None: + self.close() + await trio.lowlevel.checkpoint() - def fileno(self): + def fileno(self) -> int: return self._fd_holder.fd diff --git a/trio/_util.py b/trio/_util.py index 9204b83349..a87f1fc02c 100644 --- a/trio/_util.py +++ b/trio/_util.py @@ -1,27 +1,24 @@ # Little utilities we use internally +from __future__ import annotations -from abc import ABCMeta +import collections +import inspect import os import signal -import sys -import pathlib -from functools import wraps, update_wrapper -import typing as t import threading +import typing as t +from abc import ABCMeta +from functools import update_wrapper +from types import TracebackType -import async_generator +import trio -# There's a dependency loop here... _core is allowed to use this file (in fact -# it's the *only* file in the main trio/ package it's allowed to use), but -# ConflictDetector needs checkpoint so it also has to import -# _core. Possibly we should split this file into two: one for true generic -# low-level utility code, and one for higher level helpers? +CallT = t.TypeVar("CallT", bound=t.Callable[..., t.Any]) -import trio # Equivalent to the C function raise(), which Python doesn't wrap if os.name == "nt": - # On windows, os.kill exists but is really weird. + # On Windows, os.kill exists but is really weird. # # If you give it CTRL_C_EVENT or CTRL_BREAK_EVENT, it tries to deliver # those using GenerateConsoleCtrlEvent. But I found that when I tried @@ -40,7 +37,7 @@ # OTOH, if you pass os.kill any *other* signal number... then CPython # just calls TerminateProcess (wtf). # - # So, anyway, os.kill is not so useful for testing purposes. Instead + # So, anyway, os.kill is not so useful for testing purposes. Instead, # we use raise(): # # https://msdn.microsoft.com/en-us/library/dwwzkt4c.aspx @@ -68,24 +65,6 @@ def signal_raise(signum): signal.pthread_kill(threading.get_ident(), signum) -# Decorator to handle the change to __aiter__ in 3.5.2 -if sys.version_info < (3, 5, 2): - - def aiter_compat(aiter_impl): - # de-sugar decorator to fix Python 3.8 coverage issue - # https://github.com/python-trio/trio/pull/784#issuecomment-446438407 - async def __aiter__(*args, **kwargs): - return aiter_impl(*args, **kwargs) - - __aiter__ = wraps(aiter_impl)(__aiter__) - - return __aiter__ -else: - - def aiter_compat(aiter_impl): - return aiter_impl - - # See: #461 as to why this is needed. # The gist is that threading.main_thread() has the capability to lie to us # if somebody else edits the threading ident cache to replace the main @@ -99,10 +78,97 @@ def is_main_thread(): try: signal.signal(signal.SIGINT, signal.getsignal(signal.SIGINT)) return True - except ValueError: + except (TypeError, ValueError): return False +###### +# Call the function and get the coroutine object, while giving helpful +# errors for common mistakes. Returns coroutine object. +###### +def coroutine_or_error(async_fn, *args): + def _return_value_looks_like_wrong_library(value): + # Returned by legacy @asyncio.coroutine functions, which includes + # a surprising proportion of asyncio builtins. + if isinstance(value, collections.abc.Generator): + return True + # The protocol for detecting an asyncio Future-like object + if getattr(value, "_asyncio_future_blocking", None) is not None: + return True + # This janky check catches tornado Futures and twisted Deferreds. + # By the time we're calling this function, we already know + # something has gone wrong, so a heuristic is pretty safe. + if value.__class__.__name__ in ("Future", "Deferred"): + return True + return False + + try: + coro = async_fn(*args) + + except TypeError: + # Give good error for: nursery.start_soon(trio.sleep(1)) + if isinstance(async_fn, collections.abc.Coroutine): + # explicitly close coroutine to avoid RuntimeWarning + async_fn.close() + + raise TypeError( + "Trio was expecting an async function, but instead it got " + "a coroutine object {async_fn!r}\n" + "\n" + "Probably you did something like:\n" + "\n" + " trio.run({async_fn.__name__}(...)) # incorrect!\n" + " nursery.start_soon({async_fn.__name__}(...)) # incorrect!\n" + "\n" + "Instead, you want (notice the parentheses!):\n" + "\n" + " trio.run({async_fn.__name__}, ...) # correct!\n" + " nursery.start_soon({async_fn.__name__}, ...) # correct!".format( + async_fn=async_fn + ) + ) from None + + # Give good error for: nursery.start_soon(future) + if _return_value_looks_like_wrong_library(async_fn): + raise TypeError( + "Trio was expecting an async function, but instead it got " + "{!r} – are you trying to use a library written for " + "asyncio/twisted/tornado or similar? That won't work " + "without some sort of compatibility shim.".format(async_fn) + ) from None + + raise + + # We can't check iscoroutinefunction(async_fn), because that will fail + # for things like functools.partial objects wrapping an async + # function. So we have to just call it and then check whether the + # return value is a coroutine object. + # Note: will not be necessary on python>=3.8, see https://bugs.python.org/issue34890 + # TODO: python3.7 support is now dropped, so the above can be addressed. + if not isinstance(coro, collections.abc.Coroutine): + # Give good error for: nursery.start_soon(func_returning_future) + if _return_value_looks_like_wrong_library(coro): + raise TypeError( + "Trio got unexpected {!r} – are you trying to use a " + "library written for asyncio/twisted/tornado or similar? " + "That won't work without some sort of compatibility shim.".format(coro) + ) + + if inspect.isasyncgen(coro): + raise TypeError( + "start_soon expected an async function but got an async " + "generator {!r}".format(coro) + ) + + # Give good error for: nursery.start_soon(some_sync_fn) + raise TypeError( + "Trio expected an async function, but {!r} appears to be " + "synchronous".format(getattr(async_fn, "__qualname__", async_fn)) + ) + + return coro + + class ConflictDetector: """Detect when two tasks are about to perform operations that would conflict. @@ -116,6 +182,7 @@ class ConflictDetector: tasks don't call sendall simultaneously on the same stream. """ + def __init__(self, msg): self._msg = msg self._held = False @@ -126,17 +193,25 @@ def __enter__(self): else: self._held = True - def __exit__(self, *args): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: self._held = False -def async_wraps(cls, wrapped_cls, attr_name): - """Similar to wraps, but for async wrappers of non-async functions. +def async_wraps( + cls: type[object], + wrapped_cls: type[object], + attr_name: str, +) -> t.Callable[[CallT], CallT]: + """Similar to wraps, but for async wrappers of non-async functions.""" - """ - def decorator(func): + def decorator(func: CallT) -> CallT: func.__name__ = attr_name - func.__qualname__ = '.'.join((cls.__qualname__, attr_name)) + func.__qualname__ = ".".join((cls.__qualname__, attr_name)) func.__doc__ = """Like :meth:`~{}.{}.{}`, but async. @@ -162,7 +237,7 @@ def fix_one(qualname, name, obj): mod = getattr(obj, "__module__", None) if mod is not None and mod.startswith("trio."): obj.__module__ = module_name - # Modules, unlike everything else in Python, put fully-qualitied + # Modules, unlike everything else in Python, put fully-qualified # names into their __name__ attribute. We check for "." to avoid # rewriting these. if hasattr(obj, "__name__") and "." not in obj.__name__: @@ -177,66 +252,6 @@ def fix_one(qualname, name, obj): fix_one(objname, objname, obj) -# os.fspath is defined on Python 3.6+ but we need to support Python 3.5 too -# This is why we provide our own implementation. On Python 3.6+ we use the -# StdLib's version and on Python 3.5 our own version. -# Our own implementation implementation is based on PEP 519 while it has also -# been adapted to work with pathlib objects on python 3.5 -# The input typehint is removed as there is no os.PathLike on 3.5. -# See: https://www.python.org/dev/peps/pep-0519/#os - - -def fspath(path) -> t.Union[str, bytes]: - """Return the path representation of a path-like object. - - Returns - ------- - - If str or bytes is passed in, it is returned unchanged. - - If the os.PathLike interface is implemented it is used to get the path - representation. - - If the python version is 3.5 or earlier and a pathlib object is passed, - the object's string representation is returned. - - Raises - ------ - - Regardless of the input, if the path representation (e.g. the value - returned from __fspath__) is not str or bytes, TypeError is raised. - - If the provided path is not str, bytes, pathlib.PurePath or os.PathLike, - TypeError is raised. - """ - if isinstance(path, (str, bytes)): - return path - # Work from the object's type to match method resolution of other magic - # methods. - path_type = type(path) - # On python 3.5, pathlib objects don't have the __fspath__ method, - # but we still want to get their string representation. - if issubclass(path_type, pathlib.PurePath): - return str(path) - try: - path_repr = path_type.__fspath__(path) - except AttributeError: - if hasattr(path_type, '__fspath__'): - raise - else: - raise TypeError( - "expected str, bytes or os.PathLike object, " - "not " + path_type.__name__ - ) - if isinstance(path_repr, (str, bytes)): - return path_repr - else: - raise TypeError( - "expected {}.__fspath__() to return str or bytes, " - "not {}".format(path_type.__name__, - type(path_repr).__name__) - ) - - -if hasattr(os, "fspath"): - fspath = os.fspath # noqa - - class generic_function: """Decorator that makes a function indexable, to communicate non-inferrable generic type parameters to a static type checker. @@ -253,6 +268,7 @@ def open_memory_channel(max_buffer_size: int) -> Tuple[ and currently won't type-check without a mypy plugin or clever stubs, but at least it becomes possible to write those. """ + def __init__(self, fn): update_wrapper(self, fn) self._fn = fn @@ -264,21 +280,7 @@ def __getitem__(self, _): return self -# If a new class inherits from any ABC, then the new class's metaclass has to -# inherit from ABCMeta. If a new class inherits from typing.Generic, and -# you're using Python 3.6 or earlier, then the new class's metaclass has to -# inherit from typing.GenericMeta. Some of the classes that want to use Final -# or NoPublicConstructor inherit from ABCs and generics, so Final has to -# inherit from these metaclasses. Fortunately, GenericMeta inherits from -# ABCMeta, so inheriting from GenericMeta alone is sufficient (when it -# exists at all). -if hasattr(t, "GenericMeta"): - BaseMeta = t.GenericMeta -else: - BaseMeta = ABCMeta - - -class Final(BaseMeta): +class Final(ABCMeta): """Metaclass that enforces a class to be final (i.e., subclass not allowed). If a class uses this metaclass like this:: @@ -286,21 +288,28 @@ class Final(BaseMeta): class SomeClass(metaclass=Final): pass - The metaclass will ensure that no sub class can be created. + The metaclass will ensure that no subclass can be created. Raises ------ - - TypeError if a sub class is created + - TypeError if a subclass is created """ - def __new__(cls, name, bases, cls_namespace): + + def __new__( + cls, name: str, bases: tuple[type, ...], cls_namespace: dict[str, object] + ) -> Final: for base in bases: if isinstance(base, Final): raise TypeError( - "`%s` does not support subclassing" % base.__name__ + f"{base.__module__}.{base.__qualname__} does not support" + " subclassing" ) return super().__new__(cls, name, bases, cls_namespace) +T = t.TypeVar("T") + + class NoPublicConstructor(Final): """Metaclass that enforces a class to be final (i.e., subclass not allowed) and ensures a private constructor. @@ -310,17 +319,37 @@ class NoPublicConstructor(Final): class SomeClass(metaclass=NoPublicConstructor): pass - The metaclass will ensure that no sub class can be created, and that no instance + The metaclass will ensure that no subclass can be created, and that no instance can be initialized. If you try to instantiate your class (SomeClass()), a TypeError will be thrown. Raises ------ - - TypeError if a sub class or an instance is created. + - TypeError if a subclass or an instance is created. """ - def __call__(self, *args, **kwargs): - raise TypeError("no public constructor available") - def _create(self, *args, **kwargs): - return super().__call__(*args, **kwargs) + def __call__(cls, *args: object, **kwargs: object) -> None: + raise TypeError( + f"{cls.__module__}.{cls.__qualname__} has no public constructor" + ) + + def _create(cls: t.Type[T], *args: object, **kwargs: object) -> T: + return super().__call__(*args, **kwargs) # type: ignore + + +def name_asyncgen(agen): + """Return the fully-qualified name of the async generator function + that produced the async generator iterator *agen*. + """ + if not hasattr(agen, "ag_code"): # pragma: no cover + return repr(agen) + try: + module = agen.ag_frame.f_globals["__name__"] + except (AttributeError, KeyError): + module = f"<{agen.ag_code.co_filename}>" + try: + qualname = agen.__qualname__ + except AttributeError: + qualname = agen.ag_code.co_name + return f"{module}.{qualname}" diff --git a/trio/_version.py b/trio/_version.py index 56ce2746f8..65242863a9 100644 --- a/trio/_version.py +++ b/trio/_version.py @@ -1,3 +1,3 @@ # This file is imported from __init__.py and exec'd from setup.py -__version__ = "0.13.0+dev" +__version__ = "0.22.2+dev" diff --git a/trio/_wait_for_object.py b/trio/_wait_for_object.py index dfbf47d8a3..32a88e5398 100644 --- a/trio/_wait_for_object.py +++ b/trio/_wait_for_object.py @@ -1,9 +1,8 @@ import math -from . import _timeouts + import trio -from ._core._windows_cffi import ffi, kernel32, ErrorCodes, raise_winerror, _handle -__all__ = ["WaitForSingleObject"] +from ._core._windows_cffi import ErrorCodes, _handle, ffi, kernel32, raise_winerror async def WaitForSingleObject(obj): @@ -47,16 +46,12 @@ async def WaitForSingleObject(obj): def WaitForMultipleObjects_sync(*handles): - """Wait for any of the given Windows handles to be signaled. - - """ + """Wait for any of the given Windows handles to be signaled.""" n = len(handles) - handle_arr = ffi.new("HANDLE[{}]".format(n)) + handle_arr = ffi.new(f"HANDLE[{n}]") for i in range(n): handle_arr[i] = handles[i] - timeout = 0xffffffff # INFINITE - retcode = kernel32.WaitForMultipleObjects( - n, handle_arr, False, timeout - ) # blocking + timeout = 0xFFFFFFFF # INFINITE + retcode = kernel32.WaitForMultipleObjects(n, handle_arr, False, timeout) # blocking if retcode == ErrorCodes.WAIT_FAILED: raise_winerror() diff --git a/trio/_windows_pipes.py b/trio/_windows_pipes.py index 04bcdc7100..c1c357b018 100644 --- a/trio/_windows_pipes.py +++ b/trio/_windows_pipes.py @@ -1,7 +1,12 @@ +import sys +from typing import TYPE_CHECKING + from . import _core -from ._abc import SendStream, ReceiveStream -from ._util import ConflictDetector -from ._core._windows_cffi import _handle, raise_winerror, kernel32, ffi +from ._abc import ReceiveStream, SendStream +from ._core._windows_cffi import _handle, kernel32, raise_winerror +from ._util import ConflictDetector, Final + +assert sys.platform == "win32" or not TYPE_CHECKING # XX TODO: don't just make this up based on nothing. DEFAULT_RECEIVE_SIZE = 65536 @@ -21,7 +26,7 @@ def __init__(self, handle: int) -> None: def closed(self): return self.handle == -1 - def _close(self): + def close(self): if self.closed: return handle = self.handle @@ -29,18 +34,15 @@ def _close(self): if not kernel32.CloseHandle(_handle(handle)): raise_winerror() - async def aclose(self): - self._close() - await _core.checkpoint() - def __del__(self): - self._close() + self.close() -class PipeSendStream(SendStream): +class PipeSendStream(SendStream, metaclass=Final): """Represents a send stream over a Windows named pipe that has been opened in OVERLAPPED mode. """ + def __init__(self, handle: int) -> None: self._handle_holder = _HandleHolder(handle) self._conflict_detector = ConflictDetector( @@ -57,9 +59,7 @@ async def send_all(self, data: bytes): return try: - written = await _core.write_overlapped( - self._handle_holder.handle, data - ) + written = await _core.write_overlapped(self._handle_holder.handle, data) except BrokenPipeError as ex: raise _core.BrokenResourceError from ex # By my reading of MSDN, this assert is guaranteed to pass so long @@ -75,12 +75,17 @@ async def wait_send_all_might_not_block(self) -> None: # not implemented yet, and probably not needed await _core.checkpoint() + def close(self): + self._handle_holder.close() + async def aclose(self): - await self._handle_holder.aclose() + self.close() + await _core.checkpoint() -class PipeReceiveStream(ReceiveStream): +class PipeReceiveStream(ReceiveStream, metaclass=Final): """Represents a receive stream over an os.pipe object.""" + def __init__(self, handle: int) -> None: self._handle_holder = _HandleHolder(handle) self._conflict_detector = ConflictDetector( @@ -126,5 +131,9 @@ async def receive_some(self, max_bytes=None) -> bytes: del buffer[size:] return buffer + def close(self): + self._handle_holder.close() + async def aclose(self): - await self._handle_holder.aclose() + self.close() + await _core.checkpoint() diff --git a/trio/abc.py b/trio/abc.py index e3348360e4..439995640e 100644 --- a/trio/abc.py +++ b/trio/abc.py @@ -4,8 +4,20 @@ # temporaries, imports, etc. when implementing the module. So we put the # implementation in an underscored module, and then re-export the public parts # here. + +# Uses `from x import y as y` for compatibility with `pyright --verifytypes` (#2625) from ._abc import ( - Clock, Instrument, AsyncResource, SendStream, ReceiveStream, Stream, - HalfCloseableStream, SocketFactory, HostnameResolver, Listener, - SendChannel, ReceiveChannel, Channel + AsyncResource as AsyncResource, + Channel as Channel, + Clock as Clock, + HalfCloseableStream as HalfCloseableStream, + HostnameResolver as HostnameResolver, + Instrument as Instrument, + Listener as Listener, + ReceiveChannel as ReceiveChannel, + ReceiveStream as ReceiveStream, + SendChannel as SendChannel, + SendStream as SendStream, + SocketFactory as SocketFactory, + Stream as Stream, ) diff --git a/trio/from_thread.py b/trio/from_thread.py index 296a5a89ea..e6f7b2495e 100644 --- a/trio/from_thread.py +++ b/trio/from_thread.py @@ -3,5 +3,8 @@ an external thread by means of a Trio Token present in Thread Local Storage """ -from ._threads import from_thread_run as run -from ._threads import from_thread_run_sync as run_sync + +from ._threads import from_thread_run as run, from_thread_run_sync as run_sync + +# need to use __all__ for pyright --verifytypes to see re-exports when renaming them +__all__ = ["run", "run_sync"] diff --git a/trio/hazmat.py b/trio/hazmat.py deleted file mode 100644 index 5fe32c03d9..0000000000 --- a/trio/hazmat.py +++ /dev/null @@ -1,59 +0,0 @@ -""" -This namespace represents low-level functionality not intended for daily use, -but useful for extending Trio's functionality. -""" - -import os -import sys - -# This is the union of a subset of trio/_core/ and some things from trio/*.py. -# See comments in trio/__init__.py for details. To make static analysis easier, -# this lists all possible symbols from trio._core, and then we prune those that -# aren't available on this system. After that we add some symbols from trio/*.py. - -# Generally available symbols -from ._core import ( - cancel_shielded_checkpoint, Abort, wait_task_rescheduled, - enable_ki_protection, disable_ki_protection, currently_ki_protected, Task, - checkpoint, current_task, ParkingLot, UnboundedQueue, RunVar, TrioToken, - current_trio_token, temporarily_detach_coroutine_object, - permanently_detach_coroutine_object, reattach_detached_coroutine_object, - current_statistics, reschedule, remove_instrument, add_instrument, - current_clock, current_root_task, checkpoint_if_cancelled, - spawn_system_task, wait_readable, wait_writable, notify_closing -) - -# Unix-specific symbols -try: - from ._unix_pipes import FdStream -except ImportError: - pass - -# Kqueue-specific symbols -try: - from ._core import ( - current_kqueue, - monitor_kevent, - wait_kevent, - ) -except ImportError: - pass - -# Windows symbols -try: - from ._core import ( - current_iocp, - register_with_iocp, - wait_overlapped, - monitor_completion_key, - readinto_overlapped, - write_overlapped, - ) -except ImportError: - pass - -from . import _core - -# Import bits from trio/*.py -if sys.platform.startswith("win"): - from ._wait_for_object import WaitForSingleObject diff --git a/trio/lowlevel.py b/trio/lowlevel.py new file mode 100644 index 0000000000..54f4ef3141 --- /dev/null +++ b/trio/lowlevel.py @@ -0,0 +1,76 @@ +""" +This namespace represents low-level functionality not intended for daily use, +but useful for extending Trio's functionality. +""" + +import select as _select +import sys +import typing as _t + +# Generally available symbols +from ._core import ( + Abort as Abort, + ParkingLot as ParkingLot, + ParkingLotStatistics as ParkingLotStatistics, + RaiseCancelT as RaiseCancelT, + RunVar as RunVar, + Task as Task, + TrioToken as TrioToken, + UnboundedQueue as UnboundedQueue, + add_instrument as add_instrument, + cancel_shielded_checkpoint as cancel_shielded_checkpoint, + checkpoint as checkpoint, + checkpoint_if_cancelled as checkpoint_if_cancelled, + current_clock as current_clock, + current_root_task as current_root_task, + current_statistics as current_statistics, + current_task as current_task, + current_trio_token as current_trio_token, + currently_ki_protected as currently_ki_protected, + disable_ki_protection as disable_ki_protection, + enable_ki_protection as enable_ki_protection, + notify_closing as notify_closing, + permanently_detach_coroutine_object as permanently_detach_coroutine_object, + reattach_detached_coroutine_object as reattach_detached_coroutine_object, + remove_instrument as remove_instrument, + reschedule as reschedule, + spawn_system_task as spawn_system_task, + start_guest_run as start_guest_run, + start_thread_soon as start_thread_soon, + temporarily_detach_coroutine_object as temporarily_detach_coroutine_object, + wait_readable as wait_readable, + wait_task_rescheduled as wait_task_rescheduled, + wait_writable as wait_writable, +) +from ._subprocess import open_process as open_process + +# This is the union of a subset of trio/_core/ and some things from trio/*.py. +# See comments in trio/__init__.py for details. + +# Uses `from x import y as y` for compatibility with `pyright --verifytypes` (#2625) + + +if sys.platform == "win32": + # Windows symbols + from ._core import ( + current_iocp as current_iocp, + monitor_completion_key as monitor_completion_key, + readinto_overlapped as readinto_overlapped, + register_with_iocp as register_with_iocp, + wait_overlapped as wait_overlapped, + write_overlapped as write_overlapped, + ) + from ._wait_for_object import WaitForSingleObject as WaitForSingleObject +else: + # Unix symbols + from ._unix_pipes import FdStream as FdStream + + # Kqueue-specific symbols + if sys.platform != "linux" and (_t.TYPE_CHECKING or not hasattr(_select, "epoll")): + from ._core import ( + current_kqueue as current_kqueue, + monitor_kevent as monitor_kevent, + wait_kevent as wait_kevent, + ) + +del sys diff --git a/trio/socket.py b/trio/socket.py index 266fc0dbda..f6aebb6a6e 100644 --- a/trio/socket.py +++ b/trio/socket.py @@ -6,173 +6,561 @@ # here. # We still have some underscore names though but only a few. -from . import _socket -import sys as _sys -# The socket module exports a bunch of platform-specific constants. We want to -# re-export them. Since the exact set of constants varies depending on Python -# version, platform, the libc installed on the system where Python was built, -# etc., we figure out which constants to re-export dynamically at runtime (see -# below). But that confuses static analysis tools like jedi and mypy. So this -# import statement statically lists every constant that *could* be -# exported. It always fails at runtime, since no single Python build exports -# all these constants, but it lets static analysis tools understand what's -# going on. There's a test in test_exports.py to make sure that the list is -# kept up to date. -try: - from socket import ( - CMSG_LEN, CMSG_SPACE, CAPI, AF_UNSPEC, AF_INET, AF_UNIX, AF_IPX, - AF_APPLETALK, AF_INET6, AF_ROUTE, AF_LINK, AF_SNA, PF_SYSTEM, - AF_SYSTEM, SOCK_STREAM, SOCK_DGRAM, SOCK_RAW, SOCK_SEQPACKET, SOCK_RDM, - SO_DEBUG, SO_ACCEPTCONN, SO_REUSEADDR, SO_KEEPALIVE, SO_DONTROUTE, - SO_BROADCAST, SO_USELOOPBACK, SO_LINGER, SO_OOBINLINE, SO_REUSEPORT, - SO_SNDBUF, SO_RCVBUF, SO_SNDLOWAT, SO_RCVLOWAT, SO_SNDTIMEO, - SO_RCVTIMEO, SO_ERROR, SO_TYPE, LOCAL_PEERCRED, SOMAXCONN, SCM_RIGHTS, - SCM_CREDS, MSG_OOB, MSG_PEEK, MSG_DONTROUTE, MSG_DONTWAIT, MSG_EOR, - MSG_TRUNC, MSG_CTRUNC, MSG_WAITALL, MSG_EOF, SOL_SOCKET, SOL_IP, - SOL_TCP, SOL_UDP, IPPROTO_IP, IPPROTO_HOPOPTS, IPPROTO_ICMP, - IPPROTO_IGMP, IPPROTO_GGP, IPPROTO_IPV4, IPPROTO_IPIP, IPPROTO_TCP, - IPPROTO_EGP, IPPROTO_PUP, IPPROTO_UDP, IPPROTO_IDP, IPPROTO_HELLO, - IPPROTO_ND, IPPROTO_TP, IPPROTO_ROUTING, IPPROTO_FRAGMENT, - IPPROTO_RSVP, IPPROTO_GRE, IPPROTO_ESP, IPPROTO_AH, IPPROTO_ICMPV6, - IPPROTO_NONE, IPPROTO_DSTOPTS, IPPROTO_XTP, IPPROTO_EON, IPPROTO_PIM, - IPPROTO_IPCOMP, IPPROTO_SCTP, IPPROTO_RAW, IPPROTO_MAX, - SYSPROTO_CONTROL, IPPORT_RESERVED, IPPORT_USERRESERVED, INADDR_ANY, - INADDR_BROADCAST, INADDR_LOOPBACK, INADDR_UNSPEC_GROUP, - INADDR_ALLHOSTS_GROUP, INADDR_MAX_LOCAL_GROUP, INADDR_NONE, IP_OPTIONS, - IP_HDRINCL, IP_TOS, IP_TTL, IP_RECVOPTS, IP_RECVRETOPTS, - IP_RECVDSTADDR, IP_RETOPTS, IP_MULTICAST_IF, IP_MULTICAST_TTL, - IP_MULTICAST_LOOP, IP_ADD_MEMBERSHIP, IP_DROP_MEMBERSHIP, - IP_DEFAULT_MULTICAST_TTL, IP_DEFAULT_MULTICAST_LOOP, - IP_MAX_MEMBERSHIPS, IPV6_JOIN_GROUP, IPV6_LEAVE_GROUP, - IPV6_MULTICAST_HOPS, IPV6_MULTICAST_IF, IPV6_MULTICAST_LOOP, - IPV6_UNICAST_HOPS, IPV6_V6ONLY, IPV6_CHECKSUM, IPV6_RECVTCLASS, - IPV6_RTHDR_TYPE_0, IPV6_TCLASS, TCP_NODELAY, TCP_MAXSEG, TCP_KEEPINTVL, - TCP_KEEPCNT, TCP_FASTOPEN, TCP_NOTSENT_LOWAT, EAI_ADDRFAMILY, - EAI_AGAIN, EAI_BADFLAGS, EAI_FAIL, EAI_FAMILY, EAI_MEMORY, EAI_NODATA, - EAI_NONAME, EAI_OVERFLOW, EAI_SERVICE, EAI_SOCKTYPE, EAI_SYSTEM, - EAI_BADHINTS, EAI_PROTOCOL, EAI_MAX, AI_PASSIVE, AI_CANONNAME, - AI_NUMERICHOST, AI_NUMERICSERV, AI_MASK, AI_ALL, AI_V4MAPPED_CFG, - AI_ADDRCONFIG, AI_V4MAPPED, AI_DEFAULT, NI_MAXHOST, NI_MAXSERV, - NI_NOFQDN, NI_NUMERICHOST, NI_NAMEREQD, NI_NUMERICSERV, NI_DGRAM, - SHUT_RD, SHUT_WR, SHUT_RDWR, EBADF, EAGAIN, EWOULDBLOCK, AF_ASH, - AF_ATMPVC, AF_ATMSVC, AF_AX25, AF_BLUETOOTH, AF_BRIDGE, AF_ECONET, - AF_IRDA, AF_KEY, AF_LLC, AF_NETBEUI, AF_NETLINK, AF_NETROM, AF_PACKET, - AF_PPPOX, AF_ROSE, AF_SECURITY, AF_WANPIPE, AF_X25, BDADDR_ANY, - BDADDR_LOCAL, FD_SETSIZE, IPV6_DSTOPTS, IPV6_HOPLIMIT, IPV6_HOPOPTS, - IPV6_NEXTHOP, IPV6_PKTINFO, IPV6_RECVDSTOPTS, IPV6_RECVHOPLIMIT, - IPV6_RECVHOPOPTS, IPV6_RECVPKTINFO, IPV6_RECVRTHDR, IPV6_RTHDR, - IPV6_RTHDRDSTOPTS, MSG_ERRQUEUE, NETLINK_DNRTMSG, NETLINK_FIREWALL, - NETLINK_IP6_FW, NETLINK_NFLOG, NETLINK_ROUTE, NETLINK_USERSOCK, - NETLINK_XFRM, PACKET_BROADCAST, PACKET_FASTROUTE, PACKET_HOST, - PACKET_LOOPBACK, PACKET_MULTICAST, PACKET_OTHERHOST, PACKET_OUTGOING, - POLLERR, POLLHUP, POLLIN, POLLMSG, POLLNVAL, POLLOUT, POLLPRI, - POLLRDBAND, POLLRDNORM, POLLWRNORM, SIOCGIFINDEX, SIOCGIFNAME, - SOCK_CLOEXEC, TCP_CORK, TCP_DEFER_ACCEPT, TCP_INFO, TCP_KEEPIDLE, - TCP_LINGER2, TCP_QUICKACK, TCP_SYNCNT, TCP_WINDOW_CLAMP, AF_ALG, - AF_CAN, AF_RDS, AF_TIPC, AF_VSOCK, ALG_OP_DECRYPT, ALG_OP_ENCRYPT, - ALG_OP_SIGN, ALG_OP_VERIFY, ALG_SET_AEAD_ASSOCLEN, - ALG_SET_AEAD_AUTHSIZE, ALG_SET_IV, ALG_SET_KEY, ALG_SET_OP, - ALG_SET_PUBKEY, CAN_BCM, CAN_BCM_RX_CHANGED, CAN_BCM_RX_DELETE, - CAN_BCM_RX_READ, CAN_BCM_RX_SETUP, CAN_BCM_RX_STATUS, - CAN_BCM_RX_TIMEOUT, CAN_BCM_TX_DELETE, CAN_BCM_TX_EXPIRED, - CAN_BCM_TX_READ, CAN_BCM_TX_SEND, CAN_BCM_TX_SETUP, CAN_BCM_TX_STATUS, - CAN_EFF_FLAG, CAN_EFF_MASK, CAN_ERR_FLAG, CAN_ERR_MASK, CAN_ISOTP, - CAN_RAW, CAN_RAW_ERR_FILTER, CAN_RAW_FD_FRAMES, CAN_RAW_FILTER, - CAN_RAW_LOOPBACK, CAN_RAW_RECV_OWN_MSGS, CAN_RTR_FLAG, CAN_SFF_MASK, - IOCTL_VM_SOCKETS_GET_LOCAL_CID, IPV6_DONTFRAG, IPV6_PATHMTU, - IPV6_RECVPATHMTU, IP_TRANSPARENT, MSG_CMSG_CLOEXEC, MSG_CONFIRM, - MSG_FASTOPEN, MSG_MORE, MSG_NOSIGNAL, NETLINK_CRYPTO, PF_CAN, - PF_PACKET, PF_RDS, SCM_CREDENTIALS, SOCK_NONBLOCK, SOL_ALG, - SOL_CAN_BASE, SOL_CAN_RAW, SOL_TIPC, SO_BINDTODEVICE, SO_DOMAIN, - SO_MARK, SO_PASSCRED, SO_PASSSEC, SO_PEERCRED, SO_PEERSEC, SO_PRIORITY, - SO_PROTOCOL, SO_VM_SOCKETS_BUFFER_MAX_SIZE, - SO_VM_SOCKETS_BUFFER_MIN_SIZE, SO_VM_SOCKETS_BUFFER_SIZE, - TCP_CONGESTION, TCP_USER_TIMEOUT, TIPC_ADDR_ID, TIPC_ADDR_NAME, - TIPC_ADDR_NAMESEQ, TIPC_CFG_SRV, TIPC_CLUSTER_SCOPE, TIPC_CONN_TIMEOUT, - TIPC_CRITICAL_IMPORTANCE, TIPC_DEST_DROPPABLE, TIPC_HIGH_IMPORTANCE, - TIPC_IMPORTANCE, TIPC_LOW_IMPORTANCE, TIPC_MEDIUM_IMPORTANCE, - TIPC_NODE_SCOPE, TIPC_PUBLISHED, TIPC_SRC_DROPPABLE, - TIPC_SUBSCR_TIMEOUT, TIPC_SUB_CANCEL, TIPC_SUB_PORTS, TIPC_SUB_SERVICE, - TIPC_TOP_SRV, TIPC_WAIT_FOREVER, TIPC_WITHDRAWN, TIPC_ZONE_SCOPE, - VMADDR_CID_ANY, VMADDR_CID_HOST, VMADDR_PORT_ANY, - VM_SOCKETS_INVALID_VERSION, MSG_BCAST, MSG_MCAST, RCVALL_MAX, - RCVALL_OFF, RCVALL_ON, RCVALL_SOCKETLEVELONLY, SIO_KEEPALIVE_VALS, - SIO_LOOPBACK_FAST_PATH, SIO_RCVALL, SO_EXCLUSIVEADDRUSE, HCI_FILTER, - BTPROTO_SCO, BTPROTO_HCI, HCI_TIME_STAMP, SOL_RDS, BTPROTO_L2CAP, - BTPROTO_RFCOMM, HCI_DATA_DIR, SOL_HCI, CAN_BCM_RX_ANNOUNCE_RESUME, - CAN_BCM_RX_CHECK_DLC, CAN_BCM_RX_FILTER_ID, CAN_BCM_RX_NO_AUTOTIMER, - CAN_BCM_RX_RTR_FRAME, CAN_BCM_SETTIMER, CAN_BCM_STARTTIMER, - CAN_BCM_TX_ANNOUNCE, CAN_BCM_TX_COUNTEVT, CAN_BCM_TX_CP_CAN_ID, - CAN_BCM_TX_RESET_MULTI_IDX, IPPROTO_CBT, IPPROTO_ICLFXBM, IPPROTO_IGP, - IPPROTO_L2TP, IPPROTO_PGM, IPPROTO_RDP, IPPROTO_ST, AF_QIPCRTR, - CAN_BCM_CAN_FD_FRAME - ) -except ImportError: - pass +# Uses `from x import y as y` for compatibility with `pyright --verifytypes` (#2625) # Dynamically re-export whatever constants this particular Python happens to # have: import socket as _stdlib_socket +import sys +import typing as _t + +from . import _socket + +_bad_symbols: _t.Set[str] = set() +if sys.platform == "win32": + # See https://github.com/python-trio/trio/issues/39 + # Do not import for windows platform + # (you can still get it from stdlib socket, of course, if you want it) + _bad_symbols.add("SO_REUSEADDR") globals().update( { _name: getattr(_stdlib_socket, _name) - for _name in _stdlib_socket.__all__ if _name.isupper() + for _name in _stdlib_socket.__all__ # type: ignore + if _name.isupper() and _name not in _bad_symbols } ) # import the overwrites from ._socket import ( - fromfd, from_stdlib_socket, getprotobyname, socketpair, getnameinfo, - socket, getaddrinfo, set_custom_hostname_resolver, - set_custom_socket_factory, SocketType + SocketType as SocketType, + _SocketType as _SocketType, + from_stdlib_socket as from_stdlib_socket, + fromfd as fromfd, + getaddrinfo as getaddrinfo, + getnameinfo as getnameinfo, + getprotobyname as getprotobyname, + set_custom_hostname_resolver as set_custom_hostname_resolver, + set_custom_socket_factory as set_custom_socket_factory, + socket as socket, + socketpair as socketpair, ) # not always available so expose only if -try: - from ._socket import fromshare -except ImportError: - pass +if sys.platform == "win32" or not _t.TYPE_CHECKING: + try: + from ._socket import fromshare as fromshare + except ImportError: + pass # expose these functions to trio.socket from socket import ( - gaierror, - herror, - gethostname, - ntohs, - htonl, - htons, - inet_aton, - inet_ntoa, - inet_pton, - inet_ntop, + gaierror as gaierror, + gethostname as gethostname, + herror as herror, + htonl as htonl, + htons as htons, + inet_aton as inet_aton, + inet_ntoa as inet_ntoa, + inet_ntop as inet_ntop, + inet_pton as inet_pton, + ntohs as ntohs, ) # not always available so expose only if -try: - from socket import ( - sethostname, if_nameindex, if_nametoindex, if_indextoname - ) -except ImportError: - pass +if sys.platform != "win32" or not _t.TYPE_CHECKING: + try: + from socket import ( + if_indextoname as if_indextoname, + if_nameindex as if_nameindex, + if_nametoindex as if_nametoindex, + sethostname as sethostname, + ) + except ImportError: + pass -if _sys.platform == 'win32': - # See https://github.com/python-trio/trio/issues/39 - # Do not import for windows platform - # (you can still get it from stdlib socket, of course, if you want it) - del SO_REUSEADDR +if _t.TYPE_CHECKING: + IP_BIND_ADDRESS_NO_PORT: int +else: + try: + IP_BIND_ADDRESS_NO_PORT + except NameError: + if sys.platform == "linux": + IP_BIND_ADDRESS_NO_PORT = 24 + +del sys -# get names used by Trio that we define on our own -from ._socket import IPPROTO_IPV6 -# Not defined in all python versions and platforms but sometimes needed -try: - TCP_NOTSENT_LOWAT -except NameError: - # Hopefully will show up in 3.7: - # https://github.com/python/cpython/pull/477 - if _sys.platform == "darwin": - TCP_NOTSENT_LOWAT = 0x201 - elif _sys.platform == "linux": - TCP_NOTSENT_LOWAT = 25 +# The socket module exports a bunch of platform-specific constants. We want to +# re-export them. Since the exact set of constants varies depending on Python +# version, platform, the libc installed on the system where Python was built, +# etc., we figure out which constants to re-export dynamically at runtime (see +# below). But that confuses static analysis tools like jedi and mypy. So this +# import statement statically lists every constant that *could* be +# exported. There's a test in test_exports.py to make sure that the list is +# kept up to date. +if _t.TYPE_CHECKING: + from socket import ( # type: ignore[attr-defined] + AF_ALG as AF_ALG, + AF_APPLETALK as AF_APPLETALK, + AF_ASH as AF_ASH, + AF_ATMPVC as AF_ATMPVC, + AF_ATMSVC as AF_ATMSVC, + AF_AX25 as AF_AX25, + AF_BLUETOOTH as AF_BLUETOOTH, + AF_BRIDGE as AF_BRIDGE, + AF_CAN as AF_CAN, + AF_ECONET as AF_ECONET, + AF_INET as AF_INET, + AF_INET6 as AF_INET6, + AF_IPX as AF_IPX, + AF_IRDA as AF_IRDA, + AF_KEY as AF_KEY, + AF_LINK as AF_LINK, + AF_LLC as AF_LLC, + AF_NETBEUI as AF_NETBEUI, + AF_NETLINK as AF_NETLINK, + AF_NETROM as AF_NETROM, + AF_PACKET as AF_PACKET, + AF_PPPOX as AF_PPPOX, + AF_QIPCRTR as AF_QIPCRTR, + AF_RDS as AF_RDS, + AF_ROSE as AF_ROSE, + AF_ROUTE as AF_ROUTE, + AF_SECURITY as AF_SECURITY, + AF_SNA as AF_SNA, + AF_SYSTEM as AF_SYSTEM, + AF_TIPC as AF_TIPC, + AF_UNIX as AF_UNIX, + AF_UNSPEC as AF_UNSPEC, + AF_VSOCK as AF_VSOCK, + AF_WANPIPE as AF_WANPIPE, + AF_X25 as AF_X25, + AI_ADDRCONFIG as AI_ADDRCONFIG, + AI_ALL as AI_ALL, + AI_CANONNAME as AI_CANONNAME, + AI_DEFAULT as AI_DEFAULT, + AI_MASK as AI_MASK, + AI_NUMERICHOST as AI_NUMERICHOST, + AI_NUMERICSERV as AI_NUMERICSERV, + AI_PASSIVE as AI_PASSIVE, + AI_V4MAPPED as AI_V4MAPPED, + AI_V4MAPPED_CFG as AI_V4MAPPED_CFG, + ALG_OP_DECRYPT as ALG_OP_DECRYPT, + ALG_OP_ENCRYPT as ALG_OP_ENCRYPT, + ALG_OP_SIGN as ALG_OP_SIGN, + ALG_OP_VERIFY as ALG_OP_VERIFY, + ALG_SET_AEAD_ASSOCLEN as ALG_SET_AEAD_ASSOCLEN, + ALG_SET_AEAD_AUTHSIZE as ALG_SET_AEAD_AUTHSIZE, + ALG_SET_IV as ALG_SET_IV, + ALG_SET_KEY as ALG_SET_KEY, + ALG_SET_OP as ALG_SET_OP, + ALG_SET_PUBKEY as ALG_SET_PUBKEY, + BDADDR_ANY as BDADDR_ANY, + BDADDR_LOCAL as BDADDR_LOCAL, + BTPROTO_HCI as BTPROTO_HCI, + BTPROTO_L2CAP as BTPROTO_L2CAP, + BTPROTO_RFCOMM as BTPROTO_RFCOMM, + BTPROTO_SCO as BTPROTO_SCO, + CAN_BCM as CAN_BCM, + CAN_BCM_CAN_FD_FRAME as CAN_BCM_CAN_FD_FRAME, + CAN_BCM_RX_ANNOUNCE_RESUME as CAN_BCM_RX_ANNOUNCE_RESUME, + CAN_BCM_RX_CHANGED as CAN_BCM_RX_CHANGED, + CAN_BCM_RX_CHECK_DLC as CAN_BCM_RX_CHECK_DLC, + CAN_BCM_RX_DELETE as CAN_BCM_RX_DELETE, + CAN_BCM_RX_FILTER_ID as CAN_BCM_RX_FILTER_ID, + CAN_BCM_RX_NO_AUTOTIMER as CAN_BCM_RX_NO_AUTOTIMER, + CAN_BCM_RX_READ as CAN_BCM_RX_READ, + CAN_BCM_RX_RTR_FRAME as CAN_BCM_RX_RTR_FRAME, + CAN_BCM_RX_SETUP as CAN_BCM_RX_SETUP, + CAN_BCM_RX_STATUS as CAN_BCM_RX_STATUS, + CAN_BCM_RX_TIMEOUT as CAN_BCM_RX_TIMEOUT, + CAN_BCM_SETTIMER as CAN_BCM_SETTIMER, + CAN_BCM_STARTTIMER as CAN_BCM_STARTTIMER, + CAN_BCM_TX_ANNOUNCE as CAN_BCM_TX_ANNOUNCE, + CAN_BCM_TX_COUNTEVT as CAN_BCM_TX_COUNTEVT, + CAN_BCM_TX_CP_CAN_ID as CAN_BCM_TX_CP_CAN_ID, + CAN_BCM_TX_DELETE as CAN_BCM_TX_DELETE, + CAN_BCM_TX_EXPIRED as CAN_BCM_TX_EXPIRED, + CAN_BCM_TX_READ as CAN_BCM_TX_READ, + CAN_BCM_TX_RESET_MULTI_IDX as CAN_BCM_TX_RESET_MULTI_IDX, + CAN_BCM_TX_SEND as CAN_BCM_TX_SEND, + CAN_BCM_TX_SETUP as CAN_BCM_TX_SETUP, + CAN_BCM_TX_STATUS as CAN_BCM_TX_STATUS, + CAN_EFF_FLAG as CAN_EFF_FLAG, + CAN_EFF_MASK as CAN_EFF_MASK, + CAN_ERR_FLAG as CAN_ERR_FLAG, + CAN_ERR_MASK as CAN_ERR_MASK, + CAN_ISOTP as CAN_ISOTP, + CAN_J1939 as CAN_J1939, + CAN_RAW as CAN_RAW, + CAN_RAW_ERR_FILTER as CAN_RAW_ERR_FILTER, + CAN_RAW_FD_FRAMES as CAN_RAW_FD_FRAMES, + CAN_RAW_FILTER as CAN_RAW_FILTER, + CAN_RAW_JOIN_FILTERS as CAN_RAW_JOIN_FILTERS, + CAN_RAW_LOOPBACK as CAN_RAW_LOOPBACK, + CAN_RAW_RECV_OWN_MSGS as CAN_RAW_RECV_OWN_MSGS, + CAN_RTR_FLAG as CAN_RTR_FLAG, + CAN_SFF_MASK as CAN_SFF_MASK, + CAPI as CAPI, + CMSG_LEN as CMSG_LEN, + CMSG_SPACE as CMSG_SPACE, + EAGAIN as EAGAIN, + EAI_ADDRFAMILY as EAI_ADDRFAMILY, + EAI_AGAIN as EAI_AGAIN, + EAI_BADFLAGS as EAI_BADFLAGS, + EAI_BADHINTS as EAI_BADHINTS, + EAI_FAIL as EAI_FAIL, + EAI_FAMILY as EAI_FAMILY, + EAI_MAX as EAI_MAX, + EAI_MEMORY as EAI_MEMORY, + EAI_NODATA as EAI_NODATA, + EAI_NONAME as EAI_NONAME, + EAI_OVERFLOW as EAI_OVERFLOW, + EAI_PROTOCOL as EAI_PROTOCOL, + EAI_SERVICE as EAI_SERVICE, + EAI_SOCKTYPE as EAI_SOCKTYPE, + EAI_SYSTEM as EAI_SYSTEM, + EBADF as EBADF, + ETH_P_ALL as ETH_P_ALL, + ETHERTYPE_ARP as ETHERTYPE_ARP, + ETHERTYPE_IP as ETHERTYPE_IP, + ETHERTYPE_IPV6 as ETHERTYPE_IPV6, + ETHERTYPE_VLAN as ETHERTYPE_VLAN, + EWOULDBLOCK as EWOULDBLOCK, + FD_ACCEPT as FD_ACCEPT, + FD_CLOSE as FD_CLOSE, + FD_CLOSE_BIT as FD_CLOSE_BIT, + FD_CONNECT as FD_CONNECT, + FD_CONNECT_BIT as FD_CONNECT_BIT, + FD_READ as FD_READ, + FD_SETSIZE as FD_SETSIZE, + FD_WRITE as FD_WRITE, + HCI_DATA_DIR as HCI_DATA_DIR, + HCI_FILTER as HCI_FILTER, + HCI_TIME_STAMP as HCI_TIME_STAMP, + INADDR_ALLHOSTS_GROUP as INADDR_ALLHOSTS_GROUP, + INADDR_ANY as INADDR_ANY, + INADDR_BROADCAST as INADDR_BROADCAST, + INADDR_LOOPBACK as INADDR_LOOPBACK, + INADDR_MAX_LOCAL_GROUP as INADDR_MAX_LOCAL_GROUP, + INADDR_NONE as INADDR_NONE, + INADDR_UNSPEC_GROUP as INADDR_UNSPEC_GROUP, + INFINITE as INFINITE, + IOCTL_VM_SOCKETS_GET_LOCAL_CID as IOCTL_VM_SOCKETS_GET_LOCAL_CID, + IP_ADD_MEMBERSHIP as IP_ADD_MEMBERSHIP, + IP_ADD_SOURCE_MEMBERSHIP as IP_ADD_SOURCE_MEMBERSHIP, + IP_BLOCK_SOURCE as IP_BLOCK_SOURCE, + IP_DEFAULT_MULTICAST_LOOP as IP_DEFAULT_MULTICAST_LOOP, + IP_DEFAULT_MULTICAST_TTL as IP_DEFAULT_MULTICAST_TTL, + IP_DROP_MEMBERSHIP as IP_DROP_MEMBERSHIP, + IP_DROP_SOURCE_MEMBERSHIP as IP_DROP_SOURCE_MEMBERSHIP, + IP_HDRINCL as IP_HDRINCL, + IP_MAX_MEMBERSHIPS as IP_MAX_MEMBERSHIPS, + IP_MULTICAST_IF as IP_MULTICAST_IF, + IP_MULTICAST_LOOP as IP_MULTICAST_LOOP, + IP_MULTICAST_TTL as IP_MULTICAST_TTL, + IP_OPTIONS as IP_OPTIONS, + IP_PKTINFO as IP_PKTINFO, + IP_RECVDSTADDR as IP_RECVDSTADDR, + IP_RECVOPTS as IP_RECVOPTS, + IP_RECVRETOPTS as IP_RECVRETOPTS, + IP_RECVTOS as IP_RECVTOS, + IP_RETOPTS as IP_RETOPTS, + IP_TOS as IP_TOS, + IP_TRANSPARENT as IP_TRANSPARENT, + IP_TTL as IP_TTL, + IP_UNBLOCK_SOURCE as IP_UNBLOCK_SOURCE, + IPPORT_RESERVED as IPPORT_RESERVED, + IPPORT_USERRESERVED as IPPORT_USERRESERVED, + IPPROTO_AH as IPPROTO_AH, + IPPROTO_CBT as IPPROTO_CBT, + IPPROTO_DSTOPTS as IPPROTO_DSTOPTS, + IPPROTO_EGP as IPPROTO_EGP, + IPPROTO_EON as IPPROTO_EON, + IPPROTO_ESP as IPPROTO_ESP, + IPPROTO_FRAGMENT as IPPROTO_FRAGMENT, + IPPROTO_GGP as IPPROTO_GGP, + IPPROTO_GRE as IPPROTO_GRE, + IPPROTO_HELLO as IPPROTO_HELLO, + IPPROTO_HOPOPTS as IPPROTO_HOPOPTS, + IPPROTO_ICLFXBM as IPPROTO_ICLFXBM, + IPPROTO_ICMP as IPPROTO_ICMP, + IPPROTO_ICMPV6 as IPPROTO_ICMPV6, + IPPROTO_IDP as IPPROTO_IDP, + IPPROTO_IGMP as IPPROTO_IGMP, + IPPROTO_IGP as IPPROTO_IGP, + IPPROTO_IP as IPPROTO_IP, + IPPROTO_IPCOMP as IPPROTO_IPCOMP, + IPPROTO_IPIP as IPPROTO_IPIP, + IPPROTO_IPV4 as IPPROTO_IPV4, + IPPROTO_IPV6 as IPPROTO_IPV6, + IPPROTO_L2TP as IPPROTO_L2TP, + IPPROTO_MAX as IPPROTO_MAX, + IPPROTO_MOBILE as IPPROTO_MOBILE, + IPPROTO_MPTCP as IPPROTO_MPTCP, + IPPROTO_ND as IPPROTO_ND, + IPPROTO_NONE as IPPROTO_NONE, + IPPROTO_PGM as IPPROTO_PGM, + IPPROTO_PIM as IPPROTO_PIM, + IPPROTO_PUP as IPPROTO_PUP, + IPPROTO_RAW as IPPROTO_RAW, + IPPROTO_RDP as IPPROTO_RDP, + IPPROTO_ROUTING as IPPROTO_ROUTING, + IPPROTO_RSVP as IPPROTO_RSVP, + IPPROTO_SCTP as IPPROTO_SCTP, + IPPROTO_ST as IPPROTO_ST, + IPPROTO_TCP as IPPROTO_TCP, + IPPROTO_TP as IPPROTO_TP, + IPPROTO_UDP as IPPROTO_UDP, + IPPROTO_UDPLITE as IPPROTO_UDPLITE, + IPPROTO_XTP as IPPROTO_XTP, + IPV6_CHECKSUM as IPV6_CHECKSUM, + IPV6_DONTFRAG as IPV6_DONTFRAG, + IPV6_DSTOPTS as IPV6_DSTOPTS, + IPV6_HOPLIMIT as IPV6_HOPLIMIT, + IPV6_HOPOPTS as IPV6_HOPOPTS, + IPV6_JOIN_GROUP as IPV6_JOIN_GROUP, + IPV6_LEAVE_GROUP as IPV6_LEAVE_GROUP, + IPV6_MULTICAST_HOPS as IPV6_MULTICAST_HOPS, + IPV6_MULTICAST_IF as IPV6_MULTICAST_IF, + IPV6_MULTICAST_LOOP as IPV6_MULTICAST_LOOP, + IPV6_NEXTHOP as IPV6_NEXTHOP, + IPV6_PATHMTU as IPV6_PATHMTU, + IPV6_PKTINFO as IPV6_PKTINFO, + IPV6_RECVDSTOPTS as IPV6_RECVDSTOPTS, + IPV6_RECVHOPLIMIT as IPV6_RECVHOPLIMIT, + IPV6_RECVHOPOPTS as IPV6_RECVHOPOPTS, + IPV6_RECVPATHMTU as IPV6_RECVPATHMTU, + IPV6_RECVPKTINFO as IPV6_RECVPKTINFO, + IPV6_RECVRTHDR as IPV6_RECVRTHDR, + IPV6_RECVTCLASS as IPV6_RECVTCLASS, + IPV6_RTHDR as IPV6_RTHDR, + IPV6_RTHDR_TYPE_0 as IPV6_RTHDR_TYPE_0, + IPV6_RTHDRDSTOPTS as IPV6_RTHDRDSTOPTS, + IPV6_TCLASS as IPV6_TCLASS, + IPV6_UNICAST_HOPS as IPV6_UNICAST_HOPS, + IPV6_USE_MIN_MTU as IPV6_USE_MIN_MTU, + IPV6_V6ONLY as IPV6_V6ONLY, + J1939_EE_INFO_NONE as J1939_EE_INFO_NONE, + J1939_EE_INFO_TX_ABORT as J1939_EE_INFO_TX_ABORT, + J1939_FILTER_MAX as J1939_FILTER_MAX, + J1939_IDLE_ADDR as J1939_IDLE_ADDR, + J1939_MAX_UNICAST_ADDR as J1939_MAX_UNICAST_ADDR, + J1939_NLA_BYTES_ACKED as J1939_NLA_BYTES_ACKED, + J1939_NLA_PAD as J1939_NLA_PAD, + J1939_NO_ADDR as J1939_NO_ADDR, + J1939_NO_NAME as J1939_NO_NAME, + J1939_NO_PGN as J1939_NO_PGN, + J1939_PGN_ADDRESS_CLAIMED as J1939_PGN_ADDRESS_CLAIMED, + J1939_PGN_ADDRESS_COMMANDED as J1939_PGN_ADDRESS_COMMANDED, + J1939_PGN_MAX as J1939_PGN_MAX, + J1939_PGN_PDU1_MAX as J1939_PGN_PDU1_MAX, + J1939_PGN_REQUEST as J1939_PGN_REQUEST, + LOCAL_PEERCRED as LOCAL_PEERCRED, + MSG_BCAST as MSG_BCAST, + MSG_CMSG_CLOEXEC as MSG_CMSG_CLOEXEC, + MSG_CONFIRM as MSG_CONFIRM, + MSG_CTRUNC as MSG_CTRUNC, + MSG_DONTROUTE as MSG_DONTROUTE, + MSG_DONTWAIT as MSG_DONTWAIT, + MSG_EOF as MSG_EOF, + MSG_EOR as MSG_EOR, + MSG_ERRQUEUE as MSG_ERRQUEUE, + MSG_FASTOPEN as MSG_FASTOPEN, + MSG_MCAST as MSG_MCAST, + MSG_MORE as MSG_MORE, + MSG_NOSIGNAL as MSG_NOSIGNAL, + MSG_NOTIFICATION as MSG_NOTIFICATION, + MSG_OOB as MSG_OOB, + MSG_PEEK as MSG_PEEK, + MSG_TRUNC as MSG_TRUNC, + MSG_WAITALL as MSG_WAITALL, + NETLINK_CRYPTO as NETLINK_CRYPTO, + NETLINK_DNRTMSG as NETLINK_DNRTMSG, + NETLINK_FIREWALL as NETLINK_FIREWALL, + NETLINK_IP6_FW as NETLINK_IP6_FW, + NETLINK_NFLOG as NETLINK_NFLOG, + NETLINK_ROUTE as NETLINK_ROUTE, + NETLINK_USERSOCK as NETLINK_USERSOCK, + NETLINK_XFRM as NETLINK_XFRM, + NI_DGRAM as NI_DGRAM, + NI_MAXHOST as NI_MAXHOST, + NI_MAXSERV as NI_MAXSERV, + NI_NAMEREQD as NI_NAMEREQD, + NI_NOFQDN as NI_NOFQDN, + NI_NUMERICHOST as NI_NUMERICHOST, + NI_NUMERICSERV as NI_NUMERICSERV, + PACKET_BROADCAST as PACKET_BROADCAST, + PACKET_FASTROUTE as PACKET_FASTROUTE, + PACKET_HOST as PACKET_HOST, + PACKET_LOOPBACK as PACKET_LOOPBACK, + PACKET_MULTICAST as PACKET_MULTICAST, + PACKET_OTHERHOST as PACKET_OTHERHOST, + PACKET_OUTGOING as PACKET_OUTGOING, + PF_CAN as PF_CAN, + PF_PACKET as PF_PACKET, + PF_RDS as PF_RDS, + PF_SYSTEM as PF_SYSTEM, + POLLERR as POLLERR, + POLLHUP as POLLHUP, + POLLIN as POLLIN, + POLLMSG as POLLMSG, + POLLNVAL as POLLNVAL, + POLLOUT as POLLOUT, + POLLPRI as POLLPRI, + POLLRDBAND as POLLRDBAND, + POLLRDNORM as POLLRDNORM, + POLLWRNORM as POLLWRNORM, + RCVALL_MAX as RCVALL_MAX, + RCVALL_OFF as RCVALL_OFF, + RCVALL_ON as RCVALL_ON, + RCVALL_SOCKETLEVELONLY as RCVALL_SOCKETLEVELONLY, + SCM_CREDENTIALS as SCM_CREDENTIALS, + SCM_CREDS as SCM_CREDS, + SCM_J1939_DEST_ADDR as SCM_J1939_DEST_ADDR, + SCM_J1939_DEST_NAME as SCM_J1939_DEST_NAME, + SCM_J1939_ERRQUEUE as SCM_J1939_ERRQUEUE, + SCM_J1939_PRIO as SCM_J1939_PRIO, + SCM_RIGHTS as SCM_RIGHTS, + SHUT_RD as SHUT_RD, + SHUT_RDWR as SHUT_RDWR, + SHUT_WR as SHUT_WR, + SIO_KEEPALIVE_VALS as SIO_KEEPALIVE_VALS, + SIO_LOOPBACK_FAST_PATH as SIO_LOOPBACK_FAST_PATH, + SIO_RCVALL as SIO_RCVALL, + SIOCGIFINDEX as SIOCGIFINDEX, + SIOCGIFNAME as SIOCGIFNAME, + SO_ACCEPTCONN as SO_ACCEPTCONN, + SO_BINDTODEVICE as SO_BINDTODEVICE, + SO_BROADCAST as SO_BROADCAST, + SO_DEBUG as SO_DEBUG, + SO_DOMAIN as SO_DOMAIN, + SO_DONTROUTE as SO_DONTROUTE, + SO_ERROR as SO_ERROR, + SO_EXCLUSIVEADDRUSE as SO_EXCLUSIVEADDRUSE, + SO_INCOMING_CPU as SO_INCOMING_CPU, + SO_J1939_ERRQUEUE as SO_J1939_ERRQUEUE, + SO_J1939_FILTER as SO_J1939_FILTER, + SO_J1939_PROMISC as SO_J1939_PROMISC, + SO_J1939_SEND_PRIO as SO_J1939_SEND_PRIO, + SO_KEEPALIVE as SO_KEEPALIVE, + SO_LINGER as SO_LINGER, + SO_MARK as SO_MARK, + SO_OOBINLINE as SO_OOBINLINE, + SO_PASSCRED as SO_PASSCRED, + SO_PASSSEC as SO_PASSSEC, + SO_PEERCRED as SO_PEERCRED, + SO_PEERSEC as SO_PEERSEC, + SO_PRIORITY as SO_PRIORITY, + SO_PROTOCOL as SO_PROTOCOL, + SO_RCVBUF as SO_RCVBUF, + SO_RCVLOWAT as SO_RCVLOWAT, + SO_RCVTIMEO as SO_RCVTIMEO, + SO_REUSEADDR as SO_REUSEADDR, + SO_REUSEPORT as SO_REUSEPORT, + SO_SETFIB as SO_SETFIB, + SO_SNDBUF as SO_SNDBUF, + SO_SNDLOWAT as SO_SNDLOWAT, + SO_SNDTIMEO as SO_SNDTIMEO, + SO_TYPE as SO_TYPE, + SO_USELOOPBACK as SO_USELOOPBACK, + SO_VM_SOCKETS_BUFFER_MAX_SIZE as SO_VM_SOCKETS_BUFFER_MAX_SIZE, + SO_VM_SOCKETS_BUFFER_MIN_SIZE as SO_VM_SOCKETS_BUFFER_MIN_SIZE, + SO_VM_SOCKETS_BUFFER_SIZE as SO_VM_SOCKETS_BUFFER_SIZE, + SOCK_CLOEXEC as SOCK_CLOEXEC, + SOCK_DGRAM as SOCK_DGRAM, + SOCK_NONBLOCK as SOCK_NONBLOCK, + SOCK_RAW as SOCK_RAW, + SOCK_RDM as SOCK_RDM, + SOCK_SEQPACKET as SOCK_SEQPACKET, + SOCK_STREAM as SOCK_STREAM, + SOL_ALG as SOL_ALG, + SOL_CAN_BASE as SOL_CAN_BASE, + SOL_CAN_RAW as SOL_CAN_RAW, + SOL_HCI as SOL_HCI, + SOL_IP as SOL_IP, + SOL_RDS as SOL_RDS, + SOL_SOCKET as SOL_SOCKET, + SOL_TCP as SOL_TCP, + SOL_TIPC as SOL_TIPC, + SOL_UDP as SOL_UDP, + SOMAXCONN as SOMAXCONN, + SYSPROTO_CONTROL as SYSPROTO_CONTROL, + TCP_CC_INFO as TCP_CC_INFO, + TCP_CONGESTION as TCP_CONGESTION, + TCP_CORK as TCP_CORK, + TCP_DEFER_ACCEPT as TCP_DEFER_ACCEPT, + TCP_FASTOPEN as TCP_FASTOPEN, + TCP_FASTOPEN_CONNECT as TCP_FASTOPEN_CONNECT, + TCP_FASTOPEN_KEY as TCP_FASTOPEN_KEY, + TCP_FASTOPEN_NO_COOKIE as TCP_FASTOPEN_NO_COOKIE, + TCP_INFO as TCP_INFO, + TCP_INQ as TCP_INQ, + TCP_KEEPALIVE as TCP_KEEPALIVE, + TCP_KEEPCNT as TCP_KEEPCNT, + TCP_KEEPIDLE as TCP_KEEPIDLE, + TCP_KEEPINTVL as TCP_KEEPINTVL, + TCP_LINGER2 as TCP_LINGER2, + TCP_MAXSEG as TCP_MAXSEG, + TCP_MD5SIG as TCP_MD5SIG, + TCP_MD5SIG_EXT as TCP_MD5SIG_EXT, + TCP_NODELAY as TCP_NODELAY, + TCP_NOTSENT_LOWAT as TCP_NOTSENT_LOWAT, + TCP_QUEUE_SEQ as TCP_QUEUE_SEQ, + TCP_QUICKACK as TCP_QUICKACK, + TCP_REPAIR as TCP_REPAIR, + TCP_REPAIR_OPTIONS as TCP_REPAIR_OPTIONS, + TCP_REPAIR_QUEUE as TCP_REPAIR_QUEUE, + TCP_REPAIR_WINDOW as TCP_REPAIR_WINDOW, + TCP_SAVE_SYN as TCP_SAVE_SYN, + TCP_SAVED_SYN as TCP_SAVED_SYN, + TCP_SYNCNT as TCP_SYNCNT, + TCP_THIN_DUPACK as TCP_THIN_DUPACK, + TCP_THIN_LINEAR_TIMEOUTS as TCP_THIN_LINEAR_TIMEOUTS, + TCP_TIMESTAMP as TCP_TIMESTAMP, + TCP_TX_DELAY as TCP_TX_DELAY, + TCP_ULP as TCP_ULP, + TCP_USER_TIMEOUT as TCP_USER_TIMEOUT, + TCP_WINDOW_CLAMP as TCP_WINDOW_CLAMP, + TCP_ZEROCOPY_RECEIVE as TCP_ZEROCOPY_RECEIVE, + TIPC_ADDR_ID as TIPC_ADDR_ID, + TIPC_ADDR_NAME as TIPC_ADDR_NAME, + TIPC_ADDR_NAMESEQ as TIPC_ADDR_NAMESEQ, + TIPC_CFG_SRV as TIPC_CFG_SRV, + TIPC_CLUSTER_SCOPE as TIPC_CLUSTER_SCOPE, + TIPC_CONN_TIMEOUT as TIPC_CONN_TIMEOUT, + TIPC_CRITICAL_IMPORTANCE as TIPC_CRITICAL_IMPORTANCE, + TIPC_DEST_DROPPABLE as TIPC_DEST_DROPPABLE, + TIPC_HIGH_IMPORTANCE as TIPC_HIGH_IMPORTANCE, + TIPC_IMPORTANCE as TIPC_IMPORTANCE, + TIPC_LOW_IMPORTANCE as TIPC_LOW_IMPORTANCE, + TIPC_MEDIUM_IMPORTANCE as TIPC_MEDIUM_IMPORTANCE, + TIPC_NODE_SCOPE as TIPC_NODE_SCOPE, + TIPC_PUBLISHED as TIPC_PUBLISHED, + TIPC_SRC_DROPPABLE as TIPC_SRC_DROPPABLE, + TIPC_SUB_CANCEL as TIPC_SUB_CANCEL, + TIPC_SUB_PORTS as TIPC_SUB_PORTS, + TIPC_SUB_SERVICE as TIPC_SUB_SERVICE, + TIPC_SUBSCR_TIMEOUT as TIPC_SUBSCR_TIMEOUT, + TIPC_TOP_SRV as TIPC_TOP_SRV, + TIPC_WAIT_FOREVER as TIPC_WAIT_FOREVER, + TIPC_WITHDRAWN as TIPC_WITHDRAWN, + TIPC_ZONE_SCOPE as TIPC_ZONE_SCOPE, + UDPLITE_RECV_CSCOV as UDPLITE_RECV_CSCOV, + UDPLITE_SEND_CSCOV as UDPLITE_SEND_CSCOV, + VM_SOCKETS_INVALID_VERSION as VM_SOCKETS_INVALID_VERSION, + VMADDR_CID_ANY as VMADDR_CID_ANY, + VMADDR_CID_HOST as VMADDR_CID_HOST, + VMADDR_PORT_ANY as VMADDR_PORT_ANY, + WSA_FLAG_OVERLAPPED as WSA_FLAG_OVERLAPPED, + WSA_INVALID_HANDLE as WSA_INVALID_HANDLE, + WSA_INVALID_PARAMETER as WSA_INVALID_PARAMETER, + WSA_IO_INCOMPLETE as WSA_IO_INCOMPLETE, + WSA_IO_PENDING as WSA_IO_PENDING, + WSA_NOT_ENOUGH_MEMORY as WSA_NOT_ENOUGH_MEMORY, + WSA_OPERATION_ABORTED as WSA_OPERATION_ABORTED, + WSA_WAIT_FAILED as WSA_WAIT_FAILED, + WSA_WAIT_TIMEOUT as WSA_WAIT_TIMEOUT, + ) diff --git a/trio/testing/__init__.py b/trio/testing/__init__.py index df150ec62b..fa683e1145 100644 --- a/trio/testing/__init__.py +++ b/trio/testing/__init__.py @@ -1,27 +1,34 @@ -from .._core import wait_all_tasks_blocked - -from ._trio_test import trio_test - -from ._mock_clock import MockClock - -from ._checkpoints import assert_checkpoints, assert_no_checkpoints - -from ._sequencer import Sequencer +# Uses `from x import y as y` for compatibility with `pyright --verifytypes` (#2625) +from .._core import ( + MockClock as MockClock, + wait_all_tasks_blocked as wait_all_tasks_blocked, +) +from .._util import fixup_module_metadata from ._check_streams import ( - check_one_way_stream, check_two_way_stream, check_half_closeable_stream + check_half_closeable_stream as check_half_closeable_stream, + check_one_way_stream as check_one_way_stream, + check_two_way_stream as check_two_way_stream, +) +from ._checkpoints import ( + assert_checkpoints as assert_checkpoints, + assert_no_checkpoints as assert_no_checkpoints, ) - from ._memory_streams import ( - MemorySendStream, MemoryReceiveStream, memory_stream_pump, - memory_stream_one_way_pair, memory_stream_pair, - lockstep_stream_one_way_pair, lockstep_stream_pair + MemoryReceiveStream as MemoryReceiveStream, + MemorySendStream as MemorySendStream, + lockstep_stream_one_way_pair as lockstep_stream_one_way_pair, + lockstep_stream_pair as lockstep_stream_pair, + memory_stream_one_way_pair as memory_stream_one_way_pair, + memory_stream_pair as memory_stream_pair, + memory_stream_pump as memory_stream_pump, ) - -from ._network import open_stream_to_socket_listener +from ._network import open_stream_to_socket_listener as open_stream_to_socket_listener +from ._sequencer import Sequencer as Sequencer +from ._trio_test import trio_test as trio_test ################################################################ -from .._util import fixup_module_metadata + fixup_module_metadata(__name__, globals()) del fixup_module_metadata diff --git a/trio/testing/_check_streams.py b/trio/testing/_check_streams.py index 2216692df4..401b8ef0c2 100644 --- a/trio/testing/_check_streams.py +++ b/trio/testing/_check_streams.py @@ -1,18 +1,17 @@ # Generic stream tests +from __future__ import annotations -from contextlib import contextmanager import random +from contextlib import contextmanager +from typing import TYPE_CHECKING from .. import _core +from .._abc import HalfCloseableStream, ReceiveStream, SendStream, Stream from .._highlevel_generic import aclose_forcefully -from .._abc import SendStream, ReceiveStream, Stream, HalfCloseableStream from ._checkpoints import assert_checkpoints -__all__ = [ - "check_one_way_stream", - "check_two_way_stream", - "check_half_closeable_stream", -] +if TYPE_CHECKING: + from types import TracebackType class _ForceCloseBoth: @@ -22,7 +21,12 @@ def __init__(self, both): async def __aenter__(self): return self._both - async def __aexit__(self, *args): + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: try: await aclose_forcefully(self._both[0]) finally: @@ -37,7 +41,7 @@ def _assert_raises(exc): except exc: pass else: - raise AssertionError("expected exception: {}".format(exc)) + raise AssertionError(f"expected exception: {exc}") async def check_one_way_stream(stream_maker, clogged_stream_maker): @@ -136,8 +140,7 @@ async def simple_check_wait_send_all_might_not_block(scope): async with _core.open_nursery() as nursery: nursery.start_soon( - simple_check_wait_send_all_might_not_block, - nursery.cancel_scope + simple_check_wait_send_all_might_not_block, nursery.cancel_scope ) nursery.start_soon(do_receive_some, 1) @@ -336,14 +339,15 @@ async def receiver(): nursery.start_soon(s.send_all, b"123") nursery.start_soon(s.send_all, b"123") - # closing the receiver causes wait_send_all_might_not_block to return + # closing the receiver causes wait_send_all_might_not_block to return, + # with or without an exception async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r): async def sender(): try: with assert_checkpoints(): await s.wait_send_all_might_not_block() - except _core.BrokenResourceError: + except _core.BrokenResourceError: # pragma: no cover pass async def receiver(): @@ -360,7 +364,7 @@ async def receiver(): try: with assert_checkpoints(): await s.wait_send_all_might_not_block() - except _core.BrokenResourceError: + except _core.BrokenResourceError: # pragma: no cover pass # Check that if a task is blocked in a send-side method, then closing @@ -404,11 +408,10 @@ async def flipped_stream_maker(): async def flipped_clogged_stream_maker(): return reversed(await clogged_stream_maker()) + else: flipped_clogged_stream_maker = None - await check_one_way_stream( - flipped_stream_maker, flipped_clogged_stream_maker - ) + await check_one_way_stream(flipped_stream_maker, flipped_clogged_stream_maker) async with _ForceCloseBoth(await stream_maker()) as (s1, s2): assert isinstance(s1, Stream) diff --git a/trio/testing/_checkpoints.py b/trio/testing/_checkpoints.py index 716260893b..5804295300 100644 --- a/trio/testing/_checkpoints.py +++ b/trio/testing/_checkpoints.py @@ -2,8 +2,6 @@ from .. import _core -__all__ = ["assert_checkpoints", "assert_no_checkpoints"] - @contextmanager def _assert_yields_or_not(expected): @@ -13,19 +11,13 @@ def _assert_yields_or_not(expected): orig_schedule = task._schedule_points try: yield - if ( - expected and ( - task._cancel_points == orig_cancel - or task._schedule_points == orig_schedule - ) + if expected and ( + task._cancel_points == orig_cancel or task._schedule_points == orig_schedule ): raise AssertionError("assert_checkpoints block did not yield!") finally: - if ( - not expected and ( - task._cancel_points != orig_cancel - or task._schedule_points != orig_schedule - ) + if not expected and ( + task._cancel_points != orig_cancel or task._schedule_points != orig_schedule ): raise AssertionError("assert_no_checkpoints block yielded!") diff --git a/trio/testing/_fake_net.py b/trio/testing/_fake_net.py new file mode 100644 index 0000000000..b3bdfd85c0 --- /dev/null +++ b/trio/testing/_fake_net.py @@ -0,0 +1,408 @@ +# This should eventually be cleaned up and become public, but for right now I'm just +# implementing enough to test DTLS. + +# TODO: +# - user-defined routers +# - TCP +# - UDP broadcast + +from __future__ import annotations + +import errno +import ipaddress +import os +from typing import TYPE_CHECKING, Optional, Union + +import attr + +import trio +from trio._util import Final, NoPublicConstructor + +if TYPE_CHECKING: + from types import TracebackType + +IPAddress = Union[ipaddress.IPv4Address, ipaddress.IPv6Address] + + +def _family_for(ip: IPAddress) -> int: + if isinstance(ip, ipaddress.IPv4Address): + return trio.socket.AF_INET + elif isinstance(ip, ipaddress.IPv6Address): + return trio.socket.AF_INET6 + assert False # pragma: no cover + + +def _wildcard_ip_for(family: int) -> IPAddress: + if family == trio.socket.AF_INET: + return ipaddress.ip_address("0.0.0.0") + elif family == trio.socket.AF_INET6: + return ipaddress.ip_address("::") + else: + assert False + + +def _localhost_ip_for(family: int) -> IPAddress: + if family == trio.socket.AF_INET: + return ipaddress.ip_address("127.0.0.1") + elif family == trio.socket.AF_INET6: + return ipaddress.ip_address("::1") + else: + assert False + + +def _fake_err(code): + raise OSError(code, os.strerror(code)) + + +def _scatter(data, buffers): + written = 0 + for buf in buffers: + next_piece = data[written : written + len(buf)] + with memoryview(buf) as mbuf: + mbuf[: len(next_piece)] = next_piece + written += len(next_piece) + if written == len(data): + break + return written + + +@attr.frozen +class UDPEndpoint: + ip: IPAddress + port: int + + def as_python_sockaddr(self): + sockaddr = (self.ip.compressed, self.port) + if isinstance(self.ip, ipaddress.IPv6Address): + sockaddr += (0, 0) + return sockaddr + + @classmethod + def from_python_sockaddr(cls, sockaddr): + ip, port = sockaddr[:2] + return cls(ip=ipaddress.ip_address(ip), port=port) + + +@attr.frozen +class UDPBinding: + local: UDPEndpoint + + +@attr.frozen +class UDPPacket: + source: UDPEndpoint + destination: UDPEndpoint + payload: bytes = attr.ib(repr=lambda p: p.hex()) + + def reply(self, payload): + return UDPPacket( + source=self.destination, destination=self.source, payload=payload + ) + + +@attr.frozen +class FakeSocketFactory(trio.abc.SocketFactory): + fake_net: "FakeNet" + + def socket(self, family: int, type: int, proto: int) -> "FakeSocket": + return FakeSocket._create(self.fake_net, family, type, proto) + + +@attr.frozen +class FakeHostnameResolver(trio.abc.HostnameResolver): + fake_net: "FakeNet" + + async def getaddrinfo( + self, host: str, port: Union[int, str], family=0, type=0, proto=0, flags=0 + ): + raise NotImplementedError("FakeNet doesn't do fake DNS yet") + + async def getnameinfo(self, sockaddr, flags: int): + raise NotImplementedError("FakeNet doesn't do fake DNS yet") + + +class FakeNet(metaclass=Final): + def __init__(self): + # When we need to pick an arbitrary unique ip address/port, use these: + self._auto_ipv4_iter = ipaddress.IPv4Network("1.0.0.0/8").hosts() + self._auto_ipv4_iter = ipaddress.IPv6Network("1::/16").hosts() + self._auto_port_iter = iter(range(50000, 65535)) + + self._bound: Dict[UDPBinding, FakeSocket] = {} + + self.route_packet = None + + def _bind(self, binding: UDPBinding, socket: "FakeSocket") -> None: + if binding in self._bound: + _fake_err(errno.EADDRINUSE) + self._bound[binding] = socket + + def enable(self) -> None: + trio.socket.set_custom_socket_factory(FakeSocketFactory(self)) + trio.socket.set_custom_hostname_resolver(FakeHostnameResolver(self)) + + def send_packet(self, packet) -> None: + if self.route_packet is None: + self.deliver_packet(packet) + else: + self.route_packet(packet) + + def deliver_packet(self, packet) -> None: + binding = UDPBinding(local=packet.destination) + if binding in self._bound: + self._bound[binding]._deliver_packet(packet) + else: + # No valid destination, so drop it + pass + + +class FakeSocket(trio.socket.SocketType, metaclass=NoPublicConstructor): + def __init__(self, fake_net: FakeNet, family: int, type: int, proto: int): + self._fake_net = fake_net + + if not family: + family = trio.socket.AF_INET + if not type: + type = trio.socket.SOCK_STREAM + + if family not in (trio.socket.AF_INET, trio.socket.AF_INET6): + raise NotImplementedError(f"FakeNet doesn't (yet) support family={family}") + if type != trio.socket.SOCK_DGRAM: + raise NotImplementedError(f"FakeNet doesn't (yet) support type={type}") + + self.family = family + self.type = type + self.proto = proto + + self._closed = False + + self._packet_sender, self._packet_receiver = trio.open_memory_channel( + float("inf") + ) + + # This is the source-of-truth for what port etc. this socket is bound to + self._binding: Optional[UDPBinding] = None + + def _check_closed(self): + if self._closed: + _fake_err(errno.EBADF) + + def close(self): + # breakpoint() + if self._closed: + return + self._closed = True + if self._binding is not None: + del self._fake_net._bound[self._binding] + self._packet_receiver.close() + + async def _resolve_address_nocp(self, address, *, local): + return await trio._socket._resolve_address_nocp( + self.type, + self.family, + self.proto, + address=address, + ipv6_v6only=False, + local=local, + ) + + def _deliver_packet(self, packet: UDPPacket): + try: + self._packet_sender.send_nowait(packet) + except trio.BrokenResourceError: + # sending to a closed socket -- UDP packets get dropped + pass + + ################################################################ + # Actual IO operation implementations + ################################################################ + + async def bind(self, addr): + self._check_closed() + if self._binding is not None: + _fake_error(errno.EINVAL) + await trio.lowlevel.checkpoint() + ip_str, port = await self._resolve_address_nocp(addr, local=True) + ip = ipaddress.ip_address(ip_str) + assert _family_for(ip) == self.family + # We convert binds to INET_ANY into binds to localhost + if ip == ipaddress.ip_address("0.0.0.0"): + ip = ipaddress.ip_address("127.0.0.1") + elif ip == ipaddress.ip_address("::"): + ip = ipaddress.ip_address("::1") + if port == 0: + port = next(self._fake_net._auto_port_iter) + binding = UDPBinding(local=UDPEndpoint(ip, port)) + self._fake_net._bind(binding, self) + self._binding = binding + + async def connect(self, peer): + raise NotImplementedError("FakeNet does not (yet) support connected sockets") + + async def sendmsg(self, *args): + self._check_closed() + ancdata = [] + flags = 0 + address = None + if len(args) == 1: + (buffers,) = args + elif len(args) == 2: + buffers, address = args + elif len(args) == 3: + buffers, flags, address = args + elif len(args) == 4: + buffers, ancdata, flags, address = args + else: + raise TypeError("wrong number of arguments") + + await trio.lowlevel.checkpoint() + + if address is not None: + address = await self._resolve_address_nocp(address, local=False) + if ancdata: + raise NotImplementedError("FakeNet doesn't support ancillary data") + if flags: + raise NotImplementedError(f"FakeNet send flags must be 0, not {flags}") + + if address is None: + _fake_err(errno.ENOTCONN) + + destination = UDPEndpoint.from_python_sockaddr(address) + + if self._binding is None: + await self.bind((_wildcard_ip_for(self.family).compressed, 0)) + + payload = b"".join(buffers) + + packet = UDPPacket( + source=self._binding.local, + destination=destination, + payload=payload, + ) + + self._fake_net.send_packet(packet) + + return len(payload) + + async def recvmsg_into(self, buffers, ancbufsize=0, flags=0): + if ancbufsize != 0: + raise NotImplementedError("FakeNet doesn't support ancillary data") + if flags != 0: + raise NotImplementedError("FakeNet doesn't support any recv flags") + + self._check_closed() + + ancdata = [] + msg_flags = 0 + + packet = await self._packet_receiver.receive() + address = packet.source.as_python_sockaddr() + written = _scatter(packet.payload, buffers) + if written < len(packet.payload): + msg_flags |= trio.socket.MSG_TRUNC + return written, ancdata, msg_flags, address + + ################################################################ + # Simple state query stuff + ################################################################ + + def getsockname(self): + self._check_closed() + if self._binding is not None: + return self._binding.local.as_python_sockaddr() + elif self.family == trio.socket.AF_INET: + return ("0.0.0.0", 0) + else: + assert self.family == trio.socket.AF_INET6 + return ("::", 0) + + def getpeername(self): + self._check_closed() + if self._binding is not None: + if self._binding.remote is not None: + return self._binding.remote.as_python_sockaddr() + _fake_err(errno.ENOTCONN) + + def getsockopt(self, level, item): + self._check_closed() + raise OSError(f"FakeNet doesn't implement getsockopt({level}, {item})") + + def setsockopt(self, level, item, value): + self._check_closed() + + if (level, item) == (trio.socket.IPPROTO_IPV6, trio.socket.IPV6_V6ONLY): + if not value: + raise NotImplementedError("FakeNet always has IPV6_V6ONLY=True") + + raise OSError(f"FakeNet doesn't implement setsockopt({level}, {item}, ...)") + + ################################################################ + # Various boilerplate and trivial stubs + ################################################################ + + def __enter__(self): + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + self.close() + + async def send(self, data, flags=0): + return await self.sendto(data, flags, None) + + async def sendto(self, *args): + if len(args) == 2: + data, address = args + flags = 0 + elif len(args) == 3: + data, flags, address = args + else: + raise TypeError("wrong number of arguments") + return await self.sendmsg([data], [], flags, address) + + async def recv(self, bufsize, flags=0): + data, address = await self.recvfrom(bufsize, flags) + return data + + async def recv_into(self, buf, nbytes=0, flags=0): + got_bytes, address = await self.recvfrom_into(buf, nbytes, flags) + return got_bytes + + async def recvfrom(self, bufsize, flags=0): + data, ancdata, msg_flags, address = await self.recvmsg(bufsize, flags) + return data, address + + async def recvfrom_into(self, buf, nbytes=0, flags=0): + if nbytes != 0 and nbytes != len(buf): + raise NotImplementedError("partial recvfrom_into") + got_nbytes, ancdata, msg_flags, address = await self.recvmsg_into( + [buf], 0, flags + ) + return got_nbytes, address + + async def recvmsg(self, bufsize, ancbufsize=0, flags=0): + buf = bytearray(bufsize) + got_nbytes, ancdata, msg_flags, address = await self.recvmsg_into( + [buf], ancbufsize, flags + ) + return (bytes(buf[:got_nbytes]), ancdata, msg_flags, address) + + def fileno(self): + raise NotImplementedError("can't get fileno() for FakeNet sockets") + + def detach(self): + raise NotImplementedError("can't detach() a FakeNet socket") + + def get_inheritable(self): + return False + + def set_inheritable(self, inheritable): + if inheritable: + raise NotImplementedError("FakeNet can't make inheritable sockets") + + def share(self, process_id): + raise NotImplementedError("FakeNet can't share sockets") diff --git a/trio/testing/_memory_streams.py b/trio/testing/_memory_streams.py index d86e301888..38e8e54de8 100644 --- a/trio/testing/_memory_streams.py +++ b/trio/testing/_memory_streams.py @@ -1,19 +1,8 @@ import operator -from .. import _core +from .. import _core, _util from .._highlevel_generic import StapledStream -from .. import _util -from ..abc import SendStream, ReceiveStream - -__all__ = [ - "MemorySendStream", - "MemoryReceiveStream", - "memory_stream_pump", - "memory_stream_one_way_pair", - "memory_stream_pair", - "lockstep_stream_one_way_pair", - "lockstep_stream_pair", -] +from ..abc import ReceiveStream, SendStream ################################################################ # In-memory streams - Unbounded buffer version @@ -83,7 +72,7 @@ async def get(self, max_bytes=None): return self._get_impl(max_bytes) -class MemorySendStream(SendStream): +class MemorySendStream(SendStream, metaclass=_util.Final): """An in-memory :class:`~trio.abc.SendStream`. Args: @@ -103,11 +92,12 @@ class MemorySendStream(SendStream): you can change them at any time. """ + def __init__( self, send_all_hook=None, wait_send_all_might_not_block_hook=None, - close_hook=None + close_hook=None, ): self._conflict_detector = _util.ConflictDetector( "another task is using this stream" @@ -164,9 +154,7 @@ def close(self): self.close_hook() async def aclose(self): - """Same as :meth:`close`, but async. - - """ + """Same as :meth:`close`, but async.""" self.close() await _core.checkpoint() @@ -198,7 +186,7 @@ def get_data_nowait(self, max_bytes=None): return self._outgoing.get_nowait(max_bytes) -class MemoryReceiveStream(ReceiveStream): +class MemoryReceiveStream(ReceiveStream, metaclass=_util.Final): """An in-memory :class:`~trio.abc.ReceiveStream`. Args: @@ -214,6 +202,7 @@ class MemoryReceiveStream(ReceiveStream): change them at any time. """ + def __init__(self, receive_some_hook=None, close_hook=None): self._conflict_detector = _util.ConflictDetector( "another task is using this stream" @@ -257,28 +246,20 @@ def close(self): self.close_hook() async def aclose(self): - """Same as :meth:`close`, but async. - - """ + """Same as :meth:`close`, but async.""" self.close() await _core.checkpoint() def put_data(self, data): - """Appends the given data to the internal buffer. - - """ + """Appends the given data to the internal buffer.""" self._incoming.put(data) def put_eof(self): - """Adds an end-of-file marker to the internal buffer. - - """ + """Adds an end-of-file marker to the internal buffer.""" self._incoming.close() -def memory_stream_pump( - memory_send_stream, memory_receive_stream, *, max_bytes=None -): +def memory_stream_pump(memory_send_stream, memory_receive_stream, *, max_bytes=None): """Take data out of the given :class:`MemorySendStream`'s internal buffer, and put it into the given :class:`MemoryReceiveStream`'s internal buffer. diff --git a/trio/testing/_mock_clock.py b/trio/testing/_mock_clock.py deleted file mode 100644 index c3cf863392..0000000000 --- a/trio/testing/_mock_clock.py +++ /dev/null @@ -1,206 +0,0 @@ -import time -from math import inf - -from .. import _core -from .._abc import Clock - -__all__ = ["MockClock"] - -################################################################ -# The glorious MockClock -################################################################ - - -# Prior art: -# https://twistedmatrix.com/documents/current/api/twisted.internet.task.Clock.html -# https://github.com/ztellman/manifold/issues/57 -class MockClock(Clock): - """A user-controllable clock suitable for writing tests. - - Args: - rate (float): the initial :attr:`rate`. - autojump_threshold (float): the initial :attr:`autojump_threshold`. - - .. attribute:: rate - - How many seconds of clock time pass per second of real time. Default is - 0.0, i.e. the clock only advances through manuals calls to :meth:`jump` - or when the :attr:`autojump_threshold` is triggered. You can assign to - this attribute to change it. - - .. attribute:: autojump_threshold - - The clock keeps an eye on the run loop, and if at any point it detects - that all tasks have been blocked for this many real seconds (i.e., - according to the actual clock, not this clock), then the clock - automatically jumps ahead to the run loop's next scheduled - timeout. Default is :data:`math.inf`, i.e., to never autojump. You can - assign to this attribute to change it. - - Basically the idea is that if you have code or tests that use sleeps - and timeouts, you can use this to make it run much faster, totally - automatically. (At least, as long as those sleeps/timeouts are - happening inside Trio; if your test involves talking to external - service and waiting for it to timeout then obviously we can't help you - there.) - - You should set this to the smallest value that lets you reliably avoid - "false alarms" where some I/O is in flight (e.g. between two halves of - a socketpair) but the threshold gets triggered and time gets advanced - anyway. This will depend on the details of your tests and test - environment. If you aren't doing any I/O (like in our sleeping example - above) then just set it to zero, and the clock will jump whenever all - tasks are blocked. - - .. warning:: - - If you're using :func:`wait_all_tasks_blocked` and - :attr:`autojump_threshold` together, then you have to be - careful. Setting :attr:`autojump_threshold` acts like a background - task calling:: - - while True: - await wait_all_tasks_blocked( - cushion=clock.autojump_threshold, tiebreaker=float("inf")) - - This means that if you call :func:`wait_all_tasks_blocked` with a - cushion *larger* than your autojump threshold, then your call to - :func:`wait_all_tasks_blocked` will never return, because the - autojump task will keep waking up before your task does, and each - time it does it'll reset your task's timer. However, if your cushion - and the autojump threshold are the *same*, then the autojump's - tiebreaker will prevent them from interfering (unless you also set - your tiebreaker to infinity for some reason. Don't do that). As an - important special case: this means that if you set an autojump - threshold of zero and use :func:`wait_all_tasks_blocked` with the - default zero cushion, then everything will work fine. - - **Summary**: you should set :attr:`autojump_threshold` to be at - least as large as the largest cushion you plan to pass to - :func:`wait_all_tasks_blocked`. - - """ - def __init__(self, rate=0.0, autojump_threshold=inf): - # when the real clock said 'real_base', the virtual time was - # 'virtual_base', and since then it's advanced at 'rate' virtual - # seconds per real second. - self._real_base = 0.0 - self._virtual_base = 0.0 - self._rate = 0.0 - self._autojump_threshold = 0.0 - self._autojump_task = None - self._autojump_cancel_scope = None - # kept as an attribute so that our tests can monkeypatch it - self._real_clock = time.perf_counter - - # use the property update logic to set initial values - self.rate = rate - self.autojump_threshold = autojump_threshold - - def __repr__(self): - return ( - "".format( - self.current_time(), self._rate, id(self) - ) - ) - - @property - def rate(self): - return self._rate - - @rate.setter - def rate(self, new_rate): - if new_rate < 0: - raise ValueError("rate must be >= 0") - else: - real = self._real_clock() - virtual = self._real_to_virtual(real) - self._virtual_base = virtual - self._real_base = real - self._rate = float(new_rate) - - @property - def autojump_threshold(self): - return self._autojump_threshold - - @autojump_threshold.setter - def autojump_threshold(self, new_autojump_threshold): - self._autojump_threshold = float(new_autojump_threshold) - self._maybe_spawn_autojump_task() - if self._autojump_cancel_scope is not None: - # Task is running and currently blocked on the old setting, wake - # it up so it picks up the new setting - self._autojump_cancel_scope.cancel() - - async def _autojumper(self): - while True: - with _core.CancelScope() as cancel_scope: - self._autojump_cancel_scope = cancel_scope - try: - # If the autojump_threshold changes, then the setter does - # cancel_scope.cancel(), which causes the next line here - # to raise Cancelled, which is absorbed by the cancel - # scope above, and effectively just causes us to skip back - # to the start the loop, like a 'continue'. - await _core.wait_all_tasks_blocked( - self._autojump_threshold, inf - ) - statistics = _core.current_statistics() - jump = statistics.seconds_to_next_deadline - if 0 < jump < inf: - self.jump(jump) - else: - # There are no deadlines, nothing is going to happen - # until some actual I/O arrives (or maybe another - # wait_all_tasks_blocked task wakes up). That's fine, - # but if our threshold is zero then this will become a - # busy-wait -- so insert a small-but-non-zero _sleep to - # avoid that. - if self._autojump_threshold == 0: - await _core.wait_all_tasks_blocked(0.01) - finally: - self._autojump_cancel_scope = None - - def _maybe_spawn_autojump_task(self): - if self._autojump_threshold < inf and self._autojump_task is None: - try: - clock = _core.current_clock() - except RuntimeError: - return - if clock is self: - self._autojump_task = _core.spawn_system_task(self._autojumper) - - def _real_to_virtual(self, real): - real_offset = real - self._real_base - virtual_offset = self._rate * real_offset - return self._virtual_base + virtual_offset - - def start_clock(self): - token = _core.current_trio_token() - token.run_sync_soon(self._maybe_spawn_autojump_task) - - def current_time(self): - return self._real_to_virtual(self._real_clock()) - - def deadline_to_sleep_time(self, deadline): - virtual_timeout = deadline - self.current_time() - if virtual_timeout <= 0: - return 0 - elif self._rate > 0: - return virtual_timeout / self._rate - else: - return 999999999 - - def jump(self, seconds): - """Manually advance the clock by the given number of seconds. - - Args: - seconds (float): the number of seconds to jump the clock forward. - - Raises: - ValueError: if you try to pass a negative value for ``seconds``. - - """ - if seconds < 0: - raise ValueError("time can't go backwards") - self._virtual_base += seconds diff --git a/trio/testing/_network.py b/trio/testing/_network.py index e3844847d5..615ce2effb 100644 --- a/trio/testing/_network.py +++ b/trio/testing/_network.py @@ -1,7 +1,5 @@ from .. import socket as tsocket -from .._highlevel_socket import SocketListener, SocketStream - -__all__ = ["open_stream_to_socket_listener"] +from .._highlevel_socket import SocketStream async def open_stream_to_socket_listener(socket_listener): diff --git a/trio/testing/_sequencer.py b/trio/testing/_sequencer.py index 118969f83d..137fd3c522 100644 --- a/trio/testing/_sequencer.py +++ b/trio/testing/_sequencer.py @@ -1,20 +1,19 @@ +from __future__ import annotations + from collections import defaultdict +from contextlib import asynccontextmanager +from typing import TYPE_CHECKING import attr -from async_generator import async_generator, yield_, asynccontextmanager - -from .. import _core -from .. import _util -from .. import Event -if False: - from typing import DefaultDict, Set +from .. import Event, _core, _util -__all__ = ["Sequencer"] +if TYPE_CHECKING: + from collections.abc import AsyncIterator @attr.s(eq=False, hash=False) -class Sequencer: +class Sequencer(metaclass=_util.Final): """A convenience class for forcing code in different tasks to run in an explicit linear order. @@ -54,19 +53,16 @@ async def main(): """ - _sequence_points = attr.ib( + _sequence_points: defaultdict[int, Event] = attr.ib( factory=lambda: defaultdict(Event), init=False - ) # type: DefaultDict[int, Event] - _claimed = attr.ib(factory=set, init=False) # type: Set[int] - _broken = attr.ib(default=False, init=False) + ) + _claimed: set[int] = attr.ib(factory=set, init=False) + _broken: bool = attr.ib(default=False, init=False) @asynccontextmanager - @async_generator - async def __call__(self, position: int): + async def __call__(self, position: int) -> AsyncIterator[None]: if position in self._claimed: - raise RuntimeError( - "Attempted to re-use sequence point {}".format(position) - ) + raise RuntimeError(f"Attempted to re-use sequence point {position}") if self._broken: raise RuntimeError("sequence broken!") self._claimed.add(position) @@ -77,13 +73,11 @@ async def __call__(self, position: int): self._broken = True for event in self._sequence_points.values(): event.set() - raise RuntimeError( - "Sequencer wait cancelled -- sequence broken" - ) + raise RuntimeError("Sequencer wait cancelled -- sequence broken") else: if self._broken: raise RuntimeError("sequence broken!") try: - await yield_() + yield finally: self._sequence_points[position + 1].set() diff --git a/trio/testing/_trio_test.py b/trio/testing/_trio_test.py index 9215084955..b4ef69ef09 100644 --- a/trio/testing/_trio_test.py +++ b/trio/testing/_trio_test.py @@ -1,10 +1,8 @@ -from functools import wraps, partial +from functools import partial, wraps from .. import _core from ..abc import Clock, Instrument -__all__ = ["trio_test"] - # Use: # @@ -26,8 +24,6 @@ def wrapper(**kwargs): else: raise ValueError("too many clocks spoil the broth!") instruments = [i for i in kwargs.values() if isinstance(i, Instrument)] - return _core.run( - partial(fn, **kwargs), clock=clock, instruments=instruments - ) + return _core.run(partial(fn, **kwargs), clock=clock, instruments=instruments) return wrapper diff --git a/trio/tests.py b/trio/tests.py new file mode 100644 index 0000000000..573a076da8 --- /dev/null +++ b/trio/tests.py @@ -0,0 +1,38 @@ +import importlib +import sys +from typing import Any + +from . import _tests +from ._deprecate import warn_deprecated + +warn_deprecated( + "trio.tests", + "0.22.1", + instead="trio._tests", + issue="https://github.com/python-trio/trio/issues/274", +) + + +# This won't give deprecation warning on import, but will give a warning on use of any +# attribute in tests, and static analysis tools will also not see any content inside. +class TestsDeprecationWrapper: + __name__ = "trio.tests" + + def __getattr__(self, attr: str) -> Any: + warn_deprecated( + f"trio.tests.{attr}", + "0.22.1", + instead=f"trio._tests.{attr}", + issue="https://github.com/python-trio/trio/issues/274", + ) + + # needed to access e.g. trio._tests.tools, although pytest doesn't need it + if not hasattr(_tests, attr): # pragma: no cover + importlib.import_module(f"trio._tests.{attr}", "trio._tests") + return attr + + return getattr(_tests, attr) + + +# https://stackoverflow.com/questions/2447353/getattr-on-a-module +sys.modules[__name__] = TestsDeprecationWrapper() # type: ignore[assignment] diff --git a/trio/tests/test_exports.py b/trio/tests/test_exports.py deleted file mode 100644 index abc6f07963..0000000000 --- a/trio/tests/test_exports.py +++ /dev/null @@ -1,112 +0,0 @@ -import sys -import importlib -import types - -import pytest - -import trio -import trio.testing - -from .. import _core - - -def test_core_is_properly_reexported(): - # Each export from _core should be re-exported by exactly one of these - # three modules: - sources = [trio, trio.hazmat, trio.testing] - for symbol in dir(_core): - if symbol.startswith('_') or symbol == 'tests': - continue - found = 0 - for source in sources: - if ( - symbol in dir(source) - and getattr(source, symbol) is getattr(_core, symbol) - ): - found += 1 - print(symbol, found) - assert found == 1 - - -def public_namespaces(module): - yield module.__name__ - for name, value in module.__dict__.items(): - if name.startswith("_"): - continue - if not isinstance(value, types.ModuleType): - continue - if not value.__name__.startswith(module.__name__): - continue - if value is module: - continue - # We should rename the trio.tests module (#274), but until then we use - # a special-case hack: - if value.__name__ == "trio.tests": - continue - yield from public_namespaces(value) - - -NAMESPACES = list(public_namespaces(trio)) - - -# It doesn't make sense for downstream redistributors to run this test, since -# they might be using a newer version of Python with additional symbols which -# won't be reflected in trio.socket, and this shouldn't cause downstream test -# runs to start failing. -@pytest.mark.redistributors_should_skip -# pylint/jedi often have trouble with alpha releases, where Python's internals -# are in flux, grammar may not have settled down, etc. -@pytest.mark.skipif( - sys.version_info.releaselevel == "alpha", - reason="skip static introspection tools on Python dev/alpha releases", -) -@pytest.mark.filterwarnings( - # https://github.com/PyCQA/astroid/issues/681 - "ignore:the imp module is deprecated.*:DeprecationWarning" -) -@pytest.mark.filterwarnings( - # Same as above, but on Python 3.5 - "ignore:the imp module is deprecated.*:PendingDeprecationWarning" -) -@pytest.mark.parametrize("modname", NAMESPACES) -@pytest.mark.parametrize("tool", ["pylint", "jedi"]) -def test_static_tool_sees_all_symbols(tool, modname): - module = importlib.import_module(modname) - - def no_underscores(symbols): - return {symbol for symbol in symbols if not symbol.startswith("_")} - - runtime_names = no_underscores(dir(module)) - - # We should rename the trio.tests module (#274), but until then we use a - # special-case hack: - if modname == "trio": - runtime_names.remove("tests") - - if tool == "pylint": - from pylint.lint import PyLinter - linter = PyLinter() - ast = linter.get_ast(module.__file__, modname) - static_names = no_underscores(ast) - elif tool == "jedi": - import jedi - # Simulate typing "import trio; trio." - script = jedi.Script("import {}; {}.".format(modname, modname)) - completions = script.completions() - static_names = no_underscores(c.name for c in completions) - else: # pragma: no cover - assert False - - # It's expected that the static set will contain more names than the - # runtime set: - # - static tools are sometimes sloppy and include deleted names - # - some symbols are platform-specific at runtime, but always show up in - # static analysis (e.g. in trio.socket or trio.hazmat) - # So we check that the runtime names are a subset of the static names. - missing_names = runtime_names - static_names - if missing_names: # pragma: no cover - print("{} can't see the following names in {}:".format(tool, modname)) - print() - for name in sorted(missing_names): - print(" {}".format(name)) - assert False diff --git a/trio/tests/test_highlevel_ssl_helpers.py b/trio/tests/test_highlevel_ssl_helpers.py deleted file mode 100644 index 1583f4cd54..0000000000 --- a/trio/tests/test_highlevel_ssl_helpers.py +++ /dev/null @@ -1,121 +0,0 @@ -import pytest - -from functools import partial - -import attr - -import trio -from trio.socket import AF_INET, SOCK_STREAM, IPPROTO_TCP -import trio.testing -from .test_ssl import client_ctx, SERVER_CTX - -from .._highlevel_ssl_helpers import ( - open_ssl_over_tcp_stream, open_ssl_over_tcp_listeners, serve_ssl_over_tcp -) - - -async def echo_handler(stream): - async with stream: - try: - while True: - data = await stream.receive_some(10000) - if not data: - break - await stream.send_all(data) - except trio.BrokenResourceError: - pass - - -# Resolver that always returns the given sockaddr, no matter what host/port -# you ask for. -@attr.s -class FakeHostnameResolver(trio.abc.HostnameResolver): - sockaddr = attr.ib() - - async def getaddrinfo(self, *args): - return [(AF_INET, SOCK_STREAM, IPPROTO_TCP, "", self.sockaddr)] - - async def getnameinfo(self, *args): # pragma: no cover - raise NotImplementedError - - -# This uses serve_ssl_over_tcp, which uses open_ssl_over_tcp_listeners... -# noqa is needed because flake8 doesn't understand how pytest fixtures work. -async def test_open_ssl_over_tcp_stream_and_everything_else( - client_ctx, # noqa: F811 -): - async with trio.open_nursery() as nursery: - (listener,) = await nursery.start( - partial( - serve_ssl_over_tcp, - echo_handler, - 0, - SERVER_CTX, - host="127.0.0.1" - ) - ) - sockaddr = listener.transport_listener.socket.getsockname() - hostname_resolver = FakeHostnameResolver(sockaddr) - trio.socket.set_custom_hostname_resolver(hostname_resolver) - - # We don't have the right trust set up - # (checks that ssl_context=None is doing some validation) - stream = await open_ssl_over_tcp_stream("trio-test-1.example.org", 80) - with pytest.raises(trio.BrokenResourceError): - await stream.do_handshake() - - # We have the trust but not the hostname - # (checks custom ssl_context + hostname checking) - stream = await open_ssl_over_tcp_stream( - "xyzzy.example.org", - 80, - ssl_context=client_ctx, - ) - with pytest.raises(trio.BrokenResourceError): - await stream.do_handshake() - - # This one should work! - stream = await open_ssl_over_tcp_stream( - "trio-test-1.example.org", - 80, - ssl_context=client_ctx, - ) - assert isinstance(stream, trio.SSLStream) - assert stream.server_hostname == "trio-test-1.example.org" - await stream.send_all(b"x") - assert await stream.receive_some(1) == b"x" - await stream.aclose() - - # Check https_compatible settings are being passed through - assert not stream._https_compatible - stream = await open_ssl_over_tcp_stream( - "trio-test-1.example.org", - 80, - ssl_context=client_ctx, - https_compatible=True, - # also, smoke test happy_eyeballs_delay - happy_eyeballs_delay=1, - ) - assert stream._https_compatible - - # Stop the echo server - nursery.cancel_scope.cancel() - - -async def test_open_ssl_over_tcp_listeners(): - (listener,) = await open_ssl_over_tcp_listeners( - 0, SERVER_CTX, host="127.0.0.1" - ) # yapf: disable - async with listener: - assert isinstance(listener, trio.SSLListener) - tl = listener.transport_listener - assert isinstance(tl, trio.SocketListener) - assert tl.socket.getsockname()[0] == "127.0.0.1" - - assert not listener._https_compatible - - (listener,) = await open_ssl_over_tcp_listeners( - 0, SERVER_CTX, host="127.0.0.1", https_compatible=True - ) - async with listener: - assert listener._https_compatible diff --git a/trio/tests/test_util.py b/trio/tests/test_util.py deleted file mode 100644 index 5815ae3e42..0000000000 --- a/trio/tests/test_util.py +++ /dev/null @@ -1,215 +0,0 @@ -import os -import pathlib -import signal -import sys - -import pytest - -import trio -from .. import _core -from .._util import ( - signal_raise, ConflictDetector, fspath, is_main_thread, generic_function, - Final, NoPublicConstructor -) -from ..testing import wait_all_tasks_blocked, assert_checkpoints - - -def raise_(exc): - """ Raise provided exception. - Just a helper for raising exceptions from lambdas. """ - raise exc - - -def test_signal_raise(): - record = [] - - def handler(signum, _): - record.append(signum) - - old = signal.signal(signal.SIGFPE, handler) - try: - signal_raise(signal.SIGFPE) - finally: - signal.signal(signal.SIGFPE, old) - assert record == [signal.SIGFPE] - - -async def test_ConflictDetector(): - ul1 = ConflictDetector("ul1") - ul2 = ConflictDetector("ul2") - - with ul1: - with ul2: - print("ok") - - with pytest.raises(_core.BusyResourceError) as excinfo: - with ul1: - with ul1: - pass # pragma: no cover - assert "ul1" in str(excinfo.value) - - async def wait_with_ul1(): - with ul1: - await wait_all_tasks_blocked() - - with pytest.raises(_core.BusyResourceError) as excinfo: - async with _core.open_nursery() as nursery: - nursery.start_soon(wait_with_ul1) - nursery.start_soon(wait_with_ul1) - assert "ul1" in str(excinfo.value) - - -def test_module_metadata_is_fixed_up(): - import trio - assert trio.Cancelled.__module__ == "trio" - assert trio.open_nursery.__module__ == "trio" - assert trio.abc.Stream.__module__ == "trio.abc" - assert trio.hazmat.wait_task_rescheduled.__module__ == "trio.hazmat" - import trio.testing - assert trio.testing.trio_test.__module__ == "trio.testing" - - # Also check methods - assert trio.hazmat.ParkingLot.__init__.__module__ == "trio.hazmat" - assert trio.abc.Stream.send_all.__module__ == "trio.abc" - - # And names - assert trio.Cancelled.__name__ == "Cancelled" - assert trio.Cancelled.__qualname__ == "Cancelled" - assert trio.abc.SendStream.send_all.__name__ == "send_all" - assert trio.abc.SendStream.send_all.__qualname__ == "SendStream.send_all" - assert trio.to_thread.__name__ == "trio.to_thread" - assert trio.to_thread.run_sync.__name__ == "run_sync" - assert trio.to_thread.run_sync.__qualname__ == "run_sync" - - -# define a concrete class implementing the PathLike protocol -# Since we want to have compatibility with Python 3.5 we need -# to define the base class on runtime. -BaseKlass = os.PathLike if hasattr(os, "PathLike") else object - - -class ConcretePathLike(BaseKlass): - """ Class implementing the file system path protocol.""" - def __init__(self, path=""): - self.path = path - - def __fspath__(self): - return self.path - - -class TestFspath: - - # based on: - # https://github.com/python/cpython/blob/da6c3da6c33c6bf794f741e348b9c6d86cc43ec5/Lib/test/test_os.py#L3527-L3571 - - @pytest.mark.parametrize( - "path", (b'hello', b'goodbye', b'some/path/and/file') - ) - def test_return_bytes(self, path): - assert path == fspath(path) - - @pytest.mark.parametrize( - "path", ('hello', 'goodbye', 'some/path/and/file') - ) - def test_return_string(self, path): - assert path == fspath(path) - - @pytest.mark.parametrize( - "path", (pathlib.Path("/home"), pathlib.Path("C:\\windows")) - ) - def test_handle_pathlib(self, path): - assert str(path) == fspath(path) - - @pytest.mark.parametrize("path", ("path/like/object", b"path/like/object")) - def test_handle_pathlike_protocol(self, path): - pathlike = ConcretePathLike(path) - assert path == fspath(pathlike) - if sys.version_info > (3, 6): - assert issubclass(ConcretePathLike, os.PathLike) - assert isinstance(pathlike, os.PathLike) - - def test_argument_required(self): - with pytest.raises(TypeError): - fspath() - - def test_throw_error_at_multiple_arguments(self): - with pytest.raises(TypeError): - fspath(1, 2) - - @pytest.mark.parametrize( - "klass", (23, object(), int, type, os, type("blah", (), {})()) - ) - def test_throw_error_at_non_pathlike(self, klass): - with pytest.raises(TypeError): - fspath(klass) - - @pytest.mark.parametrize( - "exception, method", - [ - (TypeError, 1), # __fspath__ is not callable - (TypeError, lambda x: 23 - ), # __fspath__ returns a value other than str or bytes - (Exception, lambda x: raise_(Exception) - ), # __fspath__raises a random exception - (AttributeError, lambda x: raise_(AttributeError) - ), # __fspath__ raises AttributeError - ] - ) - def test_bad_pathlike_implementation(self, exception, method): - klass = type('foo', (), {}) - klass.__fspath__ = method - with pytest.raises(exception): - fspath(klass()) - - -async def test_is_main_thread(): - assert is_main_thread() - - def not_main_thread(): - assert not is_main_thread() - - await trio.to_thread.run_sync(not_main_thread) - - -def test_generic_function(): - @generic_function - def test_func(arg): - """Look, a docstring!""" - return arg - - assert test_func is test_func[int] is test_func[int, str] - assert test_func(42) == test_func[int](42) == 42 - assert test_func.__doc__ == "Look, a docstring!" - assert test_func.__qualname__ == "test_generic_function..test_func" - assert test_func.__name__ == "test_func" - assert test_func.__module__ == __name__ - - -def test_final_metaclass(): - class FinalClass(metaclass=Final): - pass - - with pytest.raises( - TypeError, match="`FinalClass` does not support subclassing" - ): - - class SubClass(FinalClass): - pass - - -def test_no_public_constructor_metaclass(): - class SpecialClass(metaclass=NoPublicConstructor): - pass - - with pytest.raises(TypeError, match="no public constructor available"): - SpecialClass() - - with pytest.raises( - TypeError, match="`SpecialClass` does not support subclassing" - ): - - class SubClass(SpecialClass): - pass - - # Private constructor should not raise - assert isinstance(SpecialClass._create(), SpecialClass) diff --git a/trio/to_thread.py b/trio/to_thread.py index 6eec7b36c7..45ea5b480b 100644 --- a/trio/to_thread.py +++ b/trio/to_thread.py @@ -1,2 +1,4 @@ -from ._threads import to_thread_run_sync as run_sync -from ._threads import current_default_thread_limiter +from ._threads import current_default_thread_limiter, to_thread_run_sync as run_sync + +# need to use __all__ for pyright --verifytypes to see re-exports when renaming them +__all__ = ["current_default_thread_limiter", "run_sync"]