From 749e2fd57b00af65988142c5cacc74614ef7aa15 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Thu, 23 Apr 2026 16:22:14 +0800 Subject: [PATCH 1/3] perf: improve tool calls in reasoning and multiple tool calls display - Updated LiveChatRoute and OpenApiRoute to replace manual message accumulation with BotMessageAccumulator. - Simplified message saving logic by using build_bot_history_content and collect_plain_text_from_message_parts. - Enhanced message processing to handle various message types (plain, image, record, file, video) more efficiently. - Improved reasoning handling by extracting thinking parts and displaying them correctly in the UI components. - Refactored message normalization and reasoning extraction logic in useMessages composable for better clarity and maintainability. - Updated ChatMessageList, MessageList, StandaloneChat, and ReasoningBlock components to accommodate new message structure and rendering logic. --- .../agent/runners/tool_loop_agent_runner.py | 89 +++--- astrbot/core/astr_agent_run_util.py | 6 + .../core/provider/sources/openai_source.py | 2 + astrbot/dashboard/routes/chat.py | 294 +++++++++++++----- astrbot/dashboard/routes/live_chat.py | 105 ++----- astrbot/dashboard/routes/open_api.py | 85 +++-- .../src/components/chat/ChatMessageList.vue | 16 +- dashboard/src/components/chat/MessageList.vue | 17 +- .../src/components/chat/StandaloneChat.vue | 19 +- dashboard/src/components/chat/ThreadPanel.vue | 103 ++---- .../message_list_comps/ReasoningBlock.vue | 217 +++++++++---- dashboard/src/composables/useMessages.ts | 173 +++++++++-- 12 files changed, 692 insertions(+), 434 deletions(-) diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 81b82403e6..cf70b41504 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -717,6 +717,15 @@ async def step(self): if self.stats.time_to_first_token == 0: self.stats.time_to_first_token = time.time() - self.stats.start_time + if llm_response.reasoning_content: + yield AgentResponse( + type="streaming_delta", + data=AgentResponseData( + chain=MessageChain(type="reasoning").message( + llm_response.reasoning_content, + ), + ), + ) if llm_response.result_chain: yield AgentResponse( type="streaming_delta", @@ -729,15 +738,6 @@ async def step(self): chain=MessageChain().message(llm_response.completion_text), ), ) - if llm_response.reasoning_content: - yield AgentResponse( - type="streaming_delta", - data=AgentResponseData( - chain=MessageChain(type="reasoning").message( - llm_response.reasoning_content, - ), - ), - ) if self._is_stop_requested(): llm_resp_result = LLMResponse( role="assistant", @@ -791,6 +791,15 @@ async def step(self): await self._complete_with_assistant_response(llm_resp) # 返回 LLM 结果 + if llm_resp.reasoning_content: + yield AgentResponse( + type="llm_result", + data=AgentResponseData( + chain=MessageChain(type="reasoning").message( + llm_resp.reasoning_content, + ), + ), + ) if llm_resp.result_chain: yield AgentResponse( type="llm_result", @@ -803,15 +812,6 @@ async def step(self): chain=MessageChain().message(llm_resp.completion_text), ), ) - if llm_resp.reasoning_content: - yield AgentResponse( - type="llm_result", - data=AgentResponseData( - chain=MessageChain(type="reasoning").message( - llm_resp.reasoning_content, - ), - ), - ) # 如果有工具调用,还需处理工具调用 if llm_resp.tools_call_name: @@ -821,6 +821,15 @@ async def step(self): logger.warning( "skills_like tool re-query returned no tool calls; fallback to assistant response." ) + if llm_resp.reasoning_content: + yield AgentResponse( + type="llm_result", + data=AgentResponseData( + chain=MessageChain(type="reasoning").message( + llm_resp.reasoning_content, + ), + ), + ) if llm_resp.result_chain: yield AgentResponse( type="llm_result", @@ -833,15 +842,7 @@ async def step(self): chain=MessageChain().message(llm_resp.completion_text), ), ) - if llm_resp.reasoning_content: - yield AgentResponse( - type="llm_result", - data=AgentResponseData( - chain=MessageChain(type="reasoning").message( - llm_resp.reasoning_content, - ), - ), - ) + await self._complete_with_assistant_response(llm_resp) return @@ -988,6 +989,7 @@ def _append_tool_call_result(tool_call_id: str, content: str) -> None: llm_response.tools_call_args, llm_response.tools_call_ids, ): + tool_result_blocks_start = len(tool_call_result_blocks) tool_call_streak = self._track_tool_call_streak(func_tool_name) yield _HandleFunctionToolsResult.from_message_chain( MessageChain( @@ -1201,24 +1203,23 @@ def _append_tool_call_result(tool_call_id: str, content: str) -> None: ), ) - # yield the last tool call result - if tool_call_result_blocks: - last_tcr_content = str(tool_call_result_blocks[-1].content) - yield _HandleFunctionToolsResult.from_message_chain( - MessageChain( - type="tool_call_result", - chain=[ - Json( - data={ - "id": func_tool_id, - "ts": time.time(), - "result": last_tcr_content, - } - ) - ], + if len(tool_call_result_blocks) > tool_result_blocks_start: + tool_result_content = str(tool_call_result_blocks[-1].content) + yield _HandleFunctionToolsResult.from_message_chain( + MessageChain( + type="tool_call_result", + chain=[ + Json( + data={ + "id": func_tool_id, + "ts": time.time(), + "result": tool_result_content, + } + ) + ], + ) ) - ) - logger.info(f"Tool `{func_tool_name}` Result: {last_tcr_content}") + logger.info(f"Tool `{func_tool_name}` Result: {tool_result_content}") # 处理函数调用响应 if tool_call_result_blocks: diff --git a/astrbot/core/astr_agent_run_util.py b/astrbot/core/astr_agent_run_util.py index 62c60a4362..6bdf3011b6 100644 --- a/astrbot/core/astr_agent_run_util.py +++ b/astrbot/core/astr_agent_run_util.py @@ -235,6 +235,12 @@ async def run_agent( ) await astr_event.send(chain) continue + elif resp.type == "llm_result": + chain = resp.data["chain"] + if chain.type == "reasoning": + # For non-streaming mode, we handle reasoning in astrbot/core/astr_agent_hooks.py. + # For streaming mode, we yield content immediately when received a reasoning chunk but not in here, see below. + continue if stream_to_general and resp.type == "streaming_delta": continue diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 67971a2a93..f2d9474906 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -652,6 +652,8 @@ async def _query_stream( reasoning = self._extract_reasoning_content(chunk) _y = False llm_response.id = chunk.id + llm_response.reasoning_content = "" + llm_response.completion_text = "" if reasoning: llm_response.reasoning_content = reasoning _y = True diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index d7d4777acc..99d7a0e7b2 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -5,7 +5,7 @@ import uuid from contextlib import asynccontextmanager from copy import deepcopy -from typing import cast +from typing import Any, cast from quart import Response as QuartResponse from quart import g, make_response, request, send_file @@ -58,6 +58,179 @@ async def _poll_webchat_stream_result(back_queue, username: str): return result, False +def normalize_legacy_reasoning_message_parts( + message_parts: list[dict] | None, + reasoning: str = "", +) -> list[dict]: + parts: list[dict] = [] + for part in message_parts or []: + if not isinstance(part, dict): + continue + copied = dict(part) + if copied.get("type") == "reasoning": + copied = {"type": "think", "think": copied.get("text", "")} + parts.append(copied) + if reasoning and not any(part.get("type") == "think" for part in parts): + parts.insert(0, {"type": "think", "think": reasoning}) + return parts + + +def extract_reasoning_from_message_parts(message_parts: list[dict]) -> str: + reasoning_parts: list[str] = [] + for part in message_parts: + if part.get("type") != "think": + continue + think = part.get("think") + if isinstance(think, str) and think: + reasoning_parts.append(think) + return "".join(reasoning_parts) + + +def collect_plain_text_from_message_parts(message_parts: list[dict]) -> str: + text_parts: list[str] = [] + for part in message_parts: + if part.get("type") != "plain": + continue + text = part.get("text") + if isinstance(text, str) and text: + text_parts.append(text) + return "".join(text_parts) + + +def build_bot_history_content( + message_parts: list[dict], + *, + agent_stats: dict | None = None, + refs: dict | None = None, + include_legacy_reasoning_field: bool = True, +) -> dict[str, Any]: + normalized_parts = normalize_legacy_reasoning_message_parts(message_parts) + content: dict[str, Any] = {"type": "bot", "message": normalized_parts} + reasoning = extract_reasoning_from_message_parts(normalized_parts) + if reasoning and include_legacy_reasoning_field: + # Keep the legacy field for old clients while the canonical structure + # moves to message parts. + content["reasoning"] = reasoning + if agent_stats: + content["agent_stats"] = agent_stats + if refs: + content["refs"] = refs + return content + + +class BotMessageAccumulator: + def __init__(self) -> None: + self.parts: list[dict] = [] + self.pending_text = "" + self.pending_tool_calls: dict[str, dict] = {} + + def has_content(self) -> bool: + return bool(self.parts or self.pending_text or self.pending_tool_calls) + + def add_plain( + self, + result_text: str, + *, + chain_type: str | None, + streaming: bool, + ) -> None: + if chain_type == "tool_call": + self._flush_pending_text() + self._store_tool_call(result_text) + return + + if chain_type == "tool_call_result": + self._flush_pending_text() + self._store_tool_call_result(result_text) + return + + if chain_type == "reasoning": + self._flush_pending_text() + self._append_think_part(result_text) + return + + if streaming: + self.pending_text += result_text + else: + self.pending_text += result_text + + def add_attachment(self, part: dict | None) -> None: + if not part: + return + self._flush_pending_text() + self.parts.append(part) + + def build_message_parts( + self, *, include_pending_tool_calls: bool = False + ) -> list[dict]: + self._flush_pending_text() + if include_pending_tool_calls and self.pending_tool_calls: + for tool_call in self.pending_tool_calls.values(): + self.parts.append({"type": "tool_call", "tool_calls": [tool_call]}) + self.pending_tool_calls = {} + return self.parts + + def plain_text(self) -> str: + return collect_plain_text_from_message_parts(self.build_message_parts()) + + def reasoning_text(self) -> str: + return extract_reasoning_from_message_parts(self.build_message_parts()) + + def _flush_pending_text(self) -> None: + if not self.pending_text: + return + + if self.parts and self.parts[-1].get("type") == "plain": + last_text = self.parts[-1].get("text") + self.parts[-1]["text"] = f"{last_text or ''}{self.pending_text}" + else: + self.parts.append({"type": "plain", "text": self.pending_text}) + self.pending_text = "" + + def _append_think_part(self, text: str) -> None: + if not text: + return + + if self.parts and self.parts[-1].get("type") == "think": + last_text = self.parts[-1].get("think") + self.parts[-1]["think"] = f"{last_text or ''}{text}" + else: + self.parts.append({"type": "think", "think": text}) + + def _store_tool_call(self, result_text: str) -> None: + tool_call = self._parse_json_object(result_text) + if not tool_call: + return + tool_call_id = str(tool_call.get("id") or "") + if not tool_call_id: + return + self.pending_tool_calls[tool_call_id] = tool_call + + def _store_tool_call_result(self, result_text: str) -> None: + tool_result = self._parse_json_object(result_text) + if not tool_result: + return + + tool_call_id = str(tool_result.get("id") or "") + if not tool_call_id: + return + + tool_call = self.pending_tool_calls.pop(tool_call_id, None) or { + "id": tool_call_id + } + tool_call["result"] = tool_result.get("result") + tool_call["finished_ts"] = tool_result.get("ts") + self.parts.append({"type": "tool_call", "tool_calls": [tool_call]}) + + @staticmethod + def _parse_json_object(raw_text: str) -> dict | None: + try: + parsed = json.loads(raw_text) + except json.JSONDecodeError: + return None + return parsed if isinstance(parsed, dict) else None + + class ChatRoute(Route): def __init__( self, @@ -519,27 +692,18 @@ async def _delete_platform_history_after( async def _save_bot_message( self, webchat_conv_id: str, - text: str, - media_parts: list, - reasoning: str, + message_parts: list[dict], agent_stats: dict, refs: dict, llm_checkpoint_id: str | None = None, platform_history_id: str = "webchat", ): """保存 bot 消息到历史记录,返回保存的记录""" - bot_message_parts = [] - bot_message_parts.extend(media_parts) - if text: - bot_message_parts.append({"type": "plain", "text": text}) - - new_his = {"type": "bot", "message": bot_message_parts} - if reasoning: - new_his["reasoning"] = reasoning - if agent_stats: - new_his["agent_stats"] = agent_stats - if refs: - new_his["refs"] = refs + new_his = build_bot_history_content( + message_parts, + agent_stats=agent_stats, + refs=refs, + ) record = await self.platform_history_mgr.insert( platform_id=platform_history_id, @@ -599,10 +763,7 @@ async def chat(self, post_data: dict | None = None): async def stream(): client_disconnected = False - accumulated_parts = [] - accumulated_text = "" - accumulated_reasoning = "" - tool_calls = {} + message_accumulator = BotMessageAccumulator() agent_stats = {} refs = {} try: @@ -683,76 +844,61 @@ async def stream(): # 累积消息部分 if msg_type == "plain": - chain_type = result.get("chain_type") - if chain_type == "tool_call": - tool_call = json.loads(result_text) - tool_calls[tool_call.get("id")] = tool_call - if accumulated_text: - # 如果累积了文本,则先保存文本 - accumulated_parts.append( - {"type": "plain", "text": accumulated_text} - ) - accumulated_text = "" - elif chain_type == "tool_call_result": - tcr = json.loads(result_text) - tc_id = tcr.get("id") - if tc_id in tool_calls: - tool_calls[tc_id]["result"] = tcr.get("result") - tool_calls[tc_id]["finished_ts"] = tcr.get("ts") - accumulated_parts.append( - { - "type": "tool_call", - "tool_calls": [tool_calls[tc_id]], - } - ) - tool_calls.pop(tc_id, None) - elif chain_type == "reasoning": - accumulated_reasoning += result_text - elif streaming: - accumulated_text += result_text - else: - accumulated_text = result_text + message_accumulator.add_plain( + result_text, + chain_type=chain_type, + streaming=streaming, + ) elif msg_type == "image": filename = result_text.replace("[IMAGE]", "") part = await self._create_attachment_from_file( filename, "image" ) - if part: - accumulated_parts.append(part) + message_accumulator.add_attachment(part) elif msg_type == "record": filename = result_text.replace("[RECORD]", "") part = await self._create_attachment_from_file( filename, "record" ) - if part: - accumulated_parts.append(part) + message_accumulator.add_attachment(part) elif msg_type == "file": # 格式: [FILE]filename filename = result_text.replace("[FILE]", "") part = await self._create_attachment_from_file( filename, "file" ) - if part: - accumulated_parts.append(part) + message_accumulator.add_attachment(part) + elif msg_type == "video": + filename = result_text.replace("[VIDEO]", "") + part = await self._create_attachment_from_file( + filename, "video" + ) + message_accumulator.add_attachment(part) - # 消息结束处理 + should_save = False if msg_type == "end": - break - elif ( - (streaming and msg_type == "complete") or not streaming - # or msg_type == "break" - ): - if ( - chain_type == "tool_call" - or chain_type == "tool_call_result" - ): - continue + should_save = message_accumulator.has_content() or bool( + refs or agent_stats + ) + elif (streaming and msg_type == "complete") or not streaming: + if chain_type not in ("tool_call", "tool_call_result"): + should_save = True + + if should_save: + message_parts_to_save = ( + message_accumulator.build_message_parts( + include_pending_tool_calls=True + ) + ) + plain_text = collect_plain_text_from_message_parts( + message_parts_to_save + ) # 提取 web_search_tavily 引用 try: refs = self._extract_web_search_refs( - accumulated_text, - accumulated_parts, + plain_text, + message_parts_to_save, ) except Exception as e: logger.exception( @@ -762,9 +908,7 @@ async def stream(): saved_record = await self._save_bot_message( webchat_conv_id, - accumulated_text, - accumulated_parts, - accumulated_reasoning, + message_parts_to_save, agent_stats, refs, llm_checkpoint_id, @@ -786,12 +930,12 @@ async def stream(): yield f"data: {json.dumps(saved_info, ensure_ascii=False)}\n\n" except Exception: pass - accumulated_parts = [] - accumulated_text = "" - accumulated_reasoning = "" - # tool_calls = {} + message_accumulator = BotMessageAccumulator() agent_stats = {} refs = {} + + if msg_type == "end": + break except BaseException as e: logger.exception(f"WebChat stream unexpected error: {e}", exc_info=True) finally: diff --git a/astrbot/dashboard/routes/live_chat.py b/astrbot/dashboard/routes/live_chat.py index 16c6058485..8f4dc26fab 100644 --- a/astrbot/dashboard/routes/live_chat.py +++ b/astrbot/dashboard/routes/live_chat.py @@ -23,6 +23,11 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path, get_astrbot_temp_path from astrbot.core.utils.datetime_utils import to_utc_isoformat +from .chat import ( + BotMessageAccumulator, + build_bot_history_content, + collect_plain_text_from_message_parts, +) from .route import Route, RouteContext @@ -250,26 +255,17 @@ def _extract_web_search_refs( async def _save_bot_message( self, webchat_conv_id: str, - text: str, - media_parts: list, - reasoning: str, + message_parts: list[dict], agent_stats: dict, refs: dict, llm_checkpoint_id: str | None = None, ): """保存 bot 消息到历史记录。""" - bot_message_parts = [] - bot_message_parts.extend(media_parts) - if text: - bot_message_parts.append({"type": "plain", "text": text}) - - new_his = {"type": "bot", "message": bot_message_parts} - if reasoning: - new_his["reasoning"] = reasoning - if agent_stats: - new_his["agent_stats"] = agent_stats - if refs: - new_his["refs"] = refs + new_his = build_bot_history_content( + message_parts, + agent_stats=agent_stats, + refs=refs, + ) return await self.platform_history_mgr.insert( platform_id="webchat", @@ -499,10 +495,7 @@ async def _handle_chat_message( }, ) - accumulated_parts = [] - accumulated_text = "" - accumulated_reasoning = "" - tool_calls = {} + message_accumulator = BotMessageAccumulator() agent_stats = {} refs = {} @@ -545,68 +538,32 @@ async def _handle_chat_message( await self._send_chat_payload(session, outgoing) if msg_type == "plain": - if chain_type == "tool_call": - try: - tool_call = json.loads(result_text) - tool_calls[tool_call.get("id")] = tool_call - if accumulated_text: - accumulated_parts.append( - {"type": "plain", "text": accumulated_text} - ) - accumulated_text = "" - except Exception: - pass - elif chain_type == "tool_call_result": - try: - tcr = json.loads(result_text) - tc_id = tcr.get("id") - if tc_id in tool_calls: - tool_calls[tc_id]["result"] = tcr.get("result") - tool_calls[tc_id]["finished_ts"] = tcr.get("ts") - accumulated_parts.append( - { - "type": "tool_call", - "tool_calls": [tool_calls[tc_id]], - } - ) - tool_calls.pop(tc_id, None) - except Exception: - pass - elif chain_type == "reasoning": - accumulated_reasoning += result_text - elif streaming: - accumulated_text += result_text - else: - accumulated_text = result_text + message_accumulator.add_plain( + result_text, + chain_type=chain_type, + streaming=streaming, + ) elif msg_type == "image": filename = str(result_text).replace("[IMAGE]", "") part = await self._create_attachment_from_file(filename, "image") - if part: - accumulated_parts.append(part) + message_accumulator.add_attachment(part) elif msg_type == "record": filename = str(result_text).replace("[RECORD]", "") part = await self._create_attachment_from_file(filename, "record") - if part: - accumulated_parts.append(part) + message_accumulator.add_attachment(part) elif msg_type == "file": filename = str(result_text).replace("[FILE]", "").split("|", 1)[0] part = await self._create_attachment_from_file(filename, "file") - if part: - accumulated_parts.append(part) + message_accumulator.add_attachment(part) elif msg_type == "video": filename = str(result_text).replace("[VIDEO]", "").split("|", 1)[0] part = await self._create_attachment_from_file(filename, "video") - if part: - accumulated_parts.append(part) + message_accumulator.add_attachment(part) should_save = False if msg_type == "end": should_save = bool( - accumulated_parts - or accumulated_text - or accumulated_reasoning - or refs - or agent_stats + message_accumulator.has_content() or refs or agent_stats ) elif (streaming and msg_type == "complete") or not streaming: if chain_type not in ( @@ -617,10 +574,16 @@ async def _handle_chat_message( should_save = True if should_save: + message_parts_to_save = message_accumulator.build_message_parts( + include_pending_tool_calls=True + ) + plain_text = collect_plain_text_from_message_parts( + message_parts_to_save + ) try: refs = self._extract_web_search_refs( - accumulated_text, - accumulated_parts, + plain_text, + message_parts_to_save, ) except Exception as e: logger.exception( @@ -630,9 +593,7 @@ async def _handle_chat_message( saved_record = await self._save_bot_message( session_id, - accumulated_text, - accumulated_parts, - accumulated_reasoning, + message_parts_to_save, agent_stats, refs, llm_checkpoint_id, @@ -653,9 +614,7 @@ async def _handle_chat_message( }, ) - accumulated_parts = [] - accumulated_text = "" - accumulated_reasoning = "" + message_accumulator = BotMessageAccumulator() agent_stats = {} refs = {} diff --git a/astrbot/dashboard/routes/open_api.py b/astrbot/dashboard/routes/open_api.py index 8f20473262..52b412b2b5 100644 --- a/astrbot/dashboard/routes/open_api.py +++ b/astrbot/dashboard/routes/open_api.py @@ -18,7 +18,11 @@ from astrbot.core.utils.datetime_utils import to_utc_isoformat from .api_key import ALL_OPEN_API_SCOPES -from .chat import ChatRoute +from .chat import ( + BotMessageAccumulator, + ChatRoute, + collect_plain_text_from_message_parts, +) from .route import Response, Route, RouteContext @@ -363,10 +367,7 @@ async def _handle_chat_ws_send(self, post_data: dict) -> None: } ) - accumulated_parts = [] - accumulated_text = "" - accumulated_reasoning = "" - tool_calls = {} + message_accumulator = BotMessageAccumulator() agent_stats = {} refs = {} while True: @@ -402,68 +403,56 @@ async def _handle_chat_ws_send(self, post_data: dict) -> None: await websocket.send_json(result) if msg_type == "plain": - if chain_type == "tool_call": - tool_call = json.loads(result_text) - tool_calls[tool_call.get("id")] = tool_call - if accumulated_text: - accumulated_parts.append( - {"type": "plain", "text": accumulated_text} - ) - accumulated_text = "" - elif chain_type == "tool_call_result": - tcr = json.loads(result_text) - tc_id = tcr.get("id") - if tc_id in tool_calls: - tool_calls[tc_id]["result"] = tcr.get("result") - tool_calls[tc_id]["finished_ts"] = tcr.get("ts") - accumulated_parts.append( - {"type": "tool_call", "tool_calls": [tool_calls[tc_id]]} - ) - tool_calls.pop(tc_id, None) - elif chain_type == "reasoning": - accumulated_reasoning += result_text - elif streaming: - accumulated_text += result_text - else: - accumulated_text = result_text + message_accumulator.add_plain( + result_text, + chain_type=chain_type, + streaming=streaming, + ) elif msg_type == "image": filename = str(result_text).replace("[IMAGE]", "") part = await self.chat_route._create_attachment_from_file( filename, "image" ) - if part: - accumulated_parts.append(part) + message_accumulator.add_attachment(part) elif msg_type == "record": filename = str(result_text).replace("[RECORD]", "") part = await self.chat_route._create_attachment_from_file( filename, "record" ) - if part: - accumulated_parts.append(part) + message_accumulator.add_attachment(part) elif msg_type == "file": filename = str(result_text).replace("[FILE]", "") part = await self.chat_route._create_attachment_from_file( filename, "file" ) - if part: - accumulated_parts.append(part) + message_accumulator.add_attachment(part) elif msg_type == "video": filename = str(result_text).replace("[VIDEO]", "") part = await self.chat_route._create_attachment_from_file( filename, "video" ) - if part: - accumulated_parts.append(part) + message_accumulator.add_attachment(part) + should_save = False if msg_type == "end": - break - if (streaming and msg_type == "complete") or not streaming: - if chain_type in ("tool_call", "tool_call_result"): - continue + should_save = bool( + message_accumulator.has_content() or refs or agent_stats + ) + elif (streaming and msg_type == "complete") or not streaming: + if chain_type not in ("tool_call", "tool_call_result"): + should_save = True + + if should_save: + message_parts_to_save = message_accumulator.build_message_parts( + include_pending_tool_calls=True + ) + plain_text = collect_plain_text_from_message_parts( + message_parts_to_save + ) try: refs = self.chat_route._extract_web_search_refs( - accumulated_text, - accumulated_parts, + plain_text, + message_parts_to_save, ) except Exception as e: logger.exception( @@ -473,9 +462,7 @@ async def _handle_chat_ws_send(self, post_data: dict) -> None: saved_record = await self.chat_route._save_bot_message( session_id, - accumulated_text, - accumulated_parts, - accumulated_reasoning, + message_parts_to_save, agent_stats, refs, ) @@ -492,11 +479,11 @@ async def _handle_chat_ws_send(self, post_data: dict) -> None: "session_id": session_id, } ) - accumulated_parts = [] - accumulated_text = "" - accumulated_reasoning = "" + message_accumulator = BotMessageAccumulator() agent_stats = {} refs = {} + if msg_type == "end": + break except Exception as e: logger.exception(f"Open API WS chat failed: {e}", exc_info=True) await self._send_chat_ws_error( diff --git a/dashboard/src/components/chat/ChatMessageList.vue b/dashboard/src/components/chat/ChatMessageList.vue index a0e70d65fe..6905ec5c8c 100644 --- a/dashboard/src/components/chat/ChatMessageList.vue +++ b/dashboard/src/components/chat/ChatMessageList.vue @@ -120,8 +120,8 @@