diff --git a/pyproject.toml b/pyproject.toml index 4f917c9..7326b5c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ dependencies = [ "httpx", "pydantic", "requests", + "types-requests", "tenacity" ] @@ -51,6 +52,7 @@ httpx = [ ] requests = [ "requests", + "types-requests", "tenacity", ] dev = [ @@ -59,6 +61,7 @@ dev = [ "respx", "ruff", "nox", + "types-requests", ] docs = [ "mkdocs", diff --git a/retryhttp/_retry.py b/retryhttp/_retry.py index 0f44ba0..bfbc99c 100644 --- a/retryhttp/_retry.py +++ b/retryhttp/_retry.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union, overload from tenacity import ( RetryCallState, @@ -24,6 +24,32 @@ from ._wait import wait_context_aware, wait_retry_after +@overload +def retry(func: F) -> F: ... + + +@overload +def retry( + func: None = None, + *, + max_attempt_number: int = 3, + retry_server_errors: bool = True, + retry_network_errors: bool = True, + retry_timeouts: bool = True, + retry_rate_limited: bool = True, + wait_server_errors: wait_base = wait_random_exponential(), + wait_network_errors: wait_base = wait_exponential(), + wait_timeouts: wait_base = wait_random_exponential(), + wait_rate_limited: wait_base = wait_retry_after(), + server_error_codes: Union[Sequence[int], int] = (500, 502, 503, 504), + network_errors: Union[ + Type[BaseException], Tuple[Type[BaseException], ...], None + ] = None, + timeouts: Union[Type[BaseException], Tuple[Type[BaseException], ...], None] = None, + **kwargs: Any, +) -> Callable[[F], F]: ... + + def retry( func: Optional[F] = None, *, @@ -42,7 +68,7 @@ def retry( ] = None, timeouts: Union[Type[BaseException], Tuple[Type[BaseException], ...], None] = None, **kwargs: Any, -) -> F: +) -> Union[F, Callable[[F], F]]: """Retry potentially transient HTTP errors with sensible default behavior. By default, retries the following errors, for a total of 3 attempts, with @@ -105,7 +131,7 @@ def retry( if timeouts is None: timeouts = get_default_timeouts() - retry_strategies = [] + retry_strategies: list[retry_base] = [] if retry_server_errors: retry_strategies.append( retry_if_server_error(server_error_codes=server_error_codes) diff --git a/retryhttp/_utils.py b/retryhttp/_utils.py index b33f11c..4cf35c1 100644 --- a/retryhttp/_utils.py +++ b/retryhttp/_utils.py @@ -22,9 +22,7 @@ pass -def get_default_network_errors() -> Tuple[ - Union[Type[httpx.NetworkError], Type[requests.ConnectionError]], ... -]: +def get_default_network_errors() -> Tuple[Type[BaseException], ...]: """Get all network errors to use by default. Args: @@ -37,7 +35,7 @@ def get_default_network_errors() -> Tuple[ N/A """ - exceptions = [] + exceptions: list[type[BaseException]] = [] if _HTTPX_INSTALLED: exceptions.extend( [ @@ -56,16 +54,14 @@ def get_default_network_errors() -> Tuple[ return tuple(exceptions) -def get_default_timeouts() -> Tuple[ - Type[Union[httpx.TimeoutException, requests.Timeout]], ... -]: +def get_default_timeouts() -> Tuple[Type[BaseException], ...]: """Get all timeout exceptions to use by default. Returns: tuple: Timeout exceptions. """ - exceptions = [] + exceptions: list[Type[BaseException]] = [] if _HTTPX_INSTALLED: exceptions.append(httpx.TimeoutException) if _REQUESTS_INSTALLED: @@ -73,16 +69,14 @@ def get_default_timeouts() -> Tuple[ return tuple(exceptions) -def get_default_http_status_exceptions() -> Tuple[ - Union[Type[httpx.HTTPStatusError], Type[requests.HTTPError]], ... -]: +def get_default_http_status_exceptions() -> Tuple[Type[BaseException], ...]: """Get default HTTP status 4xx or 5xx exceptions. Returns: tuple: HTTP status exceptions. """ - exceptions = [] + exceptions: list[Type[BaseException]] = [] if _HTTPX_INSTALLED: exceptions.append(httpx.HTTPStatusError) if _REQUESTS_INSTALLED: @@ -90,7 +84,7 @@ def get_default_http_status_exceptions() -> Tuple[ return tuple(exceptions) -def is_rate_limited(exc: Union[BaseException, None]) -> bool: +def is_rate_limited(exc: Optional[BaseException]) -> bool: """Whether a given exception indicates the user has been rate limited. Args: @@ -100,7 +94,10 @@ def is_rate_limited(exc: Union[BaseException, None]) -> bool: bool: Whether exc indicates rate limiting. """ - if isinstance(exc, get_default_http_status_exceptions()): + if exc is None: + return False + exceptions = get_default_http_status_exceptions() + if isinstance(exc, exceptions) and hasattr(exc, "response"): return exc.response.status_code == 429 return False @@ -120,9 +117,12 @@ def is_server_error( bool: whether exc indicates an error included in status_codes. """ + if exc is None: + return False if isinstance(status_codes, int): status_codes = [status_codes] - if isinstance(exc, get_default_http_status_exceptions()): + exceptions = get_default_http_status_exceptions() + if isinstance(exc, exceptions) and hasattr(exc, "response"): return exc.response.status_code in status_codes return False diff --git a/retryhttp/_wait.py b/retryhttp/_wait.py index d6ddec7..a4c1d53 100644 --- a/retryhttp/_wait.py +++ b/retryhttp/_wait.py @@ -68,7 +68,11 @@ def _get_wait_value(self, retry_state: RetryCallState) -> float: """ if retry_state.outcome: exc = retry_state.outcome.exception() - if isinstance(exc, get_default_http_status_exceptions()): + if exc is None: + return 0 + if isinstance(exc, get_default_http_status_exceptions()) and hasattr( + exc, "response" + ): value = exc.response.headers.get(self.header) if value is None: raise ValueError(f"Header not present: {self.header}")