Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions aidial_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from aidial_client._client_pool import AsyncDialClientPool, DialClientPool
from aidial_client._exception import (
DialException,
EtagMismatchError,
InvalidDialURLError,
InvalidRequestError,
ParsingDataError,
ResourceNotFoundError,
)

__all__ = [
Expand All @@ -21,4 +23,6 @@
"InvalidDialURLError",
"InvalidRequestError",
"ParsingDataError",
"EtagMismatchError",
"ResourceNotFoundError",
]
18 changes: 18 additions & 0 deletions aidial_client/_exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,21 @@ def __init__(self, message: str, **kwargs) -> None:
status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
**kwargs,
)


class EtagMismatchError(DialException):
def __init__(self, message: str, **kwargs) -> None:
super().__init__(
message=message,
status_code=HTTPStatus.PRECONDITION_FAILED,
**kwargs,
)


class ResourceNotFoundError(DialException):
def __init__(self, message: str, **kwargs) -> None:
super().__init__(
message=message,
status_code=HTTPStatus.NOT_FOUND,
**kwargs,
)
4 changes: 2 additions & 2 deletions aidial_client/_http_client/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async def request(
options: FinalRequestOptions,
cast_to: Type[ResponseT],
remaining_retries: Optional[int] = None,
_on_http_error: Optional[
on_http_error: Optional[
Callable[[httpx.HTTPStatusError], Optional[DialException]]
] = None,
) -> ResponseT:
Expand Down Expand Up @@ -100,7 +100,7 @@ async def request(
remaining_retries=retries,
)
# Try to get custom error from response status_code/code/message
custom_error = _on_http_error(err) if _on_http_error else None
custom_error = on_http_error(err) if on_http_error else None
# or fallback to default processing
raised_error = custom_error or self._make_dial_error_from_response(
err.response
Expand Down
6 changes: 3 additions & 3 deletions aidial_client/_http_client/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ def auth_headers(self) -> Dict[str, str]:

def request(
self,
*,
cast_to: Type[ResponseT],
options: FinalRequestOptions,
remaining_retries: Optional[int] = None,
*,
error_processor: Optional[
on_http_error: Optional[
Callable[[httpx.HTTPStatusError], Optional[DialException]]
] = None,
) -> ResponseT:
Expand Down Expand Up @@ -100,7 +100,7 @@ def request(
remaining_retries=retries,
)
# Try to get custom error from response status_code/code/message
custom_error = error_processor(err) if error_processor else None
custom_error = on_http_error(err) if on_http_error else None
# or fallback to default processing
raised_error = custom_error or self._make_dial_error_from_response(
err.response
Expand Down
96 changes: 88 additions & 8 deletions aidial_client/resources/files.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,75 @@
from pathlib import PurePosixPath
from typing import Union
from typing import Literal, Optional, Union
from urllib.parse import urljoin

import httpx

from aidial_client._constants import API_PREFIX
from aidial_client._exception import InvalidDialURLError
from aidial_client._exception import (
DialException,
EtagMismatchError,
InvalidDialURLError,
ResourceNotFoundError,
)
from aidial_client._internal_types._generic import NoneType
from aidial_client._internal_types._http_request import (
FileTypes,
FinalRequestOptions,
)
from aidial_client._utils._dict import remove_none
from aidial_client.helpers.storage_resource import DialStorageResourceMixin
from aidial_client.resources.base import AsyncResource, Resource
from aidial_client.resources.metadata import AsyncMetadata, Metadata
from aidial_client.types.file import FileDownloadResponse
from aidial_client.types.metadata import FileMetadata


def _files_error_processor(
http_status_error: httpx.HTTPStatusError,
) -> Optional[DialException]:
if http_status_error.response.status_code == 412:
return EtagMismatchError(
message=http_status_error.response.text,
)
elif http_status_error.response.status_code == 404:
return ResourceNotFoundError(
message=http_status_error.response.text,
)
return None


class Files(Resource, DialStorageResourceMixin):
metadata: Metadata
resource_type: str = "files"

def upload(
self, url: Union[str, PurePosixPath], file: FileTypes
self,
url: Union[str, PurePosixPath],
file: FileTypes,
etag_if_match: Optional[str] = None,
etag_if_none_match: Optional[Literal["*"]] = None,
) -> FileMetadata:
return self.http_client.request(
cast_to=FileMetadata,
options=FinalRequestOptions(
method="PUT",
url=urljoin(API_PREFIX, self.get_api_path(str(url))),
files={"file": file},
headers=remove_none(
{
"If-Match": etag_if_match,
"If-None-Match": etag_if_none_match,
}
),
),
on_http_error=_files_error_processor,
)

def download(self, url: Union[str, PurePosixPath]) -> FileDownloadResponse:
def download(
self,
url: Union[str, PurePosixPath],
etag_if_match: Optional[str] = None,
) -> FileDownloadResponse:
storage_resource = self.get_storage_resource(str(url))
if storage_resource.filename is None:
raise InvalidDialURLError("URL points to a directory, not a file")
Expand All @@ -43,19 +78,35 @@ def download(self, url: Union[str, PurePosixPath]) -> FileDownloadResponse:
options=FinalRequestOptions(
method="GET",
url=urljoin(API_PREFIX, storage_resource.api_path),
headers=remove_none(
{
"If-Match": etag_if_match,
}
),
),
on_http_error=_files_error_processor,
)
return FileDownloadResponse(
response=response, filename=storage_resource.filename
)

def delete(self, url: Union[str, PurePosixPath]) -> None:
def delete(
self,
url: Union[str, PurePosixPath],
etag_if_match: Optional[str] = None,
) -> None:
return self.http_client.request(
cast_to=NoneType,
options=FinalRequestOptions(
method="DELETE",
url=urljoin(API_PREFIX, self.get_api_path(str(url))),
headers=remove_none(
{
"If-Match": etag_if_match,
}
),
),
on_http_error=_files_error_processor,
)

def get_metadata(self, url: Union[str, PurePosixPath]) -> FileMetadata:
Expand All @@ -70,7 +121,11 @@ class AsyncFiles(AsyncResource, DialStorageResourceMixin):
resource_type: str = "files"

async def upload(
self, url: Union[str, PurePosixPath], file: FileTypes
self,
url: Union[str, PurePosixPath],
file: FileTypes,
etag_if_match: Optional[str] = None,
etag_if_none_match: Optional[Literal["*"]] = None,
) -> FileMetadata:

return await self.http_client.request(
Expand All @@ -79,11 +134,20 @@ async def upload(
method="PUT",
url=urljoin(API_PREFIX, self.get_api_path(str(url))),
files={"file": file},
headers=remove_none(
{
"If-Match": etag_if_match,
"If-None-Match": etag_if_none_match,
}
),
),
on_http_error=_files_error_processor,
)

async def download(
self, url: Union[str, PurePosixPath]
self,
url: Union[str, PurePosixPath],
etag_if_match: Optional[str] = None,
) -> FileDownloadResponse:
storage_resource = self.get_storage_resource(str(url))
if storage_resource.filename is None:
Expand All @@ -93,19 +157,35 @@ async def download(
options=FinalRequestOptions(
method="GET",
url=urljoin(API_PREFIX, storage_resource.api_path),
headers=remove_none(
{
"If-Match": etag_if_match,
}
),
),
on_http_error=_files_error_processor,
)
return FileDownloadResponse(
response=response, filename=storage_resource.filename
)

async def delete(self, url: Union[str, PurePosixPath]) -> None:
async def delete(
self,
url: Union[str, PurePosixPath],
etag_if_match: Optional[str] = None,
) -> None:
return await self.http_client.request(
cast_to=NoneType,
options=FinalRequestOptions(
method="DELETE",
url=urljoin(API_PREFIX, self.get_api_path(str(url))),
headers=remove_none(
{
"If-Match": etag_if_match,
}
),
),
on_http_error=_files_error_processor,
)

async def get_metadata(
Expand Down
2 changes: 1 addition & 1 deletion aidial_client/types/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class Deployment(ExtraAllowModel):
id: str
model: str
owner: str
object: Literal["deployment"]
object: Literal["deployment", "model"]
status: Literal["succeeded"]
created_at: int
updated_at: int
Expand Down
1 change: 1 addition & 0 deletions aidial_client/types/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class FileMetadata(BaseMetadata):
content_length: Optional[int] = None
content_type: Optional[str] = None
items: Optional[List[FileItem]] = None
etag: Optional[str] = None


class ConversationItem(BaseMetadata):
Expand Down
18 changes: 18 additions & 0 deletions tests/integration/fixtures.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
import uuid

import pytest

from aidial_client import AsyncDial, Dial
from aidial_client._exception import ResourceNotFoundError


@pytest.fixture
Expand Down Expand Up @@ -36,3 +38,19 @@ def test_deployment(sync_client: Dial) -> str:
deployment = next((d for d in deployments if d.id.startswith("gpt-")))
assert deployment
return deployment.id


@pytest.fixture
def absent_test_file(sync_client):

def _save_delete_file(p):
try:
sync_client.files.delete(p)
except ResourceNotFoundError:
pass

unique_name = f"test-file-{uuid.uuid4()}.txt"
full_path = sync_client.my_files_home() / unique_name
_save_delete_file(full_path)
yield full_path
_save_delete_file(full_path)
Loading