Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,36 @@ jobs:
FMA_APIKEY: ${{ secrets.FMA_APIKEY }}
run: pytest tests/test_fields.py --tb=short

test_unknown_error_handle:
needs:
- lint

runs-on: ubuntu-latest

container:
image: python:3.8

steps:
- name: Check out git repo
uses: actions/checkout@v2
with:
fetch-depth: 0

- name: Fix
run: git config --global --add safe.directory '*'

- name: Install dependencies
run: pip3 install poetry pytest-asyncio && poetry config virtualenvs.create false && poetry install

- name: Test
env:
FMA_APIKEY: ${{ secrets.FMA_APIKEY }}
run: pytest tests/test_unknown_error_handle.py --tb=short

test_flymyai_client:
needs:
- lint
- test_unknown_error_handle

runs-on: ubuntu-latest

Expand All @@ -88,6 +115,7 @@ jobs:
test_stream:
needs:
- lint
- test_unknown_error_handle

runs-on: ubuntu-latest

Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,4 @@ dmypy.json
.idea/*
poetry.lock
venv*
tests/fixtures*
Empty file removed flymyai/core/_client.py
Empty file.
8 changes: 6 additions & 2 deletions flymyai/core/_response_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ def get_sse_status_code(self):

def _base_construct_from_sse(self):
sse_status = self.get_sse_status_code()
if sse_status < 400:
is_details = self.sse.json().get("details") is not None
if is_details and sse_status == 200:
sse_status = 599
if sse_status < 400 and not is_details:
response = FlyMyAIResponse(
status_code=sse_status,
content=self.sse.data or self.sse.event,
Expand All @@ -50,7 +53,8 @@ def _base_construct_from_sse(self):
status_code=sse_status,
content=self.sse.data or self.sse.event,
request=self.httpx_request,
headers=self.httpx_response.headers or self.sse.headers,
headers=self.httpx_response.headers
or getattr(self.sse, "headers", {}),
)
)

Expand Down
9 changes: 7 additions & 2 deletions flymyai/core/_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ class ServerSentEvent:
_headers: dict[str, str]
_url: str

__jsoned: Any

def __init__(
self,
*,
Expand Down Expand Up @@ -47,10 +49,13 @@ def data(self) -> str:
return self._data

def json(self) -> Any:
if hasattr(self, "__jsoned"):
return self.__jsoned
if self.data:
return json.loads(self.data.strip())
self.__jsoned = json.loads(self.data.strip())
if self.event:
return json.loads(self.event.strip())
self.__jsoned = json.loads(self.event.strip())
return self.__jsoned

@property
def headers(self):
Expand Down
32 changes: 16 additions & 16 deletions flymyai/core/clients/AsyncClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,29 +99,29 @@ async def _sse_instant(
"""
async with async_response_stream() as stream:
sse = await SSEDecoder().aiter(stream.aiter_lines()).__anext__()
response = ResponseFactory(
sse=sse, httpx_request=stream.request, httpx_response=stream
).construct()
return response
try:
response = ResponseFactory(
sse=sse, httpx_request=stream.request, httpx_response=stream
).construct()
return response
except BaseFlyMyAIException as e:
raise FlyMyAIPredictException.from_base_exception(e)

def _predict(self, client_info, payload: MultipartPayload):
"""
Executes request and waits for sse data
:param payload: model input data
:return: FlyMyAIResponse or raise an exception
"""
try:
return self._sse_instant(
lambda: self._client.stream(
method="post",
url=client_info.prediction_path,
timeout=_predict_timeout,
**payload.serialize(),
headers=client_info.authorization_headers,
)
return self._sse_instant(
lambda: self._client.stream(
method="post",
url=client_info.prediction_path,
timeout=_predict_timeout,
**payload.serialize(),
headers=client_info.authorization_headers,
)
except BaseFlyMyAIException as e:
raise FlyMyAIPredictException.from_response(e.response)
)

async def predict(
self, payload: dict, model: Optional[str] = None, max_retries=None
Expand Down Expand Up @@ -160,7 +160,7 @@ async def _stream(self, client_info: APIKeyClientInfo, payload: dict):
httpx_response=sse_stream,
).construct()
except BaseFlyMyAIException as e:
raise FlyMyAIPredictException.from_response(e.response)
raise FlyMyAIPredictException.from_base_exception(e)
yield response

def stream(self, payload: dict, model: Optional[str] = None, max_retries=None):
Expand Down
4 changes: 2 additions & 2 deletions flymyai/core/clients/SyncClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _predict(self, payload: MultipartPayload, client_info: APIKeyClientInfo):
lambda: self._stream_iterator(client_info, payload, False)
)
except BaseFlyMyAIException as e:
raise FlyMyAIPredictException.from_response(e.response)
raise FlyMyAIPredictException.from_base_exception(e)

def predict(self, payload: dict, model: Optional[str] = None, max_retries=None):
"""
Expand Down Expand Up @@ -100,7 +100,7 @@ def _stream(self, client_info: APIKeyClientInfo, payload: dict):
httpx_response=sse_stream,
).construct()
except BaseFlyMyAIException as e:
raise FlyMyAIPredictException.from_response(e.response)
raise FlyMyAIPredictException.from_base_exception(e)
yield response

def stream(self, payload: dict, model: Optional[str] = None):
Expand Down
17 changes: 14 additions & 3 deletions flymyai/core/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List
import datetime
from typing import List, Type

from ._response import FlyMyAIResponse
from .models.error_responses import (
Expand Down Expand Up @@ -32,13 +33,20 @@ def from_5xx(cls, response: FlyMyAIResponse):
msg = f"""
INTERNAL SERVER ERROR ({response.status_code}):
REQUEST URL: {response.url};
"""
Content [0:250]: {response.content.decode()[0:250]}
Timestamp [UTC]: {datetime.datetime.utcnow()}
"""
internal_error_mapping = {
500: lambda: cls(msg, False, response=response),
502: lambda: cls(msg, True, response=response),
503: lambda: cls(msg, False, response=response),
504: lambda: cls(msg, True, response=response),
524: lambda: cls(msg, True, response=response),
# unknown issue, probably detected on the client side
599: lambda: cls(msg, False, response=response),
# broker issues, they are not billed at all
5000: lambda: cls(msg, False, response=response),
5320: lambda: cls(msg, True, response=response),
}
return internal_error_mapping.get(
response.status_code, lambda: cls(msg, False)
Expand Down Expand Up @@ -72,7 +80,10 @@ def __str__(self):
return self.msg


class FlyMyAIPredictException(BaseFlyMyAIException): ...
class FlyMyAIPredictException(BaseFlyMyAIException):
@classmethod
def from_base_exception(cls, exception: BaseFlyMyAIException):
return cls(exception.msg, exception.requires_retry, exception.response)


class FlyMyAIOpenAPIException(BaseFlyMyAIException): ...
Expand Down
2 changes: 2 additions & 0 deletions flymyai/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ async def aretryable_callback(
continue
else:
raise exception_group_cls(retries_history)
except Exception as e:
raise e
else:
exception_gr = exception_group_cls(retries_history)
raise exception_gr
Loading