diff --git a/nonebug/fixture.py b/nonebug/fixture.py
index 006cc4c..d7254f7 100644
--- a/nonebug/fixture.py
+++ b/nonebug/fixture.py
@@ -1,17 +1,57 @@
+from contextlib import nullcontext, asynccontextmanager
+
import pytest
+from async_asgi_testclient import TestClient
from nonebug.app import App
+from nonebug.mixin.driver import set_global_client
from . import NONEBOT_INIT_KWARGS, NONEBOT_START_LIFESPAN
+@asynccontextmanager
+async def lifespan_ctx():
+ import nonebot
+ from nonebot import logger
+ from nonebot.drivers import ASGIMixin
+
+ driver = nonebot.get_driver()
+
+ if isinstance(driver, ASGIMixin):
+ # if the driver has an asgi application
+ # use asgi lifespan to startup/shutdown
+ ctx = TestClient(driver.asgi)
+ set_global_client(ctx)
+ else:
+ ctx = driver._lifespan
+
+ try:
+ await ctx.__aenter__()
+ except Exception as e:
+ logger.opt(colors=True, exception=e).error(
+ "Error occurred while running startup hook."
+ ""
+ )
+ raise
+
+ try:
+ yield
+ finally:
+ try:
+ await ctx.__aexit__(None, None, None)
+ except Exception as e:
+ logger.opt(colors=True, exception=e).error(
+ "Error occurred while running shutdown hook."
+ ""
+ )
+
+
@pytest.fixture(scope="session", autouse=True)
-async def nonebug_init(request: pytest.FixtureRequest): # noqa: PT004
+def _nonebot_init(request: pytest.FixtureRequest):
"""
Initialize nonebot before test case running.
"""
import nonebot
- from nonebot import logger
from nonebot.matcher import matchers
from nonebug.provider import NoneBugProvider
@@ -19,29 +59,22 @@ async def nonebug_init(request: pytest.FixtureRequest): # noqa: PT004
nonebot.init(**request.config.stash.get(NONEBOT_INIT_KWARGS, {}))
matchers.set_provider(NoneBugProvider)
+
+@pytest.fixture(scope="session", autouse=True)
+async def after_nonebot_init(_nonebot_init: None): # noqa: PT004
+ pass
+
+
+@pytest.fixture(scope="session", autouse=True)
+async def nonebug_init( # noqa: PT004
+ _nonebot_init: None, after_nonebot_init: None, request: pytest.FixtureRequest
+):
run_lifespan = request.config.stash.get(NONEBOT_START_LIFESPAN, True)
- driver = nonebot.get_driver()
- if run_lifespan:
- try:
- await driver._lifespan.startup()
- except Exception as e:
- logger.opt(colors=True, exception=e).error(
- "Error occurred while running startup hook."
- ""
- )
- raise
- try:
+ ctx = lifespan_ctx() if run_lifespan else nullcontext()
+
+ async with ctx:
yield
- finally:
- if run_lifespan:
- try:
- await driver._lifespan.shutdown()
- except Exception as e:
- logger.opt(colors=True, exception=e).error(
- "Error occurred while running shutdown hook."
- ""
- )
@pytest.fixture(name="app")
@@ -53,4 +86,4 @@ def nonebug_app(nonebug_init) -> App:
return App()
-__all__ = ["nonebug_init", "nonebug_app"]
+__all__ = ["after_nonebot_init", "nonebug_init", "nonebug_app"]
diff --git a/nonebug/mixin/driver.py b/nonebug/mixin/driver.py
index c14b1f8..bc79c77 100644
--- a/nonebug/mixin/driver.py
+++ b/nonebug/mixin/driver.py
@@ -6,6 +6,21 @@
from nonebug.base import BaseApp, Context
+_global_client: Optional[TestClient] = None
+
+
+def set_global_client(client: TestClient):
+ global _global_client
+
+ if _global_client is not None:
+ raise RuntimeError()
+
+ _global_client = client
+
+
+def get_global_client() -> Optional[TestClient]:
+ return _global_client
+
@final
class ServerContext(Context):
@@ -14,18 +29,21 @@ def __init__(
app: BaseApp,
*args,
asgi: ASGIApplication,
+ client: Optional[TestClient] = None,
**kwargs,
):
super().__init__(app, *args, **kwargs)
self.asgi = asgi
+ self.specified_client = client
self.client = TestClient(self.asgi)
def get_client(self) -> TestClient:
- return self.client
+ return self.specified_client or self.client
async def setup(self) -> None:
await super().setup()
- await self.stack.enter_async_context(self.client)
+ if self.specified_client is None:
+ await self.stack.enter_async_context(self.client)
# @final
@@ -50,8 +68,12 @@ class DriverMixin(BaseApp):
def test_server(self, asgi: Optional[ASGIApplication] = None) -> ServerContext:
import nonebot
+ client = None
+ if asgi is None:
+ client = get_global_client()
+
asgi = asgi or nonebot.get_asgi()
- return ServerContext(self, asgi=asgi)
+ return ServerContext(self, asgi=asgi, client=client)
# def test_client(self):
# ...
diff --git a/tests/conftest.py b/tests/conftest.py
index 6685658..76ba821 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -5,13 +5,13 @@
from nonebot.plugin import Plugin
from nonebug import NONEBOT_INIT_KWARGS
-from nonebug.fixture import * # noqa: F403
+from nonebug.fixture import nonebug_app, nonebug_init, _nonebot_init # noqa: F401
def pytest_configure(config: pytest.Config) -> None:
config.stash[NONEBOT_INIT_KWARGS] = {"custom_key": "custom_value"}
-@pytest.fixture(scope="session")
-def load_plugin(nonebug_init: None) -> set[Plugin]:
+@pytest.fixture(scope="session", autouse=True)
+async def after_nonebot_init(_nonebot_init: None) -> set[Plugin]: # noqa: F811
return nonebot.load_plugins(str(Path(__file__).parent / "plugins"))