From 702df486695c1dfa930eb45a17d857e6d030bfda Mon Sep 17 00:00:00 2001 From: D1-3105 Date: Fri, 18 Oct 2024 03:11:23 +0300 Subject: [PATCH 1/5] add support for async inference route --- .github/workflows/test.yaml | 26 ++++ .gitignore | 2 +- flymyai/core/authorizations.py | 8 ++ flymyai/core/clients/AsyncClient.py | 58 ++++++++- flymyai/core/clients/SyncClient.py | 59 ++++++++- flymyai/core/clients/base_client.py | 116 ++++++++++++++--- flymyai/core/exceptions.py | 79 ++++++++---- flymyai/core/models/base.py | 16 +++ flymyai/core/models/error_responses.py | 18 +-- flymyai/core/models/successful_responses.py | 108 ++++++++++++++-- flymyai/core/response_factory/__init__.py | 0 .../async_task_result_factory.py | 10 ++ .../response_factory/base_response_factory.py | 42 ++++++ .../plain_inference_response_factory.py} | 27 +--- flymyai/utils/utils.py | 112 +++++++++++----- tests/fixtures/test_async_inference.json | 15 +++ tests/test_async_inference.py | 122 ++++++++++++++++++ 17 files changed, 699 insertions(+), 119 deletions(-) create mode 100644 flymyai/core/models/base.py create mode 100644 flymyai/core/response_factory/__init__.py create mode 100644 flymyai/core/response_factory/async_task_result_factory.py create mode 100644 flymyai/core/response_factory/base_response_factory.py rename flymyai/core/{_response_factory.py => response_factory/plain_inference_response_factory.py} (73%) create mode 100644 tests/fixtures/test_async_inference.json create mode 100644 tests/test_async_inference.py diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 4ed51bc..f600032 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -62,6 +62,7 @@ jobs: test_unknown_error_handle: needs: - lint + - test_fields runs-on: ubuntu-latest @@ -139,3 +140,28 @@ jobs: FMA_APIKEY: ${{ secrets.FMA_APIKEY }} run: pytest tests/test_stream.py --tb=short + test_async_inference: + needs: + - lint + - test_unknown_error_handle + runs-on: ubuntu-latest + + container: + image: python:3.8 + + steps: + - name: Check out git repo + uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Fix + run: git config --global --add safe.directory '*' + + - name: Install dependencies + run: pip3 install poetry pytest-asyncio && poetry config virtualenvs.create false && poetry install + + - name: Test + env: + FMA_APIKEY: ${{ secrets.FMA_APIKEY }} + run: pytest tests/test_async_inference.py --tb=short diff --git a/.gitignore b/.gitignore index 3781077..52dce9b 100644 --- a/.gitignore +++ b/.gitignore @@ -142,4 +142,4 @@ dmypy.json .idea/* poetry.lock venv* -tests/fixtures* +tests/fixtures_* diff --git a/flymyai/core/authorizations.py b/flymyai/core/authorizations.py index 748d65b..e2d303e 100644 --- a/flymyai/core/authorizations.py +++ b/flymyai/core/authorizations.py @@ -52,6 +52,14 @@ def _project_path(self): def prediction_path(self): return self._project_path.join(httpx.URL("predict")) + @property + def prediction_async_path(self): + return self._project_path.join(httpx.URL("predict/async/")) + + @property + def prediction_result_path(self): + return self._project_path.join(httpx.URL("predict/async/result/")) + @property def prediction_cancel_path(self): return self._project_path.join(httpx.URL("predict/cancel/")) diff --git a/flymyai/core/clients/AsyncClient.py b/flymyai/core/clients/AsyncClient.py index 7145016..840fd34 100644 --- a/flymyai/core/clients/AsyncClient.py +++ b/flymyai/core/clients/AsyncClient.py @@ -3,7 +3,9 @@ import httpx -from flymyai.core._response_factory import ResponseFactory +from flymyai.core.response_factory.plain_inference_response_factory import ( + SSEInferenceResponseFactory, +) from flymyai.core._streaming import SSEDecoder from flymyai.core.authorizations import APIKeyClientInfo from flymyai.core.clients.base_client import BaseClient, _predict_timeout @@ -12,10 +14,12 @@ FlyMyAIOpenAPIException, FlyMyAIPredictException, FlyMyAIExceptionGroup, + FlyMyAIAsyncTaskException, ) from flymyai.core.models.successful_responses import ( OpenAPISchemaResponse, PredictionResponse, + AsyncPredictionTask, ) from flymyai.core.stream_iterators.AsyncPredictionStream import AsyncPredictionStream from flymyai.multipart import MultipartPayload @@ -28,6 +32,7 @@ def _construct_client(self): http2=True, headers=self.client_info.authorization_headers, base_url=os.getenv("FLYMYAI_DSN", "https://api.flymy.ai/"), + timeout=_predict_timeout, ) async def __aenter__(self): @@ -84,7 +89,7 @@ async def cancel_prediction( url=full_client_info.prediction_cancel_path, json={"infer_id": prediction_id}, ) - return ResponseFactory( + return SSEInferenceResponseFactory( httpx_response=response, httpx_request=response.request ).construct() @@ -100,7 +105,7 @@ async def _sse_instant( async with async_response_stream() as stream: sse = await SSEDecoder().aiter(stream.aiter_lines()).__anext__() try: - response = ResponseFactory( + response = SSEInferenceResponseFactory( sse=sse, httpx_request=stream.request, httpx_response=stream ).construct() return response @@ -145,6 +150,49 @@ async def predict( ) return PredictionResponse.from_response(response, exc_history=history) + async def predict_async_task( + self, payload: dict, model: Optional[str] = None, max_retries=None + ) -> AsyncPredictionTask: + payload = MultipartPayload(input_data=payload) + client_info = self.amend_client_info(model) + try: + _, response = await aretryable_callback( + lambda: self._client.post( + client_info.prediction_async_path, **payload.serialize() + ), + max_retries or self.max_retries, + FlyMyAIAsyncTaskException, + FlyMyAIExceptionGroup, + ) + response = SSEInferenceResponseFactory(response).construct() + return self._async_prediction_task_construct(response, client_info) + except BaseFlyMyAIException as e: + raise FlyMyAIAsyncTaskException.from_base_exception(e) + + async def prediction_task_result( + self, prediction_task: AsyncPredictionTask, timeout: Optional[float] = None + ): + prediction_id = prediction_task.prediction_id + + async def get_res(): + data_resp = await self._client.get( + url=( + prediction_task.client_info or self.client_info + ).prediction_result_path, + params={"request_id": prediction_id}, + ) + return self._construct_task_result(data_resp) + + _, res = await aretryable_callback( + lambda: get_res(), + None, + FlyMyAIAsyncTaskException, + FlyMyAIExceptionGroup, + timeout, + 0.5, + ) + return res + async def _stream(self, client_info: APIKeyClientInfo, payload: dict): payload = MultipartPayload(payload) stream_iterator = self._stream_iterator( @@ -154,7 +202,7 @@ async def _stream(self, client_info: APIKeyClientInfo, payload: dict): async with stream_iterator as sse_stream: async for sse_partial in decoder.aiter(sse_stream.aiter_lines()): try: - response = ResponseFactory( + response = SSEInferenceResponseFactory( sse=sse_partial, httpx_request=sse_stream.request, httpx_response=sse_stream, @@ -175,7 +223,7 @@ async def _wrap_request(request_callback: Callable[..., Awaitable[httpx.Response Execute a request callback and return the response """ response = await request_callback() - return ResponseFactory(httpx_response=response).construct() + return SSEInferenceResponseFactory(httpx_response=response).construct() async def close(self): """ diff --git a/flymyai/core/clients/SyncClient.py b/flymyai/core/clients/SyncClient.py index 32e60c9..ff6a1a3 100644 --- a/flymyai/core/clients/SyncClient.py +++ b/flymyai/core/clients/SyncClient.py @@ -3,19 +3,23 @@ import httpx -from flymyai.core._response_factory import ResponseFactory from flymyai.core._streaming import SSEDecoder from flymyai.core.authorizations import APIKeyClientInfo -from flymyai.core.clients.base_client import BaseClient +from flymyai.core.clients.base_client import BaseClient, _predict_timeout from flymyai.core.exceptions import ( BaseFlyMyAIException, FlyMyAIOpenAPIException, FlyMyAIPredictException, FlyMyAIExceptionGroup, + FlyMyAIAsyncTaskException, ) from flymyai.core.models.successful_responses import ( PredictionResponse, OpenAPISchemaResponse, + AsyncPredictionTask, +) +from flymyai.core.response_factory.plain_inference_response_factory import ( + SSEInferenceResponseFactory, ) from flymyai.core.stream_iterators.PredictionStream import PredictionStream from flymyai.multipart import MultipartPayload @@ -28,6 +32,7 @@ def _construct_client(self): http2=True, headers=self.client_info.authorization_headers, base_url=os.getenv("FLYMYAI_DSN", "https://api.flymy.ai/"), + timeout=_predict_timeout, ) def __enter__(self): @@ -45,7 +50,7 @@ def _sse_instant(cls, stream_iter_func: Callable[[], Iterator[httpx.Response]]): """ with stream_iter_func() as stream: stream: httpx.Response - response = ResponseFactory( + response = SSEInferenceResponseFactory( sse=next(SSEDecoder().iter(stream.iter_lines())), httpx_request=stream.request, httpx_response=stream, @@ -85,6 +90,50 @@ def predict(self, payload: dict, model: Optional[str] = None, max_retries=None): ) return PredictionResponse.from_response(response, exc_history=history) + def predict_async_task( + self, payload: dict, model: Optional[str] = None, max_retries=None + ): + payload = MultipartPayload(input_data=payload) + client_info = self.amend_client_info(model) + try: + _, response = retryable_callback( + lambda: self._client.post( + client_info.prediction_async_path, **payload.serialize() + ), + max_retries or self.max_retries, + FlyMyAIAsyncTaskException, + FlyMyAIExceptionGroup, + ) + response = SSEInferenceResponseFactory(response).construct() + return self._async_prediction_task_construct(response, client_info) + except BaseFlyMyAIException as e: + raise FlyMyAIAsyncTaskException.from_base_exception(e) + + def prediction_task_result( + self, prediction_task: AsyncPredictionTask, timeout: Optional[float] = None + ): + prediction_id = prediction_task.prediction_id + + def get_res(): + resp = self._client.get( + url=( + prediction_task.client_info or self.client_info + ).prediction_result_path, + params={"request_id": prediction_id}, + ) + return self._construct_task_result(resp) + + _, res = retryable_callback( + lambda: get_res(), + None, + FlyMyAIAsyncTaskException, + FlyMyAIExceptionGroup, + timeout, + 0.5, + ) + + return res + def _stream(self, client_info: APIKeyClientInfo, payload: dict): payload = MultipartPayload(payload) response_iterator = self._stream_iterator( @@ -94,7 +143,7 @@ def _stream(self, client_info: APIKeyClientInfo, payload: dict): with response_iterator as sse_stream: for sse_partial in decoder.iter(sse_stream.iter_lines()): try: - response = ResponseFactory( + response = SSEInferenceResponseFactory( sse=sse_partial, httpx_request=sse_stream.request, httpx_response=sse_stream, @@ -138,7 +187,7 @@ def cancel_prediction( url=full_client_info.prediction_cancel_path, json={"infer_id": prediction_id}, ) - return ResponseFactory( + return SSEInferenceResponseFactory( httpx_response=response, httpx_request=response.request ).construct() diff --git a/flymyai/core/clients/base_client.py b/flymyai/core/clients/base_client.py index af9cff3..eb8c5db 100644 --- a/flymyai/core/clients/base_client.py +++ b/flymyai/core/clients/base_client.py @@ -6,14 +6,26 @@ ) import httpx +import pydantic -from flymyai.core._response_factory import ResponseFactory +from flymyai.core.response_factory.async_task_result_factory import ( + AsyncTaskResultFactory, +) +from flymyai.core.response_factory.plain_inference_response_factory import ( + SSEInferenceResponseFactory, +) from flymyai.core.authorizations import APIKeyClientInfo -from flymyai.core.exceptions import ImproperlyConfiguredClientException +from flymyai.core.exceptions import ( + ImproperlyConfiguredClientException, + BaseFlyMyAIException, + FlyMyAIAsyncTaskException, +) from flymyai.core.models.successful_responses import ( PredictionResponse, OpenAPISchemaResponse, PredictionPartial, + AsyncPredictionTask, + AsyncPredictionResponseList, ) from flymyai.multipart import MultipartPayload @@ -24,7 +36,12 @@ ) -_predict_timeout = httpx.Timeout(None, connect=10) +_predict_timeout = httpx.Timeout( + connect=int(os.getenv("FMA_CONNECT_TIMEOUT", 999999)), + read=int(os.getenv("FMA_READ_TIMEOUT", 999999)), + write=int(os.getenv("FMA_WRITE_TIMEOUT", 999999)), + pool=int(os.getenv("FMA_POOL_TIMEOUT", 999999)), +) class BaseClient(Generic[_PossibleClients]): @@ -59,50 +76,114 @@ def amend_client_info(self, model: Optional[str] = None): @overload async def predict( self, payload: dict, model: Optional[str] = None, max_retries=None - ) -> PredictionResponse: ... + ) -> PredictionResponse: + ... @overload def predict( self, payload: dict, model: Optional[str] = None, max_retries=None - ) -> PredictionResponse: ... + ) -> PredictionResponse: + ... def predict( self, payload: dict, model: Optional[str] = None, max_retries=None - ) -> PredictionResponse: ... + ) -> PredictionResponse: + ... + + @overload + async def predict_async_task( + self, payload: dict, model: Optional[str] = None, max_retries=None + ) -> AsyncPredictionTask: + ... + + @overload + def predict_async_task( + self, payload: dict, model: Optional[str] = None, max_retries=None + ) -> AsyncPredictionTask: + ... + + def predict_async_task( + self, payload: dict, model: Optional[str] = None, max_retries=None + ) -> AsyncPredictionTask: + ... + + @classmethod + def _construct_task_result(cls, response): + try: + validated = AsyncTaskResultFactory(httpx_response=response).construct() + except BaseFlyMyAIException as e: + raise FlyMyAIAsyncTaskException.from_base_exception(e) + try: + return AsyncPredictionResponseList.from_response(validated, status=200) + except pydantic.ValidationError as e: + raise AsyncPredictionResponseList.convert_error(e) from e + + def _async_prediction_task_construct( + self, response: httpx.Response, client_info: APIKeyClientInfo + ) -> AsyncPredictionTask: + async_prediction_task = AsyncPredictionTask[self.__class__].model_validate_json( + json_data=response.content + ) + async_prediction_task.client_info = client_info + async_prediction_task.set_client(self) + return async_prediction_task + + @overload + async def prediction_task_result( + self, prediction_task: AsyncPredictionTask, timeout: Optional[float] = None + ): + ... + + @overload + def prediction_task_result( + self, prediction_task: AsyncPredictionTask, timeout: Optional[float] = None + ): + ... + + def prediction_task_result( + self, prediction_task: AsyncPredictionTask, timeout: Optional[float] = None + ): + ... @overload async def openapi_schema( self, model: Optional[str] = None, max_retries=None - ) -> OpenAPISchemaResponse: ... + ) -> OpenAPISchemaResponse: + ... @overload def openapi_schema( self, model: Optional[str] = None, max_retries=None - ) -> OpenAPISchemaResponse: ... + ) -> OpenAPISchemaResponse: + ... def openapi_schema( self, model: Optional[str] = None, max_retries=None - ) -> OpenAPISchemaResponse: ... + ) -> OpenAPISchemaResponse: + ... @overload async def stream( self, payload: dict, model: Optional[str] = None, - ) -> AsyncIterator[PredictionPartial]: ... + ) -> AsyncIterator[PredictionPartial]: + ... @overload def stream( self, payload: dict, model: Optional[str] = None, - ) -> Iterator[PredictionPartial]: ... + ) -> Iterator[PredictionPartial]: + ... def stream( self, payload: dict, model: Optional[str] = None, - ): ... + ): + ... def _stream_iterator( self, client_info, payload: MultipartPayload, is_long_stream: bool @@ -123,7 +204,7 @@ def _stream_iterator( @staticmethod def _wrap_request(request_callback: Callable): response = request_callback() - return ResponseFactory(httpx_response=response).construct() + return SSEInferenceResponseFactory(httpx_response=response).construct() def is_closed(self) -> bool: return self._client.is_closed @@ -148,7 +229,8 @@ async def cancel_prediction( prediction_id: str, model: Optional[str] = None, client_info: APIKeyClientInfo = None, - ): ... + ): + ... @overload def cancel_prediction( @@ -156,11 +238,13 @@ def cancel_prediction( prediction_id: str, model: Optional[str] = None, client_info: APIKeyClientInfo = None, - ): ... + ): + ... def cancel_prediction( self, prediction_id: str, model: Optional[str] = None, client_info: APIKeyClientInfo = None, - ): ... + ): + ... diff --git a/flymyai/core/exceptions.py b/flymyai/core/exceptions.py index f89ca80..cc3ccef 100644 --- a/flymyai/core/exceptions.py +++ b/flymyai/core/exceptions.py @@ -1,23 +1,30 @@ import datetime -from typing import List, Type +from typing import List, Union from ._response import FlyMyAIResponse +from .models.base import ResponseLike from .models.error_responses import ( FlyMyAI401Response, FlyMyAI422Response, Base4xxResponse, FlyMyAI400Response, FlyMyAI421Response, + FlyMyAI425Response, ) -class ImproperlyConfiguredClientException(Exception): ... +class RetryTimeoutExceededException(TimeoutError): + ... + + +class ImproperlyConfiguredClientException(Exception): + ... class BaseFlyMyAIException(Exception): msg: str requires_retry: bool - _response: FlyMyAIResponse + _response: Union[FlyMyAIResponse, ResponseLike] def __init__(self, msg, requires_retry=False, response=None): self.msg = msg @@ -29,37 +36,46 @@ def response(self): return self._response @classmethod - def from_5xx(cls, response: FlyMyAIResponse): + def internal_error_mapping(cls): + return { + 500: lambda msg, response: cls(msg, False, response=response), + 502: lambda msg, response: cls(msg, True, response=response), + 503: lambda msg, response: cls(msg, False, response=response), + 504: lambda msg, response: cls(msg, True, response=response), + 524: lambda msg, response: cls(msg, True, response=response), + # unknown issue, probably detected on the client side + 599: lambda msg, response: cls(msg, False, response=response), + # broker issues, they are not billed at all + 5000: lambda msg, response: cls(msg, False, response=response), + 5320: lambda msg, response: cls(msg, True, response=response), + } + + @classmethod + def from_5xx(cls, response: Union[FlyMyAIResponse, ResponseLike]): msg = f""" INTERNAL SERVER ERROR ({response.status_code}): REQUEST URL: {response.url}; Content [0:250]: {response.content.decode()[0:250]} Timestamp [UTC]: {datetime.datetime.utcnow()} """ - internal_error_mapping = { - 500: lambda: cls(msg, False, response=response), - 502: lambda: cls(msg, True, response=response), - 503: lambda: cls(msg, False, response=response), - 504: lambda: cls(msg, True, response=response), - 524: lambda: cls(msg, True, response=response), - # unknown issue, probably detected on the client side - 599: lambda: cls(msg, False, response=response), - # broker issues, they are not billed at all - 5000: lambda: cls(msg, False, response=response), - 5320: lambda: cls(msg, True, response=response), - } - return internal_error_mapping.get( - response.status_code, lambda: cls(msg, False) - )() + return cls.internal_error_mapping().get( + response.status_code, lambda m, _: cls(msg, False) + )(msg, response) @classmethod - def from_4xx(cls, response: FlyMyAIResponse): - response_validation_templates = { + def client_error_mapping(cls): + return { 400: FlyMyAI400Response, 401: FlyMyAI401Response, 421: FlyMyAI421Response, 422: FlyMyAI422Response, + # requested too early + 425: FlyMyAI425Response, } + + @classmethod + def from_4xx(cls, response: Union[FlyMyAIResponse, ResponseLike]): + response_validation_templates = cls.client_error_mapping() response_4xx = response_validation_templates.get( response.status_code, Base4xxResponse ).from_response(response) @@ -70,7 +86,7 @@ def from_4xx(cls, response: FlyMyAIResponse): ) @classmethod - def from_response(cls, response: FlyMyAIResponse): + def from_response(cls, response: Union[FlyMyAIResponse, ResponseLike]): if 400 <= response.status_code < 500: return cls.from_4xx(response) if response.status_code >= 500: @@ -86,11 +102,16 @@ def from_base_exception(cls, exception: BaseFlyMyAIException): return cls(exception.msg, exception.requires_retry, exception.response) -class FlyMyAIOpenAPIException(BaseFlyMyAIException): ... +class FlyMyAIAsyncTaskException(FlyMyAIPredictException): + ... + + +class FlyMyAIOpenAPIException(BaseFlyMyAIException): + ... class FlyMyAIExceptionGroup(Exception): - def __init__(self, errors: List[BaseFlyMyAIException], **kwargs): + def __init__(self, errors: List[Exception], **kwargs): self.errors = errors exceptions_message = ";".join([str(err) for err in errors]) self.message = f"FlyMyAI exception history: {exceptions_message}" @@ -98,3 +119,13 @@ def __init__(self, errors: List[BaseFlyMyAIException], **kwargs): def __str__(self): return self.message + + def fma_errors(self): + return list( + filter(lambda err: isinstance(err, BaseFlyMyAIException), self.errors) + ) + + def non_fma_errors(self): + return list( + filter(lambda err: not isinstance(err, BaseFlyMyAIException), self.errors) + ) diff --git a/flymyai/core/models/base.py b/flymyai/core/models/base.py new file mode 100644 index 0000000..fe55d17 --- /dev/null +++ b/flymyai/core/models/base.py @@ -0,0 +1,16 @@ +import dataclasses + +import httpx + + +@dataclasses.dataclass +class ResponseLike: + status_code: int + url: httpx.URL + content: bytes + + def to_msg(self): + return f""" + BAD REQUEST DETECTED ({self.status_code}): + REQUEST URL: {self.url}; + """ diff --git a/flymyai/core/models/error_responses.py b/flymyai/core/models/error_responses.py index 1b29e67..193fa0f 100644 --- a/flymyai/core/models/error_responses.py +++ b/flymyai/core/models/error_responses.py @@ -1,11 +1,14 @@ import dataclasses import json +from typing import Union import httpx +from flymyai.core.models.base import ResponseLike + @dataclasses.dataclass -class Base4xxResponse: +class Base4xxResponse(ResponseLike): """ Base class for all 4xx """ @@ -16,14 +19,8 @@ class Base4xxResponse: requires_retry: bool = False - def to_msg(self): - return f""" - BAD REQUEST DETECTED ({self.status_code}): - REQUEST URL: {self.url}; - """ - @classmethod - def from_response(cls, response: httpx.Response): + def from_response(cls, response: Union[httpx.Response, ResponseLike]): return cls(response.status_code, response.url, response.content) @@ -84,3 +81,8 @@ def to_msg(self): if detail := jsoned.get("detail"): msg += f"Details: {detail}" return msg + + +@dataclasses.dataclass +class FlyMyAI425Response(Base4xxResponse): + requires_retry: bool = True diff --git a/flymyai/core/models/successful_responses.py b/flymyai/core/models/successful_responses.py index f645782..4f3023e 100644 --- a/flymyai/core/models/successful_responses.py +++ b/flymyai/core/models/successful_responses.py @@ -1,14 +1,21 @@ -from typing import Optional +from typing import Optional, Generic, TypeVar, List, TypedDict, Union, Awaitable import pydantic -from pydantic import PrivateAttr +from pydantic import PrivateAttr, model_validator, Field +from pydantic_core._pydantic_core import PydanticCustomError +from typing_extensions import Self from flymyai.core._response import FlyMyAIResponse +from flymyai.core.authorizations import APIKeyClientInfo +from flymyai.core.exceptions import BaseFlyMyAIException, FlyMyAIExceptionGroup +from flymyai.core.models.base import ResponseLike from flymyai.core.types.event_types import EventType +_ClientT = TypeVar("_ClientT", bound="BaseClient") + class BaseFromServer(pydantic.BaseModel): - _response: FlyMyAIResponse = PrivateAttr() + _response: Optional[FlyMyAIResponse] = PrivateAttr(default=None) @property def response(self): @@ -19,20 +26,23 @@ def from_response(cls, response: FlyMyAIResponse, **kwargs): status_code = kwargs.pop("status", response.status_code) response_json = response.json() response_json["status"] = response_json.get("status", status_code) - self = cls(**response_json, **kwargs) + ctx = kwargs.pop("context", None) + self = cls.model_validate(dict(**response_json, **kwargs), context=ctx) self._response = response return self -class PredictionResponse(BaseFromServer): +class BasePredictionResponse(BaseFromServer): + exc_history: Optional[list] = Field(default_factory=list) + output_data: dict + + +class PredictionResponse(BasePredictionResponse): """ Prediction response from FlyMyAI """ - exc_history: Optional[list] - output_data: dict status: int - inference_time: Optional[float] = None @property @@ -40,6 +50,88 @@ def response(self): return self._response +class AsyncPredictionResponse(BasePredictionResponse): + infer_details: dict + output_data: dict = Field(validation_alias="response") + + @property + def status(self) -> int: + return self.infer_details.get("status", 200) + + @model_validator(mode="after") + def validate_inference_details(self, ctx): + self._response = ctx.context.get("_response") + if (status := self.infer_details.get("status", 200)) != 200: + raise PydanticCustomError( + "inference_response_error", + "Inference details contains incorrect status: {failure_status}", + dict(failure_status=status, instance=self), + ) + return self + + +class AsyncPredictionResponseList(BaseFromServer): + inference_responses: List[AsyncPredictionResponse] + + @classmethod + def from_response(cls, response: FlyMyAIResponse, **kwargs): + result = super().from_response( + response, context={"_response": response}, **kwargs + ) + return result + + class _ErrorCTX(TypedDict): + failure_status: int + instance: Self + + @classmethod + def convert_error(cls, e: pydantic.ValidationError) -> FlyMyAIExceptionGroup: + errors = [] + error_data = e.errors(include_input=False) + for error in error_data: + err_t = error.get("type") + ctx: Optional[cls._ErrorCTX] = error.get("ctx", {}) + if err_t == "inference_response_error": + if not ctx: + raise KeyError("Inference response pydantic error should have ctx!") + errors.append( + BaseFlyMyAIException.from_response( + ResponseLike( + status_code=ctx["failure_status"], + url=getattr(ctx["instance"], "_response").url, + content=ctx["instance"].model_dump_json().encode(), + ) + ) + ) + if len(errors) != error_data: + errors.append(e) + exc_group = FlyMyAIExceptionGroup(errors) + return exc_group + + +class AsyncPredictionTask(Generic[_ClientT], BaseFromServer): + _affiliated_client: Optional[_ClientT] = PrivateAttr(default=None) + _client_info: APIKeyClientInfo = PrivateAttr(default=None) + + prediction_id: str + + def result( + self, timeout=None + ) -> Union[AsyncPredictionResponseList, Awaitable[AsyncPredictionResponseList]]: + return self._affiliated_client.prediction_task_result(self, timeout=timeout) + + @property + def client_info(self) -> APIKeyClientInfo: + return self._client_info + + @client_info.setter + def client_info(self, v: APIKeyClientInfo): + self._client_info = v + + def set_client(self, client: _ClientT): + self._affiliated_client = client + + class OpenAPISchemaResponse(BaseFromServer): """ OpenAPI schema for the current project. Use it to construct your own schema diff --git a/flymyai/core/response_factory/__init__.py b/flymyai/core/response_factory/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/flymyai/core/response_factory/async_task_result_factory.py b/flymyai/core/response_factory/async_task_result_factory.py new file mode 100644 index 0000000..4fb26d0 --- /dev/null +++ b/flymyai/core/response_factory/async_task_result_factory.py @@ -0,0 +1,10 @@ +from flymyai.core.response_factory.base_response_factory import ResponseFactory + + +class MaybeNotExistent(Exception): + ... + + +class AsyncTaskResultFactory(ResponseFactory): + def construct(self): + return self._base_construct_from_httpx_response() diff --git a/flymyai/core/response_factory/base_response_factory.py b/flymyai/core/response_factory/base_response_factory.py new file mode 100644 index 0000000..862785b --- /dev/null +++ b/flymyai/core/response_factory/base_response_factory.py @@ -0,0 +1,42 @@ +from abc import abstractmethod + +import httpx + +from flymyai.core._response import FlyMyAIResponse +from flymyai.core._streaming import ServerSentEvent +from flymyai.core.exceptions import BaseFlyMyAIException + + +class ResponseFactoryException(Exception): + ... + + +class ResponseFactory(object): + """ + Factory for FlyMyAIResponse objects + """ + + def __init__( + self, + httpx_response: httpx.Response = None, + httpx_request: httpx.Request = None, + *_, + **__ + ): + if httpx_response and not httpx_request: + self.httpx_request = httpx_response.request + else: + self.httpx_request = httpx_request + self.httpx_response = httpx_response + + def _base_construct_from_httpx_response(self): + if self.httpx_response.status_code < 400: + return FlyMyAIResponse.from_httpx(self.httpx_response) + else: + raise BaseFlyMyAIException.from_response( + FlyMyAIResponse.from_httpx(self.httpx_response) + ) + + @abstractmethod + def construct(self): + raise NotImplementedError diff --git a/flymyai/core/_response_factory.py b/flymyai/core/response_factory/plain_inference_response_factory.py similarity index 73% rename from flymyai/core/_response_factory.py rename to flymyai/core/response_factory/plain_inference_response_factory.py index f7feab5..988c21a 100644 --- a/flymyai/core/_response_factory.py +++ b/flymyai/core/response_factory/plain_inference_response_factory.py @@ -3,16 +3,13 @@ from flymyai.core._response import FlyMyAIResponse from flymyai.core._streaming import ServerSentEvent from flymyai.core.exceptions import BaseFlyMyAIException +from flymyai.core.response_factory.base_response_factory import ( + ResponseFactory, + ResponseFactoryException, +) -class ResponseFactoryException(Exception): ... - - -class ResponseFactory(object): - """ - Factory for FlyMyAIResponse objects - """ - +class SSEInferenceResponseFactory(ResponseFactory): def __init__( self, httpx_response: httpx.Response = None, @@ -21,12 +18,8 @@ def __init__( ): if not httpx_response and not sse: raise ResponseFactoryException("httpx_response and sse params required") - if httpx_response and not httpx_request: - self.httpx_request = httpx_response.request - else: - self.httpx_request = httpx_request + super().__init__(httpx_response, httpx_request) self.sse = sse - self.httpx_response = httpx_response def get_sse_status_code(self): return self.sse.json().get( @@ -58,14 +51,6 @@ def _base_construct_from_sse(self): ) ) - def _base_construct_from_httpx_response(self): - if self.httpx_response.status_code < 400: - return FlyMyAIResponse.from_httpx(self.httpx_response) - else: - raise BaseFlyMyAIException.from_response( - FlyMyAIResponse.from_httpx(self.httpx_response) - ) - def construct(self): if self.sse: return self._base_construct_from_sse() diff --git a/flymyai/utils/utils.py b/flymyai/utils/utils.py index 2219c11..1e5c162 100644 --- a/flymyai/utils/utils.py +++ b/flymyai/utils/utils.py @@ -1,31 +1,67 @@ -from typing import Callable, Awaitable, Type +import asyncio +import threading +import time +from typing import Callable, Awaitable, Type, Optional import httpx +from flymyai.core.exceptions import RetryTimeoutExceededException + def retryable_callback( cb: Callable, - retries: int, + retries: Optional[int], append_on_exception_cls: Type[Exception], exception_group_cls: Type[Exception], + timeout_seconds: Optional[float] = None, + await_treshold: Optional[float] = None, ): """ Decorator to retry a function """ - retries_history = [] - for _ in range(retries): - try: - res = cb() - return retries_history, res - except append_on_exception_cls as e: - retries_history.append(e) - if e.requires_retry: - continue - else: - raise exception_group_cls(retries_history) - else: + + should_stop = False + result_container = None + exc_container = None + + def wrapper(): + nonlocal should_stop, result_container, exc_container + retries_history = [] + r = 0 + while r != retries: + try: + res = cb() + result_container = retries_history, res + return + except append_on_exception_cls as e: + retries_history.append(e) + if e.requires_retry and not should_stop: + time.sleep(await_treshold) + r += 1 + continue + else: + exc_container = exception_group_cls(retries_history) + return + except exception_group_cls as e1: + exc_container = e1 + return + except Exception as e2: + exc_container = e2 + return exception_gr = exception_group_cls(retries_history) - raise exception_gr + exc_container = exception_gr + + waiting_thread = threading.Thread(target=wrapper) + waiting_thread.start() + if timeout_seconds is not None: + timeout_seconds += 0.01 + waiting_thread.join(timeout=timeout_seconds) + should_stop = True + if not result_container and not exc_container: + raise RetryTimeoutExceededException() + if exc_container: + raise exc_container + return result_container async def aretryable_callback( @@ -33,23 +69,37 @@ async def aretryable_callback( retries, append_on_exception_cls: Type[Exception], exception_group_cls: Type[Exception], + timeout_seconds: Optional[float] = None, + await_treshold: Optional[float] = None, ): """ Decorator to retry a function """ - retries_history = [] - for _ in range(retries): - try: - res = await cb() - return retries_history, res - except append_on_exception_cls as e: - retries_history.append(e) - if e.requires_retry: - continue - else: - raise exception_group_cls(retries_history) - except Exception as e: - raise e - else: - exception_gr = exception_group_cls(retries_history) - raise exception_gr + if timeout_seconds is not None: + timeout_seconds += 0.01 + + async def wrapper(): + retries_history = [] + r = 0 + while r != retries: + try: + res = await cb() + return retries_history, res + except append_on_exception_cls as e1: + retries_history.append(e1) + if e1.requires_retry: + await asyncio.sleep(await_treshold) + r += 1 + continue + else: + raise exception_group_cls(retries_history) + except Exception as e2: + raise e2 + else: + exception_gr = exception_group_cls(retries_history) + raise exception_gr + + try: + return await asyncio.wait_for(wrapper(), timeout_seconds) + except asyncio.TimeoutError as e: + raise RetryTimeoutExceededException() from e diff --git a/tests/fixtures/test_async_inference.json b/tests/fixtures/test_async_inference.json new file mode 100644 index 0000000..b81a009 --- /dev/null +++ b/tests/fixtures/test_async_inference.json @@ -0,0 +1,15 @@ +{ + "client_auth_fixture": { + "model": "flymyai/perf_cpn-1", + "apikey_environ": "FMA_APIKEY" + }, + "address_fixture": "https://dev-api.flymy.ai", + "fake_payload_fixture": { + "prompt": "2w432442" + }, + "binary_file_paths": {}, + "broken_payload_fixture": { + "prompt": "2w432442", + "delay": -1 + } +} \ No newline at end of file diff --git a/tests/test_async_inference.py b/tests/test_async_inference.py new file mode 100644 index 0000000..98c7080 --- /dev/null +++ b/tests/test_async_inference.py @@ -0,0 +1,122 @@ +import os +import pathlib + +import pytest + +from flymyai.core.exceptions import ( + RetryTimeoutExceededException, + BaseFlyMyAIException, + FlyMyAIExceptionGroup, +) +from flymyai.core.models.successful_responses import ( + AsyncPredictionResponseList, + AsyncPredictionTask, +) +from .FixtureFactory import FixtureFactory +from flymyai import client as sync_client, async_client + +factory = FixtureFactory(__file__) + + +@pytest.fixture +def address_fixture(): + os.environ["FLYMYAI_DSN"] = factory("address_fixture") + + +@pytest.fixture +def binary_file_paths(): + return factory("binary_file_paths") + + +@pytest.fixture +def fake_payload_fixture(binary_file_paths) -> dict: + files = {} + for k, v in binary_file_paths.items(): + files[k] = pathlib.Path(v) + payload = factory("fake_payload_fixture") + payload.update(files) + return payload + + +@pytest.fixture +def broken_payload_fixture(binary_file_paths) -> dict: + files = {} + for k, v in binary_file_paths.items(): + files[k] = pathlib.Path(v) + payload = factory("broken_payload_fixture") + payload.update(files) + return payload + + +@pytest.fixture +def client_auth_fixture() -> dict: + client_auth_fixture: dict = factory("client_auth_fixture") + auth_apikey_env = client_auth_fixture.pop("apikey_environ", "") + if auth_apikey_env: + client_auth_fixture["apikey"] = os.getenv( + auth_apikey_env, client_auth_fixture.get("apikey") + ) + return client_auth_fixture + + +def test_sync_client_async_inference( + address_fixture, fake_payload_fixture, client_auth_fixture +): + client = sync_client(**client_auth_fixture) + prediction_task = client.predict_async_task(payload=fake_payload_fixture) + assert prediction_task.prediction_id is not None + with pytest.raises(RetryTimeoutExceededException): + res = prediction_task.result(0) + assert res is None # should not achieve this point + res = prediction_task.result() + assert isinstance(res, AsyncPredictionResponseList) + assert all(map(lambda x: x.infer_details["status"] == 200, res.inference_responses)) + + mocked_pred_task = AsyncPredictionTask(prediction_id="123") + mocked_pred_task.set_client(client) + with pytest.raises(FlyMyAIExceptionGroup): + mocked_pred_task.result() + + +@pytest.mark.asyncio +async def test_async_client_async_inference( + address_fixture, fake_payload_fixture, client_auth_fixture +): + client = async_client(**client_auth_fixture) + prediction_task = await client.predict_async_task(payload=fake_payload_fixture) + assert prediction_task.prediction_id is not None + + with pytest.raises(RetryTimeoutExceededException): + res = await prediction_task.result(0) + assert res is None # should not achieve this point + res = await prediction_task.result() + assert isinstance(res, AsyncPredictionResponseList), res + assert all(map(lambda x: x.infer_details["status"] == 200, res.inference_responses)) + + mocked_pred_task = AsyncPredictionTask(prediction_id="123") + mocked_pred_task.set_client(client) + with pytest.raises(FlyMyAIExceptionGroup): + await mocked_pred_task.result() + + +def test_sync_client_async_inference_with_guaranteed_error( + address_fixture, broken_payload_fixture, client_auth_fixture +): + client = sync_client(**client_auth_fixture) + prediction_task = client.predict_async_task(payload=broken_payload_fixture) + assert prediction_task.prediction_id is not None + with pytest.raises(FlyMyAIExceptionGroup): + res = prediction_task.result() + assert res is None # should not achieve this point + + +@pytest.mark.asyncio +async def test_async_client_async_inference_with_guaranteed_error( + address_fixture, broken_payload_fixture, client_auth_fixture +): + client = async_client(**client_auth_fixture) + prediction_task = await client.predict_async_task(payload=broken_payload_fixture) + assert prediction_task.prediction_id is not None + with pytest.raises(FlyMyAIExceptionGroup): + res = await prediction_task.result() + assert res is None # should not achieve this point From a18f6319dc0ea079273f9e15c8f01a7e0c2bdf7b Mon Sep 17 00:00:00 2001 From: D1-3105 Date: Fri, 18 Oct 2024 03:13:40 +0300 Subject: [PATCH 2/5] lint fix --- flymyai/core/clients/base_client.py | 54 +++++++------------ flymyai/core/exceptions.py | 12 ++--- .../async_task_result_factory.py | 3 +- .../response_factory/base_response_factory.py | 3 +- 4 files changed, 24 insertions(+), 48 deletions(-) diff --git a/flymyai/core/clients/base_client.py b/flymyai/core/clients/base_client.py index eb8c5db..fe02874 100644 --- a/flymyai/core/clients/base_client.py +++ b/flymyai/core/clients/base_client.py @@ -76,36 +76,30 @@ def amend_client_info(self, model: Optional[str] = None): @overload async def predict( self, payload: dict, model: Optional[str] = None, max_retries=None - ) -> PredictionResponse: - ... + ) -> PredictionResponse: ... @overload def predict( self, payload: dict, model: Optional[str] = None, max_retries=None - ) -> PredictionResponse: - ... + ) -> PredictionResponse: ... def predict( self, payload: dict, model: Optional[str] = None, max_retries=None - ) -> PredictionResponse: - ... + ) -> PredictionResponse: ... @overload async def predict_async_task( self, payload: dict, model: Optional[str] = None, max_retries=None - ) -> AsyncPredictionTask: - ... + ) -> AsyncPredictionTask: ... @overload def predict_async_task( self, payload: dict, model: Optional[str] = None, max_retries=None - ) -> AsyncPredictionTask: - ... + ) -> AsyncPredictionTask: ... def predict_async_task( self, payload: dict, model: Optional[str] = None, max_retries=None - ) -> AsyncPredictionTask: - ... + ) -> AsyncPredictionTask: ... @classmethod def _construct_task_result(cls, response): @@ -131,59 +125,50 @@ def _async_prediction_task_construct( @overload async def prediction_task_result( self, prediction_task: AsyncPredictionTask, timeout: Optional[float] = None - ): - ... + ): ... @overload def prediction_task_result( self, prediction_task: AsyncPredictionTask, timeout: Optional[float] = None - ): - ... + ): ... def prediction_task_result( self, prediction_task: AsyncPredictionTask, timeout: Optional[float] = None - ): - ... + ): ... @overload async def openapi_schema( self, model: Optional[str] = None, max_retries=None - ) -> OpenAPISchemaResponse: - ... + ) -> OpenAPISchemaResponse: ... @overload def openapi_schema( self, model: Optional[str] = None, max_retries=None - ) -> OpenAPISchemaResponse: - ... + ) -> OpenAPISchemaResponse: ... def openapi_schema( self, model: Optional[str] = None, max_retries=None - ) -> OpenAPISchemaResponse: - ... + ) -> OpenAPISchemaResponse: ... @overload async def stream( self, payload: dict, model: Optional[str] = None, - ) -> AsyncIterator[PredictionPartial]: - ... + ) -> AsyncIterator[PredictionPartial]: ... @overload def stream( self, payload: dict, model: Optional[str] = None, - ) -> Iterator[PredictionPartial]: - ... + ) -> Iterator[PredictionPartial]: ... def stream( self, payload: dict, model: Optional[str] = None, - ): - ... + ): ... def _stream_iterator( self, client_info, payload: MultipartPayload, is_long_stream: bool @@ -229,8 +214,7 @@ async def cancel_prediction( prediction_id: str, model: Optional[str] = None, client_info: APIKeyClientInfo = None, - ): - ... + ): ... @overload def cancel_prediction( @@ -238,13 +222,11 @@ def cancel_prediction( prediction_id: str, model: Optional[str] = None, client_info: APIKeyClientInfo = None, - ): - ... + ): ... def cancel_prediction( self, prediction_id: str, model: Optional[str] = None, client_info: APIKeyClientInfo = None, - ): - ... + ): ... diff --git a/flymyai/core/exceptions.py b/flymyai/core/exceptions.py index cc3ccef..2e0a13f 100644 --- a/flymyai/core/exceptions.py +++ b/flymyai/core/exceptions.py @@ -13,12 +13,10 @@ ) -class RetryTimeoutExceededException(TimeoutError): - ... +class RetryTimeoutExceededException(TimeoutError): ... -class ImproperlyConfiguredClientException(Exception): - ... +class ImproperlyConfiguredClientException(Exception): ... class BaseFlyMyAIException(Exception): @@ -102,12 +100,10 @@ def from_base_exception(cls, exception: BaseFlyMyAIException): return cls(exception.msg, exception.requires_retry, exception.response) -class FlyMyAIAsyncTaskException(FlyMyAIPredictException): - ... +class FlyMyAIAsyncTaskException(FlyMyAIPredictException): ... -class FlyMyAIOpenAPIException(BaseFlyMyAIException): - ... +class FlyMyAIOpenAPIException(BaseFlyMyAIException): ... class FlyMyAIExceptionGroup(Exception): diff --git a/flymyai/core/response_factory/async_task_result_factory.py b/flymyai/core/response_factory/async_task_result_factory.py index 4fb26d0..675ef02 100644 --- a/flymyai/core/response_factory/async_task_result_factory.py +++ b/flymyai/core/response_factory/async_task_result_factory.py @@ -1,8 +1,7 @@ from flymyai.core.response_factory.base_response_factory import ResponseFactory -class MaybeNotExistent(Exception): - ... +class MaybeNotExistent(Exception): ... class AsyncTaskResultFactory(ResponseFactory): diff --git a/flymyai/core/response_factory/base_response_factory.py b/flymyai/core/response_factory/base_response_factory.py index 862785b..a212d12 100644 --- a/flymyai/core/response_factory/base_response_factory.py +++ b/flymyai/core/response_factory/base_response_factory.py @@ -7,8 +7,7 @@ from flymyai.core.exceptions import BaseFlyMyAIException -class ResponseFactoryException(Exception): - ... +class ResponseFactoryException(Exception): ... class ResponseFactory(object): From 1d6964ecbe1bdf0a6a8cb3c6a1858d37dc4af3b0 Mon Sep 17 00:00:00 2001 From: D1-3105 Date: Fri, 18 Oct 2024 03:18:23 +0300 Subject: [PATCH 3/5] fix --- flymyai/core/models/successful_responses.py | 2 +- flymyai/utils/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flymyai/core/models/successful_responses.py b/flymyai/core/models/successful_responses.py index 4f3023e..482c94f 100644 --- a/flymyai/core/models/successful_responses.py +++ b/flymyai/core/models/successful_responses.py @@ -109,7 +109,7 @@ def convert_error(cls, e: pydantic.ValidationError) -> FlyMyAIExceptionGroup: return exc_group -class AsyncPredictionTask(Generic[_ClientT], BaseFromServer): +class AsyncPredictionTask(BaseFromServer, Generic[_ClientT]): _affiliated_client: Optional[_ClientT] = PrivateAttr(default=None) _client_info: APIKeyClientInfo = PrivateAttr(default=None) diff --git a/flymyai/utils/utils.py b/flymyai/utils/utils.py index 1e5c162..6ba69a0 100644 --- a/flymyai/utils/utils.py +++ b/flymyai/utils/utils.py @@ -36,7 +36,7 @@ def wrapper(): except append_on_exception_cls as e: retries_history.append(e) if e.requires_retry and not should_stop: - time.sleep(await_treshold) + time.sleep(await_treshold or 0) r += 1 continue else: From 2ee3ab28d9aea924181b28e4897e9b95e4bdd3ea Mon Sep 17 00:00:00 2001 From: D1-3105 Date: Fri, 18 Oct 2024 03:23:52 +0300 Subject: [PATCH 4/5] fix sleep --- flymyai/utils/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flymyai/utils/utils.py b/flymyai/utils/utils.py index 6ba69a0..f0177da 100644 --- a/flymyai/utils/utils.py +++ b/flymyai/utils/utils.py @@ -88,7 +88,7 @@ async def wrapper(): except append_on_exception_cls as e1: retries_history.append(e1) if e1.requires_retry: - await asyncio.sleep(await_treshold) + await asyncio.sleep(await_treshold or 0) r += 1 continue else: From 8f7de353b3e3dcd6539970c0e564c916325e7518 Mon Sep 17 00:00:00 2001 From: teith Date: Thu, 7 Nov 2024 22:55:22 +0100 Subject: [PATCH 5/5] Updated doc --- README.md | 74 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/README.md b/README.md index e92a656..d37d102 100644 --- a/README.md +++ b/README.md @@ -233,3 +233,77 @@ asyncio.run(main()) # Continue with other operations while the model runs in the background ``` + +## Asynchronous Prediction Tasks + +For long-running operations, FlyMyAI provides asynchronous prediction tasks. This allows you to submit a task and check its status later, which is useful for handling time-consuming predictions without blocking your application. + +### Using Synchronous Client + +```python +from flymyai import client +from flymyai.core.exceptions import ( + RetryTimeoutExceededException, + FlyMyAIExceptionGroup, +) + +# Initialize client +fma_client = client(apikey="fly-secret-key") + +# Submit async prediction task +prediction_task = fma_client.predict_async_task( + model="flymyai/flux-schnell", + payload={"prompt": "Funny Cat with Stupid Dog"} +) + +try: + # Get result + result = prediction_task.result() + + print(f"Prediction completed: {result.inference_responses}") +except RetryTimeoutExceededException: + print("Prediction is taking longer than expected") +except FlyMyAIExceptionGroup as e: + print(f"Prediction failed: {e}") +``` + +### Using Asynchronous Client + +```python +import asyncio +from flymyai import async_client +from flymyai.core.exceptions import ( + RetryTimeoutExceededException, + FlyMyAIExceptionGroup, +) + +async def run_prediction(): + # Initialize async client + fma_client = async_client(apikey="fly-secret-key") + + # Submit async prediction task + prediction_task = await fma_client.predict_async_task( + model="flymyai/flux-schnell", + payload={"prompt": "Funny Cat with Stupid Dog"} +) + + try: + # Await result with default timeout + result = await prediction_task.result() + print(f"Prediction completed: {result.inference_responses}") + + # Check response status + all_successful = all( + resp.infer_details["status"] == 200 + for resp in result.inference_responses + ) + print(f"All predictions successful: {all_successful}") + + except RetryTimeoutExceededException: + print("Prediction is taking longer than expected") + except FlyMyAIExceptionGroup as e: + print(f"Prediction failed: {e}") + +# Run async function +asyncio.run(run_prediction()) +``` \ No newline at end of file