diff --git a/aidial_client/types/chat/response.py b/aidial_client/types/chat/response.py index 6560b39..a868ecb 100644 --- a/aidial_client/types/chat/response.py +++ b/aidial_client/types/chat/response.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Literal, Optional, Union +from typing import Dict, List, Literal, Optional from aidial_client._compatibility.pydantic import PYDANTIC_V2 from aidial_client._compatibility.pydantic_v1 import root_validator @@ -54,12 +54,24 @@ class FunctionCall(ExtraAllowModel): name: str +class FunctionCallDelta(ExtraAllowModel): + arguments: Optional[str] = None + name: Optional[str] = None + + class ChatCompletionMessageToolCall(ExtraAllowModel): id: str function: FunctionCall type: Literal["function"] +class ToolCallDelta(ExtraAllowModel): + index: int + id: Optional[str] = None + function: Optional[FunctionCallDelta] = None + type: Optional[Literal["function"]] = None + + class ChatCompletionMessage(ExtraAllowModel): role: Literal["assistant"] content: Optional[str] = None @@ -68,12 +80,26 @@ class ChatCompletionMessage(ExtraAllowModel): tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None +class ChatCompletionMessageDelta(ExtraAllowModel): + role: Optional[Literal["assistant"]] = None + content: Optional[str] = None + custom_content: Optional[CustomContent] = None + function_call: Optional[FunctionCallDelta] = None + tool_calls: Optional[List[ToolCallDelta]] = None + + class Choice(ExtraAllowModel): index: int message: ChatCompletionMessage finish_reason: Optional[str] +class ChoiceDelta(ExtraAllowModel): + index: int + delta: ChatCompletionMessageDelta + finish_reason: Optional[str] = None + + class ChatCompletionResponse(ExtraAllowModel): id: str object: Literal["chat.completion"] @@ -83,23 +109,6 @@ class ChatCompletionResponse(ExtraAllowModel): usage: Optional[CompletionUsage] = None -class ChunkEmptyDelta(ExtraAllowModel): - """ - Sometimes delta could be just empty, or have just content - """ - - content: Optional[str] = None - object: Literal[None] = None - tool_calls: Literal[None] = None - role: Literal[None] = None - - -class ChoiceDelta(ExtraAllowModel): - index: int - delta: Union[ChatCompletionMessage, ChunkEmptyDelta] - finish_reason: Optional[str] = None - - class ChatCompletionChunk(ExtraAllowModel): id: str object: Literal["chat.completion.chunk"] diff --git a/noxfile.py b/noxfile.py index 5d114ab..acaeb73 100644 --- a/noxfile.py +++ b/noxfile.py @@ -56,7 +56,7 @@ def format(session: nox.Session): @nox.session(python=["3.8", "3.9", "3.10", "3.11", "3.12"]) @nox.parametrize("pydantic", ["1.10.17", "2.8.2"]) @nox.parametrize("httpx", ["0.25.0", "0.27.0"]) -@nox.parametrize("openai", ["1.0.0", "1.51.0"]) +@nox.parametrize("openai", ["1.1.0", "1.51.0"]) @nox.parametrize("aiofiles", ["0.5.0", "24.1.0"]) def test( session: nox.Session, pydantic: str, httpx: str, openai: str, aiofiles: str @@ -74,7 +74,7 @@ def test( @nox.session(python=["3.11"]) @nox.parametrize("pydantic", ["1.10.17", "2.8.2"]) -@nox.parametrize("openai", ["1.0.0", "1.51.0"]) +@nox.parametrize("openai", ["1.1.0", "1.51.0"]) @nox.parametrize("aiofiles", ["0.5.0", "24.1.0"]) def integration_test( session: nox.Session, pydantic: str, openai: str, aiofiles: str diff --git a/poetry.lock b/poetry.lock index bfa7374..96e9982 100644 --- a/poetry.lock +++ b/poetry.lock @@ -952,4 +952,4 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "306f788a10adacae304c2e90099ee0827e97c56ace59b6f2926b406d2e403599" +content-hash = "879999e1de3da9332111f3cba2d0e0b721b69d0b3de85b4c1bf5e9ba13527f72" diff --git a/pyproject.toml b/pyproject.toml index af5d261..0a9eb20 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ repository = "https://github.com/epam/ai-dial-client-python" packages = [{ include = "aidial_client" }] [tool.poetry.dependencies] -openai = ">=1.0.0,<2.0.0" +openai = ">=1.1.0,<2.0.0" python = ">=3.8.1,<4.0" httpx = ">=0.25.0,<1.0" pydantic = ">=1.10,<3" diff --git a/tests/resources/completions/test_completions_streaming_tool_call.py b/tests/resources/completions/test_completions_streaming_tool_call.py new file mode 100644 index 0000000..a5b6cfb --- /dev/null +++ b/tests/resources/completions/test_completions_streaming_tool_call.py @@ -0,0 +1,203 @@ +import inspect +from typing import Iterable, List + +import pytest + +from aidial_client.types.chat import ChatCompletionChunk, ToolParam +from tests.client_mock import get_async_client_mock, get_client_mock +from tests.utils.chunks import create_mock_chunk, create_sse_data_field + +_TOOL_DEFINITION: ToolParam = { + "type": "function", + "function": { + "name": "web_search", + "description": "Performs WEB search.", + "parameters": { + "type": "object", + "properties": { + "request": { + "type": "string", + "description": "The search query or question to search for on the web", + } + }, + "required": ["request"], + }, + }, +} + +_DELTA_CHUNKS: List[dict] = [ + { + "role": "assistant", + "tool_calls": [ + { + "index": 0, + "id": "call_giAQRJYhG7UEMKwTU5dkuOKq", + "function": {"arguments": "", "name": "web_search"}, + "type": "function", + } + ], + }, + { + "role": None, + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": '{"', "name": None}, + "type": None, + } + ], + }, + { + "role": None, + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": "request", "name": None}, + "type": None, + } + ], + }, + { + "role": None, + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": '":"', "name": None}, + "type": None, + } + ], + }, + { + "role": None, + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": "current", "name": None}, + "type": None, + } + ], + }, + { + "role": None, + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": " weather", "name": None}, + "type": None, + } + ], + }, + { + "role": None, + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": " in", "name": None}, + "type": None, + } + ], + }, + { + "role": None, + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": " Paris", "name": None}, + "type": None, + } + ], + }, + { + "role": None, + "tool_calls": [ + { + "index": 0, + "id": None, + "function": {"arguments": '"}', "name": None}, + "type": None, + } + ], + }, + {"role": None, "tool_calls": None}, +] + +_STREAM_CHUNKS_MOCK: List[bytes] = [ + *[ + create_sse_data_field(create_mock_chunk(delta=delta)) + for delta in _DELTA_CHUNKS + ], + create_sse_data_field( + create_mock_chunk( + finish_reason="tool_calls", + usage={ + "prompt_tokens": 20, + "completion_tokens": 10, + "total_tokens": 30, + }, + ) + ), +] + + +def _validate_chunks(chunks: List[ChatCompletionChunk]): + assert all(len(chunk.choices) for chunk in chunks) + assert all(chunk.choices[0].delta for chunk in chunks) + assert all(chunk.choices[0].delta.tool_calls for chunk in chunks[:-2]) + assert all(chunk.choices[0].delta.content is None for chunk in chunks) + total_arguments = "".join( + chunk.choices[0].delta.tool_calls[0].function.arguments or "" for chunk in chunks[:-2] # type: ignore + ) + assert total_arguments == '{"request":"current weather in Paris"}' + + # Last chunk has no content, but usage and stop_reason + assert chunks[-1].choices[0].delta.content is None + assert chunks[-1].choices[0].finish_reason == "tool_calls" + assert chunks[-1].usage + assert chunks[-1].usage.total_tokens == 30 + assert chunks[-1].usage.prompt_tokens == 20 + assert chunks[-1].usage.completion_tokens == 10 + + +def test_sync_streaming_tool_call(): + client = get_client_mock( + status_code=200, + stream_chunks_mock=_STREAM_CHUNKS_MOCK, + ) + + response = client.chat.completions.create( + deployment_name="gpt-35-turbo", + messages=[{"role": "user", "content": "what's the weather in Paris?"}], + tools=[_TOOL_DEFINITION], + stream=True, + ) + assert isinstance(response, Iterable) + chunks = [chunk for chunk in response] + assert all(isinstance(chunk, ChatCompletionChunk) for chunk in chunks) + _validate_chunks(chunks) + + +@pytest.mark.asyncio +async def test_async_streaming_tool_call(): + async_client = get_async_client_mock( + status_code=200, + stream_chunks_mock=_STREAM_CHUNKS_MOCK, + ) + response = await async_client.chat.completions.create( + deployment_name="gpt-35-turbo", + messages=[{"role": "user", "content": "what's the weather in Paris?"}], + tools=[_TOOL_DEFINITION], + stream=True, + ) + + assert inspect.isasyncgen(response) + chunks = [chunk async for chunk in response] + assert all(isinstance(chunk, ChatCompletionChunk) for chunk in chunks) + _validate_chunks(chunks) + _validate_chunks(chunks) diff --git a/tests/resources/completions/test_completions_streaming.py b/tests/resources/completions/test_completions_streaming_vanilla.py similarity index 70% rename from tests/resources/completions/test_completions_streaming.py rename to tests/resources/completions/test_completions_streaming_vanilla.py index cf8fda8..dcb8492 100644 --- a/tests/resources/completions/test_completions_streaming.py +++ b/tests/resources/completions/test_completions_streaming_vanilla.py @@ -5,15 +5,27 @@ from aidial_client.types.chat import ChatCompletionChunk from tests.client_mock import get_async_client_mock, get_client_mock +from tests.utils.chunks import create_mock_chunk, create_sse_data_field STREAM_CHUNKS_MOCK: List[bytes] = [ - b'data: {"id":"chatcmpl-test","choices":[{"delta":{"content":"","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1723806872,"model":"gpt-35-turbo","object":"chat.completion.chunk","system_fingerprint":null}\n\n', # noqa: E501 - b'data: {"id":"chatcmpl-test","choices":[{"delta":{"content":"5"},"finish_reason":null,"index":0,"logprobs":null}],"created":1723806872,"model":"gpt-35-turbo","object":"chat.completion.chunk","system_fingerprint":null}\n\n' # noqa: E501 - b'data: {"id":"chatcmpl-test","choices":[{"delta":{},"finish_reason":"stop","index":0,"logprobs":null}],"created":1723806872,"model":"gpt-35-turbo","object":"chat.completion.chunk","system_fingerprint":null,"usage":{"completion_tokens":1,"prompt_tokens":11,"total_tokens":12}}\n\n', # noqa: E501 + create_sse_data_field( + create_mock_chunk(delta={"content": "", "role": "assistant"}) + ) + + create_sse_data_field(create_mock_chunk(delta={"content": "5"})), + create_sse_data_field( + create_mock_chunk( + finish_reason="stop", + usage={ + "completion_tokens": 1, + "prompt_tokens": 11, + "total_tokens": 12, + }, + ) + ), ] -def _validate_chunks(chunks): +def _validate_chunks(chunks: List[ChatCompletionChunk]): assert all(len(chunk.choices) for chunk in chunks) assert all(chunk.choices[0].delta for chunk in chunks) # All except last chunk has some content diff --git a/tests/utils/chunks.py b/tests/utils/chunks.py new file mode 100644 index 0000000..eee4ca9 --- /dev/null +++ b/tests/utils/chunks.py @@ -0,0 +1,34 @@ +import json +from typing import Optional, Union + + +def create_mock_chunk( + *, + delta: Optional[dict] = None, + finish_reason: Optional[str] = None, + usage: Optional[dict] = None, +) -> dict: + return { + "id": "chatcmpl-test", + "choices": [ + { + "delta": delta or {}, + "finish_reason": finish_reason, + "index": 0, + "logprobs": None, + } + ], + "created": 1723806872, + "model": "gpt-35-turbo", + "object": "chat.completion.chunk", + "system_fingerprint": None, + **({} if usage is None else {"usage": usage}), + } + + +def create_sse_data_field(chunk: Union[dict, str]) -> bytes: + if isinstance(chunk, dict): + s = json.dumps(chunk) + else: + s = chunk + return f"data: {s}\n\n".encode()