Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,4 @@ dmypy.json
.idea/*
poetry.lock
tests/fixtures*/
venv*
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ This is a Python client for [FlyMyAI](https://flymy.ai). It allows you to easily

## Requirements

- Python 3.10+
- Python 3.8+

## Installation

Install the FlyMyAI client using pip:

```sh
pip install flymyai-client
pip install flymyai
```

## Authentication
Expand Down
41 changes: 22 additions & 19 deletions flymyai/core/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Iterator,
AsyncContextManager,
AsyncIterator,
Optional,
)

import httpx
Expand Down Expand Up @@ -52,15 +53,15 @@ class BaseClient(Generic[_PossibleClients]):
client_info: APIKeyClientInfo

def __init__(
self, apikey: str, model: str | None = None, max_retries=DEFAULT_RETRY_COUNT
self, apikey: str, model: Optional[str] = None, max_retries=DEFAULT_RETRY_COUNT
):
self.client_info = APIKeyClientInfo(apikey)
if model:
self.client_info = self.client_info.copy_for_model(model)
self._client = self._construct_client()
self.max_retries = max_retries

def amend_client_info(self, model: str | None = None):
def amend_client_info(self, model: Optional[str] = None):
if model:
client_info = self.client_info.copy_for_model(model)
else:
Expand All @@ -73,64 +74,64 @@ def amend_client_info(self, model: str | None = None):

@overload
async def predict(
self, payload: dict, model: str | None = None, max_retries=None
self, payload: dict, model: Optional[str] = None, max_retries=None
) -> PredictionResponse:
...

@overload
def predict(
self, payload: dict, model: str | None = None, max_retries=None
self, payload: dict, model: Optional[str] = None, max_retries=None
) -> PredictionResponse:
...

def predict(
self, payload: dict, model: str | None = None, max_retries=None
self, payload: dict, model: Optional[str] = None, max_retries=None
) -> PredictionResponse:
...

@overload
async def openapi_schema(
self, model: str | None = None, max_retries=None
self, model: Optional[str] = None, max_retries=None
) -> OpenAPISchemaResponse:
...

@overload
def openapi_schema(
self, model: str | None = None, max_retries=None
self, model: Optional[str] = None, max_retries=None
) -> OpenAPISchemaResponse:
...

def openapi_schema(
self, model: str | None = None, max_retries=None
self, model: Optional[str] = None, max_retries=None
) -> OpenAPISchemaResponse:
...

@overload
async def stream(
self,
payload: dict,
model: str | None = None,
model: Optional[str] = None,
) -> AsyncIterator[PredictionPartial]:
...

@overload
def stream(
self,
payload: dict,
model: str | None = None,
model: Optional[str] = None,
) -> Iterator[PredictionPartial]:
...

def stream(
self,
payload: dict,
model: str | None = None,
model: Optional[str] = None,
):
...

def _stream_iterator(
self, client_info, payload: MultipartPayload, is_long_stream: bool
) -> Iterator[httpx.Response] | AsyncIterator[httpx.Response]:
) -> Union[Iterator[httpx.Response], AsyncIterator[httpx.Response]]:
return self._client.stream(
method="post",
url=(
Expand Down Expand Up @@ -209,7 +210,7 @@ def _predict(self, payload: MultipartPayload, client_info: APIKeyClientInfo):
except BaseFlyMyAIException as e:
raise FlyMyAIPredictException.from_response(e.response)

def predict(self, payload: dict, model: str | None = None, max_retries=None):
def predict(self, payload: dict, model: Optional[str] = None, max_retries=None):
"""
Wrap predict method in sse.
Retries until max_retries or self.max_retries is reached
Expand Down Expand Up @@ -248,7 +249,7 @@ def _stream(self, client_info: APIKeyClientInfo, payload: dict):
raise FlyMyAIPredictException.from_response(e.response)
yield response

def stream(self, payload: dict, model: str | None = None):
def stream(self, payload: dict, model: Optional[str] = None):
stream_iter = self._stream(self.amend_client_info(model), payload)
last_response = None
for response in stream_iter:
Expand All @@ -273,7 +274,7 @@ def _openapi_schema(self, client_info: APIKeyClientInfo):
except BaseFlyMyAIException as e:
raise FlyMyAIOpenAPIException.from_response(e.response)

def openapi_schema(self, model: str | None = None, max_retries=None):
def openapi_schema(self, model: Optional[str] = None, max_retries=None):
"""
:param model: flymyai/bert
:param max_retries: retries before give up
Expand Down Expand Up @@ -321,7 +322,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
if hasattr(self, "_client"):
await self._client.aclose()

async def openapi_schema(self, model: str | None = None, max_retries=None):
async def openapi_schema(self, model: Optional[str] = None, max_retries=None):
"""
:param max_retries: retries before giving up
:return:
Expand Down Expand Up @@ -364,7 +365,7 @@ async def _sse_instant(
:return: FlyMyAIResponse
"""
async with async_response_stream() as stream:
sse = await anext(SSEDecoder().aiter(stream.aiter_lines()))
sse = await SSEDecoder().aiter(stream.aiter_lines()).__anext__()
response = ResponseFactory(
sse=sse, httpx_request=stream.request, httpx_response=stream
).construct()
Expand All @@ -390,7 +391,7 @@ def _predict(self, client_info, payload: MultipartPayload):
raise FlyMyAIPredictException.from_response(e.response)

async def predict(
self, payload: dict, model: str | None = None, max_retries=None
self, payload: dict, model: Optional[str] = None, max_retries=None
) -> PredictionResponse:
"""
Wrap predict method in sse.
Expand Down Expand Up @@ -429,7 +430,9 @@ async def _stream(self, client_info: APIKeyClientInfo, payload: dict):
raise FlyMyAIPredictException.from_response(e.response)
yield response

async def stream(self, payload: dict, model: str | None = None, max_retries=None):
async def stream(
self, payload: dict, model: Optional[str] = None, max_retries=None
):
stream_iter = self._stream(self.amend_client_info(model), payload)
last_response = None
async for response in stream_iter:
Expand Down
5 changes: 3 additions & 2 deletions flymyai/core/authorizations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import dataclasses
from typing import Optional

import httpx

Expand Down Expand Up @@ -38,8 +39,8 @@ class APIKeyClientInfo(ClientInfo):
"""

apikey: str
username: str | None = None
project_name: str | None = None
username: Optional[str] = None
project_name: Optional[str] = None

@property
def authorization_headers(self):
Expand Down
4 changes: 3 additions & 1 deletion flymyai/core/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List

from ._response import FlyMyAIResponse
from .models import (
FlyMyAI401Response,
Expand Down Expand Up @@ -80,7 +82,7 @@ class FlyMyAIOpenAPIException(BaseFlyMyAIException):


class FlyMyAIExceptionGroup(Exception):
def __init__(self, errors: list[BaseFlyMyAIException], **kwargs):
def __init__(self, errors: List[BaseFlyMyAIException], **kwargs):
self.errors = errors
exceptions_message = ";".join([str(err) for err in errors])
self.message = f"FlyMyAI exception history: {exceptions_message}"
Expand Down
9 changes: 5 additions & 4 deletions flymyai/core/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import dataclasses
import json
from typing import Optional

import httpx
import pydantic
Expand Down Expand Up @@ -112,11 +113,11 @@ class PredictionResponse(BaseFromServer):
Prediction response from FlyMyAI
"""

exc_history: list | None
exc_history: Optional[list]
output_data: dict
status: int

inference_time: float | None = None
inference_time: Optional[float] = None

@property
def response(self):
Expand All @@ -128,13 +129,13 @@ class OpenAPISchemaResponse(BaseFromServer):
OpenAPI schema for the current project. Use it to construct your own schema
"""

exc_history: list | None
exc_history: Optional[list]
openapi_schema: dict
status: int


class PredictionPartial(BaseFromServer):
status: int
output_data: dict | None = None
output_data: Optional[dict] = None

_response: FlyMyAIResponse = PrivateAttr()
12 changes: 5 additions & 7 deletions flymyai/multipart/binary_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import pathlib
import uuid
from io import BytesIO
from typing import Union, BinaryIO, Any
from typing import Union, BinaryIO, Any, Optional, Tuple

from .base_field import BaseField

_BinaryInput = Union[bytes, pathlib.Path, BinaryIO, str]
_IOOutput = Union[BinaryIO, io.BytesIO]
_FieldOutput = tuple[str, _IOOutput, str] # filename, io[binary], mime
_IOOutput = Union[BinaryIO, BytesIO]
_FieldOutput = Tuple[str, _IOOutput, str] # filename, io[binary], mime


def is_binary_input(value: _BinaryInput) -> bool:
Expand All @@ -35,7 +35,7 @@ class BinaryField(BaseField):
def __init__(self, value: _BinaryInput):
super().__init__(value)

def validate(self, value: _BinaryInput | None = None) -> None:
def validate(self, value: Optional[_BinaryInput] = None) -> None:
value = value or self.value
if not is_binary_input(value):
raise TypeError()
Expand All @@ -55,9 +55,7 @@ def to_io(value: _BinaryInput) -> _IOOutput:
)
return io_obj

def serialize(
self, value=None
) -> tuple[BinaryIO | BytesIO, str | Any, tuple[str | None, str | None] | str]:
def serialize(self, value=None) -> Tuple[Union[str, Any], _IOOutput, Optional[str]]:
value = value or self.value
io_obj = self.to_io(value)
filename = io_obj.name
Expand Down
2 changes: 1 addition & 1 deletion flymyai/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Awaitable, Type, Any, Coroutine
from typing import Callable, Awaitable, Type

import httpx

Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ typing-extensions = "^4.9.0"
[tool.poetry.dev-dependencies]
python = ">=3.8"
httpx = ">=0.26.0"
pytest = "^7.4.3"

[tool.poetry.group.dev.dependencies]
tomli = "^2.0.1"
pytest-asyncio = "^0.23.7"

[build-system]
requires = ["poetry-core"]
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml.template
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ typing-extensions = "^4.9.0"
[tool.poetry.dev-dependencies]
python = ">=3.8"
httpx = ">=0.26.0"
pytest = "^7.4.3"

[tool.poetry.group.dev.dependencies]
tomli = "^2.0.1"
pytest-asyncio = "^0.23.7"

[build-system]
requires = ["poetry-core"]
Expand Down
3 changes: 2 additions & 1 deletion tests/FixtureFactory.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import json
import os
import pathlib
from typing import Union

fixture_dir = os.getenv("FIXTURE_DIR", "fixtures")


class FixtureFactory:
def __init__(self, test_module_name: str | pathlib.Path):
def __init__(self, test_module_name: Union[str, pathlib.Path]):
if test_module_name.endswith(".py"):
test_module_name = test_module_name[:-3]
if isinstance(test_module_name, str):
Expand Down