|
2 | 2 | import time |
3 | 3 | from types import TracebackType |
4 | 4 | from typing import ( |
| 5 | + TYPE_CHECKING, |
5 | 6 | AsyncIterable, |
6 | 7 | AsyncIterator, |
7 | 8 | List, |
|
26 | 27 | from ..backends.base import AsyncNetworkStream |
27 | 28 | from .interfaces import AsyncConnectionInterface |
28 | 29 |
|
| 30 | +if TYPE_CHECKING: |
| 31 | + from typing_extensions import Literal, Protocol, TypeAlias |
| 32 | + |
| 33 | + class _Sentinel(enum.Enum): |
| 34 | + PAUSED = enum.auto() |
| 35 | + NEED_DATA = enum.auto() |
| 36 | + |
| 37 | + _PausedType: TypeAlias = Literal[_Sentinel.PAUSED] |
| 38 | + _NeedDataType: TypeAlias = Literal[_Sentinel.NEED_DATA] |
| 39 | + _PAUSED: _PausedType = _Sentinel.PAUSED |
| 40 | + _NEED_DATA: _NeedDataType = _Sentinel.NEED_DATA |
| 41 | + |
| 42 | + class _NextEventType(Protocol): |
| 43 | + async def __call__(self) -> Union[h11.Event, _PausedType, _NeedDataType]: |
| 44 | + ... |
| 45 | + |
| 46 | +else: |
| 47 | + _PausedType = _PAUSED = h11.PAUSED |
| 48 | + _NeedDataType = _NEED_DATA = h11.NEED_DATA |
| 49 | + _Sentinel = _NextEventType = object |
| 50 | + |
29 | 51 |
|
30 | 52 | class HTTPConnectionState(enum.IntEnum): |
31 | 53 | NEW = 0 |
@@ -165,21 +187,21 @@ async def _receive_response_body(self, request: Request) -> AsyncIterator[bytes] |
165 | 187 | event = await self._receive_event(timeout=timeout) |
166 | 188 | if isinstance(event, h11.Data): |
167 | 189 | yield bytes(event.data) |
168 | | - elif isinstance(event, (h11.EndOfMessage, h11.PAUSED)): |
| 190 | + elif event is _PAUSED: |
| 191 | + break |
| 192 | + elif isinstance(event, h11.EndOfMessage): |
169 | 193 | break |
170 | 194 |
|
171 | 195 | async def _receive_event( |
172 | 196 | self, timeout: Optional[float] = None |
173 | | - ) -> Union[h11.Event, h11.PAUSED]: |
| 197 | + ) -> Union[h11.Event, _PausedType]: |
| 198 | + # The h11 type signature uses a private return type |
| 199 | + next_event = cast(_NextEventType, self._h11_state.next_event) |
174 | 200 | while True: |
175 | 201 | with map_exceptions({h11.RemoteProtocolError: RemoteProtocolError}): |
176 | | - # The h11 type signature uses a private return type |
177 | | - event = cast( |
178 | | - Union[h11.Event, h11.NEED_DATA, h11.PAUSED], |
179 | | - self._h11_state.next_event(), |
180 | | - ) |
| 202 | + event = next_event() |
181 | 203 |
|
182 | | - if isinstance(event, h11.NEED_DATA): |
| 204 | + if event is _NEED_DATA: |
183 | 205 | data = await self._network_stream.read( |
184 | 206 | self.READ_NUM_BYTES, timeout=timeout |
185 | 207 | ) |
|
0 commit comments