diff --git a/README.md b/README.md index 4efc45a..a7027aa 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,10 @@ - [Authentication](#authentication) - [API Keys](#api-keys) - [Bearer Token](#bearer-token) - - [List Deployments](#list-deployments) + - [Deployments](#deployments) + - [List Deployments](#list-deployments) + - [Get Deployment by Id](#get-deployment-by-id) + - [Get Deployment Configuration](#get-deployment-configuration) - [Make Chat Completions Requests](#make-completions-requests) - [Without Streaming](#without-streaming) - [With Streaming](#with-streaming) @@ -30,6 +33,12 @@ - [Applications](#applications) - [List Applications](#list-applications) - [Get Application by Id](#get-application-by-id) + - [Models](#models) + - [Get Model by Name](#get-model-by-name) + - [Toolsets](#toolsets) + - [Get Toolset by Id](#get-toolset-by-id) + - [Resource Permissions](#resource-permissions) + - [Grant Permissions](#grant-permissions) - [Client Pool](#client-pool) - [Synchronous Client Pool](#synchronous-client-pool) - [Asynchronous Client Pool](#asynchronous-client-pool) @@ -136,19 +145,94 @@ dial_client = Dial( ) ``` -### List Deployments +### Deployments -If you want to get a list of available deployments, use `client.deployments.list()` or method: +#### List Deployments + +To get a list of available deployments: + +```python +# Sync +deployments = client.deployments.list() +# Async +deployments = await async_client.deployments.list() +``` ```pycon >>> client.deployments.list() [ - Deployment(id='gpt-35-turbo', model='gpt-35-turbo', owner='organization-owner', object='deployment', status='succeeded', created_at=1724760524, updated_at=1724760524, scale_settings=ScaleSettings(scale_type='standard'), features={'rate': False, 'tokenize': False, 'truncate_prompt': False, 'configuration': False, 'system_prompt': True, 'tools': False, 'seed': False, 'url_attachments': False, 'folder_attachments': False, 'allow_resume': True}), - Deployment(id='stable-diffusion-xl', model='stable-diffusion-xl', owner='organization-owner', object='deployment', status='succeeded', created_at=1724760524, updated_at=1724760524, scale_settings=ScaleSettings(scale_type='standard'), features={'rate': False, 'tokenize': False, 'truncate_prompt': False, 'configuration': False, 'system_prompt': True, 'tools': False, 'seed': False, 'url_attachments': False, 'folder_attachments': False, 'allow_resume': True}), - Deployment(id='gemini-pro-vision', model='gemini-pro-vision', owner='organization-owner', object='deployment', status='succeeded', created_at=1724760524, updated_at=1724760524, scale_settings=ScaleSettings(scale_type='standard'), features={'rate': False, 'tokenize': False, 'truncate_prompt': False, 'configuration': False, 'system_prompt': True, 'tools': False, 'seed': False, 'url_attachments': False, 'folder_attachments': False, 'allow_resume': True}), + Deployment(id='gpt-35-turbo', model='gpt-35-turbo', owner='organization-owner', object='deployment', status='succeeded', created_at=1724760524, updated_at=1724760524, scale_settings=ScaleSettings(scale_type='standard'), features=Features(rate=False, tokenize=False, truncate_prompt=False, configuration=False, system_prompt=True, tools=False, seed=False, url_attachments=False, folder_attachments=False, allow_resume=True)), + Deployment(id='stable-diffusion-xl', model='stable-diffusion-xl', owner='organization-owner', object='deployment', status='succeeded', created_at=1724760524, updated_at=1724760524, scale_settings=ScaleSettings(scale_type='standard'), features=Features(rate=False, tokenize=False, truncate_prompt=False, configuration=False, system_prompt=True, tools=False, seed=False, url_attachments=False, folder_attachments=False, allow_resume=True)), + ..., ] ``` +#### Get Deployment by Id + +To fetch a single deployment by its identifier: + +```python +# Sync +deployment = client.deployments.get("gpt-35-turbo") +# Async +deployment = await async_client.deployments.get("gpt-35-turbo") +``` + +As a result, you will receive a `Deployment` object: + +```python +Deployment( + id="gpt-35-turbo", + model="gpt-35-turbo", + object="deployment", + owner="organization-owner", + status="succeeded", + created_at=1724760524, + updated_at=1724760524, + scale_settings=ScaleSettings(scale_type="standard"), + features=Features( + rate=False, + tokenize=False, + truncate_prompt=False, + configuration=True, + system_prompt=True, + tools=True, + seed=False, + url_attachments=False, + folder_attachments=False, + allow_resume=True, + ), + defaults={}, +) +``` + +#### Get Deployment Configuration + +Some deployments expose a JSON Schema document describing their runtime configuration. Use `get_configuration()` to retrieve it: + +```python +# Sync +config = client.deployments.get_configuration_schema("gpt-35-turbo") +# Async +config = await async_client.deployments.get_configuration_schema("gpt-35-turbo") +``` + +The response is a plain `dict` whose shape is entirely deployment-specific: + +```python +{ + "type": "object", + "properties": { + "model_to_use": { + "type": "string", + "enum": ["gpt-4", "gpt-4o"], + "default": "gpt-4", + } + }, + "additionalProperties": False, +} +``` + ### Make Completions Requests #### Without Streaming @@ -535,6 +619,111 @@ application = await async_client.application.get("app_id") As a result, you will receive a list of `Application` objects. Refer to the [previous example](#list-applications). +### Models + +#### Get Model by Name + +To retrieve metadata, capabilities, and pricing for a specific model: + +```python +# Sync +model_info = client.model.get("gpt-4") +# Async +model_info = await async_client.model.get("gpt-4") +``` + +As a result, you will receive a `ModelInfo` object: + +```python +ModelInfo( + id="gpt-4", + model="gpt-4", + object="model", + owner="organization-owner", + status="succeeded", + created_at=1724760524, + updated_at=1724760524, + lifecycle_status="generally-available", + display_name="GPT-4", + description="OpenAI GPT-4 model.", + capabilities=ModelCapabilities( + scale_types=["standard"], + completion=False, + chat_completion=True, + embeddings=False, + fine_tune=False, + inference=False, + ), + limits=ModelLimits( + max_prompt_tokens=8192, + max_completion_tokens=4096, + max_total_tokens=None, + ), + pricing=ModelPricing( + unit="token", + prompt="0.00003", + completion="0.00006", + ), +) +``` + +### Toolsets + +#### Get Toolset by Id + +To retrieve information about a specific MCP toolset: + +```python +# Sync +toolset_info = client.toolset.get("my-toolset") +# Async +toolset_info = await async_client.toolset.get("my-toolset") +``` + +As a result, you will receive a `ToolsetInfo` object: + +```python +ToolsetInfo( + id="my-toolset", + toolset="my-toolset", + display_name="My Toolset", + description="A collection of tools for data processing.", + transport="HTTP", + allowed_tools=["tool-a", "tool-b"], + owner="organization-owner", + status="succeeded", + created_at=1724760524, + updated_at=1724760524, +) +``` + +### Resource Permissions + +#### Grant Permissions + +Use `resource_permissions.grant()` to grant access to one or more files in DIAL storage to a specific deployment (receiver). This is typically used when a deployment needs to read files on behalf of a user. + +```python +# Sync +client.resource_permissions.grant( + resources=["files/my-bucket/report.pdf"], + receiver="my-deployment", + permissions=["READ"], +) +# Async +await async_client.resource_permissions.grant( + resources=["files/my-bucket/report.pdf"], + receiver="my-deployment", + permissions=["READ"], +) +``` + +- `resources` — list of DIAL file URL strings to share. +- `receiver` — the deployment ID that should receive access. +- `permissions` — list of permission strings; defaults to `["READ"]`. + +The method returns `None` on success and raises `DialException` on HTTP error. + ### Client Pool When you need to create multiple DIAL clients and wish to enhance performance by reusing the HTTP connection for the same DIAL instance, consider using synchronous and asynchronous **client pools**. diff --git a/aidial_client/__init__.py b/aidial_client/__init__.py index 7706e53..397016c 100644 --- a/aidial_client/__init__.py +++ b/aidial_client/__init__.py @@ -9,6 +9,8 @@ ParsingDataError, ResourceNotFoundError, ) +from aidial_client.types.model import ModelInfo, ModelLimits, ModelPricing +from aidial_client.types.toolset import ToolsetInfo __all__ = [ "Dial", @@ -24,4 +26,8 @@ "ParsingDataError", "EtagMismatchError", "ResourceNotFoundError", + "ToolsetInfo", + "ModelInfo", + "ModelPricing", + "ModelLimits", ] diff --git a/aidial_client/_client.py b/aidial_client/_client.py index 9ff03c8..ba17864 100644 --- a/aidial_client/_client.py +++ b/aidial_client/_client.py @@ -109,6 +109,11 @@ def _init_resources(self) -> None: ) self.deployments = resources.Deployments(http_client=self._http_client) self.application = resources.Application(http_client=self._http_client) + self.toolset = resources.Toolset(http_client=self._http_client) + self.model = resources.Model(http_client=self._http_client) + self.resource_permissions = resources.ResourcePermissions( + http_client=self._http_client + ) def _create_http_client(self) -> SyncHTTPClient: return SyncHTTPClient( @@ -189,6 +194,11 @@ def _init_resources(self) -> None: self.application = resources.AsyncApplication( http_client=self._http_client ) + self.toolset = resources.AsyncToolset(http_client=self._http_client) + self.model = resources.AsyncModel(http_client=self._http_client) + self.resource_permissions = resources.AsyncResourcePermissions( + http_client=self._http_client + ) def _create_http_client(self) -> AsyncHTTPClient: return AsyncHTTPClient( diff --git a/aidial_client/_internal_types/_generic.py b/aidial_client/_internal_types/_generic.py index d8e2893..113e9db 100644 --- a/aidial_client/_internal_types/_generic.py +++ b/aidial_client/_internal_types/_generic.py @@ -15,6 +15,7 @@ ExtraForbidModel, bytes, str, + dict, httpx.Response, FileDownloadResponse, None, diff --git a/aidial_client/_utils/_response_processing.py b/aidial_client/_utils/_response_processing.py index d3e4e51..5135397 100644 --- a/aidial_client/_utils/_response_processing.py +++ b/aidial_client/_utils/_response_processing.py @@ -21,6 +21,13 @@ def process_block_response( return cast(ResponseT, response.text) elif cast_to == NoneType: return cast(ResponseT, None) + elif cast_to == dict: + try: + return cast(ResponseT, response.json()) + except Exception as e: + raise ParsingDataError( + message=f"Error during parsing of response data: {str(e)}" + ) elif issubclass(cast_to, (ExtraForbidModel, ExtraAllowModel)): try: data = response.json() diff --git a/aidial_client/resources/__init__.py b/aidial_client/resources/__init__.py index 6467e8b..26abbc3 100644 --- a/aidial_client/resources/__init__.py +++ b/aidial_client/resources/__init__.py @@ -1,5 +1,11 @@ from aidial_client.resources.deployments import AsyncDeployments, Deployments from aidial_client.resources.metadata import AsyncMetadata, Metadata +from aidial_client.resources.model import AsyncModel, Model +from aidial_client.resources.resource_permissions import ( + AsyncResourcePermissions, + ResourcePermissions, +) +from aidial_client.resources.toolset import AsyncToolset, Toolset from .application import Application, AsyncApplication from .bucket import AsyncBucket, Bucket @@ -19,4 +25,10 @@ "Metadata", "Application", "AsyncApplication", + "Toolset", + "AsyncToolset", + "Model", + "AsyncModel", + "ResourcePermissions", + "AsyncResourcePermissions", ] diff --git a/aidial_client/resources/deployments.py b/aidial_client/resources/deployments.py index 32234e3..10e4fb5 100644 --- a/aidial_client/resources/deployments.py +++ b/aidial_client/resources/deployments.py @@ -1,4 +1,4 @@ -from typing import List +from typing import Any, Dict, List from aidial_client._internal_types._http_request import FinalRequestOptions from aidial_client.resources.base import AsyncResource, Resource @@ -15,6 +15,23 @@ def _list_raw(self) -> DeploymentsResponse: def list(self) -> List[Deployment]: return self._list_raw().data + def get(self, deployment_id: str) -> Deployment: + return self.http_client.request( + cast_to=Deployment, + options=FinalRequestOptions( + method="GET", url=f"openai/deployments/{deployment_id}" + ), + ) + + def get_configuration_schema(self, deployment_id: str) -> Dict[str, Any]: + return self.http_client.request( + cast_to=dict, + options=FinalRequestOptions( + method="GET", + url=f"v1/deployments/{deployment_id}/configuration", + ), + ) + class AsyncDeployments(AsyncResource): async def _list_raw(self) -> DeploymentsResponse: @@ -25,3 +42,22 @@ async def _list_raw(self) -> DeploymentsResponse: async def list(self) -> List[Deployment]: return (await self._list_raw()).data + + async def get(self, deployment_id: str) -> Deployment: + return await self.http_client.request( + cast_to=Deployment, + options=FinalRequestOptions( + method="GET", url=f"openai/deployments/{deployment_id}" + ), + ) + + async def get_configuration_schema( + self, deployment_id: str + ) -> Dict[str, Any]: + return await self.http_client.request( + cast_to=dict, + options=FinalRequestOptions( + method="GET", + url=f"v1/deployments/{deployment_id}/configuration", + ), + ) diff --git a/aidial_client/resources/model.py b/aidial_client/resources/model.py new file mode 100644 index 0000000..78fa689 --- /dev/null +++ b/aidial_client/resources/model.py @@ -0,0 +1,23 @@ +from aidial_client._internal_types._http_request import FinalRequestOptions +from aidial_client.resources.base import AsyncResource, Resource +from aidial_client.types.model import ModelInfo + + +class Model(Resource): + def get(self, model_name: str) -> ModelInfo: + return self.http_client.request( + cast_to=ModelInfo, + options=FinalRequestOptions( + method="GET", url=f"openai/models/{model_name}" + ), + ) + + +class AsyncModel(AsyncResource): + async def get(self, model_name: str) -> ModelInfo: + return await self.http_client.request( + cast_to=ModelInfo, + options=FinalRequestOptions( + method="GET", url=f"openai/models/{model_name}" + ), + ) diff --git a/aidial_client/resources/resource_permissions.py b/aidial_client/resources/resource_permissions.py new file mode 100644 index 0000000..09c7b0e --- /dev/null +++ b/aidial_client/resources/resource_permissions.py @@ -0,0 +1,53 @@ +from typing import List + +from aidial_client._internal_types._generic import NoneType +from aidial_client._internal_types._http_request import FinalRequestOptions +from aidial_client.resources.base import AsyncResource, Resource + +_GRANT_URL = "v1/ops/resource/per-request-permissions/grant" + + +class ResourcePermissions(Resource): + def grant( + self, + resources: List[str], + receiver: str, + permissions: List[str] = ["READ"], + ) -> None: + self.http_client.request( + cast_to=NoneType, + options=FinalRequestOptions( + method="POST", + url=_GRANT_URL, + json_data={ + "resources": [ + {"url": url, "permissions": permissions} + for url in resources + ], + "receiver": receiver, + }, + ), + ) + + +class AsyncResourcePermissions(AsyncResource): + async def grant( + self, + resources: List[str], + receiver: str, + permissions: List[str] = ["READ"], + ) -> None: + await self.http_client.request( + cast_to=NoneType, + options=FinalRequestOptions( + method="POST", + url=_GRANT_URL, + json_data={ + "resources": [ + {"url": url, "permissions": permissions} + for url in resources + ], + "receiver": receiver, + }, + ), + ) diff --git a/aidial_client/resources/toolset.py b/aidial_client/resources/toolset.py new file mode 100644 index 0000000..6adf37e --- /dev/null +++ b/aidial_client/resources/toolset.py @@ -0,0 +1,23 @@ +from aidial_client._internal_types._http_request import FinalRequestOptions +from aidial_client.resources.base import AsyncResource, Resource +from aidial_client.types.toolset import ToolsetInfo + + +class Toolset(Resource): + def get(self, toolset_id: str) -> ToolsetInfo: + return self.http_client.request( + cast_to=ToolsetInfo, + options=FinalRequestOptions( + method="GET", url=f"openai/toolsets/{toolset_id}" + ), + ) + + +class AsyncToolset(AsyncResource): + async def get(self, toolset_id: str) -> ToolsetInfo: + return await self.http_client.request( + cast_to=ToolsetInfo, + options=FinalRequestOptions( + method="GET", url=f"openai/toolsets/{toolset_id}" + ), + ) diff --git a/aidial_client/types/deployment.py b/aidial_client/types/deployment.py index 8223d7b..fed5c88 100644 --- a/aidial_client/types/deployment.py +++ b/aidial_client/types/deployment.py @@ -20,6 +20,11 @@ class Features(ExtraAllowModel): folder_attachments: Optional[bool] = None allow_resume: Optional[bool] = None parallel_tool_calls: Optional[bool] = None + accessible_by_per_request_key: Optional[bool] = None + content_parts: Optional[bool] = None + cache: Optional[bool] = None + auto_caching: Optional[bool] = None + assistant_attachments_in_request: Optional[bool] = None class DeploymentBase(ExtraAllowModel): diff --git a/aidial_client/types/model.py b/aidial_client/types/model.py new file mode 100644 index 0000000..1c3b2cc --- /dev/null +++ b/aidial_client/types/model.py @@ -0,0 +1,48 @@ +from typing import List, Optional + +from aidial_client._internal_types._model import ExtraAllowModel + + +class ModelCapabilities(ExtraAllowModel): + scale_types: List[str] = [] + completion: Optional[bool] = None + chat_completion: Optional[bool] = None + embeddings: Optional[bool] = None + fine_tune: Optional[bool] = None + inference: Optional[bool] = None + + +class ModelLimits(ExtraAllowModel): + """Token limits for the model. + + Either `max_total_tokens` is set alone, or `max_prompt_tokens` and + `max_completion_tokens` are set together (oneOf in the schema). + All fields are Optional here to accommodate both variants. + """ + + max_total_tokens: Optional[int] = None + max_prompt_tokens: Optional[int] = None + max_completion_tokens: Optional[int] = None + + +class ModelPricing(ExtraAllowModel): + unit: str + prompt: str + completion: Optional[str] = None + + +class ModelInfo(ExtraAllowModel): + id: str + model: str + display_name: Optional[str] = None + description: Optional[str] = None + owner: Optional[str] = None + object: Optional[str] = None + status: Optional[str] = None + created_at: Optional[int] = None + updated_at: Optional[int] = None + lifecycle_status: Optional[str] = None + tokenizer_model: Optional[str] = None + capabilities: Optional[ModelCapabilities] = None + limits: Optional[ModelLimits] = None + pricing: Optional[ModelPricing] = None diff --git a/aidial_client/types/toolset.py b/aidial_client/types/toolset.py new file mode 100644 index 0000000..8d0ac94 --- /dev/null +++ b/aidial_client/types/toolset.py @@ -0,0 +1,24 @@ +from typing import List, Optional + +from aidial_client._internal_types._model import ExtraAllowModel +from aidial_client.types.deployment import Features + + +class ToolsetInfo(ExtraAllowModel): + id: str + toolset: str + display_name: Optional[str] = None + display_version: Optional[str] = None + description: Optional[str] = None + icon_url: Optional[str] = None + owner: Optional[str] = None + object: Optional[str] = None + status: Optional[str] = None + created_at: Optional[int] = None + updated_at: Optional[int] = None + reference: Optional[str] = None + description_keywords: List[str] = [] + max_retry_attempts: Optional[int] = None + transport: Optional[str] = None + allowed_tools: List[str] = [] + features: Optional[Features] = None diff --git a/tests/resources/test_deployments.py b/tests/resources/test_deployments.py new file mode 100644 index 0000000..92ddef1 --- /dev/null +++ b/tests/resources/test_deployments.py @@ -0,0 +1,112 @@ +import pytest + +from aidial_client._exception import DialException +from aidial_client.types.deployment import Deployment +from tests.client_mock import get_async_client_mock, get_client_mock + +DEPLOYMENT_MOCK = { + "id": "gpt-4", + "model": "gpt-4", + "object": "deployment", + "owner": "organization-owner", + "status": "succeeded", + "created_at": 1672534800, + "updated_at": 1672534800, + "scale_settings": {"scale_type": "standard"}, +} + +CONFIG_MOCK = { + "type": "object", + "properties": { + "model_to_use": { + "type": "string", + "enum": ["gpt-4", "gpt-4o"], + "default": "gpt-4", + } + }, + "additionalProperties": False, +} + + +# --------------------------------------------------------------------------- +# deployments.get() +# --------------------------------------------------------------------------- + + +def test_get_deployment(): + client = get_client_mock(status_code=200, json_mock=DEPLOYMENT_MOCK) + result = client.deployments.get("gpt-4") + assert isinstance(result, Deployment) + assert result.id == "gpt-4" + assert result.model == "gpt-4" + assert result.object == "deployment" + + +@pytest.mark.asyncio +async def test_async_get_deployment(): + client = get_async_client_mock(status_code=200, json_mock=DEPLOYMENT_MOCK) + result = await client.deployments.get("gpt-4") + assert isinstance(result, Deployment) + assert result.id == "gpt-4" + assert result.model == "gpt-4" + assert result.object == "deployment" + + +def test_get_deployment_http_error(): + client = get_client_mock( + status_code=401, + json_mock={"error": {"message": "Unauthorized", "type": "auth_error"}}, + ) + with pytest.raises(DialException): + client.deployments.get("gpt-4") + + +@pytest.mark.asyncio +async def test_async_get_deployment_http_error(): + client = get_async_client_mock( + status_code=401, + json_mock={"error": {"message": "Unauthorized", "type": "auth_error"}}, + ) + with pytest.raises(DialException): + await client.deployments.get("gpt-4") + + +# --------------------------------------------------------------------------- +# deployments.get_config() +# --------------------------------------------------------------------------- + + +def test_get_deployment_config(): + client = get_client_mock(status_code=200, json_mock=CONFIG_MOCK) + result = client.deployments.get_configuration_schema("gpt-4") + assert isinstance(result, dict) + assert result.get("type") == "object" + assert "properties" in result + + +@pytest.mark.asyncio +async def test_async_get_deployment_config(): + client = get_async_client_mock(status_code=200, json_mock=CONFIG_MOCK) + result = await client.deployments.get_configuration_schema("gpt-4") + assert isinstance(result, dict) + assert result.get("type") == "object" + assert "properties" in result + + +def test_get_deployment_config_http_error(): + client = get_client_mock( + status_code=401, + json_mock={"error": {"message": "Unauthorized", "type": "auth_error"}}, + ) + with pytest.raises(DialException): + client.deployments.get_configuration_schema("gpt-4") + + +@pytest.mark.asyncio +async def test_async_get_deployment_config_http_error(): + client = get_async_client_mock( + status_code=401, + json_mock={"error": {"message": "Unauthorized", "type": "auth_error"}}, + ) + with pytest.raises(DialException): + await client.deployments.get_configuration_schema("gpt-4") diff --git a/tests/resources/test_model.py b/tests/resources/test_model.py new file mode 100644 index 0000000..caaab4b --- /dev/null +++ b/tests/resources/test_model.py @@ -0,0 +1,162 @@ +import pytest + +from aidial_client._exception import DialException +from aidial_client.types.model import ( + ModelCapabilities, + ModelInfo, + ModelLimits, + ModelPricing, +) +from tests.client_mock import get_async_client_mock, get_client_mock + +MODEL_MOCK = { + "id": "gpt-4", + "model": "gpt-4", + "display_name": "GPT 4", + "description": "Chat completion model.", + "owner": "organization-owner", + "object": "model", + "status": "succeeded", + "created_at": 1672534800, + "updated_at": 1672534800, + "lifecycle_status": "generally-available", + "tokenizer_model": "gpt-4-0314", + "capabilities": { + "scale_types": ["standard"], + "completion": False, + "chat_completion": True, + "embeddings": False, + "fine_tune": False, + "inference": False, + }, + "limits": { + "max_prompt_tokens": 8192, + "max_completion_tokens": 4096, + }, + "pricing": { + "unit": "token", + "prompt": "0.00003", + "completion": "0.00006", + }, +} + + +# --------------------------------------------------------------------------- +# Happy path +# --------------------------------------------------------------------------- + + +def test_get_model(): + client = get_client_mock(status_code=200, json_mock=MODEL_MOCK) + result = client.model.get("gpt-4") + assert isinstance(result, ModelInfo) + assert result.id == "gpt-4" + assert result.model == "gpt-4" + assert result.object == "model" + assert result.lifecycle_status == "generally-available" + + +@pytest.mark.asyncio +async def test_async_get_model(): + client = get_async_client_mock(status_code=200, json_mock=MODEL_MOCK) + result = await client.model.get("gpt-4") + assert isinstance(result, ModelInfo) + assert result.id == "gpt-4" + assert result.model == "gpt-4" + + +# --------------------------------------------------------------------------- +# Nested type fields +# --------------------------------------------------------------------------- + + +def test_get_model_pricing(): + client = get_client_mock(status_code=200, json_mock=MODEL_MOCK) + result = client.model.get("gpt-4") + assert isinstance(result.pricing, ModelPricing) + assert result.pricing.unit == "token" + assert result.pricing.prompt == "0.00003" + assert result.pricing.completion == "0.00006" + + +def test_get_model_limits_prompt_and_completion(): + client = get_client_mock(status_code=200, json_mock=MODEL_MOCK) + result = client.model.get("gpt-4") + assert isinstance(result.limits, ModelLimits) + assert result.limits.max_prompt_tokens == 8192 + assert result.limits.max_completion_tokens == 4096 + assert result.limits.max_total_tokens is None + + +def test_get_model_limits_total_tokens(): + mock = {**MODEL_MOCK, "limits": {"max_total_tokens": 16384}} + client = get_client_mock(status_code=200, json_mock=mock) + result = client.model.get("gpt-4") + assert isinstance(result.limits, ModelLimits) + assert result.limits.max_total_tokens == 16384 + assert result.limits.max_prompt_tokens is None + + +def test_get_model_capabilities(): + client = get_client_mock(status_code=200, json_mock=MODEL_MOCK) + result = client.model.get("gpt-4") + assert isinstance(result.capabilities, ModelCapabilities) + assert result.capabilities.chat_completion is True + assert result.capabilities.embeddings is False + assert result.capabilities.scale_types == ["standard"] + + +def test_get_model_tokenizer_model(): + client = get_client_mock(status_code=200, json_mock=MODEL_MOCK) + result = client.model.get("gpt-4") + assert result.tokenizer_model == "gpt-4-0314" + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +def test_get_model_no_pricing(): + mock = {**MODEL_MOCK} + del mock["pricing"] + client = get_client_mock(status_code=200, json_mock=mock) + result = client.model.get("gpt-4") + assert result.pricing is None + + +def test_get_model_embedding_no_completion_price(): + mock = { + **MODEL_MOCK, + "id": "ada-002", + "model": "ada-002", + "pricing": {"unit": "token", "prompt": "0.0000001"}, + } + client = get_client_mock(status_code=200, json_mock=mock) + result = client.model.get("ada-002") + assert isinstance(result.pricing, ModelPricing) + assert result.pricing.completion is None + + +# --------------------------------------------------------------------------- +# Error cases +# --------------------------------------------------------------------------- + + +def test_get_model_http_error(): + client = get_client_mock( + status_code=401, + json_mock={"error": {"message": "Unauthorized", "type": "auth_error"}}, + ) + with pytest.raises(DialException): + client.model.get("gpt-4") + + +@pytest.mark.asyncio +async def test_async_get_model_http_error(): + client = get_async_client_mock( + status_code=401, + json_mock={"error": {"message": "Unauthorized", "type": "auth_error"}}, + ) + with pytest.raises(DialException): + await client.model.get("gpt-4") diff --git a/tests/resources/test_resource_permissions.py b/tests/resources/test_resource_permissions.py new file mode 100644 index 0000000..6afbd77 --- /dev/null +++ b/tests/resources/test_resource_permissions.py @@ -0,0 +1,212 @@ +import json + +import httpx +import pytest + +from aidial_client import AsyncDial, Dial +from aidial_client._exception import DialException + +BASE_URL = "http://dial.core" +GRANT_PATH = "/v1/ops/resource/per-request-permissions/grant" + + +def _make_sync_client_capturing() -> tuple[Dial, list[httpx.Request]]: + """Returns a Dial client whose mock captures every sent request.""" + captured: list[httpx.Request] = [] + client = Dial(api_key="dummy", base_url=BASE_URL) + + def send_mock(request: httpx.Request, **kwargs): + captured.append(request) + return httpx.Response(200, request=request, json={}) + + client._http_client._internal_http_client.send = send_mock + return client, captured + + +def _make_async_client_capturing() -> tuple[AsyncDial, list[httpx.Request]]: + """Returns an AsyncDial client whose mock captures every sent request.""" + captured: list[httpx.Request] = [] + client = AsyncDial(api_key="dummy", base_url=BASE_URL) + + async def send_mock(request: httpx.Request, **kwargs): + captured.append(request) + return httpx.Response(200, request=request, json={}) + + client._http_client._internal_http_client.send = send_mock + return client, captured + + +# --------------------------------------------------------------------------- +# Happy path — return value +# --------------------------------------------------------------------------- + + +def test_grant_returns_none(): + client, _ = _make_sync_client_capturing() + result = client.resource_permissions.grant( + resources=["files/bucket/img.png"], + receiver="my-app", + ) + assert result is None + + +@pytest.mark.asyncio +async def test_async_grant_returns_none(): + client, _ = _make_async_client_capturing() + result = await client.resource_permissions.grant( + resources=["files/bucket/img.png"], + receiver="my-app", + ) + assert result is None + + +# --------------------------------------------------------------------------- +# Request body verification +# --------------------------------------------------------------------------- + + +def test_grant_request_body_single_resource(): + client, captured = _make_sync_client_capturing() + client.resource_permissions.grant( + resources=["files/bucket/img.png"], + receiver="my-app", + permissions=["READ"], + ) + assert len(captured) == 1 + body = json.loads(captured[0].content) + assert body == { + "resources": [{"url": "files/bucket/img.png", "permissions": ["READ"]}], + "receiver": "my-app", + } + + +@pytest.mark.asyncio +async def test_async_grant_request_body_single_resource(): + client, captured = _make_async_client_capturing() + await client.resource_permissions.grant( + resources=["files/bucket/img.png"], + receiver="my-app", + permissions=["READ"], + ) + assert len(captured) == 1 + body = json.loads(captured[0].content) + assert body == { + "resources": [{"url": "files/bucket/img.png", "permissions": ["READ"]}], + "receiver": "my-app", + } + + +def test_grant_request_body_multiple_resources(): + client, captured = _make_sync_client_capturing() + client.resource_permissions.grant( + resources=["files/bucket/a.png", "files/bucket/b.png"], + receiver="deployment-x", + permissions=["READ"], + ) + body = json.loads(captured[0].content) + assert body["receiver"] == "deployment-x" + assert len(body["resources"]) == 2 + assert body["resources"][0] == { + "url": "files/bucket/a.png", + "permissions": ["READ"], + } + assert body["resources"][1] == { + "url": "files/bucket/b.png", + "permissions": ["READ"], + } + + +def test_grant_request_body_write_permission(): + client, captured = _make_sync_client_capturing() + client.resource_permissions.grant( + resources=["files/bucket/img.png"], + receiver="my-app", + permissions=["WRITE"], + ) + body = json.loads(captured[0].content) + assert body["resources"][0]["permissions"] == ["WRITE"] + + +def test_grant_request_url(): + client, captured = _make_sync_client_capturing() + client.resource_permissions.grant( + resources=["files/bucket/img.png"], + receiver="my-app", + ) + assert captured[0].url.path == GRANT_PATH + + +def test_grant_request_method(): + client, captured = _make_sync_client_capturing() + client.resource_permissions.grant( + resources=["files/bucket/img.png"], + receiver="my-app", + ) + assert captured[0].method == "POST" + + +# --------------------------------------------------------------------------- +# Default permissions +# --------------------------------------------------------------------------- + + +def test_grant_default_permissions_are_read(): + client, captured = _make_sync_client_capturing() + client.resource_permissions.grant( + resources=["files/bucket/img.png"], + receiver="my-app", + ) + body = json.loads(captured[0].content) + assert body["resources"][0]["permissions"] == ["READ"] + + +# --------------------------------------------------------------------------- +# Error cases +# --------------------------------------------------------------------------- + + +def test_grant_http_error(): + client = Dial(api_key="dummy", base_url=BASE_URL) + + def send_mock(request: httpx.Request, **kwargs): + return httpx.Response( + 403, + request=request, + json={ + "error": { + "message": "Forbidden: per-request key required", + "type": "auth_error", + } + }, + ) + + client._http_client._internal_http_client.send = send_mock + with pytest.raises(DialException): + client.resource_permissions.grant( + resources=["files/bucket/img.png"], + receiver="my-app", + ) + + +@pytest.mark.asyncio +async def test_async_grant_http_error(): + client = AsyncDial(api_key="dummy", base_url=BASE_URL) + + async def send_mock(request: httpx.Request, **kwargs): + return httpx.Response( + 403, + request=request, + json={ + "error": { + "message": "Forbidden: per-request key required", + "type": "auth_error", + } + }, + ) + + client._http_client._internal_http_client.send = send_mock + with pytest.raises(DialException): + await client.resource_permissions.grant( + resources=["files/bucket/img.png"], + receiver="my-app", + ) diff --git a/tests/resources/test_toolset.py b/tests/resources/test_toolset.py new file mode 100644 index 0000000..0262a97 --- /dev/null +++ b/tests/resources/test_toolset.py @@ -0,0 +1,113 @@ +import pytest + +from aidial_client._exception import DialException +from aidial_client.types.toolset import ToolsetInfo +from tests.client_mock import get_async_client_mock, get_client_mock + +TOOLSET_MOCK = { + "id": "toolsets/bucket/folder/my-toolset", + "toolset": "toolsets/bucket/folder/my-toolset", + "display_name": "My Toolset", + "display_version": "1.0.0", + "description": "A test toolset", + "icon_url": "http://toolset/icon.svg", + "owner": "owner-name", + "object": "toolset", + "status": "succeeded", + "reference": "ff5584b7-a82b-4f4f-bf42-5bf74a3893d6", + "description_keywords": ["keyword1", "keyword2"], + "max_retry_attempts": 3, + "created_at": 1672534800, + "updated_at": 1672534900, + "transport": "HTTP", + "allowed_tools": ["tool1", "tool2"], + "features": { + "rate": True, + "tokenize": False, + "truncate_prompt": False, + "configuration": False, + "system_prompt": True, + "tools": True, + "seed": False, + "url_attachments": False, + "folder_attachments": False, + }, +} + + +# --------------------------------------------------------------------------- +# Happy path +# --------------------------------------------------------------------------- + + +def test_get_toolset(): + client = get_client_mock(status_code=200, json_mock=TOOLSET_MOCK) + result = client.toolset.get("my-toolset") + assert isinstance(result, ToolsetInfo) + assert result.id == "toolsets/bucket/folder/my-toolset" + assert result.toolset == "toolsets/bucket/folder/my-toolset" + assert result.transport == "HTTP" + assert result.allowed_tools == ["tool1", "tool2"] + assert result.display_name == "My Toolset" + + +@pytest.mark.asyncio +async def test_async_get_toolset(): + client = get_async_client_mock(status_code=200, json_mock=TOOLSET_MOCK) + result = await client.toolset.get("my-toolset") + assert isinstance(result, ToolsetInfo) + assert result.id == "toolsets/bucket/folder/my-toolset" + assert result.toolset == "toolsets/bucket/folder/my-toolset" + assert result.transport == "HTTP" + assert result.allowed_tools == ["tool1", "tool2"] + + +# --------------------------------------------------------------------------- +# Optional / nested fields +# --------------------------------------------------------------------------- + + +def test_get_toolset_features(): + client = get_client_mock(status_code=200, json_mock=TOOLSET_MOCK) + result = client.toolset.get("my-toolset") + assert result.features is not None + assert result.features.rate is True + assert result.features.tools is True + assert result.reference == "ff5584b7-a82b-4f4f-bf42-5bf74a3893d6" + assert result.description_keywords == ["keyword1", "keyword2"] + assert result.max_retry_attempts == 3 + + +def test_get_toolset_missing_optional_fields(): + minimal = {"id": "ts", "toolset": "ts"} + client = get_client_mock(status_code=200, json_mock=minimal) + result = client.toolset.get("ts") + assert isinstance(result, ToolsetInfo) + assert result.transport is None + assert result.allowed_tools == [] + assert result.features is None + assert result.description_keywords == [] + + +# --------------------------------------------------------------------------- +# Error cases +# --------------------------------------------------------------------------- + + +def test_get_toolset_http_error(): + client = get_client_mock( + status_code=401, + json_mock={"error": {"message": "Unauthorized", "type": "auth_error"}}, + ) + with pytest.raises(DialException): + client.toolset.get("my-toolset") + + +@pytest.mark.asyncio +async def test_async_get_toolset_http_error(): + client = get_async_client_mock( + status_code=401, + json_mock={"error": {"message": "Unauthorized", "type": "auth_error"}}, + ) + with pytest.raises(DialException): + await client.toolset.get("my-toolset")