diff --git a/flymyai/__init__.py b/flymyai/__init__.py index 0818479..2c67333 100644 --- a/flymyai/__init__.py +++ b/flymyai/__init__.py @@ -1,6 +1,5 @@ import httpx -from .core.authorizations import ClientInfoFactory from .core.client import FlyMyAI, AsyncFlyMyAI from .core.exceptions import FlyMyAIPredictException, FlyMyAIExceptionGroup @@ -11,7 +10,6 @@ "async_run", "FlyMyAI", "AsyncFlyMyAI", - "ClientInfoFactory", "FlyMyAIExceptionGroup", "FlyMyAIPredictException", ] diff --git a/flymyai/core/_client.py b/flymyai/core/_client.py index 3cb9ffb..cbcb3a2 100644 --- a/flymyai/core/_client.py +++ b/flymyai/core/_client.py @@ -15,12 +15,13 @@ from flymyai.core._response_factory import ResponseFactory from flymyai.core._streaming import SSEDecoder -from flymyai.core.authorizations import APIKeyClientInfo, ClientInfoFactory +from flymyai.core.authorizations import APIKeyClientInfo from flymyai.core.exceptions import ( FlyMyAIPredictException, FlyMyAIExceptionGroup, BaseFlyMyAIException, FlyMyAIOpenAPIException, + ImproperlyConfiguredClientException, ) from flymyai.core.models import ( PredictionResponse, @@ -48,64 +49,98 @@ class BaseClient(Generic[_PossibleClients]): _client: _PossibleClients max_retries: int - auth: APIKeyClientInfo + client_info: APIKeyClientInfo - def __init__(self, auth: APIKeyClientInfo | dict, max_retries=DEFAULT_RETRY_COUNT): - if isinstance(auth, dict): - self.auth = ClientInfoFactory(auth).build_auth() - elif isinstance(auth, APIKeyClientInfo): - self.auth = auth - else: - raise TypeError("Invalid credentials. dict required!") + def __init__( + self, apikey: str, model: str | None = None, max_retries=DEFAULT_RETRY_COUNT + ): + self.client_info = APIKeyClientInfo(apikey) + if model: + self.client_info = self.client_info.copy_for_model(model) self._client = self._construct_client() self.max_retries = max_retries + def amend_client_info(self, model: str | None = None): + if model: + client_info = self.client_info.copy_for_model(model) + else: + client_info = self.client_info + if not client_info.project_name or not client_info.username: + raise ImproperlyConfiguredClientException( + "model should be provided as /" + ) + return client_info + @overload - async def predict(self, payload: dict, max_retries=None) -> PredictionResponse: + async def predict( + self, payload: dict, model: str | None = None, max_retries=None + ) -> PredictionResponse: ... @overload - def predict(self, payload: dict, max_retries=None) -> PredictionResponse: + def predict( + self, payload: dict, model: str | None = None, max_retries=None + ) -> PredictionResponse: ... - def predict(self, payload: dict, max_retries=None) -> PredictionResponse: + def predict( + self, payload: dict, model: str | None = None, max_retries=None + ) -> PredictionResponse: ... @overload - async def openapi_schema(self, max_retries=None) -> OpenAPISchemaResponse: + async def openapi_schema( + self, model: str | None = None, max_retries=None + ) -> OpenAPISchemaResponse: ... @overload - def openapi_schema(self, max_retries=None) -> OpenAPISchemaResponse: + def openapi_schema( + self, model: str | None = None, max_retries=None + ) -> OpenAPISchemaResponse: ... - def openapi_schema(self, max_retries=None) -> OpenAPISchemaResponse: + def openapi_schema( + self, model: str | None = None, max_retries=None + ) -> OpenAPISchemaResponse: ... @overload - async def stream(self, payload: dict) -> AsyncIterator[PredictionPartial]: + async def stream( + self, + payload: dict, + model: str | None = None, + ) -> AsyncIterator[PredictionPartial]: ... @overload - def stream(self, payload: dict) -> Iterator[PredictionPartial]: + def stream( + self, + payload: dict, + model: str | None = None, + ) -> Iterator[PredictionPartial]: ... - def stream(self, payload: dict): + def stream( + self, + payload: dict, + model: str | None = None, + ): ... def _stream_iterator( - self, payload: MultipartPayload, is_long_stream: bool + self, client_info, payload: MultipartPayload, is_long_stream: bool ) -> Iterator[httpx.Response] | AsyncIterator[httpx.Response]: return self._client.stream( method="post", url=( - self.auth.prediction_path + client_info.prediction_path if not is_long_stream - else self.auth.prediction_stream_path + else client_info.prediction_stream_path ), **payload.serialize(), timeout=_predict_timeout, - headers=self.auth.authorization_headers, + headers=client_info.authorization_headers, follow_redirects=True, ) @@ -136,7 +171,7 @@ class BaseSyncClient(BaseClient[httpx.Client]): def _construct_client(self): return httpx.Client( http2=True, - headers=self.auth.authorization_headers, + headers=self.client_info.authorization_headers, base_url=os.getenv("FLYMYAI_DSN", "https://api.flymy.ai/"), ) @@ -162,37 +197,44 @@ def _sse_instant(cls, stream_iter_func: Callable[[], Iterator[httpx.Response]]): ).construct() return response - def _predict(self, payload: MultipartPayload): + def _predict(self, payload: MultipartPayload, client_info: APIKeyClientInfo): """ Wrap predict method in sse """ + try: - return self._sse_instant(lambda: self._stream_iterator(payload, False)) + return self._sse_instant( + lambda: self._stream_iterator(client_info, payload, False) + ) except BaseFlyMyAIException as e: raise FlyMyAIPredictException.from_response(e.response) - def predict(self, payload: dict, max_retries=None): + def predict(self, payload: dict, model: str | None = None, max_retries=None): """ Wrap predict method in sse. Retries until max_retries or self.max_retries is reached + :param model: flymyai/bert | None, If none - get self.client_info. :param payload: anything for model :param max_retries: retries :return: PredictionResponse(exc_history, output_data, response): exc_history - list of exception history during prediction output_data - dict with prediction output """ + payload = MultipartPayload(payload) history, response = retryable_callback( - lambda: self._predict(payload), + lambda: self._predict(payload, self.amend_client_info(model)), max_retries or self.max_retries, FlyMyAIPredictException, FlyMyAIExceptionGroup, ) return PredictionResponse.from_response(response, exc_history=history) - def _stream(self, payload: dict): + def _stream(self, client_info: APIKeyClientInfo, payload: dict): payload = MultipartPayload(payload) - response_iterator = self._stream_iterator(payload, is_long_stream=True) + response_iterator = self._stream_iterator( + client_info, payload, is_long_stream=True + ) decoder = SSEDecoder() with response_iterator as sse_stream: for sse_partial in decoder.iter(sse_stream.iter_lines()): @@ -206,8 +248,8 @@ def _stream(self, payload: dict): raise FlyMyAIPredictException.from_response(e.response) yield response - def stream(self, payload: dict): - stream_iter = self._stream(payload) + def stream(self, payload: dict, model: str | None = None): + stream_iter = self._stream(self.amend_client_info(model), payload) last_response = None for response in stream_iter: response.stream = stream_iter @@ -216,31 +258,32 @@ def stream(self, payload: dict): if last_response: last_response.is_stream_consumed = True - def _openapi_schema(self): + def _openapi_schema(self, client_info: APIKeyClientInfo): """ - OpenAPI request for current project, wrapped in executor-method (using HTTP/1) + OpenAPI request for the current project, wrapped in executor-method (using HTTP/1) :return: """ try: return self._wrap_request( lambda: self._client.get( - self.auth.openapi_schema_path, - headers=self.auth.authorization_headers, + client_info.openapi_schema_path, + headers=client_info.authorization_headers, ) ) except BaseFlyMyAIException as e: raise FlyMyAIOpenAPIException.from_response(e.response) - def openapi_schema(self, max_retries=None): + def openapi_schema(self, model: str | None = None, max_retries=None): """ - :param max_retries: retries before giving up + :param model: flymyai/bert + :param max_retries: retries before give up :return: :return: OpenAPISchemaResponse(exc_history, openapi_schema, response): exc_history - dict with exceptions; openapi_schema - dict with openapi; """ history, response = retryable_callback( - lambda: self._openapi_schema(), + lambda: self._openapi_schema(client_info=self.amend_client_info(model)), max_retries or self.max_retries, FlyMyAIPredictException, FlyMyAIExceptionGroup, @@ -250,16 +293,16 @@ def openapi_schema(self, max_retries=None): ) @classmethod - def run_predict(cls, auth: dict, payload: dict): + def run_predict(cls, apikey: str, model: str, payload: dict): """ - :param auth: {"apikey": "...", "username": "...", "project_name": "..."} + :param apikey: fly-... + :param model: flymyai/bert :param payload: jsonable / multipart/form-data available data :return: PredictionResponse(exc_history, output_data, response): exc_history - list of exception history during prediction; output_data - dict with prediction output; """ - auth = ClientInfoFactory(raw_auth=auth).build_auth() - with cls(auth) as client: + with cls(apikey, model) as client: return client.predict(payload) @@ -267,7 +310,7 @@ class BaseAsyncClient(BaseClient[httpx.AsyncClient]): def _construct_client(self): return httpx.AsyncClient( http2=True, - headers=self.auth.authorization_headers, + headers=self.client_info.authorization_headers, base_url=os.getenv("FLYMYAI_DSN", "https://api.flymy.ai/"), ) @@ -278,7 +321,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): if hasattr(self, "_client"): await self._client.aclose() - async def openapi_schema(self, max_retries=None): + async def openapi_schema(self, model: str | None = None, max_retries=None): """ :param max_retries: retries before giving up :return: @@ -296,16 +339,16 @@ async def openapi_schema(self, max_retries=None): exc_history=history, openapi_schema=response.json(), response=response ) - def _openapi_schema(self): + def _openapi_schema(self, client_info: APIKeyClientInfo): """ - OpenAPI request for current project, wrapped in executor-method (using HTTP/1) + OpenAPI request for the current project, wrapped in executor-method (using HTTP/1) :return: """ try: return self._wrap_request( lambda: self._client.get( - self.auth.openapi_schema_path, - headers=self.auth.authorization_headers, + client_info.openapi_schema_path, + headers=client_info.authorization_headers, ) ) except BaseFlyMyAIException as e: @@ -327,7 +370,7 @@ async def _sse_instant( ).construct() return response - def _predict(self, payload: MultipartPayload): + def _predict(self, client_info, payload: MultipartPayload): """ Executes request and waits for sse data :param payload: model input data @@ -337,19 +380,22 @@ def _predict(self, payload: MultipartPayload): return self._sse_instant( lambda: self._client.stream( method="post", - url=self.auth.prediction_path, + url=self.client_info.prediction_path, timeout=_predict_timeout, **payload.serialize(), - headers=self.auth.authorization_headers, + headers=client_info.authorization_headers, ) ) except BaseFlyMyAIException as e: raise FlyMyAIPredictException.from_response(e.response) - async def predict(self, payload: dict, max_retries=None): + async def predict( + self, payload: dict, model: str | None = None, max_retries=None + ) -> PredictionResponse: """ Wrap predict method in sse. Retries until max_retries or self.max_retries is reached + :param model: flymyai/bert :param payload: anything for model :param max_retries: retries :return: PredictionResponse(exc_history, output_data, response): @@ -358,16 +404,18 @@ async def predict(self, payload: dict, max_retries=None): """ payload = MultipartPayload(input_data=payload) history, response = await aretryable_callback( - lambda: self._predict(payload), + lambda: self._predict(self.amend_client_info(model), payload), max_retries or self.max_retries, FlyMyAIPredictException, FlyMyAIExceptionGroup, ) return PredictionResponse.from_response(response, exc_history=history) - async def _stream(self, payload: dict): + async def _stream(self, client_info: APIKeyClientInfo, payload: dict): payload = MultipartPayload(payload) - stream_iterator = self._stream_iterator(payload, is_long_stream=True) + stream_iterator = self._stream_iterator( + client_info, payload, is_long_stream=True + ) decoder = SSEDecoder() async with stream_iterator as sse_stream: async for sse_partial in decoder.aiter(sse_stream.aiter_lines()): @@ -381,8 +429,8 @@ async def _stream(self, payload: dict): raise FlyMyAIPredictException.from_response(e.response) yield response - async def stream(self, payload: dict): - stream_iter = self._stream(payload) + async def stream(self, payload: dict, model: str | None = None, max_retries=None): + stream_iter = self._stream(self.amend_client_info(model), payload) last_response = None async for response in stream_iter: response.stream = stream_iter @@ -406,15 +454,15 @@ async def close(self): await self._client.aclose() @classmethod - async def arun_predict(cls, auth: dict, payload: dict): + async def arun_predict(cls, apikey: str, model: str, payload: dict): """ Execute simple prediction out of a box - :param auth: {"apikey": "...", "username": "...", "project_name": "..."} + :param model: flymyai/bert + :param apikey: fly-... :param payload: {dict with prediction input} :return: PredictionResponse(exc_history, output_data, response) exc_history - list of exception history during prediction output_data - dict with prediction output """ - auth = ClientInfoFactory(raw_auth=auth).build_auth() - async with cls(auth) as client: + async with cls(apikey, model) as client: return await client.predict(payload) diff --git a/flymyai/core/authorizations.py b/flymyai/core/authorizations.py index ca74bb5..6e60422 100644 --- a/flymyai/core/authorizations.py +++ b/flymyai/core/authorizations.py @@ -1,7 +1,10 @@ +import copy import dataclasses import httpx +from flymyai.core.exceptions import ImproperlyConfiguredClientException + class ClientInfo: @@ -35,8 +38,8 @@ class APIKeyClientInfo(ClientInfo): """ apikey: str - username: str - project_name: str + username: str | None = None + project_name: str | None = None @property def authorization_headers(self): @@ -58,21 +61,17 @@ def prediction_stream_path(self): def openapi_schema_path(self): return self._project_path.join(httpx.URL("openapi.json")) - -class ClientInfoFactory: - _raw_auth: dict - - def __init__(self, raw_auth: dict): - self._raw_auth = raw_auth - - def _build_auth(self) -> ClientInfo: - """ - Build authorization - """ - if "apikey" in self._raw_auth: - return APIKeyClientInfo(**self._raw_auth) - else: - raise NotImplemented("This type of authorization is not implemented yet!") - - def build_auth(self): - return self._build_auth() + def copy_for_model(self, model: str): + copied = copy.deepcopy(self) + if not model: + raise ImproperlyConfiguredClientException( + "model should be provided as /" + ) + split_info = model.split("/") + if len(split_info) != 2: + raise ImproperlyConfiguredClientException( + "model should be provided as /" + ) + copied.username = split_info[0] + copied.project_name = split_info[1] + return copied diff --git a/flymyai/core/exceptions.py b/flymyai/core/exceptions.py index b1bba85..f8955d5 100644 --- a/flymyai/core/exceptions.py +++ b/flymyai/core/exceptions.py @@ -8,6 +8,10 @@ ) +class ImproperlyConfiguredClientException(Exception): + ... + + class BaseFlyMyAIException(Exception): msg: str requires_retry: bool diff --git a/tests/test_flymyai_client.py b/tests/test_flymyai_client.py index 9743d2f..a24f5b7 100644 --- a/tests/test_flymyai_client.py +++ b/tests/test_flymyai_client.py @@ -38,15 +38,18 @@ def client_auth_fixture() -> dict: def test_flymyai_client(address_fixture, fake_payload_fixture, client_auth_fixture): response = flymyai_sync_run( - auth=client_auth_fixture, + **client_auth_fixture, payload=fake_payload_fixture, ) assert response def test_flymyai_openapi(address_fixture, client_auth_fixture): - response = flymyai_client(auth=client_auth_fixture).openapi_schema() - assert response + response1 = flymyai_client(**client_auth_fixture).openapi_schema() + response2 = flymyai_client(client_auth_fixture["apikey"]).openapi_schema( + model=client_auth_fixture["model"] + ) + assert response1.model_dump() == response2.model_dump() @pytest.mark.asyncio @@ -54,7 +57,7 @@ async def test_flymyai_async_run( address_fixture, client_auth_fixture, fake_payload_fixture ): response = await flymyai_async_run( - auth=client_auth_fixture, payload=fake_payload_fixture + **client_auth_fixture, payload=fake_payload_fixture ) assert response @@ -62,7 +65,7 @@ async def test_flymyai_async_run( @pytest.mark.asyncio async def test_doc_case(address_fixture, client_auth_fixture, fake_payload_fixture): tasks = [ - asyncio.create_task(flymyai_async_run(auth=client_auth_fixture, payload=prompt)) + asyncio.create_task(flymyai_async_run(**client_auth_fixture, payload=prompt)) for prompt in [fake_payload_fixture] * 3 ] results = await asyncio.gather(*tasks) diff --git a/tests/test_stream.py b/tests/test_stream.py index e83e83c..2cda047 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -25,7 +25,7 @@ def vllm_stream_auth(): def test_vllm_stream(vllm_stream_auth, vllm_stream_payload, dsn): - stream_iterator = sync_client(auth=vllm_stream_auth).stream(vllm_stream_payload) + stream_iterator = sync_client(**vllm_stream_auth).stream(vllm_stream_payload) for response in stream_iterator: assert response.status == 200 assert response.output_data @@ -36,9 +36,7 @@ def test_vllm_stream(vllm_stream_auth, vllm_stream_payload, dsn): @pytest.mark.asyncio async def test_vllm_async_stream(vllm_stream_auth, vllm_stream_payload, dsn): try: - stream_iterator = async_client(auth=vllm_stream_auth).stream( - vllm_stream_payload - ) + stream_iterator = async_client(**vllm_stream_auth).stream(vllm_stream_payload) async for response in stream_iterator: assert response.status == 200 assert response.output_data