From db01f74e880323039d3c7cad38b41523a816fff5 Mon Sep 17 00:00:00 2001 From: D1-3105 Date: Mon, 22 Jul 2024 19:38:13 +0300 Subject: [PATCH 1/5] details-stream --- flymyai/core/_client.py | 80 ++++++++++++++++++++++++++++++-------- flymyai/core/_response.py | 10 +++++ flymyai/core/_streaming.py | 2 +- flymyai/core/models.py | 6 +++ tests/test_stream.py | 26 +++++++++---- 5 files changed, 98 insertions(+), 26 deletions(-) diff --git a/flymyai/core/_client.py b/flymyai/core/_client.py index c758fc7..88d469b 100644 --- a/flymyai/core/_client.py +++ b/flymyai/core/_client.py @@ -14,6 +14,7 @@ import httpx +from flymyai.core._response import FlyMyAIResponse from flymyai.core._response_factory import ResponseFactory from flymyai.core._streaming import SSEDecoder from flymyai.core.authorizations import APIKeyClientInfo @@ -28,6 +29,7 @@ PredictionResponse, OpenAPISchemaResponse, PredictionPartial, + StreamDetails, ) from flymyai.multipart.payload import MultipartPayload from flymyai.utils.utils import retryable_callback, aretryable_callback @@ -168,6 +170,34 @@ def _construct_client(self): raise NotImplemented +class PredictionStream: + stream_details: StreamDetails + + def __init__(self, response_iterator: Iterator): + self.response_iterator = response_iterator + + def __iter__(self): + return self + + def __next__(self): + response_end = None + try: + next_resp: FlyMyAIResponse = self.response_iterator.__next__() + response_end = next_resp + return PredictionPartial.from_response(response_end) + except BaseFlyMyAIException as e: + response_end = e.response + raise e + finally: + if not response_end: + raise StopIteration() + stream_details_marshalled = response_end.json().get("stream_details") + if stream_details_marshalled: + self.stream_details = StreamDetails.model_validate( + stream_details_marshalled + ) + + class BaseSyncClient(BaseClient[httpx.Client]): def _construct_client(self): return httpx.Client( @@ -251,13 +281,8 @@ def _stream(self, client_info: APIKeyClientInfo, payload: dict): def stream(self, payload: dict, model: Optional[str] = None): stream_iter = self._stream(self.amend_client_info(model), payload) - last_response = None - for response in stream_iter: - response.stream = stream_iter - yield PredictionPartial.from_response(response) - last_response = response - if last_response: - last_response.is_stream_consumed = True + stream_wrapper = PredictionStream(stream_iter) + return stream_wrapper def _openapi_schema(self, client_info: APIKeyClientInfo): """ @@ -307,6 +332,34 @@ def run_predict(cls, apikey: str, model: str, payload: dict): return client.predict(payload) +class AsyncPredictionStream: + stream_details: StreamDetails + + def __init__(self, response_iterator: AsyncIterator): + self.response_iterator = response_iterator + + def __aiter__(self): + return self + + async def __anext__(self): + response_end = None + try: + next_resp: FlyMyAIResponse = await self.response_iterator.__anext__() + response_end = next_resp + return PredictionPartial.from_response(response_end) + except BaseFlyMyAIException as e: + response_end = e.response + raise e + finally: + if not response_end: + raise StopAsyncIteration() + stream_details_marshalled = response_end.json().get("stream_details") + if stream_details_marshalled: + self.stream_details = StreamDetails.model_validate( + stream_details_marshalled + ) + + class BaseAsyncClient(BaseClient[httpx.AsyncClient]): def _construct_client(self): return httpx.AsyncClient( @@ -430,17 +483,10 @@ async def _stream(self, client_info: APIKeyClientInfo, payload: dict): raise FlyMyAIPredictException.from_response(e.response) yield response - async def stream( - self, payload: dict, model: Optional[str] = None, max_retries=None - ): + def stream(self, payload: dict, model: Optional[str] = None, max_retries=None): stream_iter = self._stream(self.amend_client_info(model), payload) - last_response = None - async for response in stream_iter: - response.stream = stream_iter - yield PredictionPartial.from_response(response) - last_response = response - if last_response: - last_response.is_stream_consumed = True + stream_wrapper = AsyncPredictionStream(stream_iter) + return stream_wrapper @staticmethod async def _wrap_request(request_callback: Callable[..., Awaitable[httpx.Response]]): diff --git a/flymyai/core/_response.py b/flymyai/core/_response.py index 545dd4d..551f646 100644 --- a/flymyai/core/_response.py +++ b/flymyai/core/_response.py @@ -1,3 +1,6 @@ +import json +import typing + import httpx @@ -10,3 +13,10 @@ def from_httpx(cls, response: httpx.Response): request=response.request, headers=response.headers, ) + + def json(self, **kwargs) -> typing.Any: + if self.content.startswith(b"data"): + trail_content = self.content[: len(b"data")] + return json.loads(trail_content) + else: + return super().json(**kwargs) diff --git a/flymyai/core/_streaming.py b/flymyai/core/_streaming.py index db38ff4..ccef6d6 100644 --- a/flymyai/core/_streaming.py +++ b/flymyai/core/_streaming.py @@ -162,6 +162,6 @@ def decode(self, line: str) -> ServerSentEvent | None: except (TypeError, ValueError): pass else: - pass # Field is ignored. + pass # the field is ignored. return None diff --git a/flymyai/core/models.py b/flymyai/core/models.py index 8e14ac5..e88929e 100644 --- a/flymyai/core/models.py +++ b/flymyai/core/models.py @@ -139,3 +139,9 @@ class PredictionPartial(BaseFromServer): output_data: Optional[dict] = None _response: FlyMyAIResponse = PrivateAttr() + + +class StreamDetails(pydantic.BaseModel): + input_tokens: int + output_tokens: int + model_size_in_billions: float diff --git a/tests/test_stream.py b/tests/test_stream.py index b867301..8adb1ba 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -31,24 +31,34 @@ def output_field(): def test_stream(stream_auth, stream_payload, dsn, output_field): stream_iterator = sync_client(**stream_auth).stream(stream_payload) - for response in stream_iterator: - assert response.status == 200 - assert response.output_data - print(response.output_data[output_field].pop(), end="") - print("\n") + try: + for response in stream_iterator: + assert response.status == 200 + assert response.output_data or hasattr(stream_iterator, "stream_details") + if response.output_data.get(output_field): + print(response.output_data[output_field].pop(), end="") + except Exception as e: + if hasattr(e, "msg"): + print(e) + raise e + finally: + print() + print(stream_iterator.stream_details) @pytest.mark.asyncio async def test_async_stream(stream_auth, stream_payload, dsn, output_field): + stream_iterator = async_client(**stream_auth).stream(stream_payload) try: - stream_iterator = async_client(**stream_auth).stream(stream_payload) async for response in stream_iterator: assert response.status == 200 - assert response.output_data - print(response.output_data[output_field].pop(), end="") + assert response.output_data or hasattr(stream_iterator, "stream_details") + if response.output_data.get(output_field): + print(response.output_data[output_field].pop(), end="") except Exception as e: if hasattr(e, "msg"): print(e) raise e finally: print() + print(stream_iterator.stream_details) From 8466d499261b862a7ebf2f42b55299edd1e5dafc Mon Sep 17 00:00:00 2001 From: D1-3105 Date: Tue, 23 Jul 2024 03:12:25 +0300 Subject: [PATCH 2/5] details-stream --- flymyai/core/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flymyai/core/models.py b/flymyai/core/models.py index e88929e..6a2fdc1 100644 --- a/flymyai/core/models.py +++ b/flymyai/core/models.py @@ -144,4 +144,4 @@ class PredictionPartial(BaseFromServer): class StreamDetails(pydantic.BaseModel): input_tokens: int output_tokens: int - model_size_in_billions: float + size_in_billions: float = pydantic.Field(alias="model_size_in_billions") From 91f7e0ea4d0226df0dcf64205ec40acadd98a360 Mon Sep 17 00:00:00 2001 From: D1-3105 Date: Wed, 7 Aug 2024 22:57:03 +0300 Subject: [PATCH 3/5] stream-cancel logic --- flymyai/core/_client.py | 60 +-------------- flymyai/core/_response.py | 7 +- flymyai/core/_response_factory.py | 4 +- flymyai/core/_streaming.py | 5 +- flymyai/core/models.py | 16 +++- .../stream_iterators/AsyncPredictionStream.py | 74 +++++++++++++++++++ .../core/stream_iterators/PredictionStream.py | 63 ++++++++++++++++ flymyai/core/stream_iterators/__init__.py | 0 flymyai/core/types/__init__.py | 0 flymyai/core/types/event_types.py | 6 ++ tests/test_stream.py | 8 +- 11 files changed, 176 insertions(+), 67 deletions(-) create mode 100644 flymyai/core/stream_iterators/AsyncPredictionStream.py create mode 100644 flymyai/core/stream_iterators/PredictionStream.py create mode 100644 flymyai/core/stream_iterators/__init__.py create mode 100644 flymyai/core/types/__init__.py create mode 100644 flymyai/core/types/event_types.py diff --git a/flymyai/core/_client.py b/flymyai/core/_client.py index 88d469b..403a0ad 100644 --- a/flymyai/core/_client.py +++ b/flymyai/core/_client.py @@ -14,7 +14,6 @@ import httpx -from flymyai.core._response import FlyMyAIResponse from flymyai.core._response_factory import ResponseFactory from flymyai.core._streaming import SSEDecoder from flymyai.core.authorizations import APIKeyClientInfo @@ -29,8 +28,9 @@ PredictionResponse, OpenAPISchemaResponse, PredictionPartial, - StreamDetails, ) +from flymyai.core.stream_iterators.AsyncPredictionStream import AsyncPredictionStream +from flymyai.core.stream_iterators.PredictionStream import PredictionStream from flymyai.multipart.payload import MultipartPayload from flymyai.utils.utils import retryable_callback, aretryable_callback @@ -170,34 +170,6 @@ def _construct_client(self): raise NotImplemented -class PredictionStream: - stream_details: StreamDetails - - def __init__(self, response_iterator: Iterator): - self.response_iterator = response_iterator - - def __iter__(self): - return self - - def __next__(self): - response_end = None - try: - next_resp: FlyMyAIResponse = self.response_iterator.__next__() - response_end = next_resp - return PredictionPartial.from_response(response_end) - except BaseFlyMyAIException as e: - response_end = e.response - raise e - finally: - if not response_end: - raise StopIteration() - stream_details_marshalled = response_end.json().get("stream_details") - if stream_details_marshalled: - self.stream_details = StreamDetails.model_validate( - stream_details_marshalled - ) - - class BaseSyncClient(BaseClient[httpx.Client]): def _construct_client(self): return httpx.Client( @@ -332,34 +304,6 @@ def run_predict(cls, apikey: str, model: str, payload: dict): return client.predict(payload) -class AsyncPredictionStream: - stream_details: StreamDetails - - def __init__(self, response_iterator: AsyncIterator): - self.response_iterator = response_iterator - - def __aiter__(self): - return self - - async def __anext__(self): - response_end = None - try: - next_resp: FlyMyAIResponse = await self.response_iterator.__anext__() - response_end = next_resp - return PredictionPartial.from_response(response_end) - except BaseFlyMyAIException as e: - response_end = e.response - raise e - finally: - if not response_end: - raise StopAsyncIteration() - stream_details_marshalled = response_end.json().get("stream_details") - if stream_details_marshalled: - self.stream_details = StreamDetails.model_validate( - stream_details_marshalled - ) - - class BaseAsyncClient(BaseClient[httpx.AsyncClient]): def _construct_client(self): return httpx.AsyncClient( diff --git a/flymyai/core/_response.py b/flymyai/core/_response.py index 551f646..3a49535 100644 --- a/flymyai/core/_response.py +++ b/flymyai/core/_response.py @@ -5,6 +5,8 @@ class FlyMyAIResponse(httpx.Response): + is_event: bool = False + @classmethod def from_httpx(cls, response: httpx.Response): return cls( @@ -16,7 +18,8 @@ def from_httpx(cls, response: httpx.Response): def json(self, **kwargs) -> typing.Any: if self.content.startswith(b"data"): - trail_content = self.content[: len(b"data")] - return json.loads(trail_content) + return json.loads(self.content.removeprefix(b"data")) + elif self.content.startswith(b"event"): + return json.loads(self.content.removeprefix(b"event")) else: return super().json(**kwargs) diff --git a/flymyai/core/_response_factory.py b/flymyai/core/_response_factory.py index 1dc7484..6f0fa2a 100644 --- a/flymyai/core/_response_factory.py +++ b/flymyai/core/_response_factory.py @@ -37,12 +37,14 @@ def get_sse_status_code(self): def _base_construct_from_sse(self): sse_status = self.get_sse_status_code() if sse_status < 400: - return FlyMyAIResponse( + response = FlyMyAIResponse( status_code=sse_status, content=self.sse.data or self.sse.event, request=self.httpx_request, headers=self.httpx_response.headers or self.sse.headers, ) + response.is_event = self.sse.event is not None + return response else: raise BaseFlyMyAIException.from_response( FlyMyAIResponse( diff --git a/flymyai/core/_streaming.py b/flymyai/core/_streaming.py index ccef6d6..f2e26fa 100644 --- a/flymyai/core/_streaming.py +++ b/flymyai/core/_streaming.py @@ -48,7 +48,10 @@ def data(self) -> str: return self._data def json(self) -> Any: - return json.loads(self.data.strip()) + if self.data: + return json.loads(self.data.strip()) + if self.event: + return json.loads(self.event.strip()) @property def headers(self): diff --git a/flymyai/core/models.py b/flymyai/core/models.py index 6a2fdc1..53abfaa 100644 --- a/flymyai/core/models.py +++ b/flymyai/core/models.py @@ -7,6 +7,7 @@ from pydantic import PrivateAttr from flymyai.core._response import FlyMyAIResponse +from flymyai.core.types.event_types import EventType @dataclasses.dataclass @@ -141,7 +142,16 @@ class PredictionPartial(BaseFromServer): _response: FlyMyAIResponse = PrivateAttr() +class PredictionEvent(BaseFromServer): + status: int + event_type: EventType + + prediction_id: Optional[str] = None # EventType.STREAM_ID + + class StreamDetails(pydantic.BaseModel): - input_tokens: int - output_tokens: int - size_in_billions: float = pydantic.Field(alias="model_size_in_billions") + input_tokens: Optional[int] = None + output_tokens: Optional[int] = None + size_in_billions: Optional[float] = pydantic.Field( + default=None, alias="model_size_in_billions" + ) diff --git a/flymyai/core/stream_iterators/AsyncPredictionStream.py b/flymyai/core/stream_iterators/AsyncPredictionStream.py new file mode 100644 index 0000000..3cd4bb3 --- /dev/null +++ b/flymyai/core/stream_iterators/AsyncPredictionStream.py @@ -0,0 +1,74 @@ +import asyncio +from typing import AsyncIterator, TypeVar, Callable, Union, Awaitable + +from flymyai.core._response import FlyMyAIResponse +from flymyai.core.exceptions import BaseFlyMyAIException +from flymyai.core.models import StreamDetails, PredictionPartial, PredictionEvent +from flymyai.core.types.event_types import EventType + + +_AsyncEventCallbackType = TypeVar( + "_AsyncEventCallbackType", + bound=Union[ + Callable[[PredictionEvent], None], Callable[[PredictionEvent], Awaitable[None]] + ], +) + + +class AsyncPredictionStream: + stream_details: StreamDetails + + event_callback: _AsyncEventCallbackType = None + + follow_cancelling: bool = True + + def __init__(self, response_iterator: AsyncIterator): + self.response_iterator = response_iterator + + def __aiter__(self): + return self + + def set_on_event(self, callback_or_coro: _AsyncEventCallbackType): + self.event_callback = callback_or_coro + + async def loop_iter(self): + response_end = None + while not response_end: + next_resp: FlyMyAIResponse = await self.response_iterator.__anext__() + if not next_resp.is_event: + response_end = next_resp + return response_end + else: + evt = PredictionEvent.from_response(next_resp) + if not self.event_callback: + pass + else: + coro_or_res = self.event_callback(evt) + if asyncio.iscoroutine(coro_or_res): + asyncio.run_coroutine_threadsafe( + coro_or_res, asyncio.get_event_loop() + ) + if ( + self.follow_cancelling + and evt.event_type == EventType.CANCELLING + ): + raise StopAsyncIteration + + async def __anext__(self): + response_end = None + try: + response_end = await self.loop_iter() + return PredictionPartial.from_response(response_end) + except BaseFlyMyAIException as e: + response_end = e.response + raise e + except Exception as e: + raise e + finally: + if not response_end: + raise StopAsyncIteration() + stream_details_marshalled = response_end.json().get("stream_details") + if stream_details_marshalled: + self.stream_details = StreamDetails.model_validate( + stream_details_marshalled + ) diff --git a/flymyai/core/stream_iterators/PredictionStream.py b/flymyai/core/stream_iterators/PredictionStream.py new file mode 100644 index 0000000..53c114b --- /dev/null +++ b/flymyai/core/stream_iterators/PredictionStream.py @@ -0,0 +1,63 @@ +from typing import Optional, Iterator, TypeVar, Callable + +from flymyai.core._response import FlyMyAIResponse +from flymyai.core.exceptions import BaseFlyMyAIException +from flymyai.core.models import StreamDetails, PredictionPartial, PredictionEvent +from flymyai.core.types.event_types import EventType + + +_SyncEventCallbackType = TypeVar( + "_SyncEventCallbackType", bound=Callable[[PredictionEvent], None] +) + + +class PredictionStream: + stream_details: StreamDetails + + event_callback: Optional[_SyncEventCallbackType] = None + follow_cancelling = True + + def __init__(self, response_iterator: Iterator): + self.response_iterator = response_iterator + + def set_on_event(self, callback: _SyncEventCallbackType): + self.event_callback = callback + + def __iter__(self): + return self + + def loop_iter(self): + response_end = None + while not response_end: + next_resp: FlyMyAIResponse = self.response_iterator.__next__() + if not next_resp.is_event: + response_end = next_resp + return response_end + else: + evt = PredictionEvent.from_response(next_resp) + if not self.event_callback: + pass + else: + self.event_callback(evt) + if ( + self.follow_cancelling + and evt.event_type == EventType.CANCELLING + ): + raise StopIteration + + def __next__(self): + response_end = None + try: + response_end = self.loop_iter() + return PredictionPartial.from_response(response_end) + except BaseFlyMyAIException as e: + response_end = e.response + raise e + finally: + if not response_end: + raise StopIteration() + stream_details_marshalled = response_end.json().get("stream_details") + if stream_details_marshalled: + self.stream_details = StreamDetails.model_validate( + stream_details_marshalled + ) diff --git a/flymyai/core/stream_iterators/__init__.py b/flymyai/core/stream_iterators/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/flymyai/core/types/__init__.py b/flymyai/core/types/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/flymyai/core/types/event_types.py b/flymyai/core/types/event_types.py new file mode 100644 index 0000000..a1f0144 --- /dev/null +++ b/flymyai/core/types/event_types.py @@ -0,0 +1,6 @@ +import enum + + +class EventType(str, enum.Enum): + CANCELLING = "stream_cancelling" + STREAM_ID = "id" diff --git a/tests/test_stream.py b/tests/test_stream.py index 8adb1ba..7c723ee 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -31,6 +31,8 @@ def output_field(): def test_stream(stream_auth, stream_payload, dsn, output_field): stream_iterator = sync_client(**stream_auth).stream(stream_payload) + stream_iterator.follow_cancelling = False + stream_iterator.set_on_event(print) try: for response in stream_iterator: assert response.status == 200 @@ -43,12 +45,14 @@ def test_stream(stream_auth, stream_payload, dsn, output_field): raise e finally: print() - print(stream_iterator.stream_details) + print(getattr(stream_iterator, "stream_details", None)) @pytest.mark.asyncio async def test_async_stream(stream_auth, stream_payload, dsn, output_field): stream_iterator = async_client(**stream_auth).stream(stream_payload) + stream_iterator.follow_cancelling = False + stream_iterator.set_on_event(print) try: async for response in stream_iterator: assert response.status == 200 @@ -61,4 +65,4 @@ async def test_async_stream(stream_auth, stream_payload, dsn, output_field): raise e finally: print() - print(stream_iterator.stream_details) + print(getattr(stream_iterator, "stream_details", None)) From aef2da1fc07fa11456ef715fd70bb271b8bc12c8 Mon Sep 17 00:00:00 2001 From: D1-3105 Date: Thu, 8 Aug 2024 00:08:40 +0300 Subject: [PATCH 4/5] stream.cancel() method, refactor and tests --- flymyai/__init__.py | 4 +- flymyai/core/_client.py | 461 ------------------ flymyai/core/authorizations.py | 4 + flymyai/core/client.py | 6 +- flymyai/core/clients/AsyncClient.py | 198 ++++++++ flymyai/core/clients/SyncClient.py | 175 +++++++ flymyai/core/clients/__init__.py | 0 flymyai/core/clients/base_client.py | 179 +++++++ flymyai/core/exceptions.py | 2 +- flymyai/core/models/__init__.py | 0 .../{models.py => models/error_responses.py} | 71 --- flymyai/core/models/successful_responses.py | 72 +++ .../stream_iterators/AsyncPredictionStream.py | 39 +- .../core/stream_iterators/PredictionStream.py | 47 +- flymyai/core/stream_iterators/exceptions.py | 2 + tests/test_stream.py | 131 ++++- 16 files changed, 831 insertions(+), 560 deletions(-) delete mode 100644 flymyai/core/_client.py create mode 100644 flymyai/core/clients/AsyncClient.py create mode 100644 flymyai/core/clients/SyncClient.py create mode 100644 flymyai/core/clients/__init__.py create mode 100644 flymyai/core/clients/base_client.py create mode 100644 flymyai/core/models/__init__.py rename flymyai/core/{models.py => models/error_responses.py} (51%) create mode 100644 flymyai/core/models/successful_responses.py create mode 100644 flymyai/core/stream_iterators/exceptions.py diff --git a/flymyai/__init__.py b/flymyai/__init__.py index 2c67333..3fd3ca3 100644 --- a/flymyai/__init__.py +++ b/flymyai/__init__.py @@ -1,7 +1,7 @@ import httpx -from .core.client import FlyMyAI, AsyncFlyMyAI -from .core.exceptions import FlyMyAIPredictException, FlyMyAIExceptionGroup +from flymyai.core.client import FlyMyAI, AsyncFlyMyAI +from flymyai.core.exceptions import FlyMyAIPredictException, FlyMyAIExceptionGroup __all__ = [ diff --git a/flymyai/core/_client.py b/flymyai/core/_client.py deleted file mode 100644 index 403a0ad..0000000 --- a/flymyai/core/_client.py +++ /dev/null @@ -1,461 +0,0 @@ -import os -from typing import ( - Callable, - Awaitable, - Generic, - TypeVar, - Union, - overload, - Iterator, - AsyncContextManager, - AsyncIterator, - Optional, -) - -import httpx - -from flymyai.core._response_factory import ResponseFactory -from flymyai.core._streaming import SSEDecoder -from flymyai.core.authorizations import APIKeyClientInfo -from flymyai.core.exceptions import ( - FlyMyAIPredictException, - FlyMyAIExceptionGroup, - BaseFlyMyAIException, - FlyMyAIOpenAPIException, - ImproperlyConfiguredClientException, -) -from flymyai.core.models import ( - PredictionResponse, - OpenAPISchemaResponse, - PredictionPartial, -) -from flymyai.core.stream_iterators.AsyncPredictionStream import AsyncPredictionStream -from flymyai.core.stream_iterators.PredictionStream import PredictionStream -from flymyai.multipart.payload import MultipartPayload -from flymyai.utils.utils import retryable_callback, aretryable_callback - -DEFAULT_RETRY_COUNT = os.getenv("FLYMYAI_MAX_RETRIES", 2) - -_PossibleClients = TypeVar( - "_PossibleClients", bound=Union[httpx.Client, httpx.AsyncClient] -) - - -_predict_timeout = httpx.Timeout(None, connect=10) - - -class BaseClient(Generic[_PossibleClients]): - - """ - Base class for FlyMyAI clients - """ - - _client: _PossibleClients - max_retries: int - client_info: APIKeyClientInfo - - def __init__( - self, apikey: str, model: Optional[str] = None, max_retries=DEFAULT_RETRY_COUNT - ): - self.client_info = APIKeyClientInfo(apikey) - if model: - self.client_info = self.client_info.copy_for_model(model) - self._client = self._construct_client() - self.max_retries = max_retries - - def amend_client_info(self, model: Optional[str] = None): - if model: - client_info = self.client_info.copy_for_model(model) - else: - client_info = self.client_info - if not client_info.project_name or not client_info.username: - raise ImproperlyConfiguredClientException( - "model should be provided as /" - ) - return client_info - - @overload - async def predict( - self, payload: dict, model: Optional[str] = None, max_retries=None - ) -> PredictionResponse: - ... - - @overload - def predict( - self, payload: dict, model: Optional[str] = None, max_retries=None - ) -> PredictionResponse: - ... - - def predict( - self, payload: dict, model: Optional[str] = None, max_retries=None - ) -> PredictionResponse: - ... - - @overload - async def openapi_schema( - self, model: Optional[str] = None, max_retries=None - ) -> OpenAPISchemaResponse: - ... - - @overload - def openapi_schema( - self, model: Optional[str] = None, max_retries=None - ) -> OpenAPISchemaResponse: - ... - - def openapi_schema( - self, model: Optional[str] = None, max_retries=None - ) -> OpenAPISchemaResponse: - ... - - @overload - async def stream( - self, - payload: dict, - model: Optional[str] = None, - ) -> AsyncIterator[PredictionPartial]: - ... - - @overload - def stream( - self, - payload: dict, - model: Optional[str] = None, - ) -> Iterator[PredictionPartial]: - ... - - def stream( - self, - payload: dict, - model: Optional[str] = None, - ): - ... - - def _stream_iterator( - self, client_info, payload: MultipartPayload, is_long_stream: bool - ) -> Union[Iterator[httpx.Response], AsyncIterator[httpx.Response]]: - return self._client.stream( - method="post", - url=( - client_info.prediction_path - if not is_long_stream - else client_info.prediction_stream_path - ), - **payload.serialize(), - timeout=_predict_timeout, - headers=client_info.authorization_headers, - follow_redirects=True, - ) - - @staticmethod - def _wrap_request(request_callback: Callable): - response = request_callback() - return ResponseFactory(httpx_response=response).construct() - - def is_closed(self) -> bool: - return self._client.is_closed - - def close(self) -> None: - """ - Close the underlying HTTPX client. - - The client will *not* be usable after this. - """ - # If an error is thrown while constructing a client, self._client - # may not be present - if hasattr(self, "_client"): - self._client.close() - - def _construct_client(self): - raise NotImplemented - - -class BaseSyncClient(BaseClient[httpx.Client]): - def _construct_client(self): - return httpx.Client( - http2=True, - headers=self.client_info.authorization_headers, - base_url=os.getenv("FLYMYAI_DSN", "https://api.flymy.ai/"), - ) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self._client.close() - - @classmethod - def _sse_instant(cls, stream_iter_func: Callable[[], Iterator[httpx.Response]]): - """ - Fetch sse response on prediction - :param stream_iter_func: context manager with underlying stream - :return: FlyMyAIResponse - """ - with stream_iter_func() as stream: - stream: httpx.Response - response = ResponseFactory( - sse=next(SSEDecoder().iter(stream.iter_lines())), - httpx_request=stream.request, - httpx_response=stream, - ).construct() - return response - - def _predict(self, payload: MultipartPayload, client_info: APIKeyClientInfo): - """ - Wrap predict method in sse - """ - - try: - return self._sse_instant( - lambda: self._stream_iterator(client_info, payload, False) - ) - except BaseFlyMyAIException as e: - raise FlyMyAIPredictException.from_response(e.response) - - def predict(self, payload: dict, model: Optional[str] = None, max_retries=None): - """ - Wrap predict method in sse. - Retries until max_retries or self.max_retries is reached - :param model: flymyai/bert | None, If none - get self.client_info. - :param payload: anything for model - :param max_retries: retries - :return: PredictionResponse(exc_history, output_data, response): - exc_history - list of exception history during prediction - output_data - dict with prediction output - """ - - payload = MultipartPayload(payload) - history, response = retryable_callback( - lambda: self._predict(payload, self.amend_client_info(model)), - max_retries or self.max_retries, - FlyMyAIPredictException, - FlyMyAIExceptionGroup, - ) - return PredictionResponse.from_response(response, exc_history=history) - - def _stream(self, client_info: APIKeyClientInfo, payload: dict): - payload = MultipartPayload(payload) - response_iterator = self._stream_iterator( - client_info, payload, is_long_stream=True - ) - decoder = SSEDecoder() - with response_iterator as sse_stream: - for sse_partial in decoder.iter(sse_stream.iter_lines()): - try: - response = ResponseFactory( - sse=sse_partial, - httpx_request=sse_stream.request, - httpx_response=sse_stream, - ).construct() - except BaseFlyMyAIException as e: - raise FlyMyAIPredictException.from_response(e.response) - yield response - - def stream(self, payload: dict, model: Optional[str] = None): - stream_iter = self._stream(self.amend_client_info(model), payload) - stream_wrapper = PredictionStream(stream_iter) - return stream_wrapper - - def _openapi_schema(self, client_info: APIKeyClientInfo): - """ - OpenAPI request for the current project, wrapped in executor-method (using HTTP/1) - :return: - """ - try: - return self._wrap_request( - lambda: self._client.get( - client_info.openapi_schema_path, - headers=client_info.authorization_headers, - ) - ) - except BaseFlyMyAIException as e: - raise FlyMyAIOpenAPIException.from_response(e.response) - - def openapi_schema(self, model: Optional[str] = None, max_retries=None): - """ - :param model: flymyai/bert - :param max_retries: retries before give up - :return: - :return: OpenAPISchemaResponse(exc_history, openapi_schema, response): - exc_history - dict with exceptions; - openapi_schema - dict with openapi; - """ - history, response = retryable_callback( - lambda: self._openapi_schema(client_info=self.amend_client_info(model)), - max_retries or self.max_retries, - FlyMyAIPredictException, - FlyMyAIExceptionGroup, - ) - return OpenAPISchemaResponse.from_response( - exc_history=history, openapi_schema=response.json(), response=response - ) - - @classmethod - def run_predict(cls, apikey: str, model: str, payload: dict): - """ - :param apikey: fly-... - :param model: flymyai/bert - :param payload: jsonable / multipart/form-data available data - :return: PredictionResponse(exc_history, output_data, response): - exc_history - list of exception history during prediction; - output_data - dict with prediction output; - """ - with cls(apikey, model) as client: - return client.predict(payload) - - -class BaseAsyncClient(BaseClient[httpx.AsyncClient]): - def _construct_client(self): - return httpx.AsyncClient( - http2=True, - headers=self.client_info.authorization_headers, - base_url=os.getenv("FLYMYAI_DSN", "https://api.flymy.ai/"), - ) - - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - if hasattr(self, "_client"): - await self._client.aclose() - - async def openapi_schema(self, model: Optional[str] = None, max_retries=None): - """ - :param max_retries: retries before giving up - :return: - :return: OpenAPISchemaResponse(exc_history, openapi_schema, response): - exc_history - dict with exceptions; - openapi_schema - dict with openapi; - """ - history, response = await aretryable_callback( - lambda: self._openapi_schema(), - max_retries or self.max_retries, - FlyMyAIPredictException, - FlyMyAIExceptionGroup, - ) - return OpenAPISchemaResponse.from_response( - exc_history=history, openapi_schema=response.json(), response=response - ) - - def _openapi_schema(self, client_info: APIKeyClientInfo): - """ - OpenAPI request for the current project, wrapped in executor-method (using HTTP/1) - :return: - """ - try: - return self._wrap_request( - lambda: self._client.get( - client_info.openapi_schema_path, - headers=client_info.authorization_headers, - ) - ) - except BaseFlyMyAIException as e: - raise FlyMyAIOpenAPIException.from_response(e.response) - - @classmethod - async def _sse_instant( - cls, async_response_stream: Callable[[], AsyncContextManager[httpx.Response]] - ): - """ - A non-blocking approach to fetch a response stream - :param async_response_stream: context manager with underlying stream - :return: FlyMyAIResponse - """ - async with async_response_stream() as stream: - sse = await SSEDecoder().aiter(stream.aiter_lines()).__anext__() - response = ResponseFactory( - sse=sse, httpx_request=stream.request, httpx_response=stream - ).construct() - return response - - def _predict(self, client_info, payload: MultipartPayload): - """ - Executes request and waits for sse data - :param payload: model input data - :return: FlyMyAIResponse or raise an exception - """ - try: - return self._sse_instant( - lambda: self._client.stream( - method="post", - url=client_info.prediction_path, - timeout=_predict_timeout, - **payload.serialize(), - headers=client_info.authorization_headers, - ) - ) - except BaseFlyMyAIException as e: - raise FlyMyAIPredictException.from_response(e.response) - - async def predict( - self, payload: dict, model: Optional[str] = None, max_retries=None - ) -> PredictionResponse: - """ - Wrap predict method in sse. - Retries until max_retries or self.max_retries is reached - :param model: flymyai/bert - :param payload: anything for model - :param max_retries: retries - :return: PredictionResponse(exc_history, output_data, response): - exc_history - list of exception history during prediction - output_data - dict with prediction output - """ - payload = MultipartPayload(input_data=payload) - history, response = await aretryable_callback( - lambda: self._predict(self.amend_client_info(model), payload), - max_retries or self.max_retries, - FlyMyAIPredictException, - FlyMyAIExceptionGroup, - ) - return PredictionResponse.from_response(response, exc_history=history) - - async def _stream(self, client_info: APIKeyClientInfo, payload: dict): - payload = MultipartPayload(payload) - stream_iterator = self._stream_iterator( - client_info, payload, is_long_stream=True - ) - decoder = SSEDecoder() - async with stream_iterator as sse_stream: - async for sse_partial in decoder.aiter(sse_stream.aiter_lines()): - try: - response = ResponseFactory( - sse=sse_partial, - httpx_request=sse_stream.request, - httpx_response=sse_stream, - ).construct() - except BaseFlyMyAIException as e: - raise FlyMyAIPredictException.from_response(e.response) - yield response - - def stream(self, payload: dict, model: Optional[str] = None, max_retries=None): - stream_iter = self._stream(self.amend_client_info(model), payload) - stream_wrapper = AsyncPredictionStream(stream_iter) - return stream_wrapper - - @staticmethod - async def _wrap_request(request_callback: Callable[..., Awaitable[httpx.Response]]): - """ - Execute a request callback and return the response - """ - response = await request_callback() - return ResponseFactory(httpx_response=response).construct() - - async def close(self): - """ - Close the client - """ - await self._client.aclose() - - @classmethod - async def arun_predict(cls, apikey: str, model: str, payload: dict): - """ - Execute simple prediction out of a box - :param model: flymyai/bert - :param apikey: fly-... - :param payload: {dict with prediction input} - :return: PredictionResponse(exc_history, output_data, response) - exc_history - list of exception history during prediction - output_data - dict with prediction output - """ - async with cls(apikey, model) as client: - return await client.predict(payload) diff --git a/flymyai/core/authorizations.py b/flymyai/core/authorizations.py index 2e0b561..5be4d56 100644 --- a/flymyai/core/authorizations.py +++ b/flymyai/core/authorizations.py @@ -54,6 +54,10 @@ def _project_path(self): def prediction_path(self): return self._project_path.join(httpx.URL("predict")) + @property + def prediction_cancel_path(self): + return self._project_path.join(httpx.URL("predict/cancel/")) + @property def prediction_stream_path(self): return self._project_path.join(httpx.URL("predict/stream/")) diff --git a/flymyai/core/client.py b/flymyai/core/client.py index 30d8e99..9d503d6 100644 --- a/flymyai/core/client.py +++ b/flymyai/core/client.py @@ -1,7 +1,5 @@ -from ._client import ( - BaseSyncClient, - BaseAsyncClient, -) +from flymyai.core.clients.AsyncClient import BaseAsyncClient +from flymyai.core.clients.SyncClient import BaseSyncClient class FlyMyAI(BaseSyncClient): diff --git a/flymyai/core/clients/AsyncClient.py b/flymyai/core/clients/AsyncClient.py new file mode 100644 index 0000000..b7cc4cb --- /dev/null +++ b/flymyai/core/clients/AsyncClient.py @@ -0,0 +1,198 @@ +import os +from typing import Optional, Callable, AsyncContextManager, Awaitable + +import httpx + +from flymyai.core._response_factory import ResponseFactory +from flymyai.core._streaming import SSEDecoder +from flymyai.core.authorizations import APIKeyClientInfo +from flymyai.core.clients.base_client import BaseClient, _predict_timeout +from flymyai.core.exceptions import ( + BaseFlyMyAIException, + FlyMyAIOpenAPIException, + FlyMyAIPredictException, + FlyMyAIExceptionGroup, +) +from flymyai.core.models.successful_responses import ( + OpenAPISchemaResponse, + PredictionResponse, +) +from flymyai.core.stream_iterators.AsyncPredictionStream import AsyncPredictionStream +from flymyai.multipart import MultipartPayload +from flymyai.utils.utils import aretryable_callback + + +class BaseAsyncClient(BaseClient[httpx.AsyncClient]): + def _construct_client(self): + return httpx.AsyncClient( + http2=True, + headers=self.client_info.authorization_headers, + base_url=os.getenv("FLYMYAI_DSN", "https://api.flymy.ai/"), + ) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if hasattr(self, "_client"): + await self._client.aclose() + + async def openapi_schema(self, model: Optional[str] = None, max_retries=None): + """ + :param max_retries: retries before giving up + :return: + :return: OpenAPISchemaResponse(exc_history, openapi_schema, response): + exc_history - dict with exceptions; + openapi_schema - dict with openapi; + """ + history, response = await aretryable_callback( + lambda: self._openapi_schema(), + max_retries or self.max_retries, + FlyMyAIPredictException, + FlyMyAIExceptionGroup, + ) + return OpenAPISchemaResponse.from_response( + exc_history=history, openapi_schema=response.json(), response=response + ) + + def _openapi_schema(self, client_info: APIKeyClientInfo): + """ + OpenAPI request for the current project, wrapped in executor-method (using HTTP/1) + :return: + """ + try: + return self._wrap_request( + lambda: self._client.get( + client_info.openapi_schema_path, + headers=client_info.authorization_headers, + ) + ) + except BaseFlyMyAIException as e: + raise FlyMyAIOpenAPIException.from_response(e.response) + + async def cancel_prediction( + self, + prediction_id: str, + model: Optional[str] = None, + client_info: APIKeyClientInfo = None, + ): + if client_info: + full_client_info = client_info + else: + full_client_info = self.amend_client_info(model) + response = await self._client.patch( + url=full_client_info.prediction_cancel_path, + json={"infer_id": prediction_id}, + ) + return ResponseFactory( + httpx_response=response, httpx_request=response.request + ).construct() + + @classmethod + async def _sse_instant( + cls, async_response_stream: Callable[[], AsyncContextManager[httpx.Response]] + ): + """ + A non-blocking approach to fetch a response stream + :param async_response_stream: context manager with underlying stream + :return: FlyMyAIResponse + """ + async with async_response_stream() as stream: + sse = await SSEDecoder().aiter(stream.aiter_lines()).__anext__() + response = ResponseFactory( + sse=sse, httpx_request=stream.request, httpx_response=stream + ).construct() + return response + + def _predict(self, client_info, payload: MultipartPayload): + """ + Executes request and waits for sse data + :param payload: model input data + :return: FlyMyAIResponse or raise an exception + """ + try: + return self._sse_instant( + lambda: self._client.stream( + method="post", + url=client_info.prediction_path, + timeout=_predict_timeout, + **payload.serialize(), + headers=client_info.authorization_headers, + ) + ) + except BaseFlyMyAIException as e: + raise FlyMyAIPredictException.from_response(e.response) + + async def predict( + self, payload: dict, model: Optional[str] = None, max_retries=None + ) -> PredictionResponse: + """ + Wrap predict method in sse. + Retries until max_retries or self.max_retries is reached + :param model: flymyai/bert + :param payload: anything for model + :param max_retries: retries + :return: PredictionResponse(exc_history, output_data, response): + exc_history - list of exception history during prediction + output_data - dict with prediction output + """ + payload = MultipartPayload(input_data=payload) + history, response = await aretryable_callback( + lambda: self._predict(self.amend_client_info(model), payload), + max_retries or self.max_retries, + FlyMyAIPredictException, + FlyMyAIExceptionGroup, + ) + return PredictionResponse.from_response(response, exc_history=history) + + async def _stream(self, client_info: APIKeyClientInfo, payload: dict): + payload = MultipartPayload(payload) + stream_iterator = self._stream_iterator( + client_info, payload, is_long_stream=True + ) + decoder = SSEDecoder() + async with stream_iterator as sse_stream: + async for sse_partial in decoder.aiter(sse_stream.aiter_lines()): + try: + response = ResponseFactory( + sse=sse_partial, + httpx_request=sse_stream.request, + httpx_response=sse_stream, + ).construct() + except BaseFlyMyAIException as e: + raise FlyMyAIPredictException.from_response(e.response) + yield response + + def stream(self, payload: dict, model: Optional[str] = None, max_retries=None): + full_client_info = self.amend_client_info(model) + stream_iter = self._stream(full_client_info, payload) + stream_wrapper = AsyncPredictionStream(stream_iter, self, full_client_info) + return stream_wrapper + + @staticmethod + async def _wrap_request(request_callback: Callable[..., Awaitable[httpx.Response]]): + """ + Execute a request callback and return the response + """ + response = await request_callback() + return ResponseFactory(httpx_response=response).construct() + + async def close(self): + """ + Close the client + """ + await self._client.aclose() + + @classmethod + async def arun_predict(cls, apikey: str, model: str, payload: dict): + """ + Execute simple prediction out of a box + :param model: flymyai/bert + :param apikey: fly-... + :param payload: {dict with prediction input} + :return: PredictionResponse(exc_history, output_data, response) + exc_history - list of exception history during prediction + output_data - dict with prediction output + """ + async with cls(apikey, model) as client: + return await client.predict(payload) diff --git a/flymyai/core/clients/SyncClient.py b/flymyai/core/clients/SyncClient.py new file mode 100644 index 0000000..b40f9c4 --- /dev/null +++ b/flymyai/core/clients/SyncClient.py @@ -0,0 +1,175 @@ +import os +from typing import Callable, Iterator, Optional + +import httpx + +from flymyai.core._response_factory import ResponseFactory +from flymyai.core._streaming import SSEDecoder +from flymyai.core.authorizations import APIKeyClientInfo +from flymyai.core.clients.base_client import BaseClient +from flymyai.core.exceptions import ( + BaseFlyMyAIException, + FlyMyAIOpenAPIException, + FlyMyAIPredictException, + FlyMyAIExceptionGroup, +) +from flymyai.core.models.successful_responses import ( + PredictionResponse, + OpenAPISchemaResponse, +) +from flymyai.core.stream_iterators.PredictionStream import PredictionStream +from flymyai.multipart import MultipartPayload +from flymyai.utils.utils import retryable_callback + + +class BaseSyncClient(BaseClient[httpx.Client]): + def _construct_client(self): + return httpx.Client( + http2=True, + headers=self.client_info.authorization_headers, + base_url=os.getenv("FLYMYAI_DSN", "https://api.flymy.ai/"), + ) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._client.close() + + @classmethod + def _sse_instant(cls, stream_iter_func: Callable[[], Iterator[httpx.Response]]): + """ + Fetch sse response on prediction + :param stream_iter_func: context manager with underlying stream + :return: FlyMyAIResponse + """ + with stream_iter_func() as stream: + stream: httpx.Response + response = ResponseFactory( + sse=next(SSEDecoder().iter(stream.iter_lines())), + httpx_request=stream.request, + httpx_response=stream, + ).construct() + return response + + def _predict(self, payload: MultipartPayload, client_info: APIKeyClientInfo): + """ + Wrap predict method in sse + """ + + try: + return self._sse_instant( + lambda: self._stream_iterator(client_info, payload, False) + ) + except BaseFlyMyAIException as e: + raise FlyMyAIPredictException.from_response(e.response) + + def predict(self, payload: dict, model: Optional[str] = None, max_retries=None): + """ + Wrap predict method in sse. + Retries until max_retries or self.max_retries is reached + :param model: flymyai/bert | None, If none - get self.client_info. + :param payload: anything for model + :param max_retries: retries + :return: PredictionResponse(exc_history, output_data, response): + exc_history - list of exception history during prediction + output_data - dict with prediction output + """ + + payload = MultipartPayload(payload) + history, response = retryable_callback( + lambda: self._predict(payload, self.amend_client_info(model)), + max_retries or self.max_retries, + FlyMyAIPredictException, + FlyMyAIExceptionGroup, + ) + return PredictionResponse.from_response(response, exc_history=history) + + def _stream(self, client_info: APIKeyClientInfo, payload: dict): + payload = MultipartPayload(payload) + response_iterator = self._stream_iterator( + client_info, payload, is_long_stream=True + ) + decoder = SSEDecoder() + with response_iterator as sse_stream: + for sse_partial in decoder.iter(sse_stream.iter_lines()): + try: + response = ResponseFactory( + sse=sse_partial, + httpx_request=sse_stream.request, + httpx_response=sse_stream, + ).construct() + except BaseFlyMyAIException as e: + raise FlyMyAIPredictException.from_response(e.response) + yield response + + def stream(self, payload: dict, model: Optional[str] = None): + full_client_info = self.amend_client_info(model) + stream_iter = self._stream(full_client_info, payload) + stream_wrapper = PredictionStream(stream_iter, self, full_client_info) + return stream_wrapper + + def _openapi_schema(self, client_info: APIKeyClientInfo): + """ + OpenAPI request for the current project, wrapped in executor-method (using HTTP/1) + :return: + """ + try: + return self._wrap_request( + lambda: self._client.get( + client_info.openapi_schema_path, + headers=client_info.authorization_headers, + ) + ) + except BaseFlyMyAIException as e: + raise FlyMyAIOpenAPIException.from_response(e.response) + + def cancel_prediction( + self, + prediction_id: str, + model: Optional[str] = None, + client_info: APIKeyClientInfo = None, + ): + if client_info: + full_client_info = client_info + else: + full_client_info = self.amend_client_info(model) + response = self._client.patch( + url=full_client_info.prediction_cancel_path, + json={"infer_id": prediction_id}, + ) + return ResponseFactory( + httpx_response=response, httpx_request=response.request + ).construct() + + def openapi_schema(self, model: Optional[str] = None, max_retries=None): + """ + :param model: flymyai/bert + :param max_retries: retries before give up + :return: + :return: OpenAPISchemaResponse(exc_history, openapi_schema, response): + exc_history - dict with exceptions; + openapi_schema - dict with openapi; + """ + history, response = retryable_callback( + lambda: self._openapi_schema(client_info=self.amend_client_info(model)), + max_retries or self.max_retries, + FlyMyAIPredictException, + FlyMyAIExceptionGroup, + ) + return OpenAPISchemaResponse.from_response( + exc_history=history, openapi_schema=response.json(), response=response + ) + + @classmethod + def run_predict(cls, apikey: str, model: str, payload: dict): + """ + :param apikey: fly-... + :param model: flymyai/bert + :param payload: jsonable / multipart/form-data available data + :return: PredictionResponse(exc_history, output_data, response): + exc_history - list of exception history during prediction; + output_data - dict with prediction output; + """ + with cls(apikey, model) as client: + return client.predict(payload) diff --git a/flymyai/core/clients/__init__.py b/flymyai/core/clients/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/flymyai/core/clients/base_client.py b/flymyai/core/clients/base_client.py new file mode 100644 index 0000000..606b6e8 --- /dev/null +++ b/flymyai/core/clients/base_client.py @@ -0,0 +1,179 @@ +import os +from typing import Generic, Optional, overload, AsyncIterator, Iterator, Callable +from typing import ( + TypeVar, + Union, +) + +import httpx + +from flymyai.core._response_factory import ResponseFactory +from flymyai.core.authorizations import APIKeyClientInfo +from flymyai.core.exceptions import ImproperlyConfiguredClientException +from flymyai.core.models.successful_responses import ( + PredictionResponse, + OpenAPISchemaResponse, + PredictionPartial, +) +from flymyai.multipart import MultipartPayload + +DEFAULT_RETRY_COUNT = os.getenv("FLYMYAI_MAX_RETRIES", 2) + +_PossibleClients = TypeVar( + "_PossibleClients", bound=Union[httpx.Client, httpx.AsyncClient] +) + + +_predict_timeout = httpx.Timeout(None, connect=10) + + +class BaseClient(Generic[_PossibleClients]): + + """ + Base class for FlyMyAI clients + """ + + _client: _PossibleClients + max_retries: int + client_info: APIKeyClientInfo + + def __init__( + self, apikey: str, model: Optional[str] = None, max_retries=DEFAULT_RETRY_COUNT + ): + self.client_info = APIKeyClientInfo(apikey) + if model: + self.client_info = self.client_info.copy_for_model(model) + self._client = self._construct_client() + self.max_retries = max_retries + + def amend_client_info(self, model: Optional[str] = None): + if model: + client_info = self.client_info.copy_for_model(model) + else: + client_info = self.client_info + if not client_info.project_name or not client_info.username: + raise ImproperlyConfiguredClientException( + "model should be provided as /" + ) + return client_info + + @overload + async def predict( + self, payload: dict, model: Optional[str] = None, max_retries=None + ) -> PredictionResponse: + ... + + @overload + def predict( + self, payload: dict, model: Optional[str] = None, max_retries=None + ) -> PredictionResponse: + ... + + def predict( + self, payload: dict, model: Optional[str] = None, max_retries=None + ) -> PredictionResponse: + ... + + @overload + async def openapi_schema( + self, model: Optional[str] = None, max_retries=None + ) -> OpenAPISchemaResponse: + ... + + @overload + def openapi_schema( + self, model: Optional[str] = None, max_retries=None + ) -> OpenAPISchemaResponse: + ... + + def openapi_schema( + self, model: Optional[str] = None, max_retries=None + ) -> OpenAPISchemaResponse: + ... + + @overload + async def stream( + self, + payload: dict, + model: Optional[str] = None, + ) -> AsyncIterator[PredictionPartial]: + ... + + @overload + def stream( + self, + payload: dict, + model: Optional[str] = None, + ) -> Iterator[PredictionPartial]: + ... + + def stream( + self, + payload: dict, + model: Optional[str] = None, + ): + ... + + def _stream_iterator( + self, client_info, payload: MultipartPayload, is_long_stream: bool + ) -> Union[Iterator[httpx.Response], AsyncIterator[httpx.Response]]: + return self._client.stream( + method="post", + url=( + client_info.prediction_path + if not is_long_stream + else client_info.prediction_stream_path + ), + **payload.serialize(), + timeout=_predict_timeout, + headers=client_info.authorization_headers, + follow_redirects=True, + ) + + @staticmethod + def _wrap_request(request_callback: Callable): + response = request_callback() + return ResponseFactory(httpx_response=response).construct() + + def is_closed(self) -> bool: + return self._client.is_closed + + def close(self) -> None: + """ + Close the underlying HTTPX client. + + The client will *not* be usable after this. + """ + # If an error is thrown while constructing a client, self._client + # may not be present + if hasattr(self, "_client"): + self._client.close() + + def _construct_client(self): + raise NotImplemented + + @overload + async def cancel_prediction( + self, + prediction_id: str, + model: Optional[str] = None, + client_info: APIKeyClientInfo = None, + ): + ... + + @overload + def cancel_prediction( + self, + prediction_id: str, + model: Optional[str] = None, + client_info: APIKeyClientInfo = None, + ): + ... + + def cancel_prediction( + self, + prediction_id: str, + model: Optional[str] = None, + client_info: APIKeyClientInfo = None, + ): + ... diff --git a/flymyai/core/exceptions.py b/flymyai/core/exceptions.py index ea99753..40ce283 100644 --- a/flymyai/core/exceptions.py +++ b/flymyai/core/exceptions.py @@ -1,7 +1,7 @@ from typing import List from ._response import FlyMyAIResponse -from .models import ( +from .models.error_responses import ( FlyMyAI401Response, FlyMyAI422Response, Base4xxResponse, diff --git a/flymyai/core/models/__init__.py b/flymyai/core/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/flymyai/core/models.py b/flymyai/core/models/error_responses.py similarity index 51% rename from flymyai/core/models.py rename to flymyai/core/models/error_responses.py index 53abfaa..1b29e67 100644 --- a/flymyai/core/models.py +++ b/flymyai/core/models/error_responses.py @@ -1,13 +1,7 @@ import dataclasses import json -from typing import Optional import httpx -import pydantic -from pydantic import PrivateAttr - -from flymyai.core._response import FlyMyAIResponse -from flymyai.core.types.event_types import EventType @dataclasses.dataclass @@ -90,68 +84,3 @@ def to_msg(self): if detail := jsoned.get("detail"): msg += f"Details: {detail}" return msg - - -class BaseFromServer(pydantic.BaseModel): - _response: FlyMyAIResponse = PrivateAttr() - - @property - def response(self): - return self._response - - @classmethod - def from_response(cls, response: FlyMyAIResponse, **kwargs): - status_code = kwargs.pop("status", response.status_code) - response_json = response.json() - response_json["status"] = response_json.get("status", status_code) - self = cls(**response_json, **kwargs) - self._response = response - return self - - -class PredictionResponse(BaseFromServer): - """ - Prediction response from FlyMyAI - """ - - exc_history: Optional[list] - output_data: dict - status: int - - inference_time: Optional[float] = None - - @property - def response(self): - return self._response - - -class OpenAPISchemaResponse(BaseFromServer): - """ - OpenAPI schema for the current project. Use it to construct your own schema - """ - - exc_history: Optional[list] - openapi_schema: dict - status: int - - -class PredictionPartial(BaseFromServer): - status: int - output_data: Optional[dict] = None - - _response: FlyMyAIResponse = PrivateAttr() - - -class PredictionEvent(BaseFromServer): - status: int - event_type: EventType - - prediction_id: Optional[str] = None # EventType.STREAM_ID - - -class StreamDetails(pydantic.BaseModel): - input_tokens: Optional[int] = None - output_tokens: Optional[int] = None - size_in_billions: Optional[float] = pydantic.Field( - default=None, alias="model_size_in_billions" - ) diff --git a/flymyai/core/models/successful_responses.py b/flymyai/core/models/successful_responses.py new file mode 100644 index 0000000..f645782 --- /dev/null +++ b/flymyai/core/models/successful_responses.py @@ -0,0 +1,72 @@ +from typing import Optional + +import pydantic +from pydantic import PrivateAttr + +from flymyai.core._response import FlyMyAIResponse +from flymyai.core.types.event_types import EventType + + +class BaseFromServer(pydantic.BaseModel): + _response: FlyMyAIResponse = PrivateAttr() + + @property + def response(self): + return self._response + + @classmethod + def from_response(cls, response: FlyMyAIResponse, **kwargs): + status_code = kwargs.pop("status", response.status_code) + response_json = response.json() + response_json["status"] = response_json.get("status", status_code) + self = cls(**response_json, **kwargs) + self._response = response + return self + + +class PredictionResponse(BaseFromServer): + """ + Prediction response from FlyMyAI + """ + + exc_history: Optional[list] + output_data: dict + status: int + + inference_time: Optional[float] = None + + @property + def response(self): + return self._response + + +class OpenAPISchemaResponse(BaseFromServer): + """ + OpenAPI schema for the current project. Use it to construct your own schema + """ + + exc_history: Optional[list] + openapi_schema: dict + status: int + + +class PredictionPartial(BaseFromServer): + status: int + output_data: Optional[dict] = None + + _response: FlyMyAIResponse = PrivateAttr() + + +class PredictionEvent(BaseFromServer): + status: int + event_type: EventType + + prediction_id: Optional[str] = None # EventType.STREAM_ID + + +class StreamDetails(pydantic.BaseModel): + input_tokens: Optional[int] = None + output_tokens: Optional[int] = None + size_in_billions: Optional[float] = pydantic.Field( + default=None, alias="model_size_in_billions" + ) diff --git a/flymyai/core/stream_iterators/AsyncPredictionStream.py b/flymyai/core/stream_iterators/AsyncPredictionStream.py index 3cd4bb3..433a15d 100644 --- a/flymyai/core/stream_iterators/AsyncPredictionStream.py +++ b/flymyai/core/stream_iterators/AsyncPredictionStream.py @@ -2,8 +2,15 @@ from typing import AsyncIterator, TypeVar, Callable, Union, Awaitable from flymyai.core._response import FlyMyAIResponse +from flymyai.core.authorizations import APIKeyClientInfo +from flymyai.core.clients.base_client import BaseClient from flymyai.core.exceptions import BaseFlyMyAIException -from flymyai.core.models import StreamDetails, PredictionPartial, PredictionEvent +from flymyai.core.models.successful_responses import ( + StreamDetails, + PredictionPartial, + PredictionEvent, +) +from flymyai.core.stream_iterators.exceptions import StreamCancellationException from flymyai.core.types.event_types import EventType @@ -20,10 +27,29 @@ class AsyncPredictionStream: event_callback: _AsyncEventCallbackType = None + prediction_id: str + follow_cancelling: bool = True - def __init__(self, response_iterator: AsyncIterator): + _client: BaseClient + _client_info: APIKeyClientInfo + + def __init__( + self, + response_iterator: AsyncIterator, + client: BaseClient, + client_info: APIKeyClientInfo, + ): self.response_iterator = response_iterator + self._client = client + self._client_info = client_info + + async def cancel(self): + if not hasattr(self, "prediction_id"): + raise StreamCancellationException("No prediction_id obtained!") + return await self._client.cancel_prediction( + self.prediction_id, client_info=self._client_info + ) def __aiter__(self): return self @@ -48,11 +74,10 @@ async def loop_iter(self): asyncio.run_coroutine_threadsafe( coro_or_res, asyncio.get_event_loop() ) - if ( - self.follow_cancelling - and evt.event_type == EventType.CANCELLING - ): - raise StopAsyncIteration + if self.follow_cancelling and evt.event_type == EventType.CANCELLING: + raise StopAsyncIteration + if evt.event_type == EventType.STREAM_ID: + self.prediction_id = evt.prediction_id async def __anext__(self): response_end = None diff --git a/flymyai/core/stream_iterators/PredictionStream.py b/flymyai/core/stream_iterators/PredictionStream.py index 53c114b..78dd1b3 100644 --- a/flymyai/core/stream_iterators/PredictionStream.py +++ b/flymyai/core/stream_iterators/PredictionStream.py @@ -1,8 +1,15 @@ from typing import Optional, Iterator, TypeVar, Callable from flymyai.core._response import FlyMyAIResponse +from flymyai.core.authorizations import APIKeyClientInfo +from flymyai.core.clients.base_client import BaseClient from flymyai.core.exceptions import BaseFlyMyAIException -from flymyai.core.models import StreamDetails, PredictionPartial, PredictionEvent +from flymyai.core.models.successful_responses import ( + StreamDetails, + PredictionPartial, + PredictionEvent, +) +from flymyai.core.stream_iterators.exceptions import StreamCancellationException from flymyai.core.types.event_types import EventType @@ -13,12 +20,29 @@ class PredictionStream: stream_details: StreamDetails + event_callback: _SyncEventCallbackType = None + prediction_id: str + follow_cancelling: bool = True - event_callback: Optional[_SyncEventCallbackType] = None - follow_cancelling = True + _client: BaseClient + _client_info: APIKeyClientInfo - def __init__(self, response_iterator: Iterator): + def __init__( + self, + response_iterator: Iterator, + client: BaseClient, + client_info: APIKeyClientInfo, + ): self.response_iterator = response_iterator + self._client = client + self._client_info = client_info + + def cancel(self): + if not hasattr(self, "prediction_id"): + raise StreamCancellationException("No prediction_id obtained!") + return self._client.cancel_prediction( + self.prediction_id, client_info=self._client_info + ) def set_on_event(self, callback: _SyncEventCallbackType): self.event_callback = callback @@ -35,15 +59,12 @@ def loop_iter(self): return response_end else: evt = PredictionEvent.from_response(next_resp) - if not self.event_callback: - pass - else: + if evt.event_type == EventType.STREAM_ID: + self.prediction_id = evt.prediction_id + if self.event_callback: self.event_callback(evt) - if ( - self.follow_cancelling - and evt.event_type == EventType.CANCELLING - ): - raise StopIteration + if self.follow_cancelling and evt.event_type == EventType.CANCELLING: + raise StopIteration def __next__(self): response_end = None @@ -53,6 +74,8 @@ def __next__(self): except BaseFlyMyAIException as e: response_end = e.response raise e + except Exception as e: + raise e finally: if not response_end: raise StopIteration() diff --git a/flymyai/core/stream_iterators/exceptions.py b/flymyai/core/stream_iterators/exceptions.py new file mode 100644 index 0000000..1875d91 --- /dev/null +++ b/flymyai/core/stream_iterators/exceptions.py @@ -0,0 +1,2 @@ +class StreamCancellationException(Exception): + ... diff --git a/tests/test_stream.py b/tests/test_stream.py index 7c723ee..32a39d2 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -1,8 +1,12 @@ +import asyncio import os +import threading import pytest from flymyai import client as sync_client, async_client +from flymyai.core.models.successful_responses import PredictionEvent +from flymyai.core.types.event_types import EventType from tests.FixtureFactory import FixtureFactory @@ -31,7 +35,7 @@ def output_field(): def test_stream(stream_auth, stream_payload, dsn, output_field): stream_iterator = sync_client(**stream_auth).stream(stream_payload) - stream_iterator.follow_cancelling = False + stream_iterator.follow_cancelling = True stream_iterator.set_on_event(print) try: for response in stream_iterator: @@ -51,7 +55,7 @@ def test_stream(stream_auth, stream_payload, dsn, output_field): @pytest.mark.asyncio async def test_async_stream(stream_auth, stream_payload, dsn, output_field): stream_iterator = async_client(**stream_auth).stream(stream_payload) - stream_iterator.follow_cancelling = False + stream_iterator.follow_cancelling = True stream_iterator.set_on_event(print) try: async for response in stream_iterator: @@ -66,3 +70,126 @@ async def test_async_stream(stream_auth, stream_payload, dsn, output_field): finally: print() print(getattr(stream_iterator, "stream_details", None)) + + +def test_stream_cancel(stream_auth, stream_payload, dsn, output_field): + stream_iterator = sync_client(**stream_auth).stream(stream_payload) + stream_iterator.follow_cancelling = False + cancelling_obtained = threading.Event() + + def cancel_callback(event: PredictionEvent): + if event.event_type == EventType.STREAM_ID: + stream_iterator.cancel() + else: + cancelling_obtained.set() + + stream_iterator.set_on_event(cancel_callback) + try: + for response in stream_iterator: + assert response.status == 200 + assert response.output_data or hasattr(stream_iterator, "stream_details") + if response.output_data.get(output_field): + print(response.output_data[output_field].pop(), end="") + except Exception as e: + if hasattr(e, "msg"): + print(e) + raise e + finally: + print() + assert cancelling_obtained.is_set() + print(getattr(stream_iterator, "stream_details", None)) + + +def test_cancel_with_client(stream_auth, stream_payload, dsn, output_field): + client = sync_client(**stream_auth) + stream_iterator = client.stream(stream_payload) + stream_iterator.follow_cancelling = True + cancelling_obtained = threading.Event() + + def cancel_callback(event: PredictionEvent): + if event.event_type == EventType.STREAM_ID: + client.cancel_prediction( + stream_iterator.prediction_id, model=stream_auth["model"] + ) + else: + cancelling_obtained.set() + + stream_iterator.set_on_event(cancel_callback) + try: + for response in stream_iterator: + assert response.status == 200 + assert response.output_data or hasattr(stream_iterator, "stream_details") + if response.output_data.get(output_field): + print(response.output_data[output_field].pop(), end="") + except Exception as e: + if hasattr(e, "msg"): + print(e) + raise e + finally: + print() + assert cancelling_obtained.is_set() + print(getattr(stream_iterator, "stream_details", None)) + + +@pytest.mark.asyncio +async def test_async_stream_cancel(stream_auth, stream_payload, dsn, output_field): + client = async_client(**stream_auth) + stream_iterator = client.stream(stream_payload) + stream_iterator.follow_cancelling = False + cancelling_obtained = asyncio.Event() + + async def cancel_callback(event: PredictionEvent): + if event.event_type == EventType.STREAM_ID: + await stream_iterator.cancel() + else: + cancelling_obtained.set() + + stream_iterator.set_on_event(cancel_callback) + try: + async for response in stream_iterator: + assert response.status == 200 + assert response.output_data or hasattr(stream_iterator, "stream_details") + if response.output_data.get(output_field): + print(response.output_data[output_field].pop(), end="") + except Exception as e: + if hasattr(e, "msg"): + print(e) + raise e + finally: + print() + assert cancelling_obtained.is_set() + print(getattr(stream_iterator, "stream_details", None)) + + +@pytest.mark.asyncio +async def test_async_stream_cancel_with_client( + stream_auth, stream_payload, dsn, output_field +): + client = async_client(**stream_auth) + stream_iterator = client.stream(stream_payload) + stream_iterator.follow_cancelling = False + cancelling_obtained = asyncio.Event() + + async def cancel_callback(event: PredictionEvent): + if event.event_type == EventType.STREAM_ID: + await client.cancel_prediction( + stream_iterator.prediction_id, model=stream_auth["model"] + ) + else: + cancelling_obtained.set() + + stream_iterator.set_on_event(cancel_callback) + try: + async for response in stream_iterator: + assert response.status == 200 + assert response.output_data or hasattr(stream_iterator, "stream_details") + if response.output_data.get(output_field): + print(response.output_data[output_field].pop(), end="") + except Exception as e: + if hasattr(e, "msg"): + print(e) + raise e + finally: + print() + assert cancelling_obtained.is_set() + print(getattr(stream_iterator, "stream_details", None)) From ac8de9d7398f4f1e86db2ca72d428eba234c0f13 Mon Sep 17 00:00:00 2001 From: D1-3105 Date: Fri, 9 Aug 2024 00:30:22 +0300 Subject: [PATCH 5/5] tests refactor for cicd --- tests/test_flymyai_client.py | 8 +++++++- tests/test_multipart.py | 17 ----------------- tests/test_stream.py | 8 +++++++- 3 files changed, 14 insertions(+), 19 deletions(-) delete mode 100644 tests/test_multipart.py diff --git a/tests/test_flymyai_client.py b/tests/test_flymyai_client.py index a24f5b7..cda0832 100644 --- a/tests/test_flymyai_client.py +++ b/tests/test_flymyai_client.py @@ -33,7 +33,13 @@ def fake_payload_fixture(binary_file_paths) -> dict: @pytest.fixture def client_auth_fixture() -> dict: - return factory("client_auth_fixture") + client_auth_fixture: dict = factory("client_auth_fixture") + auth_apikey_env = client_auth_fixture.pop("apikey_environ", "") + if auth_apikey_env: + client_auth_fixture["apikey"] = os.getenv( + auth_apikey_env, client_auth_fixture.get("apikey") + ) + return client_auth_fixture def test_flymyai_client(address_fixture, fake_payload_fixture, client_auth_fixture): diff --git a/tests/test_multipart.py b/tests/test_multipart.py deleted file mode 100644 index 5264b8e..0000000 --- a/tests/test_multipart.py +++ /dev/null @@ -1,17 +0,0 @@ -import pytest -from flymyai.multipart import MultipartPayload - -from .FixtureFactory import FixtureFactory - -factory = FixtureFactory(__file__) - - -@pytest.fixture -def multiparts(): - return factory("multiparts") - - -def test_multipart_payload(multiparts): - for payload in multiparts: - files = MultipartPayload(payload[0]).serialize().get("files") - assert bool(files) == payload[1] diff --git a/tests/test_stream.py b/tests/test_stream.py index 587d804..4c7c624 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -25,7 +25,13 @@ def stream_payload(): @pytest.fixture def stream_auth(): - return factory("auth") + client_auth_fixture = factory("auth") + auth_apikey_env = client_auth_fixture.pop("apikey_environ", "") + if auth_apikey_env: + client_auth_fixture["apikey"] = os.getenv( + auth_apikey_env, client_auth_fixture.get("apikey") + ) + return client_auth_fixture @pytest.fixture