diff --git a/pytest_trio/_tests/test_async_yield_fixture.py b/pytest_trio/_tests/test_async_yield_fixture.py index 88ed8ff..27b62e4 100644 --- a/pytest_trio/_tests/test_async_yield_fixture.py +++ b/pytest_trio/_tests/test_async_yield_fixture.py @@ -224,3 +224,40 @@ async def test_actual_test(fix1): # TODO: should trigger error instead of failure # result.assert_outcomes(error=1) result.assert_outcomes(failed=1) + + +@pytest.mark.skipif(sys.version_info < (3, 6), reason="requires python3.6") +def test_async_yield_fixture_with_nursery(testdir): + + testdir.makepyfile( + """ + import pytest + import trio + + + async def handle_client(stream): + while True: + buff = await stream.receive_some(4) + await stream.send_all(buff) + + + @pytest.fixture + async def server(): + async with trio.open_nursery() as nursery: + listeners = await nursery.start(trio.serve_tcp, handle_client, 0) + yield listeners[0] + nursery.cancel_scope.cancel() + + + @pytest.mark.trio + async def test_actual_test(server): + stream = await trio.testing.open_stream_to_socket_listener(server) + await stream.send_all(b'ping') + rep = await stream.receive_some(4) + assert rep == b'ping' + """ + ) + + result = testdir.runpytest() + + result.assert_outcomes(passed=1) diff --git a/pytest_trio/plugin.py b/pytest_trio/plugin.py index 5839992..298f129 100644 --- a/pytest_trio/plugin.py +++ b/pytest_trio/plugin.py @@ -39,24 +39,17 @@ async def _bootstrap_fixture_and_run_test(**kwargs): async def _setup_async_fixtures_in(deps): resolved_deps = {**deps} - async def _resolve_and_update_deps(afunc, deps, entry): - deps[entry] = await afunc() - - async with trio.open_nursery() as nursery: - for depname, depval in resolved_deps.items(): - if isinstance(depval, BaseAsyncFixture): - nursery.start_soon( - _resolve_and_update_deps, depval.setup, resolved_deps, - depname - ) + for depname, depval in resolved_deps.items(): + if isinstance(depval, BaseAsyncFixture): + resolved_deps[depname] = await depval.setup() + return resolved_deps async def _teardown_async_fixtures_in(deps): - async with trio.open_nursery() as nursery: - for depval in deps.values(): - if isinstance(depval, BaseAsyncFixture): - nursery.start_soon(depval.teardown) + for depval in deps.values(): + if isinstance(depval, BaseAsyncFixture): + await depval.teardown() class BaseAsyncFixture: