From d611f294db543ede9a3728a751a7a4b29a11233b Mon Sep 17 00:00:00 2001 From: Ju4tCode <42488585+yanyongyu@users.noreply.github.com> Date: Sat, 6 Apr 2024 18:13:21 +0800 Subject: [PATCH 1/2] :sparkles: add session context check --- nonebot/drivers/aiohttp.py | 2 ++ nonebot/drivers/httpx.py | 2 ++ tests/test_driver.py | 5 +++++ 3 files changed, 9 insertions(+) diff --git a/nonebot/drivers/aiohttp.py b/nonebot/drivers/aiohttp.py index b12d204b447d..6b437e223fc7 100644 --- a/nonebot/drivers/aiohttp.py +++ b/nonebot/drivers/aiohttp.py @@ -124,6 +124,8 @@ async def request(self, setup: Request) -> Response: @override async def setup(self) -> None: + if self._client is not None: + raise RuntimeError("Session has already been initialized") self._client = aiohttp.ClientSession( cookies=self._cookies, headers=self._headers, diff --git a/nonebot/drivers/httpx.py b/nonebot/drivers/httpx.py index 8300323bd65c..68eed464853b 100644 --- a/nonebot/drivers/httpx.py +++ b/nonebot/drivers/httpx.py @@ -93,6 +93,8 @@ async def request(self, setup: Request) -> Response: @override async def setup(self) -> None: + if self._client is not None: + raise RuntimeError("Session has already been initialized") self._client = httpx.AsyncClient( params=self._params, headers=self._headers, diff --git a/tests/test_driver.py b/tests/test_driver.py index 546cfeecb5ad..137488a3d11f 100644 --- a/tests/test_driver.py +++ b/tests/test_driver.py @@ -328,6 +328,11 @@ async def test_http_client_session(driver: Driver, server_url: URL): with pytest.raises(RuntimeError): await session.request(request) + with pytest.raises(RuntimeError): + async with session: + async with session: + ... + async with session as session: # simple post with query, headers, cookies and content request = Request( From 3f5e4f0a966682e702e7a929fa5a292b109911e7 Mon Sep 17 00:00:00 2001 From: Ju4tCode <42488585+yanyongyu@users.noreply.github.com> Date: Sat, 6 Apr 2024 18:17:48 +0800 Subject: [PATCH 2/2] Update tests/test_driver.py --- tests/test_driver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_driver.py b/tests/test_driver.py index 137488a3d11f..636e02c3e633 100644 --- a/tests/test_driver.py +++ b/tests/test_driver.py @@ -328,7 +328,7 @@ async def test_http_client_session(driver: Driver, server_url: URL): with pytest.raises(RuntimeError): await session.request(request) - with pytest.raises(RuntimeError): + with pytest.raises(RuntimeError): # noqa: PT012 async with session: async with session: ...