Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 49 additions & 41 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ import flymyai
response = flymyai.run(
apikey="fly-secret-key",
model="flymyai/bert",
payload={"i_text": "What a fabulous fancy building! It looks like a palace!"}
payload={"text": "What a fabulous fancy building! It looks like a palace!"}
)
print(response.output_data["o_logits"][0])
print(response.output_data["logits"][0])
```


Expand All @@ -62,58 +62,65 @@ from flymyai import client, FlyMyAIPredictException

fma_client = client(apikey="fly-secret-key")

stream_iterator = fma_client.stream(
payload={
"prompt": "tell me a story about christmas tree",
"best_of": 12,
"max_tokens": 1024,
"stop": 1,
"temperature": 1,
"top_k": 1,
"top_p": "0.95",
},
model="flymyai/llama-v3-8b"
)
try:
stream_iterator = fma_client.stream(
payload={
"i_prompt": "tell me a story about christmas tree",
"i_best_of": 12,
"i_max_tokens": 1024,
"i_stop": 1,
"i_temperature": 1,
"i_top_k": 1,
"i_top_p": "0.95",
},
model="flymyai/llama3"
)
for response in stream_iterator:
print(response.output_data["o_output"].pop(), end="")
if response.output_data.get("output"):
print(response.output_data["output"].pop(), end="")
except FlyMyAIPredictException as e:
print(e)
raise e
finally:
print()
print(stream_iterator.stream_details)
```

## Async Streams
For llms you should use stream method

#### Stable Code Instruct 3b

```python
from flymyai import async_client, FlyMyAIPredictException
import asyncio

from flymyai import async_client, FlyMyAIPredictException


async def run_stable_code():
fma_client = async_client(apikey="fly-secret-key")
stream_iterator = fma_client.stream(
payload={
"prompt": "What's the difference between an iterator and a generator in Python?",
"best_of": 12,
"max_tokens": 512,
"stop": 1,
"temperature": 1,
"top_k": 1,
"top_p": "0.95",
},
model="flymyai/Stable-Code-Instruct-3b"
)
try:
stream_iterator = fma_client.stream(
payload={
"i_prompt": "What's the difference between an iterator and a generator in Python?",
"i_best_of": 12,
"i_max_tokens": 512,
"i_stop": 1,
"i_temperature": 1,
"i_top_k": 1,
"i_top_p": "0.95",
},
model="flymyai/Stable-Code-Instruct-3b"
)
async for response in stream_iterator:
print(response.output_data["o_output"].pop(), end="")
if response.output_data.get("output"):
print(response.output_data["output"].pop(), end="")
except FlyMyAIPredictException as e:
print(e)
raise e
finally:
print()
print(stream_iterator.stream_details)


asyncio.run(run_stable_code())
Expand All @@ -126,15 +133,16 @@ asyncio.run(run_stable_code())
You can pass file inputs to models using file paths:

```python
import flymyai
import pathlib

import flymyai

response = flymyai.run(
apikey="fly-secret-key",
model="flymyai/resnet",
payload={"i_image": pathlib.Path("/path/to/image.png")}
payload={"image": pathlib.Path("/path/to/image.png")}
)
print(response.output_data["o_495"])
print(response.output_data["495"])
```


Expand All @@ -150,10 +158,10 @@ response = flymyai.run(
apikey="fly-secret-key",
model="flymyai/SDTurboFMAAceleratedH100",
payload={
"i_prompt": "An astronaut riding a rainbow unicorn, cinematic, dramatic, photorealistic",
"prompt": "An astronaut riding a rainbow unicorn, cinematic, dramatic, photorealistic",
}
)
base64_image = response.output_data["o_sample"][0]
base64_image = response.output_data["sample"][0]
image_data = base64.b64decode(base64_image)
with open("generated_image.jpg", "wb") as file:
file.write(image_data)
Expand All @@ -171,10 +179,10 @@ import flymyai
async def main():
payloads = [
{
"i_prompt": "An astronaut riding a rainbow unicorn, cinematic, dramatic, photorealistic",
"i_negative_prompt": "Dark colors, gloomy atmosphere, horror",
"i_seed": count,
"i_denoising_steps": 4,
"prompt": "An astronaut riding a rainbow unicorn, cinematic, dramatic, photorealistic",
"negative_prompt": "Dark colors, gloomy atmosphere, horror",
"seed": count,
"denoising_steps": 4,
"scheduler": "DPM++ SDE"
}
for count in range(1, 10)
Expand All @@ -192,7 +200,7 @@ async def main():
]
results = await asyncio.gather(*tasks)
for result in results:
print(result.output_data["o_output"])
print(result.output_data["output"])


asyncio.run(main())
Expand All @@ -208,13 +216,13 @@ import pathlib


async def background_task():
payload = {"i_audio": pathlib.Path("/path/to/audio.mp3")}
payload = {"audio": pathlib.Path("/path/to/audio.mp3")}
response = await flymyai.async_run(
apikey="fly-secret-key",
model="flymyai/whisper",
payload=payload
)
print("Background task completed:", response.output_data["o_transcription"])
print("Background task completed:", response.output_data["transcription"])


async def main():
Expand Down
80 changes: 63 additions & 17 deletions flymyai/core/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import httpx

from flymyai.core._response import FlyMyAIResponse
from flymyai.core._response_factory import ResponseFactory
from flymyai.core._streaming import SSEDecoder
from flymyai.core.authorizations import APIKeyClientInfo
Expand All @@ -28,6 +29,7 @@
PredictionResponse,
OpenAPISchemaResponse,
PredictionPartial,
StreamDetails,
)
from flymyai.multipart.payload import MultipartPayload
from flymyai.utils.utils import retryable_callback, aretryable_callback
Expand Down Expand Up @@ -168,6 +170,34 @@ def _construct_client(self):
raise NotImplemented


class PredictionStream:
stream_details: StreamDetails

def __init__(self, response_iterator: Iterator):
self.response_iterator = response_iterator

def __iter__(self):
return self

def __next__(self):
response_end = None
try:
next_resp: FlyMyAIResponse = self.response_iterator.__next__()
response_end = next_resp
return PredictionPartial.from_response(response_end)
except BaseFlyMyAIException as e:
response_end = e.response
raise e
finally:
if not response_end:
raise StopIteration()
stream_details_marshalled = response_end.json().get("stream_details")
if stream_details_marshalled:
self.stream_details = StreamDetails.model_validate(
stream_details_marshalled
)


class BaseSyncClient(BaseClient[httpx.Client]):
def _construct_client(self):
return httpx.Client(
Expand Down Expand Up @@ -251,13 +281,8 @@ def _stream(self, client_info: APIKeyClientInfo, payload: dict):

def stream(self, payload: dict, model: Optional[str] = None):
stream_iter = self._stream(self.amend_client_info(model), payload)
last_response = None
for response in stream_iter:
response.stream = stream_iter
yield PredictionPartial.from_response(response)
last_response = response
if last_response:
last_response.is_stream_consumed = True
stream_wrapper = PredictionStream(stream_iter)
return stream_wrapper

def _openapi_schema(self, client_info: APIKeyClientInfo):
"""
Expand Down Expand Up @@ -307,6 +332,34 @@ def run_predict(cls, apikey: str, model: str, payload: dict):
return client.predict(payload)


class AsyncPredictionStream:
stream_details: StreamDetails

def __init__(self, response_iterator: AsyncIterator):
self.response_iterator = response_iterator

def __aiter__(self):
return self

async def __anext__(self):
response_end = None
try:
next_resp: FlyMyAIResponse = await self.response_iterator.__anext__()
response_end = next_resp
return PredictionPartial.from_response(response_end)
except BaseFlyMyAIException as e:
response_end = e.response
raise e
finally:
if not response_end:
raise StopAsyncIteration()
stream_details_marshalled = response_end.json().get("stream_details")
if stream_details_marshalled:
self.stream_details = StreamDetails.model_validate(
stream_details_marshalled
)


class BaseAsyncClient(BaseClient[httpx.AsyncClient]):
def _construct_client(self):
return httpx.AsyncClient(
Expand Down Expand Up @@ -430,17 +483,10 @@ async def _stream(self, client_info: APIKeyClientInfo, payload: dict):
raise FlyMyAIPredictException.from_response(e.response)
yield response

async def stream(
self, payload: dict, model: Optional[str] = None, max_retries=None
):
def stream(self, payload: dict, model: Optional[str] = None, max_retries=None):
stream_iter = self._stream(self.amend_client_info(model), payload)
last_response = None
async for response in stream_iter:
response.stream = stream_iter
yield PredictionPartial.from_response(response)
last_response = response
if last_response:
last_response.is_stream_consumed = True
stream_wrapper = AsyncPredictionStream(stream_iter)
return stream_wrapper

@staticmethod
async def _wrap_request(request_callback: Callable[..., Awaitable[httpx.Response]]):
Expand Down
10 changes: 10 additions & 0 deletions flymyai/core/_response.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import json
import typing

import httpx


Expand All @@ -10,3 +13,10 @@ def from_httpx(cls, response: httpx.Response):
request=response.request,
headers=response.headers,
)

def json(self, **kwargs) -> typing.Any:
if self.content.startswith(b"data"):
trail_content = self.content[: len(b"data")]
return json.loads(trail_content)
else:
return super().json(**kwargs)
2 changes: 1 addition & 1 deletion flymyai/core/_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,6 @@ def decode(self, line: str) -> ServerSentEvent | None:
except (TypeError, ValueError):
pass
else:
pass # Field is ignored.
pass # the field is ignored.

return None
6 changes: 6 additions & 0 deletions flymyai/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,9 @@ class PredictionPartial(BaseFromServer):
output_data: Optional[dict] = None

_response: FlyMyAIResponse = PrivateAttr()


class StreamDetails(pydantic.BaseModel):
input_tokens: int
output_tokens: int
size_in_billions: float = pydantic.Field(alias="model_size_in_billions")
26 changes: 18 additions & 8 deletions tests/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,34 @@ def output_field():

def test_stream(stream_auth, stream_payload, dsn, output_field):
stream_iterator = sync_client(**stream_auth).stream(stream_payload)
for response in stream_iterator:
assert response.status == 200
assert response.output_data
print(response.output_data[output_field].pop(), end="")
print("\n")
try:
for response in stream_iterator:
assert response.status == 200
assert response.output_data or hasattr(stream_iterator, "stream_details")
if response.output_data.get(output_field):
print(response.output_data[output_field].pop(), end="")
except Exception as e:
if hasattr(e, "msg"):
print(e)
raise e
finally:
print()
print(stream_iterator.stream_details)


@pytest.mark.asyncio
async def test_async_stream(stream_auth, stream_payload, dsn, output_field):
stream_iterator = async_client(**stream_auth).stream(stream_payload)
try:
stream_iterator = async_client(**stream_auth).stream(stream_payload)
async for response in stream_iterator:
assert response.status == 200
assert response.output_data
print(response.output_data[output_field].pop(), end="")
assert response.output_data or hasattr(stream_iterator, "stream_details")
if response.output_data.get(output_field):
print(response.output_data[output_field].pop(), end="")
except Exception as e:
if hasattr(e, "msg"):
print(e)
raise e
finally:
print()
print(stream_iterator.stream_details)