diff --git a/.gitignore b/.gitignore
index 19015866..d80305ef 100644
--- a/.gitignore
+++ b/.gitignore
@@ -56,6 +56,7 @@ env3
.env*
.env*.local
.venv*
+*venv*
env*/
venv*/
ENV/
@@ -66,6 +67,7 @@ venv.bak/
config.yaml
bun.lockb
+
# Jupyter Notebook
.ipynb_checkpoints
@@ -76,4 +78,4 @@ bun.lockb
llmstudio/llm_engine/logs/execution_logs.jsonl
*.db
.prettierignore
-db
\ No newline at end of file
+db
diff --git a/Makefile b/Makefile
index 5a43a3e6..6c957607 100644
--- a/Makefile
+++ b/Makefile
@@ -1,2 +1,5 @@
format:
pre-commit run --all-files
+
+unit-tests:
+ pytest libs/core/tests/unit_tests
diff --git a/libs/core/llmstudio_core/providers/azure.py b/libs/core/llmstudio_core/providers/azure.py
index 2dbd7307..f558f9d6 100644
--- a/libs/core/llmstudio_core/providers/azure.py
+++ b/libs/core/llmstudio_core/providers/azure.py
@@ -62,7 +62,25 @@ async def agenerate_client(self, request: ChatRequest) -> Any:
return self.generate_client(request=request)
def generate_client(self, request: ChatRequest) -> Any:
- """Generate an AzureOpenAI client"""
+ """
+ Generates an AzureOpenAI client for processing a chat request.
+
+ This method prepares and configures the arguments required to create a client
+ request to AzureOpenAI's chat completions API. It determines model-specific
+ configurations (e.g., whether tools or functions are enabled) and combines
+ these with the base arguments for the API call.
+
+ Args:
+ request (ChatRequest): The chat request object containing the model,
+ parameters, and other necessary details.
+
+ Returns:
+ Any: The result of the chat completions API call.
+
+ Raises:
+ ProviderError: If there is an issue with the API connection or an error
+ returned from the API.
+ """
self.is_llama = "llama" in request.model.lower()
self.is_openai = "gpt" in request.model.lower()
@@ -72,7 +90,6 @@ def generate_client(self, request: ChatRequest) -> Any:
try:
messages = self.prepare_messages(request)
- # Prepare the optional tool-related arguments
tool_args = {}
if not self.is_llama and self.has_tools and self.is_openai:
tool_args = {
@@ -80,7 +97,6 @@ def generate_client(self, request: ChatRequest) -> Any:
"tool_choice": "auto" if request.parameters.get("tools") else None,
}
- # Prepare the optional function-related arguments
function_args = {}
if not self.is_llama and self.has_functions and self.is_openai:
function_args = {
@@ -90,14 +106,12 @@ def generate_client(self, request: ChatRequest) -> Any:
else None,
}
- # Prepare the base arguments
base_args = {
"model": request.model,
"messages": messages,
"stream": True,
}
- # Combine all arguments
combined_args = {
**base_args,
**tool_args,
@@ -116,13 +130,13 @@ def prepare_messages(self, request: ChatRequest):
if self.is_llama and (self.has_tools or self.has_functions):
user_message = self.convert_to_openai_format(request.chat_input)
content = "<|begin_of_text|>"
- content = self.add_system_message(
+ content = self.build_llama_system_message(
user_message,
content,
request.parameters.get("tools"),
request.parameters.get("functions"),
)
- content = self.add_conversation(user_message, content)
+ content = self.build_llama_conversation(user_message, content)
return [{"role": "user", "content": content}]
else:
return (
@@ -139,6 +153,20 @@ async def aparse_response(
yield chunk
def parse_response(self, response: AsyncGenerator, **kwargs) -> Any:
+ """
+ Processes a generator response and yields processed chunks.
+
+ If `is_llama` is True and tools or functions are enabled, it processes the response
+ using `handle_tool_response`. Otherwise, it processes each chunk and yields only those
+ containing "choices".
+
+ Args:
+ response (Generator): The response generator to process.
+ **kwargs: Additional arguments for tool handling.
+
+ Yields:
+ Any: Processed response chunks.
+ """
if self.is_llama and (self.has_tools or self.has_functions):
for chunk in self.handle_tool_response(response, **kwargs):
if chunk:
@@ -388,9 +416,25 @@ def convert_to_openai_format(self, message: Union[str, list]) -> list:
return [{"role": "user", "content": message}]
return message
- def add_system_message(
+ def build_llama_system_message(
self, openai_message: list, llama_message: str, tools: list, functions: list
) -> str:
+ """
+ Builds a complete system message for Llama based on OpenAI's message, tools, and functions.
+
+ If a system message is present in the OpenAI message, it is included in the result.
+ Otherwise, a default system message is used. Additional tool and function instructions
+ are appended if provided.
+
+ Args:
+ openai_message (list): List of OpenAI messages.
+ llama_message (str): The message to prepend to the system message.
+ tools (list): List of tools to include in the system message.
+ functions (list): List of functions to include in the system message.
+
+ Returns:
+ str: The formatted system message combined with Llama message.
+ """
system_message = ""
system_message_found = False
for message in openai_message:
@@ -407,15 +451,31 @@ def add_system_message(
"""
if tools:
- system_message = system_message + self.add_tool_instructions(tools)
+ system_message = system_message + self.build_tool_instructions(tools)
if functions:
- system_message = system_message + self.add_function_instructions(functions)
+ system_message = system_message + self.build_function_instructions(
+ functions
+ )
end_tag = "\n<|eot_id|>"
return llama_message + system_message + end_tag
- def add_tool_instructions(self, tools: list) -> str:
+ def build_tool_instructions(self, tools: list) -> str:
+ """
+ Builds a detailed instructional prompt for tools available to the assistant.
+
+ This function generates a message describing the available tools, focusing on tools
+ of type "function." It explains to the LLM how to use each tool and provides an example of the
+ correct response format for function calls.
+
+ Args:
+ tools (list): A list of tool dictionaries, where each dictionary contains tool
+ details such as type, function name, description, and parameters.
+
+ Returns:
+ str: A formatted string detailing the tool instructions and usage examples.
+ """
tool_prompt = """
You have access to the following tools:
"""
@@ -449,7 +509,21 @@ def add_tool_instructions(self, tools: list) -> str:
return tool_prompt
- def add_function_instructions(self, functions: list) -> str:
+ def build_function_instructions(self, functions: list) -> str:
+ """
+ Builds a detailed instructional prompt for available functions.
+
+ This method creates a message describing the functions accessible to the assistant.
+ It includes the function name, description, and required parameters, along with
+ specific guidelines for calling functions.
+
+ Args:
+ functions (list): A list of function dictionaries, each containing details such as
+ name, description, and parameters.
+
+ Returns:
+ str: A formatted string with instructions on using the provided functions.
+ """
function_prompt = """
You have access to the following functions:
"""
@@ -479,35 +553,60 @@ def add_function_instructions(self, functions: list) -> str:
"""
return function_prompt
- def add_conversation(self, openai_message: list, llama_message: str) -> str:
+ def build_llama_conversation(self, openai_message: list, llama_message: str) -> str:
+ """
+ Appends the OpenAI message to the Llama message while formatting OpenAI messages.
+
+ This function iterates through a list of OpenAI messages and formats them for inclusion
+ in a Llama message. It handles user messages that might include nested content (lists of
+ messages) by safely evaluating the content. System messages are skipped.
+
+ Args:
+ openai_message (list): A list of dictionaries representing the OpenAI messages. Each
+ dictionary should have "role" and "content" keys.
+ llama_message (str): The initial Llama message to which the conversation is appended.
+
+ Returns:
+ str: The Llama message with the conversation appended.
+ """
conversation_parts = []
for message in openai_message:
if message["role"] == "system":
continue
elif message["role"] == "user" and isinstance(message["content"], str):
try:
- # Attempt to safely evaluate the string to a Python object
content_as_list = ast.literal_eval(message["content"])
if isinstance(content_as_list, list):
- # If the content is a list, process each nested message
for nested_message in content_as_list:
conversation_parts.append(
self.format_message(nested_message)
)
else:
- # If the content is not a list, append it directly
conversation_parts.append(self.format_message(message))
except (ValueError, SyntaxError):
- # If evaluation fails or content is not a list/dict string, append the message directly
conversation_parts.append(self.format_message(message))
else:
- # For all other messages, use the existing formatting logic
conversation_parts.append(self.format_message(message))
return llama_message + "".join(conversation_parts)
def format_message(self, message: dict) -> str:
- """Format a single message for the conversation."""
+ """
+ Formats a single message dictionary into a structured string for a conversation.
+
+ The formatting depends on the content of the message, such as tool calls,
+ function calls, or simple user/assistant messages. Each type of message
+ is formatted with specific headers and tags.
+
+ Args:
+ message (dict): A dictionary containing message details. Expected keys
+ include "role", "content", and optionally "tool_calls",
+ "tool_call_id", or "function_call".
+
+ Returns:
+ str: A formatted string representing the message. Returns an empty
+ string if the message cannot be formatted.
+ """
if "tool_calls" in message:
for tool_call in message["tool_calls"]:
function_name = tool_call["function"]["name"]
diff --git a/libs/core/llmstudio_core/providers/provider.py b/libs/core/llmstudio_core/providers/provider.py
index 07e03835..69a673dd 100644
--- a/libs/core/llmstudio_core/providers/provider.py
+++ b/libs/core/llmstudio_core/providers/provider.py
@@ -152,8 +152,38 @@ async def achat(
parameters: Optional[dict] = {},
**kwargs,
):
-
- """Makes a chat connection with the provider's API"""
+ """
+ Asynchronously establishes a chat connection with the provider’s API, handling retries,
+ request validation, and streaming response options.
+
+ Parameters
+ ----------
+ chat_input : Any
+ The input data for the chat request, such as a string or dictionary, to be sent to the API.
+ model : str
+ The identifier of the model to be used for the chat request.
+ is_stream : Optional[bool], default=False
+ Flag to indicate if the response should be streamed. If True, returns an async generator
+ for streaming content; otherwise, returns the complete response.
+ retries : Optional[int], default=0
+ Number of retry attempts on error. Retries will be attempted for specific HTTP errors like rate limits.
+ parameters : Optional[dict], default={}
+ Additional configuration parameters for the request, such as temperature or max tokens.
+ **kwargs
+ Additional keyword arguments to customize the request.
+
+ Returns
+ -------
+ Union[AsyncGenerator, Any]
+ - If `is_stream` is True, returns an async generator yielding response chunks.
+ - If `is_stream` is False, returns the first complete response chunk.
+
+ Raises
+ ------
+ ProviderError
+ - Raised if the request validation fails or if all retry attempts are exhausted.
+ - Also raised for unexpected exceptions during request handling.
+ """
try:
request = self.validate_request(
dict(
@@ -198,8 +228,38 @@ def chat(
parameters: Optional[dict] = {},
**kwargs,
):
-
- """Makes a chat connection with the provider's API"""
+ """
+ Establishes a chat connection with the provider’s API, handling retries, request validation,
+ and streaming response options.
+
+ Parameters
+ ----------
+ chat_input : Any
+ The input data for the chat request, often a string or dictionary, to be sent to the API.
+ model : str
+ The model identifier for selecting the model used in the chat request.
+ is_stream : Optional[bool], default=False
+ Flag to indicate if the response should be streamed. If True, the function returns a generator
+ for streaming content. Otherwise, it returns the complete response.
+ retries : Optional[int], default=0
+ Number of retry attempts on error. Retries will be attempted on specific HTTP errors like rate limits.
+ parameters : Optional[dict], default={}
+ Additional configuration parameters for the request, such as temperature or max tokens.
+ **kwargs
+ Additional keyword arguments that can be passed to customize the request.
+
+ Returns
+ -------
+ Union[Generator, Any]
+ - If `is_stream` is True, returns a generator that yields chunks of the response.
+ - If `is_stream` is False, returns the first complete response chunk.
+
+ Raises
+ ------
+ ProviderError
+ - Raised if the request validation fails or if the request fails after the specified number of retries.
+ - Also raised on other unexpected exceptions during request handling.
+ """
try:
request = self.validate_request(
dict(
@@ -238,7 +298,28 @@ def chat(
async def ahandle_response(
self, request: ChatRequest, response: AsyncGenerator, start_time: float
) -> AsyncGenerator[str, None]:
- """Handles the response from an API"""
+ """
+ Asynchronously handles the response from an API, processing response chunks for either
+ streaming or non-streaming responses.
+
+ Buffers response chunks for non-streaming responses to output one single message. For streaming responses sends incremental chunks.
+
+ Parameters
+ ----------
+ request : ChatRequest
+ The chat request object, which includes input data, model name, and streaming options.
+ response : AsyncGenerator
+ The async generator yielding response chunks from the API.
+ start_time : float
+ The timestamp when the response handling started, used for latency calculations.
+
+ Yields
+ ------
+ Union[ChatCompletionChunk, ChatCompletion]
+ - If `request.is_stream` is True, yields `ChatCompletionChunk` objects with incremental
+ response chunks for streaming.
+ - If `request.is_stream` is False, yields a final `ChatCompletion` object after processing all chunks.
+ """
first_token_time = None
previous_token_time = None
token_times = []
@@ -294,7 +375,7 @@ async def ahandle_response(
chunks = [chunk[0] if isinstance(chunk, tuple) else chunk for chunk in chunks]
model = next(chunk["model"] for chunk in chunks if chunk.get("model"))
- response, output_string = self.join_chunks(chunks, request)
+ response, output_string = self.join_chunks(chunks)
metrics = self.calculate_metrics(
request.chat_input,
@@ -346,7 +427,29 @@ async def ahandle_response(
def handle_response(
self, request: ChatRequest, response: Generator, start_time: float
) -> Generator:
- """Handles the response from an API"""
+ """
+ Processes API response chunks to build a structured, complete response, yielding
+ each chunk if streaming is enabled.
+
+ If streaming, each chunk is yielded as soon as it’s processed. Otherwise, all chunks
+ are combined and yielded as a single response at the end.
+
+ Parameters
+ ----------
+ request : ChatRequest
+ The original request details, including model, input, and streaming preference.
+ response : Generator
+ A generator yielding partial response chunks from the API.
+ start_time : float
+ The start time for measuring response timing.
+
+ Yields
+ ------
+ Union[ChatCompletionChunk, ChatCompletion]
+ If streaming (`is_stream=True`), yields each `ChatCompletionChunk` as it’s processed.
+ Otherwise, yields a single `ChatCompletion` with the full response data.
+
+ """
first_token_time = None
previous_token_time = None
token_times = []
@@ -402,7 +505,7 @@ def handle_response(
chunks = [chunk[0] if isinstance(chunk, tuple) else chunk for chunk in chunks]
model = next(chunk["model"] for chunk in chunks if chunk.get("model"))
- response, output_string = self.join_chunks(chunks, request)
+ response, output_string = self.join_chunks(chunks)
metrics = self.calculate_metrics(
request.chat_input,
@@ -451,7 +554,29 @@ def handle_response(
else:
yield ChatCompletion(**response)
- def join_chunks(self, chunks, request):
+ def join_chunks(self, chunks):
+ """
+ Combine multiple response chunks from the model into a single, structured response.
+ Handles tool calls, function calls, and standard text completion based on the
+ purpose indicated by the final chunk.
+
+ Parameters
+ ----------
+ chunks : List[Dict]
+ A list of partial responses (chunks) from the model.
+
+ Returns
+ -------
+ Tuple[ChatCompletion, str]
+ - `ChatCompletion`: The structured response based on the type of completion
+ (tool calls, function call, or text).
+ - `str`: The concatenated content or arguments, depending on the completion type.
+
+ Raises
+ ------
+ Exception
+ If there is an issue constructing the response, an exception is raised.
+ """
finish_reason = chunks[-1].get("choices")[0].get("finish_reason")
if finish_reason == "tool_calls":
@@ -612,7 +737,42 @@ def calculate_metrics(
token_times: Tuple[float, ...],
token_count: int,
) -> Dict[str, Any]:
- """Calculates metrics based on token times and output"""
+ """
+ Calculates performance and cost metrics for a model response based on timing
+ information, token counts, and model-specific costs.
+
+ Parameters
+ ----------
+ input : Any
+ The input provided to the model, used to determine input token count.
+ output : Any
+ The output generated by the model, used to determine output token count.
+ model : str
+ The model identifier, used to retrieve model-specific configuration and costs.
+ start_time : float
+ The timestamp marking the start of the model response.
+ end_time : float
+ The timestamp marking the end of the model response.
+ first_token_time : float
+ The timestamp when the first token was received, used for latency calculations.
+ token_times : Tuple[float, ...]
+ A tuple of time intervals between received tokens, used for inter-token latency.
+ token_count : int
+ The total number of tokens processed in the response.
+
+ Returns
+ -------
+ Dict[str, Any]
+ A dictionary containing calculated metrics, including:
+ - `input_tokens`: Number of tokens in the input.
+ - `output_tokens`: Number of tokens in the output.
+ - `total_tokens`: Total token count (input + output).
+ - `cost_usd`: Total cost of the response in USD.
+ - `latency_s`: Total time taken for the response, in seconds.
+ - `time_to_first_token_s`: Time to receive the first token, in seconds.
+ - `inter_token_latency_s`: Average time between tokens, in seconds. If `token_times` is empty sets it to 0.
+ - `tokens_per_second`: Processing rate of tokens per second.
+ """
model_config = self.config.models[model]
input_tokens = len(self.tokenizer.encode(self.input_to_string(input)))
output_tokens = len(self.tokenizer.encode(self.output_to_string(output)))
@@ -628,17 +788,42 @@ def calculate_metrics(
"cost_usd": input_cost + output_cost,
"latency_s": total_time,
"time_to_first_token_s": first_token_time - start_time,
- "inter_token_latency_s": sum(token_times) / len(token_times),
- "tokens_per_second": token_count / total_time,
+ "inter_token_latency_s": sum(token_times) / len(token_times)
+ if token_times
+ else 0,
+ "tokens_per_second": token_count / total_time
+ if token_times
+ else 1 / total_time,
}
def calculate_cost(
self, token_count: int, token_cost: Union[float, List[Dict[str, Any]]]
) -> float:
+ """
+ Calculates the cost for a given number of tokens based on a fixed cost per token
+ or a variable rate structure.
+
+ If `token_cost` is a fixed float, the total cost is `token_count * token_cost`.
+ If `token_cost` is a list, it checks each range and calculates cost based on the applicable range's rate.
+
+ Parameters
+ ----------
+ token_count : int
+ The total number of tokens for which the cost is being calculated.
+ token_cost : Union[float, List[Dict[str, Any]]]
+ Either a fixed cost per token (as a float) or a list of dictionaries defining
+ variable cost ranges. Each dictionary in the list represents a range with
+ 'range' (a tuple of minimum and maximum token counts) and 'cost' (cost per token) keys.
+
+ Returns
+ -------
+ float
+ The calculated cost based on the token count and cost structure.
+ """
if isinstance(token_cost, list):
for cost_range in token_cost:
if token_count >= cost_range.range[0] and (
- token_count <= cost_range.range[1] or cost_range.range[1] is None
+ cost_range.range[1] is None or token_count <= cost_range.range[1]
):
return cost_range.cost * token_count
else:
@@ -646,6 +831,23 @@ def calculate_cost(
return 0
def input_to_string(self, input):
+ """
+ Converts an input, which can be a string or a structured list of messages, into a single concatenated string.
+
+ Parameters
+ ----------
+ input : Any
+ The input data to be converted. This can be:
+ - A simple string, which is returned as-is.
+ - A list of message dictionaries, where each dictionary may contain `content`, `role`,
+ and nested items like `text` or `image_url`.
+
+ Returns
+ -------
+ str
+ A concatenated string representing the text content of all messages,
+ including text and URLs from image content if present.
+ """
if isinstance(input, str):
return input
else:
@@ -667,6 +869,23 @@ def input_to_string(self, input):
return "".join(result)
def output_to_string(self, output):
+ """
+ Extracts and returns the content or arguments from the output based on
+ the `finish_reason` of the first choice in `output`.
+
+ Parameters
+ ----------
+ output : Any
+ The model output object, expected to have a `choices` attribute that should contain a `finish_reason` indicating the type of output
+ ("stop", "tool_calls", or "function_call") and corresponding content or arguments.
+
+ Returns
+ -------
+ str
+ - If `finish_reason` is "stop": Returns the message content.
+ - If `finish_reason` is "tool_calls": Returns the arguments for the first tool call.
+ - If `finish_reason` is "function_call": Returns the arguments for the function call.
+ """
if output.choices[0].finish_reason == "stop":
return output.choices[0].message.content
elif output.choices[0].finish_reason == "tool_calls":
diff --git a/libs/core/tests/unit_tests/conftest.py b/libs/core/tests/unit_tests/conftest.py
index 23f070e3..5eeccc83 100644
--- a/libs/core/tests/unit_tests/conftest.py
+++ b/libs/core/tests/unit_tests/conftest.py
@@ -1,6 +1,7 @@
from unittest.mock import MagicMock
import pytest
+from llmstudio_core.providers.azure import AzureProvider
from llmstudio_core.providers.provider import ProviderCore
@@ -11,14 +12,6 @@ async def aparse_response(self, response, **kwargs):
def parse_response(self, response, **kwargs):
return response
- def chat(self, chat_input, model, **kwargs):
- # Mock the response to match expected structure
- return MagicMock(choices=[MagicMock(finish_reason="stop")])
-
- async def achat(self, chat_input, model, **kwargs):
- # Mock the response to match expected structure
- return MagicMock(choices=[MagicMock(finish_reason="stop")])
-
def output_to_string(self, output):
# Handle string inputs
if isinstance(output, str):
@@ -27,6 +20,24 @@ def output_to_string(self, output):
return output.choices[0].message.content
return ""
+ def validate_request(self, request):
+ # For testing, simply return the request
+ return request
+
+ async def agenerate_client(self, request):
+ # For testing, return an async generator
+ async def async_gen():
+ yield {}
+
+ return async_gen()
+
+ def generate_client(self, request):
+ # For testing, return a generator
+ def gen():
+ yield {}
+
+ return gen()
+
@staticmethod
def _provider_config_name():
return "mock_provider"
@@ -42,3 +53,27 @@ def mock_provider():
tokenizer = MagicMock()
tokenizer.encode = lambda x: x.split() # Simple tokenizer mock
return MockProvider(config=config, tokenizer=tokenizer)
+
+
+class MockAzureProvider(AzureProvider):
+ async def aparse_response(self, response, **kwargs):
+ return response
+
+ async def agenerate_client(self, request):
+ # For testing, return an async generator
+ async def async_gen():
+ yield {}
+
+ return async_gen()
+
+ @staticmethod
+ def _provider_config_name():
+ return "mock_azure_provider"
+
+
+@pytest.fixture
+def mock_azure_provider():
+ config = MagicMock()
+ config.id = "mock_azure_provider"
+ base_url = "mock_url.com"
+ return MockAzureProvider(config=config, base_url=base_url)
diff --git a/libs/core/tests/unit_tests/test_azure.py b/libs/core/tests/unit_tests/test_azure.py
new file mode 100644
index 00000000..28e5f684
--- /dev/null
+++ b/libs/core/tests/unit_tests/test_azure.py
@@ -0,0 +1,232 @@
+from unittest.mock import MagicMock
+
+
+class TestParseResponse:
+ def test_tool_response_handling(self, mock_azure_provider):
+
+ mock_azure_provider.is_llama = True
+ mock_azure_provider.has_tools = True
+ mock_azure_provider.has_functions = False
+
+ mock_azure_provider.handle_tool_response = MagicMock(
+ return_value=iter(["chunk1", None, "chunk2", None, "chunk3"])
+ )
+
+ response = iter(["irrelevant"])
+
+ results = list(mock_azure_provider.parse_response(response))
+
+ assert results == ["chunk1", "chunk2", "chunk3"]
+ mock_azure_provider.handle_tool_response.assert_called_once_with(response)
+
+ def test_direct_response_handling_with_choices(self, mock_azure_provider):
+ mock_azure_provider.is_llama = False
+
+ chunk1 = MagicMock()
+ chunk1.model_dump.return_value = {"choices": ["choice1", "choice2"]}
+ chunk2 = MagicMock()
+ chunk2.model_dump.return_value = {"choices": ["choice2"]}
+ response = iter([chunk1, chunk2])
+
+ results = list(mock_azure_provider.parse_response(response))
+
+ assert results == [
+ {"choices": ["choice1", "choice2"]},
+ {"choices": ["choice2"]},
+ ]
+ chunk1.model_dump.assert_called_once()
+ chunk2.model_dump.assert_called_once()
+
+ def test_direct_response_handling_without_choices(self, mock_azure_provider):
+ mock_azure_provider.is_llama = False
+
+ chunk1 = MagicMock()
+ chunk1.model_dump.return_value = {"key": "value"}
+ chunk2 = MagicMock()
+ chunk2.model_dump.return_value = {"another_key": "another_value"}
+ response = iter([chunk1, chunk2])
+
+ results = list(mock_azure_provider.parse_response(response))
+
+ assert results == []
+ chunk1.model_dump.assert_called_once()
+ chunk2.model_dump.assert_called_once()
+
+
+class TestFormatMessage:
+ def test_format_message_tool_calls(self, mock_azure_provider):
+ message = {
+ "tool_calls": [
+ {
+ "function": {
+ "name": "example_tool",
+ "arguments": '{"arg1": "value1"}',
+ }
+ }
+ ]
+ }
+ result = mock_azure_provider.format_message(message)
+ expected = """
+ <|start_header_id|>assistant<|end_header_id|>
+ {"arg1": "value1"}
+ <|eom_id|>
+ """
+ assert result.strip() == expected.strip()
+
+ def test_format_message_tool_call_id(self, mock_azure_provider):
+ message = {"tool_call_id": "123", "content": "This is the tool response."}
+ result = mock_azure_provider.format_message(message)
+ expected = """
+ <|start_header_id|>ipython<|end_header_id|>
+ This is the tool response.
+ <|eot_id|>
+ """
+ assert result.strip() == expected.strip()
+
+ def test_format_message_function_call(self, mock_azure_provider):
+ message = {
+ "function_call": {
+ "name": "example_function",
+ "arguments": '{"arg1": "value1"}',
+ }
+ }
+ result = mock_azure_provider.format_message(message)
+ expected = """
+ <|start_header_id|>assistant<|end_header_id|>
+ {"arg1": "value1"}
+ <|eom_id|>
+ """
+ assert result.strip() == expected.strip()
+
+ def test_format_message_user_message(self, mock_azure_provider):
+ message = {"role": "user", "content": "This is a user message."}
+ result = mock_azure_provider.format_message(message)
+ expected = """
+ <|start_header_id|>user<|end_header_id|>
+ This is a user message.
+ <|eot_id|>
+ """
+ assert result.strip() == expected.strip()
+
+ def test_format_message_assistant_message(self, mock_azure_provider):
+ message = {"role": "assistant", "content": "This is an assistant message."}
+ result = mock_azure_provider.format_message(message)
+ expected = """
+ <|start_header_id|>assistant<|end_header_id|>
+ This is an assistant message.
+ <|eot_id|>
+ """
+ assert result.strip() == expected.strip()
+
+ def test_format_message_function_response(self, mock_azure_provider):
+ message = {"role": "function", "content": "This is the function response."}
+ result = mock_azure_provider.format_message(message)
+ expected = """
+ <|start_header_id|>ipython<|end_header_id|>
+ This is the function response.
+ <|eot_id|>
+ """
+ assert result.strip() == expected.strip()
+
+ def test_format_message_empty_message(self, mock_azure_provider):
+ message = {"role": "user", "content": None}
+ result = mock_azure_provider.format_message(message)
+ expected = ""
+ assert result == expected
+
+
+class TestGenerateClient:
+ def test_generate_client_with_tools_and_functions(self, mock_azure_provider):
+ mock_azure_provider.prepare_messages = MagicMock(
+ return_value="prepared_messages"
+ )
+ mock_azure_provider._client.chat.completions.create = MagicMock(
+ return_value="mock_response"
+ )
+
+ request = MagicMock()
+ request.model = "gpt-4"
+ request.parameters = {
+ "tools": ["tool1", "tool2"],
+ "functions": ["function1", "function2"],
+ "other_param": "value",
+ }
+
+ result = mock_azure_provider.generate_client(request)
+
+ expected_args = {
+ "model": "gpt-4",
+ "messages": "prepared_messages",
+ "stream": True,
+ "tools": ["tool1", "tool2"],
+ "tool_choice": "auto",
+ "functions": ["function1", "function2"],
+ "function_call": "auto",
+ "other_param": "value",
+ }
+
+ assert result == "mock_response"
+ mock_azure_provider.prepare_messages.assert_called_once_with(request)
+ mock_azure_provider._client.chat.completions.create.assert_called_once_with(
+ **expected_args
+ )
+
+ def test_generate_client_without_tools_or_functions(self, mock_azure_provider):
+ mock_azure_provider.prepare_messages = MagicMock(
+ return_value="prepared_messages"
+ )
+ mock_azure_provider._client.chat.completions.create = MagicMock(
+ return_value="mock_response"
+ )
+
+ request = MagicMock()
+ request.model = "gpt-4"
+ request.parameters = {"other_param": "value"}
+
+ result = mock_azure_provider.generate_client(request)
+
+ expected_args = {
+ "model": "gpt-4",
+ "messages": "prepared_messages",
+ "stream": True,
+ "other_param": "value",
+ }
+
+ assert result == "mock_response"
+ mock_azure_provider.prepare_messages.assert_called_once_with(request)
+ mock_azure_provider._client.chat.completions.create.assert_called_once_with(
+ **expected_args
+ )
+
+ def test_generate_client_with_llama_model(self, mock_azure_provider):
+ mock_azure_provider.prepare_messages = MagicMock(
+ return_value="prepared_messages"
+ )
+ mock_azure_provider._client.chat.completions.create = MagicMock(
+ return_value="mock_response"
+ )
+
+ request = MagicMock()
+ request.model = "llama-2"
+ request.parameters = {
+ "tools": ["tool1"],
+ "functions": ["function1"],
+ "other_param": "value",
+ }
+
+ result = mock_azure_provider.generate_client(request)
+
+ expected_args = {
+ "model": "llama-2",
+ "messages": "prepared_messages",
+ "stream": True,
+ "tools": ["tool1"],
+ "functions": ["function1"],
+ "other_param": "value",
+ }
+
+ assert result == "mock_response"
+ mock_azure_provider.prepare_messages.assert_called_once_with(request)
+ mock_azure_provider._client.chat.completions.create.assert_called_once_with(
+ **expected_args
+ )
diff --git a/libs/core/tests/unit_tests/test_azure_build.py b/libs/core/tests/unit_tests/test_azure_build.py
new file mode 100644
index 00000000..17d73f99
--- /dev/null
+++ b/libs/core/tests/unit_tests/test_azure_build.py
@@ -0,0 +1,279 @@
+from unittest.mock import MagicMock, patch
+
+
+class TestBuildLlamaSystemMessage:
+ def test_build_llama_system_message_with_existing_sm(self, mock_azure_provider):
+ mock_azure_provider.build_tool_instructions = MagicMock(
+ return_value="Tool Instructions"
+ )
+ mock_azure_provider.build_function_instructions = MagicMock(
+ return_value="\nFunction Instructions"
+ )
+
+ openai_message = [
+ {"role": "user", "content": "Hello"},
+ {"role": "system", "content": "Custom system message"},
+ ]
+ llama_message = "Initial message"
+ tools = ["Tool1", "Tool2"]
+ functions = ["Function1"]
+
+ result = mock_azure_provider.build_llama_system_message(
+ openai_message, llama_message, tools, functions
+ )
+
+ expected = (
+ "Initial message\n"
+ " <|start_header_id|>system<|end_header_id|>\n"
+ " Custom system message\n"
+ " Tool Instructions\nFunction Instructions\n<|eot_id|>" # identation here exists because in Python when adding a newline to a triple quote string it keeps identation
+ )
+ assert result == expected
+ mock_azure_provider.build_tool_instructions.assert_called_once_with(tools)
+ mock_azure_provider.build_function_instructions.assert_called_once_with(
+ functions
+ )
+
+ def test_build_llama_system_message_with_default_sm(self, mock_azure_provider):
+ mock_azure_provider.build_tool_instructions = MagicMock(
+ return_value="Tool Instructions"
+ )
+ mock_azure_provider.build_function_instructions = MagicMock(
+ return_value="\nFunction Instructions"
+ )
+
+ openai_message = [{"role": "user", "content": "Hello"}]
+ llama_message = "Initial message"
+ tools = ["Tool1"]
+ functions = []
+
+ result = mock_azure_provider.build_llama_system_message(
+ openai_message, llama_message, tools, functions
+ )
+
+ expected = (
+ "Initial message\n"
+ " <|start_header_id|>system<|end_header_id|>\n"
+ " You are a helpful AI assistant.\n"
+ " Tool Instructions\n<|eot_id|>"
+ )
+ assert result == expected
+ mock_azure_provider.build_tool_instructions.assert_called_once_with(tools)
+ mock_azure_provider.build_function_instructions.assert_not_called()
+
+ def test_build_llama_system_message_without_tools_or_functions(
+ self, mock_azure_provider
+ ):
+ mock_azure_provider.build_tool_instructions = MagicMock()
+ mock_azure_provider.build_function_instructions = MagicMock()
+
+ openai_message = [{"role": "system", "content": "Minimal system message"}]
+ llama_message = "Initial message"
+ tools = []
+ functions = []
+
+ result = mock_azure_provider.build_llama_system_message(
+ openai_message, llama_message, tools, functions
+ )
+
+ expected = (
+ "Initial message\n"
+ " <|start_header_id|>system<|end_header_id|>\n"
+ " Minimal system message\n \n<|eot_id|>"
+ )
+ assert result == expected
+ mock_azure_provider.build_tool_instructions.assert_not_called()
+ mock_azure_provider.build_function_instructions.assert_not_called()
+
+
+class TestBuildInstructions:
+ def test_build_tool_instructions(self, mock_azure_provider):
+ tools = [
+ {
+ "type": "function",
+ "function": {
+ "name": "python_repl_ast",
+ "description": "execute Python code",
+ "parameters": {"query": "string"},
+ },
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "data_lookup",
+ "description": "retrieve data from a database",
+ "parameters": {"database": "string", "query": "string"},
+ },
+ },
+ ]
+
+ result = mock_azure_provider.build_tool_instructions(tools)
+
+ expected = """
+ You have access to the following tools:
+ Use the function 'python_repl_ast' to 'execute Python code':
+Parameters format:
+{
+ "query": "string"
+}
+
+Use the function 'data_lookup' to 'retrieve data from a database':
+Parameters format:
+{
+ "database": "string",
+ "query": "string"
+}
+
+
+If you choose to use a function to produce this response, ONLY reply in the following format with no prefix or suffix:
+§{"type": "function", "name": "FUNCTION_NAME", "parameters": {"PARAMETER_NAME": PARAMETER_VALUE}}
+IMPORTANT: IT IS VITAL THAT YOU NEVER ADD A PREFIX OR A SUFFIX TO THE FUNCTION CALL.
+
+Here is an example of the output I desiere when performing function call:
+§{"type": "function", "name": "python_repl_ast", "parameters": {"query": "print(df.shape)"}}
+NOTE: There is no prefix before the symbol '§' and nothing comes after the call is done.
+
+ Reminder:
+ - Function calls MUST follow the specified format.
+ - Only call one function at a time.
+ - Required parameters MUST be specified.
+ - Put the entire function call reply on one line.
+ - If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls.
+ - If you have already called a tool and got the response for the users question please reply with the response.
+ """
+ assert result.strip() == expected.strip()
+
+ def test_build_function_instructions(self, mock_azure_provider):
+ functions = [
+ {
+ "name": "python_repl_ast",
+ "description": "execute Python code",
+ "parameters": {"query": "string"},
+ },
+ {
+ "name": "data_lookup",
+ "description": "retrieve data from a database",
+ "parameters": {"database": "string", "query": "string"},
+ },
+ ]
+
+ result = mock_azure_provider.build_function_instructions(functions)
+
+ expected = """
+You have access to the following functions:
+Use the function 'python_repl_ast' to: 'execute Python code'
+{
+ "query": "string"
+}
+
+Use the function 'data_lookup' to: 'retrieve data from a database'
+{
+ "database": "string",
+ "query": "string"
+}
+
+
+If you choose to use a function to produce this response, ONLY reply in the following format with no prefix or suffix:
+§{"type": "function", "name": "FUNCTION_NAME", "parameters": {"PARAMETER_NAME": PARAMETER_VALUE}}
+
+Here is an example of the output I desiere when performing function call:
+§{"type": "function", "name": "python_repl_ast", "parameters": {"query": "print(df.shape)"}}
+
+Reminder:
+- Function calls MUST follow the specified format.
+- Only call one function at a time.
+- NEVER call more than one function at a time.
+- Required parameters MUST be specified.
+- Put the entire function call reply on one line.
+- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls.
+- If you have already called a function and got the response for the user's question, please reply with the response.
+"""
+
+ assert result.strip() == expected.strip()
+
+
+class TestBuildLlamaConversation:
+ def test_build_llama_conversation_with_nested_messages(self, mock_azure_provider):
+ mock_azure_provider.format_message = MagicMock(
+ side_effect=lambda msg: f"[formatted:{msg['content']}]"
+ )
+
+ openai_message = [
+ {
+ "role": "user",
+ "content": "[{'content': 'nested message 1'}, {'content': 'nested message 2'}]",
+ },
+ {"role": "assistant", "content": "assistant reply"},
+ ]
+ llama_message = "Initial message: "
+
+ result = mock_azure_provider.build_llama_conversation(
+ openai_message, llama_message
+ )
+
+ expected = "Initial message: [formatted:nested message 1][formatted:nested message 2][formatted:assistant reply]"
+
+ assert result == expected
+ mock_azure_provider.format_message.assert_any_call(
+ {"content": "nested message 1"}
+ )
+ mock_azure_provider.format_message.assert_any_call(
+ {"content": "nested message 2"}
+ )
+ mock_azure_provider.format_message.assert_any_call(
+ {"role": "assistant", "content": "assistant reply"}
+ )
+
+ def test_build_llama_conversation_with_invalid_nested_content(
+ self, mock_azure_provider
+ ):
+ mock_azure_provider.format_message = MagicMock(
+ side_effect=lambda msg: f"[formatted:{msg['content']}]"
+ )
+
+ openai_message = [
+ {"role": "user", "content": "[invalid json/dict]"},
+ {"role": "assistant", "content": "assistant reply"},
+ ]
+ llama_message = "Initial message: "
+
+ with patch("ast.literal_eval", side_effect=ValueError) as mock_literal_eval:
+ result = mock_azure_provider.build_llama_conversation(
+ openai_message, llama_message
+ )
+
+ expected = "Initial message: [formatted:[invalid json/dict]][formatted:assistant reply]"
+
+ assert result == expected
+ mock_azure_provider.format_message.assert_any_call(
+ {"role": "user", "content": "[invalid json/dict]"}
+ )
+ mock_azure_provider.format_message.assert_any_call(
+ {"role": "assistant", "content": "assistant reply"}
+ )
+
+ mock_literal_eval.assert_called_once_with("[invalid json/dict]")
+
+ def test_build_llama_conversation_skipping_system_messages(
+ self, mock_azure_provider
+ ):
+ mock_azure_provider.format_message = MagicMock(
+ side_effect=lambda msg: f"[formatted:{msg['content']}]"
+ )
+
+ openai_message = [
+ {"role": "system", "content": "system message"},
+ {"role": "user", "content": "user message"},
+ ]
+ llama_message = "Initial message: "
+
+ result = mock_azure_provider.build_llama_conversation(
+ openai_message, llama_message
+ )
+
+ expected = "Initial message: [formatted:user message]"
+
+ assert result == expected
+ mock_azure_provider.format_message.assert_any_call(
+ {"role": "user", "content": "user message"}
+ )
diff --git a/libs/core/tests/unit_tests/test_provider.py b/libs/core/tests/unit_tests/test_provider.py
index 118367a3..42396c99 100644
--- a/libs/core/tests/unit_tests/test_provider.py
+++ b/libs/core/tests/unit_tests/test_provider.py
@@ -1,30 +1,39 @@
-from unittest.mock import AsyncMock, MagicMock
+from unittest.mock import MagicMock
import pytest
-from llmstudio_core.providers.provider import ChatRequest, ProviderError
+from llmstudio_core.providers.provider import ChatRequest, ProviderError, time
-request = ChatRequest(chat_input="Hello", model="test_model")
+request = ChatRequest(chat_input="Hello World", model="test_model")
-def test_chat(mock_provider):
- mock_provider.generate_client = MagicMock(return_value=MagicMock())
- mock_provider.handle_response = MagicMock(return_value=iter(["response"]))
+def test_chat_response_non_stream(mock_provider):
+ mock_provider.validate_request = MagicMock()
+ mock_provider.validate_model = MagicMock()
+ mock_provider.generate_client = MagicMock(return_value="mock_response")
+ mock_provider.handle_response = MagicMock(return_value="final_response")
- print(request.model_dump())
- response = mock_provider.chat(request.chat_input, request.model)
+ response = mock_provider.chat(chat_input="Hello", model="test_model")
- assert response is not None
+ assert response == "final_response"
+ mock_provider.validate_request.assert_called_once()
+ mock_provider.validate_model.assert_called_once()
-@pytest.mark.asyncio
-async def test_achat(mock_provider):
- mock_provider.agenerate_client = AsyncMock(return_value=AsyncMock())
- mock_provider.ahandle_response = AsyncMock(return_value=AsyncMock())
-
- print(request.model_dump())
- response = await mock_provider.achat(request.chat_input, request.model)
+def test_chat_streaming_response(mock_provider):
+ mock_provider.validate_request = MagicMock()
+ mock_provider.validate_model = MagicMock()
+ mock_provider.generate_client = MagicMock(return_value="mock_response_stream")
+ mock_provider.handle_response = MagicMock(
+ return_value=iter(["streamed_response_1", "streamed_response_2"])
+ )
- assert response is not None
+ response_stream = mock_provider.chat(
+ chat_input="Hello", model="test_model", is_stream=True
+ )
+ assert next(response_stream) == "streamed_response_1"
+ assert next(response_stream) == "streamed_response_2"
+ mock_provider.validate_request.assert_called_once()
+ mock_provider.validate_model.assert_called_once()
def test_validate_model(mock_provider):
@@ -36,22 +45,214 @@ def test_validate_model(mock_provider):
mock_provider.validate_model(request_invalid)
-def test_calculate_metrics(mock_provider):
- metrics = mock_provider.calculate_metrics(
- input="Hello",
- output="World",
- model="test_model",
- start_time=0,
- end_time=1,
- first_token_time=0.5,
- token_times=(0.1, 0.2),
- token_count=2,
- )
+def test_join_chunks_finish_reason_stop(mock_provider):
+ current_time = int(time.time())
+ chunks = [
+ {
+ "id": "test_id",
+ "model": "test_model",
+ "created": current_time,
+ "choices": [
+ {
+ "delta": {"content": "Hello, "},
+ "finish_reason": None,
+ "index": 0,
+ }
+ ],
+ },
+ {
+ "id": "test_id",
+ "model": "test_model",
+ "created": current_time,
+ "choices": [
+ {
+ "delta": {"content": "world!"},
+ "finish_reason": "stop",
+ "index": 0,
+ }
+ ],
+ },
+ ]
+ response, output_string = mock_provider.join_chunks(chunks)
+
+ assert output_string == "Hello, world!"
+ assert response.choices[0].message.content == "Hello, world!"
+
+
+def test_join_chunks_finish_reason_function_call(mock_provider):
+ current_time = int(time.time())
+ chunks = [
+ {
+ "id": "test_id",
+ "model": "test_model",
+ "created": current_time,
+ "choices": [
+ {
+ "delta": {
+ "function_call": {"name": "my_function", "arguments": "arg1"}
+ },
+ "finish_reason": None,
+ "index": 0,
+ }
+ ],
+ },
+ {
+ "id": "test_id",
+ "model": "test_model",
+ "created": current_time,
+ "choices": [
+ {
+ "delta": {"function_call": {"arguments": "arg2"}},
+ "finish_reason": "function_call",
+ "index": 0,
+ }
+ ],
+ },
+ {
+ "id": "test_id",
+ "model": "test_model",
+ "created": current_time,
+ "choices": [
+ {
+ "delta": {"function_call": {"arguments": "}"}},
+ "finish_reason": "function_call",
+ "index": 0,
+ }
+ ],
+ },
+ ]
+ response, output_string = mock_provider.join_chunks(chunks)
+
+ assert output_string == "arg1arg2"
+ assert response.choices[0].message.function_call.arguments == "arg1arg2"
+ assert response.choices[0].message.function_call.name == "my_function"
+
+
+def test_join_chunks_tool_calls(mock_provider):
+ current_time = int(time.time())
+
+ chunks = [
+ {
+ "id": "test_id_1",
+ "model": "test_model",
+ "created": current_time,
+ "choices": [
+ {
+ "delta": {
+ "tool_calls": [
+ {
+ "id": "tool_1",
+ "index": 0,
+ "function": {
+ "name": "search_tool",
+ "arguments": '{"query": "weather',
+ },
+ "type": "function",
+ }
+ ]
+ },
+ "finish_reason": None,
+ "index": 0,
+ }
+ ],
+ },
+ {
+ "id": "test_id_2",
+ "model": "test_model",
+ "created": current_time,
+ "choices": [
+ {
+ "delta": {
+ "tool_calls": [
+ {
+ "id": "tool_1",
+ "index": 0,
+ "function": {
+ "name": "search_tool",
+ "arguments": ' details"}',
+ },
+ }
+ ]
+ },
+ "finish_reason": "tool_calls",
+ "index": 0,
+ }
+ ],
+ },
+ ]
+
+ response, output_string = mock_provider.join_chunks(chunks)
+
+ assert output_string == "['search_tool', '{\"query\": \"weather details\"}']"
+
+ assert response.object == "chat.completion"
+ assert response.choices[0].finish_reason == "tool_calls"
+ tool_call = response.choices[0].message.tool_calls[0]
+
+ assert tool_call.id == "tool_1"
+ assert tool_call.function.name == "search_tool"
+ assert tool_call.function.arguments == '{"query": "weather details"}'
+ assert tool_call.type == "function"
+
+
+def test_input_to_string_with_string(mock_provider):
+ input_data = "Hello, world!"
+ assert mock_provider.input_to_string(input_data) == "Hello, world!"
+
+
+def test_input_to_string_with_list_of_text_messages(mock_provider):
+ input_data = [
+ {"content": "Hello"},
+ {"content": " world!"},
+ ]
+ assert mock_provider.input_to_string(input_data) == "Hello world!"
+
+
+def test_input_to_string_with_list_of_text_and_url(mock_provider):
+ input_data = [
+ {"role": "user", "content": [{"type": "text", "text": "Hello "}]},
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image_url",
+ "image_url": {"url": "http://example.com/image.jpg"},
+ }
+ ],
+ },
+ {"role": "user", "content": [{"type": "text", "text": " world!"}]},
+ ]
+ expected_output = "Hello http://example.com/image.jpg world!"
+ assert mock_provider.input_to_string(input_data) == expected_output
+
+
+def test_input_to_string_with_mixed_roles_and_missing_content(mock_provider):
+ input_data = [
+ {"role": "assistant", "content": "Admin text;"},
+ {"role": "user", "content": [{"type": "text", "text": "User text"}]},
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image_url",
+ "image_url": {"url": "http://example.com/another.jpg"},
+ }
+ ],
+ },
+ ]
+ expected_output = "Admin text;User texthttp://example.com/another.jpg"
+ assert mock_provider.input_to_string(input_data) == expected_output
+
+
+def test_input_to_string_with_missing_content_key(mock_provider):
+ input_data = [
+ {"role": "user"},
+ {"role": "user", "content": [{"type": "text", "text": "Hello again"}]},
+ ]
+ expected_output = "Hello again"
+ assert mock_provider.input_to_string(input_data) == expected_output
+
- assert metrics["input_tokens"] == pytest.approx(1)
- assert metrics["output_tokens"] == pytest.approx(1)
- assert metrics["cost_usd"] == pytest.approx(0.03)
- assert metrics["latency_s"] == pytest.approx(1)
- assert metrics["time_to_first_token_s"] == pytest.approx(0.5)
- assert metrics["inter_token_latency_s"] == pytest.approx(0.15)
- assert metrics["tokens_per_second"] == pytest.approx(2)
+def test_input_to_string_with_empty_list(mock_provider):
+ input_data = []
+ assert mock_provider.input_to_string(input_data) == ""
diff --git a/libs/core/tests/unit_tests/test_provider_costs_and_metrics.py b/libs/core/tests/unit_tests/test_provider_costs_and_metrics.py
new file mode 100644
index 00000000..fb54d602
--- /dev/null
+++ b/libs/core/tests/unit_tests/test_provider_costs_and_metrics.py
@@ -0,0 +1,128 @@
+from unittest.mock import MagicMock
+
+
+def test_calculate_metrics(mock_provider):
+
+ metrics = mock_provider.calculate_metrics(
+ input="Hello",
+ output="Hello World",
+ model="test_model",
+ start_time=0.0,
+ end_time=1.0,
+ first_token_time=0.5,
+ token_times=(0.1,),
+ token_count=2,
+ )
+
+ assert metrics["input_tokens"] == 1
+ assert metrics["output_tokens"] == 2
+ assert metrics["total_tokens"] == 3
+ assert metrics["cost_usd"] == 0.01 * 1 + 0.02 * 2 # input_cost + output_cost
+ assert metrics["latency_s"] == 1.0 # end_time - start_time
+ assert (
+ metrics["time_to_first_token_s"] == 0.5 - 0.0
+ ) # first_token_time - start_time
+ assert metrics["inter_token_latency_s"] == 0.1 # Average of token_times
+ assert metrics["tokens_per_second"] == 2 / 1.0 # token_count / total_time
+
+
+def test_calculate_metrics_single_token(mock_provider):
+
+ metrics = mock_provider.calculate_metrics(
+ input="Hello",
+ output="World",
+ model="test_model",
+ start_time=0.0,
+ end_time=1.0,
+ first_token_time=0.5,
+ token_times=(),
+ token_count=1,
+ )
+
+ assert metrics["input_tokens"] == 1
+ assert metrics["output_tokens"] == 1
+ assert metrics["total_tokens"] == 2
+ assert metrics["cost_usd"] == 0.01 * 1 + 0.02 * 1
+ assert metrics["latency_s"] == 1.0
+ assert metrics["time_to_first_token_s"] == 0.5 - 0.0
+ assert metrics["inter_token_latency_s"] == 0
+ assert metrics["tokens_per_second"] == 1 / 1.0
+
+
+def test_calculate_cost_fixed_cost(mock_provider):
+ fixed_cost = 0.02
+ token_count = 100
+ expected_cost = token_count * fixed_cost
+ assert mock_provider.calculate_cost(token_count, fixed_cost) == expected_cost
+
+
+def test_calculate_cost_variable_cost(mock_provider):
+ cost_range_1 = MagicMock()
+ cost_range_1.range = (0, 50)
+ cost_range_1.cost = 0.01
+
+ cost_range_2 = MagicMock()
+ cost_range_2.range = (51, 100)
+ cost_range_2.cost = 0.02
+
+ variable_cost = [cost_range_1, cost_range_2]
+ token_count = 75
+ expected_cost = token_count * 0.02
+ assert mock_provider.calculate_cost(token_count, variable_cost) == expected_cost
+
+
+def test_calculate_cost_variable_cost_higher_range(mock_provider):
+ cost_range_1 = MagicMock()
+ cost_range_1.range = (0, 50)
+ cost_range_1.cost = 0.01
+
+ cost_range_2 = MagicMock()
+ cost_range_2.range = (51, 100)
+ cost_range_2.cost = 0.02
+
+ cost_range_3 = MagicMock()
+ cost_range_3.range = (101, None)
+ cost_range_3.cost = 0.03
+
+ variable_cost = [cost_range_1, cost_range_2, cost_range_3]
+ token_count = 150
+ expected_cost = token_count * 0.03
+ assert mock_provider.calculate_cost(token_count, variable_cost) == expected_cost
+
+
+def test_calculate_cost_variable_cost_no_matching_range(mock_provider):
+ cost_range_1 = MagicMock()
+ cost_range_1.range = (0, 50)
+ cost_range_1.cost = 0.01
+
+ cost_range_2 = MagicMock()
+ cost_range_2.range = (51, 100)
+ cost_range_2.cost = 0.02
+
+ cost_range_3 = MagicMock()
+ cost_range_3.range = (101, 150)
+ cost_range_3.cost = 0.03
+
+ variable_cost = [cost_range_1, cost_range_2, cost_range_3]
+ token_count = 200
+ expected_cost = 0
+ assert mock_provider.calculate_cost(token_count, variable_cost) == expected_cost
+
+
+def test_calculate_cost_variable_cost_no_matching_range_inferior(mock_provider):
+ cost_range_1 = MagicMock()
+ cost_range_1.range = (10, 50)
+ cost_range_1.cost = 0.01
+
+ cost_range_2 = MagicMock()
+ cost_range_2.range = (51, 100)
+ cost_range_2.cost = 0.02
+
+ cost_range_3 = MagicMock()
+ cost_range_3.range = (101, 150)
+ cost_range_3.cost = 0.03
+
+ variable_cost = [cost_range_1, cost_range_2, cost_range_3]
+ token_count = 5
+ expected_cost = 0
+ assert mock_provider.calculate_cost(token_count, variable_cost) == expected_cost
diff --git a/libs/core/tests/unit_tests/test_provider_handle_response.py b/libs/core/tests/unit_tests/test_provider_handle_response.py
new file mode 100644
index 00000000..ac04e32b
--- /dev/null
+++ b/libs/core/tests/unit_tests/test_provider_handle_response.py
@@ -0,0 +1,281 @@
+from unittest.mock import MagicMock
+
+import pytest
+from llmstudio_core.providers.provider import ChatCompletion, ChatCompletionChunk, time
+
+
+@pytest.mark.asyncio
+async def test_ahandle_response_non_streaming(mock_provider):
+ request = MagicMock(
+ is_stream=False, chat_input="Hello", model="test_model", parameters={}
+ )
+ response_chunk = {
+ "choices": [
+ {
+ "delta": {"content": "Non-streamed response"},
+ "finish_reason": "stop",
+ "index": 0,
+ }
+ ],
+ "model": "test_model",
+ }
+ start_time = time.time()
+
+ async def mock_aparse_response(*args, **kwargs):
+ yield response_chunk
+
+ mock_provider.aparse_response = mock_aparse_response
+ mock_provider.join_chunks = MagicMock(
+ return_value=(
+ ChatCompletion(
+ id="id",
+ choices=[],
+ created=0,
+ model="test_model",
+ object="chat.completion",
+ ),
+ "Non-streamed response",
+ )
+ )
+ mock_provider.calculate_metrics = MagicMock(return_value={"input_tokens": 1})
+
+ response = []
+ async for chunk in mock_provider.ahandle_response(
+ request, mock_aparse_response(), start_time
+ ):
+ response.append(chunk)
+
+ assert isinstance(response[0], ChatCompletion)
+ assert response[0].choices == []
+ assert response[0].chat_output == "Non-streamed response"
+
+
+@pytest.mark.asyncio
+async def test_ahandle_response_streaming_length(mock_provider):
+ request = MagicMock(
+ is_stream=True, chat_input="Hello", model="test_model", parameters={}
+ )
+ response_chunk = {
+ "choices": [
+ {
+ "delta": {"content": "Streamed response"},
+ "finish_reason": "length",
+ "index": 0,
+ }
+ ],
+ "model": "test_model",
+ "object": "chat.completion.chunk",
+ "created": 0,
+ }
+ start_time = time.time()
+
+ async def mock_aparse_response(*args, **kwargs):
+ yield response_chunk
+
+ mock_provider.aparse_response = mock_aparse_response
+ mock_provider.join_chunks = MagicMock(
+ return_value=(
+ ChatCompletion(
+ id="id",
+ choices=[],
+ created=0,
+ model="test_model",
+ object="chat.completion",
+ ),
+ "Streamed response",
+ )
+ )
+ mock_provider.calculate_metrics = MagicMock(return_value={"input_tokens": 1})
+
+ response = []
+ async for chunk in mock_provider.ahandle_response(
+ request, mock_aparse_response(), start_time
+ ):
+ response.append(chunk)
+
+ assert isinstance(response[0], ChatCompletionChunk)
+ assert response[0].chat_output_stream == "Streamed response"
+
+
+@pytest.mark.asyncio
+async def test_ahandle_response_streaming_stop(mock_provider):
+ request = MagicMock(
+ is_stream=True, chat_input="Hello", model="test_model", parameters={}
+ )
+ response_chunk = {
+ "choices": [
+ {
+ "delta": {"content": "Streamed response"},
+ "finish_reason": "stop",
+ "index": 0,
+ }
+ ],
+ "model": "test_model",
+ "object": "chat.completion.chunk",
+ "created": 0,
+ }
+ start_time = time.time()
+
+ async def mock_aparse_response(*args, **kwargs):
+ yield response_chunk
+
+ mock_provider.aparse_response = mock_aparse_response
+ mock_provider.join_chunks = MagicMock(
+ return_value=(
+ ChatCompletion(
+ id="id",
+ choices=[],
+ created=0,
+ model="test_model",
+ object="chat.completion",
+ ),
+ "Streamed response",
+ )
+ )
+ mock_provider.calculate_metrics = MagicMock(return_value={"input_tokens": 1})
+
+ response = []
+ async for chunk in mock_provider.ahandle_response(
+ request, mock_aparse_response(), start_time
+ ):
+ response.append(chunk)
+
+ assert isinstance(response[0], ChatCompletionChunk)
+ assert response[0].chat_output == "Streamed response"
+
+
+def test_handle_response_non_streaming(mock_provider):
+ request = MagicMock(
+ is_stream=False, chat_input="Hello", model="test_model", parameters={}
+ )
+ response_chunk = {
+ "choices": [
+ {
+ "delta": {"content": "Non-streamed response"},
+ "finish_reason": "stop",
+ "index": 0,
+ }
+ ],
+ "model": "test_model",
+ }
+ start_time = time.time()
+
+ def mock_parse_response(*args, **kwargs):
+ yield response_chunk
+
+ mock_provider.aparse_response = mock_parse_response
+ mock_provider.join_chunks = MagicMock(
+ return_value=(
+ ChatCompletion(
+ id="id",
+ choices=[],
+ created=0,
+ model="test_model",
+ object="chat.completion",
+ ),
+ "Non-streamed response",
+ )
+ )
+ mock_provider.calculate_metrics = MagicMock(return_value={"input_tokens": 1})
+
+ response = []
+ for chunk in mock_provider.handle_response(
+ request, mock_parse_response(), start_time
+ ):
+ response.append(chunk)
+
+ assert isinstance(response[0], ChatCompletion)
+ assert response[0].choices == []
+ assert response[0].chat_output == "Non-streamed response"
+
+
+def test_handle_response_streaming_length(mock_provider):
+ request = MagicMock(
+ is_stream=True, chat_input="Hello", model="test_model", parameters={}
+ )
+ response_chunk = {
+ "choices": [
+ {
+ "delta": {"content": "Streamed response"},
+ "finish_reason": "length",
+ "index": 0,
+ }
+ ],
+ "model": "test_model",
+ "object": "chat.completion.chunk",
+ "created": 0,
+ }
+ start_time = time.time()
+
+ def mock_parse_response(*args, **kwargs):
+ yield response_chunk
+
+ mock_provider.aparse_response = mock_parse_response
+ mock_provider.join_chunks = MagicMock(
+ return_value=(
+ ChatCompletion(
+ id="id",
+ choices=[],
+ created=0,
+ model="test_model",
+ object="chat.completion",
+ ),
+ "Streamed response",
+ )
+ )
+ mock_provider.calculate_metrics = MagicMock(return_value={"input_tokens": 1})
+
+ response = []
+ for chunk in mock_provider.handle_response(
+ request, mock_parse_response(), start_time
+ ):
+ response.append(chunk)
+
+ assert isinstance(response[0], ChatCompletionChunk)
+ assert response[0].chat_output_stream == "Streamed response"
+
+
+def test_handle_response_streaming_stop(mock_provider):
+ request = MagicMock(
+ is_stream=True, chat_input="Hello", model="test_model", parameters={}
+ )
+ response_chunk = {
+ "choices": [
+ {
+ "delta": {"content": "Streamed response"},
+ "finish_reason": "stop",
+ "index": 0,
+ }
+ ],
+ "model": "test_model",
+ "object": "chat.completion.chunk",
+ "created": 0,
+ }
+ start_time = time.time()
+
+ def mock_parse_response(*args, **kwargs):
+ yield response_chunk
+
+ mock_provider.parse_response = mock_parse_response
+ mock_provider.join_chunks = MagicMock(
+ return_value=(
+ ChatCompletion(
+ id="id",
+ choices=[],
+ created=0,
+ model="test_model",
+ object="chat.completion",
+ ),
+ "Streamed response",
+ )
+ )
+ mock_provider.calculate_metrics = MagicMock(return_value={"input_tokens": 1})
+
+ response = []
+ for chunk in mock_provider.handle_response(
+ request, mock_parse_response(), start_time
+ ):
+ response.append(chunk)
+
+ assert isinstance(response[0], ChatCompletionChunk)
+ assert response[0].chat_output == "Streamed response"