From 53ae8cd7cf37fb3ed3194b4014a3366d907c2f19 Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Tue, 24 Feb 2026 21:14:16 +0800
Subject: [PATCH 1/3] feat: implement websockets transport mode selection for
chat
- Added transport mode selection (SSE/WebSocket) in the chat component.
- Updated conversation sidebar to include transport mode options.
- Integrated transport mode handling in message sending logic.
- Refactored message sending functions to support both SSE and WebSocket.
- Enhanced WebSocket connection management and message handling.
- Updated localization files for transport mode labels.
- Configured Vite to support WebSocket proxying.
---
.../sources/webchat/message_parts_helper.py | 333 ++++++
.../sources/webchat/webchat_adapter.py | 85 +-
.../platform/sources/webchat/webchat_event.py | 13 +-
.../sources/webchat/webchat_queue_mgr.py | 4 +
astrbot/dashboard/routes/chat.py | 118 +-
astrbot/dashboard/routes/live_chat.py | 512 +++++++-
astrbot/dashboard/routes/open_api.py | 89 +-
dashboard/src/components/chat/Chat.vue | 8 +-
.../components/chat/ConversationSidebar.vue | 37 +
dashboard/src/composables/useMessages.ts | 1028 ++++++++++++-----
.../src/i18n/locales/en-US/features/chat.json | 11 +-
.../src/i18n/locales/zh-CN/features/chat.json | 5 +
dashboard/vite.config.ts | 1 +
13 files changed, 1757 insertions(+), 487 deletions(-)
create mode 100644 astrbot/core/platform/sources/webchat/message_parts_helper.py
diff --git a/astrbot/core/platform/sources/webchat/message_parts_helper.py b/astrbot/core/platform/sources/webchat/message_parts_helper.py
new file mode 100644
index 0000000000..608e3448a2
--- /dev/null
+++ b/astrbot/core/platform/sources/webchat/message_parts_helper.py
@@ -0,0 +1,333 @@
+import json
+import mimetypes
+import shutil
+import uuid
+from collections.abc import Awaitable, Callable, Sequence
+from pathlib import Path
+
+from astrbot.core.db.po import Attachment
+from astrbot.core.message.components import (
+ File,
+ Image,
+ Json,
+ Plain,
+ Record,
+ Reply,
+ Video,
+)
+from astrbot.core.message.message_event_result import MessageChain
+
+AttachmentGetter = Callable[[str], Awaitable[Attachment | None]]
+AttachmentInserter = Callable[[str, str, str], Awaitable[Attachment | None]]
+
+MEDIA_PART_TYPES = {"image", "record", "file", "video"}
+
+
+def strip_message_parts_path_fields(message_parts: list[dict]) -> list[dict]:
+ return [{k: v for k, v in part.items() if k != "path"} for part in message_parts]
+
+
+def webchat_message_parts_have_content(message_parts: list[dict]) -> bool:
+ return any(
+ part.get("type") in ("plain", "image", "record", "file", "video")
+ and (part.get("text") or part.get("attachment_id") or part.get("filename"))
+ for part in message_parts
+ )
+
+
+async def build_webchat_message_parts(
+ message_payload: str | list,
+ *,
+ get_attachment_by_id: AttachmentGetter,
+ strict: bool = False,
+) -> list[dict]:
+ if isinstance(message_payload, str):
+ text = message_payload.strip()
+ return [{"type": "plain", "text": text}] if text else []
+
+ if not isinstance(message_payload, list):
+ if strict:
+ raise ValueError("message must be a string or list")
+ return []
+
+ message_parts: list[dict] = []
+ for part in message_payload:
+ if not isinstance(part, dict):
+ if strict:
+ raise ValueError("message part must be an object")
+ continue
+
+ part_type = str(part.get("type", "")).strip()
+ if part_type == "plain":
+ text = str(part.get("text", ""))
+ if text:
+ message_parts.append({"type": "plain", "text": text})
+ continue
+
+ if part_type == "reply":
+ message_id = part.get("message_id")
+ if message_id is None:
+ if strict:
+ raise ValueError("reply part missing message_id")
+ continue
+ message_parts.append(
+ {
+ "type": "reply",
+ "message_id": message_id,
+ "selected_text": str(part.get("selected_text", "")),
+ }
+ )
+ continue
+
+ if part_type not in MEDIA_PART_TYPES:
+ if strict:
+ raise ValueError(f"unsupported message part type: {part_type}")
+ continue
+
+ attachment_id = part.get("attachment_id")
+ if not attachment_id:
+ if strict:
+ raise ValueError(f"{part_type} part missing attachment_id")
+ continue
+
+ attachment = await get_attachment_by_id(str(attachment_id))
+ if not attachment:
+ if strict:
+ raise ValueError(f"attachment not found: {attachment_id}")
+ continue
+
+ attachment_path = Path(attachment.path)
+ message_parts.append(
+ {
+ "type": attachment.type,
+ "attachment_id": attachment.attachment_id,
+ "filename": attachment_path.name,
+ "path": str(attachment_path),
+ }
+ )
+
+ return message_parts
+
+
+def webchat_message_parts_to_message_chain(
+ message_parts: list[dict],
+ *,
+ strict: bool = False,
+) -> MessageChain:
+ components = []
+ has_content = False
+
+ for part in message_parts:
+ if not isinstance(part, dict):
+ if strict:
+ raise ValueError("message part must be an object")
+ continue
+
+ part_type = str(part.get("type", "")).strip()
+ if part_type == "plain":
+ text = str(part.get("text", ""))
+ if text:
+ components.append(Plain(text=text))
+ has_content = True
+ continue
+
+ if part_type == "reply":
+ message_id = part.get("message_id")
+ if message_id is None:
+ if strict:
+ raise ValueError("reply part missing message_id")
+ continue
+ components.append(
+ Reply(
+ id=str(message_id),
+ message_str=str(part.get("selected_text", "")),
+ chain=[],
+ )
+ )
+ continue
+
+ if part_type not in MEDIA_PART_TYPES:
+ if strict:
+ raise ValueError(f"unsupported message part type: {part_type}")
+ continue
+
+ path = part.get("path")
+ if not path:
+ if strict:
+ raise ValueError(f"{part_type} part missing path")
+ continue
+
+ file_path = Path(str(path))
+ if not file_path.exists():
+ if strict:
+ raise ValueError(f"file not found: {file_path!s}")
+ continue
+
+ file_path_str = str(file_path.resolve())
+ has_content = True
+ if part_type == "image":
+ components.append(Image.fromFileSystem(file_path_str))
+ elif part_type == "record":
+ components.append(Record.fromFileSystem(file_path_str))
+ elif part_type == "video":
+ components.append(Video.fromFileSystem(file_path_str))
+ else:
+ filename = str(part.get("filename", "")).strip() or file_path.name
+ components.append(File(name=filename, file=file_path_str))
+
+ if strict and (not components or not has_content):
+ raise ValueError("Message content is empty (reply only is not allowed)")
+
+ return MessageChain(chain=components)
+
+
+async def build_message_chain_from_payload(
+ message_payload: str | list,
+ *,
+ get_attachment_by_id: AttachmentGetter,
+ strict: bool = True,
+) -> MessageChain:
+ message_parts = await build_webchat_message_parts(
+ message_payload,
+ get_attachment_by_id=get_attachment_by_id,
+ strict=strict,
+ )
+ return webchat_message_parts_to_message_chain(message_parts, strict=strict)
+
+
+async def create_attachment_part_from_existing_file(
+ filename: str,
+ *,
+ attach_type: str,
+ insert_attachment: AttachmentInserter,
+ attachments_dir: str | Path,
+ fallback_dirs: Sequence[str | Path] = (),
+) -> dict | None:
+ basename = Path(filename).name
+ candidate_paths = [Path(attachments_dir) / basename]
+ candidate_paths.extend(Path(p) / basename for p in fallback_dirs)
+
+ file_path = next((path for path in candidate_paths if path.exists()), None)
+ if not file_path:
+ return None
+
+ mime_type, _ = mimetypes.guess_type(str(file_path))
+ attachment = await insert_attachment(
+ str(file_path),
+ attach_type,
+ mime_type or "application/octet-stream",
+ )
+ if not attachment:
+ return None
+
+ return {
+ "type": attach_type,
+ "attachment_id": attachment.attachment_id,
+ "filename": file_path.name,
+ }
+
+
+async def message_chain_to_storage_message_parts(
+ message_chain: MessageChain,
+ *,
+ insert_attachment: AttachmentInserter,
+ attachments_dir: str | Path,
+) -> list[dict]:
+ target_dir = Path(attachments_dir)
+ target_dir.mkdir(parents=True, exist_ok=True)
+
+ parts: list[dict] = []
+ for comp in message_chain.chain:
+ if isinstance(comp, Plain):
+ if comp.text:
+ parts.append({"type": "plain", "text": comp.text})
+ continue
+
+ if isinstance(comp, Json):
+ parts.append(
+ {"type": "plain", "text": json.dumps(comp.data, ensure_ascii=False)}
+ )
+ continue
+
+ if isinstance(comp, Image):
+ file_path = await comp.convert_to_file_path()
+ attachment_part = await _copy_file_to_attachment_part(
+ file_path=file_path,
+ attach_type="image",
+ insert_attachment=insert_attachment,
+ attachments_dir=target_dir,
+ )
+ if attachment_part:
+ parts.append(attachment_part)
+ continue
+
+ if isinstance(comp, Record):
+ file_path = await comp.convert_to_file_path()
+ attachment_part = await _copy_file_to_attachment_part(
+ file_path=file_path,
+ attach_type="record",
+ insert_attachment=insert_attachment,
+ attachments_dir=target_dir,
+ )
+ if attachment_part:
+ parts.append(attachment_part)
+ continue
+
+ if isinstance(comp, Video):
+ file_path = await comp.convert_to_file_path()
+ attachment_part = await _copy_file_to_attachment_part(
+ file_path=file_path,
+ attach_type="video",
+ insert_attachment=insert_attachment,
+ attachments_dir=target_dir,
+ )
+ if attachment_part:
+ parts.append(attachment_part)
+ continue
+
+ if isinstance(comp, File):
+ file_path = await comp.get_file()
+ attachment_part = await _copy_file_to_attachment_part(
+ file_path=file_path,
+ attach_type="file",
+ insert_attachment=insert_attachment,
+ attachments_dir=target_dir,
+ display_name=comp.name,
+ )
+ if attachment_part:
+ parts.append(attachment_part)
+ continue
+
+ return parts
+
+
+async def _copy_file_to_attachment_part(
+ *,
+ file_path: str,
+ attach_type: str,
+ insert_attachment: AttachmentInserter,
+ attachments_dir: Path,
+ display_name: str | None = None,
+) -> dict | None:
+ src_path = Path(file_path)
+ if not src_path.exists() or not src_path.is_file():
+ return None
+
+ suffix = src_path.suffix
+ target_path = attachments_dir / f"{uuid.uuid4().hex}{suffix}"
+ shutil.copy2(src_path, target_path)
+
+ mime_type, _ = mimetypes.guess_type(target_path.name)
+ attachment = await insert_attachment(
+ str(target_path),
+ attach_type,
+ mime_type or "application/octet-stream",
+ )
+ if not attachment:
+ return None
+
+ return {
+ "type": attach_type,
+ "attachment_id": attachment.attachment_id,
+ "filename": display_name or src_path.name,
+ }
diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py
index 047417aaaa..e72594d8a0 100644
--- a/astrbot/core/platform/sources/webchat/webchat_adapter.py
+++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py
@@ -3,12 +3,20 @@
import time
import uuid
from collections.abc import Callable, Coroutine
+from pathlib import Path
from typing import Any
from astrbot import logger
from astrbot.core import db_helper
from astrbot.core.db.po import PlatformMessageHistory
-from astrbot.core.message.components import File, Image, Plain, Record, Reply, Video
+from astrbot.core.message.components import (
+ File,
+ Image,
+ Plain,
+ Record,
+ Reply,
+ Video,
+)
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.platform import (
AstrBotMessage,
@@ -21,10 +29,20 @@
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from ...register import register_platform_adapter
+from .message_parts_helper import message_chain_to_storage_message_parts
from .webchat_event import WebChatMessageEvent
from .webchat_queue_mgr import WebChatQueueMgr, webchat_queue_mgr
+def _extract_conversation_id(session_id: str) -> str:
+ """Extract raw webchat conversation id from event/session id."""
+ if session_id.startswith("webchat!"):
+ parts = session_id.split("!", 2)
+ if len(parts) == 3:
+ return parts[2]
+ return session_id
+
+
class QueueListener:
def __init__(
self,
@@ -57,13 +75,15 @@ def __init__(
self.settings = platform_settings
self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
+ self.attachments_dir = Path(get_astrbot_data_path()) / "attachments"
os.makedirs(self.imgs_dir, exist_ok=True)
+ self.attachments_dir.mkdir(parents=True, exist_ok=True)
self.metadata = PlatformMetadata(
name="webchat",
description="webchat",
id="webchat",
- support_proactive_message=False,
+ support_proactive_message=True,
)
self._shutdown_event = asyncio.Event()
self._webchat_queue_mgr = webchat_queue_mgr
@@ -73,10 +93,67 @@ async def send_by_session(
session: MessageSesion,
message_chain: MessageChain,
) -> None:
- message_id = f"active_{str(uuid.uuid4())}"
- await WebChatMessageEvent._send(message_id, message_chain, session.session_id)
+ conversation_id = _extract_conversation_id(session.session_id)
+ active_request_ids = self._webchat_queue_mgr.list_back_request_ids(
+ conversation_id
+ )
+ subscription_request_ids = [
+ req_id for req_id in active_request_ids if req_id.startswith("ws_sub_")
+ ]
+ target_request_ids = subscription_request_ids or active_request_ids
+
+ if target_request_ids:
+ for request_id in target_request_ids:
+ await WebChatMessageEvent._send(
+ request_id,
+ message_chain,
+ session.session_id,
+ )
+ else:
+ message_id = f"active_{uuid.uuid4()!s}"
+ await WebChatMessageEvent._send(
+ message_id,
+ message_chain,
+ session.session_id,
+ )
+
+ should_persist = (
+ bool(subscription_request_ids)
+ or not active_request_ids
+ or all(req_id.startswith("active_") for req_id in active_request_ids)
+ )
+ if should_persist:
+ try:
+ await self._save_proactive_message(conversation_id, message_chain)
+ except Exception as e:
+ logger.error(
+ f"[WebChatAdapter] Failed to save proactive message: {e}",
+ exc_info=True,
+ )
+
await super().send_by_session(session, message_chain)
+ async def _save_proactive_message(
+ self,
+ conversation_id: str,
+ message_chain: MessageChain,
+ ) -> None:
+ message_parts = await message_chain_to_storage_message_parts(
+ message_chain,
+ insert_attachment=db_helper.insert_attachment,
+ attachments_dir=self.attachments_dir,
+ )
+ if not message_parts:
+ return
+
+ await db_helper.insert_platform_message_history(
+ platform_id="webchat",
+ user_id=conversation_id,
+ content={"type": "bot", "message": message_parts},
+ sender_id="bot",
+ sender_name="bot",
+ )
+
async def _get_message_history(
self, message_id: int
) -> PlatformMessageHistory | None:
diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py
index a680f76174..b7da864aae 100644
--- a/astrbot/core/platform/sources/webchat/webchat_event.py
+++ b/astrbot/core/platform/sources/webchat/webchat_event.py
@@ -14,6 +14,15 @@
attachments_dir = os.path.join(get_astrbot_data_path(), "attachments")
+def _extract_conversation_id(session_id: str) -> str:
+ """Extract raw webchat conversation id from event/session id."""
+ if session_id.startswith("webchat!"):
+ parts = session_id.split("!", 2)
+ if len(parts) == 3:
+ return parts[2]
+ return session_id
+
+
class WebChatMessageEvent(AstrMessageEvent):
def __init__(self, message_str, message_obj, platform_meta, session_id) -> None:
super().__init__(message_str, message_obj, platform_meta, session_id)
@@ -27,7 +36,7 @@ async def _send(
streaming: bool = False,
) -> str | None:
request_id = str(message_id)
- conversation_id = session_id.split("!")[-1]
+ conversation_id = _extract_conversation_id(session_id)
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(
request_id,
conversation_id,
@@ -130,7 +139,7 @@ async def send_streaming(self, generator, use_fallback: bool = False) -> None:
reasoning_content = ""
message_id = self.message_obj.message_id
request_id = str(message_id)
- conversation_id = self.session_id.split("!")[-1]
+ conversation_id = _extract_conversation_id(self.session_id)
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(
request_id,
conversation_id,
diff --git a/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py b/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py
index fd35e837c8..f3ade1589a 100644
--- a/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py
+++ b/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py
@@ -75,6 +75,10 @@ def remove_queue(self, conversation_id: str):
if task is not None:
task.cancel()
+ def list_back_request_ids(self, conversation_id: str) -> list[str]:
+ """List active back-queue request IDs for a conversation."""
+ return list(self._conversation_back_requests.get(conversation_id, set()))
+
def has_queue(self, conversation_id: str) -> bool:
"""Check if a queue exists for the given conversation ID"""
return conversation_id in self.queues
diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py
index 1235dd3814..0602cc0745 100644
--- a/astrbot/dashboard/routes/chat.py
+++ b/astrbot/dashboard/routes/chat.py
@@ -1,6 +1,5 @@
import asyncio
import json
-import mimetypes
import os
import re
import uuid
@@ -14,6 +13,12 @@
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase
from astrbot.core.platform.message_type import MessageType
+from astrbot.core.platform.sources.webchat.message_parts_helper import (
+ build_webchat_message_parts,
+ create_attachment_part_from_existing_file,
+ strip_message_parts_path_fields,
+ webchat_message_parts_have_content,
+)
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
from astrbot.core.utils.active_event_registry import active_event_registry
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
@@ -166,83 +171,24 @@ async def post_file(self):
)
async def _build_user_message_parts(self, message: str | list) -> list[dict]:
- """构建用户消息的部分列表
-
- Args:
- message: 文本消息 (str) 或消息段列表 (list)
- """
- parts = []
-
- if isinstance(message, list):
- for part in message:
- part_type = part.get("type")
- if part_type == "plain":
- parts.append({"type": "plain", "text": part.get("text", "")})
- elif part_type == "reply":
- parts.append(
- {
- "type": "reply",
- "message_id": part.get("message_id"),
- "selected_text": part.get("selected_text", ""),
- }
- )
- elif attachment_id := part.get("attachment_id"):
- attachment = await self.db.get_attachment_by_id(attachment_id)
- if attachment:
- parts.append(
- {
- "type": attachment.type,
- "attachment_id": attachment.attachment_id,
- "filename": os.path.basename(attachment.path),
- "path": attachment.path, # will be deleted
- }
- )
- return parts
-
- if message:
- parts.append({"type": "plain", "text": message})
-
- return parts
+ """构建用户消息的部分列表。"""
+ return await build_webchat_message_parts(
+ message,
+ get_attachment_by_id=self.db.get_attachment_by_id,
+ strict=False,
+ )
async def _create_attachment_from_file(
self, filename: str, attach_type: str
) -> dict | None:
- """从本地文件创建 attachment 并返回消息部分
-
- 用于处理 bot 回复中的媒体文件
-
- Args:
- filename: 存储的文件名
- attach_type: 附件类型 (image, record, file, video)
- """
- basename = os.path.basename(filename)
- candidate_paths = [
- os.path.join(self.attachments_dir, basename),
- os.path.join(self.legacy_img_dir, basename),
- ]
- file_path = next((p for p in candidate_paths if os.path.exists(p)), None)
- if not file_path:
- return None
-
- # guess mime type
- mime_type, _ = mimetypes.guess_type(filename)
- if not mime_type:
- mime_type = "application/octet-stream"
-
- # insert attachment
- attachment = await self.db.insert_attachment(
- path=file_path,
- type=attach_type,
- mime_type=mime_type,
+ """从本地文件创建 attachment 并返回消息部分。"""
+ return await create_attachment_part_from_existing_file(
+ filename,
+ attach_type=attach_type,
+ insert_attachment=self.db.insert_attachment,
+ attachments_dir=self.attachments_dir,
+ fallback_dirs=[self.legacy_img_dir],
)
- if not attachment:
- return None
-
- return {
- "type": attach_type,
- "attachment_id": attachment.attachment_id,
- "filename": os.path.basename(file_path),
- }
def _extract_web_search_refs(
self, accumulated_text: str, accumulated_parts: list
@@ -356,21 +302,6 @@ async def chat(self, post_data: dict | None = None):
selected_model = post_data.get("selected_model")
enable_streaming = post_data.get("enable_streaming", True)
- # 检查消息是否为空
- if isinstance(message, list):
- has_content = any(
- part.get("type") in ("plain", "image", "record", "file", "video")
- for part in message
- )
- if not has_content:
- return (
- Response()
- .error("Message content is empty (reply only is not allowed)")
- .__dict__
- )
- elif not message:
- return Response().error("Message are both empty").__dict__
-
if not session_id:
return Response().error("session_id is empty").__dict__
@@ -378,6 +309,12 @@ async def chat(self, post_data: dict | None = None):
# 构建用户消息段(包含 path 用于传递给 adapter)
message_parts = await self._build_user_message_parts(message)
+ if not webchat_message_parts_have_content(message_parts):
+ return (
+ Response()
+ .error("Message content is empty (reply only is not allowed)")
+ .__dict__
+ )
message_id = str(uuid.uuid4())
back_queue = webchat_queue_mgr.get_or_create_back_queue(
@@ -583,10 +520,7 @@ async def stream():
),
)
- message_parts_for_storage = []
- for part in message_parts:
- part_copy = {k: v for k, v in part.items() if k != "path"}
- message_parts_for_storage.append(part_copy)
+ message_parts_for_storage = strip_message_parts_path_fields(message_parts)
await self.platform_history_mgr.insert(
platform_id="webchat",
diff --git a/astrbot/dashboard/routes/live_chat.py b/astrbot/dashboard/routes/live_chat.py
index 8c922ab69a..25438565e1 100644
--- a/astrbot/dashboard/routes/live_chat.py
+++ b/astrbot/dashboard/routes/live_chat.py
@@ -1,6 +1,7 @@
import asyncio
import json
import os
+import re
import time
import uuid
import wave
@@ -10,9 +11,16 @@
from quart import websocket
from astrbot import logger
+from astrbot.core import sp
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
+from astrbot.core.platform.sources.webchat.message_parts_helper import (
+ build_webchat_message_parts,
+ create_attachment_part_from_existing_file,
+ strip_message_parts_path_fields,
+ webchat_message_parts_have_content,
+)
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
-from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
+from astrbot.core.utils.astrbot_path import get_astrbot_data_path, get_astrbot_temp_path
from .route import Route, RouteContext
@@ -30,6 +38,9 @@ def __init__(self, session_id: str, username: str) -> None:
self.audio_frames: list[bytes] = []
self.current_stamp: str | None = None
self.temp_audio_path: str | None = None
+ self.chat_subscriptions: dict[str, str] = {}
+ self.chat_subscription_tasks: dict[str, asyncio.Task] = {}
+ self.ws_send_lock = asyncio.Lock()
def start_speaking(self, stamp: str) -> None:
"""开始说话"""
@@ -106,13 +117,26 @@ def __init__(
self.core_lifecycle = core_lifecycle
self.db = db
self.plugin_manager = core_lifecycle.plugin_manager
+ self.platform_history_mgr = core_lifecycle.platform_message_history_manager
self.sessions: dict[str, LiveChatSession] = {}
+ self.attachments_dir = os.path.join(get_astrbot_data_path(), "attachments")
+ self.legacy_img_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
+ os.makedirs(self.attachments_dir, exist_ok=True)
# 注册 WebSocket 路由
self.app.websocket("/api/live_chat/ws")(self.live_chat_ws)
+ self.app.websocket("/api/unified_chat/ws")(self.unified_chat_ws)
async def live_chat_ws(self) -> None:
- """Live Chat WebSocket 处理器"""
+ """Legacy Live Chat WebSocket 处理器(默认 ct=live)"""
+ await self._unified_ws_loop(force_ct="live")
+
+ async def unified_chat_ws(self) -> None:
+ """Unified Chat WebSocket 处理器(支持 ct=live/chat)"""
+ await self._unified_ws_loop(force_ct=None)
+
+ async def _unified_ws_loop(self, force_ct: str | None = None) -> None:
+ """统一 WebSocket 循环"""
# WebSocket 不能通过 header 传递 token,需要从 query 参数获取
# 注意:WebSocket 上下文使用 websocket.args 而不是 request.args
token = websocket.args.get("token")
@@ -140,7 +164,11 @@ async def live_chat_ws(self) -> None:
try:
while True:
message = await websocket.receive_json()
- await self._handle_message(live_session, message)
+ ct = force_ct or message.get("ct", "live")
+ if ct == "chat":
+ await self._handle_chat_message(live_session, message)
+ else:
+ await self._handle_message(live_session, message)
except Exception as e:
logger.error(f"[Live Chat] WebSocket 错误: {e}", exc_info=True)
@@ -148,10 +176,488 @@ async def live_chat_ws(self) -> None:
finally:
# 清理会话
if session_id in self.sessions:
+ await self._cleanup_chat_subscriptions(live_session)
live_session.cleanup()
del self.sessions[session_id]
logger.info(f"[Live Chat] WebSocket 连接关闭: {username}")
+ async def _create_attachment_from_file(
+ self, filename: str, attach_type: str
+ ) -> dict | None:
+ """从本地文件创建 attachment 并返回消息部分。"""
+ return await create_attachment_part_from_existing_file(
+ filename,
+ attach_type=attach_type,
+ insert_attachment=self.db.insert_attachment,
+ attachments_dir=self.attachments_dir,
+ fallback_dirs=[self.legacy_img_dir],
+ )
+
+ def _extract_web_search_refs(
+ self, accumulated_text: str, accumulated_parts: list
+ ) -> dict:
+ """从消息中提取 web_search 引用。"""
+ supported = ["web_search_tavily", "web_search_bocha"]
+ web_search_results = {}
+ tool_call_parts = [
+ p
+ for p in accumulated_parts
+ if p.get("type") == "tool_call" and p.get("tool_calls")
+ ]
+
+ for part in tool_call_parts:
+ for tool_call in part["tool_calls"]:
+ if tool_call.get("name") not in supported or not tool_call.get(
+ "result"
+ ):
+ continue
+ try:
+ result_data = json.loads(tool_call["result"])
+ for item in result_data.get("results", []):
+ if idx := item.get("index"):
+ web_search_results[idx] = {
+ "url": item.get("url"),
+ "title": item.get("title"),
+ "snippet": item.get("snippet"),
+ }
+ except (json.JSONDecodeError, KeyError):
+ pass
+
+ if not web_search_results:
+ return {}
+
+ ref_indices = {
+ m.strip() for m in re.findall(r"[(.*?)]", accumulated_text)
+ }
+
+ used_refs = []
+ for ref_index in ref_indices:
+ if ref_index not in web_search_results:
+ continue
+ payload = {"index": ref_index, **web_search_results[ref_index]}
+ if favicon := sp.temporary_cache.get("_ws_favicon", {}).get(payload["url"]):
+ payload["favicon"] = favicon
+ used_refs.append(payload)
+
+ return {"used": used_refs} if used_refs else {}
+
+ async def _save_bot_message(
+ self,
+ webchat_conv_id: str,
+ text: str,
+ media_parts: list,
+ reasoning: str,
+ agent_stats: dict,
+ refs: dict,
+ ):
+ """保存 bot 消息到历史记录。"""
+ bot_message_parts = []
+ bot_message_parts.extend(media_parts)
+ if text:
+ bot_message_parts.append({"type": "plain", "text": text})
+
+ new_his = {"type": "bot", "message": bot_message_parts}
+ if reasoning:
+ new_his["reasoning"] = reasoning
+ if agent_stats:
+ new_his["agent_stats"] = agent_stats
+ if refs:
+ new_his["refs"] = refs
+
+ return await self.platform_history_mgr.insert(
+ platform_id="webchat",
+ user_id=webchat_conv_id,
+ content=new_his,
+ sender_id="bot",
+ sender_name="bot",
+ )
+
+ async def _send_chat_payload(self, session: LiveChatSession, payload: dict) -> None:
+ async with session.ws_send_lock:
+ await websocket.send_json(payload)
+
+ async def _forward_chat_subscription(
+ self,
+ session: LiveChatSession,
+ chat_session_id: str,
+ request_id: str,
+ ) -> None:
+ back_queue = webchat_queue_mgr.get_or_create_back_queue(
+ request_id, chat_session_id
+ )
+ try:
+ while True:
+ result = await back_queue.get()
+ if not result:
+ continue
+ await self._send_chat_payload(session, {"ct": "chat", **result})
+ except asyncio.CancelledError:
+ pass
+ except Exception as e:
+ logger.error(
+ f"[Live Chat] chat subscription forward failed ({chat_session_id}): {e}",
+ exc_info=True,
+ )
+ finally:
+ webchat_queue_mgr.remove_back_queue(request_id)
+ if session.chat_subscriptions.get(chat_session_id) == request_id:
+ session.chat_subscriptions.pop(chat_session_id, None)
+ session.chat_subscription_tasks.pop(chat_session_id, None)
+
+ async def _ensure_chat_subscription(
+ self,
+ session: LiveChatSession,
+ chat_session_id: str,
+ ) -> str:
+ existing_request_id = session.chat_subscriptions.get(chat_session_id)
+ existing_task = session.chat_subscription_tasks.get(chat_session_id)
+ if existing_request_id and existing_task and not existing_task.done():
+ return existing_request_id
+
+ request_id = f"ws_sub_{uuid.uuid4().hex}"
+ session.chat_subscriptions[chat_session_id] = request_id
+ task = asyncio.create_task(
+ self._forward_chat_subscription(session, chat_session_id, request_id),
+ name=f"chat_ws_sub_{chat_session_id}",
+ )
+ session.chat_subscription_tasks[chat_session_id] = task
+ return request_id
+
+ async def _cleanup_chat_subscriptions(self, session: LiveChatSession) -> None:
+ tasks = list(session.chat_subscription_tasks.values())
+ for task in tasks:
+ task.cancel()
+ if tasks:
+ await asyncio.gather(*tasks, return_exceptions=True)
+
+ for request_id in list(session.chat_subscriptions.values()):
+ webchat_queue_mgr.remove_back_queue(request_id)
+ session.chat_subscriptions.clear()
+ session.chat_subscription_tasks.clear()
+
+ async def _handle_chat_message(
+ self, session: LiveChatSession, message: dict
+ ) -> None:
+ """处理 Chat Mode 消息(ct=chat)"""
+ msg_type = message.get("t")
+
+ if msg_type == "bind":
+ chat_session_id = message.get("session_id")
+ if not isinstance(chat_session_id, str) or not chat_session_id:
+ await self._send_chat_payload(
+ session,
+ {
+ "ct": "chat",
+ "t": "error",
+ "data": "session_id is required",
+ "code": "INVALID_MESSAGE_FORMAT",
+ },
+ )
+ return
+
+ request_id = await self._ensure_chat_subscription(session, chat_session_id)
+ await self._send_chat_payload(
+ session,
+ {
+ "ct": "chat",
+ "type": "session_bound",
+ "session_id": chat_session_id,
+ "message_id": request_id,
+ },
+ )
+ return
+
+ if msg_type == "interrupt":
+ session.should_interrupt = True
+ await self._send_chat_payload(
+ session,
+ {
+ "ct": "chat",
+ "t": "error",
+ "data": "INTERRUPTED",
+ "code": "INTERRUPTED",
+ },
+ )
+ return
+
+ if msg_type != "send":
+ await self._send_chat_payload(
+ session,
+ {
+ "ct": "chat",
+ "t": "error",
+ "data": f"Unsupported message type: {msg_type}",
+ "code": "INVALID_MESSAGE_FORMAT",
+ },
+ )
+ return
+
+ if session.is_processing:
+ await self._send_chat_payload(
+ session,
+ {
+ "ct": "chat",
+ "t": "error",
+ "data": "Session is busy",
+ "code": "PROCESSING_ERROR",
+ },
+ )
+ return
+
+ payload = message.get("message")
+ session_id = message.get("session_id") or session.session_id
+ message_id = message.get("message_id") or str(uuid.uuid4())
+ selected_provider = message.get("selected_provider")
+ selected_model = message.get("selected_model")
+ selected_stt_provider = message.get("selected_stt_provider")
+ selected_tts_provider = message.get("selected_tts_provider")
+ persona_prompt = message.get("persona_prompt")
+ show_reasoning = message.get("show_reasoning")
+ enable_streaming = message.get("enable_streaming", True)
+
+ if not isinstance(payload, list):
+ await self._send_chat_payload(
+ session,
+ {
+ "ct": "chat",
+ "t": "error",
+ "data": "message must be list",
+ "code": "INVALID_MESSAGE_FORMAT",
+ },
+ )
+ return
+
+ message_parts = await self._build_chat_message_parts(payload)
+ has_content = webchat_message_parts_have_content(message_parts)
+ if not has_content:
+ await self._send_chat_payload(
+ session,
+ {
+ "ct": "chat",
+ "t": "error",
+ "data": "Message content is empty",
+ "code": "INVALID_MESSAGE_FORMAT",
+ },
+ )
+ return
+
+ await self._ensure_chat_subscription(session, session_id)
+
+ session.is_processing = True
+ session.should_interrupt = False
+ back_queue = webchat_queue_mgr.get_or_create_back_queue(message_id, session_id)
+
+ try:
+ chat_queue = webchat_queue_mgr.get_or_create_queue(session_id)
+ await chat_queue.put(
+ (
+ session.username,
+ session_id,
+ {
+ "message": message_parts,
+ "selected_provider": selected_provider,
+ "selected_model": selected_model,
+ "selected_stt_provider": selected_stt_provider,
+ "selected_tts_provider": selected_tts_provider,
+ "persona_prompt": persona_prompt,
+ "show_reasoning": show_reasoning,
+ "enable_streaming": enable_streaming,
+ "message_id": message_id,
+ },
+ ),
+ )
+
+ message_parts_for_storage = strip_message_parts_path_fields(message_parts)
+ await self.platform_history_mgr.insert(
+ platform_id="webchat",
+ user_id=session_id,
+ content={"type": "user", "message": message_parts_for_storage},
+ sender_id=session.username,
+ sender_name=session.username,
+ )
+
+ accumulated_parts = []
+ accumulated_text = ""
+ accumulated_reasoning = ""
+ tool_calls = {}
+ agent_stats = {}
+ refs = {}
+
+ while True:
+ if session.should_interrupt:
+ session.should_interrupt = False
+ break
+
+ try:
+ result = await asyncio.wait_for(back_queue.get(), timeout=1)
+ except asyncio.TimeoutError:
+ continue
+
+ if not result:
+ continue
+ if result.get("message_id") and result.get("message_id") != message_id:
+ continue
+
+ result_text = result.get("data", "")
+ msg_type = result.get("type")
+ streaming = result.get("streaming", False)
+ chain_type = result.get("chain_type")
+ if chain_type == "agent_stats":
+ try:
+ parsed_agent_stats = json.loads(result_text)
+ agent_stats = parsed_agent_stats
+ await self._send_chat_payload(
+ session,
+ {
+ "ct": "chat",
+ "type": "agent_stats",
+ "data": parsed_agent_stats,
+ },
+ )
+ except Exception:
+ pass
+ continue
+
+ outgoing = {"ct": "chat", **result}
+ await self._send_chat_payload(session, outgoing)
+
+ if msg_type == "plain":
+ if chain_type == "tool_call":
+ try:
+ tool_call = json.loads(result_text)
+ tool_calls[tool_call.get("id")] = tool_call
+ if accumulated_text:
+ accumulated_parts.append(
+ {"type": "plain", "text": accumulated_text}
+ )
+ accumulated_text = ""
+ except Exception:
+ pass
+ elif chain_type == "tool_call_result":
+ try:
+ tcr = json.loads(result_text)
+ tc_id = tcr.get("id")
+ if tc_id in tool_calls:
+ tool_calls[tc_id]["result"] = tcr.get("result")
+ tool_calls[tc_id]["finished_ts"] = tcr.get("ts")
+ accumulated_parts.append(
+ {
+ "type": "tool_call",
+ "tool_calls": [tool_calls[tc_id]],
+ }
+ )
+ tool_calls.pop(tc_id, None)
+ except Exception:
+ pass
+ elif chain_type == "reasoning":
+ accumulated_reasoning += result_text
+ elif streaming:
+ accumulated_text += result_text
+ else:
+ accumulated_text = result_text
+ elif msg_type == "image":
+ filename = str(result_text).replace("[IMAGE]", "")
+ part = await self._create_attachment_from_file(filename, "image")
+ if part:
+ accumulated_parts.append(part)
+ elif msg_type == "record":
+ filename = str(result_text).replace("[RECORD]", "")
+ part = await self._create_attachment_from_file(filename, "record")
+ if part:
+ accumulated_parts.append(part)
+ elif msg_type == "file":
+ filename = str(result_text).replace("[FILE]", "").split("|", 1)[0]
+ part = await self._create_attachment_from_file(filename, "file")
+ if part:
+ accumulated_parts.append(part)
+ elif msg_type == "video":
+ filename = str(result_text).replace("[VIDEO]", "").split("|", 1)[0]
+ part = await self._create_attachment_from_file(filename, "video")
+ if part:
+ accumulated_parts.append(part)
+
+ should_save = False
+ if msg_type == "end":
+ should_save = bool(
+ accumulated_parts
+ or accumulated_text
+ or accumulated_reasoning
+ or refs
+ or agent_stats
+ )
+ elif (streaming and msg_type == "complete") or not streaming:
+ if chain_type not in (
+ "tool_call",
+ "tool_call_result",
+ "agent_stats",
+ ):
+ should_save = True
+
+ if should_save:
+ try:
+ refs = self._extract_web_search_refs(
+ accumulated_text,
+ accumulated_parts,
+ )
+ except Exception as e:
+ logger.exception(
+ f"[Live Chat] Failed to extract web search refs: {e}",
+ exc_info=True,
+ )
+
+ saved_record = await self._save_bot_message(
+ session_id,
+ accumulated_text,
+ accumulated_parts,
+ accumulated_reasoning,
+ agent_stats,
+ refs,
+ )
+ if saved_record:
+ await self._send_chat_payload(
+ session,
+ {
+ "ct": "chat",
+ "type": "message_saved",
+ "data": {
+ "id": saved_record.id,
+ "created_at": saved_record.created_at.astimezone().isoformat(),
+ },
+ },
+ )
+
+ accumulated_parts = []
+ accumulated_text = ""
+ accumulated_reasoning = ""
+ agent_stats = {}
+ refs = {}
+
+ if msg_type == "end":
+ break
+
+ except Exception as e:
+ logger.error(f"[Live Chat] 处理 chat 消息失败: {e}", exc_info=True)
+ await self._send_chat_payload(
+ session,
+ {
+ "ct": "chat",
+ "t": "error",
+ "data": f"处理失败: {str(e)}",
+ "code": "PROCESSING_ERROR",
+ },
+ )
+ finally:
+ session.is_processing = False
+ webchat_queue_mgr.remove_back_queue(message_id)
+
+ async def _build_chat_message_parts(self, message: list[dict]) -> list[dict]:
+ """构建 chat websocket 用户消息段(复用 webchat 逻辑)"""
+ return await build_webchat_message_parts(
+ message,
+ get_attachment_by_id=self.db.get_attachment_by_id,
+ strict=False,
+ )
+
async def _handle_message(self, session: LiveChatSession, message: dict) -> None:
"""处理 WebSocket 消息"""
msg_type = message.get("t") # 使用 t 代替 type
diff --git a/astrbot/dashboard/routes/open_api.py b/astrbot/dashboard/routes/open_api.py
index c25870ebb1..055de6732e 100644
--- a/astrbot/dashboard/routes/open_api.py
+++ b/astrbot/dashboard/routes/open_api.py
@@ -1,4 +1,3 @@
-from pathlib import Path
from uuid import uuid4
from quart import g, request
@@ -6,9 +5,10 @@
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase
-from astrbot.core.message.components import File, Image, Plain, Record, Reply, Video
-from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.platform.message_session import MessageSesion
+from astrbot.core.platform.sources.webchat.message_parts_helper import (
+ build_message_chain_from_payload,
+)
from .chat import ChatRoute
from .route import Response, Route, RouteContext
@@ -254,83 +254,12 @@ async def get_chat_configs(self):
async def _build_message_chain_from_payload(
self,
message_payload: str | list,
- ) -> MessageChain:
- if isinstance(message_payload, str):
- text = message_payload.strip()
- if not text:
- raise ValueError("Message is empty")
- return MessageChain(chain=[Plain(text=text)])
-
- if not isinstance(message_payload, list):
- raise ValueError("message must be a string or list")
-
- components = []
- has_content = False
-
- for part in message_payload:
- if not isinstance(part, dict):
- raise ValueError("message part must be an object")
-
- part_type = str(part.get("type", "")).strip()
- if part_type == "plain":
- text = str(part.get("text", ""))
- if text:
- has_content = True
- components.append(Plain(text=text))
- continue
-
- if part_type == "reply":
- message_id = part.get("message_id")
- if message_id is None:
- raise ValueError("reply part missing message_id")
- components.append(
- Reply(
- id=str(message_id),
- message_str=str(part.get("selected_text", "")),
- chain=[],
- )
- )
- continue
-
- if part_type not in {"image", "record", "file", "video"}:
- raise ValueError(f"unsupported message part type: {part_type}")
-
- has_content = True
- file_path: Path | None = None
- resolved_type = part_type
- filename = str(part.get("filename", "")).strip()
-
- attachment_id = part.get("attachment_id")
- if attachment_id:
- attachment = await self.db.get_attachment_by_id(str(attachment_id))
- if not attachment:
- raise ValueError(f"attachment not found: {attachment_id}")
- file_path = Path(attachment.path)
- resolved_type = attachment.type
- if not filename:
- filename = file_path.name
- else:
- raise ValueError(f"{part_type} part missing attachment_id")
-
- if not file_path.exists():
- raise ValueError(f"file not found: {file_path!s}")
-
- file_path_str = str(file_path.resolve())
- if resolved_type == "image":
- components.append(Image.fromFileSystem(file_path_str))
- elif resolved_type == "record":
- components.append(Record.fromFileSystem(file_path_str))
- elif resolved_type == "video":
- components.append(Video.fromFileSystem(file_path_str))
- else:
- components.append(
- File(name=filename or file_path.name, file=file_path_str)
- )
-
- if not components or not has_content:
- raise ValueError("Message content is empty (reply only is not allowed)")
-
- return MessageChain(chain=components)
+ ):
+ return await build_message_chain_from_payload(
+ message_payload,
+ get_attachment_by_id=self.db.get_attachment_by_id,
+ strict=True,
+ )
async def send_message(self):
post_data = await request.json or {}
diff --git a/dashboard/src/components/chat/Chat.vue b/dashboard/src/components/chat/Chat.vue
index 803c5d826a..054a186629 100644
--- a/dashboard/src/components/chat/Chat.vue
+++ b/dashboard/src/components/chat/Chat.vue
@@ -10,6 +10,7 @@
:selectedSessions="selectedSessions"
:currSessionId="currSessionId"
:selectedProjectId="selectedProjectId"
+ :transportMode="transportMode"
:isDark="isDark"
:chatboxMode="chatboxMode"
:isMobile="isMobile"
@@ -26,6 +27,7 @@
@createProject="showCreateProjectDialog"
@editProject="showEditProjectDialog"
@deleteProject="handleDeleteProject"
+ @updateTransportMode="setTransportMode"
/>
@@ -301,11 +303,14 @@ const {
isStreaming,
isConvRunning,
enableStreaming,
+ transportMode,
currentSessionProject,
getSessionMessages: getSessionMsg,
sendMessage: sendMsg,
stopMessage: stopMsg,
- toggleStreaming
+ toggleStreaming,
+ setTransportMode,
+ cleanupTransport
} = useMessages(currSessionId, getMediaFile, updateSessionTitle, getSessions);
// 组件引用
@@ -695,6 +700,7 @@ onMounted(() => {
onBeforeUnmount(() => {
window.removeEventListener('resize', checkMobile);
cleanupMediaCache();
+ cleanupTransport();
});
diff --git a/dashboard/src/components/chat/ConversationSidebar.vue b/dashboard/src/components/chat/ConversationSidebar.vue
index a728930d9f..97f2179e76 100644
--- a/dashboard/src/components/chat/ConversationSidebar.vue
+++ b/dashboard/src/components/chat/ConversationSidebar.vue
@@ -117,6 +117,27 @@
{{ isDark ? tm('modes.lightMode') : tm('modes.darkMode') }}
+
+
+