From 8a8b73a32e09adaa2bd06b7710374dfa5adafc3e Mon Sep 17 00:00:00 2001 From: Miguel Neves Date: Mon, 18 Nov 2024 12:31:45 +0000 Subject: [PATCH 1/5] chore: added unit tests for azure provider --- libs/core/llmstudio_core/providers/azure.py | 70 +++++- libs/core/tests/unit_tests/conftest.py | 29 +++ libs/core/tests/unit_tests/test_azure.py | 233 ++++++++++++++++++++ 3 files changed, 326 insertions(+), 6 deletions(-) create mode 100644 libs/core/tests/unit_tests/test_azure.py diff --git a/libs/core/llmstudio_core/providers/azure.py b/libs/core/llmstudio_core/providers/azure.py index 6924f411..5adb75d9 100644 --- a/libs/core/llmstudio_core/providers/azure.py +++ b/libs/core/llmstudio_core/providers/azure.py @@ -161,7 +161,7 @@ def prepare_messages(self, request: ChatRequest): if self.is_llama and (self.has_tools or self.has_functions): user_message = self.convert_to_openai_format(request.chat_input) content = "<|begin_of_text|>" - content = self.add_system_message( + content = self.build_llama_system_message( user_message, content, request.parameters.get("tools"), @@ -190,6 +190,20 @@ async def aparse_response( yield c def parse_response(self, response: Generator, **kwargs) -> Any: + """ + Processes a generator response and yields processed chunks. + + If `is_llama` is True and tools or functions are enabled, it processes the response + using `handle_tool_response`. Otherwise, it processes each chunk and yields only those + containing "choices". + + Args: + response (Generator): The response generator to process. + **kwargs: Additional arguments for tool handling. + + Yields: + Any: Processed response chunks. + """ if self.is_llama and (self.has_tools or self.has_functions): for chunk in self.handle_tool_response(response, **kwargs): if chunk: @@ -439,9 +453,25 @@ def convert_to_openai_format(self, message: Union[str, list]) -> list: return [{"role": "user", "content": message}] return message - def add_system_message( + def build_llama_system_message( self, openai_message: list, llama_message: str, tools: list, functions: list ) -> str: + """ + Builds a complete system message for Llama based on OpenAI's message, tools, and functions. + + If a system message is present in the OpenAI message, it is included in the result. + Otherwise, a default system message is used. Additional tool and function instructions + are appended if provided. + + Args: + openai_message (list): List of OpenAI messages. + llama_message (str): The message to prepend to the system message. + tools (list): List of tools to include in the system message. + functions (list): List of functions to include in the system message. + + Returns: + str: The formatted system message combined with Llama message. + """ system_message = "" system_message_found = False for message in openai_message: @@ -458,15 +488,29 @@ def add_system_message( """ if tools: - system_message = system_message + self.add_tool_instructions(tools) + system_message = system_message + self.build_tool_instructions(tools) if functions: - system_message = system_message + self.add_function_instructions(functions) + system_message = system_message + self.build_function_instructions(functions) end_tag = "\n<|eot_id|>" return llama_message + system_message + end_tag - def add_tool_instructions(self, tools: list) -> str: + def build_tool_instructions(self, tools: list) -> str: + """ + Builds a detailed instructional prompt for tools available to the assistant. + + This function generates a message describing the available tools, focusing on tools + of type "function." It explains to the LLM how to use each tool and provides an example of the + correct response format for function calls. + + Args: + tools (list): A list of tool dictionaries, where each dictionary contains tool + details such as type, function name, description, and parameters. + + Returns: + str: A formatted string detailing the tool instructions and usage examples. + """ tool_prompt = """ You have access to the following tools: """ @@ -500,7 +544,21 @@ def add_tool_instructions(self, tools: list) -> str: return tool_prompt - def add_function_instructions(self, functions: list) -> str: + def build_function_instructions(self, functions: list) -> str: + """ + Builds a detailed instructional prompt for available functions. + + This method creates a message describing the functions accessible to the assistant. + It includes the function name, description, and required parameters, along with + specific guidelines for calling functions. + + Args: + functions (list): A list of function dictionaries, each containing details such as + name, description, and parameters. + + Returns: + str: A formatted string with instructions on using the provided functions. + """ function_prompt = """ You have access to the following functions: """ diff --git a/libs/core/tests/unit_tests/conftest.py b/libs/core/tests/unit_tests/conftest.py index 955f9276..fcc74f16 100644 --- a/libs/core/tests/unit_tests/conftest.py +++ b/libs/core/tests/unit_tests/conftest.py @@ -2,6 +2,7 @@ import pytest from llmstudio_core.providers.provider import ProviderCore +from llmstudio_core.providers.azure import AzureProvider class MockProvider(ProviderCore): @@ -50,3 +51,31 @@ def mock_provider(): tokenizer = MagicMock() tokenizer.encode = lambda x: x.split() # Simple tokenizer mock return MockProvider(config=config, tokenizer=tokenizer) + + +class MockAzureProvider(AzureProvider): + async def aparse_response(self, response, **kwargs): + return response + + async def agenerate_client(self, request): + # For testing, return an async generator + async def async_gen(): + yield {} + return async_gen() + + def generate_client(self, request): + # For testing, return a generator + def gen(): + yield {} + return gen() + + @staticmethod + def _provider_config_name(): + return "mock_azure_provider" + +@pytest.fixture +def mock_azure_provider(): + config = MagicMock() + config.id = "mock_azure_provider" + base_url = 'mock_url.com' + return MockAzureProvider(config=config, base_url=base_url) \ No newline at end of file diff --git a/libs/core/tests/unit_tests/test_azure.py b/libs/core/tests/unit_tests/test_azure.py new file mode 100644 index 00000000..fc4dd026 --- /dev/null +++ b/libs/core/tests/unit_tests/test_azure.py @@ -0,0 +1,233 @@ +import pytest +from unittest.mock import MagicMock + +class TestParseResponse: + def test_tool_response_handling(self, mock_azure_provider): + + mock_azure_provider.is_llama = True + mock_azure_provider.has_tools = True + mock_azure_provider.has_functions = False + + mock_azure_provider.handle_tool_response = MagicMock(return_value=iter(["chunk1", None, "chunk2", None, "chunk3"])) + + response = iter(["irrelevant"]) + + results = list(mock_azure_provider.parse_response(response)) + + assert results == ["chunk1", "chunk2", "chunk3"] + mock_azure_provider.handle_tool_response.assert_called_once_with(response) + + def test_direct_response_handling_with_choices(self, mock_azure_provider): + mock_azure_provider.is_llama = False + + chunk1 = MagicMock() + chunk1.model_dump.return_value = {"choices": ["choice1","choice2"]} + chunk2 = MagicMock() + chunk2.model_dump.return_value = {"choices": ["choice2"]} + response = iter([chunk1, chunk2]) + + results = list(mock_azure_provider.parse_response(response)) + + assert results == [{"choices": ["choice1", "choice2"]}, {"choices": ["choice2"]}] + chunk1.model_dump.assert_called_once() + chunk2.model_dump.assert_called_once() + + def test_direct_response_handling_without_choices(self, mock_azure_provider): + mock_azure_provider.is_llama = False + + chunk1 = MagicMock() + chunk1.model_dump.return_value = {"key": "value"} + chunk2 = MagicMock() + chunk2.model_dump.return_value = {"another_key": "another_value"} + response = iter([chunk1, chunk2]) + + results = list(mock_azure_provider.parse_response(response)) + + assert results == [] + chunk1.model_dump.assert_called_once() + chunk2.model_dump.assert_called_once() + +class TestBuildLlamaSystemMessage: + def test_with_existing_system_message(self, mock_azure_provider): + mock_azure_provider.build_tool_instructions = MagicMock(return_value="Tool Instructions") + mock_azure_provider.build_function_instructions = MagicMock(return_value="\nFunction Instructions") + + openai_message = [ + {"role": "user", "content": "Hello"}, + {"role": "system", "content": "Custom system message"} + ] + llama_message = "Initial message" + tools = ["Tool1", "Tool2"] + functions = ["Function1"] + + result = mock_azure_provider.build_llama_system_message(openai_message, llama_message, tools, functions) + + expected = ( + "Initial message\n" + " <|start_header_id|>system<|end_header_id|>\n" + " Custom system message\n" + " Tool Instructions\nFunction Instructions\n<|eot_id|>" # identation here exists because in Python when adding a newline to a triple quote string it keeps identation + ) + assert result == expected + mock_azure_provider.build_tool_instructions.assert_called_once_with(tools) + mock_azure_provider.build_function_instructions.assert_called_once_with(functions) + + def test_with_default_system_message(self, mock_azure_provider): + mock_azure_provider.build_tool_instructions = MagicMock(return_value="Tool Instructions") + mock_azure_provider.build_function_instructions = MagicMock(return_value="\nFunction Instructions") + + openai_message = [{"role": "user", "content": "Hello"}] + llama_message = "Initial message" + tools = ["Tool1"] + functions = [] + + result = mock_azure_provider.build_llama_system_message(openai_message, llama_message, tools, functions) + + expected = ( + "Initial message\n" + " <|start_header_id|>system<|end_header_id|>\n" + " You are a helpful AI assistant.\n" + " Tool Instructions\n<|eot_id|>" + ) + assert result == expected + mock_azure_provider.build_tool_instructions.assert_called_once_with(tools) + mock_azure_provider.build_function_instructions.assert_not_called() + + def test_without_tools_or_functions(self, mock_azure_provider): + mock_azure_provider.build_tool_instructions = MagicMock() + mock_azure_provider.build_function_instructions = MagicMock() + + openai_message = [{"role": "system", "content": "Minimal system message"}] + llama_message = "Initial message" + tools = [] + functions = [] + + result = mock_azure_provider.build_llama_system_message(openai_message, llama_message, tools, functions) + + expected = ( + "Initial message\n" + " <|start_header_id|>system<|end_header_id|>\n" + " Minimal system message\n \n<|eot_id|>" + ) + assert result == expected + mock_azure_provider.build_tool_instructions.assert_not_called() + mock_azure_provider.build_function_instructions.assert_not_called() + +class TestBuildInstructions: + def test_build_tool_instructions(self, mock_azure_provider): + tools = [ + { + "type": "function", + "function": { + "name": "python_repl_ast", + "description": "execute Python code", + "parameters": { + "query": "string" + } + } + }, + { + "type": "function", + "function": { + "name": "data_lookup", + "description": "retrieve data from a database", + "parameters": { + "database": "string", + "query": "string" + } + } + } + ] + + result = mock_azure_provider.build_tool_instructions(tools) + + expected = ( + """ + You have access to the following tools: + Use the function 'python_repl_ast' to 'execute Python code': +Parameters format: +{ + "query": "string" +} + +Use the function 'data_lookup' to 'retrieve data from a database': +Parameters format: +{ + "database": "string", + "query": "string" +} + + +If you choose to use a function to produce this response, ONLY reply in the following format with no prefix or suffix: +§{"type": "function", "name": "FUNCTION_NAME", "parameters": {"PARAMETER_NAME": PARAMETER_VALUE}} +IMPORTANT: IT IS VITAL THAT YOU NEVER ADD A PREFIX OR A SUFFIX TO THE FUNCTION CALL. + +Here is an example of the output I desiere when performing function call: +§{"type": "function", "name": "python_repl_ast", "parameters": {"query": "print(df.shape)"}} +NOTE: There is no prefix before the symbol '§' and nothing comes after the call is done. + + Reminder: + - Function calls MUST follow the specified format. + - Only call one function at a time. + - Required parameters MUST be specified. + - Put the entire function call reply on one line. + - If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls. + - If you have already called a tool and got the response for the users question please reply with the response. + """ + ) + assert result.strip() == expected.strip() + + def test_add_function_instructions(self, mock_azure_provider): + functions = [ + { + "name": "python_repl_ast", + "description": "execute Python code", + "parameters": { + "query": "string" + } + }, + { + "name": "data_lookup", + "description": "retrieve data from a database", + "parameters": { + "database": "string", + "query": "string" + } + } + ] + + result = mock_azure_provider.build_function_instructions(functions) + + expected = ( + """ +You have access to the following functions: +Use the function 'python_repl_ast' to: 'execute Python code' +{ + "query": "string" +} + +Use the function 'data_lookup' to: 'retrieve data from a database' +{ + "database": "string", + "query": "string" +} + + +If you choose to use a function to produce this response, ONLY reply in the following format with no prefix or suffix: +§{"type": "function", "name": "FUNCTION_NAME", "parameters": {"PARAMETER_NAME": PARAMETER_VALUE}} + +Here is an example of the output I desiere when performing function call: +§{"type": "function", "name": "python_repl_ast", "parameters": {"query": "print(df.shape)"}} + +Reminder: +- Function calls MUST follow the specified format. +- Only call one function at a time. +- NEVER call more than one function at a time. +- Required parameters MUST be specified. +- Put the entire function call reply on one line. +- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls. +- If you have already called a function and got the response for the user's question, please reply with the response. +""" + ) + + assert result.strip() == expected.strip() From 1be6bbac7b634959a642534d6eeab0eb83973b9c Mon Sep 17 00:00:00 2001 From: Miguel Neves Date: Mon, 18 Nov 2024 15:02:57 +0000 Subject: [PATCH 2/5] chore: added more unit tests and docstrings on azure, removed redundant comments --- libs/core/llmstudio_core/providers/azure.py | 41 ++++- libs/core/tests/unit_tests/test_azure.py | 176 +++++++++++++++++++- 2 files changed, 205 insertions(+), 12 deletions(-) diff --git a/libs/core/llmstudio_core/providers/azure.py b/libs/core/llmstudio_core/providers/azure.py index 5adb75d9..e923ebca 100644 --- a/libs/core/llmstudio_core/providers/azure.py +++ b/libs/core/llmstudio_core/providers/azure.py @@ -167,7 +167,7 @@ def prepare_messages(self, request: ChatRequest): request.parameters.get("tools"), request.parameters.get("functions"), ) - content = self.add_conversation(user_message, content) + content = self.build_llama_conversation(user_message, content) return [{"role": "user", "content": content}] else: return ( @@ -588,35 +588,60 @@ def build_function_instructions(self, functions: list) -> str: """ return function_prompt - def add_conversation(self, openai_message: list, llama_message: str) -> str: + def build_llama_conversation(self, openai_message: list, llama_message: str) -> str: + """ + Appends the OpenAI message to the Llama message while formatting OpenAI messages. + + This function iterates through a list of OpenAI messages and formats them for inclusion + in a Llama message. It handles user messages that might include nested content (lists of + messages) by safely evaluating the content. System messages are skipped. + + Args: + openai_message (list): A list of dictionaries representing the OpenAI messages. Each + dictionary should have "role" and "content" keys. + llama_message (str): The initial Llama message to which the conversation is appended. + + Returns: + str: The Llama message with the conversation appended. + """ conversation_parts = [] for message in openai_message: if message["role"] == "system": continue elif message["role"] == "user" and isinstance(message["content"], str): try: - # Attempt to safely evaluate the string to a Python object content_as_list = ast.literal_eval(message["content"]) if isinstance(content_as_list, list): - # If the content is a list, process each nested message for nested_message in content_as_list: conversation_parts.append( self.format_message(nested_message) ) else: - # If the content is not a list, append it directly conversation_parts.append(self.format_message(message)) except (ValueError, SyntaxError): - # If evaluation fails or content is not a list/dict string, append the message directly conversation_parts.append(self.format_message(message)) else: - # For all other messages, use the existing formatting logic conversation_parts.append(self.format_message(message)) return llama_message + "".join(conversation_parts) def format_message(self, message: dict) -> str: - """Format a single message for the conversation.""" + """ + Formats a single message dictionary into a structured string for a conversation. + + The formatting depends on the content of the message, such as tool calls, + function calls, or simple user/assistant messages. Each type of message + is formatted with specific headers and tags. + + Args: + message (dict): A dictionary containing message details. Expected keys + include "role", "content", and optionally "tool_calls", + "tool_call_id", or "function_call". + + Returns: + str: A formatted string representing the message. Returns an empty + string if the message cannot be formatted. + """ if "tool_calls" in message: for tool_call in message["tool_calls"]: function_name = tool_call["function"]["name"] diff --git a/libs/core/tests/unit_tests/test_azure.py b/libs/core/tests/unit_tests/test_azure.py index fc4dd026..fb5d0343 100644 --- a/libs/core/tests/unit_tests/test_azure.py +++ b/libs/core/tests/unit_tests/test_azure.py @@ -48,7 +48,7 @@ def test_direct_response_handling_without_choices(self, mock_azure_provider): chunk2.model_dump.assert_called_once() class TestBuildLlamaSystemMessage: - def test_with_existing_system_message(self, mock_azure_provider): + def test_build_llama_system_message_with_existing_sm(self, mock_azure_provider): mock_azure_provider.build_tool_instructions = MagicMock(return_value="Tool Instructions") mock_azure_provider.build_function_instructions = MagicMock(return_value="\nFunction Instructions") @@ -72,7 +72,7 @@ def test_with_existing_system_message(self, mock_azure_provider): mock_azure_provider.build_tool_instructions.assert_called_once_with(tools) mock_azure_provider.build_function_instructions.assert_called_once_with(functions) - def test_with_default_system_message(self, mock_azure_provider): + def test_build_llama_system_message_with_default_sm(self, mock_azure_provider): mock_azure_provider.build_tool_instructions = MagicMock(return_value="Tool Instructions") mock_azure_provider.build_function_instructions = MagicMock(return_value="\nFunction Instructions") @@ -93,7 +93,7 @@ def test_with_default_system_message(self, mock_azure_provider): mock_azure_provider.build_tool_instructions.assert_called_once_with(tools) mock_azure_provider.build_function_instructions.assert_not_called() - def test_without_tools_or_functions(self, mock_azure_provider): + def test_build_llama_system_message_without_tools_or_functions(self, mock_azure_provider): mock_azure_provider.build_tool_instructions = MagicMock() mock_azure_provider.build_function_instructions = MagicMock() @@ -114,6 +114,7 @@ def test_without_tools_or_functions(self, mock_azure_provider): mock_azure_provider.build_function_instructions.assert_not_called() class TestBuildInstructions: + def test_build_tool_instructions(self, mock_azure_provider): tools = [ { @@ -177,7 +178,7 @@ def test_build_tool_instructions(self, mock_azure_provider): ) assert result.strip() == expected.strip() - def test_add_function_instructions(self, mock_azure_provider): + def test_build_function_instructions(self, mock_azure_provider): functions = [ { "name": "python_repl_ast", @@ -231,3 +232,170 @@ def test_add_function_instructions(self, mock_azure_provider): ) assert result.strip() == expected.strip() + +class TestBuildLlamaConversation: + + def test_build_llama_conversation_with_nested_messages(self, mock_azure_provider): + mock_azure_provider.format_message = MagicMock(side_effect=lambda msg: f"[formatted:{msg['content']}]") + + openai_message = [ + {"role": "user", "content": "[{'content': 'nested message 1'}, {'content': 'nested message 2'}]"}, + {"role": "assistant", "content": "assistant reply"} + ] + llama_message = "Initial message: " + + result = mock_azure_provider.build_llama_conversation(openai_message, llama_message) + + expected = "Initial message: [formatted:nested message 1][formatted:nested message 2][formatted:assistant reply]" + + assert result == expected + mock_azure_provider.format_message.assert_any_call({"content": "nested message 1"}) + mock_azure_provider.format_message.assert_any_call({"content": "nested message 2"}) + mock_azure_provider.format_message.assert_any_call({"role": "assistant", "content": "assistant reply"}) + + def test_build_llama_conversation_with_invalid_nested_content(self, mock_azure_provider): + mock_azure_provider.format_message = MagicMock(side_effect=lambda msg: f"[formatted:{msg['content']}]") + + openai_message = [ + {"role": "user", "content": "[invalid json]"}, + {"role": "assistant", "content": "assistant reply"} + ] + llama_message = "Initial message: " + + result = mock_azure_provider.build_llama_conversation(openai_message, llama_message) + + expected = "Initial message: [formatted:[invalid json]][formatted:assistant reply]" + + assert result == expected + mock_azure_provider.format_message.assert_any_call({"role": "user", "content": "[invalid json]"}) + mock_azure_provider.format_message.assert_any_call({"role": "assistant", "content": "assistant reply"}) + + def test_build_llama_conversation_skipping_system_messages(self, mock_azure_provider): + mock_azure_provider.format_message = MagicMock(side_effect=lambda msg: f"[formatted:{msg['content']}]") + + openai_message = [ + {"role": "system", "content": "system message"}, + {"role": "user", "content": "user message"} + ] + llama_message = "Initial message: " + + result = mock_azure_provider.build_llama_conversation(openai_message, llama_message) + + expected = "Initial message: [formatted:user message]" + + assert result == expected + mock_azure_provider.format_message.assert_any_call({"role": "user", "content": "user message"}) + + + def test_build_llama_conversation_with_non_list_user_content(self, mock_azure_provider): + mock_azure_provider.format_message = MagicMock(side_effect=lambda msg: f"[formatted:{msg['content']}]") + + openai_message = [ + {"role": "user", "content": "simple user message"} + ] + llama_message = "Initial message: " + + result = mock_azure_provider.build_llama_conversation(openai_message, llama_message) + + expected = "Initial message: [formatted:simple user message]" + + assert result == expected + mock_azure_provider.format_message.assert_any_call({"role": "user", "content": "simple user message"}) + +class TestFormatMessage: + + def test_format_message_tool_calls(self, mock_azure_provider): + message = { + "tool_calls": [ + { + "function": { + "name": "example_tool", + "arguments": '{"arg1": "value1"}' + } + } + ] + } + result = mock_azure_provider.format_message(message) + expected = """ + <|start_header_id|>assistant<|end_header_id|> + {"arg1": "value1"} + <|eom_id|> + """ + assert result.strip() == expected.strip() + + + def test_format_message_tool_call_id(self, mock_azure_provider): + message = { + "tool_call_id": "123", + "content": "This is the tool response." + } + result = mock_azure_provider.format_message(message) + expected = """ + <|start_header_id|>ipython<|end_header_id|> + This is the tool response. + <|eot_id|> + """ + assert result.strip() == expected.strip() + + def test_format_message_function_call(self, mock_azure_provider): + message = { + "function_call": { + "name": "example_function", + "arguments": '{"arg1": "value1"}' + } + } + result = mock_azure_provider.format_message(message) + expected = """ + <|start_header_id|>assistant<|end_header_id|> + {"arg1": "value1"} + <|eom_id|> + """ + assert result.strip() == expected.strip() + + def test_format_message_user_message(self, mock_azure_provider): + message = { + "role": "user", + "content": "This is a user message." + } + result = mock_azure_provider.format_message(message) + expected = """ + <|start_header_id|>user<|end_header_id|> + This is a user message. + <|eot_id|> + """ + assert result.strip() == expected.strip() + + def test_format_message_assistant_message(self, mock_azure_provider): + message = { + "role": "assistant", + "content": "This is an assistant message." + } + result = mock_azure_provider.format_message(message) + expected = """ + <|start_header_id|>assistant<|end_header_id|> + This is an assistant message. + <|eot_id|> + """ + assert result.strip() == expected.strip() + + def test_format_message_function_response(self, mock_azure_provider): + message = { + "role": "function", + "content": "This is the function response." + } + result = mock_azure_provider.format_message(message) + expected = """ + <|start_header_id|>ipython<|end_header_id|> + This is the function response. + <|eot_id|> + """ + assert result.strip() == expected.strip() + + def test_format_message_empty_message(self, mock_azure_provider): + message = { + "role": "user", + "content": None + } + result = mock_azure_provider.format_message(message) + expected = "" + assert result == expected \ No newline at end of file From 70ed8601fa791247cbb6ed64140550256c90d151 Mon Sep 17 00:00:00 2001 From: Miguel Neves Date: Mon, 18 Nov 2024 15:29:47 +0000 Subject: [PATCH 3/5] chore: added unit tests for generate client on Azure Provider --- libs/core/llmstudio_core/providers/azure.py | 24 +++++-- libs/core/tests/unit_tests/conftest.py | 6 -- libs/core/tests/unit_tests/test_azure.py | 78 ++++++++++++++++++++- 3 files changed, 96 insertions(+), 12 deletions(-) diff --git a/libs/core/llmstudio_core/providers/azure.py b/libs/core/llmstudio_core/providers/azure.py index e923ebca..f3c5f255 100644 --- a/libs/core/llmstudio_core/providers/azure.py +++ b/libs/core/llmstudio_core/providers/azure.py @@ -107,7 +107,25 @@ async def agenerate_client(self, request: ChatRequest) -> Any: raise ProviderError(e.response.json()) def generate_client(self, request: ChatRequest) -> Any: - """Generate an AzureOpenAI client""" + """ + Generates an AzureOpenAI client for processing a chat request. + + This method prepares and configures the arguments required to create a client + request to AzureOpenAI's chat completions API. It determines model-specific + configurations (e.g., whether tools or functions are enabled) and combines + these with the base arguments for the API call. + + Args: + request (ChatRequest): The chat request object containing the model, + parameters, and other necessary details. + + Returns: + Any: The result of the chat completions API call. + + Raises: + ProviderError: If there is an issue with the API connection or an error + returned from the API. + """ self.is_llama = "llama" in request.model.lower() self.is_openai = "gpt" in request.model.lower() @@ -117,7 +135,6 @@ def generate_client(self, request: ChatRequest) -> Any: try: messages = self.prepare_messages(request) - # Prepare the optional tool-related arguments tool_args = {} if not self.is_llama and self.has_tools and self.is_openai: tool_args = { @@ -125,7 +142,6 @@ def generate_client(self, request: ChatRequest) -> Any: "tool_choice": "auto" if request.parameters.get("tools") else None, } - # Prepare the optional function-related arguments function_args = {} if not self.is_llama and self.has_functions and self.is_openai: function_args = { @@ -135,14 +151,12 @@ def generate_client(self, request: ChatRequest) -> Any: else None, } - # Prepare the base arguments base_args = { "model": request.model, "messages": messages, "stream": True, } - # Combine all arguments combined_args = { **base_args, **tool_args, diff --git a/libs/core/tests/unit_tests/conftest.py b/libs/core/tests/unit_tests/conftest.py index fcc74f16..662b0688 100644 --- a/libs/core/tests/unit_tests/conftest.py +++ b/libs/core/tests/unit_tests/conftest.py @@ -63,12 +63,6 @@ async def async_gen(): yield {} return async_gen() - def generate_client(self, request): - # For testing, return a generator - def gen(): - yield {} - return gen() - @staticmethod def _provider_config_name(): return "mock_azure_provider" diff --git a/libs/core/tests/unit_tests/test_azure.py b/libs/core/tests/unit_tests/test_azure.py index fb5d0343..9873c913 100644 --- a/libs/core/tests/unit_tests/test_azure.py +++ b/libs/core/tests/unit_tests/test_azure.py @@ -1,5 +1,6 @@ import pytest from unittest.mock import MagicMock +from llmstudio_core.providers.azure import ProviderError, openai class TestParseResponse: def test_tool_response_handling(self, mock_azure_provider): @@ -398,4 +399,79 @@ def test_format_message_empty_message(self, mock_azure_provider): } result = mock_azure_provider.format_message(message) expected = "" - assert result == expected \ No newline at end of file + assert result == expected + +class TestGenerateClient: + + def test_generate_client_with_tools_and_functions(self, mock_azure_provider): + mock_azure_provider.prepare_messages = MagicMock(return_value="prepared_messages") + mock_azure_provider._client.chat.completions.create = MagicMock(return_value="mock_response") + + request = MagicMock() + request.model = "gpt-4" + request.parameters = { + "tools": ["tool1", "tool2"], + "functions": ["function1", "function2"], + "other_param": "value", + } + + result = mock_azure_provider.generate_client(request) + + expected_args = { + "model": "gpt-4", + "messages": "prepared_messages", + "stream": True, + "tools": ["tool1", "tool2"], + "tool_choice": "auto", + "functions": ["function1", "function2"], + "function_call": "auto", + "other_param": "value", + } + + assert result == "mock_response" + mock_azure_provider.prepare_messages.assert_called_once_with(request) + mock_azure_provider._client.chat.completions.create.assert_called_once_with(**expected_args) + + def test_generate_client_without_tools_or_functions(self, mock_azure_provider): + mock_azure_provider.prepare_messages = MagicMock(return_value="prepared_messages") + mock_azure_provider._client.chat.completions.create = MagicMock(return_value="mock_response") + + request = MagicMock() + request.model = "gpt-4" + request.parameters = {"other_param": "value"} + + result = mock_azure_provider.generate_client(request) + + expected_args = { + "model": "gpt-4", + "messages": "prepared_messages", + "stream": True, + "other_param": "value", + } + + assert result == "mock_response" + mock_azure_provider.prepare_messages.assert_called_once_with(request) + mock_azure_provider._client.chat.completions.create.assert_called_once_with(**expected_args) + + def test_generate_client_with_llama_model(self, mock_azure_provider): + mock_azure_provider.prepare_messages = MagicMock(return_value="prepared_messages") + mock_azure_provider._client.chat.completions.create = MagicMock(return_value="mock_response") + + request = MagicMock() + request.model = "llama-2" + request.parameters = {"tools": ["tool1"], "functions": ["function1"], "other_param": "value"} + + result = mock_azure_provider.generate_client(request) + + expected_args = { + "model": "llama-2", + "messages": "prepared_messages", + "stream": True, + "tools": ["tool1"], + "functions": ["function1"], + "other_param": "value", + } + + assert result == "mock_response" + mock_azure_provider.prepare_messages.assert_called_once_with(request) + mock_azure_provider._client.chat.completions.create.assert_called_once_with(**expected_args) \ No newline at end of file From 773df1aa5985d71ae837e6a57d880fe2bcf0db18 Mon Sep 17 00:00:00 2001 From: Miguel Neves Date: Wed, 18 Dec 2024 09:43:45 +0000 Subject: [PATCH 4/5] chore: separated azure unit tests into separate files. fixed some of its tests. --- libs/core/tests/unit_tests/test_azure.py | 257 ------------------ .../core/tests/unit_tests/test_azure_build.py | 243 +++++++++++++++++ 2 files changed, 243 insertions(+), 257 deletions(-) create mode 100644 libs/core/tests/unit_tests/test_azure_build.py diff --git a/libs/core/tests/unit_tests/test_azure.py b/libs/core/tests/unit_tests/test_azure.py index 9873c913..63c658b8 100644 --- a/libs/core/tests/unit_tests/test_azure.py +++ b/libs/core/tests/unit_tests/test_azure.py @@ -1,6 +1,4 @@ -import pytest from unittest.mock import MagicMock -from llmstudio_core.providers.azure import ProviderError, openai class TestParseResponse: def test_tool_response_handling(self, mock_azure_provider): @@ -48,261 +46,6 @@ def test_direct_response_handling_without_choices(self, mock_azure_provider): chunk1.model_dump.assert_called_once() chunk2.model_dump.assert_called_once() -class TestBuildLlamaSystemMessage: - def test_build_llama_system_message_with_existing_sm(self, mock_azure_provider): - mock_azure_provider.build_tool_instructions = MagicMock(return_value="Tool Instructions") - mock_azure_provider.build_function_instructions = MagicMock(return_value="\nFunction Instructions") - - openai_message = [ - {"role": "user", "content": "Hello"}, - {"role": "system", "content": "Custom system message"} - ] - llama_message = "Initial message" - tools = ["Tool1", "Tool2"] - functions = ["Function1"] - - result = mock_azure_provider.build_llama_system_message(openai_message, llama_message, tools, functions) - - expected = ( - "Initial message\n" - " <|start_header_id|>system<|end_header_id|>\n" - " Custom system message\n" - " Tool Instructions\nFunction Instructions\n<|eot_id|>" # identation here exists because in Python when adding a newline to a triple quote string it keeps identation - ) - assert result == expected - mock_azure_provider.build_tool_instructions.assert_called_once_with(tools) - mock_azure_provider.build_function_instructions.assert_called_once_with(functions) - - def test_build_llama_system_message_with_default_sm(self, mock_azure_provider): - mock_azure_provider.build_tool_instructions = MagicMock(return_value="Tool Instructions") - mock_azure_provider.build_function_instructions = MagicMock(return_value="\nFunction Instructions") - - openai_message = [{"role": "user", "content": "Hello"}] - llama_message = "Initial message" - tools = ["Tool1"] - functions = [] - - result = mock_azure_provider.build_llama_system_message(openai_message, llama_message, tools, functions) - - expected = ( - "Initial message\n" - " <|start_header_id|>system<|end_header_id|>\n" - " You are a helpful AI assistant.\n" - " Tool Instructions\n<|eot_id|>" - ) - assert result == expected - mock_azure_provider.build_tool_instructions.assert_called_once_with(tools) - mock_azure_provider.build_function_instructions.assert_not_called() - - def test_build_llama_system_message_without_tools_or_functions(self, mock_azure_provider): - mock_azure_provider.build_tool_instructions = MagicMock() - mock_azure_provider.build_function_instructions = MagicMock() - - openai_message = [{"role": "system", "content": "Minimal system message"}] - llama_message = "Initial message" - tools = [] - functions = [] - - result = mock_azure_provider.build_llama_system_message(openai_message, llama_message, tools, functions) - - expected = ( - "Initial message\n" - " <|start_header_id|>system<|end_header_id|>\n" - " Minimal system message\n \n<|eot_id|>" - ) - assert result == expected - mock_azure_provider.build_tool_instructions.assert_not_called() - mock_azure_provider.build_function_instructions.assert_not_called() - -class TestBuildInstructions: - - def test_build_tool_instructions(self, mock_azure_provider): - tools = [ - { - "type": "function", - "function": { - "name": "python_repl_ast", - "description": "execute Python code", - "parameters": { - "query": "string" - } - } - }, - { - "type": "function", - "function": { - "name": "data_lookup", - "description": "retrieve data from a database", - "parameters": { - "database": "string", - "query": "string" - } - } - } - ] - - result = mock_azure_provider.build_tool_instructions(tools) - - expected = ( - """ - You have access to the following tools: - Use the function 'python_repl_ast' to 'execute Python code': -Parameters format: -{ - "query": "string" -} - -Use the function 'data_lookup' to 'retrieve data from a database': -Parameters format: -{ - "database": "string", - "query": "string" -} - - -If you choose to use a function to produce this response, ONLY reply in the following format with no prefix or suffix: -§{"type": "function", "name": "FUNCTION_NAME", "parameters": {"PARAMETER_NAME": PARAMETER_VALUE}} -IMPORTANT: IT IS VITAL THAT YOU NEVER ADD A PREFIX OR A SUFFIX TO THE FUNCTION CALL. - -Here is an example of the output I desiere when performing function call: -§{"type": "function", "name": "python_repl_ast", "parameters": {"query": "print(df.shape)"}} -NOTE: There is no prefix before the symbol '§' and nothing comes after the call is done. - - Reminder: - - Function calls MUST follow the specified format. - - Only call one function at a time. - - Required parameters MUST be specified. - - Put the entire function call reply on one line. - - If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls. - - If you have already called a tool and got the response for the users question please reply with the response. - """ - ) - assert result.strip() == expected.strip() - - def test_build_function_instructions(self, mock_azure_provider): - functions = [ - { - "name": "python_repl_ast", - "description": "execute Python code", - "parameters": { - "query": "string" - } - }, - { - "name": "data_lookup", - "description": "retrieve data from a database", - "parameters": { - "database": "string", - "query": "string" - } - } - ] - - result = mock_azure_provider.build_function_instructions(functions) - - expected = ( - """ -You have access to the following functions: -Use the function 'python_repl_ast' to: 'execute Python code' -{ - "query": "string" -} - -Use the function 'data_lookup' to: 'retrieve data from a database' -{ - "database": "string", - "query": "string" -} - - -If you choose to use a function to produce this response, ONLY reply in the following format with no prefix or suffix: -§{"type": "function", "name": "FUNCTION_NAME", "parameters": {"PARAMETER_NAME": PARAMETER_VALUE}} - -Here is an example of the output I desiere when performing function call: -§{"type": "function", "name": "python_repl_ast", "parameters": {"query": "print(df.shape)"}} - -Reminder: -- Function calls MUST follow the specified format. -- Only call one function at a time. -- NEVER call more than one function at a time. -- Required parameters MUST be specified. -- Put the entire function call reply on one line. -- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls. -- If you have already called a function and got the response for the user's question, please reply with the response. -""" - ) - - assert result.strip() == expected.strip() - -class TestBuildLlamaConversation: - - def test_build_llama_conversation_with_nested_messages(self, mock_azure_provider): - mock_azure_provider.format_message = MagicMock(side_effect=lambda msg: f"[formatted:{msg['content']}]") - - openai_message = [ - {"role": "user", "content": "[{'content': 'nested message 1'}, {'content': 'nested message 2'}]"}, - {"role": "assistant", "content": "assistant reply"} - ] - llama_message = "Initial message: " - - result = mock_azure_provider.build_llama_conversation(openai_message, llama_message) - - expected = "Initial message: [formatted:nested message 1][formatted:nested message 2][formatted:assistant reply]" - - assert result == expected - mock_azure_provider.format_message.assert_any_call({"content": "nested message 1"}) - mock_azure_provider.format_message.assert_any_call({"content": "nested message 2"}) - mock_azure_provider.format_message.assert_any_call({"role": "assistant", "content": "assistant reply"}) - - def test_build_llama_conversation_with_invalid_nested_content(self, mock_azure_provider): - mock_azure_provider.format_message = MagicMock(side_effect=lambda msg: f"[formatted:{msg['content']}]") - - openai_message = [ - {"role": "user", "content": "[invalid json]"}, - {"role": "assistant", "content": "assistant reply"} - ] - llama_message = "Initial message: " - - result = mock_azure_provider.build_llama_conversation(openai_message, llama_message) - - expected = "Initial message: [formatted:[invalid json]][formatted:assistant reply]" - - assert result == expected - mock_azure_provider.format_message.assert_any_call({"role": "user", "content": "[invalid json]"}) - mock_azure_provider.format_message.assert_any_call({"role": "assistant", "content": "assistant reply"}) - - def test_build_llama_conversation_skipping_system_messages(self, mock_azure_provider): - mock_azure_provider.format_message = MagicMock(side_effect=lambda msg: f"[formatted:{msg['content']}]") - - openai_message = [ - {"role": "system", "content": "system message"}, - {"role": "user", "content": "user message"} - ] - llama_message = "Initial message: " - - result = mock_azure_provider.build_llama_conversation(openai_message, llama_message) - - expected = "Initial message: [formatted:user message]" - - assert result == expected - mock_azure_provider.format_message.assert_any_call({"role": "user", "content": "user message"}) - - - def test_build_llama_conversation_with_non_list_user_content(self, mock_azure_provider): - mock_azure_provider.format_message = MagicMock(side_effect=lambda msg: f"[formatted:{msg['content']}]") - - openai_message = [ - {"role": "user", "content": "simple user message"} - ] - llama_message = "Initial message: " - - result = mock_azure_provider.build_llama_conversation(openai_message, llama_message) - - expected = "Initial message: [formatted:simple user message]" - - assert result == expected - mock_azure_provider.format_message.assert_any_call({"role": "user", "content": "simple user message"}) - class TestFormatMessage: def test_format_message_tool_calls(self, mock_azure_provider): diff --git a/libs/core/tests/unit_tests/test_azure_build.py b/libs/core/tests/unit_tests/test_azure_build.py new file mode 100644 index 00000000..10a6c8d6 --- /dev/null +++ b/libs/core/tests/unit_tests/test_azure_build.py @@ -0,0 +1,243 @@ +from unittest.mock import MagicMock, patch + +class TestBuildLlamaSystemMessage: + def test_build_llama_system_message_with_existing_sm(self, mock_azure_provider): + mock_azure_provider.build_tool_instructions = MagicMock(return_value="Tool Instructions") + mock_azure_provider.build_function_instructions = MagicMock(return_value="\nFunction Instructions") + + openai_message = [ + {"role": "user", "content": "Hello"}, + {"role": "system", "content": "Custom system message"} + ] + llama_message = "Initial message" + tools = ["Tool1", "Tool2"] + functions = ["Function1"] + + result = mock_azure_provider.build_llama_system_message(openai_message, llama_message, tools, functions) + + expected = ( + "Initial message\n" + " <|start_header_id|>system<|end_header_id|>\n" + " Custom system message\n" + " Tool Instructions\nFunction Instructions\n<|eot_id|>" # identation here exists because in Python when adding a newline to a triple quote string it keeps identation + ) + assert result == expected + mock_azure_provider.build_tool_instructions.assert_called_once_with(tools) + mock_azure_provider.build_function_instructions.assert_called_once_with(functions) + + def test_build_llama_system_message_with_default_sm(self, mock_azure_provider): + mock_azure_provider.build_tool_instructions = MagicMock(return_value="Tool Instructions") + mock_azure_provider.build_function_instructions = MagicMock(return_value="\nFunction Instructions") + + openai_message = [{"role": "user", "content": "Hello"}] + llama_message = "Initial message" + tools = ["Tool1"] + functions = [] + + result = mock_azure_provider.build_llama_system_message(openai_message, llama_message, tools, functions) + + expected = ( + "Initial message\n" + " <|start_header_id|>system<|end_header_id|>\n" + " You are a helpful AI assistant.\n" + " Tool Instructions\n<|eot_id|>" + ) + assert result == expected + mock_azure_provider.build_tool_instructions.assert_called_once_with(tools) + mock_azure_provider.build_function_instructions.assert_not_called() + + def test_build_llama_system_message_without_tools_or_functions(self, mock_azure_provider): + mock_azure_provider.build_tool_instructions = MagicMock() + mock_azure_provider.build_function_instructions = MagicMock() + + openai_message = [{"role": "system", "content": "Minimal system message"}] + llama_message = "Initial message" + tools = [] + functions = [] + + result = mock_azure_provider.build_llama_system_message(openai_message, llama_message, tools, functions) + + expected = ( + "Initial message\n" + " <|start_header_id|>system<|end_header_id|>\n" + " Minimal system message\n \n<|eot_id|>" + ) + assert result == expected + mock_azure_provider.build_tool_instructions.assert_not_called() + mock_azure_provider.build_function_instructions.assert_not_called() + +class TestBuildInstructions: + + def test_build_tool_instructions(self, mock_azure_provider): + tools = [ + { + "type": "function", + "function": { + "name": "python_repl_ast", + "description": "execute Python code", + "parameters": { + "query": "string" + } + } + }, + { + "type": "function", + "function": { + "name": "data_lookup", + "description": "retrieve data from a database", + "parameters": { + "database": "string", + "query": "string" + } + } + } + ] + + result = mock_azure_provider.build_tool_instructions(tools) + + expected = ( + """ + You have access to the following tools: + Use the function 'python_repl_ast' to 'execute Python code': +Parameters format: +{ + "query": "string" +} + +Use the function 'data_lookup' to 'retrieve data from a database': +Parameters format: +{ + "database": "string", + "query": "string" +} + + +If you choose to use a function to produce this response, ONLY reply in the following format with no prefix or suffix: +§{"type": "function", "name": "FUNCTION_NAME", "parameters": {"PARAMETER_NAME": PARAMETER_VALUE}} +IMPORTANT: IT IS VITAL THAT YOU NEVER ADD A PREFIX OR A SUFFIX TO THE FUNCTION CALL. + +Here is an example of the output I desiere when performing function call: +§{"type": "function", "name": "python_repl_ast", "parameters": {"query": "print(df.shape)"}} +NOTE: There is no prefix before the symbol '§' and nothing comes after the call is done. + + Reminder: + - Function calls MUST follow the specified format. + - Only call one function at a time. + - Required parameters MUST be specified. + - Put the entire function call reply on one line. + - If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls. + - If you have already called a tool and got the response for the users question please reply with the response. + """ + ) + assert result.strip() == expected.strip() + + def test_build_function_instructions(self, mock_azure_provider): + functions = [ + { + "name": "python_repl_ast", + "description": "execute Python code", + "parameters": { + "query": "string" + } + }, + { + "name": "data_lookup", + "description": "retrieve data from a database", + "parameters": { + "database": "string", + "query": "string" + } + } + ] + + result = mock_azure_provider.build_function_instructions(functions) + + expected = ( + """ +You have access to the following functions: +Use the function 'python_repl_ast' to: 'execute Python code' +{ + "query": "string" +} + +Use the function 'data_lookup' to: 'retrieve data from a database' +{ + "database": "string", + "query": "string" +} + + +If you choose to use a function to produce this response, ONLY reply in the following format with no prefix or suffix: +§{"type": "function", "name": "FUNCTION_NAME", "parameters": {"PARAMETER_NAME": PARAMETER_VALUE}} + +Here is an example of the output I desiere when performing function call: +§{"type": "function", "name": "python_repl_ast", "parameters": {"query": "print(df.shape)"}} + +Reminder: +- Function calls MUST follow the specified format. +- Only call one function at a time. +- NEVER call more than one function at a time. +- Required parameters MUST be specified. +- Put the entire function call reply on one line. +- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls. +- If you have already called a function and got the response for the user's question, please reply with the response. +""" + ) + + assert result.strip() == expected.strip() + +class TestBuildLlamaConversation: + + def test_build_llama_conversation_with_nested_messages(self, mock_azure_provider): + mock_azure_provider.format_message = MagicMock(side_effect=lambda msg: f"[formatted:{msg['content']}]") + + openai_message = [ + {"role": "user", "content": "[{'content': 'nested message 1'}, {'content': 'nested message 2'}]"}, + {"role": "assistant", "content": "assistant reply"} + ] + llama_message = "Initial message: " + + result = mock_azure_provider.build_llama_conversation(openai_message, llama_message) + + expected = "Initial message: [formatted:nested message 1][formatted:nested message 2][formatted:assistant reply]" + + assert result == expected + mock_azure_provider.format_message.assert_any_call({"content": "nested message 1"}) + mock_azure_provider.format_message.assert_any_call({"content": "nested message 2"}) + mock_azure_provider.format_message.assert_any_call({"role": "assistant", "content": "assistant reply"}) + + def test_build_llama_conversation_with_invalid_nested_content(self, mock_azure_provider): + mock_azure_provider.format_message = MagicMock(side_effect=lambda msg: f"[formatted:{msg['content']}]") + + openai_message = [ + {"role": "user", "content": "[invalid json/dict]"}, + {"role": "assistant", "content": "assistant reply"} + ] + llama_message = "Initial message: " + + with patch("ast.literal_eval", side_effect=ValueError) as mock_literal_eval: + result = mock_azure_provider.build_llama_conversation(openai_message, llama_message) + + expected = "Initial message: [formatted:[invalid json/dict]][formatted:assistant reply]" + + assert result == expected + mock_azure_provider.format_message.assert_any_call({"role": "user", "content": "[invalid json/dict]"}) + mock_azure_provider.format_message.assert_any_call({"role": "assistant", "content": "assistant reply"}) + + mock_literal_eval.assert_called_once_with("[invalid json/dict]") + + def test_build_llama_conversation_skipping_system_messages(self, mock_azure_provider): + mock_azure_provider.format_message = MagicMock(side_effect=lambda msg: f"[formatted:{msg['content']}]") + + openai_message = [ + {"role": "system", "content": "system message"}, + {"role": "user", "content": "user message"} + ] + llama_message = "Initial message: " + + result = mock_azure_provider.build_llama_conversation(openai_message, llama_message) + + expected = "Initial message: [formatted:user message]" + + assert result == expected + mock_azure_provider.format_message.assert_any_call({"role": "user", "content": "user message"}) \ No newline at end of file From fbe0602752a97aaad9cde32775a164c777f8717e Mon Sep 17 00:00:00 2001 From: Miguel Neves Date: Wed, 18 Dec 2024 10:04:20 +0000 Subject: [PATCH 5/5] chore: linted code --- .gitignore | 4 +- libs/core/llmstudio_core/providers/azure.py | 40 +-- .../core/llmstudio_core/providers/provider.py | 54 +++-- libs/core/tests/unit_tests/conftest.py | 14 +- libs/core/tests/unit_tests/test_azure.py | 106 ++++---- .../core/tests/unit_tests/test_azure_build.py | 160 +++++++----- libs/core/tests/unit_tests/test_provider.py | 228 ++++++++++++------ 7 files changed, 368 insertions(+), 238 deletions(-) diff --git a/.gitignore b/.gitignore index 19015866..d80305ef 100644 --- a/.gitignore +++ b/.gitignore @@ -56,6 +56,7 @@ env3 .env* .env*.local .venv* +*venv* env*/ venv*/ ENV/ @@ -66,6 +67,7 @@ venv.bak/ config.yaml bun.lockb + # Jupyter Notebook .ipynb_checkpoints @@ -76,4 +78,4 @@ bun.lockb llmstudio/llm_engine/logs/execution_logs.jsonl *.db .prettierignore -db \ No newline at end of file +db diff --git a/libs/core/llmstudio_core/providers/azure.py b/libs/core/llmstudio_core/providers/azure.py index f3c5f255..666bc6a0 100644 --- a/libs/core/llmstudio_core/providers/azure.py +++ b/libs/core/llmstudio_core/providers/azure.py @@ -116,7 +116,7 @@ def generate_client(self, request: ChatRequest) -> Any: these with the base arguments for the API call. Args: - request (ChatRequest): The chat request object containing the model, + request (ChatRequest): The chat request object containing the model, parameters, and other necessary details. Returns: @@ -207,8 +207,8 @@ def parse_response(self, response: Generator, **kwargs) -> Any: """ Processes a generator response and yields processed chunks. - If `is_llama` is True and tools or functions are enabled, it processes the response - using `handle_tool_response`. Otherwise, it processes each chunk and yields only those + If `is_llama` is True and tools or functions are enabled, it processes the response + using `handle_tool_response`. Otherwise, it processes each chunk and yields only those containing "choices". Args: @@ -473,7 +473,7 @@ def build_llama_system_message( """ Builds a complete system message for Llama based on OpenAI's message, tools, and functions. - If a system message is present in the OpenAI message, it is included in the result. + If a system message is present in the OpenAI message, it is included in the result. Otherwise, a default system message is used. Additional tool and function instructions are appended if provided. @@ -505,7 +505,9 @@ def build_llama_system_message( system_message = system_message + self.build_tool_instructions(tools) if functions: - system_message = system_message + self.build_function_instructions(functions) + system_message = system_message + self.build_function_instructions( + functions + ) end_tag = "\n<|eot_id|>" return llama_message + system_message + end_tag @@ -514,12 +516,12 @@ def build_tool_instructions(self, tools: list) -> str: """ Builds a detailed instructional prompt for tools available to the assistant. - This function generates a message describing the available tools, focusing on tools - of type "function." It explains to the LLM how to use each tool and provides an example of the + This function generates a message describing the available tools, focusing on tools + of type "function." It explains to the LLM how to use each tool and provides an example of the correct response format for function calls. Args: - tools (list): A list of tool dictionaries, where each dictionary contains tool + tools (list): A list of tool dictionaries, where each dictionary contains tool details such as type, function name, description, and parameters. Returns: @@ -562,12 +564,12 @@ def build_function_instructions(self, functions: list) -> str: """ Builds a detailed instructional prompt for available functions. - This method creates a message describing the functions accessible to the assistant. - It includes the function name, description, and required parameters, along with + This method creates a message describing the functions accessible to the assistant. + It includes the function name, description, and required parameters, along with specific guidelines for calling functions. Args: - functions (list): A list of function dictionaries, each containing details such as + functions (list): A list of function dictionaries, each containing details such as name, description, and parameters. Returns: @@ -606,12 +608,12 @@ def build_llama_conversation(self, openai_message: list, llama_message: str) -> """ Appends the OpenAI message to the Llama message while formatting OpenAI messages. - This function iterates through a list of OpenAI messages and formats them for inclusion - in a Llama message. It handles user messages that might include nested content (lists of + This function iterates through a list of OpenAI messages and formats them for inclusion + in a Llama message. It handles user messages that might include nested content (lists of messages) by safely evaluating the content. System messages are skipped. Args: - openai_message (list): A list of dictionaries representing the OpenAI messages. Each + openai_message (list): A list of dictionaries representing the OpenAI messages. Each dictionary should have "role" and "content" keys. llama_message (str): The initial Llama message to which the conversation is appended. @@ -643,17 +645,17 @@ def format_message(self, message: dict) -> str: """ Formats a single message dictionary into a structured string for a conversation. - The formatting depends on the content of the message, such as tool calls, - function calls, or simple user/assistant messages. Each type of message + The formatting depends on the content of the message, such as tool calls, + function calls, or simple user/assistant messages. Each type of message is formatted with specific headers and tags. Args: - message (dict): A dictionary containing message details. Expected keys - include "role", "content", and optionally "tool_calls", + message (dict): A dictionary containing message details. Expected keys + include "role", "content", and optionally "tool_calls", "tool_call_id", or "function_call". Returns: - str: A formatted string representing the message. Returns an empty + str: A formatted string representing the message. Returns an empty string if the message cannot be formatted. """ if "tool_calls" in message: diff --git a/libs/core/llmstudio_core/providers/provider.py b/libs/core/llmstudio_core/providers/provider.py index c3d4e23a..eabb25b7 100644 --- a/libs/core/llmstudio_core/providers/provider.py +++ b/libs/core/llmstudio_core/providers/provider.py @@ -148,7 +148,7 @@ async def achat( **kwargs, ): """ - Asynchronously establishes a chat connection with the provider’s API, handling retries, + Asynchronously establishes a chat connection with the provider’s API, handling retries, request validation, and streaming response options. Parameters @@ -158,7 +158,7 @@ async def achat( model : str The identifier of the model to be used for the chat request. is_stream : Optional[bool], default=False - Flag to indicate if the response should be streamed. If True, returns an async generator + Flag to indicate if the response should be streamed. If True, returns an async generator for streaming content; otherwise, returns the complete response. retries : Optional[int], default=0 Number of retry attempts on error. Retries will be attempted for specific HTTP errors like rate limits. @@ -224,7 +224,7 @@ def chat( **kwargs, ): """ - Establishes a chat connection with the provider’s API, handling retries, request validation, + Establishes a chat connection with the provider’s API, handling retries, request validation, and streaming response options. Parameters @@ -234,7 +234,7 @@ def chat( model : str The model identifier for selecting the model used in the chat request. is_stream : Optional[bool], default=False - Flag to indicate if the response should be streamed. If True, the function returns a generator + Flag to indicate if the response should be streamed. If True, the function returns a generator for streaming content. Otherwise, it returns the complete response. retries : Optional[int], default=0 Number of retry attempts on error. Retries will be attempted on specific HTTP errors like rate limits. @@ -294,9 +294,9 @@ async def ahandle_response( self, request: ChatRequest, response: AsyncGenerator, start_time: float ) -> AsyncGenerator[str, None]: """ - Asynchronously handles the response from an API, processing response chunks for either + Asynchronously handles the response from an API, processing response chunks for either streaming or non-streaming responses. - + Buffers response chunks for non-streaming responses to output one single message. For streaming responses sends incremental chunks. Parameters @@ -311,7 +311,7 @@ async def ahandle_response( Yields ------ Union[ChatCompletionChunk, ChatCompletion] - - If `request.is_stream` is True, yields `ChatCompletionChunk` objects with incremental + - If `request.is_stream` is True, yields `ChatCompletionChunk` objects with incremental response chunks for streaming. - If `request.is_stream` is False, yields a final `ChatCompletion` object after processing all chunks. """ @@ -423,10 +423,10 @@ def handle_response( self, request: ChatRequest, response: Generator, start_time: float ) -> Generator: """ - Processes API response chunks to build a structured, complete response, yielding + Processes API response chunks to build a structured, complete response, yielding each chunk if streaming is enabled. - - If streaming, each chunk is yielded as soon as it’s processed. Otherwise, all chunks + + If streaming, each chunk is yielded as soon as it’s processed. Otherwise, all chunks are combined and yielded as a single response at the end. Parameters @@ -551,8 +551,8 @@ def handle_response( def join_chunks(self, chunks): """ - Combine multiple response chunks from the model into a single, structured response. - Handles tool calls, function calls, and standard text completion based on the + Combine multiple response chunks from the model into a single, structured response. + Handles tool calls, function calls, and standard text completion based on the purpose indicated by the final chunk. Parameters @@ -563,7 +563,7 @@ def join_chunks(self, chunks): Returns ------- Tuple[ChatCompletion, str] - - `ChatCompletion`: The structured response based on the type of completion + - `ChatCompletion`: The structured response based on the type of completion (tool calls, function call, or text). - `str`: The concatenated content or arguments, depending on the completion type. @@ -733,7 +733,7 @@ def calculate_metrics( token_count: int, ) -> Dict[str, Any]: """ - Calculates performance and cost metrics for a model response based on timing + Calculates performance and cost metrics for a model response based on timing information, token counts, and model-specific costs. Parameters @@ -783,17 +783,21 @@ def calculate_metrics( "cost_usd": input_cost + output_cost, "latency_s": total_time, "time_to_first_token_s": first_token_time - start_time, - "inter_token_latency_s": sum(token_times) / len(token_times) if token_times else 0, - "tokens_per_second": token_count / total_time if token_times else 1 / total_time, + "inter_token_latency_s": sum(token_times) / len(token_times) + if token_times + else 0, + "tokens_per_second": token_count / total_time + if token_times + else 1 / total_time, } def calculate_cost( self, token_count: int, token_cost: Union[float, List[Dict[str, Any]]] ) -> float: """ - Calculates the cost for a given number of tokens based on a fixed cost per token + Calculates the cost for a given number of tokens based on a fixed cost per token or a variable rate structure. - + If `token_cost` is a fixed float, the total cost is `token_count * token_cost`. If `token_cost` is a list, it checks each range and calculates cost based on the applicable range's rate. @@ -802,14 +806,14 @@ def calculate_cost( token_count : int The total number of tokens for which the cost is being calculated. token_cost : Union[float, List[Dict[str, Any]]] - Either a fixed cost per token (as a float) or a list of dictionaries defining - variable cost ranges. Each dictionary in the list represents a range with + Either a fixed cost per token (as a float) or a list of dictionaries defining + variable cost ranges. Each dictionary in the list represents a range with 'range' (a tuple of minimum and maximum token counts) and 'cost' (cost per token) keys. Returns ------- float - The calculated cost based on the token count and cost structure. + The calculated cost based on the token count and cost structure. """ if isinstance(token_cost, list): for cost_range in token_cost: @@ -830,13 +834,13 @@ def input_to_string(self, input): input : Any The input data to be converted. This can be: - A simple string, which is returned as-is. - - A list of message dictionaries, where each dictionary may contain `content`, `role`, + - A list of message dictionaries, where each dictionary may contain `content`, `role`, and nested items like `text` or `image_url`. Returns ------- str - A concatenated string representing the text content of all messages, + A concatenated string representing the text content of all messages, including text and URLs from image content if present. """ if isinstance(input, str): @@ -861,13 +865,13 @@ def input_to_string(self, input): def output_to_string(self, output): """ - Extracts and returns the content or arguments from the output based on + Extracts and returns the content or arguments from the output based on the `finish_reason` of the first choice in `output`. Parameters ---------- output : Any - The model output object, expected to have a `choices` attribute that should contain a `finish_reason` indicating the type of output + The model output object, expected to have a `choices` attribute that should contain a `finish_reason` indicating the type of output ("stop", "tool_calls", or "function_call") and corresponding content or arguments. Returns diff --git a/libs/core/tests/unit_tests/conftest.py b/libs/core/tests/unit_tests/conftest.py index 662b0688..5eeccc83 100644 --- a/libs/core/tests/unit_tests/conftest.py +++ b/libs/core/tests/unit_tests/conftest.py @@ -1,8 +1,8 @@ from unittest.mock import MagicMock import pytest -from llmstudio_core.providers.provider import ProviderCore from llmstudio_core.providers.azure import AzureProvider +from llmstudio_core.providers.provider import ProviderCore class MockProvider(ProviderCore): @@ -19,7 +19,7 @@ def output_to_string(self, output): if output.choices[0].finish_reason == "stop": return output.choices[0].message.content return "" - + def validate_request(self, request): # For testing, simply return the request return request @@ -28,12 +28,14 @@ async def agenerate_client(self, request): # For testing, return an async generator async def async_gen(): yield {} + return async_gen() def generate_client(self, request): # For testing, return a generator def gen(): yield {} + return gen() @staticmethod @@ -61,15 +63,17 @@ async def agenerate_client(self, request): # For testing, return an async generator async def async_gen(): yield {} + return async_gen() @staticmethod def _provider_config_name(): return "mock_azure_provider" - + + @pytest.fixture def mock_azure_provider(): config = MagicMock() config.id = "mock_azure_provider" - base_url = 'mock_url.com' - return MockAzureProvider(config=config, base_url=base_url) \ No newline at end of file + base_url = "mock_url.com" + return MockAzureProvider(config=config, base_url=base_url) diff --git a/libs/core/tests/unit_tests/test_azure.py b/libs/core/tests/unit_tests/test_azure.py index 63c658b8..28e5f684 100644 --- a/libs/core/tests/unit_tests/test_azure.py +++ b/libs/core/tests/unit_tests/test_azure.py @@ -1,18 +1,21 @@ from unittest.mock import MagicMock + class TestParseResponse: def test_tool_response_handling(self, mock_azure_provider): - + mock_azure_provider.is_llama = True mock_azure_provider.has_tools = True mock_azure_provider.has_functions = False - - mock_azure_provider.handle_tool_response = MagicMock(return_value=iter(["chunk1", None, "chunk2", None, "chunk3"])) - + + mock_azure_provider.handle_tool_response = MagicMock( + return_value=iter(["chunk1", None, "chunk2", None, "chunk3"]) + ) + response = iter(["irrelevant"]) - + results = list(mock_azure_provider.parse_response(response)) - + assert results == ["chunk1", "chunk2", "chunk3"] mock_azure_provider.handle_tool_response.assert_called_once_with(response) @@ -20,17 +23,20 @@ def test_direct_response_handling_with_choices(self, mock_azure_provider): mock_azure_provider.is_llama = False chunk1 = MagicMock() - chunk1.model_dump.return_value = {"choices": ["choice1","choice2"]} + chunk1.model_dump.return_value = {"choices": ["choice1", "choice2"]} chunk2 = MagicMock() chunk2.model_dump.return_value = {"choices": ["choice2"]} response = iter([chunk1, chunk2]) results = list(mock_azure_provider.parse_response(response)) - assert results == [{"choices": ["choice1", "choice2"]}, {"choices": ["choice2"]}] + assert results == [ + {"choices": ["choice1", "choice2"]}, + {"choices": ["choice2"]}, + ] chunk1.model_dump.assert_called_once() chunk2.model_dump.assert_called_once() - + def test_direct_response_handling_without_choices(self, mock_azure_provider): mock_azure_provider.is_llama = False @@ -45,16 +51,16 @@ def test_direct_response_handling_without_choices(self, mock_azure_provider): assert results == [] chunk1.model_dump.assert_called_once() chunk2.model_dump.assert_called_once() - -class TestFormatMessage: + +class TestFormatMessage: def test_format_message_tool_calls(self, mock_azure_provider): message = { "tool_calls": [ { "function": { "name": "example_tool", - "arguments": '{"arg1": "value1"}' + "arguments": '{"arg1": "value1"}', } } ] @@ -67,12 +73,8 @@ def test_format_message_tool_calls(self, mock_azure_provider): """ assert result.strip() == expected.strip() - def test_format_message_tool_call_id(self, mock_azure_provider): - message = { - "tool_call_id": "123", - "content": "This is the tool response." - } + message = {"tool_call_id": "123", "content": "This is the tool response."} result = mock_azure_provider.format_message(message) expected = """ <|start_header_id|>ipython<|end_header_id|> @@ -85,7 +87,7 @@ def test_format_message_function_call(self, mock_azure_provider): message = { "function_call": { "name": "example_function", - "arguments": '{"arg1": "value1"}' + "arguments": '{"arg1": "value1"}', } } result = mock_azure_provider.format_message(message) @@ -97,10 +99,7 @@ def test_format_message_function_call(self, mock_azure_provider): assert result.strip() == expected.strip() def test_format_message_user_message(self, mock_azure_provider): - message = { - "role": "user", - "content": "This is a user message." - } + message = {"role": "user", "content": "This is a user message."} result = mock_azure_provider.format_message(message) expected = """ <|start_header_id|>user<|end_header_id|> @@ -110,10 +109,7 @@ def test_format_message_user_message(self, mock_azure_provider): assert result.strip() == expected.strip() def test_format_message_assistant_message(self, mock_azure_provider): - message = { - "role": "assistant", - "content": "This is an assistant message." - } + message = {"role": "assistant", "content": "This is an assistant message."} result = mock_azure_provider.format_message(message) expected = """ <|start_header_id|>assistant<|end_header_id|> @@ -123,10 +119,7 @@ def test_format_message_assistant_message(self, mock_azure_provider): assert result.strip() == expected.strip() def test_format_message_function_response(self, mock_azure_provider): - message = { - "role": "function", - "content": "This is the function response." - } + message = {"role": "function", "content": "This is the function response."} result = mock_azure_provider.format_message(message) expected = """ <|start_header_id|>ipython<|end_header_id|> @@ -136,19 +129,20 @@ def test_format_message_function_response(self, mock_azure_provider): assert result.strip() == expected.strip() def test_format_message_empty_message(self, mock_azure_provider): - message = { - "role": "user", - "content": None - } + message = {"role": "user", "content": None} result = mock_azure_provider.format_message(message) expected = "" assert result == expected - + + class TestGenerateClient: - def test_generate_client_with_tools_and_functions(self, mock_azure_provider): - mock_azure_provider.prepare_messages = MagicMock(return_value="prepared_messages") - mock_azure_provider._client.chat.completions.create = MagicMock(return_value="mock_response") + mock_azure_provider.prepare_messages = MagicMock( + return_value="prepared_messages" + ) + mock_azure_provider._client.chat.completions.create = MagicMock( + return_value="mock_response" + ) request = MagicMock() request.model = "gpt-4" @@ -173,11 +167,17 @@ def test_generate_client_with_tools_and_functions(self, mock_azure_provider): assert result == "mock_response" mock_azure_provider.prepare_messages.assert_called_once_with(request) - mock_azure_provider._client.chat.completions.create.assert_called_once_with(**expected_args) - + mock_azure_provider._client.chat.completions.create.assert_called_once_with( + **expected_args + ) + def test_generate_client_without_tools_or_functions(self, mock_azure_provider): - mock_azure_provider.prepare_messages = MagicMock(return_value="prepared_messages") - mock_azure_provider._client.chat.completions.create = MagicMock(return_value="mock_response") + mock_azure_provider.prepare_messages = MagicMock( + return_value="prepared_messages" + ) + mock_azure_provider._client.chat.completions.create = MagicMock( + return_value="mock_response" + ) request = MagicMock() request.model = "gpt-4" @@ -194,15 +194,25 @@ def test_generate_client_without_tools_or_functions(self, mock_azure_provider): assert result == "mock_response" mock_azure_provider.prepare_messages.assert_called_once_with(request) - mock_azure_provider._client.chat.completions.create.assert_called_once_with(**expected_args) + mock_azure_provider._client.chat.completions.create.assert_called_once_with( + **expected_args + ) def test_generate_client_with_llama_model(self, mock_azure_provider): - mock_azure_provider.prepare_messages = MagicMock(return_value="prepared_messages") - mock_azure_provider._client.chat.completions.create = MagicMock(return_value="mock_response") + mock_azure_provider.prepare_messages = MagicMock( + return_value="prepared_messages" + ) + mock_azure_provider._client.chat.completions.create = MagicMock( + return_value="mock_response" + ) request = MagicMock() request.model = "llama-2" - request.parameters = {"tools": ["tool1"], "functions": ["function1"], "other_param": "value"} + request.parameters = { + "tools": ["tool1"], + "functions": ["function1"], + "other_param": "value", + } result = mock_azure_provider.generate_client(request) @@ -217,4 +227,6 @@ def test_generate_client_with_llama_model(self, mock_azure_provider): assert result == "mock_response" mock_azure_provider.prepare_messages.assert_called_once_with(request) - mock_azure_provider._client.chat.completions.create.assert_called_once_with(**expected_args) \ No newline at end of file + mock_azure_provider._client.chat.completions.create.assert_called_once_with( + **expected_args + ) diff --git a/libs/core/tests/unit_tests/test_azure_build.py b/libs/core/tests/unit_tests/test_azure_build.py index 10a6c8d6..17d73f99 100644 --- a/libs/core/tests/unit_tests/test_azure_build.py +++ b/libs/core/tests/unit_tests/test_azure_build.py @@ -1,40 +1,55 @@ from unittest.mock import MagicMock, patch + class TestBuildLlamaSystemMessage: def test_build_llama_system_message_with_existing_sm(self, mock_azure_provider): - mock_azure_provider.build_tool_instructions = MagicMock(return_value="Tool Instructions") - mock_azure_provider.build_function_instructions = MagicMock(return_value="\nFunction Instructions") - + mock_azure_provider.build_tool_instructions = MagicMock( + return_value="Tool Instructions" + ) + mock_azure_provider.build_function_instructions = MagicMock( + return_value="\nFunction Instructions" + ) + openai_message = [ {"role": "user", "content": "Hello"}, - {"role": "system", "content": "Custom system message"} + {"role": "system", "content": "Custom system message"}, ] llama_message = "Initial message" tools = ["Tool1", "Tool2"] functions = ["Function1"] - result = mock_azure_provider.build_llama_system_message(openai_message, llama_message, tools, functions) + result = mock_azure_provider.build_llama_system_message( + openai_message, llama_message, tools, functions + ) expected = ( "Initial message\n" " <|start_header_id|>system<|end_header_id|>\n" " Custom system message\n" - " Tool Instructions\nFunction Instructions\n<|eot_id|>" # identation here exists because in Python when adding a newline to a triple quote string it keeps identation + " Tool Instructions\nFunction Instructions\n<|eot_id|>" # identation here exists because in Python when adding a newline to a triple quote string it keeps identation ) assert result == expected mock_azure_provider.build_tool_instructions.assert_called_once_with(tools) - mock_azure_provider.build_function_instructions.assert_called_once_with(functions) - + mock_azure_provider.build_function_instructions.assert_called_once_with( + functions + ) + def test_build_llama_system_message_with_default_sm(self, mock_azure_provider): - mock_azure_provider.build_tool_instructions = MagicMock(return_value="Tool Instructions") - mock_azure_provider.build_function_instructions = MagicMock(return_value="\nFunction Instructions") - + mock_azure_provider.build_tool_instructions = MagicMock( + return_value="Tool Instructions" + ) + mock_azure_provider.build_function_instructions = MagicMock( + return_value="\nFunction Instructions" + ) + openai_message = [{"role": "user", "content": "Hello"}] llama_message = "Initial message" tools = ["Tool1"] functions = [] - result = mock_azure_provider.build_llama_system_message(openai_message, llama_message, tools, functions) + result = mock_azure_provider.build_llama_system_message( + openai_message, llama_message, tools, functions + ) expected = ( "Initial message\n" @@ -45,17 +60,21 @@ def test_build_llama_system_message_with_default_sm(self, mock_azure_provider): assert result == expected mock_azure_provider.build_tool_instructions.assert_called_once_with(tools) mock_azure_provider.build_function_instructions.assert_not_called() - - def test_build_llama_system_message_without_tools_or_functions(self, mock_azure_provider): + + def test_build_llama_system_message_without_tools_or_functions( + self, mock_azure_provider + ): mock_azure_provider.build_tool_instructions = MagicMock() mock_azure_provider.build_function_instructions = MagicMock() - + openai_message = [{"role": "system", "content": "Minimal system message"}] llama_message = "Initial message" tools = [] functions = [] - result = mock_azure_provider.build_llama_system_message(openai_message, llama_message, tools, functions) + result = mock_azure_provider.build_llama_system_message( + openai_message, llama_message, tools, functions + ) expected = ( "Initial message\n" @@ -65,9 +84,9 @@ def test_build_llama_system_message_without_tools_or_functions(self, mock_azure_ assert result == expected mock_azure_provider.build_tool_instructions.assert_not_called() mock_azure_provider.build_function_instructions.assert_not_called() - + + class TestBuildInstructions: - def test_build_tool_instructions(self, mock_azure_provider): tools = [ { @@ -75,28 +94,22 @@ def test_build_tool_instructions(self, mock_azure_provider): "function": { "name": "python_repl_ast", "description": "execute Python code", - "parameters": { - "query": "string" - } - } + "parameters": {"query": "string"}, + }, }, { "type": "function", "function": { "name": "data_lookup", "description": "retrieve data from a database", - "parameters": { - "database": "string", - "query": "string" - } - } - } + "parameters": {"database": "string", "query": "string"}, + }, + }, ] result = mock_azure_provider.build_tool_instructions(tools) - expected = ( - """ + expected = """ You have access to the following tools: Use the function 'python_repl_ast' to 'execute Python code': Parameters format: @@ -128,32 +141,25 @@ def test_build_tool_instructions(self, mock_azure_provider): - If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls. - If you have already called a tool and got the response for the users question please reply with the response. """ - ) assert result.strip() == expected.strip() - + def test_build_function_instructions(self, mock_azure_provider): functions = [ { "name": "python_repl_ast", "description": "execute Python code", - "parameters": { - "query": "string" - } + "parameters": {"query": "string"}, }, { "name": "data_lookup", "description": "retrieve data from a database", - "parameters": { - "database": "string", - "query": "string" - } - } + "parameters": {"database": "string", "query": "string"}, + }, ] result = mock_azure_provider.build_function_instructions(functions) - expected = ( - """ + expected = """ You have access to the following functions: Use the function 'python_repl_ast' to: 'execute Python code' { @@ -182,62 +188,92 @@ def test_build_function_instructions(self, mock_azure_provider): - If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls. - If you have already called a function and got the response for the user's question, please reply with the response. """ - ) assert result.strip() == expected.strip() + class TestBuildLlamaConversation: - def test_build_llama_conversation_with_nested_messages(self, mock_azure_provider): - mock_azure_provider.format_message = MagicMock(side_effect=lambda msg: f"[formatted:{msg['content']}]") + mock_azure_provider.format_message = MagicMock( + side_effect=lambda msg: f"[formatted:{msg['content']}]" + ) openai_message = [ - {"role": "user", "content": "[{'content': 'nested message 1'}, {'content': 'nested message 2'}]"}, - {"role": "assistant", "content": "assistant reply"} + { + "role": "user", + "content": "[{'content': 'nested message 1'}, {'content': 'nested message 2'}]", + }, + {"role": "assistant", "content": "assistant reply"}, ] llama_message = "Initial message: " - result = mock_azure_provider.build_llama_conversation(openai_message, llama_message) + result = mock_azure_provider.build_llama_conversation( + openai_message, llama_message + ) expected = "Initial message: [formatted:nested message 1][formatted:nested message 2][formatted:assistant reply]" assert result == expected - mock_azure_provider.format_message.assert_any_call({"content": "nested message 1"}) - mock_azure_provider.format_message.assert_any_call({"content": "nested message 2"}) - mock_azure_provider.format_message.assert_any_call({"role": "assistant", "content": "assistant reply"}) + mock_azure_provider.format_message.assert_any_call( + {"content": "nested message 1"} + ) + mock_azure_provider.format_message.assert_any_call( + {"content": "nested message 2"} + ) + mock_azure_provider.format_message.assert_any_call( + {"role": "assistant", "content": "assistant reply"} + ) - def test_build_llama_conversation_with_invalid_nested_content(self, mock_azure_provider): - mock_azure_provider.format_message = MagicMock(side_effect=lambda msg: f"[formatted:{msg['content']}]") + def test_build_llama_conversation_with_invalid_nested_content( + self, mock_azure_provider + ): + mock_azure_provider.format_message = MagicMock( + side_effect=lambda msg: f"[formatted:{msg['content']}]" + ) openai_message = [ {"role": "user", "content": "[invalid json/dict]"}, - {"role": "assistant", "content": "assistant reply"} + {"role": "assistant", "content": "assistant reply"}, ] llama_message = "Initial message: " with patch("ast.literal_eval", side_effect=ValueError) as mock_literal_eval: - result = mock_azure_provider.build_llama_conversation(openai_message, llama_message) + result = mock_azure_provider.build_llama_conversation( + openai_message, llama_message + ) expected = "Initial message: [formatted:[invalid json/dict]][formatted:assistant reply]" assert result == expected - mock_azure_provider.format_message.assert_any_call({"role": "user", "content": "[invalid json/dict]"}) - mock_azure_provider.format_message.assert_any_call({"role": "assistant", "content": "assistant reply"}) + mock_azure_provider.format_message.assert_any_call( + {"role": "user", "content": "[invalid json/dict]"} + ) + mock_azure_provider.format_message.assert_any_call( + {"role": "assistant", "content": "assistant reply"} + ) mock_literal_eval.assert_called_once_with("[invalid json/dict]") - def test_build_llama_conversation_skipping_system_messages(self, mock_azure_provider): - mock_azure_provider.format_message = MagicMock(side_effect=lambda msg: f"[formatted:{msg['content']}]") + def test_build_llama_conversation_skipping_system_messages( + self, mock_azure_provider + ): + mock_azure_provider.format_message = MagicMock( + side_effect=lambda msg: f"[formatted:{msg['content']}]" + ) openai_message = [ {"role": "system", "content": "system message"}, - {"role": "user", "content": "user message"} + {"role": "user", "content": "user message"}, ] llama_message = "Initial message: " - result = mock_azure_provider.build_llama_conversation(openai_message, llama_message) + result = mock_azure_provider.build_llama_conversation( + openai_message, llama_message + ) expected = "Initial message: [formatted:user message]" assert result == expected - mock_azure_provider.format_message.assert_any_call({"role": "user", "content": "user message"}) \ No newline at end of file + mock_azure_provider.format_message.assert_any_call( + {"role": "user", "content": "user message"} + ) diff --git a/libs/core/tests/unit_tests/test_provider.py b/libs/core/tests/unit_tests/test_provider.py index 5b1d4e5b..e25396a7 100644 --- a/libs/core/tests/unit_tests/test_provider.py +++ b/libs/core/tests/unit_tests/test_provider.py @@ -1,11 +1,16 @@ -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import MagicMock import pytest -from llmstudio_core.providers.provider import ChatRequest, ProviderError, ChatCompletion, time, ChatCompletionChunk +from llmstudio_core.providers.provider import ( + ChatCompletion, + ChatRequest, + ProviderError, + time, +) request = ChatRequest(chat_input="Hello World", model="test_model") - + def test_chat_response_non_stream(mock_provider): mock_provider.validate_request = MagicMock() mock_provider.validate_model = MagicMock() @@ -17,19 +22,25 @@ def test_chat_response_non_stream(mock_provider): assert response == "final_response" mock_provider.validate_request.assert_called_once() mock_provider.validate_model.assert_called_once() - + + def test_chat_streaming_response(mock_provider): mock_provider.validate_request = MagicMock() mock_provider.validate_model = MagicMock() mock_provider.generate_client = MagicMock(return_value="mock_response_stream") - mock_provider.handle_response = MagicMock(return_value=iter(["streamed_response_1", "streamed_response_2"])) + mock_provider.handle_response = MagicMock( + return_value=iter(["streamed_response_1", "streamed_response_2"]) + ) - response_stream = mock_provider.chat(chat_input="Hello", model="test_model", is_stream=True) + response_stream = mock_provider.chat( + chat_input="Hello", model="test_model", is_stream=True + ) assert next(response_stream) == "streamed_response_1" assert next(response_stream) == "streamed_response_2" mock_provider.validate_request.assert_called_once() mock_provider.validate_model.assert_called_once() + def test_validate_model(mock_provider): request = ChatRequest(chat_input="Hello", model="test_model") mock_provider.validate_model(request) # Should not raise @@ -38,6 +49,7 @@ def test_validate_model(mock_provider): with pytest.raises(ProviderError): mock_provider.validate_model(request_invalid) + def test_calculate_metrics(mock_provider): metrics = mock_provider.calculate_metrics( @@ -56,10 +68,13 @@ def test_calculate_metrics(mock_provider): assert metrics["total_tokens"] == 3 assert metrics["cost_usd"] == 0.01 * 1 + 0.02 * 2 # input_cost + output_cost assert metrics["latency_s"] == 1.0 # end_time - start_time - assert metrics["time_to_first_token_s"] == 0.5 - 0.0 # first_token_time - start_time + assert ( + metrics["time_to_first_token_s"] == 0.5 - 0.0 + ) # first_token_time - start_time assert metrics["inter_token_latency_s"] == 0.1 # Average of token_times assert metrics["tokens_per_second"] == 2 / 1.0 # token_count / total_time - + + def test_calculate_metrics_single_token(mock_provider): metrics = mock_provider.calculate_metrics( @@ -82,11 +97,16 @@ def test_calculate_metrics_single_token(mock_provider): assert metrics["inter_token_latency_s"] == 0 assert metrics["tokens_per_second"] == 1 / 1.0 + @pytest.mark.asyncio async def test_ahandle_response_non_streaming(mock_provider): - request = MagicMock(is_stream=False, chat_input="Hello", model="test_model", parameters={}) + request = MagicMock( + is_stream=False, chat_input="Hello", model="test_model", parameters={} + ) response_chunk = { - "choices": [{"delta": {"content": "Non-streamed response"}, "finish_reason": "stop"}], + "choices": [ + {"delta": {"content": "Non-streamed response"}, "finish_reason": "stop"} + ], "model": "test_model", } start_time = time.time() @@ -95,11 +115,24 @@ async def mock_aparse_response(*args, **kwargs): yield response_chunk mock_provider.aparse_response = mock_aparse_response - mock_provider.join_chunks = MagicMock(return_value=(ChatCompletion(id="id", choices=[], created=0, model="test_model", object="chat.completion"), "Non-streamed response")) + mock_provider.join_chunks = MagicMock( + return_value=( + ChatCompletion( + id="id", + choices=[], + created=0, + model="test_model", + object="chat.completion", + ), + "Non-streamed response", + ) + ) mock_provider.calculate_metrics = MagicMock(return_value={"input_tokens": 1}) response = [] - async for chunk in mock_provider.ahandle_response(request, mock_aparse_response(), start_time): + async for chunk in mock_provider.ahandle_response( + request, mock_aparse_response(), start_time + ): response.append(chunk) assert isinstance(response[0], ChatCompletion) @@ -111,66 +144,76 @@ def test_handle_response_stop(mock_provider): current_time = int(time.time()) - response_generator = iter([ - { - "id": "test_id", - "model": "test_model", - "created": current_time, - "choices": [ - { - "delta": {"content": "Hello, "}, - "finish_reason": None, - "index": 0, - } - ], - }, - { - "id": "test_id", - "model": "test_model", - "created": current_time, - "choices": [ - { - "delta": {"content": "world!"}, - "finish_reason": "stop", - "index": 0, - } - ], - }, - ]) + response_generator = iter( + [ + { + "id": "test_id", + "model": "test_model", + "created": current_time, + "choices": [ + { + "delta": {"content": "Hello, "}, + "finish_reason": None, + "index": 0, + } + ], + }, + { + "id": "test_id", + "model": "test_model", + "created": current_time, + "choices": [ + { + "delta": {"content": "world!"}, + "finish_reason": "stop", + "index": 0, + } + ], + }, + ] + ) request = ChatRequest(chat_input="Hello", model="test_model") - result_generator = mock_provider.handle_response(request, response_generator, start_time=time.time()) + result_generator = mock_provider.handle_response( + request, response_generator, start_time=time.time() + ) result = next(result_generator) assert isinstance(result, ChatCompletion) assert result.choices[0].message.content == "Hello, world!" - + + def test_handle_response_stop_single_token(mock_provider): current_time = int(time.time()) - response_generator = iter([ - { - "id": "test_id", - "model": "test_model", - "created": current_time, - "choices": [ - { - "delta": {"content": "Hello, "}, - "finish_reason": "stop", - "index": 0, - } - ], - } - ]) + response_generator = iter( + [ + { + "id": "test_id", + "model": "test_model", + "created": current_time, + "choices": [ + { + "delta": {"content": "Hello, "}, + "finish_reason": "stop", + "index": 0, + } + ], + } + ] + ) request = ChatRequest(chat_input="Hello", model="test_model") - result_generator = mock_provider.handle_response(request, response_generator, start_time=time.time()) + result_generator = mock_provider.handle_response( + request, response_generator, start_time=time.time() + ) result = next(result_generator) assert isinstance(result, ChatCompletion) assert result.choices[0].message.content == "Hello, " - + + def test_join_chunks_finish_reason_stop(mock_provider): current_time = int(time.time()) chunks = [ @@ -204,6 +247,7 @@ def test_join_chunks_finish_reason_stop(mock_provider): assert output_string == "Hello, world!" assert response.choices[0].message.content == "Hello, world!" + def test_join_chunks_finish_reason_function_call(mock_provider): current_time = int(time.time()) chunks = [ @@ -213,7 +257,9 @@ def test_join_chunks_finish_reason_function_call(mock_provider): "created": current_time, "choices": [ { - "delta": {"function_call": {"name": "my_function", "arguments": "arg1"}}, + "delta": { + "function_call": {"name": "my_function", "arguments": "arg1"} + }, "finish_reason": None, "index": 0, } @@ -249,11 +295,11 @@ def test_join_chunks_finish_reason_function_call(mock_provider): assert output_string == "arg1arg2" assert response.choices[0].message.function_call.arguments == "arg1arg2" assert response.choices[0].message.function_call.name == "my_function" - - + + def test_join_chunks_tool_calls(mock_provider): current_time = int(time.time()) - + chunks = [ { "id": "test_id_1", @@ -266,15 +312,18 @@ def test_join_chunks_tool_calls(mock_provider): { "id": "tool_1", "index": 0, - "function": {"name": "search_tool", "arguments": "{\"query\": \"weather"}, - "type": "function" + "function": { + "name": "search_tool", + "arguments": '{"query": "weather', + }, + "type": "function", } ] }, "finish_reason": None, - "index": 0 + "index": 0, } - ] + ], }, { "id": "test_id_2", @@ -287,22 +336,23 @@ def test_join_chunks_tool_calls(mock_provider): { "id": "tool_1", "index": 0, - "function": {"name": "search_tool", "arguments": " details\"}"} + "function": { + "name": "search_tool", + "arguments": ' details"}', + }, } ] }, "finish_reason": "tool_calls", - "index": 0 + "index": 0, } - ] - } + ], + }, ] response, output_string = mock_provider.join_chunks(chunks) - - assert output_string == "['search_tool', '{\"query\": \"weather details\"}']" - + assert output_string == "['search_tool', '{\"query\": \"weather details\"}']" assert response.object == "chat.completion" assert response.choices[0].finish_reason == "tool_calls" @@ -310,9 +360,9 @@ def test_join_chunks_tool_calls(mock_provider): assert tool_call.id == "tool_1" assert tool_call.function.name == "search_tool" - assert tool_call.function.arguments == "{\"query\": \"weather details\"}" + assert tool_call.function.arguments == '{"query": "weather details"}' assert tool_call.type == "function" - + def test_calculate_cost_fixed_cost(mock_provider): fixed_cost = 0.02 @@ -320,6 +370,7 @@ def test_calculate_cost_fixed_cost(mock_provider): expected_cost = token_count * fixed_cost assert mock_provider.calculate_cost(token_count, fixed_cost) == expected_cost + def test_calculate_cost_variable_cost(mock_provider): cost_range_1 = MagicMock() cost_range_1.range = (0, 50) @@ -334,6 +385,7 @@ def test_calculate_cost_variable_cost(mock_provider): expected_cost = token_count * 0.02 assert mock_provider.calculate_cost(token_count, variable_cost) == expected_cost + def test_calculate_cost_variable_cost_higher_range(mock_provider): cost_range_1 = MagicMock() cost_range_1.range = (0, 50) @@ -352,6 +404,7 @@ def test_calculate_cost_variable_cost_higher_range(mock_provider): expected_cost = token_count * 0.03 assert mock_provider.calculate_cost(token_count, variable_cost) == expected_cost + def test_calculate_cost_variable_cost_no_matching_range(mock_provider): cost_range_1 = MagicMock() cost_range_1.range = (0, 50) @@ -369,7 +422,8 @@ def test_calculate_cost_variable_cost_no_matching_range(mock_provider): token_count = 200 expected_cost = 0 assert mock_provider.calculate_cost(token_count, variable_cost) == expected_cost - + + def test_calculate_cost_variable_cost_no_matching_range_inferior(mock_provider): cost_range_1 = MagicMock() cost_range_1.range = (10, 50) @@ -387,7 +441,8 @@ def test_calculate_cost_variable_cost_no_matching_range_inferior(mock_provider): token_count = 5 expected_cost = 0 assert mock_provider.calculate_cost(token_count, variable_cost) == expected_cost - + + def test_input_to_string_with_string(mock_provider): input_data = "Hello, world!" assert mock_provider.input_to_string(input_data) == "Hello, world!" @@ -404,7 +459,15 @@ def test_input_to_string_with_list_of_text_messages(mock_provider): def test_input_to_string_with_list_of_text_and_url(mock_provider): input_data = [ {"role": "user", "content": [{"type": "text", "text": "Hello "}]}, - {"role": "user", "content": [{"type": "image_url", "image_url": {"url": "http://example.com/image.jpg"}}]}, + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": "http://example.com/image.jpg"}, + } + ], + }, {"role": "user", "content": [{"type": "text", "text": " world!"}]}, ] expected_output = "Hello http://example.com/image.jpg world!" @@ -415,7 +478,15 @@ def test_input_to_string_with_mixed_roles_and_missing_content(mock_provider): input_data = [ {"role": "assistant", "content": "Admin text;"}, {"role": "user", "content": [{"type": "text", "text": "User text"}]}, - {"role": "user", "content": [{"type": "image_url", "image_url": {"url": "http://example.com/another.jpg"}}]}, + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": "http://example.com/another.jpg"}, + } + ], + }, ] expected_output = "Admin text;User texthttp://example.com/another.jpg" assert mock_provider.input_to_string(input_data) == expected_output @@ -426,11 +497,10 @@ def test_input_to_string_with_missing_content_key(mock_provider): {"role": "user"}, {"role": "user", "content": [{"type": "text", "text": "Hello again"}]}, ] - expected_output = "Hello again" + expected_output = "Hello again" assert mock_provider.input_to_string(input_data) == expected_output def test_input_to_string_with_empty_list(mock_provider): input_data = [] assert mock_provider.input_to_string(input_data) == "" -