diff --git a/.gitignore b/.gitignore index 456aa8d8..d80305ef 100644 --- a/.gitignore +++ b/.gitignore @@ -67,6 +67,7 @@ venv.bak/ config.yaml bun.lockb + # Jupyter Notebook .ipynb_checkpoints diff --git a/Makefile b/Makefile index 5a43a3e6..b9329a81 100644 --- a/Makefile +++ b/Makefile @@ -1,2 +1,5 @@ format: pre-commit run --all-files + +unit-tests: + pytest libs/core/tests/unit_tests \ No newline at end of file diff --git a/libs/core/llmstudio_core/providers/azure.py b/libs/core/llmstudio_core/providers/azure.py index 2dbd7307..f558f9d6 100644 --- a/libs/core/llmstudio_core/providers/azure.py +++ b/libs/core/llmstudio_core/providers/azure.py @@ -62,7 +62,25 @@ async def agenerate_client(self, request: ChatRequest) -> Any: return self.generate_client(request=request) def generate_client(self, request: ChatRequest) -> Any: - """Generate an AzureOpenAI client""" + """ + Generates an AzureOpenAI client for processing a chat request. + + This method prepares and configures the arguments required to create a client + request to AzureOpenAI's chat completions API. It determines model-specific + configurations (e.g., whether tools or functions are enabled) and combines + these with the base arguments for the API call. + + Args: + request (ChatRequest): The chat request object containing the model, + parameters, and other necessary details. + + Returns: + Any: The result of the chat completions API call. + + Raises: + ProviderError: If there is an issue with the API connection or an error + returned from the API. + """ self.is_llama = "llama" in request.model.lower() self.is_openai = "gpt" in request.model.lower() @@ -72,7 +90,6 @@ def generate_client(self, request: ChatRequest) -> Any: try: messages = self.prepare_messages(request) - # Prepare the optional tool-related arguments tool_args = {} if not self.is_llama and self.has_tools and self.is_openai: tool_args = { @@ -80,7 +97,6 @@ def generate_client(self, request: ChatRequest) -> Any: "tool_choice": "auto" if request.parameters.get("tools") else None, } - # Prepare the optional function-related arguments function_args = {} if not self.is_llama and self.has_functions and self.is_openai: function_args = { @@ -90,14 +106,12 @@ def generate_client(self, request: ChatRequest) -> Any: else None, } - # Prepare the base arguments base_args = { "model": request.model, "messages": messages, "stream": True, } - # Combine all arguments combined_args = { **base_args, **tool_args, @@ -116,13 +130,13 @@ def prepare_messages(self, request: ChatRequest): if self.is_llama and (self.has_tools or self.has_functions): user_message = self.convert_to_openai_format(request.chat_input) content = "<|begin_of_text|>" - content = self.add_system_message( + content = self.build_llama_system_message( user_message, content, request.parameters.get("tools"), request.parameters.get("functions"), ) - content = self.add_conversation(user_message, content) + content = self.build_llama_conversation(user_message, content) return [{"role": "user", "content": content}] else: return ( @@ -139,6 +153,20 @@ async def aparse_response( yield chunk def parse_response(self, response: AsyncGenerator, **kwargs) -> Any: + """ + Processes a generator response and yields processed chunks. + + If `is_llama` is True and tools or functions are enabled, it processes the response + using `handle_tool_response`. Otherwise, it processes each chunk and yields only those + containing "choices". + + Args: + response (Generator): The response generator to process. + **kwargs: Additional arguments for tool handling. + + Yields: + Any: Processed response chunks. + """ if self.is_llama and (self.has_tools or self.has_functions): for chunk in self.handle_tool_response(response, **kwargs): if chunk: @@ -388,9 +416,25 @@ def convert_to_openai_format(self, message: Union[str, list]) -> list: return [{"role": "user", "content": message}] return message - def add_system_message( + def build_llama_system_message( self, openai_message: list, llama_message: str, tools: list, functions: list ) -> str: + """ + Builds a complete system message for Llama based on OpenAI's message, tools, and functions. + + If a system message is present in the OpenAI message, it is included in the result. + Otherwise, a default system message is used. Additional tool and function instructions + are appended if provided. + + Args: + openai_message (list): List of OpenAI messages. + llama_message (str): The message to prepend to the system message. + tools (list): List of tools to include in the system message. + functions (list): List of functions to include in the system message. + + Returns: + str: The formatted system message combined with Llama message. + """ system_message = "" system_message_found = False for message in openai_message: @@ -407,15 +451,31 @@ def add_system_message( """ if tools: - system_message = system_message + self.add_tool_instructions(tools) + system_message = system_message + self.build_tool_instructions(tools) if functions: - system_message = system_message + self.add_function_instructions(functions) + system_message = system_message + self.build_function_instructions( + functions + ) end_tag = "\n<|eot_id|>" return llama_message + system_message + end_tag - def add_tool_instructions(self, tools: list) -> str: + def build_tool_instructions(self, tools: list) -> str: + """ + Builds a detailed instructional prompt for tools available to the assistant. + + This function generates a message describing the available tools, focusing on tools + of type "function." It explains to the LLM how to use each tool and provides an example of the + correct response format for function calls. + + Args: + tools (list): A list of tool dictionaries, where each dictionary contains tool + details such as type, function name, description, and parameters. + + Returns: + str: A formatted string detailing the tool instructions and usage examples. + """ tool_prompt = """ You have access to the following tools: """ @@ -449,7 +509,21 @@ def add_tool_instructions(self, tools: list) -> str: return tool_prompt - def add_function_instructions(self, functions: list) -> str: + def build_function_instructions(self, functions: list) -> str: + """ + Builds a detailed instructional prompt for available functions. + + This method creates a message describing the functions accessible to the assistant. + It includes the function name, description, and required parameters, along with + specific guidelines for calling functions. + + Args: + functions (list): A list of function dictionaries, each containing details such as + name, description, and parameters. + + Returns: + str: A formatted string with instructions on using the provided functions. + """ function_prompt = """ You have access to the following functions: """ @@ -479,35 +553,60 @@ def add_function_instructions(self, functions: list) -> str: """ return function_prompt - def add_conversation(self, openai_message: list, llama_message: str) -> str: + def build_llama_conversation(self, openai_message: list, llama_message: str) -> str: + """ + Appends the OpenAI message to the Llama message while formatting OpenAI messages. + + This function iterates through a list of OpenAI messages and formats them for inclusion + in a Llama message. It handles user messages that might include nested content (lists of + messages) by safely evaluating the content. System messages are skipped. + + Args: + openai_message (list): A list of dictionaries representing the OpenAI messages. Each + dictionary should have "role" and "content" keys. + llama_message (str): The initial Llama message to which the conversation is appended. + + Returns: + str: The Llama message with the conversation appended. + """ conversation_parts = [] for message in openai_message: if message["role"] == "system": continue elif message["role"] == "user" and isinstance(message["content"], str): try: - # Attempt to safely evaluate the string to a Python object content_as_list = ast.literal_eval(message["content"]) if isinstance(content_as_list, list): - # If the content is a list, process each nested message for nested_message in content_as_list: conversation_parts.append( self.format_message(nested_message) ) else: - # If the content is not a list, append it directly conversation_parts.append(self.format_message(message)) except (ValueError, SyntaxError): - # If evaluation fails or content is not a list/dict string, append the message directly conversation_parts.append(self.format_message(message)) else: - # For all other messages, use the existing formatting logic conversation_parts.append(self.format_message(message)) return llama_message + "".join(conversation_parts) def format_message(self, message: dict) -> str: - """Format a single message for the conversation.""" + """ + Formats a single message dictionary into a structured string for a conversation. + + The formatting depends on the content of the message, such as tool calls, + function calls, or simple user/assistant messages. Each type of message + is formatted with specific headers and tags. + + Args: + message (dict): A dictionary containing message details. Expected keys + include "role", "content", and optionally "tool_calls", + "tool_call_id", or "function_call". + + Returns: + str: A formatted string representing the message. Returns an empty + string if the message cannot be formatted. + """ if "tool_calls" in message: for tool_call in message["tool_calls"]: function_name = tool_call["function"]["name"] diff --git a/libs/core/tests/unit_tests/conftest.py b/libs/core/tests/unit_tests/conftest.py index df8e04cc..5eeccc83 100644 --- a/libs/core/tests/unit_tests/conftest.py +++ b/libs/core/tests/unit_tests/conftest.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock import pytest +from llmstudio_core.providers.azure import AzureProvider from llmstudio_core.providers.provider import ProviderCore @@ -52,3 +53,27 @@ def mock_provider(): tokenizer = MagicMock() tokenizer.encode = lambda x: x.split() # Simple tokenizer mock return MockProvider(config=config, tokenizer=tokenizer) + + +class MockAzureProvider(AzureProvider): + async def aparse_response(self, response, **kwargs): + return response + + async def agenerate_client(self, request): + # For testing, return an async generator + async def async_gen(): + yield {} + + return async_gen() + + @staticmethod + def _provider_config_name(): + return "mock_azure_provider" + + +@pytest.fixture +def mock_azure_provider(): + config = MagicMock() + config.id = "mock_azure_provider" + base_url = "mock_url.com" + return MockAzureProvider(config=config, base_url=base_url) diff --git a/libs/core/tests/unit_tests/test_azure.py b/libs/core/tests/unit_tests/test_azure.py new file mode 100644 index 00000000..28e5f684 --- /dev/null +++ b/libs/core/tests/unit_tests/test_azure.py @@ -0,0 +1,232 @@ +from unittest.mock import MagicMock + + +class TestParseResponse: + def test_tool_response_handling(self, mock_azure_provider): + + mock_azure_provider.is_llama = True + mock_azure_provider.has_tools = True + mock_azure_provider.has_functions = False + + mock_azure_provider.handle_tool_response = MagicMock( + return_value=iter(["chunk1", None, "chunk2", None, "chunk3"]) + ) + + response = iter(["irrelevant"]) + + results = list(mock_azure_provider.parse_response(response)) + + assert results == ["chunk1", "chunk2", "chunk3"] + mock_azure_provider.handle_tool_response.assert_called_once_with(response) + + def test_direct_response_handling_with_choices(self, mock_azure_provider): + mock_azure_provider.is_llama = False + + chunk1 = MagicMock() + chunk1.model_dump.return_value = {"choices": ["choice1", "choice2"]} + chunk2 = MagicMock() + chunk2.model_dump.return_value = {"choices": ["choice2"]} + response = iter([chunk1, chunk2]) + + results = list(mock_azure_provider.parse_response(response)) + + assert results == [ + {"choices": ["choice1", "choice2"]}, + {"choices": ["choice2"]}, + ] + chunk1.model_dump.assert_called_once() + chunk2.model_dump.assert_called_once() + + def test_direct_response_handling_without_choices(self, mock_azure_provider): + mock_azure_provider.is_llama = False + + chunk1 = MagicMock() + chunk1.model_dump.return_value = {"key": "value"} + chunk2 = MagicMock() + chunk2.model_dump.return_value = {"another_key": "another_value"} + response = iter([chunk1, chunk2]) + + results = list(mock_azure_provider.parse_response(response)) + + assert results == [] + chunk1.model_dump.assert_called_once() + chunk2.model_dump.assert_called_once() + + +class TestFormatMessage: + def test_format_message_tool_calls(self, mock_azure_provider): + message = { + "tool_calls": [ + { + "function": { + "name": "example_tool", + "arguments": '{"arg1": "value1"}', + } + } + ] + } + result = mock_azure_provider.format_message(message) + expected = """ + <|start_header_id|>assistant<|end_header_id|> + {"arg1": "value1"} + <|eom_id|> + """ + assert result.strip() == expected.strip() + + def test_format_message_tool_call_id(self, mock_azure_provider): + message = {"tool_call_id": "123", "content": "This is the tool response."} + result = mock_azure_provider.format_message(message) + expected = """ + <|start_header_id|>ipython<|end_header_id|> + This is the tool response. + <|eot_id|> + """ + assert result.strip() == expected.strip() + + def test_format_message_function_call(self, mock_azure_provider): + message = { + "function_call": { + "name": "example_function", + "arguments": '{"arg1": "value1"}', + } + } + result = mock_azure_provider.format_message(message) + expected = """ + <|start_header_id|>assistant<|end_header_id|> + {"arg1": "value1"} + <|eom_id|> + """ + assert result.strip() == expected.strip() + + def test_format_message_user_message(self, mock_azure_provider): + message = {"role": "user", "content": "This is a user message."} + result = mock_azure_provider.format_message(message) + expected = """ + <|start_header_id|>user<|end_header_id|> + This is a user message. + <|eot_id|> + """ + assert result.strip() == expected.strip() + + def test_format_message_assistant_message(self, mock_azure_provider): + message = {"role": "assistant", "content": "This is an assistant message."} + result = mock_azure_provider.format_message(message) + expected = """ + <|start_header_id|>assistant<|end_header_id|> + This is an assistant message. + <|eot_id|> + """ + assert result.strip() == expected.strip() + + def test_format_message_function_response(self, mock_azure_provider): + message = {"role": "function", "content": "This is the function response."} + result = mock_azure_provider.format_message(message) + expected = """ + <|start_header_id|>ipython<|end_header_id|> + This is the function response. + <|eot_id|> + """ + assert result.strip() == expected.strip() + + def test_format_message_empty_message(self, mock_azure_provider): + message = {"role": "user", "content": None} + result = mock_azure_provider.format_message(message) + expected = "" + assert result == expected + + +class TestGenerateClient: + def test_generate_client_with_tools_and_functions(self, mock_azure_provider): + mock_azure_provider.prepare_messages = MagicMock( + return_value="prepared_messages" + ) + mock_azure_provider._client.chat.completions.create = MagicMock( + return_value="mock_response" + ) + + request = MagicMock() + request.model = "gpt-4" + request.parameters = { + "tools": ["tool1", "tool2"], + "functions": ["function1", "function2"], + "other_param": "value", + } + + result = mock_azure_provider.generate_client(request) + + expected_args = { + "model": "gpt-4", + "messages": "prepared_messages", + "stream": True, + "tools": ["tool1", "tool2"], + "tool_choice": "auto", + "functions": ["function1", "function2"], + "function_call": "auto", + "other_param": "value", + } + + assert result == "mock_response" + mock_azure_provider.prepare_messages.assert_called_once_with(request) + mock_azure_provider._client.chat.completions.create.assert_called_once_with( + **expected_args + ) + + def test_generate_client_without_tools_or_functions(self, mock_azure_provider): + mock_azure_provider.prepare_messages = MagicMock( + return_value="prepared_messages" + ) + mock_azure_provider._client.chat.completions.create = MagicMock( + return_value="mock_response" + ) + + request = MagicMock() + request.model = "gpt-4" + request.parameters = {"other_param": "value"} + + result = mock_azure_provider.generate_client(request) + + expected_args = { + "model": "gpt-4", + "messages": "prepared_messages", + "stream": True, + "other_param": "value", + } + + assert result == "mock_response" + mock_azure_provider.prepare_messages.assert_called_once_with(request) + mock_azure_provider._client.chat.completions.create.assert_called_once_with( + **expected_args + ) + + def test_generate_client_with_llama_model(self, mock_azure_provider): + mock_azure_provider.prepare_messages = MagicMock( + return_value="prepared_messages" + ) + mock_azure_provider._client.chat.completions.create = MagicMock( + return_value="mock_response" + ) + + request = MagicMock() + request.model = "llama-2" + request.parameters = { + "tools": ["tool1"], + "functions": ["function1"], + "other_param": "value", + } + + result = mock_azure_provider.generate_client(request) + + expected_args = { + "model": "llama-2", + "messages": "prepared_messages", + "stream": True, + "tools": ["tool1"], + "functions": ["function1"], + "other_param": "value", + } + + assert result == "mock_response" + mock_azure_provider.prepare_messages.assert_called_once_with(request) + mock_azure_provider._client.chat.completions.create.assert_called_once_with( + **expected_args + ) diff --git a/libs/core/tests/unit_tests/test_azure_build.py b/libs/core/tests/unit_tests/test_azure_build.py new file mode 100644 index 00000000..17d73f99 --- /dev/null +++ b/libs/core/tests/unit_tests/test_azure_build.py @@ -0,0 +1,279 @@ +from unittest.mock import MagicMock, patch + + +class TestBuildLlamaSystemMessage: + def test_build_llama_system_message_with_existing_sm(self, mock_azure_provider): + mock_azure_provider.build_tool_instructions = MagicMock( + return_value="Tool Instructions" + ) + mock_azure_provider.build_function_instructions = MagicMock( + return_value="\nFunction Instructions" + ) + + openai_message = [ + {"role": "user", "content": "Hello"}, + {"role": "system", "content": "Custom system message"}, + ] + llama_message = "Initial message" + tools = ["Tool1", "Tool2"] + functions = ["Function1"] + + result = mock_azure_provider.build_llama_system_message( + openai_message, llama_message, tools, functions + ) + + expected = ( + "Initial message\n" + " <|start_header_id|>system<|end_header_id|>\n" + " Custom system message\n" + " Tool Instructions\nFunction Instructions\n<|eot_id|>" # identation here exists because in Python when adding a newline to a triple quote string it keeps identation + ) + assert result == expected + mock_azure_provider.build_tool_instructions.assert_called_once_with(tools) + mock_azure_provider.build_function_instructions.assert_called_once_with( + functions + ) + + def test_build_llama_system_message_with_default_sm(self, mock_azure_provider): + mock_azure_provider.build_tool_instructions = MagicMock( + return_value="Tool Instructions" + ) + mock_azure_provider.build_function_instructions = MagicMock( + return_value="\nFunction Instructions" + ) + + openai_message = [{"role": "user", "content": "Hello"}] + llama_message = "Initial message" + tools = ["Tool1"] + functions = [] + + result = mock_azure_provider.build_llama_system_message( + openai_message, llama_message, tools, functions + ) + + expected = ( + "Initial message\n" + " <|start_header_id|>system<|end_header_id|>\n" + " You are a helpful AI assistant.\n" + " Tool Instructions\n<|eot_id|>" + ) + assert result == expected + mock_azure_provider.build_tool_instructions.assert_called_once_with(tools) + mock_azure_provider.build_function_instructions.assert_not_called() + + def test_build_llama_system_message_without_tools_or_functions( + self, mock_azure_provider + ): + mock_azure_provider.build_tool_instructions = MagicMock() + mock_azure_provider.build_function_instructions = MagicMock() + + openai_message = [{"role": "system", "content": "Minimal system message"}] + llama_message = "Initial message" + tools = [] + functions = [] + + result = mock_azure_provider.build_llama_system_message( + openai_message, llama_message, tools, functions + ) + + expected = ( + "Initial message\n" + " <|start_header_id|>system<|end_header_id|>\n" + " Minimal system message\n \n<|eot_id|>" + ) + assert result == expected + mock_azure_provider.build_tool_instructions.assert_not_called() + mock_azure_provider.build_function_instructions.assert_not_called() + + +class TestBuildInstructions: + def test_build_tool_instructions(self, mock_azure_provider): + tools = [ + { + "type": "function", + "function": { + "name": "python_repl_ast", + "description": "execute Python code", + "parameters": {"query": "string"}, + }, + }, + { + "type": "function", + "function": { + "name": "data_lookup", + "description": "retrieve data from a database", + "parameters": {"database": "string", "query": "string"}, + }, + }, + ] + + result = mock_azure_provider.build_tool_instructions(tools) + + expected = """ + You have access to the following tools: + Use the function 'python_repl_ast' to 'execute Python code': +Parameters format: +{ + "query": "string" +} + +Use the function 'data_lookup' to 'retrieve data from a database': +Parameters format: +{ + "database": "string", + "query": "string" +} + + +If you choose to use a function to produce this response, ONLY reply in the following format with no prefix or suffix: +§{"type": "function", "name": "FUNCTION_NAME", "parameters": {"PARAMETER_NAME": PARAMETER_VALUE}} +IMPORTANT: IT IS VITAL THAT YOU NEVER ADD A PREFIX OR A SUFFIX TO THE FUNCTION CALL. + +Here is an example of the output I desiere when performing function call: +§{"type": "function", "name": "python_repl_ast", "parameters": {"query": "print(df.shape)"}} +NOTE: There is no prefix before the symbol '§' and nothing comes after the call is done. + + Reminder: + - Function calls MUST follow the specified format. + - Only call one function at a time. + - Required parameters MUST be specified. + - Put the entire function call reply on one line. + - If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls. + - If you have already called a tool and got the response for the users question please reply with the response. + """ + assert result.strip() == expected.strip() + + def test_build_function_instructions(self, mock_azure_provider): + functions = [ + { + "name": "python_repl_ast", + "description": "execute Python code", + "parameters": {"query": "string"}, + }, + { + "name": "data_lookup", + "description": "retrieve data from a database", + "parameters": {"database": "string", "query": "string"}, + }, + ] + + result = mock_azure_provider.build_function_instructions(functions) + + expected = """ +You have access to the following functions: +Use the function 'python_repl_ast' to: 'execute Python code' +{ + "query": "string" +} + +Use the function 'data_lookup' to: 'retrieve data from a database' +{ + "database": "string", + "query": "string" +} + + +If you choose to use a function to produce this response, ONLY reply in the following format with no prefix or suffix: +§{"type": "function", "name": "FUNCTION_NAME", "parameters": {"PARAMETER_NAME": PARAMETER_VALUE}} + +Here is an example of the output I desiere when performing function call: +§{"type": "function", "name": "python_repl_ast", "parameters": {"query": "print(df.shape)"}} + +Reminder: +- Function calls MUST follow the specified format. +- Only call one function at a time. +- NEVER call more than one function at a time. +- Required parameters MUST be specified. +- Put the entire function call reply on one line. +- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls. +- If you have already called a function and got the response for the user's question, please reply with the response. +""" + + assert result.strip() == expected.strip() + + +class TestBuildLlamaConversation: + def test_build_llama_conversation_with_nested_messages(self, mock_azure_provider): + mock_azure_provider.format_message = MagicMock( + side_effect=lambda msg: f"[formatted:{msg['content']}]" + ) + + openai_message = [ + { + "role": "user", + "content": "[{'content': 'nested message 1'}, {'content': 'nested message 2'}]", + }, + {"role": "assistant", "content": "assistant reply"}, + ] + llama_message = "Initial message: " + + result = mock_azure_provider.build_llama_conversation( + openai_message, llama_message + ) + + expected = "Initial message: [formatted:nested message 1][formatted:nested message 2][formatted:assistant reply]" + + assert result == expected + mock_azure_provider.format_message.assert_any_call( + {"content": "nested message 1"} + ) + mock_azure_provider.format_message.assert_any_call( + {"content": "nested message 2"} + ) + mock_azure_provider.format_message.assert_any_call( + {"role": "assistant", "content": "assistant reply"} + ) + + def test_build_llama_conversation_with_invalid_nested_content( + self, mock_azure_provider + ): + mock_azure_provider.format_message = MagicMock( + side_effect=lambda msg: f"[formatted:{msg['content']}]" + ) + + openai_message = [ + {"role": "user", "content": "[invalid json/dict]"}, + {"role": "assistant", "content": "assistant reply"}, + ] + llama_message = "Initial message: " + + with patch("ast.literal_eval", side_effect=ValueError) as mock_literal_eval: + result = mock_azure_provider.build_llama_conversation( + openai_message, llama_message + ) + + expected = "Initial message: [formatted:[invalid json/dict]][formatted:assistant reply]" + + assert result == expected + mock_azure_provider.format_message.assert_any_call( + {"role": "user", "content": "[invalid json/dict]"} + ) + mock_azure_provider.format_message.assert_any_call( + {"role": "assistant", "content": "assistant reply"} + ) + + mock_literal_eval.assert_called_once_with("[invalid json/dict]") + + def test_build_llama_conversation_skipping_system_messages( + self, mock_azure_provider + ): + mock_azure_provider.format_message = MagicMock( + side_effect=lambda msg: f"[formatted:{msg['content']}]" + ) + + openai_message = [ + {"role": "system", "content": "system message"}, + {"role": "user", "content": "user message"}, + ] + llama_message = "Initial message: " + + result = mock_azure_provider.build_llama_conversation( + openai_message, llama_message + ) + + expected = "Initial message: [formatted:user message]" + + assert result == expected + mock_azure_provider.format_message.assert_any_call( + {"role": "user", "content": "user message"} + )