Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions trio/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ async def __aexit__(self, type, value, traceback):
try:
await self._agen.asend(None)
except StopAsyncIteration:
return
return False
else:
raise RuntimeError("async generator didn't stop")
else:
Expand All @@ -122,7 +122,7 @@ async def __aexit__(self, type, value, traceback):
# Likewise, avoid suppressing if a StopIteration exception
# was passed to throw() and later wrapped into a RuntimeError
# (see PEP 479).
if exc.__cause__ is value:
if isinstance(value, (StopIteration, StopAsyncIteration)) and exc.__cause__ is value:
return False
raise
except:
Expand All @@ -133,8 +133,9 @@ async def __aexit__(self, type, value, traceback):
# fixes the impedance mismatch between the throw() protocol
# and the __exit__() protocol.
#
if sys.exc_info()[1] is not value:
raise
if sys.exc_info()[1] is value:
return False
raise

def __enter__(self):
raise RuntimeError("use 'async with {func_name}(...)', not 'with {func_name}(...)'".format(func_name=self._func_name))
Expand Down
69 changes: 69 additions & 0 deletions trio/tests/test_util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import pytest

import signal
import sys
import textwrap

from async_generator import async_generator, yield_

from .._util import *
from .. import _core
Expand Down Expand Up @@ -58,3 +62,68 @@ async def wait_with_ul1():
async with ul1:
pass # pragma: no cover
assert "ul1" in str(excinfo.value)


async def test_contextmanager_do_not_unchain_non_stopiteration_exceptions():
@acontextmanager
@async_generator
async def manager_issue29692():
try:
await yield_()
except Exception as exc:
raise RuntimeError('issue29692:Chained') from exc

with pytest.raises(RuntimeError) as excinfo:
async with manager_issue29692():
raise ZeroDivisionError
assert excinfo.value.args[0] == 'issue29692:Chained'
assert isinstance(excinfo.value.__cause__, ZeroDivisionError)

# This is a little funky because of implementation details in async_generator
# It can all go away once we stop supporting Python3.5
with pytest.raises(RuntimeError) as excinfo:
async with manager_issue29692():
exc = StopIteration('issue29692:Unchained')
raise exc
assert excinfo.value.args[0] == 'issue29692:Chained'
cause = excinfo.value.__cause__
assert cause.args[0] == 'generator raised StopIteration'
assert cause.__cause__ is exc

with pytest.raises(StopAsyncIteration) as excinfo:
async with manager_issue29692():
raise StopAsyncIteration('issue29692:Unchained')
assert excinfo.value.args[0] == 'issue29692:Unchained'
assert excinfo.value.__cause__ is None


# Native async generators are only available from Python 3.6 and onwards
nativeasyncgenerators = True
try:
exec("""
@acontextmanager
async def manager_issue29692_2():
try:
yield
except Exception as exc:
raise RuntimeError('issue29692:Chained') from exc
""")
except SyntaxError:
nativeasyncgenerators = False


@pytest.mark.skipif(not nativeasyncgenerators, reason="Python < 3.6 doesn't have native async generators")
async def test_native_contextmanager_do_not_unchain_non_stopiteration_exceptions():

with pytest.raises(RuntimeError) as excinfo:
async with manager_issue29692_2():
raise ZeroDivisionError
assert excinfo.value.args[0] == 'issue29692:Chained'
assert isinstance(excinfo.value.__cause__, ZeroDivisionError)

for cls in [StopIteration, StopAsyncIteration]:
with pytest.raises(cls) as excinfo:
async with manager_issue29692_2():
raise cls('issue29692:Unchained')
assert excinfo.value.args[0] == 'issue29692:Unchained'
assert excinfo.value.__cause__ is None