diff --git a/nonebug/fixture.py b/nonebug/fixture.py index 006cc4c..d7254f7 100644 --- a/nonebug/fixture.py +++ b/nonebug/fixture.py @@ -1,17 +1,57 @@ +from contextlib import nullcontext, asynccontextmanager + import pytest +from async_asgi_testclient import TestClient from nonebug.app import App +from nonebug.mixin.driver import set_global_client from . import NONEBOT_INIT_KWARGS, NONEBOT_START_LIFESPAN +@asynccontextmanager +async def lifespan_ctx(): + import nonebot + from nonebot import logger + from nonebot.drivers import ASGIMixin + + driver = nonebot.get_driver() + + if isinstance(driver, ASGIMixin): + # if the driver has an asgi application + # use asgi lifespan to startup/shutdown + ctx = TestClient(driver.asgi) + set_global_client(ctx) + else: + ctx = driver._lifespan + + try: + await ctx.__aenter__() + except Exception as e: + logger.opt(colors=True, exception=e).error( + "Error occurred while running startup hook." + "" + ) + raise + + try: + yield + finally: + try: + await ctx.__aexit__(None, None, None) + except Exception as e: + logger.opt(colors=True, exception=e).error( + "Error occurred while running shutdown hook." + "" + ) + + @pytest.fixture(scope="session", autouse=True) -async def nonebug_init(request: pytest.FixtureRequest): # noqa: PT004 +def _nonebot_init(request: pytest.FixtureRequest): """ Initialize nonebot before test case running. """ import nonebot - from nonebot import logger from nonebot.matcher import matchers from nonebug.provider import NoneBugProvider @@ -19,29 +59,22 @@ async def nonebug_init(request: pytest.FixtureRequest): # noqa: PT004 nonebot.init(**request.config.stash.get(NONEBOT_INIT_KWARGS, {})) matchers.set_provider(NoneBugProvider) + +@pytest.fixture(scope="session", autouse=True) +async def after_nonebot_init(_nonebot_init: None): # noqa: PT004 + pass + + +@pytest.fixture(scope="session", autouse=True) +async def nonebug_init( # noqa: PT004 + _nonebot_init: None, after_nonebot_init: None, request: pytest.FixtureRequest +): run_lifespan = request.config.stash.get(NONEBOT_START_LIFESPAN, True) - driver = nonebot.get_driver() - if run_lifespan: - try: - await driver._lifespan.startup() - except Exception as e: - logger.opt(colors=True, exception=e).error( - "Error occurred while running startup hook." - "" - ) - raise - try: + ctx = lifespan_ctx() if run_lifespan else nullcontext() + + async with ctx: yield - finally: - if run_lifespan: - try: - await driver._lifespan.shutdown() - except Exception as e: - logger.opt(colors=True, exception=e).error( - "Error occurred while running shutdown hook." - "" - ) @pytest.fixture(name="app") @@ -53,4 +86,4 @@ def nonebug_app(nonebug_init) -> App: return App() -__all__ = ["nonebug_init", "nonebug_app"] +__all__ = ["after_nonebot_init", "nonebug_init", "nonebug_app"] diff --git a/nonebug/mixin/driver.py b/nonebug/mixin/driver.py index c14b1f8..bc79c77 100644 --- a/nonebug/mixin/driver.py +++ b/nonebug/mixin/driver.py @@ -6,6 +6,21 @@ from nonebug.base import BaseApp, Context +_global_client: Optional[TestClient] = None + + +def set_global_client(client: TestClient): + global _global_client + + if _global_client is not None: + raise RuntimeError() + + _global_client = client + + +def get_global_client() -> Optional[TestClient]: + return _global_client + @final class ServerContext(Context): @@ -14,18 +29,21 @@ def __init__( app: BaseApp, *args, asgi: ASGIApplication, + client: Optional[TestClient] = None, **kwargs, ): super().__init__(app, *args, **kwargs) self.asgi = asgi + self.specified_client = client self.client = TestClient(self.asgi) def get_client(self) -> TestClient: - return self.client + return self.specified_client or self.client async def setup(self) -> None: await super().setup() - await self.stack.enter_async_context(self.client) + if self.specified_client is None: + await self.stack.enter_async_context(self.client) # @final @@ -50,8 +68,12 @@ class DriverMixin(BaseApp): def test_server(self, asgi: Optional[ASGIApplication] = None) -> ServerContext: import nonebot + client = None + if asgi is None: + client = get_global_client() + asgi = asgi or nonebot.get_asgi() - return ServerContext(self, asgi=asgi) + return ServerContext(self, asgi=asgi, client=client) # def test_client(self): # ... diff --git a/tests/conftest.py b/tests/conftest.py index 6685658..76ba821 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,13 +5,13 @@ from nonebot.plugin import Plugin from nonebug import NONEBOT_INIT_KWARGS -from nonebug.fixture import * # noqa: F403 +from nonebug.fixture import nonebug_app, nonebug_init, _nonebot_init # noqa: F401 def pytest_configure(config: pytest.Config) -> None: config.stash[NONEBOT_INIT_KWARGS] = {"custom_key": "custom_value"} -@pytest.fixture(scope="session") -def load_plugin(nonebug_init: None) -> set[Plugin]: +@pytest.fixture(scope="session", autouse=True) +async def after_nonebot_init(_nonebot_init: None) -> set[Plugin]: # noqa: F811 return nonebot.load_plugins(str(Path(__file__).parent / "plugins"))