Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 93 additions & 21 deletions flymyai/core/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
overload,
Iterator,
AsyncContextManager,
AsyncIterator,
)

import httpx
Expand All @@ -21,7 +22,11 @@
BaseFlyMyAIException,
FlyMyAIOpenAPIException,
)
from flymyai.core.models import PredictionResponse, OpenAPISchemaResponse
from flymyai.core.models import (
PredictionResponse,
OpenAPISchemaResponse,
PredictionPartial,
)
from flymyai.multipart.payload import MultipartPayload
from flymyai.utils.utils import retryable_callback, aretryable_callback

Expand Down Expand Up @@ -56,14 +61,14 @@ def __init__(self, auth: APIKeyClientInfo | dict, max_retries=DEFAULT_RETRY_COUN
self.max_retries = max_retries

@overload
async def predict(self, input_data: dict, max_retries=None) -> PredictionResponse:
async def predict(self, payload: dict, max_retries=None) -> PredictionResponse:
...

@overload
def predict(self, input_data: dict, max_retries=None) -> PredictionResponse:
def predict(self, payload: dict, max_retries=None) -> PredictionResponse:
...

def predict(self, input_data: dict, max_retries=None) -> PredictionResponse:
def predict(self, payload: dict, max_retries=None) -> PredictionResponse:
...

@overload
Expand All @@ -77,6 +82,33 @@ def openapi_schema(self, max_retries=None) -> OpenAPISchemaResponse:
def openapi_schema(self, max_retries=None) -> OpenAPISchemaResponse:
...

@overload
async def stream(self, payload: dict) -> AsyncIterator[PredictionPartial]:
...

@overload
def stream(self, payload: dict) -> Iterator[PredictionPartial]:
...

def stream(self, payload: dict):
...

def _stream_iterator(
self, payload: MultipartPayload, is_long_stream: bool
) -> Iterator[httpx.Response] | AsyncIterator[httpx.Response]:
return self._client.stream(
method="post",
url=(
self.auth.prediction_path
if not is_long_stream
else self.auth.prediction_stream_path
),
**payload.serialize(),
timeout=_predict_timeout,
headers=self.auth.authorization_headers,
follow_redirects=True,
)

@staticmethod
def _wrap_request(request_callback: Callable):
response = request_callback()
Expand Down Expand Up @@ -135,15 +167,7 @@ def _predict(self, payload: MultipartPayload):
Wrap predict method in sse
"""
try:
return self._sse_instant(
lambda: self._client.stream(
method="post",
url=self.auth.prediction_path,
**payload.serialize(),
timeout=_predict_timeout,
headers=self.auth.authorization_headers,
)
)
return self._sse_instant(lambda: self._stream_iterator(payload, False))
except BaseFlyMyAIException as e:
raise FlyMyAIPredictException.from_response(e.response)

Expand All @@ -164,9 +188,33 @@ def predict(self, payload: dict, max_retries=None):
FlyMyAIPredictException,
FlyMyAIExceptionGroup,
)
return PredictionResponse(
exc_history=history, response=response, **response.json()
)
return PredictionResponse.from_response(response, exc_history=history)

def _stream(self, payload: dict):
payload = MultipartPayload(payload)
response_iterator = self._stream_iterator(payload, is_long_stream=True)
decoder = SSEDecoder()
with response_iterator as sse_stream:
for sse_partial in decoder.iter(sse_stream.iter_lines()):
try:
response = ResponseFactory(
sse=sse_partial,
httpx_request=sse_stream.request,
httpx_response=sse_stream,
).construct()
except BaseFlyMyAIException as e:
raise FlyMyAIPredictException.from_response(e.response)
yield response

def stream(self, payload: dict):
stream_iter = self._stream(payload)
last_response = None
for response in stream_iter:
response.stream = stream_iter
yield PredictionPartial.from_response(response)
last_response = response
if last_response:
last_response.is_stream_consumed = True

def _openapi_schema(self):
"""
Expand Down Expand Up @@ -197,7 +245,7 @@ def openapi_schema(self, max_retries=None):
FlyMyAIPredictException,
FlyMyAIExceptionGroup,
)
return OpenAPISchemaResponse(
return OpenAPISchemaResponse.from_response(
exc_history=history, openapi_schema=response.json(), response=response
)

Expand Down Expand Up @@ -244,7 +292,7 @@ async def openapi_schema(self, max_retries=None):
FlyMyAIPredictException,
FlyMyAIExceptionGroup,
)
return OpenAPISchemaResponse(
return OpenAPISchemaResponse.from_response(
exc_history=history, openapi_schema=response.json(), response=response
)

Expand Down Expand Up @@ -315,9 +363,33 @@ async def predict(self, payload: dict, max_retries=None):
FlyMyAIPredictException,
FlyMyAIExceptionGroup,
)
return PredictionResponse(
exc_history=history, response=response, **response.json()
)
return PredictionResponse.from_response(response, exc_history=history)

async def _stream(self, payload: dict):
payload = MultipartPayload(payload)
stream_iterator = self._stream_iterator(payload, is_long_stream=True)
decoder = SSEDecoder()
async with stream_iterator as sse_stream:
async for sse_partial in decoder.aiter(sse_stream.aiter_lines()):
try:
response = ResponseFactory(
sse=sse_partial,
httpx_request=sse_stream.request,
httpx_response=sse_stream,
).construct()
except BaseFlyMyAIException as e:
raise FlyMyAIPredictException.from_response(e.response)
yield response

async def stream(self, payload: dict):
stream_iter = self._stream(payload)
last_response = None
async for response in stream_iter:
response.stream = stream_iter
yield PredictionPartial.from_response(response)
last_response = response
if last_response:
last_response.is_stream_consumed = True

@staticmethod
async def _wrap_request(request_callback: Callable[..., Awaitable[httpx.Response]]):
Expand Down
4 changes: 3 additions & 1 deletion flymyai/core/_response_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ def __init__(
self.httpx_response = httpx_response

def get_sse_status_code(self):
return self.sse.json().get("status_code", 200)
return self.sse.json().get(
"status", self.httpx_response.status_code if self.httpx_response else 200
)

def _base_construct_from_sse(self):
sse_status = self.get_sse_status_code()
Expand Down
4 changes: 4 additions & 0 deletions flymyai/core/authorizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ def _project_path(self):
def prediction_path(self):
return self._project_path.join(httpx.URL("predict"))

@property
def prediction_stream_path(self):
return self._project_path.join(httpx.URL("predict/stream/"))

@property
def openapi_schema_path(self):
return self._project_path.join(httpx.URL("openapi.json"))
Expand Down
2 changes: 2 additions & 0 deletions flymyai/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
FlyMyAI422Response,
Base4xxResponse,
FlyMyAI400Response,
FlyMyAI421Response,
)


Expand Down Expand Up @@ -43,6 +44,7 @@ def from_4xx(cls, response: FlyMyAIResponse):
response_validation_templates = {
400: FlyMyAI400Response,
401: FlyMyAI401Response,
421: FlyMyAI421Response,
422: FlyMyAI422Response,
}
response_4xx = response_validation_templates.get(
Expand Down
54 changes: 39 additions & 15 deletions flymyai/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@ def to_msg(self):
"""


@dataclasses.dataclass
class FlyMyAI421Response(Base4xxResponse):
requires_retry = False

def to_msg(self):
jsoned = json.loads(self.content)
msg = super().to_msg()
if detail := jsoned.get("detail"):
msg += f"\nDetail: {detail}"
return msg


@dataclasses.dataclass
class FlyMyAI422Response(Base4xxResponse):
"""
Expand All @@ -78,39 +90,51 @@ def to_msg(self):
return msg


class PredictionResponse(pydantic.BaseModel):
class BaseFromServer(pydantic.BaseModel):
_response: FlyMyAIResponse = PrivateAttr()

@property
def response(self):
return self._response

@classmethod
def from_response(cls, response: FlyMyAIResponse, **kwargs):
status_code = kwargs.pop("status", response.status_code)
response_json = response.json()
response_json["status"] = response_json.get("status", status_code)
self = cls(**response_json, **kwargs)
self._response = response
return self


class PredictionResponse(BaseFromServer):
"""
Prediction response from FlyMyAI
"""

exc_history: list | None
output_data: dict
_response: FlyMyAIResponse = PrivateAttr()
status: int

inference_time: float | None = None

def __init__(self, response=None, **data):
super().__init__(**data)
self._response = data.get("response")

@property
def response(self):
return self._response


class OpenAPISchemaResponse(pydantic.BaseModel):
class OpenAPISchemaResponse(BaseFromServer):
"""
OpenAPI schema for current project. Use it to construct your own schema
OpenAPI schema for the current project. Use it to construct your own schema
"""

exc_history: list | None
openapi_schema: dict
_response: FlyMyAIResponse = PrivateAttr()
status: int

def __init__(self, response=None, **data):
super().__init__(**data)
self._response = response

@property
def response(self):
return self._response
class PredictionPartial(BaseFromServer):
status: int
output_data: dict | None = None

_response: FlyMyAIResponse = PrivateAttr()
51 changes: 51 additions & 0 deletions tests/test_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import os

import pytest

from flymyai import client as sync_client, async_client

from tests.FixtureFactory import FixtureFactory

factory = FixtureFactory(__file__)


@pytest.fixture
def dsn():
os.environ["FLYMYAI_DSN"] = factory("address_fixture")


@pytest.fixture
def vllm_stream_payload():
return factory("vllm_stream_payload")


@pytest.fixture
def vllm_stream_auth():
return factory("vllm_auth")


def test_vllm_stream(vllm_stream_auth, vllm_stream_payload, dsn):
stream_iterator = sync_client(auth=vllm_stream_auth).stream(vllm_stream_payload)
for response in stream_iterator:
assert response.status == 200
assert response.output_data
print(response.output_data["o_text_output"].pop(), end="")
print("\n")


@pytest.mark.asyncio
async def test_vllm_async_stream(vllm_stream_auth, vllm_stream_payload, dsn):
try:
stream_iterator = async_client(auth=vllm_stream_auth).stream(
vllm_stream_payload
)
async for response in stream_iterator:
assert response.status == 200
assert response.output_data
print(response.output_data["o_text_output"].pop(), end="")
except Exception as e:
if hasattr(e, "msg"):
print(e)
raise e
finally:
print()