diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 2ab4232..4ed51bc 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -59,9 +59,36 @@ jobs: FMA_APIKEY: ${{ secrets.FMA_APIKEY }} run: pytest tests/test_fields.py --tb=short + test_unknown_error_handle: + needs: + - lint + + runs-on: ubuntu-latest + + container: + image: python:3.8 + + steps: + - name: Check out git repo + uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Fix + run: git config --global --add safe.directory '*' + + - name: Install dependencies + run: pip3 install poetry pytest-asyncio && poetry config virtualenvs.create false && poetry install + + - name: Test + env: + FMA_APIKEY: ${{ secrets.FMA_APIKEY }} + run: pytest tests/test_unknown_error_handle.py --tb=short + test_flymyai_client: needs: - lint + - test_unknown_error_handle runs-on: ubuntu-latest @@ -88,6 +115,7 @@ jobs: test_stream: needs: - lint + - test_unknown_error_handle runs-on: ubuntu-latest diff --git a/.gitignore b/.gitignore index 520ab1f..3781077 100644 --- a/.gitignore +++ b/.gitignore @@ -142,3 +142,4 @@ dmypy.json .idea/* poetry.lock venv* +tests/fixtures* diff --git a/flymyai/core/_client.py b/flymyai/core/_client.py deleted file mode 100644 index e69de29..0000000 diff --git a/flymyai/core/_response_factory.py b/flymyai/core/_response_factory.py index 576b0d7..f7feab5 100644 --- a/flymyai/core/_response_factory.py +++ b/flymyai/core/_response_factory.py @@ -35,7 +35,10 @@ def get_sse_status_code(self): def _base_construct_from_sse(self): sse_status = self.get_sse_status_code() - if sse_status < 400: + is_details = self.sse.json().get("details") is not None + if is_details and sse_status == 200: + sse_status = 599 + if sse_status < 400 and not is_details: response = FlyMyAIResponse( status_code=sse_status, content=self.sse.data or self.sse.event, @@ -50,7 +53,8 @@ def _base_construct_from_sse(self): status_code=sse_status, content=self.sse.data or self.sse.event, request=self.httpx_request, - headers=self.httpx_response.headers or self.sse.headers, + headers=self.httpx_response.headers + or getattr(self.sse, "headers", {}), ) ) diff --git a/flymyai/core/_streaming.py b/flymyai/core/_streaming.py index f09ce94..f6c4114 100644 --- a/flymyai/core/_streaming.py +++ b/flymyai/core/_streaming.py @@ -14,6 +14,8 @@ class ServerSentEvent: _headers: dict[str, str] _url: str + __jsoned: Any + def __init__( self, *, @@ -47,10 +49,13 @@ def data(self) -> str: return self._data def json(self) -> Any: + if hasattr(self, "__jsoned"): + return self.__jsoned if self.data: - return json.loads(self.data.strip()) + self.__jsoned = json.loads(self.data.strip()) if self.event: - return json.loads(self.event.strip()) + self.__jsoned = json.loads(self.event.strip()) + return self.__jsoned @property def headers(self): diff --git a/flymyai/core/clients/AsyncClient.py b/flymyai/core/clients/AsyncClient.py index b7cc4cb..7145016 100644 --- a/flymyai/core/clients/AsyncClient.py +++ b/flymyai/core/clients/AsyncClient.py @@ -99,10 +99,13 @@ async def _sse_instant( """ async with async_response_stream() as stream: sse = await SSEDecoder().aiter(stream.aiter_lines()).__anext__() - response = ResponseFactory( - sse=sse, httpx_request=stream.request, httpx_response=stream - ).construct() - return response + try: + response = ResponseFactory( + sse=sse, httpx_request=stream.request, httpx_response=stream + ).construct() + return response + except BaseFlyMyAIException as e: + raise FlyMyAIPredictException.from_base_exception(e) def _predict(self, client_info, payload: MultipartPayload): """ @@ -110,18 +113,15 @@ def _predict(self, client_info, payload: MultipartPayload): :param payload: model input data :return: FlyMyAIResponse or raise an exception """ - try: - return self._sse_instant( - lambda: self._client.stream( - method="post", - url=client_info.prediction_path, - timeout=_predict_timeout, - **payload.serialize(), - headers=client_info.authorization_headers, - ) + return self._sse_instant( + lambda: self._client.stream( + method="post", + url=client_info.prediction_path, + timeout=_predict_timeout, + **payload.serialize(), + headers=client_info.authorization_headers, ) - except BaseFlyMyAIException as e: - raise FlyMyAIPredictException.from_response(e.response) + ) async def predict( self, payload: dict, model: Optional[str] = None, max_retries=None @@ -160,7 +160,7 @@ async def _stream(self, client_info: APIKeyClientInfo, payload: dict): httpx_response=sse_stream, ).construct() except BaseFlyMyAIException as e: - raise FlyMyAIPredictException.from_response(e.response) + raise FlyMyAIPredictException.from_base_exception(e) yield response def stream(self, payload: dict, model: Optional[str] = None, max_retries=None): diff --git a/flymyai/core/clients/SyncClient.py b/flymyai/core/clients/SyncClient.py index b40f9c4..32e60c9 100644 --- a/flymyai/core/clients/SyncClient.py +++ b/flymyai/core/clients/SyncClient.py @@ -62,7 +62,7 @@ def _predict(self, payload: MultipartPayload, client_info: APIKeyClientInfo): lambda: self._stream_iterator(client_info, payload, False) ) except BaseFlyMyAIException as e: - raise FlyMyAIPredictException.from_response(e.response) + raise FlyMyAIPredictException.from_base_exception(e) def predict(self, payload: dict, model: Optional[str] = None, max_retries=None): """ @@ -100,7 +100,7 @@ def _stream(self, client_info: APIKeyClientInfo, payload: dict): httpx_response=sse_stream, ).construct() except BaseFlyMyAIException as e: - raise FlyMyAIPredictException.from_response(e.response) + raise FlyMyAIPredictException.from_base_exception(e) yield response def stream(self, payload: dict, model: Optional[str] = None): diff --git a/flymyai/core/exceptions.py b/flymyai/core/exceptions.py index 90f919c..f89ca80 100644 --- a/flymyai/core/exceptions.py +++ b/flymyai/core/exceptions.py @@ -1,4 +1,5 @@ -from typing import List +import datetime +from typing import List, Type from ._response import FlyMyAIResponse from .models.error_responses import ( @@ -32,13 +33,20 @@ def from_5xx(cls, response: FlyMyAIResponse): msg = f""" INTERNAL SERVER ERROR ({response.status_code}): REQUEST URL: {response.url}; - """ + Content [0:250]: {response.content.decode()[0:250]} + Timestamp [UTC]: {datetime.datetime.utcnow()} + """ internal_error_mapping = { 500: lambda: cls(msg, False, response=response), 502: lambda: cls(msg, True, response=response), 503: lambda: cls(msg, False, response=response), 504: lambda: cls(msg, True, response=response), 524: lambda: cls(msg, True, response=response), + # unknown issue, probably detected on the client side + 599: lambda: cls(msg, False, response=response), + # broker issues, they are not billed at all + 5000: lambda: cls(msg, False, response=response), + 5320: lambda: cls(msg, True, response=response), } return internal_error_mapping.get( response.status_code, lambda: cls(msg, False) @@ -72,7 +80,10 @@ def __str__(self): return self.msg -class FlyMyAIPredictException(BaseFlyMyAIException): ... +class FlyMyAIPredictException(BaseFlyMyAIException): + @classmethod + def from_base_exception(cls, exception: BaseFlyMyAIException): + return cls(exception.msg, exception.requires_retry, exception.response) class FlyMyAIOpenAPIException(BaseFlyMyAIException): ... diff --git a/flymyai/utils/utils.py b/flymyai/utils/utils.py index 452da33..2219c11 100644 --- a/flymyai/utils/utils.py +++ b/flymyai/utils/utils.py @@ -48,6 +48,8 @@ async def aretryable_callback( continue else: raise exception_group_cls(retries_history) + except Exception as e: + raise e else: exception_gr = exception_group_cls(retries_history) raise exception_gr diff --git a/tests/test_unknown_error_handle.py b/tests/test_unknown_error_handle.py new file mode 100644 index 0000000..53b7e9f --- /dev/null +++ b/tests/test_unknown_error_handle.py @@ -0,0 +1,247 @@ +import dataclasses +import logging +from typing import Generator, AsyncGenerator, Union, Any + +import pytest + +from flymyai import FlyMyAIExceptionGroup +from flymyai.core.clients.AsyncClient import BaseAsyncClient +from flymyai.core.clients.SyncClient import BaseSyncClient + + +class MockedStream: + _gen: Union[Generator[bytes, Any, None], AsyncGenerator[bytes, Any]] + _http1_status: int + headers: dict + + class StreamWrapper: + def __init__( + self, + gen: Union[Generator[bytes, Any, None], AsyncGenerator[bytes, Any]], + ): + self.gen = gen + + def __next__(self): + data = next(self.gen) + return data.decode() + + async def __anext__(self): + self.gen: AsyncGenerator[bytes, None] + data = await self.gen.__anext__() + return data.decode() + + def __iter__(self): + return self + + def __aiter__(self): + return self + + @dataclasses.dataclass + class Request: + url: str + + def __init__( + self, + data_generator: Union[Generator[bytes, Any, None], AsyncGenerator[bytes, Any]], + status_code: int, + ): + self._gen = data_generator + self.http1_status = status_code + self.headers = {} + + def iter_lines(self): + return self.StreamWrapper(self._gen) + + def aiter_lines(self): + return self.StreamWrapper(self._gen) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + logging.critical("MockedStream __exit__") + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + logging.critical("MockedStream __aexit__") + + @property + def request(self) -> Request: + return self.Request("AnyStr") + + @property + def status_code(self) -> int: + return self.http1_status + + +@pytest.fixture +def mock_SyncStreamIteratorWith_200_Details(): + client = BaseSyncClient("123", "123/123", max_retries=3) + + def mocked_stream(*args, **kwargs): + yield b'data: {"details": "Correct", "status": 200}' + yield b"" + + client._client.stream = lambda *_, **__: MockedStream(mocked_stream(), 200) + return client + + +@pytest.fixture +def mock_SyncStreamIteratorWith_5000(): + client = BaseSyncClient("123", "123/123", max_retries=3) + + def mocked_stream(*args, **kwargs): + yield b'data: {"details": "Unexpected broker error! Contact support!", "status": 5000}' + yield b"" + + client._client.stream = lambda *_, **__: MockedStream(mocked_stream(), 200) + return client + + +@pytest.fixture +def mock_SyncStreamIteratorWith_5320(): + client = BaseSyncClient("123", "123/123", max_retries=3) + + def mocked_stream(*args, **kwargs): + yield b'data: {"details": "Broker is down. Try again later!", "status": 5320}' + yield b"" + + client._client.stream = lambda *_, **__: MockedStream(mocked_stream(), 200) + return client + + +@pytest.fixture +def mock_AsyncStreamIteratorWith_200_Details(): + client = BaseAsyncClient("123", "123/123", max_retries=3) + + async def mocked_stream(*args, **kwargs): + yield b'data: {"details": "Correct", "status": 200}' + yield b"" + + client._client.stream = lambda *_, **__: MockedStream(mocked_stream(), 200) + return client + + +@pytest.fixture +def mock_AsyncStreamIteratorWith_5000(): + client = BaseAsyncClient("123", "123/123", max_retries=3) + + async def mocked_stream(*args, **kwargs): + yield b'data: {"details": "Unexpected broker error! Contact support!", "status": 5000}' + yield b"" + + client._client.stream = lambda *_, **__: MockedStream(mocked_stream(), 200) + return client + + +@pytest.fixture +def mock_AsyncStreamIteratorWith_5320(): + client = BaseAsyncClient("123", "123/123", max_retries=3) + + async def mocked_stream(*args, **kwargs): + yield b'data: {"details": "Broker is down. Try again later!", "status": 5320}' + yield b"" + + client._client.stream = lambda *_, **__: MockedStream(mocked_stream(), 200) + return client + + +def test_sync_unknown_error_handle_predict(mock_SyncStreamIteratorWith_200_Details): + client = mock_SyncStreamIteratorWith_200_Details + exc = None + with pytest.raises(FlyMyAIExceptionGroup): + try: + result = client.predict({}) + except FlyMyAIExceptionGroup as e: + exc = e + raise e + raise Exception(f"Should not reach this code: {result}") + assert len(exc.errors) == 1 + assert exc.errors[0].response.status_code == 599 + assert "Timestamp" in str(exc) + + +def test_sync_broker_unknown_predict(mock_SyncStreamIteratorWith_5000): + client = mock_SyncStreamIteratorWith_5000 + exc = None + with pytest.raises(FlyMyAIExceptionGroup): + try: + result = client.predict({}) + except FlyMyAIExceptionGroup as e: + exc = e + raise e + raise Exception(f"Should not reach this code: {result}") + assert len(exc.errors) == 1 + assert exc.errors[0].response.status_code == 5000 + assert "Timestamp" in str(exc) + + +def test_sync_broker_disconnected_error_handle_predict( + mock_SyncStreamIteratorWith_5320, +): + client = mock_SyncStreamIteratorWith_5320 + exc = None + with pytest.raises(FlyMyAIExceptionGroup): + try: + result = client.predict({}) + except FlyMyAIExceptionGroup as e: + exc = e + raise e + raise Exception(f"Should not reach this code: {result}") + assert len(exc.errors) == client.max_retries + assert exc.errors[0].response.status_code == 5320 + assert "Timestamp" in str(exc) + + +@pytest.mark.asyncio +async def test_async_unknown_error_handle_predict( + mock_AsyncStreamIteratorWith_200_Details, +): + client = mock_AsyncStreamIteratorWith_200_Details + exc = None + with pytest.raises(FlyMyAIExceptionGroup): + try: + result = await client.predict({}) + except FlyMyAIExceptionGroup as e: + exc = e + raise e + raise Exception(f"Should not reach this code: {result}") + assert len(exc.errors) == 1 + assert exc.errors[0].response.status_code == 599 + assert "Timestamp" in str(exc) + + +@pytest.mark.asyncio +async def test_async_broker_unknown_predict(mock_AsyncStreamIteratorWith_5000): + client = mock_AsyncStreamIteratorWith_5000 + exc = None + with pytest.raises(FlyMyAIExceptionGroup): + try: + result = await client.predict({}) + except FlyMyAIExceptionGroup as e: + exc = e + raise e + raise Exception(f"Should not reach this code: {result}") + assert len(exc.errors) == 1 + assert exc.errors[0].response.status_code == 5000 + assert "Timestamp" in str(exc) + + +@pytest.mark.asyncio +async def test_async_broker_disconnected_error_handle_predict( + mock_AsyncStreamIteratorWith_5320, +): + client = mock_AsyncStreamIteratorWith_5320 + exc = None + with pytest.raises(FlyMyAIExceptionGroup): + try: + result = await client.predict({}) + except FlyMyAIExceptionGroup as e: + exc = e + raise e + raise Exception(f"Should not reach this code: {result}") + assert len(exc.errors) == client.max_retries + assert exc.errors[0].response.status_code == 5320 + assert "Timestamp" in str(exc)