Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 0 additions & 38 deletions httpx/auth.py

This file was deleted.

212 changes: 45 additions & 167 deletions httpx/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import hstspreload

from .auth import HTTPBasicAuth
from .concurrency.asyncio import AsyncioBackend
from .concurrency.base import ConcurrencyBackend
from .config import (
Expand All @@ -19,16 +18,13 @@
)
from .dispatch.asgi import ASGIDispatch
from .dispatch.base import AsyncDispatcher, Dispatcher
from .dispatch.basic_auth import BasicAuthDispatcher
from .dispatch.connection_pool import ConnectionPool
from .dispatch.custom_auth import CustomAuthDispatcher
from .dispatch.redirect import RedirectDispatcher
from .dispatch.threaded import ThreadedDispatcher
from .dispatch.wsgi import WSGIDispatch
from .exceptions import (
HTTPError,
InvalidURL,
RedirectBodyUnavailable,
RedirectLoop,
TooManyRedirects,
)
from .exceptions import HTTPError, InvalidURL
from .models import (
URL,
AsyncRequest,
Expand All @@ -47,7 +43,6 @@
ResponseContent,
URLTypes,
)
from .status_codes import codes
from .utils import get_netrc_login


Expand All @@ -67,7 +62,7 @@ def __init__(
dispatch: typing.Union[AsyncDispatcher, Dispatcher] = None,
app: typing.Callable = None,
backend: ConcurrencyBackend = None,
trust_env: bool = None,
trust_env: bool = True,
):
if backend is None:
backend = AsyncioBackend()
Expand Down Expand Up @@ -107,7 +102,7 @@ def __init__(
self.max_redirects = max_redirects
self.dispatch = async_dispatch
self.concurrency_backend = backend
self.trust_env = True if trust_env is None else trust_env
self.trust_env = trust_env

def check_concurrency_backend(self, backend: ConcurrencyBackend) -> None:
pass # pragma: no cover
Expand Down Expand Up @@ -146,45 +141,29 @@ async def send(
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
trust_env: bool = None,
trust_env: typing.Optional[bool] = None,
) -> AsyncResponse:
if auth is None:
auth = self.auth

url = request.url

if url.scheme not in ("http", "https"):
raise InvalidURL('URL scheme must be "http" or "https".')

if auth is None:
if url.username or url.password:
auth = HTTPBasicAuth(username=url.username, password=url.password)
elif self.trust_env if trust_env is None else trust_env:
netrc_login = get_netrc_login(url.authority)
if netrc_login:
netrc_username, _, netrc_password = netrc_login
auth = HTTPBasicAuth(
username=netrc_username, password=netrc_password
)

if auth is not None:
if isinstance(auth, tuple):
auth = HTTPBasicAuth(username=auth[0], password=auth[1])
request = auth(request)
dispatcher: AsyncDispatcher = self._resolve_dispatcher(
request,
auth or self.auth,
self.trust_env if trust_env is None else trust_env,
allow_redirects,
)

try:
response = await self.send_handling_redirects(
request,
verify=verify,
cert=cert,
timeout=timeout,
allow_redirects=allow_redirects,
response = await dispatcher.send(
request, verify=verify, cert=cert, timeout=timeout
)
except HTTPError as exc:
# Add the original request to any HTTPError
exc.request = request
raise

self.cookies.extract_cookies(response)
if not stream:
try:
await response.read()
Expand All @@ -193,143 +172,42 @@ async def send(

return response

async def send_handling_redirects(
def _resolve_dispatcher(
self,
request: AsyncRequest,
*,
cert: CertTypes = None,
verify: VerifyTypes = None,
timeout: TimeoutTypes = None,
auth: AuthTypes = None,
trust_env: bool = False,
allow_redirects: bool = True,
history: typing.List[AsyncResponse] = None,
) -> AsyncResponse:
if history is None:
history = []

while True:
# We perform these checks here, so that calls to `response.next()`
# will raise redirect errors if appropriate.
if len(history) > self.max_redirects:
raise TooManyRedirects(response=history[-1])
if request.url in (response.url for response in history):
raise RedirectLoop(response=history[-1])

response = await self.dispatch.send(
request, verify=verify, cert=cert, timeout=timeout
)

should_close_response = True
try:
assert isinstance(response, AsyncResponse)
response.history = list(history)
self.cookies.extract_cookies(response)
history.append(response)

if allow_redirects and response.is_redirect:
request = self.build_redirect_request(request, response)
else:
should_close_response = False
break
finally:
if should_close_response:
await response.close()

if response.is_redirect:

async def send_next() -> AsyncResponse:
nonlocal request, response, verify, cert
nonlocal allow_redirects, timeout, history
request = self.build_redirect_request(request, response)
response = await self.send_handling_redirects(
request,
allow_redirects=allow_redirects,
verify=verify,
cert=cert,
timeout=timeout,
history=history,
)
return response

response.next = send_next # type: ignore

return response

def build_redirect_request(
self, request: AsyncRequest, response: AsyncResponse
) -> AsyncRequest:
method = self.redirect_method(request, response)
url = self.redirect_url(request, response)
headers = self.redirect_headers(request, url)
content = self.redirect_content(request, method, response)
cookies = self.merge_cookies(request.cookies)
return AsyncRequest(
method=method, url=url, headers=headers, data=content, cookies=cookies
) -> AsyncDispatcher:
dispatcher: AsyncDispatcher = RedirectDispatcher(
next_dispatcher=self.dispatch,
base_cookies=self.cookies,
allow_redirects=allow_redirects,
)

def redirect_method(self, request: AsyncRequest, response: AsyncResponse) -> str:
"""
When being redirected we may want to change the method of the request
based on certain specs or browser behavior.
"""
method = request.method

# https://tools.ietf.org/html/rfc7231#section-6.4.4
if response.status_code == codes.SEE_OTHER and method != "HEAD":
method = "GET"

# Do what the browsers do, despite standards...
# Turn 302s into GETs.
if response.status_code == codes.FOUND and method != "HEAD":
method = "GET"

# If a POST is responded to with a 301, turn it into a GET.
# This bizarre behaviour is explained in 'requests' issue 1704.
if response.status_code == codes.MOVED_PERMANENTLY and method == "POST":
method = "GET"

return method

def redirect_url(self, request: AsyncRequest, response: AsyncResponse) -> URL:
"""
Return the URL for the redirect to follow.
"""
location = response.headers["Location"]

url = URL(location, allow_relative=True)

# Facilitate relative 'Location' headers, as allowed by RFC 7231.
# (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource')
if url.is_relative_url:
url = request.url.join(url)

# Attach previous fragment if needed (RFC 7231 7.1.2)
if request.url.fragment and not url.fragment:
url = url.copy_with(fragment=request.url.fragment)

return url
username: typing.Optional[typing.Union[str, bytes]] = None
password: typing.Optional[typing.Union[str, bytes]] = None
if auth is None:
if request.url.username or request.url.password:
username, password = request.url.username, request.url.password
elif trust_env:
netrc_login = get_netrc_login(request.url.authority)
if netrc_login:
username, _, password = netrc_login
else:
if isinstance(auth, tuple):
username, password = auth[0], auth[1]
elif callable(auth):
dispatcher = CustomAuthDispatcher(
next_dispatcher=dispatcher, auth_callable=auth
)

def redirect_headers(self, request: AsyncRequest, url: URL) -> Headers:
"""
Strip Authorization headers when responses are redirected away from
the origin.
"""
headers = Headers(request.headers)
if url.origin != request.url.origin:
del headers["Authorization"]
del headers["host"]
return headers
if username is not None and password is not None:
dispatcher = BasicAuthDispatcher(
next_dispatcher=dispatcher, username=username, password=password
)

def redirect_content(
self, request: AsyncRequest, method: str, response: AsyncResponse
) -> bytes:
"""
Return the body that should be used for the redirect request.
"""
if method != request.method and method == "GET":
return b""
if request.is_streaming:
raise RedirectBodyUnavailable(response=response)
return request.content
return dispatcher


class AsyncClient(BaseClient):
Expand Down
2 changes: 1 addition & 1 deletion httpx/concurrency/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
from .base import (
BaseBackgroundManager,
BasePoolSemaphore,
BaseEvent,
BasePoolSemaphore,
BaseQueue,
BaseStream,
ConcurrencyBackend,
Expand Down
4 changes: 2 additions & 2 deletions httpx/dispatch/asgi.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import typing

from .base import AsyncDispatcher
from ..concurrency.base import ConcurrencyBackend
from ..concurrency.asyncio import AsyncioBackend
from ..concurrency.base import ConcurrencyBackend
from ..config import CertTypes, TimeoutTypes, VerifyTypes
from ..models import AsyncRequest, AsyncResponse
from .base import AsyncDispatcher


class ASGIDispatch(AsyncDispatcher):
Expand Down
43 changes: 43 additions & 0 deletions httpx/dispatch/basic_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import typing
from base64 import b64encode

from ..config import CertTypes, TimeoutTypes, VerifyTypes
from ..models import AsyncRequest, AsyncResponse
from .base import AsyncDispatcher


class BasicAuthDispatcher(AsyncDispatcher):
def __init__(
self,
next_dispatcher: AsyncDispatcher,
username: typing.Union[str, bytes],
password: typing.Union[str, bytes],
):
self.next_dispatcher = next_dispatcher
self.username = username
self.password = password

async def send(
self,
request: AsyncRequest,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
) -> AsyncResponse:
request.headers["Authorization"] = self.build_auth_header()
return await self.next_dispatcher.send(
request, verify=verify, cert=cert, timeout=timeout
)

def build_auth_header(self) -> str:
username, password = self.username, self.password

if isinstance(username, str):
username = username.encode("latin1")

if isinstance(password, str):
password = password.encode("latin1")

userpass = b":".join((username, password))
token = b64encode(userpass).decode().strip()
return f"Basic {token}"
Loading