diff --git a/README.md b/README.md index bb61208..d057392 100644 --- a/README.md +++ b/README.md @@ -47,9 +47,9 @@ import flymyai response = flymyai.run( apikey="fly-secret-key", model="flymyai/bert", - payload={"i_text": "What a fabulous fancy building! It looks like a palace!"} + payload={"text": "What a fabulous fancy building! It looks like a palace!"} ) -print(response.output_data["o_logits"][0]) +print(response.output_data["logits"][0]) ``` @@ -62,58 +62,65 @@ from flymyai import client, FlyMyAIPredictException fma_client = client(apikey="fly-secret-key") +stream_iterator = fma_client.stream( + payload={ + "prompt": "tell me a story about christmas tree", + "best_of": 12, + "max_tokens": 1024, + "stop": 1, + "temperature": 1, + "top_k": 1, + "top_p": "0.95", + }, + model="flymyai/llama-v3-8b" +) try: - stream_iterator = fma_client.stream( - payload={ - "i_prompt": "tell me a story about christmas tree", - "i_best_of": 12, - "i_max_tokens": 1024, - "i_stop": 1, - "i_temperature": 1, - "i_top_k": 1, - "i_top_p": "0.95", - }, - model="flymyai/llama3" - ) for response in stream_iterator: - print(response.output_data["o_output"].pop(), end="") + if response.output_data.get("output"): + print(response.output_data["output"].pop(), end="") except FlyMyAIPredictException as e: print(e) raise e finally: print() + print(stream_iterator.stream_details) ``` ## Async Streams For llms you should use stream method #### Stable Code Instruct 3b + ```python -from flymyai import async_client, FlyMyAIPredictException import asyncio +from flymyai import async_client, FlyMyAIPredictException + + async def run_stable_code(): fma_client = async_client(apikey="fly-secret-key") + stream_iterator = fma_client.stream( + payload={ + "prompt": "What's the difference between an iterator and a generator in Python?", + "best_of": 12, + "max_tokens": 512, + "stop": 1, + "temperature": 1, + "top_k": 1, + "top_p": "0.95", + }, + model="flymyai/Stable-Code-Instruct-3b" + ) try: - stream_iterator = fma_client.stream( - payload={ - "i_prompt": "What's the difference between an iterator and a generator in Python?", - "i_best_of": 12, - "i_max_tokens": 512, - "i_stop": 1, - "i_temperature": 1, - "i_top_k": 1, - "i_top_p": "0.95", - }, - model="flymyai/Stable-Code-Instruct-3b" - ) async for response in stream_iterator: - print(response.output_data["o_output"].pop(), end="") + if response.output_data.get("output"): + print(response.output_data["output"].pop(), end="") except FlyMyAIPredictException as e: print(e) raise e finally: print() + print(stream_iterator.stream_details) asyncio.run(run_stable_code()) @@ -126,15 +133,16 @@ asyncio.run(run_stable_code()) You can pass file inputs to models using file paths: ```python -import flymyai import pathlib +import flymyai + response = flymyai.run( apikey="fly-secret-key", model="flymyai/resnet", - payload={"i_image": pathlib.Path("/path/to/image.png")} + payload={"image": pathlib.Path("/path/to/image.png")} ) -print(response.output_data["o_495"]) +print(response.output_data["495"]) ``` @@ -150,10 +158,10 @@ response = flymyai.run( apikey="fly-secret-key", model="flymyai/SDTurboFMAAceleratedH100", payload={ - "i_prompt": "An astronaut riding a rainbow unicorn, cinematic, dramatic, photorealistic", + "prompt": "An astronaut riding a rainbow unicorn, cinematic, dramatic, photorealistic", } ) -base64_image = response.output_data["o_sample"][0] +base64_image = response.output_data["sample"][0] image_data = base64.b64decode(base64_image) with open("generated_image.jpg", "wb") as file: file.write(image_data) @@ -171,10 +179,10 @@ import flymyai async def main(): payloads = [ { - "i_prompt": "An astronaut riding a rainbow unicorn, cinematic, dramatic, photorealistic", - "i_negative_prompt": "Dark colors, gloomy atmosphere, horror", - "i_seed": count, - "i_denoising_steps": 4, + "prompt": "An astronaut riding a rainbow unicorn, cinematic, dramatic, photorealistic", + "negative_prompt": "Dark colors, gloomy atmosphere, horror", + "seed": count, + "denoising_steps": 4, "scheduler": "DPM++ SDE" } for count in range(1, 10) @@ -192,7 +200,7 @@ async def main(): ] results = await asyncio.gather(*tasks) for result in results: - print(result.output_data["o_output"]) + print(result.output_data["output"]) asyncio.run(main()) @@ -208,13 +216,13 @@ import pathlib async def background_task(): - payload = {"i_audio": pathlib.Path("/path/to/audio.mp3")} + payload = {"audio": pathlib.Path("/path/to/audio.mp3")} response = await flymyai.async_run( apikey="fly-secret-key", model="flymyai/whisper", payload=payload ) - print("Background task completed:", response.output_data["o_transcription"]) + print("Background task completed:", response.output_data["transcription"]) async def main(): diff --git a/flymyai/core/_client.py b/flymyai/core/_client.py index c758fc7..88d469b 100644 --- a/flymyai/core/_client.py +++ b/flymyai/core/_client.py @@ -14,6 +14,7 @@ import httpx +from flymyai.core._response import FlyMyAIResponse from flymyai.core._response_factory import ResponseFactory from flymyai.core._streaming import SSEDecoder from flymyai.core.authorizations import APIKeyClientInfo @@ -28,6 +29,7 @@ PredictionResponse, OpenAPISchemaResponse, PredictionPartial, + StreamDetails, ) from flymyai.multipart.payload import MultipartPayload from flymyai.utils.utils import retryable_callback, aretryable_callback @@ -168,6 +170,34 @@ def _construct_client(self): raise NotImplemented +class PredictionStream: + stream_details: StreamDetails + + def __init__(self, response_iterator: Iterator): + self.response_iterator = response_iterator + + def __iter__(self): + return self + + def __next__(self): + response_end = None + try: + next_resp: FlyMyAIResponse = self.response_iterator.__next__() + response_end = next_resp + return PredictionPartial.from_response(response_end) + except BaseFlyMyAIException as e: + response_end = e.response + raise e + finally: + if not response_end: + raise StopIteration() + stream_details_marshalled = response_end.json().get("stream_details") + if stream_details_marshalled: + self.stream_details = StreamDetails.model_validate( + stream_details_marshalled + ) + + class BaseSyncClient(BaseClient[httpx.Client]): def _construct_client(self): return httpx.Client( @@ -251,13 +281,8 @@ def _stream(self, client_info: APIKeyClientInfo, payload: dict): def stream(self, payload: dict, model: Optional[str] = None): stream_iter = self._stream(self.amend_client_info(model), payload) - last_response = None - for response in stream_iter: - response.stream = stream_iter - yield PredictionPartial.from_response(response) - last_response = response - if last_response: - last_response.is_stream_consumed = True + stream_wrapper = PredictionStream(stream_iter) + return stream_wrapper def _openapi_schema(self, client_info: APIKeyClientInfo): """ @@ -307,6 +332,34 @@ def run_predict(cls, apikey: str, model: str, payload: dict): return client.predict(payload) +class AsyncPredictionStream: + stream_details: StreamDetails + + def __init__(self, response_iterator: AsyncIterator): + self.response_iterator = response_iterator + + def __aiter__(self): + return self + + async def __anext__(self): + response_end = None + try: + next_resp: FlyMyAIResponse = await self.response_iterator.__anext__() + response_end = next_resp + return PredictionPartial.from_response(response_end) + except BaseFlyMyAIException as e: + response_end = e.response + raise e + finally: + if not response_end: + raise StopAsyncIteration() + stream_details_marshalled = response_end.json().get("stream_details") + if stream_details_marshalled: + self.stream_details = StreamDetails.model_validate( + stream_details_marshalled + ) + + class BaseAsyncClient(BaseClient[httpx.AsyncClient]): def _construct_client(self): return httpx.AsyncClient( @@ -430,17 +483,10 @@ async def _stream(self, client_info: APIKeyClientInfo, payload: dict): raise FlyMyAIPredictException.from_response(e.response) yield response - async def stream( - self, payload: dict, model: Optional[str] = None, max_retries=None - ): + def stream(self, payload: dict, model: Optional[str] = None, max_retries=None): stream_iter = self._stream(self.amend_client_info(model), payload) - last_response = None - async for response in stream_iter: - response.stream = stream_iter - yield PredictionPartial.from_response(response) - last_response = response - if last_response: - last_response.is_stream_consumed = True + stream_wrapper = AsyncPredictionStream(stream_iter) + return stream_wrapper @staticmethod async def _wrap_request(request_callback: Callable[..., Awaitable[httpx.Response]]): diff --git a/flymyai/core/_response.py b/flymyai/core/_response.py index 545dd4d..551f646 100644 --- a/flymyai/core/_response.py +++ b/flymyai/core/_response.py @@ -1,3 +1,6 @@ +import json +import typing + import httpx @@ -10,3 +13,10 @@ def from_httpx(cls, response: httpx.Response): request=response.request, headers=response.headers, ) + + def json(self, **kwargs) -> typing.Any: + if self.content.startswith(b"data"): + trail_content = self.content[: len(b"data")] + return json.loads(trail_content) + else: + return super().json(**kwargs) diff --git a/flymyai/core/_streaming.py b/flymyai/core/_streaming.py index db38ff4..ccef6d6 100644 --- a/flymyai/core/_streaming.py +++ b/flymyai/core/_streaming.py @@ -162,6 +162,6 @@ def decode(self, line: str) -> ServerSentEvent | None: except (TypeError, ValueError): pass else: - pass # Field is ignored. + pass # the field is ignored. return None diff --git a/flymyai/core/models.py b/flymyai/core/models.py index 8e14ac5..6a2fdc1 100644 --- a/flymyai/core/models.py +++ b/flymyai/core/models.py @@ -139,3 +139,9 @@ class PredictionPartial(BaseFromServer): output_data: Optional[dict] = None _response: FlyMyAIResponse = PrivateAttr() + + +class StreamDetails(pydantic.BaseModel): + input_tokens: int + output_tokens: int + size_in_billions: float = pydantic.Field(alias="model_size_in_billions") diff --git a/tests/test_stream.py b/tests/test_stream.py index b867301..8adb1ba 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -31,24 +31,34 @@ def output_field(): def test_stream(stream_auth, stream_payload, dsn, output_field): stream_iterator = sync_client(**stream_auth).stream(stream_payload) - for response in stream_iterator: - assert response.status == 200 - assert response.output_data - print(response.output_data[output_field].pop(), end="") - print("\n") + try: + for response in stream_iterator: + assert response.status == 200 + assert response.output_data or hasattr(stream_iterator, "stream_details") + if response.output_data.get(output_field): + print(response.output_data[output_field].pop(), end="") + except Exception as e: + if hasattr(e, "msg"): + print(e) + raise e + finally: + print() + print(stream_iterator.stream_details) @pytest.mark.asyncio async def test_async_stream(stream_auth, stream_payload, dsn, output_field): + stream_iterator = async_client(**stream_auth).stream(stream_payload) try: - stream_iterator = async_client(**stream_auth).stream(stream_payload) async for response in stream_iterator: assert response.status == 200 - assert response.output_data - print(response.output_data[output_field].pop(), end="") + assert response.output_data or hasattr(stream_iterator, "stream_details") + if response.output_data.get(output_field): + print(response.output_data[output_field].pop(), end="") except Exception as e: if hasattr(e, "msg"): print(e) raise e finally: print() + print(stream_iterator.stream_details)