diff --git a/nonebot/drivers/__init__.py b/nonebot/drivers/__init__.py index da235ff2dd78..c7e6e82dd016 100644 --- a/nonebot/drivers/__init__.py +++ b/nonebot/drivers/__init__.py @@ -24,6 +24,7 @@ from nonebot.internal.driver import Response as Response from nonebot.internal.driver import ReverseDriver as ReverseDriver from nonebot.internal.driver import ReverseMixin as ReverseMixin +from nonebot.internal.driver import Timeout as Timeout from nonebot.internal.driver import WebSocket as WebSocket from nonebot.internal.driver import WebSocketClientMixin as WebSocketClientMixin from nonebot.internal.driver import WebSocketServerSetup as WebSocketServerSetup @@ -34,6 +35,7 @@ "Cookies": True, "Request": True, "Response": True, + "Timeout": True, "WebSocket": True, "HTTPVersion": True, "Driver": True, diff --git a/nonebot/drivers/aiohttp.py b/nonebot/drivers/aiohttp.py index 7a2e06a7eca9..b77e374e6aa5 100644 --- a/nonebot/drivers/aiohttp.py +++ b/nonebot/drivers/aiohttp.py @@ -37,7 +37,14 @@ from nonebot.drivers import WebSocket as BaseWebSocket from nonebot.drivers.none import Driver as NoneDriver from nonebot.exception import WebSocketClosed -from nonebot.internal.driver import Cookies, CookieTypes, HeaderTypes, QueryTypes +from nonebot.internal.driver import ( + Cookies, + CookieTypes, + HeaderTypes, + QueryTypes, + Timeout, + TimeoutTypes, +) try: import aiohttp @@ -56,7 +63,7 @@ def __init__( headers: HeaderTypes = None, cookies: CookieTypes = None, version: Union[str, HTTPVersion] = HTTPVersion.H11, - timeout: Optional[float] = None, + timeout: TimeoutTypes = None, proxy: Optional[str] = None, ): self._client: Optional[aiohttp.ClientSession] = None @@ -78,7 +85,15 @@ def __init__( else: raise RuntimeError(f"Unsupported HTTP version: {version}") - self._timeout = timeout + if isinstance(timeout, Timeout): + self._timeout = aiohttp.ClientTimeout( + total=timeout.total, + connect=timeout.connect, + sock_read=timeout.read, + ) + else: + self._timeout = aiohttp.ClientTimeout(timeout) + self._proxy = proxy @property @@ -106,7 +121,14 @@ async def request(self, setup: Request) -> Response: if cookie.value is not None ) - timeout = aiohttp.ClientTimeout(setup.timeout) + if isinstance(setup.timeout, Timeout): + timeout = aiohttp.ClientTimeout( + total=setup.timeout.total, + connect=setup.timeout.connect, + sock_read=setup.timeout.read, + ) + else: + timeout = aiohttp.ClientTimeout(setup.timeout) async with await self.client.request( setup.method, @@ -149,7 +171,14 @@ async def stream_request( if cookie.value is not None ) - timeout = aiohttp.ClientTimeout(setup.timeout) + if isinstance(setup.timeout, Timeout): + timeout = aiohttp.ClientTimeout( + total=setup.timeout.total, + connect=setup.timeout.connect, + sock_read=setup.timeout.read, + ) + else: + timeout = aiohttp.ClientTimeout(setup.timeout) async with self.client.request( setup.method, @@ -226,7 +255,13 @@ async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]: else: raise RuntimeError(f"Unsupported HTTP version: {setup.version}") - timeout = aiohttp.ClientWSTimeout(ws_close=setup.timeout or 10.0) # type: ignore + if isinstance(setup.timeout, Timeout): + timeout = aiohttp.ClientWSTimeout( + ws_receive=setup.timeout.read, # type: ignore + ws_close=setup.timeout.total, # type: ignore + ) + else: + timeout = aiohttp.ClientWSTimeout(ws_close=setup.timeout or 10.0) # type: ignore async with aiohttp.ClientSession(version=version, trust_env=True) as session: async with session.ws_connect( @@ -245,7 +280,7 @@ def get_session( headers: HeaderTypes = None, cookies: CookieTypes = None, version: Union[str, HTTPVersion] = HTTPVersion.H11, - timeout: Optional[float] = None, + timeout: TimeoutTypes = None, proxy: Optional[str] = None, ) -> Session: return Session( diff --git a/nonebot/drivers/httpx.py b/nonebot/drivers/httpx.py index bca949a13cf6..49161566a7f7 100644 --- a/nonebot/drivers/httpx.py +++ b/nonebot/drivers/httpx.py @@ -33,7 +33,14 @@ combine_driver, ) from nonebot.drivers.none import Driver as NoneDriver -from nonebot.internal.driver import Cookies, CookieTypes, HeaderTypes, QueryTypes +from nonebot.internal.driver import ( + Cookies, + CookieTypes, + HeaderTypes, + QueryTypes, + Timeout, + TimeoutTypes, +) try: import httpx @@ -52,7 +59,7 @@ def __init__( headers: HeaderTypes = None, cookies: CookieTypes = None, version: Union[str, HTTPVersion] = HTTPVersion.H11, - timeout: Optional[float] = None, + timeout: TimeoutTypes = None, proxy: Optional[str] = None, ): self._client: Optional[httpx.AsyncClient] = None @@ -65,7 +72,16 @@ def __init__( ) self._cookies = Cookies(cookies) self._version = HTTPVersion(version) - self._timeout = timeout + + if isinstance(timeout, Timeout): + self._timeout = httpx.Timeout( + timeout=timeout.total, + connect=timeout.connect, + read=timeout.read, + ) + else: + self._timeout = httpx.Timeout(timeout) + self._proxy = proxy @property @@ -76,6 +92,15 @@ def client(self) -> httpx.AsyncClient: @override async def request(self, setup: Request) -> Response: + if isinstance(setup.timeout, Timeout): + timeout = httpx.Timeout( + timeout=setup.timeout.total, + connect=setup.timeout.connect, + read=setup.timeout.read, + ) + else: + timeout = httpx.Timeout(setup.timeout) + response = await self.client.request( setup.method, str(setup.url), @@ -87,7 +112,7 @@ async def request(self, setup: Request) -> Response: params=setup.url.raw_query_string, headers=tuple(setup.headers.items()), cookies=setup.cookies.jar, - timeout=setup.timeout, + timeout=timeout, ) return Response( response.status_code, @@ -103,6 +128,15 @@ async def stream_request( *, chunk_size: int = 1024, ) -> AsyncGenerator[Response, None]: + if isinstance(setup.timeout, Timeout): + timeout = httpx.Timeout( + timeout=setup.timeout.total, + connect=setup.timeout.connect, + read=setup.timeout.read, + ) + else: + timeout = httpx.Timeout(setup.timeout) + async with self.client.stream( setup.method, str(setup.url), @@ -114,7 +148,7 @@ async def stream_request( params=setup.url.raw_query_string, headers=tuple(setup.headers.items()), cookies=setup.cookies.jar, - timeout=setup.timeout, + timeout=timeout, ) as response: response_headers = response.headers.multi_items() async for chunk in response.aiter_bytes(chunk_size=chunk_size): @@ -183,7 +217,7 @@ def get_session( headers: HeaderTypes = None, cookies: CookieTypes = None, version: Union[str, HTTPVersion] = HTTPVersion.H11, - timeout: Optional[float] = None, + timeout: TimeoutTypes = None, proxy: Optional[str] = None, ) -> Session: return Session( diff --git a/nonebot/drivers/websockets.py b/nonebot/drivers/websockets.py index ff9601b05b9d..b348c417502a 100644 --- a/nonebot/drivers/websockets.py +++ b/nonebot/drivers/websockets.py @@ -25,7 +25,7 @@ from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union from typing_extensions import ParamSpec, override -from nonebot.drivers import Request, WebSocketClientMixin, combine_driver +from nonebot.drivers import Request, Timeout, WebSocketClientMixin, combine_driver from nonebot.drivers import WebSocket as BaseWebSocket from nonebot.drivers.none import Driver as NoneDriver from nonebot.exception import WebSocketClosed @@ -73,10 +73,16 @@ def type(self) -> str: async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]: if setup.proxy is not None: logger.warning("proxy is not supported by websockets driver") + + if isinstance(setup.timeout, Timeout): + timeout = setup.timeout.total or setup.timeout.connect or setup.timeout.read + else: + timeout = setup.timeout + connection = Connect( str(setup.url), extra_headers={**setup.headers, **setup.cookies.as_header(setup)}, - open_timeout=setup.timeout, + open_timeout=timeout, ) async with connection as ws: yield WebSocket(request=setup, websocket=ws) diff --git a/nonebot/internal/driver/__init__.py b/nonebot/internal/driver/__init__.py index 0fd5c8553e51..e4b3f042c3f6 100644 --- a/nonebot/internal/driver/__init__.py +++ b/nonebot/internal/driver/__init__.py @@ -27,5 +27,7 @@ from .model import Request as Request from .model import Response as Response from .model import SimpleQuery as SimpleQuery +from .model import Timeout as Timeout +from .model import TimeoutTypes as TimeoutTypes from .model import WebSocket as WebSocket from .model import WebSocketServerSetup as WebSocketServerSetup diff --git a/nonebot/internal/driver/abstract.py b/nonebot/internal/driver/abstract.py index d35e0011f58a..253354f18db2 100644 --- a/nonebot/internal/driver/abstract.py +++ b/nonebot/internal/driver/abstract.py @@ -30,6 +30,7 @@ QueryTypes, Request, Response, + TimeoutTypes, WebSocket, WebSocketServerSetup, ) @@ -245,7 +246,7 @@ def __init__( headers: HeaderTypes = None, cookies: CookieTypes = None, version: Union[str, HTTPVersion] = HTTPVersion.H11, - timeout: Optional[float] = None, + timeout: TimeoutTypes = None, proxy: Optional[str] = None, ): raise NotImplementedError @@ -315,7 +316,7 @@ def get_session( headers: HeaderTypes = None, cookies: CookieTypes = None, version: Union[str, HTTPVersion] = HTTPVersion.H11, - timeout: Optional[float] = None, + timeout: TimeoutTypes = None, proxy: Optional[str] = None, ) -> HTTPClientSession: """获取一个 HTTP 会话""" diff --git a/nonebot/internal/driver/model.py b/nonebot/internal/driver/model.py index 0f0b2c27827c..ca6558c92777 100644 --- a/nonebot/internal/driver/model.py +++ b/nonebot/internal/driver/model.py @@ -42,6 +42,7 @@ FileType, ] FilesTypes: TypeAlias = Union[dict[str, FileTypes], list[tuple[str, FileTypes]], None] +TimeoutTypes: TypeAlias = Union[float, "Timeout", None] class HTTPVersion(Enum): @@ -50,6 +51,15 @@ class HTTPVersion(Enum): H2 = "2" +@dataclass +class Timeout: + """Request 超时配置。""" + + total: Optional[float] = None + connect: Optional[float] = None + read: Optional[float] = None + + class Request: def __init__( self, @@ -64,7 +74,7 @@ def __init__( json: Any = None, files: FilesTypes = None, version: Union[str, HTTPVersion] = HTTPVersion.H11, - timeout: Optional[float] = None, + timeout: TimeoutTypes = None, proxy: Optional[str] = None, ): # method @@ -76,7 +86,7 @@ def __init__( # http version self.version: HTTPVersion = HTTPVersion(version) # timeout - self.timeout: Optional[float] = timeout + self.timeout: TimeoutTypes = timeout # proxy self.proxy: Optional[str] = proxy diff --git a/tests/test_driver.py b/tests/test_driver.py index 3094ea62a0f1..cf364daa083e 100644 --- a/tests/test_driver.py +++ b/tests/test_driver.py @@ -16,6 +16,7 @@ HTTPServerSetup, Request, Response, + Timeout, WebSocket, WebSocketClientMixin, WebSocketServerSetup, @@ -235,6 +236,7 @@ async def test_http_client(driver: Driver, server_url: URL): headers={"X-Test": "test"}, cookies={"session": "test"}, content="test", + timeout=Timeout(total=4, connect=2, read=2), ) response = await driver.request(request) assert server_url.host is not None @@ -250,6 +252,7 @@ async def test_http_client(driver: Driver, server_url: URL): headers={"X-Test": "test"}, cookies={"session": "test"}, content="test", + timeout=Timeout(total=4, connect=2, read=2), ) assert request.url == request_raw_url.url, ( "request.url should be equal to request_raw_url.url" @@ -312,6 +315,7 @@ async def test_http_client(driver: Driver, server_url: URL): headers={"X-Test": "stream"}, cookies={"session": "stream"}, content="stream_test" * 1024, + timeout=Timeout(total=4, connect=2, read=2), ) chunks = [] async for resp in driver.stream_request(request, chunk_size=4): @@ -414,6 +418,7 @@ async def test_http_client_session(driver: Driver, server_url: URL): headers={"X-Test": "test"}, cookies={"cookie": "test"}, content="test", + timeout=Timeout(total=4, connect=2, read=2), ) response = await session.request(request) assert response.status_code == 200 @@ -499,6 +504,7 @@ async def test_http_client_session(driver: Driver, server_url: URL): headers={"X-Test": "stream"}, cookies={"cookie": "stream"}, content="stream_test" * 1024, + timeout=Timeout(total=4, connect=2, read=2), ) chunks = [] async for resp in session.stream_request(request, chunk_size=4):