diff --git a/flymyai/core/_client.py b/flymyai/core/_client.py index 43affdd..3cb9ffb 100644 --- a/flymyai/core/_client.py +++ b/flymyai/core/_client.py @@ -8,6 +8,7 @@ overload, Iterator, AsyncContextManager, + AsyncIterator, ) import httpx @@ -21,7 +22,11 @@ BaseFlyMyAIException, FlyMyAIOpenAPIException, ) -from flymyai.core.models import PredictionResponse, OpenAPISchemaResponse +from flymyai.core.models import ( + PredictionResponse, + OpenAPISchemaResponse, + PredictionPartial, +) from flymyai.multipart.payload import MultipartPayload from flymyai.utils.utils import retryable_callback, aretryable_callback @@ -56,14 +61,14 @@ def __init__(self, auth: APIKeyClientInfo | dict, max_retries=DEFAULT_RETRY_COUN self.max_retries = max_retries @overload - async def predict(self, input_data: dict, max_retries=None) -> PredictionResponse: + async def predict(self, payload: dict, max_retries=None) -> PredictionResponse: ... @overload - def predict(self, input_data: dict, max_retries=None) -> PredictionResponse: + def predict(self, payload: dict, max_retries=None) -> PredictionResponse: ... - def predict(self, input_data: dict, max_retries=None) -> PredictionResponse: + def predict(self, payload: dict, max_retries=None) -> PredictionResponse: ... @overload @@ -77,6 +82,33 @@ def openapi_schema(self, max_retries=None) -> OpenAPISchemaResponse: def openapi_schema(self, max_retries=None) -> OpenAPISchemaResponse: ... + @overload + async def stream(self, payload: dict) -> AsyncIterator[PredictionPartial]: + ... + + @overload + def stream(self, payload: dict) -> Iterator[PredictionPartial]: + ... + + def stream(self, payload: dict): + ... + + def _stream_iterator( + self, payload: MultipartPayload, is_long_stream: bool + ) -> Iterator[httpx.Response] | AsyncIterator[httpx.Response]: + return self._client.stream( + method="post", + url=( + self.auth.prediction_path + if not is_long_stream + else self.auth.prediction_stream_path + ), + **payload.serialize(), + timeout=_predict_timeout, + headers=self.auth.authorization_headers, + follow_redirects=True, + ) + @staticmethod def _wrap_request(request_callback: Callable): response = request_callback() @@ -135,15 +167,7 @@ def _predict(self, payload: MultipartPayload): Wrap predict method in sse """ try: - return self._sse_instant( - lambda: self._client.stream( - method="post", - url=self.auth.prediction_path, - **payload.serialize(), - timeout=_predict_timeout, - headers=self.auth.authorization_headers, - ) - ) + return self._sse_instant(lambda: self._stream_iterator(payload, False)) except BaseFlyMyAIException as e: raise FlyMyAIPredictException.from_response(e.response) @@ -164,9 +188,33 @@ def predict(self, payload: dict, max_retries=None): FlyMyAIPredictException, FlyMyAIExceptionGroup, ) - return PredictionResponse( - exc_history=history, response=response, **response.json() - ) + return PredictionResponse.from_response(response, exc_history=history) + + def _stream(self, payload: dict): + payload = MultipartPayload(payload) + response_iterator = self._stream_iterator(payload, is_long_stream=True) + decoder = SSEDecoder() + with response_iterator as sse_stream: + for sse_partial in decoder.iter(sse_stream.iter_lines()): + try: + response = ResponseFactory( + sse=sse_partial, + httpx_request=sse_stream.request, + httpx_response=sse_stream, + ).construct() + except BaseFlyMyAIException as e: + raise FlyMyAIPredictException.from_response(e.response) + yield response + + def stream(self, payload: dict): + stream_iter = self._stream(payload) + last_response = None + for response in stream_iter: + response.stream = stream_iter + yield PredictionPartial.from_response(response) + last_response = response + if last_response: + last_response.is_stream_consumed = True def _openapi_schema(self): """ @@ -197,7 +245,7 @@ def openapi_schema(self, max_retries=None): FlyMyAIPredictException, FlyMyAIExceptionGroup, ) - return OpenAPISchemaResponse( + return OpenAPISchemaResponse.from_response( exc_history=history, openapi_schema=response.json(), response=response ) @@ -244,7 +292,7 @@ async def openapi_schema(self, max_retries=None): FlyMyAIPredictException, FlyMyAIExceptionGroup, ) - return OpenAPISchemaResponse( + return OpenAPISchemaResponse.from_response( exc_history=history, openapi_schema=response.json(), response=response ) @@ -315,9 +363,33 @@ async def predict(self, payload: dict, max_retries=None): FlyMyAIPredictException, FlyMyAIExceptionGroup, ) - return PredictionResponse( - exc_history=history, response=response, **response.json() - ) + return PredictionResponse.from_response(response, exc_history=history) + + async def _stream(self, payload: dict): + payload = MultipartPayload(payload) + stream_iterator = self._stream_iterator(payload, is_long_stream=True) + decoder = SSEDecoder() + async with stream_iterator as sse_stream: + async for sse_partial in decoder.aiter(sse_stream.aiter_lines()): + try: + response = ResponseFactory( + sse=sse_partial, + httpx_request=sse_stream.request, + httpx_response=sse_stream, + ).construct() + except BaseFlyMyAIException as e: + raise FlyMyAIPredictException.from_response(e.response) + yield response + + async def stream(self, payload: dict): + stream_iter = self._stream(payload) + last_response = None + async for response in stream_iter: + response.stream = stream_iter + yield PredictionPartial.from_response(response) + last_response = response + if last_response: + last_response.is_stream_consumed = True @staticmethod async def _wrap_request(request_callback: Callable[..., Awaitable[httpx.Response]]): diff --git a/flymyai/core/_response_factory.py b/flymyai/core/_response_factory.py index 2df01d5..1dc7484 100644 --- a/flymyai/core/_response_factory.py +++ b/flymyai/core/_response_factory.py @@ -30,7 +30,9 @@ def __init__( self.httpx_response = httpx_response def get_sse_status_code(self): - return self.sse.json().get("status_code", 200) + return self.sse.json().get( + "status", self.httpx_response.status_code if self.httpx_response else 200 + ) def _base_construct_from_sse(self): sse_status = self.get_sse_status_code() diff --git a/flymyai/core/authorizations.py b/flymyai/core/authorizations.py index 280e667..ca74bb5 100644 --- a/flymyai/core/authorizations.py +++ b/flymyai/core/authorizations.py @@ -50,6 +50,10 @@ def _project_path(self): def prediction_path(self): return self._project_path.join(httpx.URL("predict")) + @property + def prediction_stream_path(self): + return self._project_path.join(httpx.URL("predict/stream/")) + @property def openapi_schema_path(self): return self._project_path.join(httpx.URL("openapi.json")) diff --git a/flymyai/core/exceptions.py b/flymyai/core/exceptions.py index 9acb84f..b1bba85 100644 --- a/flymyai/core/exceptions.py +++ b/flymyai/core/exceptions.py @@ -4,6 +4,7 @@ FlyMyAI422Response, Base4xxResponse, FlyMyAI400Response, + FlyMyAI421Response, ) @@ -43,6 +44,7 @@ def from_4xx(cls, response: FlyMyAIResponse): response_validation_templates = { 400: FlyMyAI400Response, 401: FlyMyAI401Response, + 421: FlyMyAI421Response, 422: FlyMyAI422Response, } response_4xx = response_validation_templates.get( diff --git a/flymyai/core/models.py b/flymyai/core/models.py index 4b8fe19..a74f1a8 100644 --- a/flymyai/core/models.py +++ b/flymyai/core/models.py @@ -61,6 +61,18 @@ def to_msg(self): """ +@dataclasses.dataclass +class FlyMyAI421Response(Base4xxResponse): + requires_retry = False + + def to_msg(self): + jsoned = json.loads(self.content) + msg = super().to_msg() + if detail := jsoned.get("detail"): + msg += f"\nDetail: {detail}" + return msg + + @dataclasses.dataclass class FlyMyAI422Response(Base4xxResponse): """ @@ -78,39 +90,51 @@ def to_msg(self): return msg -class PredictionResponse(pydantic.BaseModel): +class BaseFromServer(pydantic.BaseModel): + _response: FlyMyAIResponse = PrivateAttr() + + @property + def response(self): + return self._response + + @classmethod + def from_response(cls, response: FlyMyAIResponse, **kwargs): + status_code = kwargs.pop("status", response.status_code) + response_json = response.json() + response_json["status"] = response_json.get("status", status_code) + self = cls(**response_json, **kwargs) + self._response = response + return self + + +class PredictionResponse(BaseFromServer): """ Prediction response from FlyMyAI """ exc_history: list | None output_data: dict - _response: FlyMyAIResponse = PrivateAttr() + status: int inference_time: float | None = None - def __init__(self, response=None, **data): - super().__init__(**data) - self._response = data.get("response") - @property def response(self): return self._response -class OpenAPISchemaResponse(pydantic.BaseModel): +class OpenAPISchemaResponse(BaseFromServer): """ - OpenAPI schema for current project. Use it to construct your own schema + OpenAPI schema for the current project. Use it to construct your own schema """ exc_history: list | None openapi_schema: dict - _response: FlyMyAIResponse = PrivateAttr() + status: int - def __init__(self, response=None, **data): - super().__init__(**data) - self._response = response - @property - def response(self): - return self._response +class PredictionPartial(BaseFromServer): + status: int + output_data: dict | None = None + + _response: FlyMyAIResponse = PrivateAttr() diff --git a/tests/test_stream.py b/tests/test_stream.py new file mode 100644 index 0000000..e83e83c --- /dev/null +++ b/tests/test_stream.py @@ -0,0 +1,51 @@ +import os + +import pytest + +from flymyai import client as sync_client, async_client + +from tests.FixtureFactory import FixtureFactory + +factory = FixtureFactory(__file__) + + +@pytest.fixture +def dsn(): + os.environ["FLYMYAI_DSN"] = factory("address_fixture") + + +@pytest.fixture +def vllm_stream_payload(): + return factory("vllm_stream_payload") + + +@pytest.fixture +def vllm_stream_auth(): + return factory("vllm_auth") + + +def test_vllm_stream(vllm_stream_auth, vllm_stream_payload, dsn): + stream_iterator = sync_client(auth=vllm_stream_auth).stream(vllm_stream_payload) + for response in stream_iterator: + assert response.status == 200 + assert response.output_data + print(response.output_data["o_text_output"].pop(), end="") + print("\n") + + +@pytest.mark.asyncio +async def test_vllm_async_stream(vllm_stream_auth, vllm_stream_payload, dsn): + try: + stream_iterator = async_client(auth=vllm_stream_auth).stream( + vllm_stream_payload + ) + async for response in stream_iterator: + assert response.status == 200 + assert response.output_data + print(response.output_data["o_text_output"].pop(), end="") + except Exception as e: + if hasattr(e, "msg"): + print(e) + raise e + finally: + print()