Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions aidial_client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from aidial_client._auth import AsyncAuthValue, AuthType, SyncAuthValue
from aidial_client._auth import AsyncAuthValue, SyncAuthValue
from aidial_client._client import AsyncDial, Dial
from aidial_client._client_pool import AsyncDialClientPool, DialClientPool
from aidial_client._exception import (
Expand All @@ -15,7 +15,6 @@
"AsyncDial",
"DialClientPool",
"AsyncDialClientPool",
"AuthType",
"SyncAuthValue",
"AsyncAuthValue",
# Exceptions
Expand Down
122 changes: 55 additions & 67 deletions aidial_client/_auth.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,5 @@
from enum import Enum
from inspect import isawaitable
from typing import (
Awaitable,
Callable,
Dict,
Optional,
Tuple,
TypeVar,
Union,
overload,
)

from typing_extensions import assert_never


class AuthType(Enum):
API_KEY = "API_KEY"
BEARER = "BEARER"

from typing import Awaitable, Callable, Dict, Optional, TypeVar, Union

SyncAuthValue = Union[str, Callable[[], str]]
AsyncAuthValue = Union[SyncAuthValue, Callable[[], Awaitable[str]]]
Expand All @@ -28,72 +10,78 @@ class AuthType(Enum):
)


@overload
def get_auth_value(auth_value: SyncAuthValue) -> str: ...

def get_auth_value(auth_value: SyncAuthValue) -> str:
if isinstance(auth_value, str):
return auth_value
if callable(auth_value):
return auth_value()
from typing import TYPE_CHECKING, assert_never

@overload
def get_auth_value(
auth_value: AsyncAuthValue,
) -> Union[str, Awaitable[str]]: ...
if TYPE_CHECKING:
assert_never(auth_value)
raise TypeError(
Comment thread
adubovik marked this conversation as resolved.
f"auth_value must be a string or a callable returning a string, got {type(auth_value).__name__}"
)


def get_auth_value(
auth_value: Union[SyncAuthValue, AsyncAuthValue]
) -> Union[str, Awaitable[str]]:
async def aget_auth_value(auth_value: AsyncAuthValue) -> str:
if isinstance(auth_value, str):
return auth_value
elif callable(auth_value):
return auth_value()
else:
if callable(auth_value):
result = auth_value()
return await result if isawaitable(result) else result
from typing import TYPE_CHECKING, assert_never

if TYPE_CHECKING:
assert_never(auth_value)
raise TypeError(
Comment thread
adubovik marked this conversation as resolved.
f"auth_value must be a string or a callable, got {type(auth_value).__name__}"
)


async def aget_auth_value(auth_value: AsyncAuthValue) -> str:
processed_auth_value = get_auth_value(auth_value)
if isawaitable(processed_auth_value):
return await processed_auth_value
return processed_auth_value
def get_combined_auth_headers(
*,
api_key: Optional[SyncAuthValue] = None,
bearer_token: Optional[SyncAuthValue] = None,
) -> Dict[str, str]:
headers: Dict[str, str] = {}

if api_key is not None:
headers["api-key"] = get_auth_value(api_key)

if bearer_token is not None:
bearer_str = get_auth_value(bearer_token)
headers["Authorization"] = f"Bearer {bearer_str}"

def _get_auth_headers(auth_type: AuthType, auth_value: str) -> Dict[str, str]:
if auth_type == AuthType.API_KEY:
return {"api-key": auth_value}
elif auth_type == AuthType.BEARER:
return {"Authorization": f"Bearer {auth_value}"}
else:
assert_never(auth_type)
return headers


def get_auth_headers(
async def aget_combined_auth_headers(
*,
auth_value: SyncAuthValue,
auth_type: AuthType,
api_key: Optional[AsyncAuthValue] = None,
bearer_token: Optional[AsyncAuthValue] = None,
) -> Dict[str, str]:
processed_auth_value = get_auth_value(auth_value)
return _get_auth_headers(auth_type, processed_auth_value)
"""Get combined authentication headers from both api_key and bearer_token (async)."""
headers: Dict[str, str] = {}

if api_key is not None:
processed_api_key = await aget_auth_value(api_key)
headers["api-key"] = processed_api_key

async def aget_auth_headers(
auth_value: AsyncAuthValue,
auth_type: AuthType,
) -> Dict[str, str]:
processed_auth_value = await aget_auth_value(auth_value)
return _get_auth_headers(auth_type, processed_auth_value)
if bearer_token is not None:
processed_bearer_token = await aget_auth_value(bearer_token)
headers["Authorization"] = f"Bearer {processed_bearer_token}"

return headers


def process_auth(
def validate_auth(
*,
api_key: Optional[AuthValueT] = None,
bearer_token: Optional[AuthValueT] = None,
) -> Tuple[AuthType, AuthValueT]:
if api_key and bearer_token:
api_key: Optional[AsyncAuthValue] = None,
bearer_token: Optional[AsyncAuthValue] = None,
) -> None:
"""Validate that at least one authentication method is provided."""
if not api_key and not bearer_token:
raise ValueError(
"Either api_key or bearer_token must be provided, but not both"
"At least one of api_key or bearer_token must be provided"
)
elif api_key:
return AuthType.API_KEY, api_key
elif bearer_token:
return AuthType.BEARER, bearer_token
else:
raise ValueError("Either api_key or bearer_token must be provided")
25 changes: 13 additions & 12 deletions aidial_client/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
import aidial_client.resources as resources
from aidial_client._auth import (
AsyncAuthValue,
AuthType,
AuthValueT,
SyncAuthValue,
process_auth,
validate_auth,
)
from aidial_client._constants import (
API_PREFIX,
Expand All @@ -31,8 +30,8 @@


class BaseDialClient(Generic[_HttpClientT, AuthValueT], ABC):
_auth_type: AuthType
_auth_value: AuthValueT
_api_key: Optional[AuthValueT]
_bearer_token: Optional[AuthValueT]
_base_url: str
_http_client: _HttpClientT
_auth_headers: Dict[str, str]
Expand All @@ -50,9 +49,9 @@ def __init__(
api_version: Optional[str] = None,
http_client: Optional[_HttpClientT] = None,
):
self._auth_type, self._auth_value = process_auth(
api_key=api_key, bearer_token=bearer_token
)
validate_auth(api_key=api_key, bearer_token=bearer_token)
self._api_key = api_key
self._bearer_token = bearer_token
self._max_retries = max_retries
self._timeout = timeout
self._base_url = enforce_trailing_slash(base_url)
Expand Down Expand Up @@ -85,6 +84,7 @@ def api_version(self) -> Optional[str]:


class Dial(BaseDialClient[SyncHTTPClient, SyncAuthValue]):

def _init_resources(self) -> None:
openai_client = openai.AzureOpenAI(
api_key="-",
Expand Down Expand Up @@ -113,8 +113,8 @@ def _init_resources(self) -> None:
def _create_http_client(self) -> SyncHTTPClient:
return SyncHTTPClient(
self._base_url,
self._auth_value,
self._auth_type,
self._api_key,
self._bearer_token,
self._max_retries,
self._timeout,
)
Expand Down Expand Up @@ -156,10 +156,11 @@ def auth_headers(self) -> Dict[str, str]:


class AsyncDial(BaseDialClient[AsyncHTTPClient, AsyncAuthValue]):

def _init_resources(self) -> None:
openai_client = openai.AsyncAzureOpenAI(
# set empty string, we will override
# it with our client values during request
# it with our client values during a request
api_key="",
api_version="",
base_url=urljoin(self._base_url, OPENAI_PREFIX),
Expand Down Expand Up @@ -192,8 +193,8 @@ def _init_resources(self) -> None:
def _create_http_client(self) -> AsyncHTTPClient:
return AsyncHTTPClient(
self._base_url,
self._auth_value,
self._auth_type,
self._api_key,
self._bearer_token,
self._max_retries,
self._timeout,
)
Expand Down
16 changes: 5 additions & 11 deletions aidial_client/_client_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import httpx

from aidial_client._auth import AsyncAuthValue, SyncAuthValue, process_auth
from aidial_client._auth import AsyncAuthValue, SyncAuthValue
from aidial_client._client import AsyncDial, Dial
from aidial_client._constants import (
DEFAULT_CONNECTION_LIMITS,
Expand Down Expand Up @@ -32,17 +32,14 @@ def create_client(
max_retries: int = DEFAULT_MAX_RETRIES,
timeout: Union[httpx.Timeout, float] = DEFAULT_TIMEOUT,
) -> Dial:
auth_type, auth_value = process_auth(
api_key=api_key, bearer_token=bearer_token
)
return Dial(
base_url=base_url,
api_key=api_key,
bearer_token=bearer_token,
http_client=SyncHTTPClient(
base_url=base_url,
auth_value=auth_value,
auth_type=auth_type,
api_key=api_key,
bearer_token=bearer_token,
max_retries=max_retries,
timeout=timeout,
internal_http_client=self._internal_http_client,
Expand Down Expand Up @@ -70,17 +67,14 @@ def create_client(
max_retries: int = DEFAULT_MAX_RETRIES,
timeout: Union[httpx.Timeout, float] = DEFAULT_TIMEOUT,
) -> AsyncDial:
auth_type, auth_value = process_auth(
api_key=api_key, bearer_token=bearer_token
)
return AsyncDial(
base_url=base_url,
api_key=api_key,
bearer_token=bearer_token,
http_client=AsyncHTTPClient(
base_url=base_url,
auth_value=auth_value,
auth_type=auth_type,
api_key=api_key,
bearer_token=bearer_token,
max_retries=max_retries,
timeout=timeout,
internal_http_client=self._internal_http_client,
Expand Down
8 changes: 4 additions & 4 deletions aidial_client/_http_client/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import httpx

from aidial_client._auth import AsyncAuthValue, aget_auth_headers
from aidial_client._auth import AsyncAuthValue, aget_combined_auth_headers
from aidial_client._exception import DialException
from aidial_client._http_client._base import BaseHTTPClient
from aidial_client._internal_types._generic import ResponseT
Expand All @@ -20,8 +20,8 @@ def _create_internal_client(self) -> httpx.AsyncClient:
)

async def auth_headers(self) -> Dict[str, str]:
return await aget_auth_headers(
auth_value=self._auth_value, auth_type=self._auth_type
return await aget_combined_auth_headers(
api_key=self._api_key, bearer_token=self._bearer_token
)

async def _retry_request(
Expand Down Expand Up @@ -99,7 +99,7 @@ async def request(
cast_to=cast_to,
remaining_retries=retries,
)
# Try to get custom error from response status_code/code/message
# Try to get a custom error from response status_code/code/message
custom_error = on_http_error(err) if on_http_error else None
# or fallback to default processing
raised_error = custom_error or self._make_dial_error_from_response(
Expand Down
14 changes: 7 additions & 7 deletions aidial_client/_http_client/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import httpx

from aidial_client._auth import AuthType, AuthValueT
from aidial_client._auth import AuthValueT
from aidial_client._constants import INITIAL_RETRY_DELAY, MAX_RETRY_DELAY
from aidial_client._exception import DialException
from aidial_client._internal_types._http_request import FinalRequestOptions
Expand All @@ -19,21 +19,21 @@

class BaseHTTPClient(ABC, Generic[_HttpInternalClientT, AuthValueT]):
_internal_http_client: _HttpInternalClientT
_auth_value: AuthValueT
_auth_type: AuthType
_api_key: Optional[AuthValueT]
_bearer_token: Optional[AuthValueT]

def __init__(
self,
base_url: str,
auth_value: AuthValueT,
auth_type: AuthType,
api_key: Optional[AuthValueT],
bearer_token: Optional[AuthValueT],
max_retries: int,
timeout: Union[float, httpx.Timeout, None],
internal_http_client: Optional[_HttpInternalClientT] = None,
):
self.base_url = httpx.URL(enforce_trailing_slash(base_url))
self._auth_value = auth_value
self._auth_type = auth_type
self._api_key = api_key
self._bearer_token = bearer_token
self._max_retries = max_retries
self._timeout = timeout
self._internal_http_client = (
Expand Down
8 changes: 4 additions & 4 deletions aidial_client/_http_client/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import httpx

from aidial_client._auth import SyncAuthValue, get_auth_headers
from aidial_client._auth import SyncAuthValue, get_combined_auth_headers
from aidial_client._exception import DialException
from aidial_client._http_client._base import BaseHTTPClient
from aidial_client._internal_types._generic import ResponseT
Expand Down Expand Up @@ -37,8 +37,8 @@ def _retry_request(
)

def auth_headers(self) -> Dict[str, str]:
return get_auth_headers(
auth_value=self._auth_value, auth_type=self._auth_type
return get_combined_auth_headers(
api_key=self._api_key, bearer_token=self._bearer_token
)

def request(
Expand Down Expand Up @@ -99,7 +99,7 @@ def request(
cast_to=cast_to,
remaining_retries=retries,
)
# Try to get custom error from response status_code/code/message
# Try to get a custom error from response status_code/code/message
custom_error = on_http_error(err) if on_http_error else None
# or fallback to default processing
raised_error = custom_error or self._make_dial_error_from_response(
Expand Down
Loading