diff --git a/trio/_util.py b/trio/_util.py index 3592065820..333896acde 100644 --- a/trio/_util.py +++ b/trio/_util.py @@ -99,7 +99,7 @@ async def __aexit__(self, type, value, traceback): try: await self._agen.asend(None) except StopAsyncIteration: - return + return False else: raise RuntimeError("async generator didn't stop") else: @@ -122,7 +122,7 @@ async def __aexit__(self, type, value, traceback): # Likewise, avoid suppressing if a StopIteration exception # was passed to throw() and later wrapped into a RuntimeError # (see PEP 479). - if exc.__cause__ is value: + if isinstance(value, (StopIteration, StopAsyncIteration)) and exc.__cause__ is value: return False raise except: @@ -133,8 +133,9 @@ async def __aexit__(self, type, value, traceback): # fixes the impedance mismatch between the throw() protocol # and the __exit__() protocol. # - if sys.exc_info()[1] is not value: - raise + if sys.exc_info()[1] is value: + return False + raise def __enter__(self): raise RuntimeError("use 'async with {func_name}(...)', not 'with {func_name}(...)'".format(func_name=self._func_name)) diff --git a/trio/tests/test_util.py b/trio/tests/test_util.py index 00b0907d32..309e6ac81f 100644 --- a/trio/tests/test_util.py +++ b/trio/tests/test_util.py @@ -1,6 +1,10 @@ import pytest import signal +import sys +import textwrap + +from async_generator import async_generator, yield_ from .._util import * from .. import _core @@ -58,3 +62,68 @@ async def wait_with_ul1(): async with ul1: pass # pragma: no cover assert "ul1" in str(excinfo.value) + + +async def test_contextmanager_do_not_unchain_non_stopiteration_exceptions(): + @acontextmanager + @async_generator + async def manager_issue29692(): + try: + await yield_() + except Exception as exc: + raise RuntimeError('issue29692:Chained') from exc + + with pytest.raises(RuntimeError) as excinfo: + async with manager_issue29692(): + raise ZeroDivisionError + assert excinfo.value.args[0] == 'issue29692:Chained' + assert isinstance(excinfo.value.__cause__, ZeroDivisionError) + + # This is a little funky because of implementation details in async_generator + # It can all go away once we stop supporting Python3.5 + with pytest.raises(RuntimeError) as excinfo: + async with manager_issue29692(): + exc = StopIteration('issue29692:Unchained') + raise exc + assert excinfo.value.args[0] == 'issue29692:Chained' + cause = excinfo.value.__cause__ + assert cause.args[0] == 'generator raised StopIteration' + assert cause.__cause__ is exc + + with pytest.raises(StopAsyncIteration) as excinfo: + async with manager_issue29692(): + raise StopAsyncIteration('issue29692:Unchained') + assert excinfo.value.args[0] == 'issue29692:Unchained' + assert excinfo.value.__cause__ is None + + +# Native async generators are only available from Python 3.6 and onwards +nativeasyncgenerators = True +try: + exec(""" +@acontextmanager +async def manager_issue29692_2(): + try: + yield + except Exception as exc: + raise RuntimeError('issue29692:Chained') from exc +""") +except SyntaxError: + nativeasyncgenerators = False + + +@pytest.mark.skipif(not nativeasyncgenerators, reason="Python < 3.6 doesn't have native async generators") +async def test_native_contextmanager_do_not_unchain_non_stopiteration_exceptions(): + + with pytest.raises(RuntimeError) as excinfo: + async with manager_issue29692_2(): + raise ZeroDivisionError + assert excinfo.value.args[0] == 'issue29692:Chained' + assert isinstance(excinfo.value.__cause__, ZeroDivisionError) + + for cls in [StopIteration, StopAsyncIteration]: + with pytest.raises(cls) as excinfo: + async with manager_issue29692_2(): + raise cls('issue29692:Unchained') + assert excinfo.value.args[0] == 'issue29692:Unchained' + assert excinfo.value.__cause__ is None