Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ jobs:
test_unknown_error_handle:
needs:
- lint
- test_fields

runs-on: ubuntu-latest

Expand Down Expand Up @@ -139,3 +140,28 @@ jobs:
FMA_APIKEY: ${{ secrets.FMA_APIKEY }}
run: pytest tests/test_stream.py --tb=short

test_async_inference:
needs:
- lint
- test_unknown_error_handle
runs-on: ubuntu-latest

container:
image: python:3.8

steps:
- name: Check out git repo
uses: actions/checkout@v2
with:
fetch-depth: 0

- name: Fix
run: git config --global --add safe.directory '*'

- name: Install dependencies
run: pip3 install poetry pytest-asyncio && poetry config virtualenvs.create false && poetry install

- name: Test
env:
FMA_APIKEY: ${{ secrets.FMA_APIKEY }}
run: pytest tests/test_async_inference.py --tb=short
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,4 @@ dmypy.json
.idea/*
poetry.lock
venv*
tests/fixtures*
tests/fixtures_*
74 changes: 74 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,77 @@ asyncio.run(main())
# Continue with other operations while the model runs in the background
```


## Asynchronous Prediction Tasks

For long-running operations, FlyMyAI provides asynchronous prediction tasks. This allows you to submit a task and check its status later, which is useful for handling time-consuming predictions without blocking your application.

### Using Synchronous Client

```python
from flymyai import client
from flymyai.core.exceptions import (
RetryTimeoutExceededException,
FlyMyAIExceptionGroup,
)

# Initialize client
fma_client = client(apikey="fly-secret-key")

# Submit async prediction task
prediction_task = fma_client.predict_async_task(
model="flymyai/flux-schnell",
payload={"prompt": "Funny Cat with Stupid Dog"}
)

try:
# Get result
result = prediction_task.result()

print(f"Prediction completed: {result.inference_responses}")
except RetryTimeoutExceededException:
print("Prediction is taking longer than expected")
except FlyMyAIExceptionGroup as e:
print(f"Prediction failed: {e}")
```

### Using Asynchronous Client

```python
import asyncio
from flymyai import async_client
from flymyai.core.exceptions import (
RetryTimeoutExceededException,
FlyMyAIExceptionGroup,
)

async def run_prediction():
# Initialize async client
fma_client = async_client(apikey="fly-secret-key")

# Submit async prediction task
prediction_task = await fma_client.predict_async_task(
model="flymyai/flux-schnell",
payload={"prompt": "Funny Cat with Stupid Dog"}
)

try:
# Await result with default timeout
result = await prediction_task.result()
print(f"Prediction completed: {result.inference_responses}")

# Check response status
all_successful = all(
resp.infer_details["status"] == 200
for resp in result.inference_responses
)
print(f"All predictions successful: {all_successful}")

except RetryTimeoutExceededException:
print("Prediction is taking longer than expected")
except FlyMyAIExceptionGroup as e:
print(f"Prediction failed: {e}")

# Run async function
asyncio.run(run_prediction())
```
8 changes: 8 additions & 0 deletions flymyai/core/authorizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ def _project_path(self):
def prediction_path(self):
return self._project_path.join(httpx.URL("predict"))

@property
def prediction_async_path(self):
return self._project_path.join(httpx.URL("predict/async/"))

@property
def prediction_result_path(self):
return self._project_path.join(httpx.URL("predict/async/result/"))

@property
def prediction_cancel_path(self):
return self._project_path.join(httpx.URL("predict/cancel/"))
Expand Down
58 changes: 53 additions & 5 deletions flymyai/core/clients/AsyncClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

import httpx

from flymyai.core._response_factory import ResponseFactory
from flymyai.core.response_factory.plain_inference_response_factory import (
SSEInferenceResponseFactory,
)
from flymyai.core._streaming import SSEDecoder
from flymyai.core.authorizations import APIKeyClientInfo
from flymyai.core.clients.base_client import BaseClient, _predict_timeout
Expand All @@ -12,10 +14,12 @@
FlyMyAIOpenAPIException,
FlyMyAIPredictException,
FlyMyAIExceptionGroup,
FlyMyAIAsyncTaskException,
)
from flymyai.core.models.successful_responses import (
OpenAPISchemaResponse,
PredictionResponse,
AsyncPredictionTask,
)
from flymyai.core.stream_iterators.AsyncPredictionStream import AsyncPredictionStream
from flymyai.multipart import MultipartPayload
Expand All @@ -28,6 +32,7 @@ def _construct_client(self):
http2=True,
headers=self.client_info.authorization_headers,
base_url=os.getenv("FLYMYAI_DSN", "https://api.flymy.ai/"),
timeout=_predict_timeout,
)

async def __aenter__(self):
Expand Down Expand Up @@ -84,7 +89,7 @@ async def cancel_prediction(
url=full_client_info.prediction_cancel_path,
json={"infer_id": prediction_id},
)
return ResponseFactory(
return SSEInferenceResponseFactory(
httpx_response=response, httpx_request=response.request
).construct()

Expand All @@ -100,7 +105,7 @@ async def _sse_instant(
async with async_response_stream() as stream:
sse = await SSEDecoder().aiter(stream.aiter_lines()).__anext__()
try:
response = ResponseFactory(
response = SSEInferenceResponseFactory(
sse=sse, httpx_request=stream.request, httpx_response=stream
).construct()
return response
Expand Down Expand Up @@ -145,6 +150,49 @@ async def predict(
)
return PredictionResponse.from_response(response, exc_history=history)

async def predict_async_task(
self, payload: dict, model: Optional[str] = None, max_retries=None
) -> AsyncPredictionTask:
payload = MultipartPayload(input_data=payload)
client_info = self.amend_client_info(model)
try:
_, response = await aretryable_callback(
lambda: self._client.post(
client_info.prediction_async_path, **payload.serialize()
),
max_retries or self.max_retries,
FlyMyAIAsyncTaskException,
FlyMyAIExceptionGroup,
)
response = SSEInferenceResponseFactory(response).construct()
return self._async_prediction_task_construct(response, client_info)
except BaseFlyMyAIException as e:
raise FlyMyAIAsyncTaskException.from_base_exception(e)

async def prediction_task_result(
self, prediction_task: AsyncPredictionTask, timeout: Optional[float] = None
):
prediction_id = prediction_task.prediction_id

async def get_res():
data_resp = await self._client.get(
url=(
prediction_task.client_info or self.client_info
).prediction_result_path,
params={"request_id": prediction_id},
)
return self._construct_task_result(data_resp)

_, res = await aretryable_callback(
lambda: get_res(),
None,
FlyMyAIAsyncTaskException,
FlyMyAIExceptionGroup,
timeout,
0.5,
)
return res

async def _stream(self, client_info: APIKeyClientInfo, payload: dict):
payload = MultipartPayload(payload)
stream_iterator = self._stream_iterator(
Expand All @@ -154,7 +202,7 @@ async def _stream(self, client_info: APIKeyClientInfo, payload: dict):
async with stream_iterator as sse_stream:
async for sse_partial in decoder.aiter(sse_stream.aiter_lines()):
try:
response = ResponseFactory(
response = SSEInferenceResponseFactory(
sse=sse_partial,
httpx_request=sse_stream.request,
httpx_response=sse_stream,
Expand All @@ -175,7 +223,7 @@ async def _wrap_request(request_callback: Callable[..., Awaitable[httpx.Response
Execute a request callback and return the response
"""
response = await request_callback()
return ResponseFactory(httpx_response=response).construct()
return SSEInferenceResponseFactory(httpx_response=response).construct()

async def close(self):
"""
Expand Down
59 changes: 54 additions & 5 deletions flymyai/core/clients/SyncClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,23 @@

import httpx

from flymyai.core._response_factory import ResponseFactory
from flymyai.core._streaming import SSEDecoder
from flymyai.core.authorizations import APIKeyClientInfo
from flymyai.core.clients.base_client import BaseClient
from flymyai.core.clients.base_client import BaseClient, _predict_timeout
from flymyai.core.exceptions import (
BaseFlyMyAIException,
FlyMyAIOpenAPIException,
FlyMyAIPredictException,
FlyMyAIExceptionGroup,
FlyMyAIAsyncTaskException,
)
from flymyai.core.models.successful_responses import (
PredictionResponse,
OpenAPISchemaResponse,
AsyncPredictionTask,
)
from flymyai.core.response_factory.plain_inference_response_factory import (
SSEInferenceResponseFactory,
)
from flymyai.core.stream_iterators.PredictionStream import PredictionStream
from flymyai.multipart import MultipartPayload
Expand All @@ -28,6 +32,7 @@ def _construct_client(self):
http2=True,
headers=self.client_info.authorization_headers,
base_url=os.getenv("FLYMYAI_DSN", "https://api.flymy.ai/"),
timeout=_predict_timeout,
)

def __enter__(self):
Expand All @@ -45,7 +50,7 @@ def _sse_instant(cls, stream_iter_func: Callable[[], Iterator[httpx.Response]]):
"""
with stream_iter_func() as stream:
stream: httpx.Response
response = ResponseFactory(
response = SSEInferenceResponseFactory(
sse=next(SSEDecoder().iter(stream.iter_lines())),
httpx_request=stream.request,
httpx_response=stream,
Expand Down Expand Up @@ -85,6 +90,50 @@ def predict(self, payload: dict, model: Optional[str] = None, max_retries=None):
)
return PredictionResponse.from_response(response, exc_history=history)

def predict_async_task(
self, payload: dict, model: Optional[str] = None, max_retries=None
):
payload = MultipartPayload(input_data=payload)
client_info = self.amend_client_info(model)
try:
_, response = retryable_callback(
lambda: self._client.post(
client_info.prediction_async_path, **payload.serialize()
),
max_retries or self.max_retries,
FlyMyAIAsyncTaskException,
FlyMyAIExceptionGroup,
)
response = SSEInferenceResponseFactory(response).construct()
return self._async_prediction_task_construct(response, client_info)
except BaseFlyMyAIException as e:
raise FlyMyAIAsyncTaskException.from_base_exception(e)

def prediction_task_result(
self, prediction_task: AsyncPredictionTask, timeout: Optional[float] = None
):
prediction_id = prediction_task.prediction_id

def get_res():
resp = self._client.get(
url=(
prediction_task.client_info or self.client_info
).prediction_result_path,
params={"request_id": prediction_id},
)
return self._construct_task_result(resp)

_, res = retryable_callback(
lambda: get_res(),
None,
FlyMyAIAsyncTaskException,
FlyMyAIExceptionGroup,
timeout,
0.5,
)

return res

def _stream(self, client_info: APIKeyClientInfo, payload: dict):
payload = MultipartPayload(payload)
response_iterator = self._stream_iterator(
Expand All @@ -94,7 +143,7 @@ def _stream(self, client_info: APIKeyClientInfo, payload: dict):
with response_iterator as sse_stream:
for sse_partial in decoder.iter(sse_stream.iter_lines()):
try:
response = ResponseFactory(
response = SSEInferenceResponseFactory(
sse=sse_partial,
httpx_request=sse_stream.request,
httpx_response=sse_stream,
Expand Down Expand Up @@ -138,7 +187,7 @@ def cancel_prediction(
url=full_client_info.prediction_cancel_path,
json={"infer_id": prediction_id},
)
return ResponseFactory(
return SSEInferenceResponseFactory(
httpx_response=response, httpx_request=response.request
).construct()

Expand Down
Loading