From 53ae8cd7cf37fb3ed3194b4014a3366d907c2f19 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Tue, 24 Feb 2026 21:14:16 +0800 Subject: [PATCH 1/3] feat: implement websockets transport mode selection for chat - Added transport mode selection (SSE/WebSocket) in the chat component. - Updated conversation sidebar to include transport mode options. - Integrated transport mode handling in message sending logic. - Refactored message sending functions to support both SSE and WebSocket. - Enhanced WebSocket connection management and message handling. - Updated localization files for transport mode labels. - Configured Vite to support WebSocket proxying. --- .../sources/webchat/message_parts_helper.py | 333 ++++++ .../sources/webchat/webchat_adapter.py | 85 +- .../platform/sources/webchat/webchat_event.py | 13 +- .../sources/webchat/webchat_queue_mgr.py | 4 + astrbot/dashboard/routes/chat.py | 118 +- astrbot/dashboard/routes/live_chat.py | 512 +++++++- astrbot/dashboard/routes/open_api.py | 89 +- dashboard/src/components/chat/Chat.vue | 8 +- .../components/chat/ConversationSidebar.vue | 37 + dashboard/src/composables/useMessages.ts | 1028 ++++++++++++----- .../src/i18n/locales/en-US/features/chat.json | 11 +- .../src/i18n/locales/zh-CN/features/chat.json | 5 + dashboard/vite.config.ts | 1 + 13 files changed, 1757 insertions(+), 487 deletions(-) create mode 100644 astrbot/core/platform/sources/webchat/message_parts_helper.py diff --git a/astrbot/core/platform/sources/webchat/message_parts_helper.py b/astrbot/core/platform/sources/webchat/message_parts_helper.py new file mode 100644 index 0000000000..608e3448a2 --- /dev/null +++ b/astrbot/core/platform/sources/webchat/message_parts_helper.py @@ -0,0 +1,333 @@ +import json +import mimetypes +import shutil +import uuid +from collections.abc import Awaitable, Callable, Sequence +from pathlib import Path + +from astrbot.core.db.po import Attachment +from astrbot.core.message.components import ( + File, + Image, + Json, + Plain, + Record, + Reply, + Video, +) +from astrbot.core.message.message_event_result import MessageChain + +AttachmentGetter = Callable[[str], Awaitable[Attachment | None]] +AttachmentInserter = Callable[[str, str, str], Awaitable[Attachment | None]] + +MEDIA_PART_TYPES = {"image", "record", "file", "video"} + + +def strip_message_parts_path_fields(message_parts: list[dict]) -> list[dict]: + return [{k: v for k, v in part.items() if k != "path"} for part in message_parts] + + +def webchat_message_parts_have_content(message_parts: list[dict]) -> bool: + return any( + part.get("type") in ("plain", "image", "record", "file", "video") + and (part.get("text") or part.get("attachment_id") or part.get("filename")) + for part in message_parts + ) + + +async def build_webchat_message_parts( + message_payload: str | list, + *, + get_attachment_by_id: AttachmentGetter, + strict: bool = False, +) -> list[dict]: + if isinstance(message_payload, str): + text = message_payload.strip() + return [{"type": "plain", "text": text}] if text else [] + + if not isinstance(message_payload, list): + if strict: + raise ValueError("message must be a string or list") + return [] + + message_parts: list[dict] = [] + for part in message_payload: + if not isinstance(part, dict): + if strict: + raise ValueError("message part must be an object") + continue + + part_type = str(part.get("type", "")).strip() + if part_type == "plain": + text = str(part.get("text", "")) + if text: + message_parts.append({"type": "plain", "text": text}) + continue + + if part_type == "reply": + message_id = part.get("message_id") + if message_id is None: + if strict: + raise ValueError("reply part missing message_id") + continue + message_parts.append( + { + "type": "reply", + "message_id": message_id, + "selected_text": str(part.get("selected_text", "")), + } + ) + continue + + if part_type not in MEDIA_PART_TYPES: + if strict: + raise ValueError(f"unsupported message part type: {part_type}") + continue + + attachment_id = part.get("attachment_id") + if not attachment_id: + if strict: + raise ValueError(f"{part_type} part missing attachment_id") + continue + + attachment = await get_attachment_by_id(str(attachment_id)) + if not attachment: + if strict: + raise ValueError(f"attachment not found: {attachment_id}") + continue + + attachment_path = Path(attachment.path) + message_parts.append( + { + "type": attachment.type, + "attachment_id": attachment.attachment_id, + "filename": attachment_path.name, + "path": str(attachment_path), + } + ) + + return message_parts + + +def webchat_message_parts_to_message_chain( + message_parts: list[dict], + *, + strict: bool = False, +) -> MessageChain: + components = [] + has_content = False + + for part in message_parts: + if not isinstance(part, dict): + if strict: + raise ValueError("message part must be an object") + continue + + part_type = str(part.get("type", "")).strip() + if part_type == "plain": + text = str(part.get("text", "")) + if text: + components.append(Plain(text=text)) + has_content = True + continue + + if part_type == "reply": + message_id = part.get("message_id") + if message_id is None: + if strict: + raise ValueError("reply part missing message_id") + continue + components.append( + Reply( + id=str(message_id), + message_str=str(part.get("selected_text", "")), + chain=[], + ) + ) + continue + + if part_type not in MEDIA_PART_TYPES: + if strict: + raise ValueError(f"unsupported message part type: {part_type}") + continue + + path = part.get("path") + if not path: + if strict: + raise ValueError(f"{part_type} part missing path") + continue + + file_path = Path(str(path)) + if not file_path.exists(): + if strict: + raise ValueError(f"file not found: {file_path!s}") + continue + + file_path_str = str(file_path.resolve()) + has_content = True + if part_type == "image": + components.append(Image.fromFileSystem(file_path_str)) + elif part_type == "record": + components.append(Record.fromFileSystem(file_path_str)) + elif part_type == "video": + components.append(Video.fromFileSystem(file_path_str)) + else: + filename = str(part.get("filename", "")).strip() or file_path.name + components.append(File(name=filename, file=file_path_str)) + + if strict and (not components or not has_content): + raise ValueError("Message content is empty (reply only is not allowed)") + + return MessageChain(chain=components) + + +async def build_message_chain_from_payload( + message_payload: str | list, + *, + get_attachment_by_id: AttachmentGetter, + strict: bool = True, +) -> MessageChain: + message_parts = await build_webchat_message_parts( + message_payload, + get_attachment_by_id=get_attachment_by_id, + strict=strict, + ) + return webchat_message_parts_to_message_chain(message_parts, strict=strict) + + +async def create_attachment_part_from_existing_file( + filename: str, + *, + attach_type: str, + insert_attachment: AttachmentInserter, + attachments_dir: str | Path, + fallback_dirs: Sequence[str | Path] = (), +) -> dict | None: + basename = Path(filename).name + candidate_paths = [Path(attachments_dir) / basename] + candidate_paths.extend(Path(p) / basename for p in fallback_dirs) + + file_path = next((path for path in candidate_paths if path.exists()), None) + if not file_path: + return None + + mime_type, _ = mimetypes.guess_type(str(file_path)) + attachment = await insert_attachment( + str(file_path), + attach_type, + mime_type or "application/octet-stream", + ) + if not attachment: + return None + + return { + "type": attach_type, + "attachment_id": attachment.attachment_id, + "filename": file_path.name, + } + + +async def message_chain_to_storage_message_parts( + message_chain: MessageChain, + *, + insert_attachment: AttachmentInserter, + attachments_dir: str | Path, +) -> list[dict]: + target_dir = Path(attachments_dir) + target_dir.mkdir(parents=True, exist_ok=True) + + parts: list[dict] = [] + for comp in message_chain.chain: + if isinstance(comp, Plain): + if comp.text: + parts.append({"type": "plain", "text": comp.text}) + continue + + if isinstance(comp, Json): + parts.append( + {"type": "plain", "text": json.dumps(comp.data, ensure_ascii=False)} + ) + continue + + if isinstance(comp, Image): + file_path = await comp.convert_to_file_path() + attachment_part = await _copy_file_to_attachment_part( + file_path=file_path, + attach_type="image", + insert_attachment=insert_attachment, + attachments_dir=target_dir, + ) + if attachment_part: + parts.append(attachment_part) + continue + + if isinstance(comp, Record): + file_path = await comp.convert_to_file_path() + attachment_part = await _copy_file_to_attachment_part( + file_path=file_path, + attach_type="record", + insert_attachment=insert_attachment, + attachments_dir=target_dir, + ) + if attachment_part: + parts.append(attachment_part) + continue + + if isinstance(comp, Video): + file_path = await comp.convert_to_file_path() + attachment_part = await _copy_file_to_attachment_part( + file_path=file_path, + attach_type="video", + insert_attachment=insert_attachment, + attachments_dir=target_dir, + ) + if attachment_part: + parts.append(attachment_part) + continue + + if isinstance(comp, File): + file_path = await comp.get_file() + attachment_part = await _copy_file_to_attachment_part( + file_path=file_path, + attach_type="file", + insert_attachment=insert_attachment, + attachments_dir=target_dir, + display_name=comp.name, + ) + if attachment_part: + parts.append(attachment_part) + continue + + return parts + + +async def _copy_file_to_attachment_part( + *, + file_path: str, + attach_type: str, + insert_attachment: AttachmentInserter, + attachments_dir: Path, + display_name: str | None = None, +) -> dict | None: + src_path = Path(file_path) + if not src_path.exists() or not src_path.is_file(): + return None + + suffix = src_path.suffix + target_path = attachments_dir / f"{uuid.uuid4().hex}{suffix}" + shutil.copy2(src_path, target_path) + + mime_type, _ = mimetypes.guess_type(target_path.name) + attachment = await insert_attachment( + str(target_path), + attach_type, + mime_type or "application/octet-stream", + ) + if not attachment: + return None + + return { + "type": attach_type, + "attachment_id": attachment.attachment_id, + "filename": display_name or src_path.name, + } diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py index 047417aaaa..e72594d8a0 100644 --- a/astrbot/core/platform/sources/webchat/webchat_adapter.py +++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py @@ -3,12 +3,20 @@ import time import uuid from collections.abc import Callable, Coroutine +from pathlib import Path from typing import Any from astrbot import logger from astrbot.core import db_helper from astrbot.core.db.po import PlatformMessageHistory -from astrbot.core.message.components import File, Image, Plain, Record, Reply, Video +from astrbot.core.message.components import ( + File, + Image, + Plain, + Record, + Reply, + Video, +) from astrbot.core.message.message_event_result import MessageChain from astrbot.core.platform import ( AstrBotMessage, @@ -21,10 +29,20 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path from ...register import register_platform_adapter +from .message_parts_helper import message_chain_to_storage_message_parts from .webchat_event import WebChatMessageEvent from .webchat_queue_mgr import WebChatQueueMgr, webchat_queue_mgr +def _extract_conversation_id(session_id: str) -> str: + """Extract raw webchat conversation id from event/session id.""" + if session_id.startswith("webchat!"): + parts = session_id.split("!", 2) + if len(parts) == 3: + return parts[2] + return session_id + + class QueueListener: def __init__( self, @@ -57,13 +75,15 @@ def __init__( self.settings = platform_settings self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs") + self.attachments_dir = Path(get_astrbot_data_path()) / "attachments" os.makedirs(self.imgs_dir, exist_ok=True) + self.attachments_dir.mkdir(parents=True, exist_ok=True) self.metadata = PlatformMetadata( name="webchat", description="webchat", id="webchat", - support_proactive_message=False, + support_proactive_message=True, ) self._shutdown_event = asyncio.Event() self._webchat_queue_mgr = webchat_queue_mgr @@ -73,10 +93,67 @@ async def send_by_session( session: MessageSesion, message_chain: MessageChain, ) -> None: - message_id = f"active_{str(uuid.uuid4())}" - await WebChatMessageEvent._send(message_id, message_chain, session.session_id) + conversation_id = _extract_conversation_id(session.session_id) + active_request_ids = self._webchat_queue_mgr.list_back_request_ids( + conversation_id + ) + subscription_request_ids = [ + req_id for req_id in active_request_ids if req_id.startswith("ws_sub_") + ] + target_request_ids = subscription_request_ids or active_request_ids + + if target_request_ids: + for request_id in target_request_ids: + await WebChatMessageEvent._send( + request_id, + message_chain, + session.session_id, + ) + else: + message_id = f"active_{uuid.uuid4()!s}" + await WebChatMessageEvent._send( + message_id, + message_chain, + session.session_id, + ) + + should_persist = ( + bool(subscription_request_ids) + or not active_request_ids + or all(req_id.startswith("active_") for req_id in active_request_ids) + ) + if should_persist: + try: + await self._save_proactive_message(conversation_id, message_chain) + except Exception as e: + logger.error( + f"[WebChatAdapter] Failed to save proactive message: {e}", + exc_info=True, + ) + await super().send_by_session(session, message_chain) + async def _save_proactive_message( + self, + conversation_id: str, + message_chain: MessageChain, + ) -> None: + message_parts = await message_chain_to_storage_message_parts( + message_chain, + insert_attachment=db_helper.insert_attachment, + attachments_dir=self.attachments_dir, + ) + if not message_parts: + return + + await db_helper.insert_platform_message_history( + platform_id="webchat", + user_id=conversation_id, + content={"type": "bot", "message": message_parts}, + sender_id="bot", + sender_name="bot", + ) + async def _get_message_history( self, message_id: int ) -> PlatformMessageHistory | None: diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py index a680f76174..b7da864aae 100644 --- a/astrbot/core/platform/sources/webchat/webchat_event.py +++ b/astrbot/core/platform/sources/webchat/webchat_event.py @@ -14,6 +14,15 @@ attachments_dir = os.path.join(get_astrbot_data_path(), "attachments") +def _extract_conversation_id(session_id: str) -> str: + """Extract raw webchat conversation id from event/session id.""" + if session_id.startswith("webchat!"): + parts = session_id.split("!", 2) + if len(parts) == 3: + return parts[2] + return session_id + + class WebChatMessageEvent(AstrMessageEvent): def __init__(self, message_str, message_obj, platform_meta, session_id) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) @@ -27,7 +36,7 @@ async def _send( streaming: bool = False, ) -> str | None: request_id = str(message_id) - conversation_id = session_id.split("!")[-1] + conversation_id = _extract_conversation_id(session_id) web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue( request_id, conversation_id, @@ -130,7 +139,7 @@ async def send_streaming(self, generator, use_fallback: bool = False) -> None: reasoning_content = "" message_id = self.message_obj.message_id request_id = str(message_id) - conversation_id = self.session_id.split("!")[-1] + conversation_id = _extract_conversation_id(self.session_id) web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue( request_id, conversation_id, diff --git a/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py b/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py index fd35e837c8..f3ade1589a 100644 --- a/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py +++ b/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py @@ -75,6 +75,10 @@ def remove_queue(self, conversation_id: str): if task is not None: task.cancel() + def list_back_request_ids(self, conversation_id: str) -> list[str]: + """List active back-queue request IDs for a conversation.""" + return list(self._conversation_back_requests.get(conversation_id, set())) + def has_queue(self, conversation_id: str) -> bool: """Check if a queue exists for the given conversation ID""" return conversation_id in self.queues diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index 1235dd3814..0602cc0745 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -1,6 +1,5 @@ import asyncio import json -import mimetypes import os import re import uuid @@ -14,6 +13,12 @@ from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.db import BaseDatabase from astrbot.core.platform.message_type import MessageType +from astrbot.core.platform.sources.webchat.message_parts_helper import ( + build_webchat_message_parts, + create_attachment_part_from_existing_file, + strip_message_parts_path_fields, + webchat_message_parts_have_content, +) from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr from astrbot.core.utils.active_event_registry import active_event_registry from astrbot.core.utils.astrbot_path import get_astrbot_data_path @@ -166,83 +171,24 @@ async def post_file(self): ) async def _build_user_message_parts(self, message: str | list) -> list[dict]: - """构建用户消息的部分列表 - - Args: - message: 文本消息 (str) 或消息段列表 (list) - """ - parts = [] - - if isinstance(message, list): - for part in message: - part_type = part.get("type") - if part_type == "plain": - parts.append({"type": "plain", "text": part.get("text", "")}) - elif part_type == "reply": - parts.append( - { - "type": "reply", - "message_id": part.get("message_id"), - "selected_text": part.get("selected_text", ""), - } - ) - elif attachment_id := part.get("attachment_id"): - attachment = await self.db.get_attachment_by_id(attachment_id) - if attachment: - parts.append( - { - "type": attachment.type, - "attachment_id": attachment.attachment_id, - "filename": os.path.basename(attachment.path), - "path": attachment.path, # will be deleted - } - ) - return parts - - if message: - parts.append({"type": "plain", "text": message}) - - return parts + """构建用户消息的部分列表。""" + return await build_webchat_message_parts( + message, + get_attachment_by_id=self.db.get_attachment_by_id, + strict=False, + ) async def _create_attachment_from_file( self, filename: str, attach_type: str ) -> dict | None: - """从本地文件创建 attachment 并返回消息部分 - - 用于处理 bot 回复中的媒体文件 - - Args: - filename: 存储的文件名 - attach_type: 附件类型 (image, record, file, video) - """ - basename = os.path.basename(filename) - candidate_paths = [ - os.path.join(self.attachments_dir, basename), - os.path.join(self.legacy_img_dir, basename), - ] - file_path = next((p for p in candidate_paths if os.path.exists(p)), None) - if not file_path: - return None - - # guess mime type - mime_type, _ = mimetypes.guess_type(filename) - if not mime_type: - mime_type = "application/octet-stream" - - # insert attachment - attachment = await self.db.insert_attachment( - path=file_path, - type=attach_type, - mime_type=mime_type, + """从本地文件创建 attachment 并返回消息部分。""" + return await create_attachment_part_from_existing_file( + filename, + attach_type=attach_type, + insert_attachment=self.db.insert_attachment, + attachments_dir=self.attachments_dir, + fallback_dirs=[self.legacy_img_dir], ) - if not attachment: - return None - - return { - "type": attach_type, - "attachment_id": attachment.attachment_id, - "filename": os.path.basename(file_path), - } def _extract_web_search_refs( self, accumulated_text: str, accumulated_parts: list @@ -356,21 +302,6 @@ async def chat(self, post_data: dict | None = None): selected_model = post_data.get("selected_model") enable_streaming = post_data.get("enable_streaming", True) - # 检查消息是否为空 - if isinstance(message, list): - has_content = any( - part.get("type") in ("plain", "image", "record", "file", "video") - for part in message - ) - if not has_content: - return ( - Response() - .error("Message content is empty (reply only is not allowed)") - .__dict__ - ) - elif not message: - return Response().error("Message are both empty").__dict__ - if not session_id: return Response().error("session_id is empty").__dict__ @@ -378,6 +309,12 @@ async def chat(self, post_data: dict | None = None): # 构建用户消息段(包含 path 用于传递给 adapter) message_parts = await self._build_user_message_parts(message) + if not webchat_message_parts_have_content(message_parts): + return ( + Response() + .error("Message content is empty (reply only is not allowed)") + .__dict__ + ) message_id = str(uuid.uuid4()) back_queue = webchat_queue_mgr.get_or_create_back_queue( @@ -583,10 +520,7 @@ async def stream(): ), ) - message_parts_for_storage = [] - for part in message_parts: - part_copy = {k: v for k, v in part.items() if k != "path"} - message_parts_for_storage.append(part_copy) + message_parts_for_storage = strip_message_parts_path_fields(message_parts) await self.platform_history_mgr.insert( platform_id="webchat", diff --git a/astrbot/dashboard/routes/live_chat.py b/astrbot/dashboard/routes/live_chat.py index 8c922ab69a..25438565e1 100644 --- a/astrbot/dashboard/routes/live_chat.py +++ b/astrbot/dashboard/routes/live_chat.py @@ -1,6 +1,7 @@ import asyncio import json import os +import re import time import uuid import wave @@ -10,9 +11,16 @@ from quart import websocket from astrbot import logger +from astrbot.core import sp from astrbot.core.core_lifecycle import AstrBotCoreLifecycle +from astrbot.core.platform.sources.webchat.message_parts_helper import ( + build_webchat_message_parts, + create_attachment_part_from_existing_file, + strip_message_parts_path_fields, + webchat_message_parts_have_content, +) from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr -from astrbot.core.utils.astrbot_path import get_astrbot_temp_path +from astrbot.core.utils.astrbot_path import get_astrbot_data_path, get_astrbot_temp_path from .route import Route, RouteContext @@ -30,6 +38,9 @@ def __init__(self, session_id: str, username: str) -> None: self.audio_frames: list[bytes] = [] self.current_stamp: str | None = None self.temp_audio_path: str | None = None + self.chat_subscriptions: dict[str, str] = {} + self.chat_subscription_tasks: dict[str, asyncio.Task] = {} + self.ws_send_lock = asyncio.Lock() def start_speaking(self, stamp: str) -> None: """开始说话""" @@ -106,13 +117,26 @@ def __init__( self.core_lifecycle = core_lifecycle self.db = db self.plugin_manager = core_lifecycle.plugin_manager + self.platform_history_mgr = core_lifecycle.platform_message_history_manager self.sessions: dict[str, LiveChatSession] = {} + self.attachments_dir = os.path.join(get_astrbot_data_path(), "attachments") + self.legacy_img_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs") + os.makedirs(self.attachments_dir, exist_ok=True) # 注册 WebSocket 路由 self.app.websocket("/api/live_chat/ws")(self.live_chat_ws) + self.app.websocket("/api/unified_chat/ws")(self.unified_chat_ws) async def live_chat_ws(self) -> None: - """Live Chat WebSocket 处理器""" + """Legacy Live Chat WebSocket 处理器(默认 ct=live)""" + await self._unified_ws_loop(force_ct="live") + + async def unified_chat_ws(self) -> None: + """Unified Chat WebSocket 处理器(支持 ct=live/chat)""" + await self._unified_ws_loop(force_ct=None) + + async def _unified_ws_loop(self, force_ct: str | None = None) -> None: + """统一 WebSocket 循环""" # WebSocket 不能通过 header 传递 token,需要从 query 参数获取 # 注意:WebSocket 上下文使用 websocket.args 而不是 request.args token = websocket.args.get("token") @@ -140,7 +164,11 @@ async def live_chat_ws(self) -> None: try: while True: message = await websocket.receive_json() - await self._handle_message(live_session, message) + ct = force_ct or message.get("ct", "live") + if ct == "chat": + await self._handle_chat_message(live_session, message) + else: + await self._handle_message(live_session, message) except Exception as e: logger.error(f"[Live Chat] WebSocket 错误: {e}", exc_info=True) @@ -148,10 +176,488 @@ async def live_chat_ws(self) -> None: finally: # 清理会话 if session_id in self.sessions: + await self._cleanup_chat_subscriptions(live_session) live_session.cleanup() del self.sessions[session_id] logger.info(f"[Live Chat] WebSocket 连接关闭: {username}") + async def _create_attachment_from_file( + self, filename: str, attach_type: str + ) -> dict | None: + """从本地文件创建 attachment 并返回消息部分。""" + return await create_attachment_part_from_existing_file( + filename, + attach_type=attach_type, + insert_attachment=self.db.insert_attachment, + attachments_dir=self.attachments_dir, + fallback_dirs=[self.legacy_img_dir], + ) + + def _extract_web_search_refs( + self, accumulated_text: str, accumulated_parts: list + ) -> dict: + """从消息中提取 web_search 引用。""" + supported = ["web_search_tavily", "web_search_bocha"] + web_search_results = {} + tool_call_parts = [ + p + for p in accumulated_parts + if p.get("type") == "tool_call" and p.get("tool_calls") + ] + + for part in tool_call_parts: + for tool_call in part["tool_calls"]: + if tool_call.get("name") not in supported or not tool_call.get( + "result" + ): + continue + try: + result_data = json.loads(tool_call["result"]) + for item in result_data.get("results", []): + if idx := item.get("index"): + web_search_results[idx] = { + "url": item.get("url"), + "title": item.get("title"), + "snippet": item.get("snippet"), + } + except (json.JSONDecodeError, KeyError): + pass + + if not web_search_results: + return {} + + ref_indices = { + m.strip() for m in re.findall(r"(.*?)", accumulated_text) + } + + used_refs = [] + for ref_index in ref_indices: + if ref_index not in web_search_results: + continue + payload = {"index": ref_index, **web_search_results[ref_index]} + if favicon := sp.temporary_cache.get("_ws_favicon", {}).get(payload["url"]): + payload["favicon"] = favicon + used_refs.append(payload) + + return {"used": used_refs} if used_refs else {} + + async def _save_bot_message( + self, + webchat_conv_id: str, + text: str, + media_parts: list, + reasoning: str, + agent_stats: dict, + refs: dict, + ): + """保存 bot 消息到历史记录。""" + bot_message_parts = [] + bot_message_parts.extend(media_parts) + if text: + bot_message_parts.append({"type": "plain", "text": text}) + + new_his = {"type": "bot", "message": bot_message_parts} + if reasoning: + new_his["reasoning"] = reasoning + if agent_stats: + new_his["agent_stats"] = agent_stats + if refs: + new_his["refs"] = refs + + return await self.platform_history_mgr.insert( + platform_id="webchat", + user_id=webchat_conv_id, + content=new_his, + sender_id="bot", + sender_name="bot", + ) + + async def _send_chat_payload(self, session: LiveChatSession, payload: dict) -> None: + async with session.ws_send_lock: + await websocket.send_json(payload) + + async def _forward_chat_subscription( + self, + session: LiveChatSession, + chat_session_id: str, + request_id: str, + ) -> None: + back_queue = webchat_queue_mgr.get_or_create_back_queue( + request_id, chat_session_id + ) + try: + while True: + result = await back_queue.get() + if not result: + continue + await self._send_chat_payload(session, {"ct": "chat", **result}) + except asyncio.CancelledError: + pass + except Exception as e: + logger.error( + f"[Live Chat] chat subscription forward failed ({chat_session_id}): {e}", + exc_info=True, + ) + finally: + webchat_queue_mgr.remove_back_queue(request_id) + if session.chat_subscriptions.get(chat_session_id) == request_id: + session.chat_subscriptions.pop(chat_session_id, None) + session.chat_subscription_tasks.pop(chat_session_id, None) + + async def _ensure_chat_subscription( + self, + session: LiveChatSession, + chat_session_id: str, + ) -> str: + existing_request_id = session.chat_subscriptions.get(chat_session_id) + existing_task = session.chat_subscription_tasks.get(chat_session_id) + if existing_request_id and existing_task and not existing_task.done(): + return existing_request_id + + request_id = f"ws_sub_{uuid.uuid4().hex}" + session.chat_subscriptions[chat_session_id] = request_id + task = asyncio.create_task( + self._forward_chat_subscription(session, chat_session_id, request_id), + name=f"chat_ws_sub_{chat_session_id}", + ) + session.chat_subscription_tasks[chat_session_id] = task + return request_id + + async def _cleanup_chat_subscriptions(self, session: LiveChatSession) -> None: + tasks = list(session.chat_subscription_tasks.values()) + for task in tasks: + task.cancel() + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + + for request_id in list(session.chat_subscriptions.values()): + webchat_queue_mgr.remove_back_queue(request_id) + session.chat_subscriptions.clear() + session.chat_subscription_tasks.clear() + + async def _handle_chat_message( + self, session: LiveChatSession, message: dict + ) -> None: + """处理 Chat Mode 消息(ct=chat)""" + msg_type = message.get("t") + + if msg_type == "bind": + chat_session_id = message.get("session_id") + if not isinstance(chat_session_id, str) or not chat_session_id: + await self._send_chat_payload( + session, + { + "ct": "chat", + "t": "error", + "data": "session_id is required", + "code": "INVALID_MESSAGE_FORMAT", + }, + ) + return + + request_id = await self._ensure_chat_subscription(session, chat_session_id) + await self._send_chat_payload( + session, + { + "ct": "chat", + "type": "session_bound", + "session_id": chat_session_id, + "message_id": request_id, + }, + ) + return + + if msg_type == "interrupt": + session.should_interrupt = True + await self._send_chat_payload( + session, + { + "ct": "chat", + "t": "error", + "data": "INTERRUPTED", + "code": "INTERRUPTED", + }, + ) + return + + if msg_type != "send": + await self._send_chat_payload( + session, + { + "ct": "chat", + "t": "error", + "data": f"Unsupported message type: {msg_type}", + "code": "INVALID_MESSAGE_FORMAT", + }, + ) + return + + if session.is_processing: + await self._send_chat_payload( + session, + { + "ct": "chat", + "t": "error", + "data": "Session is busy", + "code": "PROCESSING_ERROR", + }, + ) + return + + payload = message.get("message") + session_id = message.get("session_id") or session.session_id + message_id = message.get("message_id") or str(uuid.uuid4()) + selected_provider = message.get("selected_provider") + selected_model = message.get("selected_model") + selected_stt_provider = message.get("selected_stt_provider") + selected_tts_provider = message.get("selected_tts_provider") + persona_prompt = message.get("persona_prompt") + show_reasoning = message.get("show_reasoning") + enable_streaming = message.get("enable_streaming", True) + + if not isinstance(payload, list): + await self._send_chat_payload( + session, + { + "ct": "chat", + "t": "error", + "data": "message must be list", + "code": "INVALID_MESSAGE_FORMAT", + }, + ) + return + + message_parts = await self._build_chat_message_parts(payload) + has_content = webchat_message_parts_have_content(message_parts) + if not has_content: + await self._send_chat_payload( + session, + { + "ct": "chat", + "t": "error", + "data": "Message content is empty", + "code": "INVALID_MESSAGE_FORMAT", + }, + ) + return + + await self._ensure_chat_subscription(session, session_id) + + session.is_processing = True + session.should_interrupt = False + back_queue = webchat_queue_mgr.get_or_create_back_queue(message_id, session_id) + + try: + chat_queue = webchat_queue_mgr.get_or_create_queue(session_id) + await chat_queue.put( + ( + session.username, + session_id, + { + "message": message_parts, + "selected_provider": selected_provider, + "selected_model": selected_model, + "selected_stt_provider": selected_stt_provider, + "selected_tts_provider": selected_tts_provider, + "persona_prompt": persona_prompt, + "show_reasoning": show_reasoning, + "enable_streaming": enable_streaming, + "message_id": message_id, + }, + ), + ) + + message_parts_for_storage = strip_message_parts_path_fields(message_parts) + await self.platform_history_mgr.insert( + platform_id="webchat", + user_id=session_id, + content={"type": "user", "message": message_parts_for_storage}, + sender_id=session.username, + sender_name=session.username, + ) + + accumulated_parts = [] + accumulated_text = "" + accumulated_reasoning = "" + tool_calls = {} + agent_stats = {} + refs = {} + + while True: + if session.should_interrupt: + session.should_interrupt = False + break + + try: + result = await asyncio.wait_for(back_queue.get(), timeout=1) + except asyncio.TimeoutError: + continue + + if not result: + continue + if result.get("message_id") and result.get("message_id") != message_id: + continue + + result_text = result.get("data", "") + msg_type = result.get("type") + streaming = result.get("streaming", False) + chain_type = result.get("chain_type") + if chain_type == "agent_stats": + try: + parsed_agent_stats = json.loads(result_text) + agent_stats = parsed_agent_stats + await self._send_chat_payload( + session, + { + "ct": "chat", + "type": "agent_stats", + "data": parsed_agent_stats, + }, + ) + except Exception: + pass + continue + + outgoing = {"ct": "chat", **result} + await self._send_chat_payload(session, outgoing) + + if msg_type == "plain": + if chain_type == "tool_call": + try: + tool_call = json.loads(result_text) + tool_calls[tool_call.get("id")] = tool_call + if accumulated_text: + accumulated_parts.append( + {"type": "plain", "text": accumulated_text} + ) + accumulated_text = "" + except Exception: + pass + elif chain_type == "tool_call_result": + try: + tcr = json.loads(result_text) + tc_id = tcr.get("id") + if tc_id in tool_calls: + tool_calls[tc_id]["result"] = tcr.get("result") + tool_calls[tc_id]["finished_ts"] = tcr.get("ts") + accumulated_parts.append( + { + "type": "tool_call", + "tool_calls": [tool_calls[tc_id]], + } + ) + tool_calls.pop(tc_id, None) + except Exception: + pass + elif chain_type == "reasoning": + accumulated_reasoning += result_text + elif streaming: + accumulated_text += result_text + else: + accumulated_text = result_text + elif msg_type == "image": + filename = str(result_text).replace("[IMAGE]", "") + part = await self._create_attachment_from_file(filename, "image") + if part: + accumulated_parts.append(part) + elif msg_type == "record": + filename = str(result_text).replace("[RECORD]", "") + part = await self._create_attachment_from_file(filename, "record") + if part: + accumulated_parts.append(part) + elif msg_type == "file": + filename = str(result_text).replace("[FILE]", "").split("|", 1)[0] + part = await self._create_attachment_from_file(filename, "file") + if part: + accumulated_parts.append(part) + elif msg_type == "video": + filename = str(result_text).replace("[VIDEO]", "").split("|", 1)[0] + part = await self._create_attachment_from_file(filename, "video") + if part: + accumulated_parts.append(part) + + should_save = False + if msg_type == "end": + should_save = bool( + accumulated_parts + or accumulated_text + or accumulated_reasoning + or refs + or agent_stats + ) + elif (streaming and msg_type == "complete") or not streaming: + if chain_type not in ( + "tool_call", + "tool_call_result", + "agent_stats", + ): + should_save = True + + if should_save: + try: + refs = self._extract_web_search_refs( + accumulated_text, + accumulated_parts, + ) + except Exception as e: + logger.exception( + f"[Live Chat] Failed to extract web search refs: {e}", + exc_info=True, + ) + + saved_record = await self._save_bot_message( + session_id, + accumulated_text, + accumulated_parts, + accumulated_reasoning, + agent_stats, + refs, + ) + if saved_record: + await self._send_chat_payload( + session, + { + "ct": "chat", + "type": "message_saved", + "data": { + "id": saved_record.id, + "created_at": saved_record.created_at.astimezone().isoformat(), + }, + }, + ) + + accumulated_parts = [] + accumulated_text = "" + accumulated_reasoning = "" + agent_stats = {} + refs = {} + + if msg_type == "end": + break + + except Exception as e: + logger.error(f"[Live Chat] 处理 chat 消息失败: {e}", exc_info=True) + await self._send_chat_payload( + session, + { + "ct": "chat", + "t": "error", + "data": f"处理失败: {str(e)}", + "code": "PROCESSING_ERROR", + }, + ) + finally: + session.is_processing = False + webchat_queue_mgr.remove_back_queue(message_id) + + async def _build_chat_message_parts(self, message: list[dict]) -> list[dict]: + """构建 chat websocket 用户消息段(复用 webchat 逻辑)""" + return await build_webchat_message_parts( + message, + get_attachment_by_id=self.db.get_attachment_by_id, + strict=False, + ) + async def _handle_message(self, session: LiveChatSession, message: dict) -> None: """处理 WebSocket 消息""" msg_type = message.get("t") # 使用 t 代替 type diff --git a/astrbot/dashboard/routes/open_api.py b/astrbot/dashboard/routes/open_api.py index c25870ebb1..055de6732e 100644 --- a/astrbot/dashboard/routes/open_api.py +++ b/astrbot/dashboard/routes/open_api.py @@ -1,4 +1,3 @@ -from pathlib import Path from uuid import uuid4 from quart import g, request @@ -6,9 +5,10 @@ from astrbot.core import logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.db import BaseDatabase -from astrbot.core.message.components import File, Image, Plain, Record, Reply, Video -from astrbot.core.message.message_event_result import MessageChain from astrbot.core.platform.message_session import MessageSesion +from astrbot.core.platform.sources.webchat.message_parts_helper import ( + build_message_chain_from_payload, +) from .chat import ChatRoute from .route import Response, Route, RouteContext @@ -254,83 +254,12 @@ async def get_chat_configs(self): async def _build_message_chain_from_payload( self, message_payload: str | list, - ) -> MessageChain: - if isinstance(message_payload, str): - text = message_payload.strip() - if not text: - raise ValueError("Message is empty") - return MessageChain(chain=[Plain(text=text)]) - - if not isinstance(message_payload, list): - raise ValueError("message must be a string or list") - - components = [] - has_content = False - - for part in message_payload: - if not isinstance(part, dict): - raise ValueError("message part must be an object") - - part_type = str(part.get("type", "")).strip() - if part_type == "plain": - text = str(part.get("text", "")) - if text: - has_content = True - components.append(Plain(text=text)) - continue - - if part_type == "reply": - message_id = part.get("message_id") - if message_id is None: - raise ValueError("reply part missing message_id") - components.append( - Reply( - id=str(message_id), - message_str=str(part.get("selected_text", "")), - chain=[], - ) - ) - continue - - if part_type not in {"image", "record", "file", "video"}: - raise ValueError(f"unsupported message part type: {part_type}") - - has_content = True - file_path: Path | None = None - resolved_type = part_type - filename = str(part.get("filename", "")).strip() - - attachment_id = part.get("attachment_id") - if attachment_id: - attachment = await self.db.get_attachment_by_id(str(attachment_id)) - if not attachment: - raise ValueError(f"attachment not found: {attachment_id}") - file_path = Path(attachment.path) - resolved_type = attachment.type - if not filename: - filename = file_path.name - else: - raise ValueError(f"{part_type} part missing attachment_id") - - if not file_path.exists(): - raise ValueError(f"file not found: {file_path!s}") - - file_path_str = str(file_path.resolve()) - if resolved_type == "image": - components.append(Image.fromFileSystem(file_path_str)) - elif resolved_type == "record": - components.append(Record.fromFileSystem(file_path_str)) - elif resolved_type == "video": - components.append(Video.fromFileSystem(file_path_str)) - else: - components.append( - File(name=filename or file_path.name, file=file_path_str) - ) - - if not components or not has_content: - raise ValueError("Message content is empty (reply only is not allowed)") - - return MessageChain(chain=components) + ): + return await build_message_chain_from_payload( + message_payload, + get_attachment_by_id=self.db.get_attachment_by_id, + strict=True, + ) async def send_message(self): post_data = await request.json or {} diff --git a/dashboard/src/components/chat/Chat.vue b/dashboard/src/components/chat/Chat.vue index 803c5d826a..054a186629 100644 --- a/dashboard/src/components/chat/Chat.vue +++ b/dashboard/src/components/chat/Chat.vue @@ -10,6 +10,7 @@ :selectedSessions="selectedSessions" :currSessionId="currSessionId" :selectedProjectId="selectedProjectId" + :transportMode="transportMode" :isDark="isDark" :chatboxMode="chatboxMode" :isMobile="isMobile" @@ -26,6 +27,7 @@ @createProject="showCreateProjectDialog" @editProject="showEditProjectDialog" @deleteProject="handleDeleteProject" + @updateTransportMode="setTransportMode" /> @@ -301,11 +303,14 @@ const { isStreaming, isConvRunning, enableStreaming, + transportMode, currentSessionProject, getSessionMessages: getSessionMsg, sendMessage: sendMsg, stopMessage: stopMsg, - toggleStreaming + toggleStreaming, + setTransportMode, + cleanupTransport } = useMessages(currSessionId, getMediaFile, updateSessionTitle, getSessions); // 组件引用 @@ -695,6 +700,7 @@ onMounted(() => { onBeforeUnmount(() => { window.removeEventListener('resize', checkMobile); cleanupMediaCache(); + cleanupTransport(); }); diff --git a/dashboard/src/components/chat/ConversationSidebar.vue b/dashboard/src/components/chat/ConversationSidebar.vue index a728930d9f..97f2179e76 100644 --- a/dashboard/src/components/chat/ConversationSidebar.vue +++ b/dashboard/src/components/chat/ConversationSidebar.vue @@ -117,6 +117,27 @@ {{ isDark ? tm('modes.lightMode') : tm('modes.darkMode') }} + + + + {{ tm('transport.title') }} + + +