diff --git a/src/lmstudio/history.py b/src/lmstudio/history.py index 4e62452..548371d 100644 --- a/src/lmstudio/history.py +++ b/src/lmstudio/history.py @@ -135,9 +135,13 @@ def _to_history_content(self) -> str: | ToolCallRequestData | ToolCallRequestDataDict ) +AssistantMultiPartInput = Iterable[AssistantResponseInput | ToolCallRequestInput] ToolCallResultInput = ToolCallResultData | ToolCallResultDataDict +ToolCallResultMultiPartInput = Iterable[ToolCallResultInput] ChatMessageInput = str | ChatMessageContent | ChatMessageContentDict -ChatMessageMultiPartInput = UserMessageMultiPartInput +ChatMessageMultiPartInput = ( + UserMessageMultiPartInput | AssistantMultiPartInput | ToolCallResultMultiPartInput +) AnyChatMessageInput = ChatMessageInput | ChatMessageMultiPartInput @@ -251,9 +255,12 @@ def add_entry(self, role: str, content: AnyChatMessageInput) -> AnyChatMessage: if role == "user": messages = cast(AnyUserMessageInput, content) return self.add_user_message(messages) + # Tool results accept multi-part content, so just forward it to that method + if role == "tool": + tool_results = cast(Iterable[ToolCallResultInput], content) + return self.add_tool_results(tool_results) # Assistant responses consist of a text response with zero or more tool requests if role == "assistant": - response: AssistantResponseInput if _is_chat_message_input(content): response = cast(AssistantResponseInput, content) return self.add_assistant_response(response) @@ -263,7 +270,7 @@ def add_entry(self, role: str, content: AnyChatMessageInput) -> AnyChatMessage: raise LMStudioValueError( f"Unable to parse assistant response content: {content}" ) from None - response = response_content + response = cast(AssistantResponseInput, response_content) tool_requests = cast(Iterable[ToolCallRequest], tool_request_contents) return self.add_assistant_response(response, tool_requests) @@ -276,7 +283,7 @@ def add_entry(self, role: str, content: AnyChatMessageInput) -> AnyChatMessage: content_item = content else: try: - (content_item,) = content + (content_item,) = cast(Iterable[ChatMessageInput], content) except ValueError: err_msg = f"{role!r} role does not support multi-part message content." raise LMStudioValueError(err_msg) from None @@ -284,9 +291,6 @@ def add_entry(self, role: str, content: AnyChatMessageInput) -> AnyChatMessage: case "system": prompt = cast(SystemPromptInput, content_item) result = self.add_system_prompt(prompt) - case "tool": - tool_result = cast(ToolCallResultInput, content_item) - result = self.add_tool_result(tool_result) case _: raise LMStudioValueError(f"Unknown history role: {role}") return result diff --git a/tests/test_history.py b/tests/test_history.py index 6ec7925..ce20483 100644 --- a/tests/test_history.py +++ b/tests/test_history.py @@ -10,9 +10,10 @@ from lmstudio.sdk_api import LMStudioOSError from lmstudio.schemas import DictObject from lmstudio.history import ( + AnyChatMessageDict, AnyChatMessageInput, + AssistantMultiPartInput, Chat, - AnyChatMessageDict, ChatHistoryData, ChatHistoryDataDict, LocalFileInput, @@ -29,6 +30,10 @@ LlmPredictionStats, PredictionResult, ) +from lmstudio._sdk_models import ( + ToolCallRequestDataDict, + ToolCallResultDataDict, +) from .support import IMAGE_FILEPATH, check_sdk_error @@ -125,6 +130,51 @@ "role": "system", "content": [{"type": "text", "text": "Structured text system prompt"}], }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Example tool call request"}, + { + "type": "toolCallRequest", + "toolCallRequest": { + "type": "function", + "id": "114663647", + "name": "example_tool_name", + "arguments": { + "n": 58013, + "t": "value", + }, + }, + }, + { + "type": "toolCallRequest", + "toolCallRequest": { + "type": "function", + "id": "114663648", + "name": "another_example_tool_name", + "arguments": { + "n": 23, + "t": "some other value", + }, + }, + }, + ], + }, + { + "role": "tool", + "content": [ + { + "type": "toolCallResult", + "toolCallId": "114663647", + "content": "example tool call result", + }, + { + "type": "toolCallResult", + "toolCallId": "114663648", + "content": "another example tool call result", + }, + ], + }, ] INPUT_HISTORY = {"messages": INPUT_ENTRIES} @@ -214,6 +264,51 @@ "role": "system", "content": [{"type": "text", "text": "Structured text system prompt"}], }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Example tool call request"}, + { + "type": "toolCallRequest", + "toolCallRequest": { + "type": "function", + "id": "114663647", + "name": "example_tool_name", + "arguments": { + "n": 58013, + "t": "value", + }, + }, + }, + { + "type": "toolCallRequest", + "toolCallRequest": { + "type": "function", + "id": "114663648", + "name": "another_example_tool_name", + "arguments": { + "n": 23, + "t": "some other value", + }, + }, + }, + ], + }, + { + "role": "tool", + "content": [ + { + "type": "toolCallResult", + "toolCallId": "114663647", + "content": "example tool call result", + }, + { + "type": "toolCallResult", + "toolCallId": "114663648", + "content": "another example tool call result", + }, + ], + }, ] @@ -271,6 +366,44 @@ def test_from_history_with_simple_text() -> None: "sizeBytes": 100, "fileType": "text/plain", } +INPUT_TOOL_REQUESTS: list[ToolCallRequestDataDict] = [ + { + "type": "toolCallRequest", + "toolCallRequest": { + "type": "function", + "id": "114663647", + "name": "example_tool_name", + "arguments": { + "n": 58013, + "t": "value", + }, + }, + }, + { + "type": "toolCallRequest", + "toolCallRequest": { + "type": "function", + "id": "114663648", + "name": "another_example_tool_name", + "arguments": { + "n": 23, + "t": "some other value", + }, + }, + }, +] +INPUT_TOOL_RESULTS: list[ToolCallResultDataDict] = [ + { + "type": "toolCallResult", + "toolCallId": "114663647", + "content": "example tool call result", + }, + { + "type": "toolCallResult", + "toolCallId": "114663648", + "content": "another example tool call result", + }, +] def test_get_history() -> None: @@ -289,6 +422,8 @@ def test_get_history() -> None: chat.add_user_message("Avoid consecutive responses") chat.add_assistant_response(INPUT_FILE_HANDLE_DICT) chat.add_system_prompt(TextData(text="Structured text system prompt")) + chat.add_assistant_response("Example tool call request", INPUT_TOOL_REQUESTS) + chat.add_tool_results(INPUT_TOOL_RESULTS) assert chat._get_history_for_prediction() == EXPECTED_HISTORY @@ -307,6 +442,19 @@ def test_add_entry() -> None: chat.add_entry("user", "Avoid consecutive responses") chat.add_entry("assistant", INPUT_FILE_HANDLE_DICT) chat.add_entry("system", TextData(text="Structured text system prompt")) + tool_call_message_contents: AssistantMultiPartInput = [ + "Example tool call request", + *INPUT_TOOL_REQUESTS, + ] + chat.add_entry("assistant", tool_call_message_contents) + chat.add_entry("tool", INPUT_TOOL_RESULTS) + assert chat._get_history_for_prediction() == EXPECTED_HISTORY + + +def test_append() -> None: + chat = Chat() + for message in INPUT_ENTRIES: + chat.append(cast(AnyChatMessageDict, message)) assert chat._get_history_for_prediction() == EXPECTED_HISTORY