diff --git a/newsfragments/612.bugfix.rst b/newsfragments/612.bugfix.rst new file mode 100644 index 0000000000..d9a9bdfb1d --- /dev/null +++ b/newsfragments/612.bugfix.rst @@ -0,0 +1,4 @@ +The nursery context manager was rewritten to avoid use of +`@asynccontextmanager` and `@async_generator`. This reduces extraneous frames +in exception traces and addresses bugs regarding `StopIteration` and +`StopAsyncIteration` exceptions not propagating correctly. diff --git a/trio/_core/_run.py b/trio/_core/_run.py index 6c8651ceb8..8421f3be0a 100644 --- a/trio/_core/_run.py +++ b/trio/_core/_run.py @@ -16,9 +16,7 @@ from sniffio import current_async_library_cvar import attr -from async_generator import ( - async_generator, yield_, asynccontextmanager, isasyncgen -) +from async_generator import isasyncgen from sortedcontainers import SortedDict from outcome import Error, Value @@ -295,57 +293,55 @@ def started(self, value=None): self._old_nursery._check_nursery_closed() -@asynccontextmanager -@async_generator -@enable_ki_protection -async def open_nursery(): - """Returns an async context manager which creates a new nursery. +class NurseryManager: + """Nursery context manager. - This context manager's ``__aenter__`` method executes synchronously. Its - ``__aexit__`` method blocks until all child tasks have exited. + Note we explicitly avoid @asynccontextmanager and @async_generator + since they add a lot of extraneous stack frames to exceptions, as + well as cause problematic behavior with handling of StopIteration + and StopAsyncIteration. """ - assert currently_ki_protected() - with open_cancel_scope() as scope: - nursery = Nursery(current_task(), scope) - nested_child_exc = None - try: - await yield_(nursery) - except BaseException as exc: - nested_child_exc = exc + + @enable_ki_protection + async def __aenter__(self): + assert currently_ki_protected() + self._scope_manager = open_cancel_scope() + scope = self._scope_manager.__enter__() + self._nursery = Nursery(current_task(), scope) + return self._nursery + + @enable_ki_protection + async def __aexit__(self, etype, exc, tb): assert currently_ki_protected() - await nursery._nested_child_finished(nested_child_exc) - - -# I *think* this is equivalent to the above, and it gives *much* nicer -# exception tracebacks... but I'm a little nervous about it because it's much -# trickier code :-( -# -# class NurseryManager: -# @enable_ki_protection -# async def __aenter__(self): -# self._scope_manager = open_cancel_scope() -# scope = self._scope_manager.__enter__() -# self._parent_nursery = Nursery(current_task(), scope) -# return self._parent_nursery -# -# @enable_ki_protection -# async def __aexit__(self, etype, exc, tb): -# try: -# await self._parent_nursery._clean_up(exc) -# except BaseException as new_exc: -# if not self._scope_manager.__exit__( -# type(new_exc), new_exc, new_exc.__traceback__): -# if exc is new_exc: -# return False -# else: -# raise -# else: -# self._scope_manager.__exit__(None, None, None) -# return True -# -# def open_nursery(): -# return NurseryManager() + try: + await self._nursery._nested_child_finished(exc) + except BaseException as new_exc: + try: + if self._scope_manager.__exit__( + type(new_exc), new_exc, new_exc.__traceback__ + ): + return True + except BaseException as scope_manager_exc: + if scope_manager_exc == exc: + return False + raise # scope_manager_exc + raise # new_exc + else: + self._scope_manager.__exit__(None, None, None) + return True + + def __enter__(self): + raise RuntimeError( + "use 'async with open_nursery(...)', not 'with open_nursery(...)'" + ) + + def __exit__(self): # pragma: no cover + assert False, """Never called, but should be defined""" + + +def open_nursery(): + return NurseryManager() class Nursery: diff --git a/trio/_core/tests/test_run.py b/trio/_core/tests/test_run.py index 49ffb76baf..c264c39e75 100644 --- a/trio/_core/tests/test_run.py +++ b/trio/_core/tests/test_run.py @@ -17,6 +17,7 @@ from .tutil import check_sequence_matches, gc_collect_harder from ... import _core from ..._timeouts import sleep +from ..._util import aiter_compat from ...testing import ( wait_all_tasks_blocked, Sequencer, @@ -1823,6 +1824,79 @@ async def start_sleep_then_crash(nursery): assert _core.current_time() - t0 == 7 +async def test_nursery_explicit_exception(): + with pytest.raises(KeyError): + async with _core.open_nursery(): + raise KeyError() + + +async def test_nursery_stop_iteration(): + async def fail(): + raise ValueError + + try: + async with _core.open_nursery() as nursery: + nursery.start_soon(fail) + raise StopIteration + except _core.MultiError as e: + assert tuple(map(type, e.exceptions)) == (StopIteration, ValueError) + + +async def test_nursery_stop_async_iteration(): + class it(object): + def __init__(self, count): + self.count = count + self.val = 0 + + async def __anext__(self): + await sleep(0) + val = self.val + if val >= self.count: + raise StopAsyncIteration + self.val += 1 + return val + + class async_zip(object): + def __init__(self, *largs): + self.nexts = [obj.__anext__ for obj in largs] + + async def _accumulate(self, f, items, i): + items[i] = await f() + + @aiter_compat + def __aiter__(self): + return self + + async def __anext__(self): + nexts = self.nexts + items = [ + None, + ] * len(nexts) + got_stop = False + + def handle(exc): + nonlocal got_stop + if isinstance(exc, StopAsyncIteration): + got_stop = True + return None + else: # pragma: no cover + return exc + + with _core.MultiError.catch(handle): + async with _core.open_nursery() as nursery: + for i, f in enumerate(nexts): + nursery.start_soon(self._accumulate, f, items, i) + + if got_stop: + raise StopAsyncIteration + return items + + result = [] + async for vals in async_zip(it(4), it(2)): + result.append(vals) + assert result == [[0, 0], [1, 1]] + + def test_contextvar_support(): var = contextvars.ContextVar("test") var.set("before")