Skip to content
Merged
2 changes: 2 additions & 0 deletions nonebot/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from nonebot.internal.driver import Response as Response
from nonebot.internal.driver import ReverseDriver as ReverseDriver
from nonebot.internal.driver import ReverseMixin as ReverseMixin
from nonebot.internal.driver import Timeout as Timeout
from nonebot.internal.driver import WebSocket as WebSocket
from nonebot.internal.driver import WebSocketClientMixin as WebSocketClientMixin
from nonebot.internal.driver import WebSocketServerSetup as WebSocketServerSetup
Expand All @@ -34,6 +35,7 @@
"Cookies": True,
"Request": True,
"Response": True,
"Timeout": True,
"WebSocket": True,
"HTTPVersion": True,
"Driver": True,
Expand Down
49 changes: 42 additions & 7 deletions nonebot/drivers/aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,14 @@
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers.none import Driver as NoneDriver
from nonebot.exception import WebSocketClosed
from nonebot.internal.driver import Cookies, CookieTypes, HeaderTypes, QueryTypes
from nonebot.internal.driver import (
Cookies,
CookieTypes,
HeaderTypes,
QueryTypes,
Timeout,
TimeoutTypes,
)

try:
import aiohttp
Expand All @@ -56,7 +63,7 @@ def __init__(
headers: HeaderTypes = None,
cookies: CookieTypes = None,
version: Union[str, HTTPVersion] = HTTPVersion.H11,
timeout: Optional[float] = None,
timeout: TimeoutTypes = None,
proxy: Optional[str] = None,
):
self._client: Optional[aiohttp.ClientSession] = None
Expand All @@ -78,7 +85,15 @@ def __init__(
else:
raise RuntimeError(f"Unsupported HTTP version: {version}")

self._timeout = timeout
if isinstance(timeout, Timeout):
self._timeout = aiohttp.ClientTimeout(
total=timeout.total,
connect=timeout.connect,
sock_read=timeout.read,
)
else:
self._timeout = aiohttp.ClientTimeout(timeout)

self._proxy = proxy

@property
Expand Down Expand Up @@ -106,7 +121,14 @@ async def request(self, setup: Request) -> Response:
if cookie.value is not None
)

timeout = aiohttp.ClientTimeout(setup.timeout)
if isinstance(setup.timeout, Timeout):
timeout = aiohttp.ClientTimeout(
total=setup.timeout.total,
connect=setup.timeout.connect,
sock_read=setup.timeout.read,
)
else:
timeout = aiohttp.ClientTimeout(setup.timeout)

async with await self.client.request(
setup.method,
Expand Down Expand Up @@ -149,7 +171,14 @@ async def stream_request(
if cookie.value is not None
)

timeout = aiohttp.ClientTimeout(setup.timeout)
if isinstance(setup.timeout, Timeout):
timeout = aiohttp.ClientTimeout(
total=setup.timeout.total,
connect=setup.timeout.connect,
sock_read=setup.timeout.read,
)
else:
timeout = aiohttp.ClientTimeout(setup.timeout)

async with self.client.request(
setup.method,
Expand Down Expand Up @@ -226,7 +255,13 @@ async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]:
else:
raise RuntimeError(f"Unsupported HTTP version: {setup.version}")

timeout = aiohttp.ClientWSTimeout(ws_close=setup.timeout or 10.0) # type: ignore
if isinstance(setup.timeout, Timeout):
timeout = aiohttp.ClientWSTimeout(
ws_receive=setup.timeout.read, # type: ignore
ws_close=setup.timeout.total, # type: ignore
)
else:
timeout = aiohttp.ClientWSTimeout(ws_close=setup.timeout or 10.0) # type: ignore

async with aiohttp.ClientSession(version=version, trust_env=True) as session:
async with session.ws_connect(
Expand All @@ -245,7 +280,7 @@ def get_session(
headers: HeaderTypes = None,
cookies: CookieTypes = None,
version: Union[str, HTTPVersion] = HTTPVersion.H11,
timeout: Optional[float] = None,
timeout: TimeoutTypes = None,
proxy: Optional[str] = None,
) -> Session:
return Session(
Expand Down
46 changes: 40 additions & 6 deletions nonebot/drivers/httpx.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,14 @@
combine_driver,
)
from nonebot.drivers.none import Driver as NoneDriver
from nonebot.internal.driver import Cookies, CookieTypes, HeaderTypes, QueryTypes
from nonebot.internal.driver import (
Cookies,
CookieTypes,
HeaderTypes,
QueryTypes,
Timeout,
TimeoutTypes,
)

try:
import httpx
Expand All @@ -52,7 +59,7 @@ def __init__(
headers: HeaderTypes = None,
cookies: CookieTypes = None,
version: Union[str, HTTPVersion] = HTTPVersion.H11,
timeout: Optional[float] = None,
timeout: TimeoutTypes = None,
proxy: Optional[str] = None,
):
self._client: Optional[httpx.AsyncClient] = None
Expand All @@ -65,7 +72,16 @@ def __init__(
)
self._cookies = Cookies(cookies)
self._version = HTTPVersion(version)
self._timeout = timeout

if isinstance(timeout, Timeout):
self._timeout = httpx.Timeout(
timeout=timeout.total,
connect=timeout.connect,
read=timeout.read,
)
else:
self._timeout = httpx.Timeout(timeout)

self._proxy = proxy

@property
Expand All @@ -76,6 +92,15 @@ def client(self) -> httpx.AsyncClient:

@override
async def request(self, setup: Request) -> Response:
if isinstance(setup.timeout, Timeout):
timeout = httpx.Timeout(
timeout=setup.timeout.total,
connect=setup.timeout.connect,
read=setup.timeout.read,
)
else:
timeout = httpx.Timeout(setup.timeout)

response = await self.client.request(
setup.method,
str(setup.url),
Expand All @@ -87,7 +112,7 @@ async def request(self, setup: Request) -> Response:
params=setup.url.raw_query_string,
headers=tuple(setup.headers.items()),
cookies=setup.cookies.jar,
timeout=setup.timeout,
timeout=timeout,
)
return Response(
response.status_code,
Expand All @@ -103,6 +128,15 @@ async def stream_request(
*,
chunk_size: int = 1024,
) -> AsyncGenerator[Response, None]:
if isinstance(setup.timeout, Timeout):
timeout = httpx.Timeout(
timeout=setup.timeout.total,
connect=setup.timeout.connect,
read=setup.timeout.read,
)
else:
timeout = httpx.Timeout(setup.timeout)

async with self.client.stream(
setup.method,
str(setup.url),
Expand All @@ -114,7 +148,7 @@ async def stream_request(
params=setup.url.raw_query_string,
headers=tuple(setup.headers.items()),
cookies=setup.cookies.jar,
timeout=setup.timeout,
timeout=timeout,
) as response:
response_headers = response.headers.multi_items()
async for chunk in response.aiter_bytes(chunk_size=chunk_size):
Expand Down Expand Up @@ -183,7 +217,7 @@ def get_session(
headers: HeaderTypes = None,
cookies: CookieTypes = None,
version: Union[str, HTTPVersion] = HTTPVersion.H11,
timeout: Optional[float] = None,
timeout: TimeoutTypes = None,
proxy: Optional[str] = None,
) -> Session:
return Session(
Expand Down
10 changes: 8 additions & 2 deletions nonebot/drivers/websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union
from typing_extensions import ParamSpec, override

from nonebot.drivers import Request, WebSocketClientMixin, combine_driver
from nonebot.drivers import Request, Timeout, WebSocketClientMixin, combine_driver
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers.none import Driver as NoneDriver
from nonebot.exception import WebSocketClosed
Expand Down Expand Up @@ -73,10 +73,16 @@ def type(self) -> str:
async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]:
if setup.proxy is not None:
logger.warning("proxy is not supported by websockets driver")

if isinstance(setup.timeout, Timeout):
timeout = setup.timeout.total or setup.timeout.connect or setup.timeout.read
else:
timeout = setup.timeout

connection = Connect(
str(setup.url),
extra_headers={**setup.headers, **setup.cookies.as_header(setup)},
open_timeout=setup.timeout,
open_timeout=timeout,
)
async with connection as ws:
yield WebSocket(request=setup, websocket=ws)
Expand Down
2 changes: 2 additions & 0 deletions nonebot/internal/driver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,7 @@
from .model import Request as Request
from .model import Response as Response
from .model import SimpleQuery as SimpleQuery
from .model import Timeout as Timeout
from .model import TimeoutTypes as TimeoutTypes
from .model import WebSocket as WebSocket
from .model import WebSocketServerSetup as WebSocketServerSetup
5 changes: 3 additions & 2 deletions nonebot/internal/driver/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
QueryTypes,
Request,
Response,
TimeoutTypes,
WebSocket,
WebSocketServerSetup,
)
Expand Down Expand Up @@ -245,7 +246,7 @@ def __init__(
headers: HeaderTypes = None,
cookies: CookieTypes = None,
version: Union[str, HTTPVersion] = HTTPVersion.H11,
timeout: Optional[float] = None,
timeout: TimeoutTypes = None,
proxy: Optional[str] = None,
):
raise NotImplementedError
Expand Down Expand Up @@ -315,7 +316,7 @@ def get_session(
headers: HeaderTypes = None,
cookies: CookieTypes = None,
version: Union[str, HTTPVersion] = HTTPVersion.H11,
timeout: Optional[float] = None,
timeout: TimeoutTypes = None,
proxy: Optional[str] = None,
) -> HTTPClientSession:
"""获取一个 HTTP 会话"""
Expand Down
14 changes: 12 additions & 2 deletions nonebot/internal/driver/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
FileType,
]
FilesTypes: TypeAlias = Union[dict[str, FileTypes], list[tuple[str, FileTypes]], None]
TimeoutTypes: TypeAlias = Union[float, "Timeout", None]


class HTTPVersion(Enum):
Expand All @@ -50,6 +51,15 @@ class HTTPVersion(Enum):
H2 = "2"


@dataclass
class Timeout:
"""Request 超时配置。"""

total: Optional[float] = None
connect: Optional[float] = None
read: Optional[float] = None


class Request:
def __init__(
self,
Expand All @@ -64,7 +74,7 @@ def __init__(
json: Any = None,
files: FilesTypes = None,
version: Union[str, HTTPVersion] = HTTPVersion.H11,
timeout: Optional[float] = None,
timeout: TimeoutTypes = None,
proxy: Optional[str] = None,
):
# method
Expand All @@ -76,7 +86,7 @@ def __init__(
# http version
self.version: HTTPVersion = HTTPVersion(version)
# timeout
self.timeout: Optional[float] = timeout
self.timeout: TimeoutTypes = timeout
# proxy
self.proxy: Optional[str] = proxy

Expand Down
6 changes: 6 additions & 0 deletions tests/test_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
HTTPServerSetup,
Request,
Response,
Timeout,
WebSocket,
WebSocketClientMixin,
WebSocketServerSetup,
Expand Down Expand Up @@ -235,6 +236,7 @@ async def test_http_client(driver: Driver, server_url: URL):
headers={"X-Test": "test"},
cookies={"session": "test"},
content="test",
timeout=Timeout(total=4, connect=2, read=2),
)
response = await driver.request(request)
assert server_url.host is not None
Expand All @@ -250,6 +252,7 @@ async def test_http_client(driver: Driver, server_url: URL):
headers={"X-Test": "test"},
cookies={"session": "test"},
content="test",
timeout=Timeout(total=4, connect=2, read=2),
)
assert request.url == request_raw_url.url, (
"request.url should be equal to request_raw_url.url"
Expand Down Expand Up @@ -312,6 +315,7 @@ async def test_http_client(driver: Driver, server_url: URL):
headers={"X-Test": "stream"},
cookies={"session": "stream"},
content="stream_test" * 1024,
timeout=Timeout(total=4, connect=2, read=2),
)
chunks = []
async for resp in driver.stream_request(request, chunk_size=4):
Expand Down Expand Up @@ -414,6 +418,7 @@ async def test_http_client_session(driver: Driver, server_url: URL):
headers={"X-Test": "test"},
cookies={"cookie": "test"},
content="test",
timeout=Timeout(total=4, connect=2, read=2),
)
response = await session.request(request)
assert response.status_code == 200
Expand Down Expand Up @@ -499,6 +504,7 @@ async def test_http_client_session(driver: Driver, server_url: URL):
headers={"X-Test": "stream"},
cookies={"cookie": "stream"},
content="stream_test" * 1024,
timeout=Timeout(total=4, connect=2, read=2),
)
chunks = []
async for resp in session.stream_request(request, chunk_size=4):
Expand Down
Loading