From 88d963daa0ae534b7b69ba6a9c3791dc294a1c6f Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sat, 7 Feb 2026 13:48:09 +0800 Subject: [PATCH 1/2] perf: optimize webchat and wecom ai queue lifecycle --- .../sources/webchat/webchat_adapter.py | 38 +-- .../platform/sources/webchat/webchat_event.py | 8 +- .../sources/webchat/webchat_queue_mgr.py | 117 ++++++++- .../sources/wecom_ai_bot/wecomai_adapter.py | 48 +--- .../sources/wecom_ai_bot/wecomai_queue_mgr.py | 127 ++++++++- astrbot/dashboard/routes/chat.py | 4 +- astrbot/dashboard/routes/live_chat.py | 243 +++++++++--------- 7 files changed, 366 insertions(+), 219 deletions(-) diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py index 316c95d814..a9dff522d4 100644 --- a/astrbot/core/platform/sources/webchat/webchat_adapter.py +++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py @@ -29,43 +29,11 @@ class QueueListener: def __init__(self, webchat_queue_mgr: WebChatQueueMgr, callback: Callable) -> None: self.webchat_queue_mgr = webchat_queue_mgr self.callback = callback - self.running_tasks = set() - - async def listen_to_queue(self, conversation_id: str): - """Listen to a specific conversation queue""" - queue = self.webchat_queue_mgr.get_or_create_queue(conversation_id) - while True: - try: - data = await queue.get() - await self.callback(data) - except Exception as e: - logger.error( - f"Error processing message from conversation {conversation_id}: {e}", - ) - break async def run(self): - """Monitor for new conversation queues and start listeners""" - monitored_conversations = set() - - while True: - # Check for new conversations - current_conversations = set(self.webchat_queue_mgr.queues.keys()) - new_conversations = current_conversations - monitored_conversations - - # Start listeners for new conversations - for conversation_id in new_conversations: - task = asyncio.create_task(self.listen_to_queue(conversation_id)) - self.running_tasks.add(task) - task.add_done_callback(self.running_tasks.discard) - monitored_conversations.add(conversation_id) - logger.debug(f"Started listener for conversation: {conversation_id}") - - # Clean up monitored conversations that no longer exist - removed_conversations = monitored_conversations - current_conversations - monitored_conversations -= removed_conversations - - await asyncio.sleep(1) # Check for new conversations every second + """Register callback and keep adapter task alive.""" + self.webchat_queue_mgr.set_listener(self.callback) + await asyncio.Event().wait() @register_platform_adapter("webchat", "webchat") diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py index 6e7201c6dc..82c1a7580e 100644 --- a/astrbot/core/platform/sources/webchat/webchat_event.py +++ b/astrbot/core/platform/sources/webchat/webchat_event.py @@ -26,8 +26,8 @@ async def _send( session_id: str, streaming: bool = False, ) -> str | None: - cid = session_id.split("!")[-1] - web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid) + request_id = str(message_id) + web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(request_id) if not message: await web_chat_back_queue.put( { @@ -124,9 +124,9 @@ async def send(self, message: MessageChain | None): async def send_streaming(self, generator, use_fallback: bool = False): final_data = "" reasoning_content = "" - cid = self.session_id.split("!")[-1] - web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid) message_id = self.message_obj.message_id + request_id = str(message_id) + web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(request_id) async for chain in generator: # 处理音频流(Live Mode) if chain.type == "audio_chunk": diff --git a/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py b/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py index 6c365cb3a9..6acf61b730 100644 --- a/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py +++ b/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py @@ -1,35 +1,124 @@ import asyncio +from collections.abc import Awaitable, Callable + +from astrbot import logger class WebChatQueueMgr: - def __init__(self) -> None: - self.queues = {} + def __init__(self, queue_maxsize: int = 128, back_queue_maxsize: int = 512) -> None: + self.queues: dict[str, asyncio.Queue] = {} """Conversation ID to asyncio.Queue mapping""" - self.back_queues = {} - """Conversation ID to asyncio.Queue mapping for responses""" + self.back_queues: dict[str, asyncio.Queue] = {} + """Request ID to asyncio.Queue mapping for responses""" + self._queue_close_events: dict[str, asyncio.Event] = {} + self._listener_tasks: dict[str, asyncio.Task] = {} + self._listener_callback: Callable[[tuple], Awaitable[None]] | None = None + self.queue_maxsize = queue_maxsize + self.back_queue_maxsize = back_queue_maxsize def get_or_create_queue(self, conversation_id: str) -> asyncio.Queue: """Get or create a queue for the given conversation ID""" if conversation_id not in self.queues: - self.queues[conversation_id] = asyncio.Queue() + self.queues[conversation_id] = asyncio.Queue(maxsize=self.queue_maxsize) + self._queue_close_events[conversation_id] = asyncio.Event() + self._start_listener_if_needed(conversation_id) return self.queues[conversation_id] - def get_or_create_back_queue(self, conversation_id: str) -> asyncio.Queue: - """Get or create a back queue for the given conversation ID""" - if conversation_id not in self.back_queues: - self.back_queues[conversation_id] = asyncio.Queue() - return self.back_queues[conversation_id] + def get_or_create_back_queue(self, request_id: str) -> asyncio.Queue: + """Get or create a back queue for the given request ID""" + if request_id not in self.back_queues: + self.back_queues[request_id] = asyncio.Queue( + maxsize=self.back_queue_maxsize + ) + return self.back_queues[request_id] + + def remove_back_queue(self, request_id: str): + """Remove back queue for the given request ID""" + self.back_queues.pop(request_id, None) def remove_queues(self, conversation_id: str): """Remove queues for the given conversation ID""" - if conversation_id in self.queues: - del self.queues[conversation_id] - if conversation_id in self.back_queues: - del self.back_queues[conversation_id] + self.remove_queue(conversation_id) + + def remove_queue(self, conversation_id: str): + """Remove input queue and listener for the given conversation ID""" + self.queues.pop(conversation_id, None) + + close_event = self._queue_close_events.pop(conversation_id, None) + if close_event is not None: + close_event.set() + + task = self._listener_tasks.pop(conversation_id, None) + if task is not None: + task.cancel() def has_queue(self, conversation_id: str) -> bool: """Check if a queue exists for the given conversation ID""" return conversation_id in self.queues + def set_listener( + self, + callback: Callable[[tuple], Awaitable[None]], + ): + self._listener_callback = callback + for conversation_id in list(self.queues.keys()): + self._start_listener_if_needed(conversation_id) + + def _start_listener_if_needed(self, conversation_id: str): + if self._listener_callback is None: + return + if conversation_id in self._listener_tasks: + task = self._listener_tasks[conversation_id] + if not task.done(): + return + queue = self.queues.get(conversation_id) + close_event = self._queue_close_events.get(conversation_id) + if queue is None or close_event is None: + return + task = asyncio.create_task( + self._listen_to_queue(conversation_id, queue, close_event), + name=f"webchat_listener_{conversation_id}", + ) + self._listener_tasks[conversation_id] = task + task.add_done_callback( + lambda _: self._listener_tasks.pop(conversation_id, None) + ) + logger.debug(f"Started listener for conversation: {conversation_id}") + + async def _listen_to_queue( + self, + conversation_id: str, + queue: asyncio.Queue, + close_event: asyncio.Event, + ): + while True: + get_task = asyncio.create_task(queue.get()) + close_task = asyncio.create_task(close_event.wait()) + try: + done, pending = await asyncio.wait( + {get_task, close_task}, + return_when=asyncio.FIRST_COMPLETED, + ) + for task in pending: + task.cancel() + if close_task in done: + break + data = get_task.result() + if self._listener_callback is None: + continue + try: + await self._listener_callback(data) + except Exception as e: + logger.error( + f"Error processing message from conversation {conversation_id}: {e}" + ) + except asyncio.CancelledError: + break + finally: + if not get_task.done(): + get_task.cancel() + if not close_task.done(): + close_task.cancel() + webchat_queue_mgr = WebChatQueueMgr() diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py index 57da5176ba..3b68bc2840 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py @@ -51,44 +51,13 @@ def __init__( ) -> None: self.queue_mgr = queue_mgr self.callback = callback - self.running_tasks = set() - - async def listen_to_queue(self, session_id: str): - """监听特定会话的队列""" - queue = self.queue_mgr.get_or_create_queue(session_id) - while True: - try: - data = await queue.get() - await self.callback(data) - except Exception as e: - logger.error(f"处理会话 {session_id} 消息时发生错误: {e}") - break async def run(self): - """监控新会话队列并启动监听器""" - monitored_sessions = set() - + """注册监听回调并定期清理过期响应。""" + self.queue_mgr.set_listener(self.callback) while True: - # 检查新会话 - current_sessions = set(self.queue_mgr.queues.keys()) - new_sessions = current_sessions - monitored_sessions - - # 为新会话启动监听器 - for session_id in new_sessions: - task = asyncio.create_task(self.listen_to_queue(session_id)) - self.running_tasks.add(task) - task.add_done_callback(self.running_tasks.discard) - monitored_sessions.add(session_id) - logger.debug(f"[WecomAI] 为会话启动监听器: {session_id}") - - # 清理已不存在的会话 - removed_sessions = monitored_sessions - current_sessions - monitored_sessions -= removed_sessions - - # 清理过期的待处理响应 self.queue_mgr.cleanup_expired_responses() - - await asyncio.sleep(1) # 每秒检查一次新会话 + await asyncio.sleep(1) @register_platform_adapter( @@ -212,7 +181,12 @@ async def _process_message( # wechat server is requesting for updates of a stream stream_id = message_data["stream"]["id"] if not self.queue_mgr.has_back_queue(stream_id): - logger.error(f"Cannot find back queue for stream_id: {stream_id}") + if self.queue_mgr.is_stream_finished(stream_id): + logger.debug( + f"Stream already finished, returning end message: {stream_id}" + ) + else: + logger.warning(f"Cannot find back queue for stream_id: {stream_id}") # 返回结束标志,告诉微信服务器流已结束 end_message = WecomAIBotStreamMessageBuilder.make_text_stream( @@ -243,10 +217,10 @@ async def _process_message( latest_plain_content = msg["data"] or "" elif msg["type"] == "image": image_base64.append(msg["image_data"]) - elif msg["type"] == "end": + elif msg["type"] in {"end", "complete"}: # stream end finish = True - self.queue_mgr.remove_queues(stream_id) + self.queue_mgr.remove_queues(stream_id, mark_finished=True) break logger.debug( diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py index 3a982bdf73..b9dbf5b6a8 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py @@ -4,6 +4,7 @@ """ import asyncio +from collections.abc import Awaitable, Callable from typing import Any from astrbot.api import logger @@ -12,7 +13,7 @@ class WecomAIQueueMgr: """企业微信智能机器人队列管理器""" - def __init__(self) -> None: + def __init__(self, queue_maxsize: int = 128, back_queue_maxsize: int = 512) -> None: self.queues: dict[str, asyncio.Queue] = {} """StreamID 到输入队列的映射 - 用于接收用户消息""" @@ -21,6 +22,13 @@ def __init__(self) -> None: self.pending_responses: dict[str, dict[str, Any]] = {} """待处理的响应缓存,用于流式响应""" + self.completed_streams: dict[str, float] = {} + """已结束的 stream 缓存,用于兼容平台后续重复轮询""" + self._queue_close_events: dict[str, asyncio.Event] = {} + self._listener_tasks: dict[str, asyncio.Task] = {} + self._listener_callback: Callable[[dict], Awaitable[None]] | None = None + self.queue_maxsize = queue_maxsize + self.back_queue_maxsize = back_queue_maxsize def get_or_create_queue(self, session_id: str) -> asyncio.Queue: """获取或创建指定会话的输入队列 @@ -33,7 +41,9 @@ def get_or_create_queue(self, session_id: str) -> asyncio.Queue: """ if session_id not in self.queues: - self.queues[session_id] = asyncio.Queue() + self.queues[session_id] = asyncio.Queue(maxsize=self.queue_maxsize) + self._queue_close_events[session_id] = asyncio.Event() + self._start_listener_if_needed(session_id) logger.debug(f"[WecomAI] 创建输入队列: {session_id}") return self.queues[session_id] @@ -48,20 +58,21 @@ def get_or_create_back_queue(self, session_id: str) -> asyncio.Queue: """ if session_id not in self.back_queues: - self.back_queues[session_id] = asyncio.Queue() + self.back_queues[session_id] = asyncio.Queue( + maxsize=self.back_queue_maxsize + ) logger.debug(f"[WecomAI] 创建输出队列: {session_id}") return self.back_queues[session_id] - def remove_queues(self, session_id: str): + def remove_queues(self, session_id: str, mark_finished: bool = False): """移除指定会话的所有队列 Args: session_id: 会话ID + mark_finished: 是否标记为已正常结束 """ - if session_id in self.queues: - del self.queues[session_id] - logger.debug(f"[WecomAI] 移除输入队列: {session_id}") + self.remove_queue(session_id) if session_id in self.back_queues: del self.back_queues[session_id] @@ -70,6 +81,23 @@ def remove_queues(self, session_id: str): if session_id in self.pending_responses: del self.pending_responses[session_id] logger.debug(f"[WecomAI] 移除待处理响应: {session_id}") + if mark_finished: + self.completed_streams[session_id] = asyncio.get_event_loop().time() + logger.debug(f"[WecomAI] 标记流已结束: {session_id}") + + def remove_queue(self, session_id: str): + """仅移除输入队列和对应监听任务""" + if session_id in self.queues: + del self.queues[session_id] + logger.debug(f"[WecomAI] 移除输入队列: {session_id}") + + close_event = self._queue_close_events.pop(session_id, None) + if close_event is not None: + close_event.set() + + task = self._listener_tasks.pop(session_id, None) + if task is not None: + task.cancel() def has_queue(self, session_id: str) -> bool: """检查是否存在指定会话的队列 @@ -121,6 +149,20 @@ def get_pending_response(self, session_id: str) -> dict[str, Any] | None: """ return self.pending_responses.get(session_id) + def is_stream_finished( + self, + session_id: str, + max_age_seconds: int = 60, + ) -> bool: + """判断 stream 是否在短期内已结束""" + finished_at = self.completed_streams.get(session_id) + if finished_at is None: + return False + if asyncio.get_event_loop().time() - finished_at > max_age_seconds: + self.completed_streams.pop(session_id, None) + return False + return True + def cleanup_expired_responses(self, max_age_seconds: int = 300): """清理过期的待处理响应 @@ -136,8 +178,75 @@ def cleanup_expired_responses(self, max_age_seconds: int = 300): expired_sessions.append(session_id) for session_id in expired_sessions: - del self.pending_responses[session_id] - logger.debug(f"[WecomAI] 清理过期响应: {session_id}") + self.remove_queues(session_id) + logger.debug(f"[WecomAI] 清理过期响应及队列: {session_id}") + expired_finished = [ + session_id + for session_id, finished_at in self.completed_streams.items() + if current_time - finished_at > 60 + ] + for session_id in expired_finished: + self.completed_streams.pop(session_id, None) + + def set_listener( + self, + callback: Callable[[dict], Awaitable[None]], + ): + self._listener_callback = callback + for session_id in list(self.queues.keys()): + self._start_listener_if_needed(session_id) + + def _start_listener_if_needed(self, session_id: str): + if self._listener_callback is None: + return + if session_id in self._listener_tasks: + task = self._listener_tasks[session_id] + if not task.done(): + return + queue = self.queues.get(session_id) + close_event = self._queue_close_events.get(session_id) + if queue is None or close_event is None: + return + task = asyncio.create_task( + self._listen_to_queue(session_id, queue, close_event), + name=f"wecomai_listener_{session_id}", + ) + self._listener_tasks[session_id] = task + task.add_done_callback(lambda _: self._listener_tasks.pop(session_id, None)) + logger.debug(f"[WecomAI] 为会话启动监听器: {session_id}") + + async def _listen_to_queue( + self, + session_id: str, + queue: asyncio.Queue, + close_event: asyncio.Event, + ): + while True: + get_task = asyncio.create_task(queue.get()) + close_task = asyncio.create_task(close_event.wait()) + try: + done, pending = await asyncio.wait( + {get_task, close_task}, + return_when=asyncio.FIRST_COMPLETED, + ) + for task in pending: + task.cancel() + if close_task in done: + break + data = get_task.result() + if self._listener_callback is None: + continue + try: + await self._listener_callback(data) + except Exception as e: + logger.error(f"处理会话 {session_id} 消息时发生错误: {e}") + except asyncio.CancelledError: + break + finally: + if not get_task.done(): + get_task.cancel() + if not close_task.done(): + close_task.cancel() def get_stats(self) -> dict[str, int]: """获取队列统计信息 diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index 696ffe6139..2fc74af0ff 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -354,12 +354,12 @@ async def chat(self): return Response().error("session_id is empty").__dict__ webchat_conv_id = session_id - back_queue = webchat_queue_mgr.get_or_create_back_queue(webchat_conv_id) # 构建用户消息段(包含 path 用于传递给 adapter) message_parts = await self._build_user_message_parts(message) message_id = str(uuid.uuid4()) + back_queue = webchat_queue_mgr.get_or_create_back_queue(message_id) async def stream(): client_disconnected = False @@ -532,6 +532,8 @@ async def stream(): refs = {} except BaseException as e: logger.exception(f"WebChat stream unexpected error: {e}", exc_info=True) + finally: + webchat_queue_mgr.remove_back_queue(message_id) # 将消息放入会话特定的队列 chat_queue = webchat_queue_mgr.get_or_create_queue(webchat_conv_id) diff --git a/astrbot/dashboard/routes/live_chat.py b/astrbot/dashboard/routes/live_chat.py index 0c3ddcc2e1..22bad47033 100644 --- a/astrbot/dashboard/routes/live_chat.py +++ b/astrbot/dashboard/routes/live_chat.py @@ -256,143 +256,148 @@ async def _process_audio( await queue.put((session.username, cid, payload)) # 3. 等待响应并流式发送 TTS 音频 - back_queue = webchat_queue_mgr.get_or_create_back_queue(cid) + back_queue = webchat_queue_mgr.get_or_create_back_queue(message_id) bot_text = "" audio_playing = False - while True: - if session.should_interrupt: - # 用户打断,停止处理 - logger.info("[Live Chat] 检测到用户打断") - await websocket.send_json({"t": "stop_play"}) - # 保存消息并标记为被打断 - await self._save_interrupted_message(session, user_text, bot_text) - # 清空队列中未处理的消息 - while not back_queue.empty(): - try: - back_queue.get_nowait() - except asyncio.QueueEmpty: - break - break - - try: - result = await asyncio.wait_for(back_queue.get(), timeout=0.5) - except asyncio.TimeoutError: - continue - - if not result: - continue - - result_message_id = result.get("message_id") - if result_message_id != message_id: - logger.warning( - f"[Live Chat] 消息 ID 不匹配: {result_message_id} != {message_id}" - ) - continue - - result_type = result.get("type") - result_chain_type = result.get("chain_type") - data = result.get("data", "") - - if result_chain_type == "agent_stats": - try: - stats = json.loads(data) - await websocket.send_json( - { - "t": "metrics", - "data": { - "llm_ttft": stats.get("time_to_first_token", 0), - "llm_total_time": stats.get("end_time", 0) - - stats.get("start_time", 0), - }, - } + try: + while True: + if session.should_interrupt: + # 用户打断,停止处理 + logger.info("[Live Chat] 检测到用户打断") + await websocket.send_json({"t": "stop_play"}) + # 保存消息并标记为被打断 + await self._save_interrupted_message( + session, user_text, bot_text ) - except Exception as e: - logger.error(f"[Live Chat] 解析 AgentStats 失败: {e}") - continue + # 清空队列中未处理的消息 + while not back_queue.empty(): + try: + back_queue.get_nowait() + except asyncio.QueueEmpty: + break + break - if result_chain_type == "tts_stats": try: - stats = json.loads(data) - await websocket.send_json( - { - "t": "metrics", - "data": stats, - } - ) - except Exception as e: - logger.error(f"[Live Chat] 解析 TTSStats 失败: {e}") - continue - - if result_type == "plain": - # 普通文本消息 - bot_text += data - - elif result_type == "audio_chunk": - # 流式音频数据 - if not audio_playing: - audio_playing = True - logger.debug("[Live Chat] 开始播放音频流") - - # Calculate latency from wav assembly finish to first audio chunk - speak_to_first_frame_latency = ( - time.time() - wav_assembly_finish_time - ) - await websocket.send_json( - { - "t": "metrics", - "data": { - "speak_to_first_frame": speak_to_first_frame_latency - }, - } + result = await asyncio.wait_for(back_queue.get(), timeout=0.5) + except asyncio.TimeoutError: + continue + + if not result: + continue + + result_message_id = result.get("message_id") + if result_message_id != message_id: + logger.warning( + f"[Live Chat] 消息 ID 不匹配: {result_message_id} != {message_id}" ) + continue + + result_type = result.get("type") + result_chain_type = result.get("chain_type") + data = result.get("data", "") - text = result.get("text") - if text: + if result_chain_type == "agent_stats": + try: + stats = json.loads(data) + await websocket.send_json( + { + "t": "metrics", + "data": { + "llm_ttft": stats.get("time_to_first_token", 0), + "llm_total_time": stats.get("end_time", 0) + - stats.get("start_time", 0), + }, + } + ) + except Exception as e: + logger.error(f"[Live Chat] 解析 AgentStats 失败: {e}") + continue + + if result_chain_type == "tts_stats": + try: + stats = json.loads(data) + await websocket.send_json( + { + "t": "metrics", + "data": stats, + } + ) + except Exception as e: + logger.error(f"[Live Chat] 解析 TTSStats 失败: {e}") + continue + + if result_type == "plain": + # 普通文本消息 + bot_text += data + + elif result_type == "audio_chunk": + # 流式音频数据 + if not audio_playing: + audio_playing = True + logger.debug("[Live Chat] 开始播放音频流") + + # Calculate latency from wav assembly finish to first audio chunk + speak_to_first_frame_latency = ( + time.time() - wav_assembly_finish_time + ) + await websocket.send_json( + { + "t": "metrics", + "data": { + "speak_to_first_frame": speak_to_first_frame_latency + }, + } + ) + + text = result.get("text") + if text: + await websocket.send_json( + { + "t": "bot_text_chunk", + "data": {"text": text}, + } + ) + + # 发送音频数据给前端 await websocket.send_json( { - "t": "bot_text_chunk", - "data": {"text": text}, + "t": "response", + "data": data, # base64 编码的音频数据 } ) - # 发送音频数据给前端 - await websocket.send_json( - { - "t": "response", - "data": data, # base64 编码的音频数据 - } - ) - - elif result_type in ["complete", "end"]: - # 处理完成 - logger.info(f"[Live Chat] Bot 回复完成: {bot_text}") - - # 如果没有音频流,发送 bot 消息文本 - if not audio_playing: + elif result_type in ["complete", "end"]: + # 处理完成 + logger.info(f"[Live Chat] Bot 回复完成: {bot_text}") + + # 如果没有音频流,发送 bot 消息文本 + if not audio_playing: + await websocket.send_json( + { + "t": "bot_msg", + "data": { + "text": bot_text, + "ts": int(time.time() * 1000), + }, + } + ) + + # 发送结束标记 + await websocket.send_json({"t": "end"}) + + # 发送总耗时 + wav_to_tts_duration = time.time() - wav_assembly_finish_time await websocket.send_json( { - "t": "bot_msg", - "data": { - "text": bot_text, - "ts": int(time.time() * 1000), - }, + "t": "metrics", + "data": {"wav_to_tts_total_time": wav_to_tts_duration}, } ) - - # 发送结束标记 - await websocket.send_json({"t": "end"}) - - # 发送总耗时 - wav_to_tts_duration = time.time() - wav_assembly_finish_time - await websocket.send_json( - { - "t": "metrics", - "data": {"wav_to_tts_total_time": wav_to_tts_duration}, - } - ) - break + break + finally: + webchat_queue_mgr.remove_back_queue(message_id) except Exception as e: logger.error(f"[Live Chat] 处理音频失败: {e}", exc_info=True) From a0a39372bc72fd83bc0e4ee40224948c0d4ce7df Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sat, 7 Feb 2026 14:01:46 +0800 Subject: [PATCH 2/2] perf: enhance webchat back queue management with conversation ID support --- .../platform/sources/webchat/webchat_event.py | 12 +++++++-- .../sources/webchat/webchat_queue_mgr.py | 25 ++++++++++++++++++- astrbot/dashboard/routes/chat.py | 5 +++- astrbot/dashboard/routes/live_chat.py | 2 +- 4 files changed, 39 insertions(+), 5 deletions(-) diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py index 82c1a7580e..0dcc9cc0c4 100644 --- a/astrbot/core/platform/sources/webchat/webchat_event.py +++ b/astrbot/core/platform/sources/webchat/webchat_event.py @@ -27,7 +27,11 @@ async def _send( streaming: bool = False, ) -> str | None: request_id = str(message_id) - web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(request_id) + conversation_id = session_id.split("!")[-1] + web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue( + request_id, + conversation_id, + ) if not message: await web_chat_back_queue.put( { @@ -126,7 +130,11 @@ async def send_streaming(self, generator, use_fallback: bool = False): reasoning_content = "" message_id = self.message_obj.message_id request_id = str(message_id) - web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(request_id) + conversation_id = self.session_id.split("!")[-1] + web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue( + request_id, + conversation_id, + ) async for chain in generator: # 处理音频流(Live Mode) if chain.type == "audio_chunk": diff --git a/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py b/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py index 6acf61b730..c7636faacd 100644 --- a/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py +++ b/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py @@ -10,6 +10,8 @@ def __init__(self, queue_maxsize: int = 128, back_queue_maxsize: int = 512) -> N """Conversation ID to asyncio.Queue mapping""" self.back_queues: dict[str, asyncio.Queue] = {} """Request ID to asyncio.Queue mapping for responses""" + self._conversation_back_requests: dict[str, set[str]] = {} + self._request_conversation: dict[str, str] = {} self._queue_close_events: dict[str, asyncio.Event] = {} self._listener_tasks: dict[str, asyncio.Task] = {} self._listener_callback: Callable[[tuple], Awaitable[None]] | None = None @@ -24,20 +26,41 @@ def get_or_create_queue(self, conversation_id: str) -> asyncio.Queue: self._start_listener_if_needed(conversation_id) return self.queues[conversation_id] - def get_or_create_back_queue(self, request_id: str) -> asyncio.Queue: + def get_or_create_back_queue( + self, + request_id: str, + conversation_id: str | None = None, + ) -> asyncio.Queue: """Get or create a back queue for the given request ID""" if request_id not in self.back_queues: self.back_queues[request_id] = asyncio.Queue( maxsize=self.back_queue_maxsize ) + if conversation_id: + self._request_conversation[request_id] = conversation_id + if conversation_id not in self._conversation_back_requests: + self._conversation_back_requests[conversation_id] = set() + self._conversation_back_requests[conversation_id].add(request_id) return self.back_queues[request_id] def remove_back_queue(self, request_id: str): """Remove back queue for the given request ID""" self.back_queues.pop(request_id, None) + conversation_id = self._request_conversation.pop(request_id, None) + if conversation_id: + request_ids = self._conversation_back_requests.get(conversation_id) + if request_ids is not None: + request_ids.discard(request_id) + if not request_ids: + self._conversation_back_requests.pop(conversation_id, None) def remove_queues(self, conversation_id: str): """Remove queues for the given conversation ID""" + for request_id in list( + self._conversation_back_requests.get(conversation_id, set()) + ): + self.remove_back_queue(request_id) + self._conversation_back_requests.pop(conversation_id, None) self.remove_queue(conversation_id) def remove_queue(self, conversation_id: str): diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index 2fc74af0ff..55b279fe11 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -359,7 +359,10 @@ async def chat(self): message_parts = await self._build_user_message_parts(message) message_id = str(uuid.uuid4()) - back_queue = webchat_queue_mgr.get_or_create_back_queue(message_id) + back_queue = webchat_queue_mgr.get_or_create_back_queue( + message_id, + webchat_conv_id, + ) async def stream(): client_disconnected = False diff --git a/astrbot/dashboard/routes/live_chat.py b/astrbot/dashboard/routes/live_chat.py index 22bad47033..b6336a7974 100644 --- a/astrbot/dashboard/routes/live_chat.py +++ b/astrbot/dashboard/routes/live_chat.py @@ -256,7 +256,7 @@ async def _process_audio( await queue.put((session.username, cid, payload)) # 3. 等待响应并流式发送 TTS 音频 - back_queue = webchat_queue_mgr.get_or_create_back_queue(message_id) + back_queue = webchat_queue_mgr.get_or_create_back_queue(message_id, cid) bot_text = "" audio_playing = False