Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 27 additions & 18 deletions aidial_client/types/chat/response.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Literal, Optional, Union
from typing import Dict, List, Literal, Optional

from aidial_client._compatibility.pydantic import PYDANTIC_V2
from aidial_client._compatibility.pydantic_v1 import root_validator
Expand Down Expand Up @@ -54,12 +54,24 @@ class FunctionCall(ExtraAllowModel):
name: str


class FunctionCallDelta(ExtraAllowModel):
arguments: Optional[str] = None
name: Optional[str] = None


class ChatCompletionMessageToolCall(ExtraAllowModel):
id: str
function: FunctionCall
type: Literal["function"]


class ToolCallDelta(ExtraAllowModel):
index: int
id: Optional[str] = None
function: Optional[FunctionCallDelta] = None
type: Optional[Literal["function"]] = None


class ChatCompletionMessage(ExtraAllowModel):
role: Literal["assistant"]
content: Optional[str] = None
Expand All @@ -68,12 +80,26 @@ class ChatCompletionMessage(ExtraAllowModel):
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None


class ChatCompletionMessageDelta(ExtraAllowModel):
role: Optional[Literal["assistant"]] = None
content: Optional[str] = None
custom_content: Optional[CustomContent] = None
function_call: Optional[FunctionCallDelta] = None
tool_calls: Optional[List[ToolCallDelta]] = None


class Choice(ExtraAllowModel):
index: int
message: ChatCompletionMessage
finish_reason: Optional[str]


class ChoiceDelta(ExtraAllowModel):
index: int
delta: ChatCompletionMessageDelta
finish_reason: Optional[str] = None


class ChatCompletionResponse(ExtraAllowModel):
id: str
object: Literal["chat.completion"]
Expand All @@ -83,23 +109,6 @@ class ChatCompletionResponse(ExtraAllowModel):
usage: Optional[CompletionUsage] = None


class ChunkEmptyDelta(ExtraAllowModel):
"""
Sometimes delta could be just empty, or have just content
"""

content: Optional[str] = None
object: Literal[None] = None
tool_calls: Literal[None] = None
role: Literal[None] = None


class ChoiceDelta(ExtraAllowModel):
index: int
delta: Union[ChatCompletionMessage, ChunkEmptyDelta]
finish_reason: Optional[str] = None


class ChatCompletionChunk(ExtraAllowModel):
id: str
object: Literal["chat.completion.chunk"]
Expand Down
4 changes: 2 additions & 2 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def format(session: nox.Session):
@nox.session(python=["3.8", "3.9", "3.10", "3.11", "3.12"])
@nox.parametrize("pydantic", ["1.10.17", "2.8.2"])
@nox.parametrize("httpx", ["0.25.0", "0.27.0"])
@nox.parametrize("openai", ["1.0.0", "1.51.0"])
@nox.parametrize("openai", ["1.1.0", "1.51.0"])
@nox.parametrize("aiofiles", ["0.5.0", "24.1.0"])
def test(
session: nox.Session, pydantic: str, httpx: str, openai: str, aiofiles: str
Expand All @@ -74,7 +74,7 @@ def test(

@nox.session(python=["3.11"])
@nox.parametrize("pydantic", ["1.10.17", "2.8.2"])
@nox.parametrize("openai", ["1.0.0", "1.51.0"])
@nox.parametrize("openai", ["1.1.0", "1.51.0"])
@nox.parametrize("aiofiles", ["0.5.0", "24.1.0"])
def integration_test(
session: nox.Session, pydantic: str, openai: str, aiofiles: str
Expand Down
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ repository = "https://github.com/epam/ai-dial-client-python"
packages = [{ include = "aidial_client" }]

[tool.poetry.dependencies]
openai = ">=1.0.0,<2.0.0"
openai = ">=1.1.0,<2.0.0"
python = ">=3.8.1,<4.0"
httpx = ">=0.25.0,<1.0"
pydantic = ">=1.10,<3"
Expand Down
203 changes: 203 additions & 0 deletions tests/resources/completions/test_completions_streaming_tool_call.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
import inspect
from typing import Iterable, List

import pytest

from aidial_client.types.chat import ChatCompletionChunk, ToolParam
from tests.client_mock import get_async_client_mock, get_client_mock
from tests.utils.chunks import create_mock_chunk, create_sse_data_field

_TOOL_DEFINITION: ToolParam = {
"type": "function",
"function": {
"name": "web_search",
"description": "Performs WEB search.",
"parameters": {
"type": "object",
"properties": {
"request": {
"type": "string",
"description": "The search query or question to search for on the web",
}
},
"required": ["request"],
},
},
}

_DELTA_CHUNKS: List[dict] = [
{
"role": "assistant",
"tool_calls": [
{
"index": 0,
"id": "call_giAQRJYhG7UEMKwTU5dkuOKq",
"function": {"arguments": "", "name": "web_search"},
"type": "function",
}
],
},
{
"role": None,
"tool_calls": [
{
"index": 0,
"id": None,
"function": {"arguments": '{"', "name": None},
"type": None,
}
],
},
{
"role": None,
"tool_calls": [
{
"index": 0,
"id": None,
"function": {"arguments": "request", "name": None},
"type": None,
}
],
},
{
"role": None,
"tool_calls": [
{
"index": 0,
"id": None,
"function": {"arguments": '":"', "name": None},
"type": None,
}
],
},
{
"role": None,
"tool_calls": [
{
"index": 0,
"id": None,
"function": {"arguments": "current", "name": None},
"type": None,
}
],
},
{
"role": None,
"tool_calls": [
{
"index": 0,
"id": None,
"function": {"arguments": " weather", "name": None},
"type": None,
}
],
},
{
"role": None,
"tool_calls": [
{
"index": 0,
"id": None,
"function": {"arguments": " in", "name": None},
"type": None,
}
],
},
{
"role": None,
"tool_calls": [
{
"index": 0,
"id": None,
"function": {"arguments": " Paris", "name": None},
"type": None,
}
],
},
{
"role": None,
"tool_calls": [
{
"index": 0,
"id": None,
"function": {"arguments": '"}', "name": None},
"type": None,
}
],
},
{"role": None, "tool_calls": None},
]

_STREAM_CHUNKS_MOCK: List[bytes] = [
*[
create_sse_data_field(create_mock_chunk(delta=delta))
for delta in _DELTA_CHUNKS
],
create_sse_data_field(
create_mock_chunk(
finish_reason="tool_calls",
usage={
"prompt_tokens": 20,
"completion_tokens": 10,
"total_tokens": 30,
},
)
),
]


def _validate_chunks(chunks: List[ChatCompletionChunk]):
assert all(len(chunk.choices) for chunk in chunks)
assert all(chunk.choices[0].delta for chunk in chunks)
assert all(chunk.choices[0].delta.tool_calls for chunk in chunks[:-2])
assert all(chunk.choices[0].delta.content is None for chunk in chunks)
total_arguments = "".join(
chunk.choices[0].delta.tool_calls[0].function.arguments or "" for chunk in chunks[:-2] # type: ignore
)
assert total_arguments == '{"request":"current weather in Paris"}'

# Last chunk has no content, but usage and stop_reason
assert chunks[-1].choices[0].delta.content is None
assert chunks[-1].choices[0].finish_reason == "tool_calls"
assert chunks[-1].usage
assert chunks[-1].usage.total_tokens == 30
assert chunks[-1].usage.prompt_tokens == 20
assert chunks[-1].usage.completion_tokens == 10


def test_sync_streaming_tool_call():
client = get_client_mock(
status_code=200,
stream_chunks_mock=_STREAM_CHUNKS_MOCK,
)

response = client.chat.completions.create(
deployment_name="gpt-35-turbo",
messages=[{"role": "user", "content": "what's the weather in Paris?"}],
tools=[_TOOL_DEFINITION],
stream=True,
)
assert isinstance(response, Iterable)
chunks = [chunk for chunk in response]
assert all(isinstance(chunk, ChatCompletionChunk) for chunk in chunks)
_validate_chunks(chunks)


@pytest.mark.asyncio
async def test_async_streaming_tool_call():
async_client = get_async_client_mock(
status_code=200,
stream_chunks_mock=_STREAM_CHUNKS_MOCK,
)
response = await async_client.chat.completions.create(
deployment_name="gpt-35-turbo",
messages=[{"role": "user", "content": "what's the weather in Paris?"}],
tools=[_TOOL_DEFINITION],
stream=True,
)

assert inspect.isasyncgen(response)
chunks = [chunk async for chunk in response]
assert all(isinstance(chunk, ChatCompletionChunk) for chunk in chunks)
_validate_chunks(chunks)
_validate_chunks(chunks)
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,27 @@

from aidial_client.types.chat import ChatCompletionChunk
from tests.client_mock import get_async_client_mock, get_client_mock
from tests.utils.chunks import create_mock_chunk, create_sse_data_field

STREAM_CHUNKS_MOCK: List[bytes] = [
b'data: {"id":"chatcmpl-test","choices":[{"delta":{"content":"","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1723806872,"model":"gpt-35-turbo","object":"chat.completion.chunk","system_fingerprint":null}\n\n', # noqa: E501
b'data: {"id":"chatcmpl-test","choices":[{"delta":{"content":"5"},"finish_reason":null,"index":0,"logprobs":null}],"created":1723806872,"model":"gpt-35-turbo","object":"chat.completion.chunk","system_fingerprint":null}\n\n' # noqa: E501
b'data: {"id":"chatcmpl-test","choices":[{"delta":{},"finish_reason":"stop","index":0,"logprobs":null}],"created":1723806872,"model":"gpt-35-turbo","object":"chat.completion.chunk","system_fingerprint":null,"usage":{"completion_tokens":1,"prompt_tokens":11,"total_tokens":12}}\n\n', # noqa: E501
create_sse_data_field(
create_mock_chunk(delta={"content": "", "role": "assistant"})
)
+ create_sse_data_field(create_mock_chunk(delta={"content": "5"})),
create_sse_data_field(
create_mock_chunk(
finish_reason="stop",
usage={
"completion_tokens": 1,
"prompt_tokens": 11,
"total_tokens": 12,
},
)
),
]


def _validate_chunks(chunks):
def _validate_chunks(chunks: List[ChatCompletionChunk]):
assert all(len(chunk.choices) for chunk in chunks)
assert all(chunk.choices[0].delta for chunk in chunks)
# All except last chunk has some content
Expand Down
34 changes: 34 additions & 0 deletions tests/utils/chunks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import json
from typing import Optional, Union


def create_mock_chunk(
*,
delta: Optional[dict] = None,
finish_reason: Optional[str] = None,
usage: Optional[dict] = None,
) -> dict:
return {
"id": "chatcmpl-test",
"choices": [
{
"delta": delta or {},
"finish_reason": finish_reason,
"index": 0,
"logprobs": None,
}
],
"created": 1723806872,
"model": "gpt-35-turbo",
"object": "chat.completion.chunk",
"system_fingerprint": None,
**({} if usage is None else {"usage": usage}),
}


def create_sse_data_field(chunk: Union[dict, str]) -> bytes:
if isinstance(chunk, dict):
s = json.dumps(chunk)
else:
s = chunk
return f"data: {s}\n\n".encode()