From a4331f7b56700940836bc406f674df7fc03d3e77 Mon Sep 17 00:00:00 2001 From: Miguel Neves Date: Tue, 12 Nov 2024 19:24:26 +0000 Subject: [PATCH 01/13] chore: added unit tests for core provider. small bugfix on calculate_metrics of provider --- .../core/llmstudio_core/providers/provider.py | 4 +- libs/core/tests/unit_tests/conftest.py | 16 +++ libs/core/tests/unit_tests/test_provider.py | 120 ++++++++++++++++-- 3 files changed, 125 insertions(+), 15 deletions(-) diff --git a/libs/core/llmstudio_core/providers/provider.py b/libs/core/llmstudio_core/providers/provider.py index c06ce998..27c79ced 100644 --- a/libs/core/llmstudio_core/providers/provider.py +++ b/libs/core/llmstudio_core/providers/provider.py @@ -623,8 +623,8 @@ def calculate_metrics( "cost_usd": input_cost + output_cost, "latency_s": total_time, "time_to_first_token_s": first_token_time - start_time, - "inter_token_latency_s": sum(token_times) / len(token_times), - "tokens_per_second": token_count / total_time, + "inter_token_latency_s": sum(token_times) / len(token_times) if token_times else 0, + "tokens_per_second": token_count / total_time if token_times else 1 / total_time, } def calculate_cost( diff --git a/libs/core/tests/unit_tests/conftest.py b/libs/core/tests/unit_tests/conftest.py index 23f070e3..1a261eac 100644 --- a/libs/core/tests/unit_tests/conftest.py +++ b/libs/core/tests/unit_tests/conftest.py @@ -26,6 +26,22 @@ def output_to_string(self, output): if output.choices[0].finish_reason == "stop": return output.choices[0].message.content return "" + + def validate_request(self, request): + # For testing, simply return the request + return request + + async def agenerate_client(self, request): + # For testing, return an async generator + async def async_gen(): + yield {} + return async_gen() + + def generate_client(self, request): + # For testing, return a generator + def gen(): + yield {} + return gen() @staticmethod def _provider_config_name(): diff --git a/libs/core/tests/unit_tests/test_provider.py b/libs/core/tests/unit_tests/test_provider.py index 118367a3..10f8178e 100644 --- a/libs/core/tests/unit_tests/test_provider.py +++ b/libs/core/tests/unit_tests/test_provider.py @@ -1,7 +1,7 @@ from unittest.mock import AsyncMock, MagicMock import pytest -from llmstudio_core.providers.provider import ChatRequest, ProviderError +from llmstudio_core.providers.provider import ChatRequest, ProviderError, ChatCompletion, time request = ChatRequest(chat_input="Hello", model="test_model") @@ -35,23 +35,117 @@ def test_validate_model(mock_provider): with pytest.raises(ProviderError): mock_provider.validate_model(request_invalid) - def test_calculate_metrics(mock_provider): + + mock_provider.tokenizer.encode = lambda x: x.split() # Assuming tokenizer splits "Hello" and "World" into one token each + metrics = mock_provider.calculate_metrics( input="Hello", - output="World", + output="Hello World", model="test_model", - start_time=0, - end_time=1, + start_time=0.0, + end_time=1.0, first_token_time=0.5, - token_times=(0.1, 0.2), + token_times=(0.1,), token_count=2, ) - assert metrics["input_tokens"] == pytest.approx(1) - assert metrics["output_tokens"] == pytest.approx(1) - assert metrics["cost_usd"] == pytest.approx(0.03) - assert metrics["latency_s"] == pytest.approx(1) - assert metrics["time_to_first_token_s"] == pytest.approx(0.5) - assert metrics["inter_token_latency_s"] == pytest.approx(0.15) - assert metrics["tokens_per_second"] == pytest.approx(2) + assert metrics["input_tokens"] == 1 + assert metrics["output_tokens"] == 2 + assert metrics["total_tokens"] == 3 + assert metrics["cost_usd"] == 0.01 * 1 + 0.02 * 2 # input_cost + output_cost + assert metrics["latency_s"] == 1.0 # end_time - start_time + assert metrics["time_to_first_token_s"] == 0.5 - 0.0 # first_token_time - start_time + assert metrics["inter_token_latency_s"] == 0.1 # Average of token_times + assert metrics["tokens_per_second"] == 2 / 1.0 # token_count / total_time + +def test_calculate_metrics_single_token(mock_provider): + + mock_provider.tokenizer.encode = lambda x: x.split() + + metrics = mock_provider.calculate_metrics( + input="Hello", + output="World", + model="test_model", + start_time=0.0, + end_time=1.0, + first_token_time=0.5, + token_times=(), + token_count=1, + ) + + assert metrics["input_tokens"] == 1 + assert metrics["output_tokens"] == 1 + assert metrics["total_tokens"] == 2 + assert metrics["cost_usd"] == 0.01 * 1 + 0.02 * 1 + assert metrics["latency_s"] == 1.0 + assert metrics["time_to_first_token_s"] == 0.5 - 0.0 + assert metrics["inter_token_latency_s"] == 0 + assert metrics["tokens_per_second"] == 1 / 1.0 + +def test_handle_response_stop(mock_provider): + + current_time = int(time.time()) + + response_generator = iter([ + { + "id": "test_id", + "model": "test_model", + "created": current_time, + "choices": [ + { + "delta": {"content": "Hello, "}, + "finish_reason": None, + "index": 0, + } + ], + }, + { + "id": "test_id", + "model": "test_model", + "created": current_time, + "choices": [ + { + "delta": {"content": "world!"}, + "finish_reason": "stop", + "index": 0, + } + ], + }, + ]) + + request = ChatRequest(chat_input="Hello", model="test_model") + result_generator = mock_provider.handle_response(request, response_generator, start_time=time.time()) + result = next(result_generator) + + assert isinstance(result, ChatCompletion) + assert result.choices[0].message.content == "Hello, world!" + +def test_handle_response_stop_single_token(mock_provider): + """ + testing single token answer. token_times var will be 0 + """ + + current_time = int(time.time()) + + response_generator = iter([ + { + "id": "test_id", + "model": "test_model", + "created": current_time, + "choices": [ + { + "delta": {"content": "Hello, "}, + "finish_reason": "stop", + "index": 0, + } + ], + } + ]) + + request = ChatRequest(chat_input="Hello", model="test_model") + result_generator = mock_provider.handle_response(request, response_generator, start_time=time.time()) + result = next(result_generator) + + assert isinstance(result, ChatCompletion) + assert result.choices[0].message.content == "Hello, " \ No newline at end of file From ca7c52ce70d1b25c359a3c69dd65d4e8e52aeba9 Mon Sep 17 00:00:00 2001 From: Miguel Neves Date: Wed, 13 Nov 2024 09:53:11 +0000 Subject: [PATCH 02/13] added unit tests and docstring for join chunks --- .../core/llmstudio_core/providers/provider.py | 52 ++++++- libs/core/tests/unit_tests/test_provider.py | 144 +++++++++++++++++- 2 files changed, 191 insertions(+), 5 deletions(-) diff --git a/libs/core/llmstudio_core/providers/provider.py b/libs/core/llmstudio_core/providers/provider.py index 27c79ced..8ac18053 100644 --- a/libs/core/llmstudio_core/providers/provider.py +++ b/libs/core/llmstudio_core/providers/provider.py @@ -289,7 +289,7 @@ async def ahandle_response( chunks = [chunk[0] if isinstance(chunk, tuple) else chunk for chunk in chunks] model = next(chunk["model"] for chunk in chunks if chunk.get("model")) - response, output_string = self.join_chunks(chunks, request) + response, output_string = self.join_chunks(chunks) metrics = self.calculate_metrics( request.chat_input, @@ -341,7 +341,29 @@ async def ahandle_response( def handle_response( self, request: ChatRequest, response: Generator, start_time: float ) -> Generator: - """Handles the response from an API""" + """ + Processes API response chunks to build a structured, complete response, yielding + each chunk if streaming is enabled. + + If streaming, each chunk is yielded as soon as it’s processed. Otherwise, all chunks + are combined and yielded as a single response at the end. + + Parameters + ---------- + request : ChatRequest + The original request details, including model, input, and streaming preference. + response : Generator + A generator yielding partial response chunks from the API. + start_time : float + The start time for measuring response timing. + + Yields + ------ + Union[ChatCompletionChunk, ChatCompletion] + If streaming (`is_stream=True`), yields each `ChatCompletionChunk` as it’s processed. + Otherwise, yields a single `ChatCompletion` with the full response data. + + """ first_token_time = None previous_token_time = None token_times = [] @@ -397,7 +419,7 @@ def handle_response( chunks = [chunk[0] if isinstance(chunk, tuple) else chunk for chunk in chunks] model = next(chunk["model"] for chunk in chunks if chunk.get("model")) - response, output_string = self.join_chunks(chunks, request) + response, output_string = self.join_chunks(chunks) metrics = self.calculate_metrics( request.chat_input, @@ -446,7 +468,29 @@ def handle_response( else: yield ChatCompletion(**response) - def join_chunks(self, chunks, request): + def join_chunks(self, chunks): + """ + Combine multiple response chunks from the model into a single, structured response. + Handles tool calls, function calls, and standard text completion based on the + purpose indicated by the final chunk. + + Parameters + ---------- + chunks : List[Dict] + A list of partial responses (chunks) from the model. + + Returns + ------- + Tuple[ChatCompletion, str] + - `ChatCompletion`: The structured response based on the type of completion + (tool calls, function call, or text). + - `str`: The concatenated content or arguments, depending on the completion type. + + Raises + ------ + Exception + If there is an issue constructing the response, an exception is raised. + """ finish_reason = chunks[-1].get("choices")[0].get("finish_reason") if finish_reason == "tool_calls": diff --git a/libs/core/tests/unit_tests/test_provider.py b/libs/core/tests/unit_tests/test_provider.py index 10f8178e..5a89138d 100644 --- a/libs/core/tests/unit_tests/test_provider.py +++ b/libs/core/tests/unit_tests/test_provider.py @@ -148,4 +148,146 @@ def test_handle_response_stop_single_token(mock_provider): result = next(result_generator) assert isinstance(result, ChatCompletion) - assert result.choices[0].message.content == "Hello, " \ No newline at end of file + assert result.choices[0].message.content == "Hello, " + +def test_join_chunks_finish_reason_stop(mock_provider): + current_time = int(time.time()) + chunks = [ + { + "id": "test_id", + "model": "test_model", + "created": current_time, + "choices": [ + { + "delta": {"content": "Hello, "}, + "finish_reason": None, + "index": 0, + } + ], + }, + { + "id": "test_id", + "model": "test_model", + "created": current_time, + "choices": [ + { + "delta": {"content": "world!"}, + "finish_reason": "stop", + "index": 0, + } + ], + }, + ] + response, output_string = mock_provider.join_chunks(chunks) + + assert output_string == "Hello, world!" + assert response.choices[0].message.content == "Hello, world!" + +def test_join_chunks_finish_reason_function_call(mock_provider): + current_time = int(time.time()) + chunks = [ + { + "id": "test_id", + "model": "test_model", + "created": current_time, + "choices": [ + { + "delta": {"function_call": {"name": "my_function", "arguments": "arg1"}}, + "finish_reason": None, + "index": 0, + } + ], + }, + { + "id": "test_id", + "model": "test_model", + "created": current_time, + "choices": [ + { + "delta": {"function_call": {"arguments": "arg2"}}, + "finish_reason": "function_call", + "index": 0, + } + ], + }, + { + "id": "test_id", + "model": "test_model", + "created": current_time, + "choices": [ + { + "delta": {"function_call": {"arguments": "}"}}, + "finish_reason": "function_call", + "index": 0, + } + ], + }, + ] + response, output_string = mock_provider.join_chunks(chunks) + + assert output_string == "arg1arg2" + assert response.choices[0].message.function_call.arguments == "arg1arg2" + assert response.choices[0].message.function_call.name == "my_function" + + +def test_join_chunks_tool_calls(mock_provider): + current_time = int(time.time()) + + chunks = [ + { + "id": "test_id_1", + "model": "test_model", + "created": current_time, + "choices": [ + { + "delta": { + "tool_calls": [ + { + "id": "tool_1", + "index": 0, + "function": {"name": "search_tool", "arguments": "{\"query\": \"weather"}, + "type": "function" + } + ] + }, + "finish_reason": None, + "index": 0 + } + ] + }, + { + "id": "test_id_2", + "model": "test_model", + "created": current_time, + "choices": [ + { + "delta": { + "tool_calls": [ + { + "id": "tool_1", + "index": 0, + "function": {"name": "search_tool", "arguments": " details\"}"} + } + ] + }, + "finish_reason": "tool_calls", + "index": 0 + } + ] + } + ] + + response, output_string = mock_provider.join_chunks(chunks) + + assert output_string == "['search_tool', '{\"query\": \"weather details\"}']" + + + + assert response.object == "chat.completion" + assert response.choices[0].finish_reason == "tool_calls" + tool_call = response.choices[0].message.tool_calls[0] + + assert tool_call.id == "tool_1" + assert tool_call.function.name == "search_tool" + assert tool_call.function.arguments == "{\"query\": \"weather details\"}" + assert tool_call.type == "function" \ No newline at end of file From 3ed1c1d92b5c71ec1238242175ecdf517d447ef6 Mon Sep 17 00:00:00 2001 From: Miguel Neves Date: Thu, 14 Nov 2024 10:35:35 +0000 Subject: [PATCH 03/13] added unit tests and docstrings for calculate_cost on provider --- .../core/llmstudio_core/providers/provider.py | 60 ++++++++++++++- libs/core/tests/unit_tests/test_provider.py | 77 ++++++++++++++++++- 2 files changed, 134 insertions(+), 3 deletions(-) diff --git a/libs/core/llmstudio_core/providers/provider.py b/libs/core/llmstudio_core/providers/provider.py index 8ac18053..b2467d72 100644 --- a/libs/core/llmstudio_core/providers/provider.py +++ b/libs/core/llmstudio_core/providers/provider.py @@ -651,7 +651,42 @@ def calculate_metrics( token_times: Tuple[float, ...], token_count: int, ) -> Dict[str, Any]: - """Calculates metrics based on token times and output""" + """ + Calculates performance and cost metrics for a model response based on timing + information, token counts, and model-specific costs. + + Parameters + ---------- + input : Any + The input provided to the model, used to determine input token count. + output : Any + The output generated by the model, used to determine output token count. + model : str + The model identifier, used to retrieve model-specific configuration and costs. + start_time : float + The timestamp marking the start of the model response. + end_time : float + The timestamp marking the end of the model response. + first_token_time : float + The timestamp when the first token was received, used for latency calculations. + token_times : Tuple[float, ...] + A tuple of time intervals between received tokens, used for inter-token latency. + token_count : int + The total number of tokens processed in the response. + + Returns + ------- + Dict[str, Any] + A dictionary containing calculated metrics, including: + - `input_tokens`: Number of tokens in the input. + - `output_tokens`: Number of tokens in the output. + - `total_tokens`: Total token count (input + output). + - `cost_usd`: Total cost of the response in USD. + - `latency_s`: Total time taken for the response, in seconds. + - `time_to_first_token_s`: Time to receive the first token, in seconds. + - `inter_token_latency_s`: Average time between tokens, in seconds. If `token_times` is empty sets it to 0. + - `tokens_per_second`: Processing rate of tokens per second. + """ model_config = self.config.models[model] input_tokens = len(self.tokenizer.encode(self.input_to_string(input))) output_tokens = len(self.tokenizer.encode(self.output_to_string(output))) @@ -674,10 +709,31 @@ def calculate_metrics( def calculate_cost( self, token_count: int, token_cost: Union[float, List[Dict[str, Any]]] ) -> float: + """ + Calculates the cost for a given number of tokens based on a fixed cost per token + or a variable rate structure. + + If `token_cost` is a fixed float, the total cost is `token_count * token_cost`. + If `token_cost` is a list, it checks each range and calculates cost based on the applicable range's rate. + + Parameters + ---------- + token_count : int + The total number of tokens for which the cost is being calculated. + token_cost : Union[float, List[Dict[str, Any]]] + Either a fixed cost per token (as a float) or a list of dictionaries defining + variable cost ranges. Each dictionary in the list represents a range with + 'range' (a tuple of minimum and maximum token counts) and 'cost' (cost per token) keys. + + Returns + ------- + float + The calculated cost based on the token count and cost structure. + """ if isinstance(token_cost, list): for cost_range in token_cost: if token_count >= cost_range.range[0] and ( - token_count <= cost_range.range[1] or cost_range.range[1] is None + cost_range.range[1] is None or token_count <= cost_range.range[1] ): return cost_range.cost * token_count else: diff --git a/libs/core/tests/unit_tests/test_provider.py b/libs/core/tests/unit_tests/test_provider.py index 5a89138d..15750c3d 100644 --- a/libs/core/tests/unit_tests/test_provider.py +++ b/libs/core/tests/unit_tests/test_provider.py @@ -290,4 +290,79 @@ def test_join_chunks_tool_calls(mock_provider): assert tool_call.id == "tool_1" assert tool_call.function.name == "search_tool" assert tool_call.function.arguments == "{\"query\": \"weather details\"}" - assert tool_call.type == "function" \ No newline at end of file + assert tool_call.type == "function" + + +def test_calculate_cost_fixed_cost(mock_provider): + fixed_cost = 0.02 + token_count = 100 + expected_cost = token_count * fixed_cost + assert mock_provider.calculate_cost(token_count, fixed_cost) == expected_cost + +def test_calculate_cost_variable_cost(mock_provider): + cost_range_1 = MagicMock() + cost_range_1.range = (0, 50) + cost_range_1.cost = 0.01 + + cost_range_2 = MagicMock() + cost_range_2.range = (51, 100) + cost_range_2.cost = 0.02 + + variable_cost = [cost_range_1, cost_range_2] + token_count = 75 + expected_cost = token_count * 0.02 + assert mock_provider.calculate_cost(token_count, variable_cost) == expected_cost + +def test_calculate_cost_variable_cost_higher_range(mock_provider): + cost_range_1 = MagicMock() + cost_range_1.range = (0, 50) + cost_range_1.cost = 0.01 + + cost_range_2 = MagicMock() + cost_range_2.range = (51, 100) + cost_range_2.cost = 0.02 + + cost_range_3 = MagicMock() + cost_range_3.range = (101, None) + cost_range_3.cost = 0.03 + + variable_cost = [cost_range_1, cost_range_2, cost_range_3] + token_count = 150 + expected_cost = token_count * 0.03 + assert mock_provider.calculate_cost(token_count, variable_cost) == expected_cost + +def test_calculate_cost_variable_cost_no_matching_range(mock_provider): + cost_range_1 = MagicMock() + cost_range_1.range = (0, 50) + cost_range_1.cost = 0.01 + + cost_range_2 = MagicMock() + cost_range_2.range = (51, 100) + cost_range_2.cost = 0.02 + + cost_range_3 = MagicMock() + cost_range_3.range = (101, 150) + cost_range_3.cost = 0.03 + + variable_cost = [cost_range_1, cost_range_2, cost_range_3] + token_count = 200 + expected_cost = 0 + assert mock_provider.calculate_cost(token_count, variable_cost) == expected_cost + +def test_calculate_cost_variable_cost_no_matching_range_inferior(mock_provider): + cost_range_1 = MagicMock() + cost_range_1.range = (10, 50) + cost_range_1.cost = 0.01 + + cost_range_2 = MagicMock() + cost_range_2.range = (51, 100) + cost_range_2.cost = 0.02 + + cost_range_3 = MagicMock() + cost_range_3.range = (101, 150) + cost_range_3.cost = 0.03 + + variable_cost = [cost_range_1, cost_range_2, cost_range_3] + token_count = 5 + expected_cost = 0 + assert mock_provider.calculate_cost(token_count, variable_cost) == expected_cost From bb0a4651152e67d1f6df5692a323b530c20a4141 Mon Sep 17 00:00:00 2001 From: Miguel Neves Date: Thu, 14 Nov 2024 10:59:09 +0000 Subject: [PATCH 04/13] added unit tests and docstrings for input_to_string on provider --- .../core/llmstudio_core/providers/provider.py | 34 ++++++++++++++ libs/core/tests/unit_tests/test_provider.py | 47 +++++++++++++++++++ 2 files changed, 81 insertions(+) diff --git a/libs/core/llmstudio_core/providers/provider.py b/libs/core/llmstudio_core/providers/provider.py index b2467d72..98cc93b4 100644 --- a/libs/core/llmstudio_core/providers/provider.py +++ b/libs/core/llmstudio_core/providers/provider.py @@ -741,6 +741,23 @@ def calculate_cost( return 0 def input_to_string(self, input): + """ + Converts an input, which can be a string or a structured list of messages, into a single concatenated string. + + Parameters + ---------- + input : Any + The input data to be converted. This can be: + - A simple string, which is returned as-is. + - A list of message dictionaries, where each dictionary may contain `content`, `role`, + and nested items like `text` or `image_url`. + + Returns + ------- + str + A concatenated string representing the text content of all messages, + including text and URLs from image content if present. + """ if isinstance(input, str): return input else: @@ -762,6 +779,23 @@ def input_to_string(self, input): return "".join(result) def output_to_string(self, output): + """ + Extracts and returns the content or arguments from the output based on + the `finish_reason` of the first choice in `output`. + + Parameters + ---------- + output : Any + The model output object, expected to have a `choices` attribute that should contain a `finish_reason` indicating the type of output + ("stop", "tool_calls", or "function_call") and corresponding content or arguments. + + Returns + ------- + str + - If `finish_reason` is "stop": Returns the message content. + - If `finish_reason` is "tool_calls": Returns the arguments for the first tool call. + - If `finish_reason` is "function_call": Returns the arguments for the function call. + """ if output.choices[0].finish_reason == "stop": return output.choices[0].message.content elif output.choices[0].finish_reason == "tool_calls": diff --git a/libs/core/tests/unit_tests/test_provider.py b/libs/core/tests/unit_tests/test_provider.py index 15750c3d..408359cd 100644 --- a/libs/core/tests/unit_tests/test_provider.py +++ b/libs/core/tests/unit_tests/test_provider.py @@ -366,3 +366,50 @@ def test_calculate_cost_variable_cost_no_matching_range_inferior(mock_provider): token_count = 5 expected_cost = 0 assert mock_provider.calculate_cost(token_count, variable_cost) == expected_cost + +def test_input_to_string_with_string(mock_provider): + input_data = "Hello, world!" + assert mock_provider.input_to_string(input_data) == "Hello, world!" + + +def test_input_to_string_with_list_of_text_messages(mock_provider): + input_data = [ + {"content": "Hello"}, + {"content": " world!"}, + ] + assert mock_provider.input_to_string(input_data) == "Hello world!" + + +def test_input_to_string_with_list_of_text_and_url(mock_provider): + input_data = [ + {"role": "user", "content": [{"type": "text", "text": "Hello "}]}, + {"role": "user", "content": [{"type": "image_url", "image_url": {"url": "http://example.com/image.jpg"}}]}, + {"role": "user", "content": [{"type": "text", "text": " world!"}]}, + ] + expected_output = "Hello http://example.com/image.jpg world!" + assert mock_provider.input_to_string(input_data) == expected_output + + +def test_input_to_string_with_mixed_roles_and_missing_content(mock_provider): + input_data = [ + {"role": "assistant", "content": "Admin text;"}, + {"role": "user", "content": [{"type": "text", "text": "User text"}]}, + {"role": "user", "content": [{"type": "image_url", "image_url": {"url": "http://example.com/another.jpg"}}]}, + ] + expected_output = "Admin text;User texthttp://example.com/another.jpg" + assert mock_provider.input_to_string(input_data) == expected_output + + +def test_input_to_string_with_missing_content_key(mock_provider): + input_data = [ + {"role": "user"}, + {"role": "user", "content": [{"type": "text", "text": "Hello again"}]}, + ] + expected_output = "Hello again" + assert mock_provider.input_to_string(input_data) == expected_output + + +def test_input_to_string_with_empty_list(mock_provider): + input_data = [] + assert mock_provider.input_to_string(input_data) == "" + From c9c6d0fc4afad1323d04a3c688682309dddd5524 Mon Sep 17 00:00:00 2001 From: Miguel Neves Date: Thu, 14 Nov 2024 15:40:23 +0000 Subject: [PATCH 05/13] added unit tests and docstrings for chat and achat --- .../core/llmstudio_core/providers/provider.py | 87 ++++++++++++++++++- libs/core/tests/unit_tests/conftest.py | 8 -- libs/core/tests/unit_tests/test_provider.py | 75 ++++++++++++---- 3 files changed, 141 insertions(+), 29 deletions(-) diff --git a/libs/core/llmstudio_core/providers/provider.py b/libs/core/llmstudio_core/providers/provider.py index 98cc93b4..c3d4e23a 100644 --- a/libs/core/llmstudio_core/providers/provider.py +++ b/libs/core/llmstudio_core/providers/provider.py @@ -147,8 +147,38 @@ async def achat( parameters: Optional[dict] = {}, **kwargs, ): + """ + Asynchronously establishes a chat connection with the provider’s API, handling retries, + request validation, and streaming response options. + + Parameters + ---------- + chat_input : Any + The input data for the chat request, such as a string or dictionary, to be sent to the API. + model : str + The identifier of the model to be used for the chat request. + is_stream : Optional[bool], default=False + Flag to indicate if the response should be streamed. If True, returns an async generator + for streaming content; otherwise, returns the complete response. + retries : Optional[int], default=0 + Number of retry attempts on error. Retries will be attempted for specific HTTP errors like rate limits. + parameters : Optional[dict], default={} + Additional configuration parameters for the request, such as temperature or max tokens. + **kwargs + Additional keyword arguments to customize the request. - """Makes a chat connection with the provider's API""" + Returns + ------- + Union[AsyncGenerator, Any] + - If `is_stream` is True, returns an async generator yielding response chunks. + - If `is_stream` is False, returns the first complete response chunk. + + Raises + ------ + ProviderError + - Raised if the request validation fails or if all retry attempts are exhausted. + - Also raised for unexpected exceptions during request handling. + """ try: request = self.validate_request( dict( @@ -193,8 +223,38 @@ def chat( parameters: Optional[dict] = {}, **kwargs, ): + """ + Establishes a chat connection with the provider’s API, handling retries, request validation, + and streaming response options. + + Parameters + ---------- + chat_input : Any + The input data for the chat request, often a string or dictionary, to be sent to the API. + model : str + The model identifier for selecting the model used in the chat request. + is_stream : Optional[bool], default=False + Flag to indicate if the response should be streamed. If True, the function returns a generator + for streaming content. Otherwise, it returns the complete response. + retries : Optional[int], default=0 + Number of retry attempts on error. Retries will be attempted on specific HTTP errors like rate limits. + parameters : Optional[dict], default={} + Additional configuration parameters for the request, such as temperature or max tokens. + **kwargs + Additional keyword arguments that can be passed to customize the request. + + Returns + ------- + Union[Generator, Any] + - If `is_stream` is True, returns a generator that yields chunks of the response. + - If `is_stream` is False, returns the first complete response chunk. - """Makes a chat connection with the provider's API""" + Raises + ------ + ProviderError + - Raised if the request validation fails or if the request fails after the specified number of retries. + - Also raised on other unexpected exceptions during request handling. + """ try: request = self.validate_request( dict( @@ -233,7 +293,28 @@ def chat( async def ahandle_response( self, request: ChatRequest, response: AsyncGenerator, start_time: float ) -> AsyncGenerator[str, None]: - """Handles the response from an API""" + """ + Asynchronously handles the response from an API, processing response chunks for either + streaming or non-streaming responses. + + Buffers response chunks for non-streaming responses to output one single message. For streaming responses sends incremental chunks. + + Parameters + ---------- + request : ChatRequest + The chat request object, which includes input data, model name, and streaming options. + response : AsyncGenerator + The async generator yielding response chunks from the API. + start_time : float + The timestamp when the response handling started, used for latency calculations. + + Yields + ------ + Union[ChatCompletionChunk, ChatCompletion] + - If `request.is_stream` is True, yields `ChatCompletionChunk` objects with incremental + response chunks for streaming. + - If `request.is_stream` is False, yields a final `ChatCompletion` object after processing all chunks. + """ first_token_time = None previous_token_time = None token_times = [] diff --git a/libs/core/tests/unit_tests/conftest.py b/libs/core/tests/unit_tests/conftest.py index 1a261eac..955f9276 100644 --- a/libs/core/tests/unit_tests/conftest.py +++ b/libs/core/tests/unit_tests/conftest.py @@ -11,14 +11,6 @@ async def aparse_response(self, response, **kwargs): def parse_response(self, response, **kwargs): return response - def chat(self, chat_input, model, **kwargs): - # Mock the response to match expected structure - return MagicMock(choices=[MagicMock(finish_reason="stop")]) - - async def achat(self, chat_input, model, **kwargs): - # Mock the response to match expected structure - return MagicMock(choices=[MagicMock(finish_reason="stop")]) - def output_to_string(self, output): # Handle string inputs if isinstance(output, str): diff --git a/libs/core/tests/unit_tests/test_provider.py b/libs/core/tests/unit_tests/test_provider.py index 408359cd..6c00db12 100644 --- a/libs/core/tests/unit_tests/test_provider.py +++ b/libs/core/tests/unit_tests/test_provider.py @@ -1,30 +1,44 @@ from unittest.mock import AsyncMock, MagicMock import pytest -from llmstudio_core.providers.provider import ChatRequest, ProviderError, ChatCompletion, time +from llmstudio_core.providers.provider import ChatRequest, ProviderError, ChatCompletion, time, ChatCompletionChunk -request = ChatRequest(chat_input="Hello", model="test_model") +request = ChatRequest(chat_input="Hello World", model="test_model") + +def test_chat_response_non_stream(mock_provider): + mock_provider.validate_request = MagicMock() + mock_provider.validate_model = MagicMock() + mock_provider.generate_client = MagicMock(return_value="mock_response") + mock_provider.handle_response = MagicMock(return_value="final_response") -def test_chat(mock_provider): - mock_provider.generate_client = MagicMock(return_value=MagicMock()) - mock_provider.handle_response = MagicMock(return_value=iter(["response"])) - - print(request.model_dump()) - response = mock_provider.chat(request.chat_input, request.model) - - assert response is not None + response = mock_provider.chat(chat_input="Hello", model="test_model") + assert response == "final_response" + mock_provider.validate_request.assert_called_once() + mock_provider.validate_model.assert_called_once() + +def test_chat_streaming_response(mock_provider): + mock_provider.validate_request = MagicMock() + mock_provider.validate_model = MagicMock() + mock_provider.generate_client = MagicMock(return_value="mock_response_stream") + mock_provider.handle_response = MagicMock(return_value=iter(["streamed_response_1", "streamed_response_2"])) + + response_stream = mock_provider.chat(chat_input="Hello", model="test_model", is_stream=True) + assert next(response_stream) == "streamed_response_1" + assert next(response_stream) == "streamed_response_2" + mock_provider.validate_request.assert_called_once() + mock_provider.validate_model.assert_called_once() -@pytest.mark.asyncio -async def test_achat(mock_provider): - mock_provider.agenerate_client = AsyncMock(return_value=AsyncMock()) - mock_provider.ahandle_response = AsyncMock(return_value=AsyncMock()) + +#@pytest.mark.asyncio +#async def test_achat_response_non_stream(mock_provider): +# pass - print(request.model_dump()) - response = await mock_provider.achat(request.chat_input, request.model) - assert response is not None +#@pytest.mark.asyncio +#async def test_achat_streaming_response(mock_provider): +# pass def test_validate_model(mock_provider): @@ -82,7 +96,32 @@ def test_calculate_metrics_single_token(mock_provider): assert metrics["time_to_first_token_s"] == 0.5 - 0.0 assert metrics["inter_token_latency_s"] == 0 assert metrics["tokens_per_second"] == 1 / 1.0 - + +@pytest.mark.asyncio +async def test_ahandle_response_non_streaming(mock_provider): + request = MagicMock(is_stream=False, chat_input="Hello", model="test_model", parameters={}) + response_chunk = { + "choices": [{"delta": {"content": "Non-streamed response"}, "finish_reason": "stop"}], + "model": "test_model", + } + start_time = time.time() + + async def mock_aparse_response(*args, **kwargs): + yield response_chunk + + mock_provider.aparse_response = mock_aparse_response + mock_provider.join_chunks = MagicMock(return_value=(ChatCompletion(id="id", choices=[], created=0, model="test_model", object="chat.completion"), "Non-streamed response")) + mock_provider.calculate_metrics = MagicMock(return_value={"input_tokens": 1}) + + response = [] + async for chunk in mock_provider.ahandle_response(request, mock_aparse_response(), start_time): + response.append(chunk) + + assert isinstance(response[0], ChatCompletion) + assert response[0].choices == [] + assert response[0].chat_output == "Non-streamed response" + + def test_handle_response_stop(mock_provider): current_time = int(time.time()) From 64ad3e1c20c2d20c875fbff29a264a138185f437 Mon Sep 17 00:00:00 2001 From: Miguel Neves Date: Thu, 14 Nov 2024 15:40:44 +0000 Subject: [PATCH 06/13] added unit tests and docstrings for chat and achat --- libs/core/tests/unit_tests/test_provider.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/libs/core/tests/unit_tests/test_provider.py b/libs/core/tests/unit_tests/test_provider.py index 6c00db12..e61a1c02 100644 --- a/libs/core/tests/unit_tests/test_provider.py +++ b/libs/core/tests/unit_tests/test_provider.py @@ -30,17 +30,6 @@ def test_chat_streaming_response(mock_provider): mock_provider.validate_request.assert_called_once() mock_provider.validate_model.assert_called_once() - -#@pytest.mark.asyncio -#async def test_achat_response_non_stream(mock_provider): -# pass - - -#@pytest.mark.asyncio -#async def test_achat_streaming_response(mock_provider): -# pass - - def test_validate_model(mock_provider): request = ChatRequest(chat_input="Hello", model="test_model") mock_provider.validate_model(request) # Should not raise From 0440f09320675f26cb93afd1801551ec4af39ba4 Mon Sep 17 00:00:00 2001 From: Miguel Neves Date: Fri, 15 Nov 2024 20:07:40 +0000 Subject: [PATCH 07/13] chore: cleaned provider unit tests --- libs/core/tests/unit_tests/test_provider.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/libs/core/tests/unit_tests/test_provider.py b/libs/core/tests/unit_tests/test_provider.py index e61a1c02..5b1d4e5b 100644 --- a/libs/core/tests/unit_tests/test_provider.py +++ b/libs/core/tests/unit_tests/test_provider.py @@ -39,8 +39,6 @@ def test_validate_model(mock_provider): mock_provider.validate_model(request_invalid) def test_calculate_metrics(mock_provider): - - mock_provider.tokenizer.encode = lambda x: x.split() # Assuming tokenizer splits "Hello" and "World" into one token each metrics = mock_provider.calculate_metrics( input="Hello", @@ -63,8 +61,6 @@ def test_calculate_metrics(mock_provider): assert metrics["tokens_per_second"] == 2 / 1.0 # token_count / total_time def test_calculate_metrics_single_token(mock_provider): - - mock_provider.tokenizer.encode = lambda x: x.split() metrics = mock_provider.calculate_metrics( input="Hello", @@ -150,9 +146,6 @@ def test_handle_response_stop(mock_provider): assert result.choices[0].message.content == "Hello, world!" def test_handle_response_stop_single_token(mock_provider): - """ - testing single token answer. token_times var will be 0 - """ current_time = int(time.time()) From 2d0580e3c78cff84891485ab5ca06b128312394f Mon Sep 17 00:00:00 2001 From: Miguel Neves Date: Tue, 17 Dec 2024 14:43:06 +0000 Subject: [PATCH 08/13] chore: separated provider tests into different files. fixed some of its tests --- libs/core/tests/unit_tests/test_provider.py | 213 +----------------- .../test_provider_costs_and_metrics.py | 119 ++++++++++ .../test_provider_handle_response.py | 152 +++++++++++++ 3 files changed, 273 insertions(+), 211 deletions(-) create mode 100644 libs/core/tests/unit_tests/test_provider_costs_and_metrics.py create mode 100644 libs/core/tests/unit_tests/test_provider_handle_response.py diff --git a/libs/core/tests/unit_tests/test_provider.py b/libs/core/tests/unit_tests/test_provider.py index 5b1d4e5b..dec8f567 100644 --- a/libs/core/tests/unit_tests/test_provider.py +++ b/libs/core/tests/unit_tests/test_provider.py @@ -1,10 +1,9 @@ -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import MagicMock import pytest -from llmstudio_core.providers.provider import ChatRequest, ProviderError, ChatCompletion, time, ChatCompletionChunk +from llmstudio_core.providers.provider import ChatRequest, ProviderError, ChatCompletion, time request = ChatRequest(chat_input="Hello World", model="test_model") - def test_chat_response_non_stream(mock_provider): mock_provider.validate_request = MagicMock() @@ -37,139 +36,6 @@ def test_validate_model(mock_provider): request_invalid = ChatRequest(chat_input="Hello", model="invalid_model") with pytest.raises(ProviderError): mock_provider.validate_model(request_invalid) - -def test_calculate_metrics(mock_provider): - - metrics = mock_provider.calculate_metrics( - input="Hello", - output="Hello World", - model="test_model", - start_time=0.0, - end_time=1.0, - first_token_time=0.5, - token_times=(0.1,), - token_count=2, - ) - - assert metrics["input_tokens"] == 1 - assert metrics["output_tokens"] == 2 - assert metrics["total_tokens"] == 3 - assert metrics["cost_usd"] == 0.01 * 1 + 0.02 * 2 # input_cost + output_cost - assert metrics["latency_s"] == 1.0 # end_time - start_time - assert metrics["time_to_first_token_s"] == 0.5 - 0.0 # first_token_time - start_time - assert metrics["inter_token_latency_s"] == 0.1 # Average of token_times - assert metrics["tokens_per_second"] == 2 / 1.0 # token_count / total_time - -def test_calculate_metrics_single_token(mock_provider): - - metrics = mock_provider.calculate_metrics( - input="Hello", - output="World", - model="test_model", - start_time=0.0, - end_time=1.0, - first_token_time=0.5, - token_times=(), - token_count=1, - ) - - assert metrics["input_tokens"] == 1 - assert metrics["output_tokens"] == 1 - assert metrics["total_tokens"] == 2 - assert metrics["cost_usd"] == 0.01 * 1 + 0.02 * 1 - assert metrics["latency_s"] == 1.0 - assert metrics["time_to_first_token_s"] == 0.5 - 0.0 - assert metrics["inter_token_latency_s"] == 0 - assert metrics["tokens_per_second"] == 1 / 1.0 - -@pytest.mark.asyncio -async def test_ahandle_response_non_streaming(mock_provider): - request = MagicMock(is_stream=False, chat_input="Hello", model="test_model", parameters={}) - response_chunk = { - "choices": [{"delta": {"content": "Non-streamed response"}, "finish_reason": "stop"}], - "model": "test_model", - } - start_time = time.time() - - async def mock_aparse_response(*args, **kwargs): - yield response_chunk - - mock_provider.aparse_response = mock_aparse_response - mock_provider.join_chunks = MagicMock(return_value=(ChatCompletion(id="id", choices=[], created=0, model="test_model", object="chat.completion"), "Non-streamed response")) - mock_provider.calculate_metrics = MagicMock(return_value={"input_tokens": 1}) - - response = [] - async for chunk in mock_provider.ahandle_response(request, mock_aparse_response(), start_time): - response.append(chunk) - - assert isinstance(response[0], ChatCompletion) - assert response[0].choices == [] - assert response[0].chat_output == "Non-streamed response" - - -def test_handle_response_stop(mock_provider): - - current_time = int(time.time()) - - response_generator = iter([ - { - "id": "test_id", - "model": "test_model", - "created": current_time, - "choices": [ - { - "delta": {"content": "Hello, "}, - "finish_reason": None, - "index": 0, - } - ], - }, - { - "id": "test_id", - "model": "test_model", - "created": current_time, - "choices": [ - { - "delta": {"content": "world!"}, - "finish_reason": "stop", - "index": 0, - } - ], - }, - ]) - - request = ChatRequest(chat_input="Hello", model="test_model") - result_generator = mock_provider.handle_response(request, response_generator, start_time=time.time()) - result = next(result_generator) - - assert isinstance(result, ChatCompletion) - assert result.choices[0].message.content == "Hello, world!" - -def test_handle_response_stop_single_token(mock_provider): - - current_time = int(time.time()) - - response_generator = iter([ - { - "id": "test_id", - "model": "test_model", - "created": current_time, - "choices": [ - { - "delta": {"content": "Hello, "}, - "finish_reason": "stop", - "index": 0, - } - ], - } - ]) - - request = ChatRequest(chat_input="Hello", model="test_model") - result_generator = mock_provider.handle_response(request, response_generator, start_time=time.time()) - result = next(result_generator) - - assert isinstance(result, ChatCompletion) - assert result.choices[0].message.content == "Hello, " def test_join_chunks_finish_reason_stop(mock_provider): current_time = int(time.time()) @@ -313,81 +179,6 @@ def test_join_chunks_tool_calls(mock_provider): assert tool_call.function.arguments == "{\"query\": \"weather details\"}" assert tool_call.type == "function" - -def test_calculate_cost_fixed_cost(mock_provider): - fixed_cost = 0.02 - token_count = 100 - expected_cost = token_count * fixed_cost - assert mock_provider.calculate_cost(token_count, fixed_cost) == expected_cost - -def test_calculate_cost_variable_cost(mock_provider): - cost_range_1 = MagicMock() - cost_range_1.range = (0, 50) - cost_range_1.cost = 0.01 - - cost_range_2 = MagicMock() - cost_range_2.range = (51, 100) - cost_range_2.cost = 0.02 - - variable_cost = [cost_range_1, cost_range_2] - token_count = 75 - expected_cost = token_count * 0.02 - assert mock_provider.calculate_cost(token_count, variable_cost) == expected_cost - -def test_calculate_cost_variable_cost_higher_range(mock_provider): - cost_range_1 = MagicMock() - cost_range_1.range = (0, 50) - cost_range_1.cost = 0.01 - - cost_range_2 = MagicMock() - cost_range_2.range = (51, 100) - cost_range_2.cost = 0.02 - - cost_range_3 = MagicMock() - cost_range_3.range = (101, None) - cost_range_3.cost = 0.03 - - variable_cost = [cost_range_1, cost_range_2, cost_range_3] - token_count = 150 - expected_cost = token_count * 0.03 - assert mock_provider.calculate_cost(token_count, variable_cost) == expected_cost - -def test_calculate_cost_variable_cost_no_matching_range(mock_provider): - cost_range_1 = MagicMock() - cost_range_1.range = (0, 50) - cost_range_1.cost = 0.01 - - cost_range_2 = MagicMock() - cost_range_2.range = (51, 100) - cost_range_2.cost = 0.02 - - cost_range_3 = MagicMock() - cost_range_3.range = (101, 150) - cost_range_3.cost = 0.03 - - variable_cost = [cost_range_1, cost_range_2, cost_range_3] - token_count = 200 - expected_cost = 0 - assert mock_provider.calculate_cost(token_count, variable_cost) == expected_cost - -def test_calculate_cost_variable_cost_no_matching_range_inferior(mock_provider): - cost_range_1 = MagicMock() - cost_range_1.range = (10, 50) - cost_range_1.cost = 0.01 - - cost_range_2 = MagicMock() - cost_range_2.range = (51, 100) - cost_range_2.cost = 0.02 - - cost_range_3 = MagicMock() - cost_range_3.range = (101, 150) - cost_range_3.cost = 0.03 - - variable_cost = [cost_range_1, cost_range_2, cost_range_3] - token_count = 5 - expected_cost = 0 - assert mock_provider.calculate_cost(token_count, variable_cost) == expected_cost - def test_input_to_string_with_string(mock_provider): input_data = "Hello, world!" assert mock_provider.input_to_string(input_data) == "Hello, world!" diff --git a/libs/core/tests/unit_tests/test_provider_costs_and_metrics.py b/libs/core/tests/unit_tests/test_provider_costs_and_metrics.py new file mode 100644 index 00000000..9cf0d52f --- /dev/null +++ b/libs/core/tests/unit_tests/test_provider_costs_and_metrics.py @@ -0,0 +1,119 @@ +from unittest.mock import MagicMock + +def test_calculate_metrics(mock_provider): + + metrics = mock_provider.calculate_metrics( + input="Hello", + output="Hello World", + model="test_model", + start_time=0.0, + end_time=1.0, + first_token_time=0.5, + token_times=(0.1,), + token_count=2, + ) + + assert metrics["input_tokens"] == 1 + assert metrics["output_tokens"] == 2 + assert metrics["total_tokens"] == 3 + assert metrics["cost_usd"] == 0.01 * 1 + 0.02 * 2 # input_cost + output_cost + assert metrics["latency_s"] == 1.0 # end_time - start_time + assert metrics["time_to_first_token_s"] == 0.5 - 0.0 # first_token_time - start_time + assert metrics["inter_token_latency_s"] == 0.1 # Average of token_times + assert metrics["tokens_per_second"] == 2 / 1.0 # token_count / total_time + +def test_calculate_metrics_single_token(mock_provider): + + metrics = mock_provider.calculate_metrics( + input="Hello", + output="World", + model="test_model", + start_time=0.0, + end_time=1.0, + first_token_time=0.5, + token_times=(), + token_count=1, + ) + + assert metrics["input_tokens"] == 1 + assert metrics["output_tokens"] == 1 + assert metrics["total_tokens"] == 2 + assert metrics["cost_usd"] == 0.01 * 1 + 0.02 * 1 + assert metrics["latency_s"] == 1.0 + assert metrics["time_to_first_token_s"] == 0.5 - 0.0 + assert metrics["inter_token_latency_s"] == 0 + assert metrics["tokens_per_second"] == 1 / 1.0 + +def test_calculate_cost_fixed_cost(mock_provider): + fixed_cost = 0.02 + token_count = 100 + expected_cost = token_count * fixed_cost + assert mock_provider.calculate_cost(token_count, fixed_cost) == expected_cost + +def test_calculate_cost_variable_cost(mock_provider): + cost_range_1 = MagicMock() + cost_range_1.range = (0, 50) + cost_range_1.cost = 0.01 + + cost_range_2 = MagicMock() + cost_range_2.range = (51, 100) + cost_range_2.cost = 0.02 + + variable_cost = [cost_range_1, cost_range_2] + token_count = 75 + expected_cost = token_count * 0.02 + assert mock_provider.calculate_cost(token_count, variable_cost) == expected_cost + +def test_calculate_cost_variable_cost_higher_range(mock_provider): + cost_range_1 = MagicMock() + cost_range_1.range = (0, 50) + cost_range_1.cost = 0.01 + + cost_range_2 = MagicMock() + cost_range_2.range = (51, 100) + cost_range_2.cost = 0.02 + + cost_range_3 = MagicMock() + cost_range_3.range = (101, None) + cost_range_3.cost = 0.03 + + variable_cost = [cost_range_1, cost_range_2, cost_range_3] + token_count = 150 + expected_cost = token_count * 0.03 + assert mock_provider.calculate_cost(token_count, variable_cost) == expected_cost + +def test_calculate_cost_variable_cost_no_matching_range(mock_provider): + cost_range_1 = MagicMock() + cost_range_1.range = (0, 50) + cost_range_1.cost = 0.01 + + cost_range_2 = MagicMock() + cost_range_2.range = (51, 100) + cost_range_2.cost = 0.02 + + cost_range_3 = MagicMock() + cost_range_3.range = (101, 150) + cost_range_3.cost = 0.03 + + variable_cost = [cost_range_1, cost_range_2, cost_range_3] + token_count = 200 + expected_cost = 0 + assert mock_provider.calculate_cost(token_count, variable_cost) == expected_cost + +def test_calculate_cost_variable_cost_no_matching_range_inferior(mock_provider): + cost_range_1 = MagicMock() + cost_range_1.range = (10, 50) + cost_range_1.cost = 0.01 + + cost_range_2 = MagicMock() + cost_range_2.range = (51, 100) + cost_range_2.cost = 0.02 + + cost_range_3 = MagicMock() + cost_range_3.range = (101, 150) + cost_range_3.cost = 0.03 + + variable_cost = [cost_range_1, cost_range_2, cost_range_3] + token_count = 5 + expected_cost = 0 + assert mock_provider.calculate_cost(token_count, variable_cost) == expected_cost \ No newline at end of file diff --git a/libs/core/tests/unit_tests/test_provider_handle_response.py b/libs/core/tests/unit_tests/test_provider_handle_response.py new file mode 100644 index 00000000..7e6cc5b1 --- /dev/null +++ b/libs/core/tests/unit_tests/test_provider_handle_response.py @@ -0,0 +1,152 @@ +from unittest.mock import MagicMock + +import pytest +from llmstudio_core.providers.provider import ChatCompletion, time, ChatCompletionChunk + +@pytest.mark.asyncio +async def test_ahandle_response_non_streaming(mock_provider): + request = MagicMock(is_stream=False, chat_input="Hello", model="test_model", parameters={}) + response_chunk = { + "choices": [{"delta": {"content": "Non-streamed response"}, "finish_reason": "stop", "index":0}], + "model": "test_model", + } + start_time = time.time() + + async def mock_aparse_response(*args, **kwargs): + yield response_chunk + + mock_provider.aparse_response = mock_aparse_response + mock_provider.join_chunks = MagicMock(return_value=(ChatCompletion(id="id", choices=[], created=0, model="test_model", object="chat.completion"), "Non-streamed response")) + mock_provider.calculate_metrics = MagicMock(return_value={"input_tokens": 1}) + + response = [] + async for chunk in mock_provider.ahandle_response(request, mock_aparse_response(), start_time): + response.append(chunk) + + assert isinstance(response[0], ChatCompletion) + assert response[0].choices == [] + assert response[0].chat_output == "Non-streamed response" + +@pytest.mark.asyncio +async def test_ahandle_response_streaming_length(mock_provider): + request = MagicMock(is_stream=True, chat_input="Hello", model="test_model", parameters={}) + response_chunk = { + "choices": [{"delta": {"content": "Streamed response"}, "finish_reason": "length", "index":0}], + "model": "test_model", + "object":"chat.completion.chunk", + "created":0 + } + start_time = time.time() + + async def mock_aparse_response(*args, **kwargs): + yield response_chunk + + mock_provider.aparse_response = mock_aparse_response + mock_provider.join_chunks = MagicMock(return_value=(ChatCompletion(id="id", choices=[], created=0, model="test_model", object="chat.completion"), "Streamed response")) + mock_provider.calculate_metrics = MagicMock(return_value={"input_tokens": 1}) + + response = [] + async for chunk in mock_provider.ahandle_response(request, mock_aparse_response(), start_time): + response.append(chunk) + + assert isinstance(response[0], ChatCompletionChunk) + assert response[0].chat_output_stream == "Streamed response" + +@pytest.mark.asyncio +async def test_ahandle_response_streaming_stop(mock_provider): + request = MagicMock(is_stream=True, chat_input="Hello", model="test_model", parameters={}) + response_chunk = { + "choices": [{"delta": {"content": "Streamed response"}, "finish_reason": "stop", "index":0}], + "model": "test_model", + "object":"chat.completion.chunk", + "created":0 + } + start_time = time.time() + + async def mock_aparse_response(*args, **kwargs): + yield response_chunk + + mock_provider.aparse_response = mock_aparse_response + mock_provider.join_chunks = MagicMock(return_value=(ChatCompletion(id="id", choices=[], created=0, model="test_model", object="chat.completion"), "Streamed response")) + mock_provider.calculate_metrics = MagicMock(return_value={"input_tokens": 1}) + + response = [] + async for chunk in mock_provider.ahandle_response(request, mock_aparse_response(), start_time): + response.append(chunk) + + assert isinstance(response[0], ChatCompletionChunk) + assert response[0].chat_output == "Streamed response" + #assert response[0].chat_output_stream == "Streamed response" + + +def test_handle_response_non_streaming(mock_provider): + request = MagicMock(is_stream=False, chat_input="Hello", model="test_model", parameters={}) + response_chunk = { + "choices": [{"delta": {"content": "Non-streamed response"}, "finish_reason": "stop", "index":0}], + "model": "test_model", + } + start_time = time.time() + + def mock_parse_response(*args, **kwargs): + yield response_chunk + + mock_provider.aparse_response = mock_parse_response + mock_provider.join_chunks = MagicMock(return_value=(ChatCompletion(id="id", choices=[], created=0, model="test_model", object="chat.completion"), "Non-streamed response")) + mock_provider.calculate_metrics = MagicMock(return_value={"input_tokens": 1}) + + response = [] + for chunk in mock_provider.handle_response(request, mock_parse_response(), start_time): + response.append(chunk) + + assert isinstance(response[0], ChatCompletion) + assert response[0].choices == [] + assert response[0].chat_output == "Non-streamed response" + +def test_handle_response_streaming_length(mock_provider): + request = MagicMock(is_stream=True, chat_input="Hello", model="test_model", parameters={}) + response_chunk = { + "choices": [{"delta": {"content": "Streamed response"}, "finish_reason": "length", "index":0}], + "model": "test_model", + "object":"chat.completion.chunk", + "created":0 + } + start_time = time.time() + + def mock_parse_response(*args, **kwargs): + yield response_chunk + + mock_provider.aparse_response = mock_parse_response + mock_provider.join_chunks = MagicMock(return_value=(ChatCompletion(id="id", choices=[], created=0, model="test_model", object="chat.completion"), "Streamed response")) + mock_provider.calculate_metrics = MagicMock(return_value={"input_tokens": 1}) + + response = [] + for chunk in mock_provider.handle_response(request, mock_parse_response(), start_time): + response.append(chunk) + + assert isinstance(response[0], ChatCompletionChunk) + assert response[0].chat_output_stream == "Streamed response" + +def test_handle_response_streaming_stop(mock_provider): + request = MagicMock(is_stream=True, chat_input="Hello", model="test_model", parameters={}) + response_chunk = { + "choices": [{"delta": {"content": "Streamed response"}, "finish_reason": "stop", "index":0}], + "model": "test_model", + "object":"chat.completion.chunk", + "created":0 + } + start_time = time.time() + + def mock_parse_response(*args, **kwargs): + yield response_chunk + + mock_provider.parse_response = mock_parse_response + mock_provider.join_chunks = MagicMock(return_value=(ChatCompletion(id="id", choices=[], created=0, model="test_model", object="chat.completion"), "Streamed response")) + mock_provider.calculate_metrics = MagicMock(return_value={"input_tokens": 1}) + + response = [] + for chunk in mock_provider.handle_response(request, mock_parse_response(), start_time): + response.append(chunk) + + assert isinstance(response[0], ChatCompletionChunk) + assert response[0].chat_output == "Streamed response" + #assert response[0].chat_output_stream == "Streamed response" \ No newline at end of file From 0683bf4ff68bc5c34f4fecc7c1d9a8f02a4dabcc Mon Sep 17 00:00:00 2001 From: Miguel Neves Date: Wed, 18 Dec 2024 10:11:28 +0000 Subject: [PATCH 09/13] chore: linted code --- .gitignore | 3 +- .../core/llmstudio_core/providers/provider.py | 54 ++--- libs/core/tests/unit_tests/conftest.py | 4 +- libs/core/tests/unit_tests/test_provider.py | 89 +++++--- .../test_provider_costs_and_metrics.py | 19 +- .../test_provider_handle_response.py | 209 ++++++++++++++---- 6 files changed, 280 insertions(+), 98 deletions(-) diff --git a/.gitignore b/.gitignore index 19015866..456aa8d8 100644 --- a/.gitignore +++ b/.gitignore @@ -56,6 +56,7 @@ env3 .env* .env*.local .venv* +*venv* env*/ venv*/ ENV/ @@ -76,4 +77,4 @@ bun.lockb llmstudio/llm_engine/logs/execution_logs.jsonl *.db .prettierignore -db \ No newline at end of file +db diff --git a/libs/core/llmstudio_core/providers/provider.py b/libs/core/llmstudio_core/providers/provider.py index c3d4e23a..eabb25b7 100644 --- a/libs/core/llmstudio_core/providers/provider.py +++ b/libs/core/llmstudio_core/providers/provider.py @@ -148,7 +148,7 @@ async def achat( **kwargs, ): """ - Asynchronously establishes a chat connection with the provider’s API, handling retries, + Asynchronously establishes a chat connection with the provider’s API, handling retries, request validation, and streaming response options. Parameters @@ -158,7 +158,7 @@ async def achat( model : str The identifier of the model to be used for the chat request. is_stream : Optional[bool], default=False - Flag to indicate if the response should be streamed. If True, returns an async generator + Flag to indicate if the response should be streamed. If True, returns an async generator for streaming content; otherwise, returns the complete response. retries : Optional[int], default=0 Number of retry attempts on error. Retries will be attempted for specific HTTP errors like rate limits. @@ -224,7 +224,7 @@ def chat( **kwargs, ): """ - Establishes a chat connection with the provider’s API, handling retries, request validation, + Establishes a chat connection with the provider’s API, handling retries, request validation, and streaming response options. Parameters @@ -234,7 +234,7 @@ def chat( model : str The model identifier for selecting the model used in the chat request. is_stream : Optional[bool], default=False - Flag to indicate if the response should be streamed. If True, the function returns a generator + Flag to indicate if the response should be streamed. If True, the function returns a generator for streaming content. Otherwise, it returns the complete response. retries : Optional[int], default=0 Number of retry attempts on error. Retries will be attempted on specific HTTP errors like rate limits. @@ -294,9 +294,9 @@ async def ahandle_response( self, request: ChatRequest, response: AsyncGenerator, start_time: float ) -> AsyncGenerator[str, None]: """ - Asynchronously handles the response from an API, processing response chunks for either + Asynchronously handles the response from an API, processing response chunks for either streaming or non-streaming responses. - + Buffers response chunks for non-streaming responses to output one single message. For streaming responses sends incremental chunks. Parameters @@ -311,7 +311,7 @@ async def ahandle_response( Yields ------ Union[ChatCompletionChunk, ChatCompletion] - - If `request.is_stream` is True, yields `ChatCompletionChunk` objects with incremental + - If `request.is_stream` is True, yields `ChatCompletionChunk` objects with incremental response chunks for streaming. - If `request.is_stream` is False, yields a final `ChatCompletion` object after processing all chunks. """ @@ -423,10 +423,10 @@ def handle_response( self, request: ChatRequest, response: Generator, start_time: float ) -> Generator: """ - Processes API response chunks to build a structured, complete response, yielding + Processes API response chunks to build a structured, complete response, yielding each chunk if streaming is enabled. - - If streaming, each chunk is yielded as soon as it’s processed. Otherwise, all chunks + + If streaming, each chunk is yielded as soon as it’s processed. Otherwise, all chunks are combined and yielded as a single response at the end. Parameters @@ -551,8 +551,8 @@ def handle_response( def join_chunks(self, chunks): """ - Combine multiple response chunks from the model into a single, structured response. - Handles tool calls, function calls, and standard text completion based on the + Combine multiple response chunks from the model into a single, structured response. + Handles tool calls, function calls, and standard text completion based on the purpose indicated by the final chunk. Parameters @@ -563,7 +563,7 @@ def join_chunks(self, chunks): Returns ------- Tuple[ChatCompletion, str] - - `ChatCompletion`: The structured response based on the type of completion + - `ChatCompletion`: The structured response based on the type of completion (tool calls, function call, or text). - `str`: The concatenated content or arguments, depending on the completion type. @@ -733,7 +733,7 @@ def calculate_metrics( token_count: int, ) -> Dict[str, Any]: """ - Calculates performance and cost metrics for a model response based on timing + Calculates performance and cost metrics for a model response based on timing information, token counts, and model-specific costs. Parameters @@ -783,17 +783,21 @@ def calculate_metrics( "cost_usd": input_cost + output_cost, "latency_s": total_time, "time_to_first_token_s": first_token_time - start_time, - "inter_token_latency_s": sum(token_times) / len(token_times) if token_times else 0, - "tokens_per_second": token_count / total_time if token_times else 1 / total_time, + "inter_token_latency_s": sum(token_times) / len(token_times) + if token_times + else 0, + "tokens_per_second": token_count / total_time + if token_times + else 1 / total_time, } def calculate_cost( self, token_count: int, token_cost: Union[float, List[Dict[str, Any]]] ) -> float: """ - Calculates the cost for a given number of tokens based on a fixed cost per token + Calculates the cost for a given number of tokens based on a fixed cost per token or a variable rate structure. - + If `token_cost` is a fixed float, the total cost is `token_count * token_cost`. If `token_cost` is a list, it checks each range and calculates cost based on the applicable range's rate. @@ -802,14 +806,14 @@ def calculate_cost( token_count : int The total number of tokens for which the cost is being calculated. token_cost : Union[float, List[Dict[str, Any]]] - Either a fixed cost per token (as a float) or a list of dictionaries defining - variable cost ranges. Each dictionary in the list represents a range with + Either a fixed cost per token (as a float) or a list of dictionaries defining + variable cost ranges. Each dictionary in the list represents a range with 'range' (a tuple of minimum and maximum token counts) and 'cost' (cost per token) keys. Returns ------- float - The calculated cost based on the token count and cost structure. + The calculated cost based on the token count and cost structure. """ if isinstance(token_cost, list): for cost_range in token_cost: @@ -830,13 +834,13 @@ def input_to_string(self, input): input : Any The input data to be converted. This can be: - A simple string, which is returned as-is. - - A list of message dictionaries, where each dictionary may contain `content`, `role`, + - A list of message dictionaries, where each dictionary may contain `content`, `role`, and nested items like `text` or `image_url`. Returns ------- str - A concatenated string representing the text content of all messages, + A concatenated string representing the text content of all messages, including text and URLs from image content if present. """ if isinstance(input, str): @@ -861,13 +865,13 @@ def input_to_string(self, input): def output_to_string(self, output): """ - Extracts and returns the content or arguments from the output based on + Extracts and returns the content or arguments from the output based on the `finish_reason` of the first choice in `output`. Parameters ---------- output : Any - The model output object, expected to have a `choices` attribute that should contain a `finish_reason` indicating the type of output + The model output object, expected to have a `choices` attribute that should contain a `finish_reason` indicating the type of output ("stop", "tool_calls", or "function_call") and corresponding content or arguments. Returns diff --git a/libs/core/tests/unit_tests/conftest.py b/libs/core/tests/unit_tests/conftest.py index 955f9276..df8e04cc 100644 --- a/libs/core/tests/unit_tests/conftest.py +++ b/libs/core/tests/unit_tests/conftest.py @@ -18,7 +18,7 @@ def output_to_string(self, output): if output.choices[0].finish_reason == "stop": return output.choices[0].message.content return "" - + def validate_request(self, request): # For testing, simply return the request return request @@ -27,12 +27,14 @@ async def agenerate_client(self, request): # For testing, return an async generator async def async_gen(): yield {} + return async_gen() def generate_client(self, request): # For testing, return a generator def gen(): yield {} + return gen() @staticmethod diff --git a/libs/core/tests/unit_tests/test_provider.py b/libs/core/tests/unit_tests/test_provider.py index dec8f567..a3a2288e 100644 --- a/libs/core/tests/unit_tests/test_provider.py +++ b/libs/core/tests/unit_tests/test_provider.py @@ -1,10 +1,15 @@ from unittest.mock import MagicMock import pytest -from llmstudio_core.providers.provider import ChatRequest, ProviderError, ChatCompletion, time +from llmstudio_core.providers.provider import ( + ChatRequest, + ProviderError, + time, +) request = ChatRequest(chat_input="Hello World", model="test_model") - + + def test_chat_response_non_stream(mock_provider): mock_provider.validate_request = MagicMock() mock_provider.validate_model = MagicMock() @@ -16,19 +21,25 @@ def test_chat_response_non_stream(mock_provider): assert response == "final_response" mock_provider.validate_request.assert_called_once() mock_provider.validate_model.assert_called_once() - + + def test_chat_streaming_response(mock_provider): mock_provider.validate_request = MagicMock() mock_provider.validate_model = MagicMock() mock_provider.generate_client = MagicMock(return_value="mock_response_stream") - mock_provider.handle_response = MagicMock(return_value=iter(["streamed_response_1", "streamed_response_2"])) + mock_provider.handle_response = MagicMock( + return_value=iter(["streamed_response_1", "streamed_response_2"]) + ) - response_stream = mock_provider.chat(chat_input="Hello", model="test_model", is_stream=True) + response_stream = mock_provider.chat( + chat_input="Hello", model="test_model", is_stream=True + ) assert next(response_stream) == "streamed_response_1" assert next(response_stream) == "streamed_response_2" mock_provider.validate_request.assert_called_once() mock_provider.validate_model.assert_called_once() + def test_validate_model(mock_provider): request = ChatRequest(chat_input="Hello", model="test_model") mock_provider.validate_model(request) # Should not raise @@ -36,7 +47,8 @@ def test_validate_model(mock_provider): request_invalid = ChatRequest(chat_input="Hello", model="invalid_model") with pytest.raises(ProviderError): mock_provider.validate_model(request_invalid) - + + def test_join_chunks_finish_reason_stop(mock_provider): current_time = int(time.time()) chunks = [ @@ -70,6 +82,7 @@ def test_join_chunks_finish_reason_stop(mock_provider): assert output_string == "Hello, world!" assert response.choices[0].message.content == "Hello, world!" + def test_join_chunks_finish_reason_function_call(mock_provider): current_time = int(time.time()) chunks = [ @@ -79,7 +92,9 @@ def test_join_chunks_finish_reason_function_call(mock_provider): "created": current_time, "choices": [ { - "delta": {"function_call": {"name": "my_function", "arguments": "arg1"}}, + "delta": { + "function_call": {"name": "my_function", "arguments": "arg1"} + }, "finish_reason": None, "index": 0, } @@ -115,11 +130,11 @@ def test_join_chunks_finish_reason_function_call(mock_provider): assert output_string == "arg1arg2" assert response.choices[0].message.function_call.arguments == "arg1arg2" assert response.choices[0].message.function_call.name == "my_function" - - + + def test_join_chunks_tool_calls(mock_provider): current_time = int(time.time()) - + chunks = [ { "id": "test_id_1", @@ -132,15 +147,18 @@ def test_join_chunks_tool_calls(mock_provider): { "id": "tool_1", "index": 0, - "function": {"name": "search_tool", "arguments": "{\"query\": \"weather"}, - "type": "function" + "function": { + "name": "search_tool", + "arguments": '{"query": "weather', + }, + "type": "function", } ] }, "finish_reason": None, - "index": 0 + "index": 0, } - ] + ], }, { "id": "test_id_2", @@ -153,22 +171,23 @@ def test_join_chunks_tool_calls(mock_provider): { "id": "tool_1", "index": 0, - "function": {"name": "search_tool", "arguments": " details\"}"} + "function": { + "name": "search_tool", + "arguments": ' details"}', + }, } ] }, "finish_reason": "tool_calls", - "index": 0 + "index": 0, } - ] - } + ], + }, ] response, output_string = mock_provider.join_chunks(chunks) - - assert output_string == "['search_tool', '{\"query\": \"weather details\"}']" - + assert output_string == "['search_tool', '{\"query\": \"weather details\"}']" assert response.object == "chat.completion" assert response.choices[0].finish_reason == "tool_calls" @@ -176,9 +195,10 @@ def test_join_chunks_tool_calls(mock_provider): assert tool_call.id == "tool_1" assert tool_call.function.name == "search_tool" - assert tool_call.function.arguments == "{\"query\": \"weather details\"}" + assert tool_call.function.arguments == '{"query": "weather details"}' assert tool_call.type == "function" - + + def test_input_to_string_with_string(mock_provider): input_data = "Hello, world!" assert mock_provider.input_to_string(input_data) == "Hello, world!" @@ -195,7 +215,15 @@ def test_input_to_string_with_list_of_text_messages(mock_provider): def test_input_to_string_with_list_of_text_and_url(mock_provider): input_data = [ {"role": "user", "content": [{"type": "text", "text": "Hello "}]}, - {"role": "user", "content": [{"type": "image_url", "image_url": {"url": "http://example.com/image.jpg"}}]}, + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": "http://example.com/image.jpg"}, + } + ], + }, {"role": "user", "content": [{"type": "text", "text": " world!"}]}, ] expected_output = "Hello http://example.com/image.jpg world!" @@ -206,7 +234,15 @@ def test_input_to_string_with_mixed_roles_and_missing_content(mock_provider): input_data = [ {"role": "assistant", "content": "Admin text;"}, {"role": "user", "content": [{"type": "text", "text": "User text"}]}, - {"role": "user", "content": [{"type": "image_url", "image_url": {"url": "http://example.com/another.jpg"}}]}, + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": "http://example.com/another.jpg"}, + } + ], + }, ] expected_output = "Admin text;User texthttp://example.com/another.jpg" assert mock_provider.input_to_string(input_data) == expected_output @@ -217,11 +253,10 @@ def test_input_to_string_with_missing_content_key(mock_provider): {"role": "user"}, {"role": "user", "content": [{"type": "text", "text": "Hello again"}]}, ] - expected_output = "Hello again" + expected_output = "Hello again" assert mock_provider.input_to_string(input_data) == expected_output def test_input_to_string_with_empty_list(mock_provider): input_data = [] assert mock_provider.input_to_string(input_data) == "" - diff --git a/libs/core/tests/unit_tests/test_provider_costs_and_metrics.py b/libs/core/tests/unit_tests/test_provider_costs_and_metrics.py index 9cf0d52f..fb54d602 100644 --- a/libs/core/tests/unit_tests/test_provider_costs_and_metrics.py +++ b/libs/core/tests/unit_tests/test_provider_costs_and_metrics.py @@ -1,5 +1,6 @@ from unittest.mock import MagicMock + def test_calculate_metrics(mock_provider): metrics = mock_provider.calculate_metrics( @@ -18,10 +19,13 @@ def test_calculate_metrics(mock_provider): assert metrics["total_tokens"] == 3 assert metrics["cost_usd"] == 0.01 * 1 + 0.02 * 2 # input_cost + output_cost assert metrics["latency_s"] == 1.0 # end_time - start_time - assert metrics["time_to_first_token_s"] == 0.5 - 0.0 # first_token_time - start_time + assert ( + metrics["time_to_first_token_s"] == 0.5 - 0.0 + ) # first_token_time - start_time assert metrics["inter_token_latency_s"] == 0.1 # Average of token_times assert metrics["tokens_per_second"] == 2 / 1.0 # token_count / total_time - + + def test_calculate_metrics_single_token(mock_provider): metrics = mock_provider.calculate_metrics( @@ -43,13 +47,15 @@ def test_calculate_metrics_single_token(mock_provider): assert metrics["time_to_first_token_s"] == 0.5 - 0.0 assert metrics["inter_token_latency_s"] == 0 assert metrics["tokens_per_second"] == 1 / 1.0 - + + def test_calculate_cost_fixed_cost(mock_provider): fixed_cost = 0.02 token_count = 100 expected_cost = token_count * fixed_cost assert mock_provider.calculate_cost(token_count, fixed_cost) == expected_cost + def test_calculate_cost_variable_cost(mock_provider): cost_range_1 = MagicMock() cost_range_1.range = (0, 50) @@ -64,6 +70,7 @@ def test_calculate_cost_variable_cost(mock_provider): expected_cost = token_count * 0.02 assert mock_provider.calculate_cost(token_count, variable_cost) == expected_cost + def test_calculate_cost_variable_cost_higher_range(mock_provider): cost_range_1 = MagicMock() cost_range_1.range = (0, 50) @@ -82,6 +89,7 @@ def test_calculate_cost_variable_cost_higher_range(mock_provider): expected_cost = token_count * 0.03 assert mock_provider.calculate_cost(token_count, variable_cost) == expected_cost + def test_calculate_cost_variable_cost_no_matching_range(mock_provider): cost_range_1 = MagicMock() cost_range_1.range = (0, 50) @@ -99,7 +107,8 @@ def test_calculate_cost_variable_cost_no_matching_range(mock_provider): token_count = 200 expected_cost = 0 assert mock_provider.calculate_cost(token_count, variable_cost) == expected_cost - + + def test_calculate_cost_variable_cost_no_matching_range_inferior(mock_provider): cost_range_1 = MagicMock() cost_range_1.range = (10, 50) @@ -116,4 +125,4 @@ def test_calculate_cost_variable_cost_no_matching_range_inferior(mock_provider): variable_cost = [cost_range_1, cost_range_2, cost_range_3] token_count = 5 expected_cost = 0 - assert mock_provider.calculate_cost(token_count, variable_cost) == expected_cost \ No newline at end of file + assert mock_provider.calculate_cost(token_count, variable_cost) == expected_cost diff --git a/libs/core/tests/unit_tests/test_provider_handle_response.py b/libs/core/tests/unit_tests/test_provider_handle_response.py index 7e6cc5b1..e7107eeb 100644 --- a/libs/core/tests/unit_tests/test_provider_handle_response.py +++ b/libs/core/tests/unit_tests/test_provider_handle_response.py @@ -1,13 +1,22 @@ from unittest.mock import MagicMock import pytest -from llmstudio_core.providers.provider import ChatCompletion, time, ChatCompletionChunk +from llmstudio_core.providers.provider import ChatCompletion, ChatCompletionChunk, time + @pytest.mark.asyncio async def test_ahandle_response_non_streaming(mock_provider): - request = MagicMock(is_stream=False, chat_input="Hello", model="test_model", parameters={}) + request = MagicMock( + is_stream=False, chat_input="Hello", model="test_model", parameters={} + ) response_chunk = { - "choices": [{"delta": {"content": "Non-streamed response"}, "finish_reason": "stop", "index":0}], + "choices": [ + { + "delta": {"content": "Non-streamed response"}, + "finish_reason": "stop", + "index": 0, + } + ], "model": "test_model", } start_time = time.time() @@ -16,25 +25,47 @@ async def mock_aparse_response(*args, **kwargs): yield response_chunk mock_provider.aparse_response = mock_aparse_response - mock_provider.join_chunks = MagicMock(return_value=(ChatCompletion(id="id", choices=[], created=0, model="test_model", object="chat.completion"), "Non-streamed response")) + mock_provider.join_chunks = MagicMock( + return_value=( + ChatCompletion( + id="id", + choices=[], + created=0, + model="test_model", + object="chat.completion", + ), + "Non-streamed response", + ) + ) mock_provider.calculate_metrics = MagicMock(return_value={"input_tokens": 1}) response = [] - async for chunk in mock_provider.ahandle_response(request, mock_aparse_response(), start_time): + async for chunk in mock_provider.ahandle_response( + request, mock_aparse_response(), start_time + ): response.append(chunk) assert isinstance(response[0], ChatCompletion) assert response[0].choices == [] assert response[0].chat_output == "Non-streamed response" + @pytest.mark.asyncio async def test_ahandle_response_streaming_length(mock_provider): - request = MagicMock(is_stream=True, chat_input="Hello", model="test_model", parameters={}) + request = MagicMock( + is_stream=True, chat_input="Hello", model="test_model", parameters={} + ) response_chunk = { - "choices": [{"delta": {"content": "Streamed response"}, "finish_reason": "length", "index":0}], + "choices": [ + { + "delta": {"content": "Streamed response"}, + "finish_reason": "length", + "index": 0, + } + ], "model": "test_model", - "object":"chat.completion.chunk", - "created":0 + "object": "chat.completion.chunk", + "created": 0, } start_time = time.time() @@ -42,24 +73,46 @@ async def mock_aparse_response(*args, **kwargs): yield response_chunk mock_provider.aparse_response = mock_aparse_response - mock_provider.join_chunks = MagicMock(return_value=(ChatCompletion(id="id", choices=[], created=0, model="test_model", object="chat.completion"), "Streamed response")) + mock_provider.join_chunks = MagicMock( + return_value=( + ChatCompletion( + id="id", + choices=[], + created=0, + model="test_model", + object="chat.completion", + ), + "Streamed response", + ) + ) mock_provider.calculate_metrics = MagicMock(return_value={"input_tokens": 1}) response = [] - async for chunk in mock_provider.ahandle_response(request, mock_aparse_response(), start_time): + async for chunk in mock_provider.ahandle_response( + request, mock_aparse_response(), start_time + ): response.append(chunk) assert isinstance(response[0], ChatCompletionChunk) assert response[0].chat_output_stream == "Streamed response" - + + @pytest.mark.asyncio async def test_ahandle_response_streaming_stop(mock_provider): - request = MagicMock(is_stream=True, chat_input="Hello", model="test_model", parameters={}) + request = MagicMock( + is_stream=True, chat_input="Hello", model="test_model", parameters={} + ) response_chunk = { - "choices": [{"delta": {"content": "Streamed response"}, "finish_reason": "stop", "index":0}], + "choices": [ + { + "delta": {"content": "Streamed response"}, + "finish_reason": "stop", + "index": 0, + } + ], "model": "test_model", - "object":"chat.completion.chunk", - "created":0 + "object": "chat.completion.chunk", + "created": 0, } start_time = time.time() @@ -67,22 +120,43 @@ async def mock_aparse_response(*args, **kwargs): yield response_chunk mock_provider.aparse_response = mock_aparse_response - mock_provider.join_chunks = MagicMock(return_value=(ChatCompletion(id="id", choices=[], created=0, model="test_model", object="chat.completion"), "Streamed response")) + mock_provider.join_chunks = MagicMock( + return_value=( + ChatCompletion( + id="id", + choices=[], + created=0, + model="test_model", + object="chat.completion", + ), + "Streamed response", + ) + ) mock_provider.calculate_metrics = MagicMock(return_value={"input_tokens": 1}) response = [] - async for chunk in mock_provider.ahandle_response(request, mock_aparse_response(), start_time): + async for chunk in mock_provider.ahandle_response( + request, mock_aparse_response(), start_time + ): response.append(chunk) assert isinstance(response[0], ChatCompletionChunk) assert response[0].chat_output == "Streamed response" - #assert response[0].chat_output_stream == "Streamed response" - + # assert response[0].chat_output_stream == "Streamed response" + def test_handle_response_non_streaming(mock_provider): - request = MagicMock(is_stream=False, chat_input="Hello", model="test_model", parameters={}) + request = MagicMock( + is_stream=False, chat_input="Hello", model="test_model", parameters={} + ) response_chunk = { - "choices": [{"delta": {"content": "Non-streamed response"}, "finish_reason": "stop", "index":0}], + "choices": [ + { + "delta": {"content": "Non-streamed response"}, + "finish_reason": "stop", + "index": 0, + } + ], "model": "test_model", } start_time = time.time() @@ -91,24 +165,46 @@ def mock_parse_response(*args, **kwargs): yield response_chunk mock_provider.aparse_response = mock_parse_response - mock_provider.join_chunks = MagicMock(return_value=(ChatCompletion(id="id", choices=[], created=0, model="test_model", object="chat.completion"), "Non-streamed response")) + mock_provider.join_chunks = MagicMock( + return_value=( + ChatCompletion( + id="id", + choices=[], + created=0, + model="test_model", + object="chat.completion", + ), + "Non-streamed response", + ) + ) mock_provider.calculate_metrics = MagicMock(return_value={"input_tokens": 1}) response = [] - for chunk in mock_provider.handle_response(request, mock_parse_response(), start_time): + for chunk in mock_provider.handle_response( + request, mock_parse_response(), start_time + ): response.append(chunk) assert isinstance(response[0], ChatCompletion) assert response[0].choices == [] assert response[0].chat_output == "Non-streamed response" - + + def test_handle_response_streaming_length(mock_provider): - request = MagicMock(is_stream=True, chat_input="Hello", model="test_model", parameters={}) + request = MagicMock( + is_stream=True, chat_input="Hello", model="test_model", parameters={} + ) response_chunk = { - "choices": [{"delta": {"content": "Streamed response"}, "finish_reason": "length", "index":0}], + "choices": [ + { + "delta": {"content": "Streamed response"}, + "finish_reason": "length", + "index": 0, + } + ], "model": "test_model", - "object":"chat.completion.chunk", - "created":0 + "object": "chat.completion.chunk", + "created": 0, } start_time = time.time() @@ -116,23 +212,45 @@ def mock_parse_response(*args, **kwargs): yield response_chunk mock_provider.aparse_response = mock_parse_response - mock_provider.join_chunks = MagicMock(return_value=(ChatCompletion(id="id", choices=[], created=0, model="test_model", object="chat.completion"), "Streamed response")) + mock_provider.join_chunks = MagicMock( + return_value=( + ChatCompletion( + id="id", + choices=[], + created=0, + model="test_model", + object="chat.completion", + ), + "Streamed response", + ) + ) mock_provider.calculate_metrics = MagicMock(return_value={"input_tokens": 1}) response = [] - for chunk in mock_provider.handle_response(request, mock_parse_response(), start_time): + for chunk in mock_provider.handle_response( + request, mock_parse_response(), start_time + ): response.append(chunk) assert isinstance(response[0], ChatCompletionChunk) assert response[0].chat_output_stream == "Streamed response" - + + def test_handle_response_streaming_stop(mock_provider): - request = MagicMock(is_stream=True, chat_input="Hello", model="test_model", parameters={}) + request = MagicMock( + is_stream=True, chat_input="Hello", model="test_model", parameters={} + ) response_chunk = { - "choices": [{"delta": {"content": "Streamed response"}, "finish_reason": "stop", "index":0}], + "choices": [ + { + "delta": {"content": "Streamed response"}, + "finish_reason": "stop", + "index": 0, + } + ], "model": "test_model", - "object":"chat.completion.chunk", - "created":0 + "object": "chat.completion.chunk", + "created": 0, } start_time = time.time() @@ -140,13 +258,26 @@ def mock_parse_response(*args, **kwargs): yield response_chunk mock_provider.parse_response = mock_parse_response - mock_provider.join_chunks = MagicMock(return_value=(ChatCompletion(id="id", choices=[], created=0, model="test_model", object="chat.completion"), "Streamed response")) + mock_provider.join_chunks = MagicMock( + return_value=( + ChatCompletion( + id="id", + choices=[], + created=0, + model="test_model", + object="chat.completion", + ), + "Streamed response", + ) + ) mock_provider.calculate_metrics = MagicMock(return_value={"input_tokens": 1}) response = [] - for chunk in mock_provider.handle_response(request, mock_parse_response(), start_time): + for chunk in mock_provider.handle_response( + request, mock_parse_response(), start_time + ): response.append(chunk) assert isinstance(response[0], ChatCompletionChunk) assert response[0].chat_output == "Streamed response" - #assert response[0].chat_output_stream == "Streamed response" \ No newline at end of file + # assert response[0].chat_output_stream == "Streamed response" From 6d616338e553230fb01efbd97c1f8900b710be6e Mon Sep 17 00:00:00 2001 From: Miguel Neves Date: Wed, 18 Dec 2024 10:18:37 +0000 Subject: [PATCH 10/13] chore: deleted some comments --- libs/core/tests/unit_tests/test_provider_handle_response.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/libs/core/tests/unit_tests/test_provider_handle_response.py b/libs/core/tests/unit_tests/test_provider_handle_response.py index e7107eeb..ac04e32b 100644 --- a/libs/core/tests/unit_tests/test_provider_handle_response.py +++ b/libs/core/tests/unit_tests/test_provider_handle_response.py @@ -142,7 +142,6 @@ async def mock_aparse_response(*args, **kwargs): assert isinstance(response[0], ChatCompletionChunk) assert response[0].chat_output == "Streamed response" - # assert response[0].chat_output_stream == "Streamed response" def test_handle_response_non_streaming(mock_provider): @@ -280,4 +279,3 @@ def mock_parse_response(*args, **kwargs): assert isinstance(response[0], ChatCompletionChunk) assert response[0].chat_output == "Streamed response" - # assert response[0].chat_output_stream == "Streamed response" From 2c4208c553e241fbcf76875922729fb0b301ab1a Mon Sep 17 00:00:00 2001 From: Miguel Neves Date: Wed, 18 Dec 2024 10:24:47 +0000 Subject: [PATCH 11/13] chore: linted --- libs/core/tests/unit_tests/test_provider.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/libs/core/tests/unit_tests/test_provider.py b/libs/core/tests/unit_tests/test_provider.py index a3a2288e..42396c99 100644 --- a/libs/core/tests/unit_tests/test_provider.py +++ b/libs/core/tests/unit_tests/test_provider.py @@ -1,11 +1,7 @@ from unittest.mock import MagicMock import pytest -from llmstudio_core.providers.provider import ( - ChatRequest, - ProviderError, - time, -) +from llmstudio_core.providers.provider import ChatRequest, ProviderError, time request = ChatRequest(chat_input="Hello World", model="test_model") From 2a2e5be4cc45f403cfe420ae4840871017632916 Mon Sep 17 00:00:00 2001 From: Miguel Neves <61327611+MiNeves00@users.noreply.github.com> Date: Wed, 18 Dec 2024 10:29:16 +0000 Subject: [PATCH 12/13] chore: Added Azure Provider Unit Tests (#176) * chore: added unit tests for azure provider * chore: added more unit tests and docstrings on azure, removed redundant comments * chore: added unit tests for generate client on Azure Provider * chore: separated azure unit tests into separate files. fixed some of its tests. * chore: linted code --- .gitignore | 1 + Makefile | 3 + libs/core/llmstudio_core/providers/azure.py | 137 +++++++-- libs/core/tests/unit_tests/conftest.py | 25 ++ libs/core/tests/unit_tests/test_azure.py | 232 +++++++++++++++ .../core/tests/unit_tests/test_azure_build.py | 279 ++++++++++++++++++ 6 files changed, 658 insertions(+), 19 deletions(-) create mode 100644 libs/core/tests/unit_tests/test_azure.py create mode 100644 libs/core/tests/unit_tests/test_azure_build.py diff --git a/.gitignore b/.gitignore index 456aa8d8..d80305ef 100644 --- a/.gitignore +++ b/.gitignore @@ -67,6 +67,7 @@ venv.bak/ config.yaml bun.lockb + # Jupyter Notebook .ipynb_checkpoints diff --git a/Makefile b/Makefile index 5a43a3e6..b9329a81 100644 --- a/Makefile +++ b/Makefile @@ -1,2 +1,5 @@ format: pre-commit run --all-files + +unit-tests: + pytest libs/core/tests/unit_tests \ No newline at end of file diff --git a/libs/core/llmstudio_core/providers/azure.py b/libs/core/llmstudio_core/providers/azure.py index 2dbd7307..f558f9d6 100644 --- a/libs/core/llmstudio_core/providers/azure.py +++ b/libs/core/llmstudio_core/providers/azure.py @@ -62,7 +62,25 @@ async def agenerate_client(self, request: ChatRequest) -> Any: return self.generate_client(request=request) def generate_client(self, request: ChatRequest) -> Any: - """Generate an AzureOpenAI client""" + """ + Generates an AzureOpenAI client for processing a chat request. + + This method prepares and configures the arguments required to create a client + request to AzureOpenAI's chat completions API. It determines model-specific + configurations (e.g., whether tools or functions are enabled) and combines + these with the base arguments for the API call. + + Args: + request (ChatRequest): The chat request object containing the model, + parameters, and other necessary details. + + Returns: + Any: The result of the chat completions API call. + + Raises: + ProviderError: If there is an issue with the API connection or an error + returned from the API. + """ self.is_llama = "llama" in request.model.lower() self.is_openai = "gpt" in request.model.lower() @@ -72,7 +90,6 @@ def generate_client(self, request: ChatRequest) -> Any: try: messages = self.prepare_messages(request) - # Prepare the optional tool-related arguments tool_args = {} if not self.is_llama and self.has_tools and self.is_openai: tool_args = { @@ -80,7 +97,6 @@ def generate_client(self, request: ChatRequest) -> Any: "tool_choice": "auto" if request.parameters.get("tools") else None, } - # Prepare the optional function-related arguments function_args = {} if not self.is_llama and self.has_functions and self.is_openai: function_args = { @@ -90,14 +106,12 @@ def generate_client(self, request: ChatRequest) -> Any: else None, } - # Prepare the base arguments base_args = { "model": request.model, "messages": messages, "stream": True, } - # Combine all arguments combined_args = { **base_args, **tool_args, @@ -116,13 +130,13 @@ def prepare_messages(self, request: ChatRequest): if self.is_llama and (self.has_tools or self.has_functions): user_message = self.convert_to_openai_format(request.chat_input) content = "<|begin_of_text|>" - content = self.add_system_message( + content = self.build_llama_system_message( user_message, content, request.parameters.get("tools"), request.parameters.get("functions"), ) - content = self.add_conversation(user_message, content) + content = self.build_llama_conversation(user_message, content) return [{"role": "user", "content": content}] else: return ( @@ -139,6 +153,20 @@ async def aparse_response( yield chunk def parse_response(self, response: AsyncGenerator, **kwargs) -> Any: + """ + Processes a generator response and yields processed chunks. + + If `is_llama` is True and tools or functions are enabled, it processes the response + using `handle_tool_response`. Otherwise, it processes each chunk and yields only those + containing "choices". + + Args: + response (Generator): The response generator to process. + **kwargs: Additional arguments for tool handling. + + Yields: + Any: Processed response chunks. + """ if self.is_llama and (self.has_tools or self.has_functions): for chunk in self.handle_tool_response(response, **kwargs): if chunk: @@ -388,9 +416,25 @@ def convert_to_openai_format(self, message: Union[str, list]) -> list: return [{"role": "user", "content": message}] return message - def add_system_message( + def build_llama_system_message( self, openai_message: list, llama_message: str, tools: list, functions: list ) -> str: + """ + Builds a complete system message for Llama based on OpenAI's message, tools, and functions. + + If a system message is present in the OpenAI message, it is included in the result. + Otherwise, a default system message is used. Additional tool and function instructions + are appended if provided. + + Args: + openai_message (list): List of OpenAI messages. + llama_message (str): The message to prepend to the system message. + tools (list): List of tools to include in the system message. + functions (list): List of functions to include in the system message. + + Returns: + str: The formatted system message combined with Llama message. + """ system_message = "" system_message_found = False for message in openai_message: @@ -407,15 +451,31 @@ def add_system_message( """ if tools: - system_message = system_message + self.add_tool_instructions(tools) + system_message = system_message + self.build_tool_instructions(tools) if functions: - system_message = system_message + self.add_function_instructions(functions) + system_message = system_message + self.build_function_instructions( + functions + ) end_tag = "\n<|eot_id|>" return llama_message + system_message + end_tag - def add_tool_instructions(self, tools: list) -> str: + def build_tool_instructions(self, tools: list) -> str: + """ + Builds a detailed instructional prompt for tools available to the assistant. + + This function generates a message describing the available tools, focusing on tools + of type "function." It explains to the LLM how to use each tool and provides an example of the + correct response format for function calls. + + Args: + tools (list): A list of tool dictionaries, where each dictionary contains tool + details such as type, function name, description, and parameters. + + Returns: + str: A formatted string detailing the tool instructions and usage examples. + """ tool_prompt = """ You have access to the following tools: """ @@ -449,7 +509,21 @@ def add_tool_instructions(self, tools: list) -> str: return tool_prompt - def add_function_instructions(self, functions: list) -> str: + def build_function_instructions(self, functions: list) -> str: + """ + Builds a detailed instructional prompt for available functions. + + This method creates a message describing the functions accessible to the assistant. + It includes the function name, description, and required parameters, along with + specific guidelines for calling functions. + + Args: + functions (list): A list of function dictionaries, each containing details such as + name, description, and parameters. + + Returns: + str: A formatted string with instructions on using the provided functions. + """ function_prompt = """ You have access to the following functions: """ @@ -479,35 +553,60 @@ def add_function_instructions(self, functions: list) -> str: """ return function_prompt - def add_conversation(self, openai_message: list, llama_message: str) -> str: + def build_llama_conversation(self, openai_message: list, llama_message: str) -> str: + """ + Appends the OpenAI message to the Llama message while formatting OpenAI messages. + + This function iterates through a list of OpenAI messages and formats them for inclusion + in a Llama message. It handles user messages that might include nested content (lists of + messages) by safely evaluating the content. System messages are skipped. + + Args: + openai_message (list): A list of dictionaries representing the OpenAI messages. Each + dictionary should have "role" and "content" keys. + llama_message (str): The initial Llama message to which the conversation is appended. + + Returns: + str: The Llama message with the conversation appended. + """ conversation_parts = [] for message in openai_message: if message["role"] == "system": continue elif message["role"] == "user" and isinstance(message["content"], str): try: - # Attempt to safely evaluate the string to a Python object content_as_list = ast.literal_eval(message["content"]) if isinstance(content_as_list, list): - # If the content is a list, process each nested message for nested_message in content_as_list: conversation_parts.append( self.format_message(nested_message) ) else: - # If the content is not a list, append it directly conversation_parts.append(self.format_message(message)) except (ValueError, SyntaxError): - # If evaluation fails or content is not a list/dict string, append the message directly conversation_parts.append(self.format_message(message)) else: - # For all other messages, use the existing formatting logic conversation_parts.append(self.format_message(message)) return llama_message + "".join(conversation_parts) def format_message(self, message: dict) -> str: - """Format a single message for the conversation.""" + """ + Formats a single message dictionary into a structured string for a conversation. + + The formatting depends on the content of the message, such as tool calls, + function calls, or simple user/assistant messages. Each type of message + is formatted with specific headers and tags. + + Args: + message (dict): A dictionary containing message details. Expected keys + include "role", "content", and optionally "tool_calls", + "tool_call_id", or "function_call". + + Returns: + str: A formatted string representing the message. Returns an empty + string if the message cannot be formatted. + """ if "tool_calls" in message: for tool_call in message["tool_calls"]: function_name = tool_call["function"]["name"] diff --git a/libs/core/tests/unit_tests/conftest.py b/libs/core/tests/unit_tests/conftest.py index df8e04cc..5eeccc83 100644 --- a/libs/core/tests/unit_tests/conftest.py +++ b/libs/core/tests/unit_tests/conftest.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock import pytest +from llmstudio_core.providers.azure import AzureProvider from llmstudio_core.providers.provider import ProviderCore @@ -52,3 +53,27 @@ def mock_provider(): tokenizer = MagicMock() tokenizer.encode = lambda x: x.split() # Simple tokenizer mock return MockProvider(config=config, tokenizer=tokenizer) + + +class MockAzureProvider(AzureProvider): + async def aparse_response(self, response, **kwargs): + return response + + async def agenerate_client(self, request): + # For testing, return an async generator + async def async_gen(): + yield {} + + return async_gen() + + @staticmethod + def _provider_config_name(): + return "mock_azure_provider" + + +@pytest.fixture +def mock_azure_provider(): + config = MagicMock() + config.id = "mock_azure_provider" + base_url = "mock_url.com" + return MockAzureProvider(config=config, base_url=base_url) diff --git a/libs/core/tests/unit_tests/test_azure.py b/libs/core/tests/unit_tests/test_azure.py new file mode 100644 index 00000000..28e5f684 --- /dev/null +++ b/libs/core/tests/unit_tests/test_azure.py @@ -0,0 +1,232 @@ +from unittest.mock import MagicMock + + +class TestParseResponse: + def test_tool_response_handling(self, mock_azure_provider): + + mock_azure_provider.is_llama = True + mock_azure_provider.has_tools = True + mock_azure_provider.has_functions = False + + mock_azure_provider.handle_tool_response = MagicMock( + return_value=iter(["chunk1", None, "chunk2", None, "chunk3"]) + ) + + response = iter(["irrelevant"]) + + results = list(mock_azure_provider.parse_response(response)) + + assert results == ["chunk1", "chunk2", "chunk3"] + mock_azure_provider.handle_tool_response.assert_called_once_with(response) + + def test_direct_response_handling_with_choices(self, mock_azure_provider): + mock_azure_provider.is_llama = False + + chunk1 = MagicMock() + chunk1.model_dump.return_value = {"choices": ["choice1", "choice2"]} + chunk2 = MagicMock() + chunk2.model_dump.return_value = {"choices": ["choice2"]} + response = iter([chunk1, chunk2]) + + results = list(mock_azure_provider.parse_response(response)) + + assert results == [ + {"choices": ["choice1", "choice2"]}, + {"choices": ["choice2"]}, + ] + chunk1.model_dump.assert_called_once() + chunk2.model_dump.assert_called_once() + + def test_direct_response_handling_without_choices(self, mock_azure_provider): + mock_azure_provider.is_llama = False + + chunk1 = MagicMock() + chunk1.model_dump.return_value = {"key": "value"} + chunk2 = MagicMock() + chunk2.model_dump.return_value = {"another_key": "another_value"} + response = iter([chunk1, chunk2]) + + results = list(mock_azure_provider.parse_response(response)) + + assert results == [] + chunk1.model_dump.assert_called_once() + chunk2.model_dump.assert_called_once() + + +class TestFormatMessage: + def test_format_message_tool_calls(self, mock_azure_provider): + message = { + "tool_calls": [ + { + "function": { + "name": "example_tool", + "arguments": '{"arg1": "value1"}', + } + } + ] + } + result = mock_azure_provider.format_message(message) + expected = """ + <|start_header_id|>assistant<|end_header_id|> + {"arg1": "value1"} + <|eom_id|> + """ + assert result.strip() == expected.strip() + + def test_format_message_tool_call_id(self, mock_azure_provider): + message = {"tool_call_id": "123", "content": "This is the tool response."} + result = mock_azure_provider.format_message(message) + expected = """ + <|start_header_id|>ipython<|end_header_id|> + This is the tool response. + <|eot_id|> + """ + assert result.strip() == expected.strip() + + def test_format_message_function_call(self, mock_azure_provider): + message = { + "function_call": { + "name": "example_function", + "arguments": '{"arg1": "value1"}', + } + } + result = mock_azure_provider.format_message(message) + expected = """ + <|start_header_id|>assistant<|end_header_id|> + {"arg1": "value1"} + <|eom_id|> + """ + assert result.strip() == expected.strip() + + def test_format_message_user_message(self, mock_azure_provider): + message = {"role": "user", "content": "This is a user message."} + result = mock_azure_provider.format_message(message) + expected = """ + <|start_header_id|>user<|end_header_id|> + This is a user message. + <|eot_id|> + """ + assert result.strip() == expected.strip() + + def test_format_message_assistant_message(self, mock_azure_provider): + message = {"role": "assistant", "content": "This is an assistant message."} + result = mock_azure_provider.format_message(message) + expected = """ + <|start_header_id|>assistant<|end_header_id|> + This is an assistant message. + <|eot_id|> + """ + assert result.strip() == expected.strip() + + def test_format_message_function_response(self, mock_azure_provider): + message = {"role": "function", "content": "This is the function response."} + result = mock_azure_provider.format_message(message) + expected = """ + <|start_header_id|>ipython<|end_header_id|> + This is the function response. + <|eot_id|> + """ + assert result.strip() == expected.strip() + + def test_format_message_empty_message(self, mock_azure_provider): + message = {"role": "user", "content": None} + result = mock_azure_provider.format_message(message) + expected = "" + assert result == expected + + +class TestGenerateClient: + def test_generate_client_with_tools_and_functions(self, mock_azure_provider): + mock_azure_provider.prepare_messages = MagicMock( + return_value="prepared_messages" + ) + mock_azure_provider._client.chat.completions.create = MagicMock( + return_value="mock_response" + ) + + request = MagicMock() + request.model = "gpt-4" + request.parameters = { + "tools": ["tool1", "tool2"], + "functions": ["function1", "function2"], + "other_param": "value", + } + + result = mock_azure_provider.generate_client(request) + + expected_args = { + "model": "gpt-4", + "messages": "prepared_messages", + "stream": True, + "tools": ["tool1", "tool2"], + "tool_choice": "auto", + "functions": ["function1", "function2"], + "function_call": "auto", + "other_param": "value", + } + + assert result == "mock_response" + mock_azure_provider.prepare_messages.assert_called_once_with(request) + mock_azure_provider._client.chat.completions.create.assert_called_once_with( + **expected_args + ) + + def test_generate_client_without_tools_or_functions(self, mock_azure_provider): + mock_azure_provider.prepare_messages = MagicMock( + return_value="prepared_messages" + ) + mock_azure_provider._client.chat.completions.create = MagicMock( + return_value="mock_response" + ) + + request = MagicMock() + request.model = "gpt-4" + request.parameters = {"other_param": "value"} + + result = mock_azure_provider.generate_client(request) + + expected_args = { + "model": "gpt-4", + "messages": "prepared_messages", + "stream": True, + "other_param": "value", + } + + assert result == "mock_response" + mock_azure_provider.prepare_messages.assert_called_once_with(request) + mock_azure_provider._client.chat.completions.create.assert_called_once_with( + **expected_args + ) + + def test_generate_client_with_llama_model(self, mock_azure_provider): + mock_azure_provider.prepare_messages = MagicMock( + return_value="prepared_messages" + ) + mock_azure_provider._client.chat.completions.create = MagicMock( + return_value="mock_response" + ) + + request = MagicMock() + request.model = "llama-2" + request.parameters = { + "tools": ["tool1"], + "functions": ["function1"], + "other_param": "value", + } + + result = mock_azure_provider.generate_client(request) + + expected_args = { + "model": "llama-2", + "messages": "prepared_messages", + "stream": True, + "tools": ["tool1"], + "functions": ["function1"], + "other_param": "value", + } + + assert result == "mock_response" + mock_azure_provider.prepare_messages.assert_called_once_with(request) + mock_azure_provider._client.chat.completions.create.assert_called_once_with( + **expected_args + ) diff --git a/libs/core/tests/unit_tests/test_azure_build.py b/libs/core/tests/unit_tests/test_azure_build.py new file mode 100644 index 00000000..17d73f99 --- /dev/null +++ b/libs/core/tests/unit_tests/test_azure_build.py @@ -0,0 +1,279 @@ +from unittest.mock import MagicMock, patch + + +class TestBuildLlamaSystemMessage: + def test_build_llama_system_message_with_existing_sm(self, mock_azure_provider): + mock_azure_provider.build_tool_instructions = MagicMock( + return_value="Tool Instructions" + ) + mock_azure_provider.build_function_instructions = MagicMock( + return_value="\nFunction Instructions" + ) + + openai_message = [ + {"role": "user", "content": "Hello"}, + {"role": "system", "content": "Custom system message"}, + ] + llama_message = "Initial message" + tools = ["Tool1", "Tool2"] + functions = ["Function1"] + + result = mock_azure_provider.build_llama_system_message( + openai_message, llama_message, tools, functions + ) + + expected = ( + "Initial message\n" + " <|start_header_id|>system<|end_header_id|>\n" + " Custom system message\n" + " Tool Instructions\nFunction Instructions\n<|eot_id|>" # identation here exists because in Python when adding a newline to a triple quote string it keeps identation + ) + assert result == expected + mock_azure_provider.build_tool_instructions.assert_called_once_with(tools) + mock_azure_provider.build_function_instructions.assert_called_once_with( + functions + ) + + def test_build_llama_system_message_with_default_sm(self, mock_azure_provider): + mock_azure_provider.build_tool_instructions = MagicMock( + return_value="Tool Instructions" + ) + mock_azure_provider.build_function_instructions = MagicMock( + return_value="\nFunction Instructions" + ) + + openai_message = [{"role": "user", "content": "Hello"}] + llama_message = "Initial message" + tools = ["Tool1"] + functions = [] + + result = mock_azure_provider.build_llama_system_message( + openai_message, llama_message, tools, functions + ) + + expected = ( + "Initial message\n" + " <|start_header_id|>system<|end_header_id|>\n" + " You are a helpful AI assistant.\n" + " Tool Instructions\n<|eot_id|>" + ) + assert result == expected + mock_azure_provider.build_tool_instructions.assert_called_once_with(tools) + mock_azure_provider.build_function_instructions.assert_not_called() + + def test_build_llama_system_message_without_tools_or_functions( + self, mock_azure_provider + ): + mock_azure_provider.build_tool_instructions = MagicMock() + mock_azure_provider.build_function_instructions = MagicMock() + + openai_message = [{"role": "system", "content": "Minimal system message"}] + llama_message = "Initial message" + tools = [] + functions = [] + + result = mock_azure_provider.build_llama_system_message( + openai_message, llama_message, tools, functions + ) + + expected = ( + "Initial message\n" + " <|start_header_id|>system<|end_header_id|>\n" + " Minimal system message\n \n<|eot_id|>" + ) + assert result == expected + mock_azure_provider.build_tool_instructions.assert_not_called() + mock_azure_provider.build_function_instructions.assert_not_called() + + +class TestBuildInstructions: + def test_build_tool_instructions(self, mock_azure_provider): + tools = [ + { + "type": "function", + "function": { + "name": "python_repl_ast", + "description": "execute Python code", + "parameters": {"query": "string"}, + }, + }, + { + "type": "function", + "function": { + "name": "data_lookup", + "description": "retrieve data from a database", + "parameters": {"database": "string", "query": "string"}, + }, + }, + ] + + result = mock_azure_provider.build_tool_instructions(tools) + + expected = """ + You have access to the following tools: + Use the function 'python_repl_ast' to 'execute Python code': +Parameters format: +{ + "query": "string" +} + +Use the function 'data_lookup' to 'retrieve data from a database': +Parameters format: +{ + "database": "string", + "query": "string" +} + + +If you choose to use a function to produce this response, ONLY reply in the following format with no prefix or suffix: +§{"type": "function", "name": "FUNCTION_NAME", "parameters": {"PARAMETER_NAME": PARAMETER_VALUE}} +IMPORTANT: IT IS VITAL THAT YOU NEVER ADD A PREFIX OR A SUFFIX TO THE FUNCTION CALL. + +Here is an example of the output I desiere when performing function call: +§{"type": "function", "name": "python_repl_ast", "parameters": {"query": "print(df.shape)"}} +NOTE: There is no prefix before the symbol '§' and nothing comes after the call is done. + + Reminder: + - Function calls MUST follow the specified format. + - Only call one function at a time. + - Required parameters MUST be specified. + - Put the entire function call reply on one line. + - If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls. + - If you have already called a tool and got the response for the users question please reply with the response. + """ + assert result.strip() == expected.strip() + + def test_build_function_instructions(self, mock_azure_provider): + functions = [ + { + "name": "python_repl_ast", + "description": "execute Python code", + "parameters": {"query": "string"}, + }, + { + "name": "data_lookup", + "description": "retrieve data from a database", + "parameters": {"database": "string", "query": "string"}, + }, + ] + + result = mock_azure_provider.build_function_instructions(functions) + + expected = """ +You have access to the following functions: +Use the function 'python_repl_ast' to: 'execute Python code' +{ + "query": "string" +} + +Use the function 'data_lookup' to: 'retrieve data from a database' +{ + "database": "string", + "query": "string" +} + + +If you choose to use a function to produce this response, ONLY reply in the following format with no prefix or suffix: +§{"type": "function", "name": "FUNCTION_NAME", "parameters": {"PARAMETER_NAME": PARAMETER_VALUE}} + +Here is an example of the output I desiere when performing function call: +§{"type": "function", "name": "python_repl_ast", "parameters": {"query": "print(df.shape)"}} + +Reminder: +- Function calls MUST follow the specified format. +- Only call one function at a time. +- NEVER call more than one function at a time. +- Required parameters MUST be specified. +- Put the entire function call reply on one line. +- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls. +- If you have already called a function and got the response for the user's question, please reply with the response. +""" + + assert result.strip() == expected.strip() + + +class TestBuildLlamaConversation: + def test_build_llama_conversation_with_nested_messages(self, mock_azure_provider): + mock_azure_provider.format_message = MagicMock( + side_effect=lambda msg: f"[formatted:{msg['content']}]" + ) + + openai_message = [ + { + "role": "user", + "content": "[{'content': 'nested message 1'}, {'content': 'nested message 2'}]", + }, + {"role": "assistant", "content": "assistant reply"}, + ] + llama_message = "Initial message: " + + result = mock_azure_provider.build_llama_conversation( + openai_message, llama_message + ) + + expected = "Initial message: [formatted:nested message 1][formatted:nested message 2][formatted:assistant reply]" + + assert result == expected + mock_azure_provider.format_message.assert_any_call( + {"content": "nested message 1"} + ) + mock_azure_provider.format_message.assert_any_call( + {"content": "nested message 2"} + ) + mock_azure_provider.format_message.assert_any_call( + {"role": "assistant", "content": "assistant reply"} + ) + + def test_build_llama_conversation_with_invalid_nested_content( + self, mock_azure_provider + ): + mock_azure_provider.format_message = MagicMock( + side_effect=lambda msg: f"[formatted:{msg['content']}]" + ) + + openai_message = [ + {"role": "user", "content": "[invalid json/dict]"}, + {"role": "assistant", "content": "assistant reply"}, + ] + llama_message = "Initial message: " + + with patch("ast.literal_eval", side_effect=ValueError) as mock_literal_eval: + result = mock_azure_provider.build_llama_conversation( + openai_message, llama_message + ) + + expected = "Initial message: [formatted:[invalid json/dict]][formatted:assistant reply]" + + assert result == expected + mock_azure_provider.format_message.assert_any_call( + {"role": "user", "content": "[invalid json/dict]"} + ) + mock_azure_provider.format_message.assert_any_call( + {"role": "assistant", "content": "assistant reply"} + ) + + mock_literal_eval.assert_called_once_with("[invalid json/dict]") + + def test_build_llama_conversation_skipping_system_messages( + self, mock_azure_provider + ): + mock_azure_provider.format_message = MagicMock( + side_effect=lambda msg: f"[formatted:{msg['content']}]" + ) + + openai_message = [ + {"role": "system", "content": "system message"}, + {"role": "user", "content": "user message"}, + ] + llama_message = "Initial message: " + + result = mock_azure_provider.build_llama_conversation( + openai_message, llama_message + ) + + expected = "Initial message: [formatted:user message]" + + assert result == expected + mock_azure_provider.format_message.assert_any_call( + {"role": "user", "content": "user message"} + ) From 3345efc85ee3fdb43098028e64e6f83a50f184b8 Mon Sep 17 00:00:00 2001 From: Diogo Goncalves Date: Wed, 18 Dec 2024 10:40:07 +0000 Subject: [PATCH 13/13] chore: new line Signed-off-by: Diogo Goncalves --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index b9329a81..6c957607 100644 --- a/Makefile +++ b/Makefile @@ -2,4 +2,4 @@ format: pre-commit run --all-files unit-tests: - pytest libs/core/tests/unit_tests \ No newline at end of file + pytest libs/core/tests/unit_tests