Skip to content

Commit 627db6c

Browse files
committed
hack h11 Sentinel return type with an enum protocol and cast
1 parent f42eab9 commit 627db6c

File tree

2 files changed

+60
-16
lines changed

2 files changed

+60
-16
lines changed

httpcore/_async/http11.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import time
33
from types import TracebackType
44
from typing import (
5+
TYPE_CHECKING,
56
AsyncIterable,
67
AsyncIterator,
78
List,
@@ -26,6 +27,27 @@
2627
from ..backends.base import AsyncNetworkStream
2728
from .interfaces import AsyncConnectionInterface
2829

30+
if TYPE_CHECKING:
31+
from typing_extensions import Literal, Protocol, TypeAlias
32+
33+
class _Sentinel(enum.Enum):
34+
PAUSED = enum.auto()
35+
NEED_DATA = enum.auto()
36+
37+
_PausedType: TypeAlias = Literal[_Sentinel.PAUSED]
38+
_NeedDataType: TypeAlias = Literal[_Sentinel.NEED_DATA]
39+
_PAUSED: _PausedType = _Sentinel.PAUSED
40+
_NEED_DATA: _NeedDataType = _Sentinel.NEED_DATA
41+
42+
class _NextEventType(Protocol):
43+
async def __call__(self) -> Union[h11.Event, _PausedType, _NeedDataType]:
44+
...
45+
46+
else:
47+
_PausedType = _PAUSED = h11.PAUSED
48+
_NeedDataType = _NEED_DATA = h11.NEED_DATA
49+
_Sentinel = _NextEventType = object
50+
2951

3052
class HTTPConnectionState(enum.IntEnum):
3153
NEW = 0
@@ -165,21 +187,21 @@ async def _receive_response_body(self, request: Request) -> AsyncIterator[bytes]
165187
event = await self._receive_event(timeout=timeout)
166188
if isinstance(event, h11.Data):
167189
yield bytes(event.data)
168-
elif isinstance(event, (h11.EndOfMessage, h11.PAUSED)):
190+
elif event is _PAUSED:
191+
break
192+
elif isinstance(event, h11.EndOfMessage):
169193
break
170194

171195
async def _receive_event(
172196
self, timeout: Optional[float] = None
173-
) -> Union[h11.Event, h11.PAUSED]:
197+
) -> Union[h11.Event, _PausedType]:
198+
# The h11 type signature uses a private return type
199+
next_event = cast(_NextEventType, self._h11_state.next_event)
174200
while True:
175201
with map_exceptions({h11.RemoteProtocolError: RemoteProtocolError}):
176-
# The h11 type signature uses a private return type
177-
event = cast(
178-
Union[h11.Event, h11.NEED_DATA, h11.PAUSED],
179-
self._h11_state.next_event(),
180-
)
202+
event = next_event()
181203

182-
if isinstance(event, h11.NEED_DATA):
204+
if event is _NEED_DATA:
183205
data = await self._network_stream.read(
184206
self.READ_NUM_BYTES, timeout=timeout
185207
)

httpcore/_sync/http11.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import time
33
from types import TracebackType
44
from typing import (
5+
TYPE_CHECKING,
56
Iterable,
67
Iterator,
78
List,
@@ -26,6 +27,27 @@
2627
from ..backends.base import NetworkStream
2728
from .interfaces import ConnectionInterface
2829

30+
if TYPE_CHECKING:
31+
from typing_extensions import Literal, Protocol, TypeAlias
32+
33+
class _Sentinel(enum.Enum):
34+
PAUSED = enum.auto()
35+
NEED_DATA = enum.auto()
36+
37+
_PausedType: TypeAlias = Literal[_Sentinel.PAUSED]
38+
_NeedDataType: TypeAlias = Literal[_Sentinel.NEED_DATA]
39+
_PAUSED: _PausedType = _Sentinel.PAUSED
40+
_NEED_DATA: _NeedDataType = _Sentinel.NEED_DATA
41+
42+
class _NextEventType(Protocol):
43+
def __call__(self) -> Union[h11.Event, _PausedType, _NeedDataType]:
44+
...
45+
46+
else:
47+
_PausedType = _PAUSED = h11.PAUSED
48+
_NeedDataType = _NEED_DATA = h11.NEED_DATA
49+
_Sentinel = _NextEventType = object
50+
2951

3052
class HTTPConnectionState(enum.IntEnum):
3153
NEW = 0
@@ -165,21 +187,21 @@ def _receive_response_body(self, request: Request) -> Iterator[bytes]:
165187
event = self._receive_event(timeout=timeout)
166188
if isinstance(event, h11.Data):
167189
yield bytes(event.data)
168-
elif isinstance(event, (h11.EndOfMessage, h11.PAUSED)):
190+
elif event is _PAUSED:
191+
break
192+
elif isinstance(event, h11.EndOfMessage):
169193
break
170194

171195
def _receive_event(
172196
self, timeout: Optional[float] = None
173-
) -> Union[h11.Event, h11.PAUSED]:
197+
) -> Union[h11.Event, _PausedType]:
198+
# The h11 type signature uses a private return type
199+
next_event = cast(_NextEventType, self._h11_state.next_event)
174200
while True:
175201
with map_exceptions({h11.RemoteProtocolError: RemoteProtocolError}):
176-
# The h11 type signature uses a private return type
177-
event = cast(
178-
Union[h11.Event, h11.NEED_DATA, h11.PAUSED],
179-
self._h11_state.next_event(),
180-
)
202+
event = next_event()
181203

182-
if isinstance(event, h11.NEED_DATA):
204+
if event is _NEED_DATA:
183205
data = self._network_stream.read(
184206
self.READ_NUM_BYTES, timeout=timeout
185207
)

0 commit comments

Comments
 (0)