diff --git a/astrbot/builtin_stars/astrbot/long_term_memory.py b/astrbot/builtin_stars/astrbot/long_term_memory.py index 610995db24..e08cdc5157 100644 --- a/astrbot/builtin_stars/astrbot/long_term_memory.py +++ b/astrbot/builtin_stars/astrbot/long_term_memory.py @@ -17,7 +17,7 @@ class LongTermMemory: - def __init__(self, acm: AstrBotConfigManager, context: star.Context): + def __init__(self, acm: AstrBotConfigManager, context: star.Context) -> None: self.acm = acm self.context = context self.session_chats = defaultdict(list) @@ -111,7 +111,7 @@ async def need_active_reply(self, event: AstrMessageEvent) -> bool: return False - async def handle_message(self, event: AstrMessageEvent): + async def handle_message(self, event: AstrMessageEvent) -> None: """仅支持群聊""" if event.get_message_type() == MessageType.GROUP_MESSAGE: datetime_str = datetime.datetime.now().strftime("%H:%M:%S") @@ -148,7 +148,7 @@ async def handle_message(self, event: AstrMessageEvent): if len(self.session_chats[event.unified_msg_origin]) > cfg["max_cnt"]: self.session_chats[event.unified_msg_origin].pop(0) - async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest): + async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest) -> None: """当触发 LLM 请求前,调用此方法修改 req""" if event.unified_msg_origin not in self.session_chats: return @@ -171,7 +171,9 @@ async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest): ) req.system_prompt += chats_str - async def after_req_llm(self, event: AstrMessageEvent, llm_resp: LLMResponse): + async def after_req_llm( + self, event: AstrMessageEvent, llm_resp: LLMResponse + ) -> None: if event.unified_msg_origin not in self.session_chats: return diff --git a/astrbot/builtin_stars/astrbot/main.py b/astrbot/builtin_stars/astrbot/main.py index 773d039393..da2a008354 100644 --- a/astrbot/builtin_stars/astrbot/main.py +++ b/astrbot/builtin_stars/astrbot/main.py @@ -85,7 +85,9 @@ async def on_message(self, event: AstrMessageEvent): logger.error(f"主动回复失败: {e}") @filter.on_llm_request() - async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest): + async def decorate_llm_req( + self, event: AstrMessageEvent, req: ProviderRequest + ) -> None: """在请求 LLM 前注入人格信息、Identifier、时间、回复内容等 System Prompt""" if self.ltm and self.ltm_enabled(event): try: @@ -94,7 +96,9 @@ async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest): logger.error(f"ltm: {e}") @filter.on_llm_response() - async def record_llm_resp_to_ltm(self, event: AstrMessageEvent, resp: LLMResponse): + async def record_llm_resp_to_ltm( + self, event: AstrMessageEvent, resp: LLMResponse + ) -> None: """在 LLM 响应后记录对话""" if self.ltm and self.ltm_enabled(event): try: @@ -103,7 +107,7 @@ async def record_llm_resp_to_ltm(self, event: AstrMessageEvent, resp: LLMRespons logger.error(f"ltm: {e}") @filter.after_message_sent() - async def after_message_sent(self, event: AstrMessageEvent): + async def after_message_sent(self, event: AstrMessageEvent) -> None: """消息发送后处理""" if self.ltm and self.ltm_enabled(event): try: diff --git a/astrbot/builtin_stars/builtin_commands/commands/admin.py b/astrbot/builtin_stars/builtin_commands/commands/admin.py index 83d4b5974d..a4f46b6036 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/admin.py +++ b/astrbot/builtin_stars/builtin_commands/commands/admin.py @@ -5,10 +5,10 @@ class AdminCommands: - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context - async def op(self, event: AstrMessageEvent, admin_id: str = ""): + async def op(self, event: AstrMessageEvent, admin_id: str = "") -> None: """授权管理员。op """ if not admin_id: event.set_result( @@ -21,7 +21,7 @@ async def op(self, event: AstrMessageEvent, admin_id: str = ""): self.context.get_config().save_config() event.set_result(MessageEventResult().message("授权成功。")) - async def deop(self, event: AstrMessageEvent, admin_id: str = ""): + async def deop(self, event: AstrMessageEvent, admin_id: str = "") -> None: """取消授权管理员。deop """ if not admin_id: event.set_result( @@ -39,7 +39,7 @@ async def deop(self, event: AstrMessageEvent, admin_id: str = ""): MessageEventResult().message("此用户 ID 不在管理员名单内。"), ) - async def wl(self, event: AstrMessageEvent, sid: str = ""): + async def wl(self, event: AstrMessageEvent, sid: str = "") -> None: """添加白名单。wl """ if not sid: event.set_result( @@ -53,7 +53,7 @@ async def wl(self, event: AstrMessageEvent, sid: str = ""): cfg.save_config() event.set_result(MessageEventResult().message("添加白名单成功。")) - async def dwl(self, event: AstrMessageEvent, sid: str = ""): + async def dwl(self, event: AstrMessageEvent, sid: str = "") -> None: """删除白名单。dwl """ if not sid: event.set_result( @@ -70,7 +70,7 @@ async def dwl(self, event: AstrMessageEvent, sid: str = ""): except ValueError: event.set_result(MessageEventResult().message("此 SID 不在白名单内。")) - async def update_dashboard(self, event: AstrMessageEvent): + async def update_dashboard(self, event: AstrMessageEvent) -> None: """更新管理面板""" await event.send(MessageChain().message("正在尝试更新管理面板...")) await download_dashboard(version=f"v{VERSION}", latest=False) diff --git a/astrbot/builtin_stars/builtin_commands/commands/alter_cmd.py b/astrbot/builtin_stars/builtin_commands/commands/alter_cmd.py index 50007f6c09..ba31c3326c 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/alter_cmd.py +++ b/astrbot/builtin_stars/builtin_commands/commands/alter_cmd.py @@ -11,10 +11,10 @@ class AlterCmdCommands(CommandParserMixin): - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context - async def update_reset_permission(self, scene_key: str, perm_type: str): + async def update_reset_permission(self, scene_key: str, perm_type: str) -> None: """更新reset命令在特定场景下的权限设置""" from astrbot.api import sp @@ -26,7 +26,7 @@ async def update_reset_permission(self, scene_key: str, perm_type: str): alter_cmd_cfg["astrbot"] = plugin_cfg await sp.global_put("alter_cmd", alter_cmd_cfg) - async def alter_cmd(self, event: AstrMessageEvent): + async def alter_cmd(self, event: AstrMessageEvent) -> None: token = self.parse_commands(event.message_str) if token.len < 3: await event.send( diff --git a/astrbot/builtin_stars/builtin_commands/commands/conversation.py b/astrbot/builtin_stars/builtin_commands/commands/conversation.py index de3d11ac82..eb8cfdefa5 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/conversation.py +++ b/astrbot/builtin_stars/builtin_commands/commands/conversation.py @@ -16,7 +16,7 @@ class ConversationCommands: - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context async def _get_current_persona_id(self, session_id): @@ -33,7 +33,7 @@ async def _get_current_persona_id(self, session_id): return None return conv.persona_id - async def reset(self, message: AstrMessageEvent): + async def reset(self, message: AstrMessageEvent) -> None: """重置 LLM 会话""" umo = message.unified_msg_origin cfg = self.context.get_config(umo=message.unified_msg_origin) @@ -98,7 +98,7 @@ async def reset(self, message: AstrMessageEvent): message.set_result(MessageEventResult().message(ret)) - async def his(self, message: AstrMessageEvent, page: int = 1): + async def his(self, message: AstrMessageEvent, page: int = 1) -> None: """查看对话记录""" if not self.context.get_using_provider(message.unified_msg_origin): message.set_result( @@ -141,7 +141,7 @@ async def his(self, message: AstrMessageEvent, page: int = 1): message.set_result(MessageEventResult().message(ret).use_t2i(False)) - async def convs(self, message: AstrMessageEvent, page: int = 1): + async def convs(self, message: AstrMessageEvent, page: int = 1) -> None: """查看对话列表""" cfg = self.context.get_config(umo=message.unified_msg_origin) agent_runner_type = cfg["provider_settings"]["agent_runner_type"] @@ -216,7 +216,7 @@ async def convs(self, message: AstrMessageEvent, page: int = 1): message.set_result(MessageEventResult().message(ret).use_t2i(False)) return - async def new_conv(self, message: AstrMessageEvent): + async def new_conv(self, message: AstrMessageEvent) -> None: """创建新对话""" cfg = self.context.get_config(umo=message.unified_msg_origin) agent_runner_type = cfg["provider_settings"]["agent_runner_type"] @@ -242,7 +242,7 @@ async def new_conv(self, message: AstrMessageEvent): MessageEventResult().message(f"切换到新对话: 新对话({cid[:4]})。"), ) - async def groupnew_conv(self, message: AstrMessageEvent, sid: str = ""): + async def groupnew_conv(self, message: AstrMessageEvent, sid: str = "") -> None: """创建新群聊对话""" if sid: session = str( @@ -273,7 +273,7 @@ async def switch_conv( self, message: AstrMessageEvent, index: int | None = None, - ): + ) -> None: """通过 /ls 前面的序号切换对话""" if not isinstance(index, int): message.set_result( @@ -308,7 +308,7 @@ async def switch_conv( ), ) - async def rename_conv(self, message: AstrMessageEvent, new_name: str = ""): + async def rename_conv(self, message: AstrMessageEvent, new_name: str = "") -> None: """重命名对话""" if not new_name: message.set_result(MessageEventResult().message("请输入新的对话名称。")) @@ -319,7 +319,7 @@ async def rename_conv(self, message: AstrMessageEvent, new_name: str = ""): ) message.set_result(MessageEventResult().message("重命名对话成功。")) - async def del_conv(self, message: AstrMessageEvent): + async def del_conv(self, message: AstrMessageEvent) -> None: """删除当前对话""" cfg = self.context.get_config(umo=message.unified_msg_origin) is_unique_session = cfg["platform_settings"]["unique_session"] diff --git a/astrbot/builtin_stars/builtin_commands/commands/help.py b/astrbot/builtin_stars/builtin_commands/commands/help.py index 092fc59ec0..ae2f4c787e 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/help.py +++ b/astrbot/builtin_stars/builtin_commands/commands/help.py @@ -8,7 +8,7 @@ class HelpCommand: - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context async def _query_astrbot_notice(self): @@ -34,7 +34,7 @@ async def _build_reserved_command_lines(self) -> list[str]: lines: list[str] = [] hidden_commands = {"set", "unset", "websearch"} - def walk(items: list[dict], indent: int = 0): + def walk(items: list[dict], indent: int = 0) -> None: for item in items: if not item.get("reserved") or not item.get("enabled"): continue @@ -62,7 +62,7 @@ def walk(items: list[dict], indent: int = 0): walk(commands) return lines - async def help(self, event: AstrMessageEvent): + async def help(self, event: AstrMessageEvent) -> None: """查看帮助""" notice = "" try: diff --git a/astrbot/builtin_stars/builtin_commands/commands/llm.py b/astrbot/builtin_stars/builtin_commands/commands/llm.py index 85977df407..ba9ba5c9b2 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/llm.py +++ b/astrbot/builtin_stars/builtin_commands/commands/llm.py @@ -3,10 +3,10 @@ class LLMCommands: - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context - async def llm(self, event: AstrMessageEvent): + async def llm(self, event: AstrMessageEvent) -> None: """开启/关闭 LLM""" cfg = self.context.get_config(umo=event.unified_msg_origin) enable = cfg["provider_settings"].get("enable", True) diff --git a/astrbot/builtin_stars/builtin_commands/commands/persona.py b/astrbot/builtin_stars/builtin_commands/commands/persona.py index 169c9e2b6f..cf99988a2e 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/persona.py +++ b/astrbot/builtin_stars/builtin_commands/commands/persona.py @@ -9,7 +9,7 @@ class PersonaCommands: - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context def _build_tree_output( @@ -50,7 +50,7 @@ def _build_tree_output( return lines - async def persona(self, message: AstrMessageEvent): + async def persona(self, message: AstrMessageEvent) -> None: l = message.message_str.split(" ") # noqa: E741 umo = message.unified_msg_origin diff --git a/astrbot/builtin_stars/builtin_commands/commands/plugin.py b/astrbot/builtin_stars/builtin_commands/commands/plugin.py index ab45efc116..49bee94627 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/plugin.py +++ b/astrbot/builtin_stars/builtin_commands/commands/plugin.py @@ -8,10 +8,10 @@ class PluginCommands: - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context - async def plugin_ls(self, event: AstrMessageEvent): + async def plugin_ls(self, event: AstrMessageEvent) -> None: """获取已经安装的插件列表。""" parts = ["已加载的插件:\n"] for plugin in self.context.get_all_stars(): @@ -30,7 +30,7 @@ async def plugin_ls(self, event: AstrMessageEvent): MessageEventResult().message(f"{plugin_list_info}").use_t2i(False), ) - async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = ""): + async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = "") -> None: """禁用插件""" if DEMO_MODE: event.set_result(MessageEventResult().message("演示模式下无法禁用插件。")) @@ -43,7 +43,7 @@ async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = ""): await self.context._star_manager.turn_off_plugin(plugin_name) # type: ignore event.set_result(MessageEventResult().message(f"插件 {plugin_name} 已禁用。")) - async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = ""): + async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = "") -> None: """启用插件""" if DEMO_MODE: event.set_result(MessageEventResult().message("演示模式下无法启用插件。")) @@ -56,7 +56,7 @@ async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = ""): await self.context._star_manager.turn_on_plugin(plugin_name) # type: ignore event.set_result(MessageEventResult().message(f"插件 {plugin_name} 已启用。")) - async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = ""): + async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = "") -> None: """安装插件""" if DEMO_MODE: event.set_result(MessageEventResult().message("演示模式下无法安装插件。")) @@ -77,7 +77,7 @@ async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = ""): event.set_result(MessageEventResult().message(f"安装插件失败: {e}")) return - async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = ""): + async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = "") -> None: """获取插件帮助""" if not plugin_name: event.set_result( diff --git a/astrbot/builtin_stars/builtin_commands/commands/provider.py b/astrbot/builtin_stars/builtin_commands/commands/provider.py index 60b81ebe5c..ae20eb8e1c 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/provider.py +++ b/astrbot/builtin_stars/builtin_commands/commands/provider.py @@ -8,7 +8,7 @@ class ProviderCommands: - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context def _log_reachability_failure( @@ -17,7 +17,7 @@ def _log_reachability_failure( provider_capability_type: ProviderType | None, err_code: str, err_reason: str, - ): + ) -> None: """记录不可达原因到日志。""" meta = provider.meta() logger.warning( @@ -49,7 +49,7 @@ async def provider( event: AstrMessageEvent, idx: str | int | None = None, idx2: int | None = None, - ): + ) -> None: """查看或者切换 LLM Provider""" umo = event.unified_msg_origin cfg = self.context.get_config(umo).get("provider_settings", {}) @@ -228,7 +228,7 @@ async def model_ls( self, message: AstrMessageEvent, idx_or_name: int | str | None = None, - ): + ) -> None: """查看或者切换模型""" prov = self.context.get_using_provider(message.unified_msg_origin) if not prov: @@ -293,7 +293,7 @@ async def model_ls( MessageEventResult().message(f"切换模型到 {prov.get_model()}。"), ) - async def key(self, message: AstrMessageEvent, index: int | None = None): + async def key(self, message: AstrMessageEvent, index: int | None = None) -> None: prov = self.context.get_using_provider(message.unified_msg_origin) if not prov: message.set_result( diff --git a/astrbot/builtin_stars/builtin_commands/commands/setunset.py b/astrbot/builtin_stars/builtin_commands/commands/setunset.py index 79e5d5d1ca..096698844d 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/setunset.py +++ b/astrbot/builtin_stars/builtin_commands/commands/setunset.py @@ -3,10 +3,10 @@ class SetUnsetCommands: - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context - async def set_variable(self, event: AstrMessageEvent, key: str, value: str): + async def set_variable(self, event: AstrMessageEvent, key: str, value: str) -> None: """设置会话变量""" uid = event.unified_msg_origin session_var = await sp.session_get(uid, "session_variables", {}) @@ -19,7 +19,7 @@ async def set_variable(self, event: AstrMessageEvent, key: str, value: str): ), ) - async def unset_variable(self, event: AstrMessageEvent, key: str): + async def unset_variable(self, event: AstrMessageEvent, key: str) -> None: """移除会话变量""" uid = event.unified_msg_origin session_var = await sp.session_get(uid, "session_variables", {}) diff --git a/astrbot/builtin_stars/builtin_commands/commands/sid.py b/astrbot/builtin_stars/builtin_commands/commands/sid.py index 4d95c5a606..e8bdbffb19 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/sid.py +++ b/astrbot/builtin_stars/builtin_commands/commands/sid.py @@ -7,10 +7,10 @@ class SIDCommand: """会话ID命令类""" - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context - async def sid(self, event: AstrMessageEvent): + async def sid(self, event: AstrMessageEvent) -> None: """获取消息来源信息""" sid = event.unified_msg_origin user_id = str(event.get_sender_id()) diff --git a/astrbot/builtin_stars/builtin_commands/commands/t2i.py b/astrbot/builtin_stars/builtin_commands/commands/t2i.py index 7766b342fc..78d6b0df7b 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/t2i.py +++ b/astrbot/builtin_stars/builtin_commands/commands/t2i.py @@ -7,10 +7,10 @@ class T2ICommand: """文本转图片命令类""" - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context - async def t2i(self, event: AstrMessageEvent): + async def t2i(self, event: AstrMessageEvent) -> None: """开关文本转图片""" config = self.context.get_config(umo=event.unified_msg_origin) if config["t2i"]: diff --git a/astrbot/builtin_stars/builtin_commands/commands/tts.py b/astrbot/builtin_stars/builtin_commands/commands/tts.py index dee8e31de3..13049ac22e 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/tts.py +++ b/astrbot/builtin_stars/builtin_commands/commands/tts.py @@ -8,10 +8,10 @@ class TTSCommand: """文本转语音命令类""" - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context - async def tts(self, event: AstrMessageEvent): + async def tts(self, event: AstrMessageEvent) -> None: """开关文本转语音(会话级别)""" umo = event.unified_msg_origin ses_tts = await SessionServiceManager.is_tts_enabled_for_session(umo) diff --git a/astrbot/builtin_stars/builtin_commands/main.py b/astrbot/builtin_stars/builtin_commands/main.py index 207a14b4a5..9b839ca881 100644 --- a/astrbot/builtin_stars/builtin_commands/main.py +++ b/astrbot/builtin_stars/builtin_commands/main.py @@ -35,84 +35,84 @@ def __init__(self, context: star.Context) -> None: self.sid_c = SIDCommand(self.context) @filter.command("help") - async def help(self, event: AstrMessageEvent): + async def help(self, event: AstrMessageEvent) -> None: """查看帮助""" await self.help_c.help(event) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("llm") - async def llm(self, event: AstrMessageEvent): + async def llm(self, event: AstrMessageEvent) -> None: """开启/关闭 LLM""" await self.llm_c.llm(event) @filter.command_group("plugin") - def plugin(self): + def plugin(self) -> None: """插件管理""" @plugin.command("ls") - async def plugin_ls(self, event: AstrMessageEvent): + async def plugin_ls(self, event: AstrMessageEvent) -> None: """获取已经安装的插件列表。""" await self.plugin_c.plugin_ls(event) @filter.permission_type(filter.PermissionType.ADMIN) @plugin.command("off") - async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = ""): + async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = "") -> None: """禁用插件""" await self.plugin_c.plugin_off(event, plugin_name) @filter.permission_type(filter.PermissionType.ADMIN) @plugin.command("on") - async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = ""): + async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = "") -> None: """启用插件""" await self.plugin_c.plugin_on(event, plugin_name) @filter.permission_type(filter.PermissionType.ADMIN) @plugin.command("get") - async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = ""): + async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = "") -> None: """安装插件""" await self.plugin_c.plugin_get(event, plugin_repo) @plugin.command("help") - async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = ""): + async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = "") -> None: """获取插件帮助""" await self.plugin_c.plugin_help(event, plugin_name) @filter.command("t2i") - async def t2i(self, event: AstrMessageEvent): + async def t2i(self, event: AstrMessageEvent) -> None: """开关文本转图片""" await self.t2i_c.t2i(event) @filter.command("tts") - async def tts(self, event: AstrMessageEvent): + async def tts(self, event: AstrMessageEvent) -> None: """开关文本转语音(会话级别)""" await self.tts_c.tts(event) @filter.command("sid") - async def sid(self, event: AstrMessageEvent): + async def sid(self, event: AstrMessageEvent) -> None: """获取会话 ID 和 管理员 ID""" await self.sid_c.sid(event) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("op") - async def op(self, event: AstrMessageEvent, admin_id: str = ""): + async def op(self, event: AstrMessageEvent, admin_id: str = "") -> None: """授权管理员。op """ await self.admin_c.op(event, admin_id) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("deop") - async def deop(self, event: AstrMessageEvent, admin_id: str): + async def deop(self, event: AstrMessageEvent, admin_id: str) -> None: """取消授权管理员。deop """ await self.admin_c.deop(event, admin_id) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("wl") - async def wl(self, event: AstrMessageEvent, sid: str = ""): + async def wl(self, event: AstrMessageEvent, sid: str = "") -> None: """添加白名单。wl """ await self.admin_c.wl(event, sid) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("dwl") - async def dwl(self, event: AstrMessageEvent, sid: str): + async def dwl(self, event: AstrMessageEvent, sid: str) -> None: """删除白名单。dwl """ await self.admin_c.dwl(event, sid) @@ -123,12 +123,12 @@ async def provider( event: AstrMessageEvent, idx: str | int | None = None, idx2: int | None = None, - ): + ) -> None: """查看或者切换 LLM Provider""" await self.provider_c.provider(event, idx, idx2) @filter.command("reset") - async def reset(self, message: AstrMessageEvent): + async def reset(self, message: AstrMessageEvent) -> None: """重置 LLM 会话""" await self.conversation_c.reset(message) @@ -138,74 +138,76 @@ async def model_ls( self, message: AstrMessageEvent, idx_or_name: int | str | None = None, - ): + ) -> None: """查看或者切换模型""" await self.provider_c.model_ls(message, idx_or_name) @filter.command("history") - async def his(self, message: AstrMessageEvent, page: int = 1): + async def his(self, message: AstrMessageEvent, page: int = 1) -> None: """查看对话记录""" await self.conversation_c.his(message, page) @filter.command("ls") - async def convs(self, message: AstrMessageEvent, page: int = 1): + async def convs(self, message: AstrMessageEvent, page: int = 1) -> None: """查看对话列表""" await self.conversation_c.convs(message, page) @filter.command("new") - async def new_conv(self, message: AstrMessageEvent): + async def new_conv(self, message: AstrMessageEvent) -> None: """创建新对话""" await self.conversation_c.new_conv(message) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("groupnew") - async def groupnew_conv(self, message: AstrMessageEvent, sid: str): + async def groupnew_conv(self, message: AstrMessageEvent, sid: str) -> None: """创建新群聊对话""" await self.conversation_c.groupnew_conv(message, sid) @filter.command("switch") - async def switch_conv(self, message: AstrMessageEvent, index: int | None = None): + async def switch_conv( + self, message: AstrMessageEvent, index: int | None = None + ) -> None: """通过 /ls 前面的序号切换对话""" await self.conversation_c.switch_conv(message, index) @filter.command("rename") - async def rename_conv(self, message: AstrMessageEvent, new_name: str): + async def rename_conv(self, message: AstrMessageEvent, new_name: str) -> None: """重命名对话""" await self.conversation_c.rename_conv(message, new_name) @filter.command("del") - async def del_conv(self, message: AstrMessageEvent): + async def del_conv(self, message: AstrMessageEvent) -> None: """删除当前对话""" await self.conversation_c.del_conv(message) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("key") - async def key(self, message: AstrMessageEvent, index: int | None = None): + async def key(self, message: AstrMessageEvent, index: int | None = None) -> None: """查看或者切换 Key""" await self.provider_c.key(message, index) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("persona") - async def persona(self, message: AstrMessageEvent): + async def persona(self, message: AstrMessageEvent) -> None: """查看或者切换 Persona""" await self.persona_c.persona(message) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("dashboard_update") - async def update_dashboard(self, event: AstrMessageEvent): + async def update_dashboard(self, event: AstrMessageEvent) -> None: """更新管理面板""" await self.admin_c.update_dashboard(event) @filter.command("set") - async def set_variable(self, event: AstrMessageEvent, key: str, value: str): + async def set_variable(self, event: AstrMessageEvent, key: str, value: str) -> None: await self.setunset_c.set_variable(event, key, value) @filter.command("unset") - async def unset_variable(self, event: AstrMessageEvent, key: str): + async def unset_variable(self, event: AstrMessageEvent, key: str) -> None: await self.setunset_c.unset_variable(event, key) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("alter_cmd", alias={"alter"}) - async def alter_cmd(self, event: AstrMessageEvent): + async def alter_cmd(self, event: AstrMessageEvent) -> None: """修改命令权限""" await self.alter_cmd_c.alter_cmd(event) diff --git a/astrbot/builtin_stars/session_controller/main.py b/astrbot/builtin_stars/session_controller/main.py index cb8c8bf58d..70081e03a6 100644 --- a/astrbot/builtin_stars/session_controller/main.py +++ b/astrbot/builtin_stars/session_controller/main.py @@ -17,11 +17,11 @@ class Main(Star): """会话控制""" - def __init__(self, context: Context): + def __init__(self, context: Context) -> None: super().__init__(context) @filter.event_message_type(filter.EventMessageType.ALL, priority=maxsize) - async def handle_session_control_agent(self, event: AstrMessageEvent): + async def handle_session_control_agent(self, event: AstrMessageEvent) -> None: """会话控制代理""" for session_filter in FILTERS: session_id = session_filter.filter(event) @@ -90,7 +90,7 @@ async def handle_empty_mention(self, event: AstrMessageEvent): async def empty_mention_waiter( controller: SessionController, event: AstrMessageEvent, - ): + ) -> None: event.message_obj.message.insert( 0, Comp.At(qq=event.get_self_id(), name=event.get_self_id()), diff --git a/astrbot/builtin_stars/web_searcher/engines/__init__.py b/astrbot/builtin_stars/web_searcher/engines/__init__.py index 82def138f8..55d2abffd7 100644 --- a/astrbot/builtin_stars/web_searcher/engines/__init__.py +++ b/astrbot/builtin_stars/web_searcher/engines/__init__.py @@ -49,7 +49,7 @@ def __init__(self) -> None: def _set_selector(self, selector: str) -> str: raise NotImplementedError - def _get_next_page(self, query: str): + async def _get_next_page(self, query: str) -> str: raise NotImplementedError async def _get_html(self, url: str, data: dict | None = None) -> str: diff --git a/astrbot/builtin_stars/web_searcher/main.py b/astrbot/builtin_stars/web_searcher/main.py index 12c8f68b38..85eeffd94e 100644 --- a/astrbot/builtin_stars/web_searcher/main.py +++ b/astrbot/builtin_stars/web_searcher/main.py @@ -199,7 +199,7 @@ async def _extract_tavily(self, cfg: AstrBotConfig, payload: dict) -> list[dict] return results @filter.command("websearch") - async def websearch(self, event: AstrMessageEvent, oper: str | None = None): + async def websearch(self, event: AstrMessageEvent, oper: str | None = None) -> None: """网页搜索指令(已废弃)""" event.set_result( MessageEventResult().message( @@ -246,7 +246,7 @@ async def search_from_search_engine( return ret - async def ensure_baidu_ai_search_mcp(self, umo: str | None = None): + async def ensure_baidu_ai_search_mcp(self, umo: str | None = None) -> None: if self.baidu_initialized: return cfg = self.context.get_config(umo=umo) @@ -553,7 +553,7 @@ async def edit_web_search_tools( self, event: AstrMessageEvent, req: ProviderRequest, - ): + ) -> None: """Get the session conversation for the given event.""" cfg = self.context.get_config(umo=event.unified_msg_origin) prov_settings = cfg.get("provider_settings", {}) diff --git a/astrbot/cli/commands/cmd_conf.py b/astrbot/cli/commands/cmd_conf.py index a9bd40f000..703c9b8995 100644 --- a/astrbot/cli/commands/cmd_conf.py +++ b/astrbot/cli/commands/cmd_conf.py @@ -127,7 +127,7 @@ def _get_nested_item(obj: dict[str, Any], path: str) -> Any: @click.group(name="conf") -def conf(): +def conf() -> None: """配置管理命令 支持的配置项: @@ -149,7 +149,7 @@ def conf(): @conf.command(name="set") @click.argument("key") @click.argument("value") -def set_config(key: str, value: str): +def set_config(key: str, value: str) -> None: """设置配置项的值""" if key not in CONFIG_VALIDATORS: raise click.ClickException(f"不支持的配置项: {key}") @@ -178,7 +178,7 @@ def set_config(key: str, value: str): @conf.command(name="get") @click.argument("key", required=False) -def get_config(key: str | None = None): +def get_config(key: str | None = None) -> None: """获取配置项的值,不提供key则显示所有可配置项""" config = _load_config() diff --git a/astrbot/cli/commands/cmd_plug.py b/astrbot/cli/commands/cmd_plug.py index a1099de1d6..9cf94365af 100644 --- a/astrbot/cli/commands/cmd_plug.py +++ b/astrbot/cli/commands/cmd_plug.py @@ -15,7 +15,7 @@ @click.group() -def plug(): +def plug() -> None: """插件管理""" @@ -28,7 +28,7 @@ def _get_data_path() -> Path: return (base / "data").resolve() -def display_plugins(plugins, title=None, color=None): +def display_plugins(plugins, title=None, color=None) -> None: if title: click.echo(click.style(title, fg=color, bold=True)) @@ -45,7 +45,7 @@ def display_plugins(plugins, title=None, color=None): @plug.command() @click.argument("name") -def new(name: str): +def new(name: str) -> None: """创建新插件""" base_path = _get_data_path() plug_path = base_path / "plugins" / name @@ -100,7 +100,7 @@ def new(name: str): @plug.command() @click.option("--all", "-a", is_flag=True, help="列出未安装的插件") -def list(all: bool): +def list(all: bool) -> None: """列出插件""" base_path = _get_data_path() plugins = build_plug_list(base_path / "plugins") @@ -141,7 +141,7 @@ def list(all: bool): @plug.command() @click.argument("name") @click.option("--proxy", help="代理服务器地址") -def install(name: str, proxy: str | None): +def install(name: str, proxy: str | None) -> None: """安装插件""" base_path = _get_data_path() plug_path = base_path / "plugins" @@ -164,7 +164,7 @@ def install(name: str, proxy: str | None): @plug.command() @click.argument("name") -def remove(name: str): +def remove(name: str) -> None: """卸载插件""" base_path = _get_data_path() plugins = build_plug_list(base_path / "plugins") @@ -187,7 +187,7 @@ def remove(name: str): @plug.command() @click.argument("name", required=False) @click.option("--proxy", help="Github代理地址") -def update(name: str, proxy: str | None): +def update(name: str, proxy: str | None) -> None: """更新插件""" base_path = _get_data_path() plug_path = base_path / "plugins" @@ -225,7 +225,7 @@ def update(name: str, proxy: str | None): @plug.command() @click.argument("query") -def search(query: str): +def search(query: str) -> None: """搜索插件""" base_path = _get_data_path() plugins = build_plug_list(base_path / "plugins") diff --git a/astrbot/cli/commands/cmd_run.py b/astrbot/cli/commands/cmd_run.py index 9333f1b870..23665dff3d 100644 --- a/astrbot/cli/commands/cmd_run.py +++ b/astrbot/cli/commands/cmd_run.py @@ -10,7 +10,7 @@ from ..utils import check_astrbot_root, check_dashboard, get_astrbot_root -async def run_astrbot(astrbot_root: Path): +async def run_astrbot(astrbot_root: Path) -> None: """运行 AstrBot""" from astrbot.core import LogBroker, LogManager, db_helper, logger from astrbot.core.initial_loader import InitialLoader diff --git a/astrbot/cli/utils/plugin.py b/astrbot/cli/utils/plugin.py index cd76a07c8d..81f59e0bfc 100644 --- a/astrbot/cli/utils/plugin.py +++ b/astrbot/cli/utils/plugin.py @@ -19,7 +19,7 @@ class PluginStatus(str, Enum): NOT_PUBLISHED = "未发布" -def get_git_repo(url: str, target_path: Path, proxy: str | None = None): +def get_git_repo(url: str, target_path: Path, proxy: str | None = None) -> None: """从 Git 仓库下载代码并解压到指定路径""" temp_dir = Path(tempfile.mkdtemp()) try: diff --git a/astrbot/core/agent/context/compressor.py b/astrbot/core/agent/context/compressor.py index 792835181f..31a0b0b48d 100644 --- a/astrbot/core/agent/context/compressor.py +++ b/astrbot/core/agent/context/compressor.py @@ -57,7 +57,9 @@ class TruncateByTurnsCompressor: Truncates the message list by removing older turns. """ - def __init__(self, truncate_turns: int = 1, compression_threshold: float = 0.82): + def __init__( + self, truncate_turns: int = 1, compression_threshold: float = 0.82 + ) -> None: """Initialize the truncate by turns compressor. Args: @@ -152,7 +154,7 @@ def __init__( keep_recent: int = 4, instruction_text: str | None = None, compression_threshold: float = 0.82, - ): + ) -> None: """Initialize the LLM summary compressor. Args: diff --git a/astrbot/core/agent/context/manager.py b/astrbot/core/agent/context/manager.py index b8e131d983..216a3e7e15 100644 --- a/astrbot/core/agent/context/manager.py +++ b/astrbot/core/agent/context/manager.py @@ -13,7 +13,7 @@ class ContextManager: def __init__( self, config: ContextConfig, - ): + ) -> None: """Initialize the context manager. There are two strategies to handle context limit reached: diff --git a/astrbot/core/agent/handoff.py b/astrbot/core/agent/handoff.py index 5812766c85..68bccb8f95 100644 --- a/astrbot/core/agent/handoff.py +++ b/astrbot/core/agent/handoff.py @@ -14,7 +14,7 @@ def __init__( parameters: dict | None = None, tool_description: str | None = None, **kwargs, - ): + ) -> None: self.agent = agent # Avoid passing duplicate `description` to the FunctionTool dataclass. diff --git a/astrbot/core/agent/hooks.py b/astrbot/core/agent/hooks.py index d834240b73..74ca6335b3 100644 --- a/astrbot/core/agent/hooks.py +++ b/astrbot/core/agent/hooks.py @@ -9,22 +9,22 @@ class BaseAgentRunHooks(Generic[TContext]): - async def on_agent_begin(self, run_context: ContextWrapper[TContext]): ... + async def on_agent_begin(self, run_context: ContextWrapper[TContext]) -> None: ... async def on_tool_start( self, run_context: ContextWrapper[TContext], tool: FunctionTool, tool_args: dict | None, - ): ... + ) -> None: ... async def on_tool_end( self, run_context: ContextWrapper[TContext], tool: FunctionTool, tool_args: dict | None, tool_result: mcp.types.CallToolResult | None, - ): ... + ) -> None: ... async def on_agent_done( self, run_context: ContextWrapper[TContext], llm_response: LLMResponse, - ): ... + ) -> None: ... diff --git a/astrbot/core/agent/mcp_client.py b/astrbot/core/agent/mcp_client.py index c5ff123b24..18f4d47e04 100644 --- a/astrbot/core/agent/mcp_client.py +++ b/astrbot/core/agent/mcp_client.py @@ -108,7 +108,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: class MCPClient: - def __init__(self): + def __init__(self) -> None: # Initialize session and client objects self.session: mcp.ClientSession | None = None self.exit_stack = AsyncExitStack() @@ -126,7 +126,7 @@ def __init__(self): self._reconnect_lock = asyncio.Lock() # Lock for thread-safe reconnection self._reconnecting: bool = False # For logging and debugging - async def connect_to_server(self, mcp_server_config: dict, name: str): + async def connect_to_server(self, mcp_server_config: dict, name: str) -> None: """Connect to MCP server If `url` parameter exists: @@ -144,7 +144,7 @@ async def connect_to_server(self, mcp_server_config: dict, name: str): cfg = _prepare_config(mcp_server_config.copy()) - def logging_callback(msg: str): + def logging_callback(msg: str) -> None: # Handle MCP service error logs print(f"MCP Server {name} Error: {msg}") self.server_errlogs.append(msg) @@ -214,7 +214,7 @@ def logging_callback(msg: str): **cfg, ) - def callback(msg: str): + def callback(msg: str) -> None: # Handle MCP service error logs self.server_errlogs.append(msg) @@ -343,7 +343,7 @@ async def _call_with_retry(): return await _call_with_retry() - async def cleanup(self): + async def cleanup(self) -> None: """Clean up resources including old exit stacks from reconnections""" # Close current exit stack try: @@ -365,7 +365,7 @@ class MCPTool(FunctionTool, Generic[TContext]): def __init__( self, mcp_tool: mcp.Tool, mcp_client: MCPClient, mcp_server_name: str, **kwargs - ): + ) -> None: super().__init__( name=mcp_tool.name, description=mcp_tool.description or "", diff --git a/astrbot/core/agent/runners/coze/coze_api_client.py b/astrbot/core/agent/runners/coze/coze_api_client.py index e8f3a1e24b..f5799dfbb7 100644 --- a/astrbot/core/agent/runners/coze/coze_api_client.py +++ b/astrbot/core/agent/runners/coze/coze_api_client.py @@ -10,7 +10,7 @@ class CozeAPIClient: - def __init__(self, api_key: str, api_base: str = "https://api.coze.cn"): + def __init__(self, api_key: str, api_base: str = "https://api.coze.cn") -> None: self.api_key = api_key self.api_base = api_base self.session = None @@ -277,7 +277,7 @@ async def get_message_list( logger.error(f"获取Coze消息列表失败: {e!s}") raise Exception(f"获取Coze消息列表失败: {e!s}") - async def close(self): + async def close(self) -> None: """关闭会话""" if self.session: await self.session.close() @@ -288,7 +288,7 @@ async def close(self): import asyncio import os - async def test_coze_api_client(): + async def test_coze_api_client() -> None: api_key = os.getenv("COZE_API_KEY", "") bot_id = os.getenv("COZE_BOT_ID", "") client = CozeAPIClient(api_key=api_key) diff --git a/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py b/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py index 7a095a60b6..1aaf6e3b9c 100644 --- a/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py +++ b/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py @@ -67,7 +67,7 @@ async def reset( if isinstance(self.timeout, str): self.timeout = int(self.timeout) - def has_rag_options(self): + def has_rag_options(self) -> bool: """判断是否有 RAG 选项 Returns: diff --git a/astrbot/core/agent/runners/dify/dify_api_client.py b/astrbot/core/agent/runners/dify/dify_api_client.py index d9c6556cf5..26da6dfe9a 100644 --- a/astrbot/core/agent/runners/dify/dify_api_client.py +++ b/astrbot/core/agent/runners/dify/dify_api_client.py @@ -31,7 +31,7 @@ async def _stream_sse(resp: ClientResponse) -> AsyncGenerator[dict, None]: class DifyAPIClient: - def __init__(self, api_key: str, api_base: str = "https://api.dify.ai/v1"): + def __init__(self, api_key: str, api_base: str = "https://api.dify.ai/v1") -> None: self.api_key = api_key self.api_base = api_base self.session = ClientSession(trust_env=True) @@ -155,7 +155,7 @@ async def file_upload( raise Exception(f"Dify 文件上传失败:{resp.status}. {text}") return await resp.json() # {"id": "xxx", ...} - async def close(self): + async def close(self) -> None: await self.session.close() async def get_chat_convs(self, user: str, limit: int = 20): diff --git a/astrbot/core/agent/tool.py b/astrbot/core/agent/tool.py index 50899ff80b..04a47b35f1 100644 --- a/astrbot/core/agent/tool.py +++ b/astrbot/core/agent/tool.py @@ -64,7 +64,7 @@ class FunctionTool(ToolSchema, Generic[TContext]): with a task identifier while the real work continues asynchronously. """ - def __repr__(self): + def __repr__(self) -> str: return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description})" async def call(self, context: ContextWrapper[TContext], **kwargs) -> ToolExecResult: @@ -88,7 +88,7 @@ def empty(self) -> bool: """Check if the tool set is empty.""" return len(self.tools) == 0 - def add_tool(self, tool: FunctionTool): + def add_tool(self, tool: FunctionTool) -> None: """Add a tool to the set.""" # 检查是否已存在同名工具 for i, existing_tool in enumerate(self.tools): @@ -97,7 +97,7 @@ def add_tool(self, tool: FunctionTool): return self.tools.append(tool) - def remove_tool(self, name: str): + def remove_tool(self, name: str) -> None: """Remove a tool by its name.""" self.tools = [tool for tool in self.tools if tool.name != name] @@ -156,7 +156,7 @@ def add_func( func_args: list, desc: str, handler: Callable[..., Awaitable[Any]], - ): + ) -> None: """Add a function tool to the set.""" params = { "type": "object", # hard-coded here @@ -176,7 +176,7 @@ def add_func( self.add_tool(_func) @deprecated(reason="Use remove_tool() instead", version="4.0.0") - def remove_func(self, name: str): + def remove_func(self, name: str) -> None: """Remove a function tool by its name.""" self.remove_tool(name) @@ -325,22 +325,22 @@ def names(self) -> list[str]: """获取所有工具的名称列表""" return [tool.name for tool in self.tools] - def merge(self, other: "ToolSet"): + def merge(self, other: "ToolSet") -> None: """Merge another ToolSet into this one.""" for tool in other.tools: self.add_tool(tool) - def __len__(self): + def __len__(self) -> int: return len(self.tools) - def __bool__(self): + def __bool__(self) -> bool: return len(self.tools) > 0 def __iter__(self): return iter(self.tools) - def __repr__(self): + def __repr__(self) -> str: return f"ToolSet(tools={self.tools})" - def __str__(self): + def __str__(self) -> str: return f"ToolSet(tools={self.tools})" diff --git a/astrbot/core/astr_agent_hooks.py b/astrbot/core/astr_agent_hooks.py index 1bd8c451e2..09bf32deb4 100644 --- a/astrbot/core/astr_agent_hooks.py +++ b/astrbot/core/astr_agent_hooks.py @@ -12,7 +12,7 @@ class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]): - async def on_agent_done(self, run_context, llm_response): + async def on_agent_done(self, run_context, llm_response) -> None: # 执行事件钩子 if llm_response and llm_response.reasoning_content: # we will use this in result_decorate stage to inject reasoning content to chain @@ -31,7 +31,7 @@ async def on_tool_start( run_context: ContextWrapper[AstrAgentContext], tool: FunctionTool[Any], tool_args: dict | None, - ): + ) -> None: await call_event_hook( run_context.context.event, EventType.OnUsingLLMToolEvent, @@ -45,7 +45,7 @@ async def on_tool_end( tool: FunctionTool[Any], tool_args: dict | None, tool_result: CallToolResult | None, - ): + ) -> None: run_context.context.event.clear_result() await call_event_hook( run_context.context.event, diff --git a/astrbot/core/astr_agent_run_util.py b/astrbot/core/astr_agent_run_util.py index 5556aa6b40..379e62d6a5 100644 --- a/astrbot/core/astr_agent_run_util.py +++ b/astrbot/core/astr_agent_run_util.py @@ -295,7 +295,7 @@ async def _run_agent_feeder( max_step: int, show_tool_use: bool, show_reasoning: bool, -): +) -> None: """运行 Agent 并将文本输出分句放入队列""" buffer = "" try: @@ -352,7 +352,7 @@ async def _safe_tts_stream_wrapper( tts_provider: TTSProvider, text_queue: asyncio.Queue[str | None], audio_queue: "asyncio.Queue[bytes | tuple[str, bytes] | None]", -): +) -> None: """包装原生流式 TTS 确保异常处理和队列关闭""" try: await tts_provider.get_audio_stream(text_queue, audio_queue) @@ -366,7 +366,7 @@ async def _simulated_stream_tts( tts_provider: TTSProvider, text_queue: asyncio.Queue[str | None], audio_queue: "asyncio.Queue[bytes | tuple[str, bytes] | None]", -): +) -> None: """模拟流式 TTS 分句生成音频""" try: while True: diff --git a/astrbot/core/astr_agent_tool_exec.py b/astrbot/core/astr_agent_tool_exec.py index 460cab3324..230faaf1c1 100644 --- a/astrbot/core/astr_agent_tool_exec.py +++ b/astrbot/core/astr_agent_tool_exec.py @@ -57,7 +57,7 @@ async def execute(cls, tool, run_context, **tool_args): elif tool.is_background_task: task_id = uuid.uuid4().hex - async def _run_in_background(): + async def _run_in_background() -> None: try: await cls._execute_background( tool=tool, @@ -153,7 +153,7 @@ async def _execute_background( run_context: ContextWrapper[AstrAgentContext], task_id: str, **tool_args, - ): + ) -> None: from astrbot.core.astr_main_agent import ( MainAgentBuildConfig, _get_session_conv, diff --git a/astrbot/core/astrbot_config_mgr.py b/astrbot/core/astrbot_config_mgr.py index 3a1353ce5e..c2bfb1c37b 100644 --- a/astrbot/core/astrbot_config_mgr.py +++ b/astrbot/core/astrbot_config_mgr.py @@ -36,7 +36,7 @@ def __init__( default_config: AstrBotConfig, ucr: UmopConfigRouter, sp: SharedPreferences, - ): + ) -> None: self.sp = sp self.ucr = ucr self.confs: dict[str, AstrBotConfig] = {} @@ -56,7 +56,7 @@ def _get_abconf_data(self) -> dict: ) return self.abconf_data - def _load_all_configs(self): + def _load_all_configs(self) -> None: """Load all configurations from the shared preferences.""" abconf_data = self._get_abconf_data() self.abconf_data = abconf_data diff --git a/astrbot/core/backup/exporter.py b/astrbot/core/backup/exporter.py index 51c4a4650f..a922375998 100644 --- a/astrbot/core/backup/exporter.py +++ b/astrbot/core/backup/exporter.py @@ -59,7 +59,7 @@ def __init__( main_db: BaseDatabase, kb_manager: "KnowledgeBaseManager | None" = None, config_path: str = CMD_CONFIG_FILE_PATH, - ): + ) -> None: self.main_db = main_db self.kb_manager = kb_manager self.config_path = config_path diff --git a/astrbot/core/backup/importer.py b/astrbot/core/backup/importer.py index f36a79cf53..2e67f85e5c 100644 --- a/astrbot/core/backup/importer.py +++ b/astrbot/core/backup/importer.py @@ -110,7 +110,7 @@ def to_dict(self) -> dict: class ImportResult: """导入结果""" - def __init__(self): + def __init__(self) -> None: self.success = True self.imported_tables: dict[str, int] = {} self.imported_files: dict[str, int] = {} @@ -161,7 +161,7 @@ def __init__( kb_manager: "KnowledgeBaseManager | None" = None, config_path: str = CMD_CONFIG_FILE_PATH, kb_root_dir: str = KB_PATH, - ): + ) -> None: self.main_db = main_db self.kb_manager = kb_manager self.config_path = config_path diff --git a/astrbot/core/computer/booters/base.py b/astrbot/core/computer/booters/base.py index 55af866a3e..ea93a3d6d9 100644 --- a/astrbot/core/computer/booters/base.py +++ b/astrbot/core/computer/booters/base.py @@ -22,7 +22,7 @@ async def upload_file(self, path: str, file_name: str) -> dict: """ ... - async def download_file(self, remote_path: str, local_path: str): + async def download_file(self, remote_path: str, local_path: str) -> None: """Download file from the computer.""" ... diff --git a/astrbot/core/computer/booters/local.py b/astrbot/core/computer/booters/local.py index c89c40187b..a80ef0da28 100644 --- a/astrbot/core/computer/booters/local.py +++ b/astrbot/core/computer/booters/local.py @@ -225,7 +225,7 @@ async def upload_file(self, path: str, file_name: str) -> dict: "LocalBooter does not support upload_file operation. Use shell instead." ) - async def download_file(self, remote_path: str, local_path: str): + async def download_file(self, remote_path: str, local_path: str) -> None: raise NotImplementedError( "LocalBooter does not support download_file operation. Use shell instead." ) diff --git a/astrbot/core/computer/tools/fs.py b/astrbot/core/computer/tools/fs.py index 9acc371b2c..9cf590a616 100644 --- a/astrbot/core/computer/tools/fs.py +++ b/astrbot/core/computer/tools/fs.py @@ -100,7 +100,7 @@ async def call( self, context: ContextWrapper[AstrAgentContext], local_path: str, - ): + ) -> str | None: sb = await get_booter( context.context.context, context.context.event.unified_msg_origin, diff --git a/astrbot/core/config/astrbot_config.py b/astrbot/core/config/astrbot_config.py index 2208ee766d..bda02324e5 100644 --- a/astrbot/core/config/astrbot_config.py +++ b/astrbot/core/config/astrbot_config.py @@ -33,7 +33,7 @@ def __init__( config_path: str = ASTRBOT_CONFIG_PATH, default_config: dict = DEFAULT_CONFIG, schema: dict | None = None, - ): + ) -> None: super().__init__() # 调用父类的 __setattr__ 方法,防止保存配置时将此属性写入配置文件 @@ -66,7 +66,7 @@ def _config_schema_to_default_config(self, schema: dict) -> dict: """将 Schema 转换成 Config""" conf = {} - def _parse_schema(schema: dict, conf: dict): + def _parse_schema(schema: dict, conf: dict) -> None: for k, v in schema.items(): if v["type"] not in DEFAULT_VALUE_MAP: raise TypeError( @@ -148,7 +148,7 @@ def check_config_integrity(self, refer_conf: dict, conf: dict, path=""): return has_new - def save_config(self, replace_config: dict | None = None): + def save_config(self, replace_config: dict | None = None) -> None: """将配置写入文件 如果传入 replace_config,则将配置替换为 replace_config @@ -164,14 +164,14 @@ def __getattr__(self, item): except KeyError: return None - def __delattr__(self, key): + def __delattr__(self, key) -> None: try: del self[key] self.save_config() except KeyError: raise AttributeError(f"没有找到 Key: '{key}'") - def __setattr__(self, key, value): + def __setattr__(self, key, value) -> None: self[key] = value def check_exist(self) -> bool: diff --git a/astrbot/core/conversation_mgr.py b/astrbot/core/conversation_mgr.py index a0a0c0e2fa..6fcb3608c3 100644 --- a/astrbot/core/conversation_mgr.py +++ b/astrbot/core/conversation_mgr.py @@ -16,7 +16,7 @@ class ConversationManager: """负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。""" - def __init__(self, db_helper: BaseDatabase): + def __init__(self, db_helper: BaseDatabase) -> None: self.session_conversations: dict[str, str] = {} self.db = db_helper self.save_interval = 60 # 每 60 秒保存一次 @@ -106,7 +106,9 @@ async def new_conversation( await sp.session_put(unified_msg_origin, "sel_conv_id", conv.conversation_id) return conv.conversation_id - async def switch_conversation(self, unified_msg_origin: str, conversation_id: str): + async def switch_conversation( + self, unified_msg_origin: str, conversation_id: str + ) -> None: """切换会话的对话 Args: @@ -121,7 +123,7 @@ async def delete_conversation( self, unified_msg_origin: str, conversation_id: str | None = None, - ): + ) -> None: """删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话 Args: @@ -138,7 +140,7 @@ async def delete_conversation( self.session_conversations.pop(unified_msg_origin, None) await sp.session_remove(unified_msg_origin, "sel_conv_id") - async def delete_conversations_by_user_id(self, unified_msg_origin: str): + async def delete_conversations_by_user_id(self, unified_msg_origin: str) -> None: """删除会话的所有对话 Args: diff --git a/astrbot/core/cron/events.py b/astrbot/core/cron/events.py index d4f0e01e27..a90ca38227 100644 --- a/astrbot/core/cron/events.py +++ b/astrbot/core/cron/events.py @@ -24,7 +24,7 @@ def __init__( sender_name: str = "Scheduler", extras: dict[str, Any] | None = None, message_type: MessageType = MessageType.FRIEND_MESSAGE, - ): + ) -> None: platform_meta = PlatformMetadata( name="cron", description="CronJob", @@ -53,13 +53,13 @@ def __init__( if extras: self._extras.update(extras) - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: if message is None: return await self.context_obj.send_message(self.session, message) await super().send(message) - async def send_streaming(self, generator, use_fallback: bool = False): + async def send_streaming(self, generator, use_fallback: bool = False) -> None: async for chain in generator: await self.send(chain) diff --git a/astrbot/core/cron/manager.py b/astrbot/core/cron/manager.py index 0572fa03a6..d12878be3e 100644 --- a/astrbot/core/cron/manager.py +++ b/astrbot/core/cron/manager.py @@ -25,14 +25,14 @@ class CronJobManager: """Central scheduler for BasicCronJob and ActiveAgentCronJob.""" - def __init__(self, db: BaseDatabase): + def __init__(self, db: BaseDatabase) -> None: self.db = db self.scheduler = AsyncIOScheduler() self._basic_handlers: dict[str, Callable[..., Any]] = {} self._lock = asyncio.Lock() self._started = False - async def start(self, ctx: "Context"): + async def start(self, ctx: "Context") -> None: self.ctx: Context = ctx # star context async with self._lock: if self._started: @@ -41,14 +41,14 @@ async def start(self, ctx: "Context"): self._started = True await self.sync_from_db() - async def shutdown(self): + async def shutdown(self) -> None: async with self._lock: if not self._started: return self.scheduler.shutdown(wait=False) self._started = False - async def sync_from_db(self): + async def sync_from_db(self) -> None: jobs = await self.db.list_cron_jobs() for job in jobs: if not job.enabled or not job.persistent: @@ -136,11 +136,11 @@ async def delete_job(self, job_id: str) -> None: async def list_jobs(self, job_type: str | None = None) -> list[CronJob]: return await self.db.list_cron_jobs(job_type) - def _remove_scheduled(self, job_id: str): + def _remove_scheduled(self, job_id: str) -> None: if self.scheduler.get_job(job_id): self.scheduler.remove_job(job_id) - def _schedule_job(self, job: CronJob): + def _schedule_job(self, job: CronJob) -> None: if not self._started: self.scheduler.start() self._started = True @@ -188,7 +188,7 @@ def _get_next_run_time(self, job_id: str): aps_job = self.scheduler.get_job(job_id) return aps_job.next_run_time if aps_job else None - async def _run_job(self, job_id: str): + async def _run_job(self, job_id: str) -> None: job = await self.db.get_cron_job(job_id) if not job or not job.enabled: return @@ -222,7 +222,7 @@ async def _run_job(self, job_id: str): # one-shot: remove after execution regardless of success await self.delete_job(job_id) - async def _run_basic_job(self, job: CronJob): + async def _run_basic_job(self, job: CronJob) -> None: handler = self._basic_handlers.get(job.job_id) if not handler: raise RuntimeError(f"Basic cron job handler not found for {job.job_id}") @@ -231,7 +231,7 @@ async def _run_basic_job(self, job: CronJob): if asyncio.iscoroutine(result): await result - async def _run_active_agent_job(self, job: CronJob, start_time: datetime): + async def _run_active_agent_job(self, job: CronJob, start_time: datetime) -> None: payload = job.payload or {} session_str = payload.get("session") if not session_str: @@ -266,7 +266,7 @@ async def _woke_main_agent( message: str, session_str: str, extras: dict, - ): + ) -> None: """Woke the main agent to handle the cron job message.""" from astrbot.core.astr_main_agent import ( MainAgentBuildConfig, diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py index 7b67b87554..72837e178d 100644 --- a/astrbot/core/db/__init__.py +++ b/astrbot/core/db/__init__.py @@ -43,7 +43,7 @@ def __init__(self) -> None: expire_on_commit=False, ) - async def initialize(self): + async def initialize(self) -> None: """初始化数据库连接""" @asynccontextmanager diff --git a/astrbot/core/db/migration/migra_3_to_4.py b/astrbot/core/db/migration/migra_3_to_4.py index 66b72d5cb1..727d97b29b 100644 --- a/astrbot/core/db/migration/migra_3_to_4.py +++ b/astrbot/core/db/migration/migra_3_to_4.py @@ -43,7 +43,7 @@ def get_platform_type( async def migration_conversation_table( db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]], -): +) -> None: db_helper_v3 = SQLiteV3DatabaseV3( db_path=DB_PATH.replace("data_v4.db", "data_v3.db"), ) @@ -101,7 +101,7 @@ async def migration_conversation_table( async def migration_platform_table( db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]], -): +) -> None: db_helper_v3 = SQLiteV3DatabaseV3( db_path=DB_PATH.replace("data_v4.db", "data_v3.db"), ) @@ -180,7 +180,7 @@ async def migration_platform_table( async def migration_webchat_data( db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]], -): +) -> None: """迁移 WebChat 的历史记录到新的 PlatformMessageHistory 表中""" db_helper_v3 = SQLiteV3DatabaseV3( db_path=DB_PATH.replace("data_v4.db", "data_v3.db"), @@ -236,7 +236,7 @@ async def migration_webchat_data( async def migration_persona_data( db_helper: BaseDatabase, astrbot_config: AstrBotConfig, -): +) -> None: """迁移 Persona 数据到新的表中。 旧的 Persona 数据存储在 preference 中,新的 Persona 数据存储在 persona 表中。 """ @@ -279,7 +279,7 @@ async def migration_persona_data( async def migration_preferences( db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]], -): +) -> None: # 1. global scope migration keys = [ "inactivated_llm_tools", diff --git a/astrbot/core/db/migration/migra_45_to_46.py b/astrbot/core/db/migration/migra_45_to_46.py index dc70026f9f..58736ab51f 100644 --- a/astrbot/core/db/migration/migra_45_to_46.py +++ b/astrbot/core/db/migration/migra_45_to_46.py @@ -3,7 +3,7 @@ from astrbot.core.umop_config_router import UmopConfigRouter -async def migrate_45_to_46(acm: AstrBotConfigManager, ucr: UmopConfigRouter): +async def migrate_45_to_46(acm: AstrBotConfigManager, ucr: UmopConfigRouter) -> None: abconf_data = acm.abconf_data if not isinstance(abconf_data, dict): diff --git a/astrbot/core/db/migration/migra_token_usage.py b/astrbot/core/db/migration/migra_token_usage.py index 07938301d7..76bf8ce01c 100644 --- a/astrbot/core/db/migration/migra_token_usage.py +++ b/astrbot/core/db/migration/migra_token_usage.py @@ -12,7 +12,7 @@ from astrbot.core.db import BaseDatabase -async def migrate_token_usage(db_helper: BaseDatabase): +async def migrate_token_usage(db_helper: BaseDatabase) -> None: """Add token_usage column to conversations table. This migration adds a new column to track token consumption in conversations. diff --git a/astrbot/core/db/migration/migra_webchat_session.py b/astrbot/core/db/migration/migra_webchat_session.py index ff0b5ca6fc..46025fc646 100644 --- a/astrbot/core/db/migration/migra_webchat_session.py +++ b/astrbot/core/db/migration/migra_webchat_session.py @@ -17,7 +17,7 @@ from astrbot.core.db.po import ConversationV2, PlatformMessageHistory, PlatformSession -async def migrate_webchat_session(db_helper: BaseDatabase): +async def migrate_webchat_session(db_helper: BaseDatabase) -> None: """Create PlatformSession records from platform_message_history. This migration extracts all unique user_ids from platform_message_history diff --git a/astrbot/core/db/migration/shared_preferences_v3.py b/astrbot/core/db/migration/shared_preferences_v3.py index 3abcb1a66a..05b514583d 100644 --- a/astrbot/core/db/migration/shared_preferences_v3.py +++ b/astrbot/core/db/migration/shared_preferences_v3.py @@ -8,7 +8,7 @@ class SharedPreferences: - def __init__(self, path=None): + def __init__(self, path=None) -> None: if path is None: path = os.path.join(get_astrbot_data_path(), "shared_preferences.json") self.path = path @@ -23,7 +23,7 @@ def _load_preferences(self): os.remove(self.path) return {} - def _save_preferences(self): + def _save_preferences(self) -> None: with open(self.path, "w") as f: json.dump(self._data, f, indent=4, ensure_ascii=False) f.flush() @@ -31,16 +31,16 @@ def _save_preferences(self): def get(self, key, default: _VT = None) -> _VT: return self._data.get(key, default) - def put(self, key, value): + def put(self, key, value) -> None: self._data[key] = value self._save_preferences() - def remove(self, key): + def remove(self, key) -> None: if key in self._data: del self._data[key] self._save_preferences() - def clear(self): + def clear(self) -> None: self._data.clear() self._save_preferences() diff --git a/astrbot/core/db/migration/sqlite_v3.py b/astrbot/core/db/migration/sqlite_v3.py index b1a780d488..b326ebb449 100644 --- a/astrbot/core/db/migration/sqlite_v3.py +++ b/astrbot/core/db/migration/sqlite_v3.py @@ -127,7 +127,7 @@ def _get_conn(self, db_path: str) -> sqlite3.Connection: conn.text_factory = str return conn - def _exec_sql(self, sql: str, params: tuple | None = None): + def _exec_sql(self, sql: str, params: tuple | None = None) -> None: conn = self.conn try: c = self.conn.cursor() @@ -144,7 +144,7 @@ def _exec_sql(self, sql: str, params: tuple | None = None): conn.commit() - def insert_platform_metrics(self, metrics: dict): + def insert_platform_metrics(self, metrics: dict) -> None: for k, v in metrics.items(): self._exec_sql( """ @@ -153,7 +153,7 @@ def insert_platform_metrics(self, metrics: dict): (k, v, int(time.time())), ) - def insert_llm_metrics(self, metrics: dict): + def insert_llm_metrics(self, metrics: dict) -> None: for k, v in metrics.items(): self._exec_sql( """ @@ -249,7 +249,7 @@ def get_conversation_by_user_id( return Conversation(*res) - def new_conversation(self, user_id: str, cid: str): + def new_conversation(self, user_id: str, cid: str) -> None: history = "[]" updated_at = int(time.time()) created_at = updated_at @@ -287,7 +287,7 @@ def get_conversations(self, user_id: str) -> list[Conversation]: ) return conversations - def update_conversation(self, user_id: str, cid: str, history: str): + def update_conversation(self, user_id: str, cid: str, history: str) -> None: """更新对话,并且同时更新时间""" updated_at = int(time.time()) self._exec_sql( @@ -297,7 +297,7 @@ def update_conversation(self, user_id: str, cid: str, history: str): (history, updated_at, user_id, cid), ) - def update_conversation_title(self, user_id: str, cid: str, title: str): + def update_conversation_title(self, user_id: str, cid: str, title: str) -> None: self._exec_sql( """ UPDATE webchat_conversation SET title = ? WHERE user_id = ? AND cid = ? @@ -305,7 +305,9 @@ def update_conversation_title(self, user_id: str, cid: str, title: str): (title, user_id, cid), ) - def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str): + def update_conversation_persona_id( + self, user_id: str, cid: str, persona_id: str + ) -> None: self._exec_sql( """ UPDATE webchat_conversation SET persona_id = ? WHERE user_id = ? AND cid = ? @@ -313,7 +315,7 @@ def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str (persona_id, user_id, cid), ) - def delete_conversation(self, user_id: str, cid: str): + def delete_conversation(self, user_id: str, cid: str) -> None: self._exec_sql( """ DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ? diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index 153e13e8b3..922a7299f3 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -305,7 +305,7 @@ async def update_conversation( await session.execute(query) return await self.get_conversation_by_id(cid) - async def delete_conversation(self, cid): + async def delete_conversation(self, cid) -> None: async with self.get_db() as session: session: AsyncSession async with session.begin(): @@ -461,7 +461,7 @@ async def delete_platform_message_offset( platform_id, user_id, offset_sec=86400, - ): + ) -> None: """Delete platform message history records newer than the specified offset.""" async with self.get_db() as session: session: AsyncSession @@ -645,7 +645,7 @@ async def update_persona( await session.execute(query) return await self.get_persona_by_id(persona_id) - async def delete_persona(self, persona_id): + async def delete_persona(self, persona_id) -> None: """Delete a persona by its ID.""" async with self.get_db() as session: session: AsyncSession @@ -903,7 +903,7 @@ async def get_preferences(self, scope, scope_id=None, key=None): result = await session.execute(query) return result.scalars().all() - async def remove_preference(self, scope, scope_id, key): + async def remove_preference(self, scope, scope_id, key) -> None: """Remove a preference by scope ID and key.""" async with self.get_db() as session: session: AsyncSession @@ -917,7 +917,7 @@ async def remove_preference(self, scope, scope_id, key): ) await session.commit() - async def clear_preferences(self, scope, scope_id): + async def clear_preferences(self, scope, scope_id) -> None: """Clear all preferences for a specific scope ID.""" async with self.get_db() as session: session: AsyncSession @@ -1195,7 +1195,7 @@ async def _inner(): result = None - def runner(): + def runner() -> None: nonlocal result result = asyncio.run(_inner()) @@ -1218,7 +1218,7 @@ async def _inner(): result = None - def runner(): + def runner() -> None: nonlocal result result = asyncio.run(_inner()) @@ -1253,7 +1253,7 @@ async def _inner(): result = None - def runner(): + def runner() -> None: nonlocal result result = asyncio.run(_inner()) diff --git a/astrbot/core/db/vec_db/base.py b/astrbot/core/db/vec_db/base.py index 7440b6f2a7..04f8903b15 100644 --- a/astrbot/core/db/vec_db/base.py +++ b/astrbot/core/db/vec_db/base.py @@ -9,7 +9,7 @@ class Result: class BaseVecDB: - async def initialize(self): + async def initialize(self) -> None: """初始化向量数据库""" @abc.abstractmethod diff --git a/astrbot/core/db/vec_db/faiss_impl/document_storage.py b/astrbot/core/db/vec_db/faiss_impl/document_storage.py index e27eb6fe88..2adae69ccc 100644 --- a/astrbot/core/db/vec_db/faiss_impl/document_storage.py +++ b/astrbot/core/db/vec_db/faiss_impl/document_storage.py @@ -33,7 +33,7 @@ class Document(BaseDocModel, table=True): class DocumentStorage: - def __init__(self, db_path: str): + def __init__(self, db_path: str) -> None: self.db_path = db_path self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}" self.engine: AsyncEngine | None = None @@ -43,7 +43,7 @@ def __init__(self, db_path: str): "sqlite_init.sql", ) - async def initialize(self): + async def initialize(self) -> None: """Initialize the SQLite database and create the documents table if it doesn't exist.""" await self.connect() async with self.engine.begin() as conn: # type: ignore @@ -80,7 +80,7 @@ async def initialize(self): await conn.commit() - async def connect(self): + async def connect(self) -> None: """Connect to the SQLite database.""" if self.engine is None: self.engine = create_async_engine( @@ -211,7 +211,7 @@ async def insert_documents_batch( await session.flush() # Flush to get all IDs return [doc.id for doc in documents] # type: ignore - async def delete_document_by_doc_id(self, doc_id: str): + async def delete_document_by_doc_id(self, doc_id: str) -> None: """Delete a document by its doc_id. Args: @@ -249,7 +249,7 @@ async def get_document_by_doc_id(self, doc_id: str): return self._document_to_dict(document) return None - async def update_document_by_doc_id(self, doc_id: str, new_text: str): + async def update_document_by_doc_id(self, doc_id: str, new_text: str) -> None: """Update a document by its doc_id. Args: @@ -269,7 +269,7 @@ async def update_document_by_doc_id(self, doc_id: str, new_text: str): document.updated_at = datetime.now() session.add(document) - async def delete_documents(self, metadata_filters: dict): + async def delete_documents(self, metadata_filters: dict) -> None: """Delete documents by their metadata filters. Args: @@ -384,7 +384,7 @@ async def tuple_to_dict(self, row): "updated_at": row[5], } - async def close(self): + async def close(self) -> None: """Close the connection to the SQLite database.""" if self.engine: await self.engine.dispose() diff --git a/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py b/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py index 564454cb14..dc6977cf8a 100644 --- a/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +++ b/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py @@ -10,7 +10,7 @@ class EmbeddingStorage: - def __init__(self, dimension: int, path: str | None = None): + def __init__(self, dimension: int, path: str | None = None) -> None: self.dimension = dimension self.path = path self.index = None @@ -20,7 +20,7 @@ def __init__(self, dimension: int, path: str | None = None): base_index = faiss.IndexFlatL2(dimension) self.index = faiss.IndexIDMap(base_index) - async def insert(self, vector: np.ndarray, id: int): + async def insert(self, vector: np.ndarray, id: int) -> None: """插入向量 Args: @@ -38,7 +38,7 @@ async def insert(self, vector: np.ndarray, id: int): self.index.add_with_ids(vector.reshape(1, -1), np.array([id])) await self.save_index() - async def insert_batch(self, vectors: np.ndarray, ids: list[int]): + async def insert_batch(self, vectors: np.ndarray, ids: list[int]) -> None: """批量插入向量 Args: @@ -71,7 +71,7 @@ async def search(self, vector: np.ndarray, k: int) -> tuple: distances, indices = self.index.search(vector, k) return distances, indices - async def delete(self, ids: list[int]): + async def delete(self, ids: list[int]) -> None: """删除向量 Args: @@ -83,7 +83,7 @@ async def delete(self, ids: list[int]): self.index.remove_ids(id_array) await self.save_index() - async def save_index(self): + async def save_index(self) -> None: """保存索引 Args: diff --git a/astrbot/core/db/vec_db/faiss_impl/vec_db.py b/astrbot/core/db/vec_db/faiss_impl/vec_db.py index 14221f1e85..3fca246ef5 100644 --- a/astrbot/core/db/vec_db/faiss_impl/vec_db.py +++ b/astrbot/core/db/vec_db/faiss_impl/vec_db.py @@ -20,7 +20,7 @@ def __init__( index_store_path: str, embedding_provider: EmbeddingProvider, rerank_provider: RerankProvider | None = None, - ): + ) -> None: self.doc_store_path = doc_store_path self.index_store_path = index_store_path self.embedding_provider = embedding_provider @@ -32,7 +32,7 @@ def __init__( self.embedding_provider = embedding_provider self.rerank_provider = rerank_provider - async def initialize(self): + async def initialize(self) -> None: await self.document_storage.initialize() async def insert( @@ -165,7 +165,7 @@ async def retrieve( return top_k_results - async def delete(self, doc_id: str): + async def delete(self, doc_id: str) -> None: """删除一条文档块(chunk)""" # 获得对应的 int id result = await self.document_storage.get_document_by_doc_id(doc_id) @@ -177,7 +177,7 @@ async def delete(self, doc_id: str): await self.document_storage.delete_document_by_doc_id(doc_id) await self.embedding_storage.delete([int_id]) - async def close(self): + async def close(self) -> None: await self.document_storage.close() async def count_documents(self, metadata_filter: dict | None = None) -> int: @@ -192,7 +192,7 @@ async def count_documents(self, metadata_filter: dict | None = None) -> int: ) return count - async def delete_documents(self, metadata_filters: dict): + async def delete_documents(self, metadata_filters: dict) -> None: """根据元数据过滤器删除文档""" docs = await self.document_storage.get_documents( metadata_filters=metadata_filters, diff --git a/astrbot/core/event_bus.py b/astrbot/core/event_bus.py index 0017e65fa0..44cdccb83a 100644 --- a/astrbot/core/event_bus.py +++ b/astrbot/core/event_bus.py @@ -28,13 +28,13 @@ def __init__( event_queue: Queue, pipeline_scheduler_mapping: dict[str, PipelineScheduler], astrbot_config_mgr: AstrBotConfigManager, - ): + ) -> None: self.event_queue = event_queue # 事件队列 # abconf uuid -> scheduler self.pipeline_scheduler_mapping = pipeline_scheduler_mapping self.astrbot_config_mgr = astrbot_config_mgr - async def dispatch(self): + async def dispatch(self) -> None: while True: event: AstrMessageEvent = await self.event_queue.get() conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin) @@ -47,7 +47,7 @@ async def dispatch(self): continue asyncio.create_task(scheduler.execute(event)) - def _print_event(self, event: AstrMessageEvent, conf_name: str): + def _print_event(self, event: AstrMessageEvent, conf_name: str) -> None: """用于记录事件信息 Args: diff --git a/astrbot/core/file_token_service.py b/astrbot/core/file_token_service.py index ea97759c18..42fbd23dfe 100644 --- a/astrbot/core/file_token_service.py +++ b/astrbot/core/file_token_service.py @@ -9,12 +9,12 @@ class FileTokenService: """维护一个简单的基于令牌的文件下载服务,支持超时和懒清除。""" - def __init__(self, default_timeout: float = 300): + def __init__(self, default_timeout: float = 300) -> None: self.lock = asyncio.Lock() self.staged_files = {} # token: (file_path, expire_time) self.default_timeout = default_timeout - async def _cleanup_expired_tokens(self): + async def _cleanup_expired_tokens(self) -> None: """清理过期的令牌""" now = time.time() expired_tokens = [ diff --git a/astrbot/core/initial_loader.py b/astrbot/core/initial_loader.py index f54d186416..3f836a4c42 100644 --- a/astrbot/core/initial_loader.py +++ b/astrbot/core/initial_loader.py @@ -17,13 +17,13 @@ class InitialLoader: """AstrBot 启动器,负责初始化和启动核心组件和仪表板服务器。""" - def __init__(self, db: BaseDatabase, log_broker: LogBroker): + def __init__(self, db: BaseDatabase, log_broker: LogBroker) -> None: self.db = db self.logger = logger self.log_broker = log_broker self.webui_dir: str | None = None - async def start(self): + async def start(self) -> None: core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db) try: diff --git a/astrbot/core/knowledge_base/chunking/fixed_size.py b/astrbot/core/knowledge_base/chunking/fixed_size.py index 5439f070ff..c0eb17865f 100644 --- a/astrbot/core/knowledge_base/chunking/fixed_size.py +++ b/astrbot/core/knowledge_base/chunking/fixed_size.py @@ -12,7 +12,7 @@ class FixedSizeChunker(BaseChunker): 按照固定的字符数分块,并支持块之间的重叠。 """ - def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50): + def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50) -> None: """初始化分块器 Args: diff --git a/astrbot/core/knowledge_base/chunking/recursive.py b/astrbot/core/knowledge_base/chunking/recursive.py index 3882b0871a..e27ffbd1b7 100644 --- a/astrbot/core/knowledge_base/chunking/recursive.py +++ b/astrbot/core/knowledge_base/chunking/recursive.py @@ -11,7 +11,7 @@ def __init__( length_function: Callable[[str], int] = len, is_separator_regex: bool = False, separators: list[str] | None = None, - ): + ) -> None: """初始化递归字符文本分割器 Args: diff --git a/astrbot/core/knowledge_base/kb_db_sqlite.py b/astrbot/core/knowledge_base/kb_db_sqlite.py index 5e1db842fe..ba25ed7e54 100644 --- a/astrbot/core/knowledge_base/kb_db_sqlite.py +++ b/astrbot/core/knowledge_base/kb_db_sqlite.py @@ -253,7 +253,7 @@ async def get_document_with_metadata(self, doc_id: str) -> dict | None: "knowledge_base": row[1], } - async def delete_document_by_id(self, doc_id: str, vec_db: FaissVecDB): + async def delete_document_by_id(self, doc_id: str, vec_db: FaissVecDB) -> None: """删除单个文档及其相关数据""" # 在知识库表中删除 async with self.get_db() as session, session.begin(): diff --git a/astrbot/core/knowledge_base/kb_helper.py b/astrbot/core/knowledge_base/kb_helper.py index 4adfb60b8a..1e9127d72a 100644 --- a/astrbot/core/knowledge_base/kb_helper.py +++ b/astrbot/core/knowledge_base/kb_helper.py @@ -31,7 +31,7 @@ class RateLimiter: """一个简单的速率限制器""" - def __init__(self, max_rpm: int): + def __init__(self, max_rpm: int) -> None: self.max_per_minute = max_rpm self.interval = 60.0 / max_rpm if max_rpm > 0 else 0 self.last_call_time = 0 @@ -116,7 +116,7 @@ def __init__( provider_manager: ProviderManager, kb_root_dir: str, chunker: BaseChunker, - ): + ) -> None: self.kb_db = kb_db self.kb = kb self.prov_mgr = provider_manager @@ -130,7 +130,7 @@ def __init__( self.kb_medias_dir.mkdir(parents=True, exist_ok=True) self.kb_files_dir.mkdir(parents=True, exist_ok=True) - async def initialize(self): + async def initialize(self) -> None: await self._ensure_vec_db() async def get_ep(self) -> EmbeddingProvider: @@ -174,7 +174,7 @@ async def _ensure_vec_db(self) -> FaissVecDB: self.vec_db = vec_db return vec_db - async def delete_vec_db(self): + async def delete_vec_db(self) -> None: """删除知识库的向量数据库和所有相关文件""" import shutil @@ -182,7 +182,7 @@ async def delete_vec_db(self): if self.kb_dir.exists(): shutil.rmtree(self.kb_dir) - async def terminate(self): + async def terminate(self) -> None: if self.vec_db: await self.vec_db.close() @@ -293,7 +293,7 @@ async def upload_document( await progress_callback("chunking", 100, 100) # 阶段3: 生成向量(带进度回调) - async def embedding_progress_callback(current, total): + async def embedding_progress_callback(current, total) -> None: if progress_callback: await progress_callback("embedding", current, total) @@ -360,7 +360,7 @@ async def get_document(self, doc_id: str) -> KBDocument | None: doc = await self.kb_db.get_document_by_id(doc_id) return doc - async def delete_document(self, doc_id: str): + async def delete_document(self, doc_id: str) -> None: """删除单个文档及其相关数据""" await self.kb_db.delete_document_by_id( doc_id=doc_id, @@ -372,7 +372,7 @@ async def delete_document(self, doc_id: str): ) await self.refresh_kb() - async def delete_chunk(self, chunk_id: str, doc_id: str): + async def delete_chunk(self, chunk_id: str, doc_id: str) -> None: """删除单个文本块及其相关数据""" vec_db: FaissVecDB = self.vec_db # type: ignore await vec_db.delete(chunk_id) @@ -383,7 +383,7 @@ async def delete_chunk(self, chunk_id: str, doc_id: str): await self.refresh_kb() await self.refresh_document(doc_id) - async def refresh_kb(self): + async def refresh_kb(self) -> None: if self.kb: kb = await self.kb_db.get_kb_by_id(self.kb.kb_id) if kb: diff --git a/astrbot/core/knowledge_base/kb_mgr.py b/astrbot/core/knowledge_base/kb_mgr.py index b085924ca4..ae5a1b9e7c 100644 --- a/astrbot/core/knowledge_base/kb_mgr.py +++ b/astrbot/core/knowledge_base/kb_mgr.py @@ -26,14 +26,14 @@ class KnowledgeBaseManager: def __init__( self, provider_manager: ProviderManager, - ): + ) -> None: Path(DB_PATH).parent.mkdir(parents=True, exist_ok=True) self.provider_manager = provider_manager self._session_deleted_callback_registered = False self.kb_insts: dict[str, KBHelper] = {} - async def initialize(self): + async def initialize(self) -> None: """初始化知识库模块""" try: logger.info("正在初始化知识库模块...") @@ -58,13 +58,13 @@ async def initialize(self): logger.error(f"知识库模块初始化失败: {e}") logger.error(traceback.format_exc()) - async def _init_kb_database(self): + async def _init_kb_database(self) -> None: self.kb_db = KBSQLiteDatabase(DB_PATH.as_posix()) await self.kb_db.initialize() await self.kb_db.migrate_to_v1() logger.info(f"KnowledgeBase database initialized: {DB_PATH}") - async def load_kbs(self): + async def load_kbs(self) -> None: """加载所有知识库实例""" kb_records = await self.kb_db.list_kbs() for record in kb_records: @@ -275,7 +275,7 @@ def _format_context(self, results: list[RetrievalResult]) -> str: return "\n".join(lines) - async def terminate(self): + async def terminate(self) -> None: """终止所有知识库实例,关闭数据库连接""" for kb_id, kb_helper in self.kb_insts.items(): try: diff --git a/astrbot/core/knowledge_base/parsers/url_parser.py b/astrbot/core/knowledge_base/parsers/url_parser.py index f68e2e0c4f..2867164a96 100644 --- a/astrbot/core/knowledge_base/parsers/url_parser.py +++ b/astrbot/core/knowledge_base/parsers/url_parser.py @@ -6,7 +6,7 @@ class URLExtractor: """URL 内容提取器,封装了 Tavily API 调用和密钥管理""" - def __init__(self, tavily_keys: list[str]): + def __init__(self, tavily_keys: list[str]) -> None: """ 初始化 URL 提取器 diff --git a/astrbot/core/knowledge_base/retrieval/manager.py b/astrbot/core/knowledge_base/retrieval/manager.py index 746406e907..d406ceabce 100644 --- a/astrbot/core/knowledge_base/retrieval/manager.py +++ b/astrbot/core/knowledge_base/retrieval/manager.py @@ -44,7 +44,7 @@ def __init__( sparse_retriever: SparseRetriever, rank_fusion: RankFusion, kb_db: KBSQLiteDatabase, - ): + ) -> None: """初始化检索管理器 Args: diff --git a/astrbot/core/knowledge_base/retrieval/rank_fusion.py b/astrbot/core/knowledge_base/retrieval/rank_fusion.py index 26203f94b4..40afd97484 100644 --- a/astrbot/core/knowledge_base/retrieval/rank_fusion.py +++ b/astrbot/core/knowledge_base/retrieval/rank_fusion.py @@ -31,7 +31,7 @@ class RankFusion: - 使用 Reciprocal Rank Fusion (RRF) 算法 """ - def __init__(self, kb_db: KBSQLiteDatabase, k: int = 60): + def __init__(self, kb_db: KBSQLiteDatabase, k: int = 60) -> None: """初始化结果融合器 Args: diff --git a/astrbot/core/knowledge_base/retrieval/sparse_retriever.py b/astrbot/core/knowledge_base/retrieval/sparse_retriever.py index ea5da1c9ec..d453251d17 100644 --- a/astrbot/core/knowledge_base/retrieval/sparse_retriever.py +++ b/astrbot/core/knowledge_base/retrieval/sparse_retriever.py @@ -34,7 +34,7 @@ class SparseRetriever: - 使用 BM25 算法计算相关度 """ - def __init__(self, kb_db: KBSQLiteDatabase): + def __init__(self, kb_db: KBSQLiteDatabase) -> None: """初始化稀疏检索器 Args: diff --git a/astrbot/core/log.py b/astrbot/core/log.py index 49eeb30e4b..264c431972 100644 --- a/astrbot/core/log.py +++ b/astrbot/core/log.py @@ -91,7 +91,7 @@ class LogBroker: 发布-订阅模式 """ - def __init__(self): + def __init__(self) -> None: self.log_cache = deque(maxlen=CACHED_SIZE) # 环形缓冲区, 保存最近的日志 self.subscribers: list[Queue] = [] # 订阅者列表 @@ -106,7 +106,7 @@ def register(self) -> Queue: self.subscribers.append(q) return q - def unregister(self, q: Queue): + def unregister(self, q: Queue) -> None: """取消订阅 Args: @@ -115,7 +115,7 @@ def unregister(self, q: Queue): """ self.subscribers.remove(q) - def publish(self, log_entry: dict): + def publish(self, log_entry: dict) -> None: """发布新日志到所有订阅者, 使用非阻塞方式投递, 避免一个订阅者阻塞整个系统 Args: @@ -137,11 +137,11 @@ class LogQueueHandler(logging.Handler): 继承自 logging.Handler """ - def __init__(self, log_broker: LogBroker): + def __init__(self, log_broker: LogBroker) -> None: super().__init__() self.log_broker = log_broker - def emit(self, record): + def emit(self, record) -> None: """日志处理的入口方法, 接受一个日志记录, 转换为字符串后由 LogBroker 发布 这个方法会在每次日志记录时被调用 @@ -201,7 +201,7 @@ def GetLogger(cls, log_name: str = "default"): class PluginFilter(logging.Filter): """插件过滤器类, 用于标记日志来源是插件还是核心组件""" - def filter(self, record): + def filter(self, record) -> bool: record.plugin_tag = ( "[Plug]" if is_plugin_path(record.pathname) else "[Core]" ) @@ -213,7 +213,7 @@ class FileNameFilter(logging.Filter): """ # 获取这个文件和父文件夹的名字:. 并且去除 .py - def filter(self, record): + def filter(self, record) -> bool: dirname = os.path.dirname(record.pathname) record.filename = ( os.path.basename(dirname) @@ -226,14 +226,14 @@ class LevelNameFilter(logging.Filter): """短日志级别名称过滤器类, 用于将日志级别名称转换为四个字母的缩写""" # 添加短日志级别名称 - def filter(self, record): + def filter(self, record) -> bool: record.short_levelname = get_short_level_name(record.levelname) return True class AstrBotVersionTagFilter(logging.Filter): """在 WARNING 及以上级别日志后追加当前 AstrBot 版本号。""" - def filter(self, record): + def filter(self, record) -> bool: if record.levelno >= logging.WARNING: record.astrbot_version_tag = f" [v{VERSION}]" else: @@ -251,7 +251,7 @@ def filter(self, record): return logger @classmethod - def set_queue_handler(cls, logger: logging.Logger, log_broker: LogBroker): + def set_queue_handler(cls, logger: logging.Logger, log_broker: LogBroker) -> None: """设置队列处理器, 用于将日志消息发送到 LogBroker Args: @@ -301,7 +301,7 @@ def _get_trace_file_handlers(cls, logger: logging.Logger) -> list[logging.Handle ] @classmethod - def _remove_file_handlers(cls, logger: logging.Logger): + def _remove_file_handlers(cls, logger: logging.Logger) -> None: for handler in cls._get_file_handlers(logger): logger.removeHandler(handler) try: @@ -310,7 +310,7 @@ def _remove_file_handlers(cls, logger: logging.Logger): pass @classmethod - def _remove_trace_file_handlers(cls, logger: logging.Logger): + def _remove_trace_file_handlers(cls, logger: logging.Logger) -> None: for handler in cls._get_trace_file_handlers(logger): logger.removeHandler(handler) try: @@ -326,7 +326,7 @@ def _add_file_handler( max_mb: int | None = None, backup_count: int = 3, trace: bool = False, - ): + ) -> None: os.makedirs(os.path.dirname(file_path) or ".", exist_ok=True) max_bytes = 0 if max_mb and max_mb > 0: @@ -365,7 +365,7 @@ def configure_logger( logger: logging.Logger, config: dict | None, override_level: str | None = None, - ): + ) -> None: """根据配置设置日志级别和文件日志。 Args: @@ -413,7 +413,7 @@ def configure_logger( cls._add_file_handler(logger, file_path, max_mb=max_mb) @classmethod - def configure_trace_logger(cls, config: dict | None): + def configure_trace_logger(cls, config: dict | None) -> None: """为 trace 事件配置独立的文件日志,不向控制台输出。""" if not config: return diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 280276089c..a192025dc9 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -66,7 +66,7 @@ class ComponentType(str, Enum): class BaseMessageComponent(BaseModel): type: ComponentType - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: super().__init__(**kwargs) def toDict(self): @@ -89,7 +89,7 @@ class Plain(BaseMessageComponent): text: str convert: bool | None = True - def __init__(self, text: str, convert: bool = True, **_): + def __init__(self, text: str, convert: bool = True, **_) -> None: super().__init__(text=text, convert=convert, **_) def toDict(self): @@ -103,7 +103,7 @@ class Face(BaseMessageComponent): type = ComponentType.Face id: int - def __init__(self, **_): + def __init__(self, **_) -> None: super().__init__(**_) @@ -118,7 +118,7 @@ class Record(BaseMessageComponent): # 额外 path: str | None - def __init__(self, file: str | None, **_): + def __init__(self, file: str | None, **_) -> None: for k in _: if k == "url": pass @@ -221,7 +221,7 @@ class Video(BaseMessageComponent): # 额外 path: str | None = "" - def __init__(self, file: str, **_): + def __init__(self, file: str, **_) -> None: super().__init__(file=file, **_) @staticmethod @@ -255,7 +255,7 @@ async def convert_to_file_path(self) -> str: return os.path.abspath(url) raise Exception(f"not a valid file: {url}") - async def register_to_file_service(self): + async def register_to_file_service(self) -> str: """将视频注册到文件服务。 Returns: @@ -303,7 +303,7 @@ class At(BaseMessageComponent): qq: int | str # 此处str为all时代表所有人 name: str | None = "" - def __init__(self, **_): + def __init__(self, **_) -> None: super().__init__(**_) def toDict(self): @@ -316,28 +316,28 @@ def toDict(self): class AtAll(At): qq: str = "all" - def __init__(self, **_): + def __init__(self, **_) -> None: super().__init__(**_) class RPS(BaseMessageComponent): # TODO type = ComponentType.RPS - def __init__(self, **_): + def __init__(self, **_) -> None: super().__init__(**_) class Dice(BaseMessageComponent): # TODO type = ComponentType.Dice - def __init__(self, **_): + def __init__(self, **_) -> None: super().__init__(**_) class Shake(BaseMessageComponent): # TODO type = ComponentType.Shake - def __init__(self, **_): + def __init__(self, **_) -> None: super().__init__(**_) @@ -348,7 +348,7 @@ class Share(BaseMessageComponent): content: str | None = "" image: str | None = "" - def __init__(self, **_): + def __init__(self, **_) -> None: super().__init__(**_) @@ -357,7 +357,7 @@ class Contact(BaseMessageComponent): # TODO _type: str # type 字段冲突 id: int | None = 0 - def __init__(self, **_): + def __init__(self, **_) -> None: super().__init__(**_) @@ -368,7 +368,7 @@ class Location(BaseMessageComponent): # TODO title: str | None = "" content: str | None = "" - def __init__(self, **_): + def __init__(self, **_) -> None: super().__init__(**_) @@ -382,7 +382,7 @@ class Music(BaseMessageComponent): content: str | None = "" image: str | None = "" - def __init__(self, **_): + def __init__(self, **_) -> None: # for k in _.keys(): # if k == "_type" and _[k] not in ["qq", "163", "xm", "custom"]: # logger.warn(f"Protocol: {k}={_[k]} doesn't match values") @@ -402,7 +402,7 @@ class Image(BaseMessageComponent): path: str | None = "" file_unique: str | None = "" # 某些平台可能有图片缓存的唯一标识 - def __init__(self, file: str | None, **_): + def __init__(self, file: str | None, **_) -> None: super().__init__(file=file, **_) @staticmethod @@ -525,7 +525,7 @@ class Reply(BaseMessageComponent): seq: int | None = 0 """deprecated""" - def __init__(self, **_): + def __init__(self, **_) -> None: super().__init__(**_) @@ -534,7 +534,7 @@ class Poke(BaseMessageComponent): id: int | None = 0 qq: int | None = 0 - def __init__(self, type: str, **_): + def __init__(self, type: str, **_) -> None: type = f"Poke:{type}" super().__init__(type=type, **_) @@ -543,7 +543,7 @@ class Forward(BaseMessageComponent): type = ComponentType.Forward id: str - def __init__(self, **_): + def __init__(self, **_) -> None: super().__init__(**_) @@ -558,7 +558,7 @@ class Node(BaseMessageComponent): seq: str | list | None = "" # 忽略 time: int | None = 0 # 忽略 - def __init__(self, content: list[BaseMessageComponent], **_): + def __init__(self, content: list[BaseMessageComponent], **_) -> None: if isinstance(content, Node): # back content = [content] @@ -605,7 +605,7 @@ class Nodes(BaseMessageComponent): type = ComponentType.Nodes nodes: list[Node] - def __init__(self, nodes: list[Node], **_): + def __init__(self, nodes: list[Node], **_) -> None: super().__init__(nodes=nodes, **_) def toDict(self): @@ -631,7 +631,7 @@ class Json(BaseMessageComponent): type = ComponentType.Json data: dict - def __init__(self, data: str | dict, **_): + def __init__(self, data: str | dict, **_) -> None: if isinstance(data, str): data = json.loads(data) super().__init__(data=data, **_) @@ -650,7 +650,7 @@ class File(BaseMessageComponent): file_: str | None = "" # 本地路径 url: str | None = "" # url - def __init__(self, name: str, file: str = "", url: str = ""): + def __init__(self, name: str, file: str = "", url: str = "") -> None: """文件消息段。""" super().__init__(name=name, file_=file, url=url) @@ -686,7 +686,7 @@ def file(self) -> str: return "" @file.setter - def file(self, value: str): + def file(self, value: str) -> None: """向前兼容, 设置file属性, 传入的参数可能是文件路径或URL Args: @@ -721,7 +721,7 @@ async def get_file(self, allow_return_url: bool = False) -> str: return "" - async def _download_file(self): + async def _download_file(self) -> None: """下载文件""" if not self.url: raise ValueError("Download failed: No URL provided in File component.") @@ -736,7 +736,7 @@ async def _download_file(self): await download_file(self.url, file_path) self.file_ = os.path.abspath(file_path) - async def register_to_file_service(self): + async def register_to_file_service(self) -> str: """将文件注册到文件服务。 Returns: @@ -786,7 +786,7 @@ class WechatEmoji(BaseMessageComponent): md5_len: int | None = 0 cdnurl: str | None = "" - def __init__(self, **_): + def __init__(self, **_) -> None: super().__init__(**_) diff --git a/astrbot/core/persona_mgr.py b/astrbot/core/persona_mgr.py index a2f4065538..f3633f20fa 100644 --- a/astrbot/core/persona_mgr.py +++ b/astrbot/core/persona_mgr.py @@ -17,7 +17,7 @@ class PersonaManager: - def __init__(self, db_helper: BaseDatabase, acm: AstrBotConfigManager): + def __init__(self, db_helper: BaseDatabase, acm: AstrBotConfigManager) -> None: self.db = db_helper self.acm = acm default_ps = acm.default_conf.get("provider_settings", {}) @@ -29,7 +29,7 @@ def __init__(self, db_helper: BaseDatabase, acm: AstrBotConfigManager): self.selected_default_persona_v3: Personality | None = None self.persona_v3_config: list[dict] = [] - async def initialize(self): + async def initialize(self) -> None: self.personas = await self.get_all_personas() self.get_v3_persona_data() logger.info(f"已加载 {len(self.personas)} 个人格。") @@ -58,7 +58,7 @@ async def get_default_persona_v3( except Exception: return DEFAULT_PERSONALITY - async def delete_persona(self, persona_id: str): + async def delete_persona(self, persona_id: str) -> None: """删除指定 persona""" if not await self.db.get_persona_by_id(persona_id): raise ValueError(f"Persona with ID {persona_id} does not exist.") diff --git a/astrbot/core/pipeline/content_safety_check/stage.py b/astrbot/core/pipeline/content_safety_check/stage.py index b089c48e08..19037eb081 100644 --- a/astrbot/core/pipeline/content_safety_check/stage.py +++ b/astrbot/core/pipeline/content_safety_check/stage.py @@ -16,7 +16,7 @@ class ContentSafetyCheckStage(Stage): 当前只会检查文本的。 """ - async def initialize(self, ctx: PipelineContext): + async def initialize(self, ctx: PipelineContext) -> None: config = ctx.astrbot_config["content_safety"] self.strategy_selector = StrategySelector(config) diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index b095a7cbb1..d26f67add9 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -336,7 +336,7 @@ async def _save_to_history( llm_response: LLMResponse | None, all_messages: list[Message], runner_stats: AgentStats | None, - ): + ) -> None: if ( not req or not req.conversation diff --git a/astrbot/core/pipeline/rate_limit_check/stage.py b/astrbot/core/pipeline/rate_limit_check/stage.py index 64e21dd7e5..392bceff30 100644 --- a/astrbot/core/pipeline/rate_limit_check/stage.py +++ b/astrbot/core/pipeline/rate_limit_check/stage.py @@ -19,7 +19,7 @@ class RateLimitStage(Stage): 如果触发限流,将 stall 流水线,直到下一个时间窗口来临时自动唤醒。 """ - def __init__(self): + def __init__(self) -> None: # 存储每个会话的请求时间队列 self.event_timestamps: defaultdict[str, deque[datetime]] = defaultdict(deque) # 为每个会话设置一个锁,避免并发冲突 diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index 60ab168b3f..b57fed29e2 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -35,7 +35,7 @@ class RespondStage(Stage): Comp.WechatEmoji: lambda comp: comp.md5 is not None, # 微信表情 } - async def initialize(self, ctx: PipelineContext): + async def initialize(self, ctx: PipelineContext) -> None: self.ctx = ctx self.config = ctx.astrbot_config self.platform_settings: dict = self.config.get("platform_settings", {}) @@ -91,7 +91,7 @@ async def _calc_comp_interval(self, comp: BaseMessageComponent) -> float: # random return random.uniform(self.interval[0], self.interval[1]) - async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]): + async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]) -> bool: """检查消息链是否为空 Args: diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py index e0bcd5ac9d..823aa0eaa1 100644 --- a/astrbot/core/pipeline/result_decorate/stage.py +++ b/astrbot/core/pipeline/result_decorate/stage.py @@ -20,7 +20,7 @@ @register_stage class ResultDecorateStage(Stage): - async def initialize(self, ctx: PipelineContext): + async def initialize(self, ctx: PipelineContext) -> None: self.ctx = ctx self.reply_prefix = ctx.astrbot_config["platform_settings"]["reply_prefix"] self.reply_with_mention = ctx.astrbot_config["platform_settings"][ diff --git a/astrbot/core/pipeline/scheduler.py b/astrbot/core/pipeline/scheduler.py index 8569f945aa..71c98778fa 100644 --- a/astrbot/core/pipeline/scheduler.py +++ b/astrbot/core/pipeline/scheduler.py @@ -15,21 +15,21 @@ class PipelineScheduler: """管道调度器,负责调度各个阶段的执行""" - def __init__(self, context: PipelineContext): + def __init__(self, context: PipelineContext) -> None: registered_stages.sort( key=lambda x: STAGES_ORDER.index(x.__name__), ) # 按照顺序排序 self.ctx = context # 上下文对象 self.stages = [] # 存储阶段实例 - async def initialize(self): + async def initialize(self) -> None: """初始化管道调度器时, 初始化所有阶段""" for stage_cls in registered_stages: stage_instance = stage_cls() # 创建实例 await stage_instance.initialize(self.ctx) self.stages.append(stage_instance) - async def _process_stages(self, event: AstrMessageEvent, from_stage=0): + async def _process_stages(self, event: AstrMessageEvent, from_stage=0) -> None: """依次执行各个阶段 Args: @@ -72,7 +72,7 @@ async def _process_stages(self, event: AstrMessageEvent, from_stage=0): logger.debug(f"阶段 {stage.__class__.__name__} 已终止事件传播。") break - async def execute(self, event: AstrMessageEvent): + async def execute(self, event: AstrMessageEvent) -> None: """执行 pipeline Args: diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index b99a5778be..83b9813e07 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -38,7 +38,7 @@ def __init__( message_obj: AstrBotMessage, platform_meta: PlatformMetadata, session_id: str, - ): + ) -> None: self.message_str = message_str """纯文本的消息""" self.message_obj = message_obj @@ -91,7 +91,7 @@ def unified_msg_origin(self) -> str: return str(self.session) @unified_msg_origin.setter - def unified_msg_origin(self, value: str): + def unified_msg_origin(self, value: str) -> None: """设置统一的消息来源字符串。格式为 platform_name:message_type:session_id""" self.new_session = MessageSession.from_str(value) self.session = self.new_session @@ -102,7 +102,7 @@ def session_id(self) -> str: return self.session.session_id @session_id.setter - def session_id(self, value: str): + def session_id(self, value: str) -> None: """设置用户的会话 ID。可以直接使用下面的 unified_msg_origin""" self.session.session_id = value @@ -191,7 +191,7 @@ def get_sender_name(self) -> str: return self.message_obj.sender.nickname return "" - def set_extra(self, key, value): + def set_extra(self, key, value) -> None: """设置额外的信息。""" self._extras[key] = value @@ -201,7 +201,7 @@ def get_extra(self, key: str | None = None, default=None) -> Any: return self._extras return self._extras.get(key, default) - def clear_extra(self): + def clear_extra(self) -> None: """清除额外的信息。""" logger.info(f"清除 {self.get_platform_name()} 的额外信息: {self._extras}") self._extras.clear() @@ -234,7 +234,7 @@ async def send_streaming( self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False, - ): + ) -> None: """发送流式消息到消息平台,使用异步生成器。 目前仅支持: telegram,qq official 私聊。 Fallback仅支持 aiocqhttp。 @@ -244,13 +244,13 @@ async def send_streaming( ) self._has_send_oper = True - async def _pre_send(self): + async def _pre_send(self) -> None: """调度器会在执行 send() 前调用该方法 deprecated in v3.5.18""" - async def _post_send(self): + async def _post_send(self) -> None: """调度器会在执行 send() 后调用该方法 deprecated in v3.5.18""" - def set_result(self, result: MessageEventResult | str): + def set_result(self, result: MessageEventResult | str) -> None: """设置消息事件的结果。 Note: @@ -279,14 +279,14 @@ async def check_count(self, event: AstrMessageEvent): result.chain = [] self._result = result - def stop_event(self): + def stop_event(self) -> None: """终止事件传播。""" if self._result is None: self.set_result(MessageEventResult().stop_event()) else: self._result.stop_event() - def continue_event(self): + def continue_event(self) -> None: """继续事件传播。""" if self._result is None: self.set_result(MessageEventResult().continue_event()) @@ -299,7 +299,7 @@ def is_stopped(self) -> bool: return False # 默认是继续传播 return self._result.is_stopped() - def should_call_llm(self, call_llm: bool): + def should_call_llm(self, call_llm: bool) -> None: """是否在此消息事件中禁止默认的 LLM 请求。 只会阻止 AstrBot 默认的 LLM 请求链路,不会阻止插件中的 LLM 请求。 @@ -310,7 +310,7 @@ def get_result(self) -> MessageEventResult | None: """获取消息事件的结果。""" return self._result - def clear_result(self): + def clear_result(self) -> None: """清除消息事件的结果。""" self._result = None @@ -404,7 +404,7 @@ def request_llm( """平台适配器""" - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: """发送消息到消息平台。 Args: @@ -423,7 +423,7 @@ async def send(self, message: MessageChain): ) self._has_send_oper = True - async def react(self, emoji: str): + async def react(self, emoji: str) -> None: """对消息添加表情回应。 默认实现为发送一条包含该表情的消息。 diff --git a/astrbot/core/platform/astrbot_message.py b/astrbot/core/platform/astrbot_message.py index 253963322b..3db53fd484 100644 --- a/astrbot/core/platform/astrbot_message.py +++ b/astrbot/core/platform/astrbot_message.py @@ -11,7 +11,7 @@ class MessageMember: user_id: str # 发送者id nickname: str | None = None - def __str__(self): + def __str__(self) -> str: # 使用 f-string 来构建返回的字符串表示形式 return ( f"User ID: {self.user_id}," @@ -34,7 +34,7 @@ class Group: members: list[MessageMember] | None = None """所有群成员""" - def __str__(self): + def __str__(self) -> str: # 使用 f-string 来构建返回的字符串表示形式 return ( f"Group ID: {self.group_id}\n" @@ -78,7 +78,7 @@ def group_id(self) -> str: return "" @group_id.setter - def group_id(self, value: str | None): + def group_id(self, value: str | None) -> None: """设置 group_id""" if value: if self.group: diff --git a/astrbot/core/platform/manager.py b/astrbot/core/platform/manager.py index c8043e56b2..9cec8a9421 100644 --- a/astrbot/core/platform/manager.py +++ b/astrbot/core/platform/manager.py @@ -13,7 +13,7 @@ class PlatformManager: - def __init__(self, config: AstrBotConfig, event_queue: Queue): + def __init__(self, config: AstrBotConfig, event_queue: Queue) -> None: self.platform_insts: list[Platform] = [] """加载的 Platform 的实例""" @@ -38,7 +38,7 @@ def _sanitize_platform_id(self, platform_id: str | None) -> tuple[str | None, bo sanitized = platform_id.replace(":", "_").replace("!", "_") return sanitized, sanitized != platform_id - async def initialize(self): + async def initialize(self) -> None: """初始化所有平台适配器""" for platform in self.platforms_config: try: @@ -58,7 +58,7 @@ async def initialize(self): ), ) - async def load_platform(self, platform_config: dict): + async def load_platform(self, platform_config: dict) -> None: """实例化一个平台""" # 动态导入 try: @@ -176,7 +176,9 @@ async def load_platform(self, platform_config: dict): except Exception: logger.error(traceback.format_exc()) - async def _task_wrapper(self, task: asyncio.Task, platform: Platform | None = None): + async def _task_wrapper( + self, task: asyncio.Task, platform: Platform | None = None + ) -> None: # 设置平台状态为运行中 if platform: platform.status = PlatformStatus.RUNNING @@ -198,7 +200,7 @@ async def _task_wrapper(self, task: asyncio.Task, platform: Platform | None = No if platform: platform.record_error(error_msg, tb_str) - async def reload(self, platform_config: dict): + async def reload(self, platform_config: dict) -> None: await self.terminate_platform(platform_config["id"]) if platform_config["enable"]: await self.load_platform(platform_config) @@ -209,7 +211,7 @@ async def reload(self, platform_config: dict): if key not in config_ids: await self.terminate_platform(key) - async def terminate_platform(self, platform_id: str): + async def terminate_platform(self, platform_id: str) -> None: if platform_id in self._inst_map: logger.info(f"正在尝试终止 {platform_id} 平台适配器 ...") @@ -231,7 +233,7 @@ async def terminate_platform(self, platform_id: str): if getattr(inst, "terminate", None): await inst.terminate() - async def terminate(self): + async def terminate(self) -> None: for inst in self.platform_insts: if getattr(inst, "terminate", None): await inst.terminate() diff --git a/astrbot/core/platform/message_session.py b/astrbot/core/platform/message_session.py index b282b307ae..89639941eb 100644 --- a/astrbot/core/platform/message_session.py +++ b/astrbot/core/platform/message_session.py @@ -15,7 +15,7 @@ class MessageSession: session_id: str platform_id: str = field(init=False) - def __str__(self): + def __str__(self) -> str: return f"{self.platform_id}:{self.message_type.value}:{self.session_id}" def __post_init__(self): diff --git a/astrbot/core/platform/platform.py b/astrbot/core/platform/platform.py index 8592273d18..a7c181217d 100644 --- a/astrbot/core/platform/platform.py +++ b/astrbot/core/platform/platform.py @@ -34,7 +34,7 @@ class PlatformError: class Platform(abc.ABC): - def __init__(self, config: dict, event_queue: Queue): + def __init__(self, config: dict, event_queue: Queue) -> None: super().__init__() # 平台配置 self.config = config @@ -53,7 +53,7 @@ def status(self) -> PlatformStatus: return self._status @status.setter - def status(self, value: PlatformStatus): + def status(self, value: PlatformStatus) -> None: """设置平台运行状态""" self._status = value if value == PlatformStatus.RUNNING and self._started_at is None: @@ -69,12 +69,12 @@ def last_error(self) -> PlatformError | None: """获取最近的错误""" return self._errors[-1] if self._errors else None - def record_error(self, message: str, traceback_str: str | None = None): + def record_error(self, message: str, traceback_str: str | None = None) -> None: """记录一个错误""" self._errors.append(PlatformError(message=message, traceback=traceback_str)) self._status = PlatformStatus.ERROR - def clear_errors(self): + def clear_errors(self) -> None: """清除错误记录""" self._errors.clear() if self._status == PlatformStatus.ERROR: @@ -121,7 +121,7 @@ def run(self) -> Coroutine[Any, Any, None]: """得到一个平台的运行实例,需要返回一个协程对象。""" raise NotImplementedError - async def terminate(self): + async def terminate(self) -> None: """终止一个平台的运行实例。""" @abc.abstractmethod @@ -140,11 +140,11 @@ async def send_by_session( """ await Metric.upload(msg_event_tick=1, adapter_name=self.meta().name) - def commit_event(self, event: AstrMessageEvent): + def commit_event(self, event: AstrMessageEvent) -> None: """提交一个事件到事件队列。""" self._event_queue.put_nowait(event) - def get_client(self): + def get_client(self) -> object: """获取平台的客户端对象。""" async def webhook_callback(self, request: Any) -> Any: diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py index 3d84cbd441..99ea727315 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py @@ -26,7 +26,7 @@ def __init__( platform_meta, session_id, bot: CQHttp, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.bot = bot @@ -72,7 +72,7 @@ async def _dispatch_send( is_group: bool, session_id: str | None, messages: list[dict], - ): + ) -> None: # session_id 必须是纯数字字符串 session_id_int = ( int(session_id) if session_id and session_id.isdigit() else None @@ -97,7 +97,7 @@ async def send_message( event: Event | None = None, is_group: bool = False, session_id: str | None = None, - ): + ) -> None: """发送消息至 QQ 协议端(aiocqhttp)。 Args: @@ -143,7 +143,7 @@ async def send_message( await cls._dispatch_send(bot, event, is_group, session_id, messages) await asyncio.sleep(0.5) - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: """发送消息""" event = getattr(self.message_obj, "raw_message", None) diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py index d4d8e1d62a..fb6c997848 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py @@ -61,7 +61,7 @@ def __init__( ) @self.bot.on_request() - async def request(event: Event): + async def request(event: Event) -> None: try: abm = await self.convert_message(event) if not abm: @@ -72,7 +72,7 @@ async def request(event: Event): return @self.bot.on_notice() - async def notice(event: Event): + async def notice(event: Event) -> None: try: abm = await self.convert_message(event) if abm: @@ -82,7 +82,7 @@ async def notice(event: Event): return @self.bot.on_message("group") - async def group(event: Event): + async def group(event: Event) -> None: try: abm = await self.convert_message(event) if abm: @@ -92,7 +92,7 @@ async def group(event: Event): return @self.bot.on_message("private") - async def private(event: Event): + async def private(event: Event) -> None: try: abm = await self.convert_message(event) if abm: @@ -102,14 +102,14 @@ async def private(event: Event): return @self.bot.on_websocket_connection - def on_websocket_connection(_): + def on_websocket_connection(_) -> None: logger.info("aiocqhttp(OneBot v11) 适配器已连接。") async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: is_group = session.message_type == MessageType.GROUP_MESSAGE if is_group: session_id = session.session_id.split("_")[-1] @@ -435,17 +435,17 @@ def run(self) -> Awaitable[Any]: self.shutdown_event = asyncio.Event() return coro - async def terminate(self): + async def terminate(self) -> None: self.shutdown_event.set() - async def shutdown_trigger_placeholder(self): + async def shutdown_trigger_placeholder(self) -> None: await self.shutdown_event.wait() logger.info("aiocqhttp 适配器已被关闭") def meta(self) -> PlatformMetadata: return self.metadata - async def handle_msg(self, message: AstrBotMessage): + async def handle_msg(self, message: AstrBotMessage) -> None: message_event = AiocqhttpMessageEvent( message_str=message.message_str, message_obj=message, diff --git a/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py b/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py index 8c93ab40f8..41f4aa6117 100644 --- a/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +++ b/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py @@ -2,7 +2,7 @@ import os import threading import uuid -from typing import cast +from typing import NoReturn, cast import aiohttp import dingtalk_stream @@ -90,7 +90,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: raise NotImplementedError("钉钉机器人适配器不支持 send_by_session") def meta(self) -> PlatformMetadata: @@ -104,7 +104,7 @@ def meta(self) -> PlatformMetadata: async def create_message_card( self, message_id: str, incoming_message: dingtalk_stream.ChatbotMessage - ): + ) -> bool | None: if not self.card_template_id: return False @@ -122,7 +122,9 @@ async def create_message_card( logger.error(f"创建钉钉卡片失败: {e}") return False - async def send_card_message(self, message_id: str, content: str, is_final: bool): + async def send_card_message( + self, message_id: str, content: str, is_final: bool + ) -> None: if message_id not in self.card_instance_id_dict: return @@ -276,7 +278,7 @@ async def get_access_token(self) -> str: return "" return (await resp.json())["data"]["accessToken"] - async def handle_msg(self, abm: AstrBotMessage): + async def handle_msg(self, abm: AstrBotMessage) -> None: event = DingtalkMessageEvent( message_str=abm.message_str, message_obj=abm, @@ -288,10 +290,10 @@ async def handle_msg(self, abm: AstrBotMessage): self._event_queue.put_nowait(event) - async def run(self): + async def run(self) -> None: # await self.client_.start() # 钉钉的 SDK 并没有实现真正的异步,start() 里面有堵塞方法。 - def start_client(loop: asyncio.AbstractEventLoop): + def start_client(loop: asyncio.AbstractEventLoop) -> None: try: self._shutdown_event = threading.Event() task = loop.create_task(self.client_.start()) @@ -307,8 +309,8 @@ def start_client(loop: asyncio.AbstractEventLoop): loop = asyncio.get_event_loop() await loop.run_in_executor(None, start_client, loop) - async def terminate(self): - def monkey_patch_close(): + async def terminate(self) -> None: + def monkey_patch_close() -> NoReturn: raise KeyboardInterrupt("Graceful shutdown") if self.client_.websocket is not None: diff --git a/astrbot/core/platform/sources/dingtalk/dingtalk_event.py b/astrbot/core/platform/sources/dingtalk/dingtalk_event.py index 5af0d6eb0b..6c7d2ff4ce 100644 --- a/astrbot/core/platform/sources/dingtalk/dingtalk_event.py +++ b/astrbot/core/platform/sources/dingtalk/dingtalk_event.py @@ -17,7 +17,7 @@ def __init__( session_id, client: dingtalk_stream.ChatbotHandler, adapter: "Any" = None, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.client = client self.adapter = adapter @@ -26,7 +26,7 @@ async def send_with_client( self, client: dingtalk_stream.ChatbotHandler, message: MessageChain, - ): + ) -> None: icm = cast(dingtalk_stream.ChatbotMessage, self.message_obj.raw_message) ats = [] # fixes: #4218 @@ -80,7 +80,7 @@ async def send_with_client( logger.warning(f"钉钉图片处理失败: {e}, 跳过图片发送") continue - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: await self.send_with_client(self.client, message) await super().send(message) diff --git a/astrbot/core/platform/sources/discord/client.py b/astrbot/core/platform/sources/discord/client.py index ac0610f2a8..ebd32c471a 100644 --- a/astrbot/core/platform/sources/discord/client.py +++ b/astrbot/core/platform/sources/discord/client.py @@ -15,7 +15,7 @@ class DiscordBotClient(discord.Bot): """Discord客户端封装""" - def __init__(self, token: str, proxy: str | None = None): + def __init__(self, token: str, proxy: str | None = None) -> None: self.token = token self.proxy = proxy @@ -32,7 +32,7 @@ def __init__(self, token: str, proxy: str | None = None): self.on_ready_once_callback: Callable[[], Awaitable[None]] | None = None self._ready_once_fired = False - async def on_ready(self): + async def on_ready(self) -> None: """当机器人成功连接并准备就绪时触发""" if self.user is None: logger.error("[Discord] 客户端未正确加载用户信息 (self.user is None)") @@ -93,7 +93,7 @@ def _create_interaction_data(self, interaction: discord.Interaction) -> dict: "type": "interaction", } - async def on_message(self, message: discord.Message): + async def on_message(self, message: discord.Message) -> None: """当接收到消息时触发""" if message.author.bot: return @@ -130,12 +130,12 @@ def _extract_interaction_content(self, interaction: discord.Interaction) -> str: return str(interaction_data) - async def start_polling(self): + async def start_polling(self) -> None: """开始轮询消息,这是个阻塞方法""" await self.start(self.token) @override - async def close(self): + async def close(self) -> None: """关闭客户端""" if not self.is_closed(): await super().close() diff --git a/astrbot/core/platform/sources/discord/components.py b/astrbot/core/platform/sources/discord/components.py index f875652a07..433509f5e1 100644 --- a/astrbot/core/platform/sources/discord/components.py +++ b/astrbot/core/platform/sources/discord/components.py @@ -19,7 +19,7 @@ def __init__( image: str | None = None, footer: str | None = None, fields: list[dict] | None = None, - ): + ) -> None: self.title = title self.description = description self.color = color @@ -71,7 +71,7 @@ def __init__( emoji: str | None = None, url: str | None = None, disabled: bool = False, - ): + ) -> None: self.label = label self.custom_id = custom_id self.style = style @@ -85,7 +85,7 @@ class DiscordReference(BaseMessageComponent): type: str = "discord_reference" - def __init__(self, message_id: str, channel_id: str): + def __init__(self, message_id: str, channel_id: str) -> None: self.message_id = message_id self.channel_id = channel_id @@ -99,7 +99,7 @@ def __init__( self, components: list[BaseMessageComponent] | None = None, timeout: float | None = None, - ): + ) -> None: self.components = components or [] self.timeout = timeout diff --git a/astrbot/core/platform/sources/discord/discord_platform_adapter.py b/astrbot/core/platform/sources/discord/discord_platform_adapter.py index ed9899f6f6..7657962a11 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_adapter.py +++ b/astrbot/core/platform/sources/discord/discord_platform_adapter.py @@ -60,7 +60,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: """通过会话发送消息""" if self.client.user is None: logger.error( @@ -122,11 +122,11 @@ def meta(self) -> PlatformMetadata: ) @override - async def run(self): + async def run(self) -> None: """主要运行逻辑""" # 初始化回调函数 - async def on_received(message_data): + async def on_received(message_data) -> None: logger.debug(f"[Discord] 收到消息: {message_data}") if self.client_self_id is None: self.client_self_id = message_data.get("bot_id") @@ -143,7 +143,7 @@ async def on_received(message_data): self.client = DiscordBotClient(token, proxy) self.client.on_message_received = on_received - async def callback(): + async def callback() -> None: if self.enable_command_register: await self._collect_and_register_commands() if self.activity_name: @@ -251,7 +251,7 @@ async def convert_message(self, data: dict) -> AstrBotMessage: # 由于 on_interaction 已被禁用,我们只处理普通消息 return self._convert_message_to_abm(data) - async def handle_msg(self, message: AstrBotMessage, followup_webhook=None): + async def handle_msg(self, message: AstrBotMessage, followup_webhook=None) -> None: """处理消息""" message_event = DiscordPlatformEvent( message_str=message.message_str, @@ -323,7 +323,7 @@ async def handle_msg(self, message: AstrBotMessage, followup_webhook=None): self.commit_event(message_event) @override - async def terminate(self): + async def terminate(self) -> None: """终止适配器""" logger.info("[Discord] 正在终止适配器... (step 1: cancel polling task)") self.shutdown_event.set() @@ -358,11 +358,11 @@ async def terminate(self): logger.warning(f"[Discord] 客户端关闭异常: {e}") logger.info("[Discord] 适配器已终止。") - def register_handler(self, handler_info): + def register_handler(self, handler_info) -> None: """注册处理器信息""" self.registered_handlers.append(handler_info) - async def _collect_and_register_commands(self): + async def _collect_and_register_commands(self) -> None: """收集所有指令并注册到Discord""" logger.info("[Discord] 开始收集并注册斜杠指令...") registered_commands = [] @@ -420,7 +420,7 @@ def _create_dynamic_callback(self, cmd_name: str): async def dynamic_callback( ctx: discord.ApplicationContext, params: str | None = None - ): + ) -> None: # 将平台特定的前缀'/'剥离,以适配通用的CommandFilter logger.debug(f"[Discord] 回调函数触发: {cmd_name}") logger.debug(f"[Discord] 回调函数参数: {ctx}") diff --git a/astrbot/core/platform/sources/discord/discord_platform_event.py b/astrbot/core/platform/sources/discord/discord_platform_event.py index 0530182259..02d4dae868 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_event.py +++ b/astrbot/core/platform/sources/discord/discord_platform_event.py @@ -28,7 +28,7 @@ class DiscordViewComponent(BaseMessageComponent): type: str = "discord_view" - def __init__(self, view: discord.ui.View): + def __init__(self, view: discord.ui.View) -> None: self.view = view @@ -41,12 +41,12 @@ def __init__( session_id: str, client: DiscordBotClient, interaction_followup_webhook: discord.Webhook | None = None, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.client = client self.interaction_followup_webhook = interaction_followup_webhook - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: """发送消息到Discord平台""" # 解析消息链为 Discord 所需的对象 try: @@ -267,7 +267,7 @@ async def _parse_to_discord( content = content[:2000] return content, files, view, embeds, reference_message_id - async def react(self, emoji: str): + async def react(self, emoji: str) -> None: """对原消息添加反应""" try: if hasattr(self.message_obj, "raw_message") and hasattr( diff --git a/astrbot/core/platform/sources/lark/lark_adapter.py b/astrbot/core/platform/sources/lark/lark_adapter.py index 09044c56bd..e76572768e 100644 --- a/astrbot/core/platform/sources/lark/lark_adapter.py +++ b/astrbot/core/platform/sources/lark/lark_adapter.py @@ -53,10 +53,10 @@ def __init__( logger.warning("未设置飞书机器人名称,@ 机器人可能得不到回复。") # 初始化 WebSocket 长连接相关配置 - async def on_msg_event_recv(event: lark.im.v1.P2ImMessageReceiveV1): + async def on_msg_event_recv(event: lark.im.v1.P2ImMessageReceiveV1) -> None: await self.convert_msg(event) - def do_v2_msg_event(event: lark.im.v1.P2ImMessageReceiveV1): + def do_v2_msg_event(event: lark.im.v1.P2ImMessageReceiveV1) -> None: asyncio.create_task(on_msg_event_recv(event)) self.event_handler = ( @@ -91,7 +91,7 @@ def do_v2_msg_event(event: lark.im.v1.P2ImMessageReceiveV1): self.event_id_timestamps: dict[str, float] = {} - def _clean_expired_events(self): + def _clean_expired_events(self) -> None: """清理超过 30 分钟的事件记录""" current_time = time.time() expired_keys = [ @@ -121,7 +121,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: if session.message_type == MessageType.GROUP_MESSAGE: id_type = "chat_id" receive_id = session.session_id @@ -149,7 +149,7 @@ def meta(self) -> PlatformMetadata: support_streaming_message=False, ) - async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1): + async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1) -> None: if event.event is None: logger.debug("[Lark] 收到空事件(event.event is None)") return @@ -299,7 +299,7 @@ async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1): logger.debug(abm) await self.handle_msg(abm) - async def handle_msg(self, abm: AstrBotMessage): + async def handle_msg(self, abm: AstrBotMessage) -> None: event = LarkMessageEvent( message_str=abm.message_str, message_obj=abm, @@ -310,7 +310,7 @@ async def handle_msg(self, abm: AstrBotMessage): self._event_queue.put_nowait(event) - async def handle_webhook_event(self, event_data: dict): + async def handle_webhook_event(self, event_data: dict) -> None: """处理 Webhook 事件 Args: @@ -332,7 +332,7 @@ async def handle_webhook_event(self, event_data: dict): except Exception as e: logger.error(f"[Lark Webhook] 处理事件失败: {e}", exc_info=True) - async def run(self): + async def run(self) -> None: if self.connection_mode == "webhook": # Webhook 模式 if self.webhook_server is None: @@ -355,7 +355,7 @@ async def webhook_callback(self, request: Any) -> Any: return await self.webhook_server.handle_callback(request) - async def terminate(self): + async def terminate(self) -> None: if self.connection_mode == "socket": await self.client._disconnect() logger.info("飞书(Lark) 适配器已关闭") diff --git a/astrbot/core/platform/sources/lark/lark_event.py b/astrbot/core/platform/sources/lark/lark_event.py index b5d4eb1c4d..83a4558889 100644 --- a/astrbot/core/platform/sources/lark/lark_event.py +++ b/astrbot/core/platform/sources/lark/lark_event.py @@ -38,7 +38,7 @@ def __init__( platform_meta, session_id, bot: lark.Client, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.bot = bot @@ -274,7 +274,7 @@ async def send_message_chain( reply_message_id: str | None = None, receive_id: str | None = None, receive_id_type: str | None = None, - ): + ) -> None: """通用的消息链发送方法 Args: @@ -342,7 +342,7 @@ async def send_message_chain( media_comp, lark_client, reply_message_id, receive_id, receive_id_type ) - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: """发送消息链到飞书,然后交给父类做框架级发送/记录""" await LarkMessageEvent.send_message_chain( message, @@ -358,7 +358,7 @@ async def _send_file_message( reply_message_id: str | None = None, receive_id: str | None = None, receive_id_type: str | None = None, - ): + ) -> None: """发送文件消息 Args: @@ -392,7 +392,7 @@ async def _send_audio_message( reply_message_id: str | None = None, receive_id: str | None = None, receive_id_type: str | None = None, - ): + ) -> None: """发送音频消息 Args: @@ -465,7 +465,7 @@ async def _send_media_message( reply_message_id: str | None = None, receive_id: str | None = None, receive_id_type: str | None = None, - ): + ) -> None: """发送视频消息 Args: @@ -531,7 +531,7 @@ async def _send_media_message( receive_id_type=receive_id_type, ) - async def react(self, emoji: str): + async def react(self, emoji: str) -> None: if self.bot.im is None: logger.error("[Lark] API Client im 模块未初始化,无法发送表情") return diff --git a/astrbot/core/platform/sources/lark/server.py b/astrbot/core/platform/sources/lark/server.py index 3921eb8be2..52177ebb0c 100644 --- a/astrbot/core/platform/sources/lark/server.py +++ b/astrbot/core/platform/sources/lark/server.py @@ -21,7 +21,7 @@ class AESCipher: """AES 加密/解密工具类""" - def __init__(self, key: str): + def __init__(self, key: str) -> None: self.bs = AES.block_size self.key = hashlib.sha256(self.str_to_bytes(key)).digest() @@ -52,7 +52,7 @@ class LarkWebhookServer: 仅支持统一 Webhook 模式 """ - def __init__(self, config: dict, event_queue: asyncio.Queue): + def __init__(self, config: dict, event_queue: asyncio.Queue) -> None: """初始化 Webhook 服务器 Args: @@ -197,7 +197,7 @@ async def handle_callback(self, request) -> tuple[dict, int] | dict: return {} - def set_callback(self, callback: Callable[[dict], Awaitable[None]]): + def set_callback(self, callback: Callable[[dict], Awaitable[None]]) -> None: """设置事件回调函数 Args: diff --git a/astrbot/core/platform/sources/misskey/misskey_adapter.py b/astrbot/core/platform/sources/misskey/misskey_adapter.py index d8f560b1bd..34415b855b 100644 --- a/astrbot/core/platform/sources/misskey/misskey_adapter.py +++ b/astrbot/core/platform/sources/misskey/misskey_adapter.py @@ -121,7 +121,7 @@ def meta(self) -> PlatformMetadata: support_streaming_message=False, ) - async def run(self): + async def run(self) -> None: if not self.instance_url or not self.access_token: logger.error("[Misskey] 配置不完整,无法启动") return @@ -150,7 +150,7 @@ async def run(self): await self._start_websocket_connection() - def _register_event_handlers(self, streaming): + def _register_event_handlers(self, streaming) -> None: """注册事件处理器""" streaming.add_message_handler("notification", self._handle_notification) streaming.add_message_handler("main:notification", self._handle_notification) @@ -194,7 +194,7 @@ def _process_poll_data( message: AstrBotMessage, poll: dict[str, Any], message_parts: list[str], - ): + ) -> None: """处理投票数据,将其添加到消息中""" try: if not isinstance(message.raw_message, dict): @@ -233,7 +233,7 @@ def _extract_additional_fields(self, session, message_chain) -> dict[str, Any]: return fields - async def _start_websocket_connection(self): + async def _start_websocket_connection(self) -> None: backoff_delay = 1.0 max_backoff = 300.0 backoff_multiplier = 1.5 @@ -281,7 +281,7 @@ async def _start_websocket_connection(self): await asyncio.sleep(sleep_time) backoff_delay = min(backoff_delay * backoff_multiplier, max_backoff) - async def _handle_notification(self, data: dict[str, Any]): + async def _handle_notification(self, data: dict[str, Any]) -> None: try: notification_type = data.get("type") logger.debug( @@ -305,7 +305,7 @@ async def _handle_notification(self, data: dict[str, Any]): except Exception as e: logger.error(f"[Misskey] 处理通知失败: {e}") - async def _handle_chat_message(self, data: dict[str, Any]): + async def _handle_chat_message(self, data: dict[str, Any]) -> None: try: sender_id = str( data.get("fromUserId", "") or data.get("fromUser", {}).get("id", ""), @@ -340,7 +340,7 @@ async def _handle_chat_message(self, data: dict[str, Any]): except Exception as e: logger.error(f"[Misskey] 处理聊天消息失败: {e}") - async def _debug_handler(self, data: dict[str, Any]): + async def _debug_handler(self, data: dict[str, Any]) -> None: event_type = data.get("type", "unknown") logger.debug( f"[Misskey] 收到未处理事件: type={event_type}, channel={data.get('channel', 'unknown')}", @@ -754,7 +754,7 @@ async def convert_room_message(self, raw_data: dict[str, Any]) -> AstrBotMessage ) return message - async def terminate(self): + async def terminate(self) -> None: self._running = False if self.api: await self.api.close() diff --git a/astrbot/core/platform/sources/misskey/misskey_api.py b/astrbot/core/platform/sources/misskey/misskey_api.py index 06dc6304d3..3e5eb9a90e 100644 --- a/astrbot/core/platform/sources/misskey/misskey_api.py +++ b/astrbot/core/platform/sources/misskey/misskey_api.py @@ -3,7 +3,7 @@ import random import uuid from collections.abc import Awaitable, Callable -from typing import Any +from typing import Any, NoReturn try: import aiohttp @@ -43,7 +43,7 @@ class WebSocketError(APIError): class StreamingClient: - def __init__(self, instance_url: str, access_token: str): + def __init__(self, instance_url: str, access_token: str) -> None: self.instance_url = instance_url.rstrip("/") self.access_token = access_token self.websocket: Any | None = None @@ -90,7 +90,7 @@ async def connect(self) -> bool: self.is_connected = False return False - async def disconnect(self): + async def disconnect(self) -> None: self._running = False if self.websocket: await self.websocket.close() @@ -116,7 +116,7 @@ async def subscribe_channel( self.channels[channel_id] = channel_type return channel_id - async def unsubscribe_channel(self, channel_id: str): + async def unsubscribe_channel(self, channel_id: str) -> None: if ( not self.is_connected or not self.websocket @@ -136,10 +136,10 @@ def add_message_handler( self, event_type: str, handler: Callable[[dict], Awaitable[None]], - ): + ) -> None: self.message_handlers[event_type] = handler - async def listen(self): + async def listen(self) -> None: if not self.is_connected or not self.websocket: raise WebSocketError("WebSocket 未连接") @@ -187,7 +187,7 @@ async def listen(self): except Exception: pass - async def _handle_message(self, data: dict[str, Any]): + async def _handle_message(self, data: dict[str, Any]) -> None: message_type = data.get("type") body = data.get("body", {}) @@ -334,7 +334,7 @@ def __init__( download_timeout: int = 15, chunk_size: int = 64 * 1024, max_download_bytes: int | None = None, - ): + ) -> None: self.instance_url = instance_url.rstrip("/") self.access_token = access_token self._session: aiohttp.ClientSession | None = None @@ -375,7 +375,7 @@ def session(self) -> aiohttp.ClientSession: self._session = aiohttp.ClientSession(headers=headers) return self._session - def _handle_response_status(self, status: int, endpoint: str): + def _handle_response_status(self, status: int, endpoint: str) -> NoReturn: """处理 HTTP 响应状态码""" if status == 400: logger.error(f"[Misskey API] 请求参数错误: {endpoint} (HTTP {status})") @@ -449,7 +449,6 @@ async def _process_response( ) self._handle_response_status(response.status, endpoint) - raise APIConnectionError(f"Request failed for {endpoint}") @retry_async( max_retries=API_MAX_RETRIES, diff --git a/astrbot/core/platform/sources/misskey/misskey_event.py b/astrbot/core/platform/sources/misskey/misskey_event.py index 7975f0ec72..068f7e7a28 100644 --- a/astrbot/core/platform/sources/misskey/misskey_event.py +++ b/astrbot/core/platform/sources/misskey/misskey_event.py @@ -26,7 +26,7 @@ def __init__( platform_meta: PlatformMetadata, session_id: str, client, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.client = client @@ -40,7 +40,7 @@ def _is_system_command(self, message_str: str) -> bool: return any(message_trimmed.startswith(prefix) for prefix in system_prefixes) - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: """发送消息,使用适配器的完整上传和发送逻辑""" try: logger.debug( diff --git a/astrbot/core/platform/sources/misskey/misskey_utils.py b/astrbot/core/platform/sources/misskey/misskey_utils.py index d9388598dd..dd02c13c01 100644 --- a/astrbot/core/platform/sources/misskey/misskey_utils.py +++ b/astrbot/core/platform/sources/misskey/misskey_utils.py @@ -403,7 +403,7 @@ def cache_user_info( raw_data: dict[str, Any], client_self_id: str, is_chat: bool = False, -): +) -> None: """缓存用户信息""" if is_chat: user_cache_data = { @@ -429,7 +429,7 @@ def cache_room_info( user_cache: dict[str, Any], raw_data: dict[str, Any], client_self_id: str, -): +) -> None: """缓存房间信息""" room_data = raw_data.get("toRoom") room_id = raw_data.get("toRoomId") diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py index 6076bfc1b7..88c8fc225e 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py @@ -32,12 +32,12 @@ def __init__( platform_meta: PlatformMetadata, session_id: str, bot: Client, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.bot = bot self.send_buffer = None - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: self.send_buffer = message await self._post_send() diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py index 6f1164faf1..86ca76db84 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py @@ -35,11 +35,13 @@ # QQ 机器人官方框架 class botClient(Client): - def set_platform(self, platform: QQOfficialPlatformAdapter): + def set_platform(self, platform: QQOfficialPlatformAdapter) -> None: self.platform = platform # 收到群消息 - async def on_group_at_message_create(self, message: botpy.message.GroupMessage): + async def on_group_at_message_create( + self, message: botpy.message.GroupMessage + ) -> None: abm = QQOfficialPlatformAdapter._parse_from_qqofficial( message, MessageType.GROUP_MESSAGE, @@ -49,7 +51,7 @@ async def on_group_at_message_create(self, message: botpy.message.GroupMessage): self._commit(abm) # 收到频道消息 - async def on_at_message_create(self, message: botpy.message.Message): + async def on_at_message_create(self, message: botpy.message.Message) -> None: abm = QQOfficialPlatformAdapter._parse_from_qqofficial( message, MessageType.GROUP_MESSAGE, @@ -59,7 +61,9 @@ async def on_at_message_create(self, message: botpy.message.Message): self._commit(abm) # 收到私聊消息 - async def on_direct_message_create(self, message: botpy.message.DirectMessage): + async def on_direct_message_create( + self, message: botpy.message.DirectMessage + ) -> None: abm = QQOfficialPlatformAdapter._parse_from_qqofficial( message, MessageType.FRIEND_MESSAGE, @@ -68,7 +72,7 @@ async def on_direct_message_create(self, message: botpy.message.DirectMessage): self._commit(abm) # 收到 C2C 消息 - async def on_c2c_message_create(self, message: botpy.message.C2CMessage): + async def on_c2c_message_create(self, message: botpy.message.C2CMessage) -> None: abm = QQOfficialPlatformAdapter._parse_from_qqofficial( message, MessageType.FRIEND_MESSAGE, @@ -76,7 +80,7 @@ async def on_c2c_message_create(self, message: botpy.message.C2CMessage): abm.session_id = abm.sender.user_id self._commit(abm) - def _commit(self, abm: AstrBotMessage): + def _commit(self, abm: AstrBotMessage) -> None: self.platform.commit_event( QQOfficialMessageEvent( abm.message_str, @@ -128,7 +132,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: raise NotImplementedError("QQ 机器人官方 API 适配器不支持 send_by_session") def meta(self) -> PlatformMetadata: @@ -222,6 +226,6 @@ def run(self): def get_client(self) -> botClient: return self.client - async def terminate(self): + async def terminate(self) -> None: await self.client.close() logger.info("QQ 官方机器人接口 适配器已被优雅地关闭") diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py index af160f1b5c..c709f2cece 100644 --- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py @@ -26,11 +26,13 @@ # QQ 机器人官方框架 class botClient(Client): - def set_platform(self, platform: "QQOfficialWebhookPlatformAdapter"): + def set_platform(self, platform: "QQOfficialWebhookPlatformAdapter") -> None: self.platform = platform # 收到群消息 - async def on_group_at_message_create(self, message: botpy.message.GroupMessage): + async def on_group_at_message_create( + self, message: botpy.message.GroupMessage + ) -> None: abm = QQOfficialPlatformAdapter._parse_from_qqofficial( message, MessageType.GROUP_MESSAGE, @@ -40,7 +42,7 @@ async def on_group_at_message_create(self, message: botpy.message.GroupMessage): self._commit(abm) # 收到频道消息 - async def on_at_message_create(self, message: botpy.message.Message): + async def on_at_message_create(self, message: botpy.message.Message) -> None: abm = QQOfficialPlatformAdapter._parse_from_qqofficial( message, MessageType.GROUP_MESSAGE, @@ -50,7 +52,9 @@ async def on_at_message_create(self, message: botpy.message.Message): self._commit(abm) # 收到私聊消息 - async def on_direct_message_create(self, message: botpy.message.DirectMessage): + async def on_direct_message_create( + self, message: botpy.message.DirectMessage + ) -> None: abm = QQOfficialPlatformAdapter._parse_from_qqofficial( message, MessageType.FRIEND_MESSAGE, @@ -59,7 +63,7 @@ async def on_direct_message_create(self, message: botpy.message.DirectMessage): self._commit(abm) # 收到 C2C 消息 - async def on_c2c_message_create(self, message: botpy.message.C2CMessage): + async def on_c2c_message_create(self, message: botpy.message.C2CMessage) -> None: abm = QQOfficialPlatformAdapter._parse_from_qqofficial( message, MessageType.FRIEND_MESSAGE, @@ -67,7 +71,7 @@ async def on_c2c_message_create(self, message: botpy.message.C2CMessage): abm.session_id = abm.sender.user_id self._commit(abm) - def _commit(self, abm: AstrBotMessage): + def _commit(self, abm: AstrBotMessage) -> None: self.platform.commit_event( QQOfficialWebhookMessageEvent( abm.message_str, @@ -110,7 +114,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: raise NotImplementedError("QQ 机器人官方 API 适配器不支持 send_by_session") def meta(self) -> PlatformMetadata: @@ -121,7 +125,7 @@ def meta(self) -> PlatformMetadata: support_proactive_message=False, ) - async def run(self): + async def run(self) -> None: self.webhook_helper = QQOfficialWebhook( self.config, self._event_queue, @@ -149,7 +153,7 @@ async def webhook_callback(self, request: Any) -> Any: # 复用 webhook_helper 的回调处理逻辑 return await self.webhook_helper.handle_callback(request) - async def terminate(self): + async def terminate(self) -> None: if self.webhook_helper: self.webhook_helper.shutdown_event.set() await self.client.close() diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py index 306db5e563..5ceeb2c707 100644 --- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py +++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py @@ -13,5 +13,5 @@ def __init__( platform_meta: PlatformMetadata, session_id: str, bot: Client, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id, bot) diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py index 2eda11a6cf..5f35471eea 100644 --- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py @@ -14,7 +14,9 @@ class QQOfficialWebhook: - def __init__(self, config: dict, event_queue: asyncio.Queue, botpy_client: Client): + def __init__( + self, config: dict, event_queue: asyncio.Queue, botpy_client: Client + ) -> None: self.appid = config["appid"] self.secret = config["secret"] self.port = config.get("port", 6196) @@ -38,7 +40,7 @@ def __init__(self, config: dict, event_queue: asyncio.Queue, botpy_client: Clien self.event_queue = event_queue self.shutdown_event = asyncio.Event() - async def initialize(self): + async def initialize(self) -> None: logger.info("正在登录到 QQ 官方机器人...") self.user = await self.http.login(self.token) logger.info(f"已登录 QQ 官方机器人账号: {self.user}") @@ -46,7 +48,7 @@ async def initialize(self): self.client.api = self.api self.client.http = self.http - async def bot_connect(): + async def bot_connect() -> None: pass self._connection = ConnectionSession( @@ -115,7 +117,7 @@ async def handle_callback(self, request) -> dict: return {"opcode": 12} - async def start_polling(self): + async def start_polling(self) -> None: logger.info( f"将在 {self.callback_server_host}:{self.port} 端口启动 QQ 官方机器人 webhook 适配器。", ) @@ -125,5 +127,5 @@ async def start_polling(self): shutdown_trigger=self.shutdown_trigger, ) - async def shutdown_trigger(self): + async def shutdown_trigger(self) -> None: await self.shutdown_event.wait() diff --git a/astrbot/core/platform/sources/satori/satori_adapter.py b/astrbot/core/platform/sources/satori/satori_adapter.py index 10912dc8e1..5c2f7a37f3 100644 --- a/astrbot/core/platform/sources/satori/satori_adapter.py +++ b/astrbot/core/platform/sources/satori/satori_adapter.py @@ -73,7 +73,7 @@ async def send_by_session( self, session: MessageSession, message_chain: MessageChain, - ): + ) -> None: from .satori_event import SatoriPlatformEvent await SatoriPlatformEvent.send_with_adapter( @@ -99,7 +99,7 @@ def _is_websocket_closed(self, ws) -> bool: except AttributeError: return False - async def run(self): + async def run(self) -> None: self.running = True self.session = ClientSession(timeout=ClientTimeout(total=30)) @@ -133,7 +133,7 @@ async def run(self): if self.session: await self.session.close() - async def connect_websocket(self): + async def connect_websocket(self) -> None: logger.info(f"Satori 适配器正在连接到 WebSocket: {self.endpoint}") logger.info(f"Satori 适配器 HTTP API 地址: {self.api_base_url}") @@ -181,7 +181,7 @@ async def connect_websocket(self): except Exception as e: logger.error(f"Satori WebSocket 关闭异常: {e}") - async def send_identify(self): + async def send_identify(self) -> None: if not self.ws: raise Exception("WebSocket连接未建立") @@ -209,7 +209,7 @@ async def send_identify(self): logger.error(f"发送 IDENTIFY 信令失败: {e}") raise - async def heartbeat_loop(self): + async def heartbeat_loop(self) -> None: try: while self.running and self.ws: await asyncio.sleep(self.heartbeat_interval) @@ -234,7 +234,7 @@ async def heartbeat_loop(self): except Exception as e: logger.error(f"心跳任务异常: {e}") - async def handle_message(self, message: str): + async def handle_message(self, message: str) -> None: try: data = json.loads(message) op = data.get("op") @@ -275,7 +275,7 @@ async def handle_message(self, message: str): except Exception as e: logger.error(f"处理 WebSocket 消息异常: {e}") - async def handle_event(self, event_data: dict): + async def handle_event(self, event_data: dict) -> None: try: event_type = event_data.get("type") sn = event_data.get("sn") @@ -720,7 +720,7 @@ async def _parse_xml_node(self, node: ET.Element, elements: list) -> None: if child.tail and child.tail.strip(): elements.append(Plain(text=child.tail)) - async def handle_msg(self, message: AstrBotMessage): + async def handle_msg(self, message: AstrBotMessage) -> None: from .satori_event import SatoriPlatformEvent message_event = SatoriPlatformEvent( @@ -780,7 +780,7 @@ async def send_http_request( logger.error(f"Satori HTTP 请求异常: {e}") return {} - async def terminate(self): + async def terminate(self) -> None: self.running = False if self.heartbeat_task: diff --git a/astrbot/core/platform/sources/satori/satori_event.py b/astrbot/core/platform/sources/satori/satori_event.py index 81a0d222ca..0214222837 100644 --- a/astrbot/core/platform/sources/satori/satori_event.py +++ b/astrbot/core/platform/sources/satori/satori_event.py @@ -28,7 +28,7 @@ def __init__( platform_meta: PlatformMetadata, session_id: str, adapter: "SatoriPlatformAdapter", - ): + ) -> None: # 更新平台元数据 if adapter and hasattr(adapter, "logins") and adapter.logins: current_login = adapter.logins[0] @@ -110,7 +110,7 @@ async def send_with_adapter( logger.error(f"Satori 消息发送异常: {e}") return None - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: platform = getattr(self, "platform", None) user_id = getattr(self, "user_id", None) diff --git a/astrbot/core/platform/sources/slack/client.py b/astrbot/core/platform/sources/slack/client.py index fbdc71759a..efd7a6f3d2 100644 --- a/astrbot/core/platform/sources/slack/client.py +++ b/astrbot/core/platform/sources/slack/client.py @@ -27,7 +27,7 @@ def __init__( port: int = 3000, path: str = "/slack/events", event_handler: Callable | None = None, - ): + ) -> None: self.web_client = web_client self.signing_secret = signing_secret self.host = host @@ -44,7 +44,7 @@ def __init__( self.shutdown_event = asyncio.Event() - def _setup_routes(self): + def _setup_routes(self) -> None: """设置路由""" @self.app.route(self.path, methods=["POST"]) @@ -105,7 +105,7 @@ async def handle_callback(self, req): logger.error(f"处理 Slack 事件时出错: {e}") return Response("Internal Server Error", status=500) - async def start(self): + async def start(self) -> None: """启动 Webhook 服务器""" logger.info( f"Slack Webhook 服务器启动中,监听 {self.host}:{self.port}{self.path}...", @@ -118,10 +118,10 @@ async def start(self): shutdown_trigger=self.shutdown_trigger, ) - async def shutdown_trigger(self): + async def shutdown_trigger(self) -> None: await self.shutdown_event.wait() - async def stop(self): + async def stop(self) -> None: """停止 Webhook 服务器""" self.shutdown_event.set() logger.info("Slack Webhook 服务器已停止") @@ -135,7 +135,7 @@ def __init__( web_client: AsyncWebClient, app_token: str, event_handler: Callable | None = None, - ): + ) -> None: self.web_client = web_client self.app_token = app_token self.event_handler = event_handler @@ -143,7 +143,7 @@ def __init__( async def _handle_events( self, _: AsyncBaseSocketModeClient, req: SocketModeRequest - ): + ) -> None: """处理 Socket Mode 事件""" try: if self.socket_client is None: @@ -160,7 +160,7 @@ async def _handle_events( except Exception as e: logger.error(f"处理 Socket Mode 事件时出错: {e}") - async def start(self): + async def start(self) -> None: """启动 Socket Mode 连接""" self.socket_client = SocketModeClient( app_token=self.app_token, @@ -174,7 +174,7 @@ async def start(self): logger.info("Slack Socket Mode 客户端启动中...") await self.socket_client.connect() - async def stop(self): + async def stop(self) -> None: """停止 Socket Mode 连接""" if self.socket_client: await self.socket_client.disconnect() diff --git a/astrbot/core/platform/sources/slack/slack_adapter.py b/astrbot/core/platform/sources/slack/slack_adapter.py index afd80a8fe1..13e317e49c 100644 --- a/astrbot/core/platform/sources/slack/slack_adapter.py +++ b/astrbot/core/platform/sources/slack/slack_adapter.py @@ -81,7 +81,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: blocks, text = await SlackMessageEvent._parse_slack_blocks( message_chain=message_chain, web_client=self.web_client, @@ -285,7 +285,7 @@ def _parse_blocks(self, blocks: list) -> list: return message_components - async def _handle_socket_event(self, req: SocketModeRequest): + async def _handle_socket_event(self, req: SocketModeRequest) -> None: """处理 Socket Mode 事件""" if req.type == "events_api": # 事件 API @@ -374,7 +374,7 @@ async def run(self) -> None: f"不支持的连接模式: {self.connection_mode},请使用 'socket' 或 'webhook'", ) - async def _handle_webhook_event(self, event_data: dict): + async def _handle_webhook_event(self, event_data: dict) -> None: """处理 Webhook 事件""" event = event_data.get("event", {}) @@ -401,7 +401,7 @@ async def webhook_callback(self, request: Any) -> Any: return await self.webhook_client.handle_callback(request) - async def terminate(self): + async def terminate(self) -> None: if self.socket_client: await self.socket_client.stop() if self.webhook_client: @@ -411,7 +411,7 @@ async def terminate(self): def meta(self) -> PlatformMetadata: return self.metadata - async def handle_msg(self, message: AstrBotMessage): + async def handle_msg(self, message: AstrBotMessage) -> None: message_event = SlackMessageEvent( message_str=message.message_str, message_obj=message, diff --git a/astrbot/core/platform/sources/slack/slack_event.py b/astrbot/core/platform/sources/slack/slack_event.py index 822e6fdeb4..3f62690b53 100644 --- a/astrbot/core/platform/sources/slack/slack_event.py +++ b/astrbot/core/platform/sources/slack/slack_event.py @@ -24,7 +24,7 @@ def __init__( platform_meta, session_id, web_client: AsyncWebClient, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.web_client = web_client @@ -126,7 +126,7 @@ async def _parse_slack_blocks( return blocks, "" if blocks else text_content - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: blocks, text = await SlackMessageEvent._parse_slack_blocks( message, self.web_client, diff --git a/astrbot/core/platform/sources/telegram/tg_adapter.py b/astrbot/core/platform/sources/telegram/tg_adapter.py index 1e92972d65..3381c14f3a 100644 --- a/astrbot/core/platform/sources/telegram/tg_adapter.py +++ b/astrbot/core/platform/sources/telegram/tg_adapter.py @@ -104,7 +104,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: from_username = session.session_id await TelegramPlatformEvent.send_with_client( self.client, @@ -119,7 +119,7 @@ def meta(self) -> PlatformMetadata: return PlatformMetadata(name="telegram", description="telegram 适配器", id=id_) @override - async def run(self): + async def run(self) -> None: await self.application.initialize() await self.application.start() @@ -144,7 +144,7 @@ async def run(self): logger.info("Telegram Platform Adapter is running.") await queue - async def register_commands(self): + async def register_commands(self) -> None: """收集所有注册的指令并注册到 Telegram""" try: commands = self.collect_commands() @@ -222,7 +222,7 @@ def _extract_command_info( description = description[:30] + "..." return cmd_name, description - async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE): + async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: if not update.effective_chat: logger.warning( "Received a start command without an effective chat, skipping /start reply.", @@ -233,7 +233,9 @@ async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE): text=self.config["start_message"], ) - async def message_handler(self, update: Update, context: ContextTypes.DEFAULT_TYPE): + async def message_handler( + self, update: Update, context: ContextTypes.DEFAULT_TYPE + ) -> None: logger.debug(f"Telegram message: {update.message}") # Handle media group messages @@ -477,7 +479,7 @@ async def handle_media_group_message( replace_existing=True, ) - async def process_media_group(self, media_group_id: str): + async def process_media_group(self, media_group_id: str) -> None: """Process a complete media group by merging all collected messages. Args: @@ -523,7 +525,7 @@ async def process_media_group(self, media_group_id: str): # Process the merged message await self.handle_msg(abm) - async def handle_msg(self, message: AstrBotMessage): + async def handle_msg(self, message: AstrBotMessage) -> None: message_event = TelegramPlatformEvent( message_str=message.message_str, message_obj=message, @@ -536,7 +538,7 @@ async def handle_msg(self, message: AstrBotMessage): def get_client(self) -> ExtBot: return self.client - async def terminate(self): + async def terminate(self) -> None: try: if self.scheduler.running: self.scheduler.shutdown() diff --git a/astrbot/core/platform/sources/telegram/tg_event.py b/astrbot/core/platform/sources/telegram/tg_event.py index 5faba68038..1df289d831 100644 --- a/astrbot/core/platform/sources/telegram/tg_event.py +++ b/astrbot/core/platform/sources/telegram/tg_event.py @@ -38,7 +38,7 @@ def __init__( platform_meta: PlatformMetadata, session_id: str, client: ExtBot, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.client = client @@ -73,7 +73,7 @@ async def send_with_client( client: ExtBot, message: MessageChain, user_name: str, - ): + ) -> None: image_path = None has_reply = False @@ -134,14 +134,14 @@ async def send_with_client( path = await i.convert_to_file_path() await client.send_voice(voice=path, **cast(Any, payload)) - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: if self.get_message_type() == MessageType.GROUP_MESSAGE: await self.send_with_client(self.client, message, self.message_obj.group_id) else: await self.send_with_client(self.client, message, self.get_sender_id()) await super().send(message) - async def react(self, emoji: str | None, big: bool = False): + async def react(self, emoji: str | None, big: bool = False) -> None: """给原消息添加 Telegram 反应: - 普通 emoji:传入 '👍'、'😂' 等 - 自定义表情:传入其 custom_emoji_id(纯数字字符串) diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py index a9dff522d4..5eb62e6b34 100644 --- a/astrbot/core/platform/sources/webchat/webchat_adapter.py +++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py @@ -30,7 +30,7 @@ def __init__(self, webchat_queue_mgr: WebChatQueueMgr, callback: Callable) -> No self.webchat_queue_mgr = webchat_queue_mgr self.callback = callback - async def run(self): + async def run(self) -> None: """Register callback and keep adapter task alive.""" self.webchat_queue_mgr.set_listener(self.callback) await asyncio.Event().wait() @@ -61,7 +61,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: message_id = f"active_{str(uuid.uuid4())}" await WebChatMessageEvent._send(message_id, message_chain, session.session_id) await super().send_by_session(session, message_chain) @@ -180,7 +180,7 @@ async def convert_message(self, data: tuple) -> AstrBotMessage: return abm def run(self) -> Coroutine[Any, Any, None]: - async def callback(data: tuple): + async def callback(data: tuple) -> None: abm = await self.convert_message(data) await self.handle_msg(abm) @@ -190,7 +190,7 @@ async def callback(data: tuple): def meta(self) -> PlatformMetadata: return self.metadata - async def handle_msg(self, message: AstrBotMessage): + async def handle_msg(self, message: AstrBotMessage) -> None: message_event = WebChatMessageEvent( message_str=message.message_str, message_obj=message, @@ -208,6 +208,6 @@ async def handle_msg(self, message: AstrBotMessage): self.commit_event(message_event) - async def terminate(self): + async def terminate(self) -> None: # Do nothing pass diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py index 0dcc9cc0c4..a3d1cc3c35 100644 --- a/astrbot/core/platform/sources/webchat/webchat_event.py +++ b/astrbot/core/platform/sources/webchat/webchat_event.py @@ -15,7 +15,7 @@ class WebChatMessageEvent(AstrMessageEvent): - def __init__(self, message_str, message_obj, platform_meta, session_id): + def __init__(self, message_str, message_obj, platform_meta, session_id) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) os.makedirs(imgs_dir, exist_ok=True) @@ -120,12 +120,12 @@ async def _send( return data - async def send(self, message: MessageChain | None): + async def send(self, message: MessageChain | None) -> None: message_id = self.message_obj.message_id await WebChatMessageEvent._send(message_id, message, session_id=self.session_id) await super().send(MessageChain([])) - async def send_streaming(self, generator, use_fallback: bool = False): + async def send_streaming(self, generator, use_fallback: bool = False) -> None: final_data = "" reasoning_content = "" message_id = self.message_obj.message_id diff --git a/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py b/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py index c7636faacd..688d83e2c9 100644 --- a/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py +++ b/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py @@ -54,7 +54,7 @@ def remove_back_queue(self, request_id: str): if not request_ids: self._conversation_back_requests.pop(conversation_id, None) - def remove_queues(self, conversation_id: str): + def remove_queues(self, conversation_id: str) -> None: """Remove queues for the given conversation ID""" for request_id in list( self._conversation_back_requests.get(conversation_id, set()) diff --git a/astrbot/core/platform/sources/wecom/wecom_adapter.py b/astrbot/core/platform/sources/wecom/wecom_adapter.py index adc24578fd..7d7c47ae2c 100644 --- a/astrbot/core/platform/sources/wecom/wecom_adapter.py +++ b/astrbot/core/platform/sources/wecom/wecom_adapter.py @@ -39,7 +39,7 @@ class WecomServer: - def __init__(self, event_queue: asyncio.Queue, config: dict): + def __init__(self, event_queue: asyncio.Queue, config: dict) -> None: self.server = quart.Quart(__name__) self.port = int(cast(str, config.get("port"))) self.callback_server_host = config.get("callback_server_host", "0.0.0.0") @@ -123,7 +123,7 @@ async def handle_callback(self, request) -> str: return "success" - async def start_polling(self): + async def start_polling(self) -> None: logger.info( f"将在 {self.callback_server_host}:{self.port} 端口启动 企业微信 适配器。", ) @@ -133,7 +133,7 @@ async def start_polling(self): shutdown_trigger=self.shutdown_trigger, ) - async def shutdown_trigger(self): + async def shutdown_trigger(self) -> None: await self.shutdown_event.wait() @@ -182,7 +182,7 @@ def __init__( self.client.__setattr__("API_BASE_URL", self.api_base_url) - async def callback(msg: BaseMessage): + async def callback(msg: BaseMessage) -> None: if msg.type == "unknown" and msg._data["Event"] == "kf_msg_or_event": def get_latest_msg_item() -> dict | None: @@ -214,7 +214,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: await super().send_by_session(session, message_chain) @override @@ -228,7 +228,7 @@ def meta(self) -> PlatformMetadata: ) @override - async def run(self): + async def run(self) -> None: loop = asyncio.get_event_loop() if self.kf_name: try: @@ -404,7 +404,7 @@ async def convert_wechat_kf_message(self, msg: dict) -> AstrBotMessage | None: return await self.handle_msg(abm) - async def handle_msg(self, message: AstrBotMessage): + async def handle_msg(self, message: AstrBotMessage) -> None: message_event = WecomPlatformEvent( message_str=message.message_str, message_obj=message, @@ -417,7 +417,7 @@ async def handle_msg(self, message: AstrBotMessage): def get_client(self) -> WeChatClient: return self.client - async def terminate(self): + async def terminate(self) -> None: self.server.shutdown_event.set() try: await self.server.server.shutdown() diff --git a/astrbot/core/platform/sources/wecom/wecom_event.py b/astrbot/core/platform/sources/wecom/wecom_event.py index 0b5dae2729..865a142341 100644 --- a/astrbot/core/platform/sources/wecom/wecom_event.py +++ b/astrbot/core/platform/sources/wecom/wecom_event.py @@ -28,7 +28,7 @@ def __init__( platform_meta: PlatformMetadata, session_id: str, client: WeChatClient, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.client = client @@ -37,7 +37,7 @@ async def send_with_client( client: WeChatClient, message: MessageChain, user_name: str, - ): + ) -> None: pass async def split_plain(self, plain: str) -> list[str]: @@ -86,7 +86,7 @@ async def split_plain(self, plain: str) -> list[str]: return result - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: message_obj = self.message_obj is_wechat_kf = hasattr(self.client, "kf_message") diff --git a/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py b/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py index 2df09a7637..260b950d19 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py @@ -14,6 +14,7 @@ import socket import struct import time +from typing import NoReturn from Crypto.Cipher import AES @@ -30,7 +31,7 @@ class FormatException(Exception): pass -def throw_exception(message, exception_class=FormatException): +def throw_exception(message, exception_class=FormatException) -> NoReturn: """My define raise exception function""" raise exception_class(message) @@ -145,7 +146,7 @@ class Prpcrypt: MIN_RANDOM_VALUE = 1000000000000000 # 最小值: 1000000000000000 (16位) RANDOM_RANGE = 9000000000000000 # 范围大小: 确保最大值为 9999999999999999 (16位) - def __init__(self, key): + def __init__(self, key) -> None: # self.key = base64.b64decode(key+"=") self.key = key # 设置加解密模式为AES的CBC模式 @@ -220,7 +221,7 @@ def get_random_str(self): class WXBizJsonMsgCrypt: # 构造函数 - def __init__(self, sToken, sEncodingAESKey, sReceiveId): + def __init__(self, sToken, sEncodingAESKey, sReceiveId) -> None: try: self.key = base64.b64decode(sEncodingAESKey + "=") assert len(self.key) == 32 diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py index 3b68bc2840..19c1399996 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py @@ -52,7 +52,7 @@ def __init__( self.queue_mgr = queue_mgr self.callback = callback - async def run(self): + async def run(self) -> None: """注册监听回调并定期清理过期响应。""" self.queue_mgr.set_listener(self.callback) while True: @@ -123,7 +123,7 @@ def __init__( self._handle_queued_message, ) - async def _handle_queued_message(self, data: dict): + async def _handle_queued_message(self, data: dict) -> None: """处理队列中的消息,类似webchat的callback""" try: abm = await self.convert_message(data) @@ -288,7 +288,7 @@ async def _enqueue_message( callback_params: dict[str, str], stream_id: str, session_id: str, - ): + ) -> None: """将消息放入队列进行异步处理""" input_queue = self.queue_mgr.get_or_create_queue(stream_id) _ = self.queue_mgr.get_or_create_back_queue(stream_id) @@ -392,7 +392,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: """通过会话发送消息""" # 企业微信智能机器人主要通过回调响应,这里记录日志 logger.info("会话发送消息: %s -> %s", session.session_id, message_chain) @@ -401,7 +401,7 @@ async def send_by_session( def run(self) -> Awaitable[Any]: """运行适配器,同时启动HTTP服务器和队列监听器""" - async def run_both(): + async def run_both() -> None: # 如果启用统一 webhook 模式,则不启动独立服务器 webhook_uuid = self.config.get("webhook_uuid") if self.unified_webhook_mode and webhook_uuid: @@ -428,7 +428,7 @@ async def webhook_callback(self, request: Any) -> Any: else: return await self.server.handle_callback(request) - async def terminate(self): + async def terminate(self) -> None: """终止适配器""" logger.info("企业微信智能机器人适配器正在关闭...") self.shutdown_event.set() @@ -438,7 +438,7 @@ def meta(self) -> PlatformMetadata: """获取平台元数据""" return self.metadata - async def handle_msg(self, message: AstrBotMessage): + async def handle_msg(self, message: AstrBotMessage) -> None: """处理消息,创建消息事件并提交到事件队列""" try: message_event = WecomAIBotMessageEvent( diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py index 6c448a97e6..97831fbb22 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py @@ -19,7 +19,7 @@ class WecomAIBotAPIClient: """企业微信智能机器人 API 客户端""" - def __init__(self, token: str, encoding_aes_key: str): + def __init__(self, token: str, encoding_aes_key: str) -> None: """初始化 API 客户端 Args: diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py index fd11d7ceb9..90a9e363b7 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py @@ -22,7 +22,7 @@ def __init__( session_id: str, api_client: WecomAIBotAPIClient, queue_mgr: WecomAIQueueMgr, - ): + ) -> None: """初始化消息事件 Args: @@ -90,7 +90,7 @@ async def _send( return data - async def send(self, message: MessageChain | None): + async def send(self, message: MessageChain | None) -> None: """发送消息""" raw = self.message_obj.raw_message assert isinstance(raw, dict), ( @@ -100,7 +100,7 @@ async def send(self, message: MessageChain | None): await WecomAIBotMessageEvent._send(message, stream_id, self.queue_mgr) await super().send(MessageChain([])) - async def send_streaming(self, generator, use_fallback=False): + async def send_streaming(self, generator, use_fallback=False) -> None: """流式发送消息,参考webchat的send_streaming设计""" final_data = "" raw = self.message_obj.raw_message diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py index b9dbf5b6a8..9b6e6b9683 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py @@ -64,7 +64,7 @@ def get_or_create_back_queue(self, session_id: str) -> asyncio.Queue: logger.debug(f"[WecomAI] 创建输出队列: {session_id}") return self.back_queues[session_id] - def remove_queues(self, session_id: str, mark_finished: bool = False): + def remove_queues(self, session_id: str, mark_finished: bool = False) -> None: """移除指定会话的所有队列 Args: @@ -123,7 +123,9 @@ def has_back_queue(self, session_id: str) -> bool: """ return session_id in self.back_queues - def set_pending_response(self, session_id: str, callback_params: dict[str, str]): + def set_pending_response( + self, session_id: str, callback_params: dict[str, str] + ) -> None: """设置待处理的响应参数 Args: @@ -163,7 +165,7 @@ def is_stream_finished( return False return True - def cleanup_expired_responses(self, max_age_seconds: int = 300): + def cleanup_expired_responses(self, max_age_seconds: int = 300) -> None: """清理过期的待处理响应 Args: diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py index 5cbdd1130e..80ec5179e3 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py @@ -23,7 +23,7 @@ def __init__( port: int, api_client: WecomAIBotAPIClient, message_handler: Callable[[dict[str, Any], dict[str, str]], Any] | None = None, - ): + ) -> None: """初始化服务器 Args: @@ -43,7 +43,7 @@ def __init__( self.shutdown_event = asyncio.Event() - def _setup_routes(self): + def _setup_routes(self) -> None: """设置 Quart 路由""" # 使用 Quart 的 add_url_rule 方法添加路由 self.app.add_url_rule( @@ -162,7 +162,7 @@ async def handle_callback(self, request): logger.error("处理消息时发生异常: %s", e) return "内部服务器错误", 500 - async def start_server(self): + async def start_server(self) -> None: """启动服务器""" logger.info("启动企业微信智能机器人服务器,监听 %s:%d", self.host, self.port) @@ -176,11 +176,11 @@ async def start_server(self): logger.error("服务器运行异常: %s", e) raise - async def shutdown_trigger(self): + async def shutdown_trigger(self) -> None: """关闭触发器""" await self.shutdown_event.wait() - async def shutdown(self): + async def shutdown(self) -> None: """关闭服务器""" logger.info("企业微信智能机器人服务器正在关闭...") self.shutdown_event.set() diff --git a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py index a38952127e..858e61544b 100644 --- a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +++ b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py @@ -35,7 +35,7 @@ class WeixinOfficialAccountServer: - def __init__(self, event_queue: asyncio.Queue, config: dict): + def __init__(self, event_queue: asyncio.Queue, config: dict) -> None: self.server = quart.Quart(__name__) self.port = int(cast(int | str, config.get("port"))) self.callback_server_host = config.get("callback_server_host", "0.0.0.0") @@ -129,7 +129,7 @@ async def handle_callback(self, request) -> str: return "success" - async def start_polling(self): + async def start_polling(self) -> None: logger.info( f"将在 {self.callback_server_host}:{self.port} 端口启动 微信公众平台 适配器。", ) @@ -139,7 +139,7 @@ async def start_polling(self): shutdown_trigger=self.shutdown_trigger, ) - async def shutdown_trigger(self): + async def shutdown_trigger(self) -> None: await self.shutdown_event.wait() @@ -218,7 +218,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: await super().send_by_session(session, message_chain) @override @@ -232,7 +232,7 @@ def meta(self) -> PlatformMetadata: ) @override - async def run(self): + async def run(self) -> None: # 如果启用统一 webhook 模式,则不启动独立服务器 webhook_uuid = self.config.get("webhook_uuid") if self.unified_webhook_mode and webhook_uuid: @@ -331,7 +331,7 @@ async def convert_message( logger.info(f"abm: {abm}") await self.handle_msg(abm) - async def handle_msg(self, message: AstrBotMessage): + async def handle_msg(self, message: AstrBotMessage) -> None: message_event = WeixinOfficialAccountPlatformEvent( message_str=message.message_str, message_obj=message, @@ -344,7 +344,7 @@ async def handle_msg(self, message: AstrBotMessage): def get_client(self) -> WeChatClient: return self.client - async def terminate(self): + async def terminate(self) -> None: self.server.shutdown_event.set() try: await self.server.server.shutdown() diff --git a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py index c1f137a415..995b166900 100644 --- a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +++ b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py @@ -26,7 +26,7 @@ def __init__( platform_meta: PlatformMetadata, session_id: str, client: WeChatClient, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.client = client @@ -35,7 +35,7 @@ async def send_with_client( client: WeChatClient, message: MessageChain, user_name: str, - ): + ) -> None: pass async def split_plain(self, plain: str) -> list[str]: @@ -84,7 +84,7 @@ async def split_plain(self, plain: str) -> list[str]: return result - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: message_obj = self.message_obj active_send_mode = cast(dict, message_obj.raw_message).get( "active_send_mode", False diff --git a/astrbot/core/platform_message_history_mgr.py b/astrbot/core/platform_message_history_mgr.py index d6d5246984..ad8bb44f6d 100644 --- a/astrbot/core/platform_message_history_mgr.py +++ b/astrbot/core/platform_message_history_mgr.py @@ -3,7 +3,7 @@ class PlatformMessageHistoryManager: - def __init__(self, db_helper: BaseDatabase): + def __init__(self, db_helper: BaseDatabase) -> None: self.db = db_helper async def insert( @@ -40,7 +40,9 @@ async def get( history.reverse() return history - async def delete(self, platform_id: str, user_id: str, offset_sec: int = 86400): + async def delete( + self, platform_id: str, user_id: str, offset_sec: int = 86400 + ) -> None: """Delete platform message history records older than the specified offset.""" await self.db.delete_platform_message_offset( platform_id=platform_id, diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index 7c568626d5..20c5a7947d 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -111,7 +111,7 @@ class ProviderRequest: model: str | None = None """模型名称,为 None 时使用提供商的默认模型""" - def __repr__(self): + def __repr__(self) -> str: return ( f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, " f"image_count={len(self.image_urls or [])}, " @@ -121,10 +121,10 @@ def __repr__(self): f"conversation_id={self.conversation.cid if self.conversation else 'N/A'}, " ) - def __str__(self): + def __str__(self) -> str: return self.__repr__() - def append_tool_calls_result(self, tool_calls_result: ToolCallsResult): + def append_tool_calls_result(self, tool_calls_result: ToolCallsResult) -> None: """添加工具调用结果到请求中""" if not self.tool_calls_result: self.tool_calls_result = [] @@ -309,7 +309,7 @@ def __init__( is_chunk: bool = False, id: str | None = None, usage: TokenUsage | None = None, - ): + ) -> None: """初始化 LLMResponse Args: @@ -356,7 +356,7 @@ def completion_text(self): return self._completion_text @completion_text.setter - def completion_text(self, value): + def completion_text(self, value) -> None: if self.result_chain: self.result_chain.chain = [ comp diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 7aad86bddf..106b42cc5b 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -500,7 +500,7 @@ def load_mcp_config(self): logger.error(f"加载 MCP 配置失败: {e}") return DEFAULT_MCP_CONFIG - def save_mcp_config(self, config: dict): + def save_mcp_config(self, config: dict) -> bool: try: with open(self.mcp_config_path, "w", encoding="utf-8") as f: json.dump(config, f, ensure_ascii=False, indent=4) @@ -575,10 +575,10 @@ async def sync_modelscope_mcp_servers(self, access_token: str) -> None: except Exception as e: raise Exception(f"同步 ModelScope MCP 服务器时发生错误: {e!s}") - def __str__(self): + def __str__(self) -> str: return str(self.func_list) - def __repr__(self): + def __repr__(self) -> str: return str(self.func_list) diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 7ec8c36ff4..ff0bb303dd 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -32,7 +32,7 @@ def __init__( acm: AstrBotConfigManager, db_helper: BaseDatabase, persona_mgr: PersonaManager, - ): + ) -> None: self.reload_lock = asyncio.Lock() self.resource_lock = asyncio.Lock() self.persona_mgr = persona_mgr @@ -92,7 +92,7 @@ async def set_provider( provider_id: str, provider_type: ProviderType, umo: str | None = None, - ): + ) -> None: """设置提供商。 Args: @@ -213,7 +213,7 @@ def get_using_provider( return provider - async def initialize(self): + async def initialize(self) -> None: # 逐个初始化提供商 for provider_config in self.providers_config: try: @@ -277,7 +277,7 @@ async def initialize(self): # 初始化 MCP Client 连接 asyncio.create_task(self.llm_tools.init_mcp_clients(), name="init_mcp_clients") - def dynamic_import_provider(self, type: str): + def dynamic_import_provider(self, type: str) -> None: """动态导入提供商适配器模块 Args: @@ -434,7 +434,7 @@ def _resolve_env_key_list(self, provider_config: dict) -> dict: provider_config["key"] = resolved_keys return provider_config - async def load_provider(self, provider_config: dict): + async def load_provider(self, provider_config: dict) -> None: # 如果 provider_source_id 存在且不为空,则从 provider_sources 中找到对应的配置并合并 provider_config = self.get_merged_provider_config(provider_config) @@ -591,7 +591,7 @@ async def load_provider(self, provider_config: dict): f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}", ) - async def reload(self, provider_config: dict): + async def reload(self, provider_config: dict) -> None: async with self.reload_lock: await self.terminate_provider(provider_config["id"]) if provider_config["enable"]: @@ -637,7 +637,7 @@ async def reload(self, provider_config: dict): def get_insts(self): return self.provider_insts - async def terminate_provider(self, provider_id: str): + async def terminate_provider(self, provider_id: str) -> None: if provider_id in self.inst_map: logger.info( f"终止 {provider_id} 提供商适配器({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)}) ...", @@ -673,7 +673,7 @@ async def terminate_provider(self, provider_id: str): async def delete_provider( self, provider_id: str | None = None, provider_source_id: str | None = None - ): + ) -> None: """Delete provider and/or provider source from config and terminate the instances. Config will be saved after deletion.""" async with self.resource_lock: # delete from config @@ -693,7 +693,7 @@ async def delete_provider( config.save_config() logger.info(f"Provider {target_prov_ids} 已从配置中删除。") - async def update_provider(self, origin_provider_id: str, new_config: dict): + async def update_provider(self, origin_provider_id: str, new_config: dict) -> None: """Update provider config and reload the instance. Config will be saved after update.""" async with self.resource_lock: npid = new_config.get("id", None) @@ -717,7 +717,7 @@ async def update_provider(self, origin_provider_id: str, new_config: dict): # reload instance await self.reload(new_config) - async def create_provider(self, new_config: dict): + async def create_provider(self, new_config: dict) -> None: """Add new provider config and load the instance. Config will be saved after addition.""" async with self.resource_lock: npid = new_config.get("id", None) @@ -733,7 +733,7 @@ async def create_provider(self, new_config: dict): # load instance await self.load_provider(new_config) - async def terminate(self): + async def terminate(self) -> None: for provider_inst in self.provider_insts: if hasattr(provider_inst, "terminate"): await provider_inst.terminate() # type: ignore diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 623ff508ba..901efd0052 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -32,7 +32,7 @@ def __init__(self, provider_config: dict) -> None: self.model_name = "" self.provider_config = provider_config - def set_model(self, model_name: str): + def set_model(self, model_name: str) -> None: """Set the current model name""" self.model_name = model_name @@ -54,7 +54,7 @@ def meta(self) -> ProviderMeta: ) return meta - async def test(self): + async def test(self) -> None: """test the provider is a raises: @@ -84,7 +84,7 @@ def get_keys(self) -> list[str]: return keys or [""] @abc.abstractmethod - def set_key(self, key: str): + def set_key(self, key: str) -> None: raise NotImplementedError @abc.abstractmethod @@ -157,7 +157,7 @@ async def text_chat_stream( yield None # type: ignore raise NotImplementedError() - async def pop_record(self, context: list): + async def pop_record(self, context: list) -> None: """弹出 context 第一条非系统提示词对话记录""" poped = 0 indexs_to_pop = [] @@ -188,7 +188,7 @@ def _ensure_message_to_dicts( return dicts - async def test(self, timeout: float = 45.0): + async def test(self, timeout: float = 45.0) -> None: await asyncio.wait_for( self.text_chat(prompt="REPLY `PONG` ONLY"), timeout=timeout, @@ -206,7 +206,7 @@ async def get_text(self, audio_url: str) -> str: """获取音频的文本""" raise NotImplementedError - async def test(self): + async def test(self) -> None: sample_audio_path = os.path.join( get_astrbot_path(), "samples", @@ -280,7 +280,7 @@ async def get_audio_stream( accumulated_text += text_part - async def test(self): + async def test(self) -> None: await self.get_audio("hi") @@ -305,7 +305,7 @@ def get_dim(self) -> int: """获取向量的维度""" ... - async def test(self): + async def test(self) -> None: await self.get_embedding("astrbot") async def get_embeddings_batch( @@ -335,7 +335,7 @@ async def get_embeddings_batch( completed_count = 0 total_count = len(texts) - async def process_batch(batch_idx: int, batch_texts: list[str]): + async def process_batch(batch_idx: int, batch_texts: list[str]) -> None: nonlocal completed_count async with semaphore: for attempt in range(max_retries): @@ -392,7 +392,7 @@ async def rerank( """获取查询和文档的重排序分数""" ... - async def test(self): + async def test(self) -> None: result = await self.rerank("Apple", documents=["apple", "banana"]) if not result: raise Exception("Rerank provider test failed, no results returned") diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index e6658b74ee..80684aca6d 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -642,7 +642,7 @@ async def get_models(self) -> list[str]: models_str.append(model.id) return models_str - def set_key(self, key: str): + def set_key(self, key: str) -> None: self.chosen_api_key = key async def terminate(self): diff --git a/astrbot/core/provider/sources/azure_tts_source.py b/astrbot/core/provider/sources/azure_tts_source.py index eccbdf05fd..29dc3f7a48 100644 --- a/astrbot/core/provider/sources/azure_tts_source.py +++ b/astrbot/core/provider/sources/azure_tts_source.py @@ -22,7 +22,7 @@ class OTTSProvider: - def __init__(self, config: dict): + def __init__(self, config: dict) -> None: self.skey = config["OTTS_SKEY"] self.api_url = config["OTTS_URL"] self.auth_time_url = config["OTTS_AUTH_TIME"] @@ -54,7 +54,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): await self._client.aclose() self._client = None - async def _sync_time(self): + async def _sync_time(self) -> None: try: response = await self.client.get(self.auth_time_url) response.raise_for_status() @@ -109,7 +109,7 @@ async def get_audio(self, text: str, voice_params: dict) -> str: class AzureNativeProvider(TTSProvider): - def __init__(self, provider_config: dict, provider_settings: dict): + def __init__(self, provider_config: dict, provider_settings: dict) -> None: super().__init__(provider_config, provider_settings) self.subscription_key = provider_config.get( "azure_tts_subscription_key", @@ -159,7 +159,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): await self._client.aclose() self._client = None - async def _refresh_token(self): + async def _refresh_token(self) -> None: token_url = ( f"https://{self.region}.api.cognitive.microsoft.com/sts/v1.0/issuetoken" ) @@ -205,7 +205,7 @@ async def get_audio(self, text: str) -> str: @register_provider_adapter("azure_tts", "Azure TTS", ProviderType.TEXT_TO_SPEECH) class AzureTTSProvider(TTSProvider): - def __init__(self, provider_config: dict, provider_settings: dict): + def __init__(self, provider_config: dict, provider_settings: dict) -> None: super().__init__(provider_config, provider_settings) key_value = provider_config.get("azure_tts_subscription_key", "") self.provider = self._parse_provider(key_value, provider_config) diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index d4074ab553..9557f3dbcd 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -846,7 +846,7 @@ def get_current_key(self) -> str: def get_keys(self) -> list[str]: return self.api_keys - def set_key(self, key): + def set_key(self, key) -> None: self.chosen_api_key = key self._init_client() @@ -928,6 +928,6 @@ async def encode_image_bs64(self, image_url: str) -> str: image_bs64 = base64.b64encode(f.read()).decode("utf-8") return "data:image/jpeg;base64," + image_bs64 - async def terminate(self): + async def terminate(self) -> None: if self.client: await self.client.aclose() diff --git a/astrbot/core/provider/sources/genie_tts.py b/astrbot/core/provider/sources/genie_tts.py index 36436919cc..bca92deb74 100644 --- a/astrbot/core/provider/sources/genie_tts.py +++ b/astrbot/core/provider/sources/genie_tts.py @@ -61,7 +61,7 @@ async def get_audio(self, text: str) -> str: loop = asyncio.get_event_loop() - def _generate(save_path: str): + def _generate(save_path: str) -> None: assert genie is not None genie.tts( character_name=self.character_name, @@ -99,7 +99,7 @@ async def get_audio_stream( filename = f"genie_tts_{uuid.uuid4()}.wav" path = os.path.join(temp_dir, filename) - def _generate(save_path: str, t: str): + def _generate(save_path: str, t: str) -> None: assert genie is not None genie.tts( character_name=self.character_name, diff --git a/astrbot/core/provider/sources/gsv_selfhosted_source.py b/astrbot/core/provider/sources/gsv_selfhosted_source.py index 7f8d39eacd..029f6af10d 100644 --- a/astrbot/core/provider/sources/gsv_selfhosted_source.py +++ b/astrbot/core/provider/sources/gsv_selfhosted_source.py @@ -39,7 +39,7 @@ def __init__( self.timeout = provider_config.get("timeout", 60) self._session: aiohttp.ClientSession | None = None - async def initialize(self): + async def initialize(self) -> None: """异步初始化:在 ProviderManager 中被调用""" self._session = aiohttp.ClientSession( timeout=aiohttp.ClientTimeout(total=self.timeout), @@ -85,7 +85,7 @@ async def _make_request( logger.error(f"[GSV TTS] 请求 {endpoint} 最终失败:{e}") raise - async def _set_model_weights(self): + async def _set_model_weights(self) -> None: """设置模型路径""" try: if self.gpt_weights_path: @@ -144,7 +144,7 @@ def build_synthesis_params(self, text: str) -> dict: # TODO: 在此处添加情绪分析,例如 params["emotion"] = detect_emotion(text) return params - async def terminate(self): + async def terminate(self) -> None: """终止释放资源:在 ProviderManager 中被调用""" if self._session and not self._session.closed: await self._session.close() diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index ce118417a4..0708c09c72 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -375,7 +375,7 @@ async def _prepare_chat_payload( return payloads, context_query - def _finally_convert_payload(self, payloads: dict): + def _finally_convert_payload(self, payloads: dict) -> None: """Finally convert the payload. Such as think part conversion, tool inject.""" for message in payloads.get("messages", []): if message.get("role") == "assistant" and isinstance( @@ -626,7 +626,7 @@ def get_current_key(self) -> str: def get_keys(self) -> list[str]: return self.api_keys - def set_key(self, key): + def set_key(self, key) -> None: self.client.api_key = key async def assemble_context( diff --git a/astrbot/core/provider/sources/sensevoice_selfhosted_source.py b/astrbot/core/provider/sources/sensevoice_selfhosted_source.py index a41bd72fd0..965b83a5a1 100644 --- a/astrbot/core/provider/sources/sensevoice_selfhosted_source.py +++ b/astrbot/core/provider/sources/sensevoice_selfhosted_source.py @@ -37,7 +37,7 @@ def __init__( self.model = None self.is_emotion = provider_config.get("is_emotion", False) - async def initialize(self): + async def initialize(self) -> None: logger.info("下载或者加载 SenseVoice 模型中,这可能需要一些时间 ...") # 将模型加载放到线程池中执行 @@ -52,7 +52,7 @@ async def get_timestamped_path(self) -> str: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") return os.path.join("data", "temp", f"{timestamp}") - async def _is_silk_file(self, file_path): + async def _is_silk_file(self, file_path) -> bool: silk_header = b"SILK" with open(file_path, "rb") as f: file_header = f.read(8) diff --git a/astrbot/core/provider/sources/whisper_api_source.py b/astrbot/core/provider/sources/whisper_api_source.py index 1473cdbbec..875881b7bf 100644 --- a/astrbot/core/provider/sources/whisper_api_source.py +++ b/astrbot/core/provider/sources/whisper_api_source.py @@ -38,7 +38,7 @@ def __init__( self.set_model(provider_config["model"]) - async def _get_audio_format(self, file_path): + async def _get_audio_format(self, file_path) -> str | None: # 定义要检测的头部字节 silk_header = b"SILK" amr_header = b"#!AMR" diff --git a/astrbot/core/provider/sources/whisper_selfhosted_source.py b/astrbot/core/provider/sources/whisper_selfhosted_source.py index a14f93f146..d5d2dc3405 100644 --- a/astrbot/core/provider/sources/whisper_selfhosted_source.py +++ b/astrbot/core/provider/sources/whisper_selfhosted_source.py @@ -30,7 +30,7 @@ def __init__( self.set_model(provider_config["model"]) self.model = None - async def initialize(self): + async def initialize(self) -> None: loop = asyncio.get_event_loop() logger.info("下载或者加载 Whisper 模型中,这可能需要一些时间 ...") self.model = await loop.run_in_executor( @@ -40,7 +40,7 @@ async def initialize(self): ) logger.info("Whisper 模型加载完成。") - async def _is_silk_file(self, file_path): + async def _is_silk_file(self, file_path) -> bool: silk_header = b"SILK" with open(file_path, "rb") as f: file_header = f.read(8) diff --git a/astrbot/core/provider/sources/xai_source.py b/astrbot/core/provider/sources/xai_source.py index a050412d36..b7b432b49a 100644 --- a/astrbot/core/provider/sources/xai_source.py +++ b/astrbot/core/provider/sources/xai_source.py @@ -13,7 +13,7 @@ def __init__( ) -> None: super().__init__(provider_config, provider_settings) - def _maybe_inject_xai_search(self, payloads: dict): + def _maybe_inject_xai_search(self, payloads: dict) -> None: """当开启 xAI 原生搜索时,向请求体注入 Live Search 参数。 - 仅在 provider_config.xai_native_search 为 True 时生效 @@ -24,6 +24,6 @@ def _maybe_inject_xai_search(self, payloads: dict): # OpenAI SDK 不识别的字段会在 _query/_query_stream 中放入 extra_body payloads["search_parameters"] = {"mode": "auto"} - def _finally_convert_payload(self, payloads: dict): + def _finally_convert_payload(self, payloads: dict) -> None: self._maybe_inject_xai_search(payloads) super()._finally_convert_payload(payloads) diff --git a/astrbot/core/provider/sources/xinference_rerank_source.py b/astrbot/core/provider/sources/xinference_rerank_source.py index 960408550e..9c3a77c158 100644 --- a/astrbot/core/provider/sources/xinference_rerank_source.py +++ b/astrbot/core/provider/sources/xinference_rerank_source.py @@ -37,7 +37,7 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None: self.model: AsyncRESTfulRerankModelHandle | None = None self.model_uid = None - async def initialize(self): + async def initialize(self) -> None: if self.api_key: logger.info("Xinference Rerank: Using API key for authentication.") self.client = Client(self.base_url, api_key=self.api_key) diff --git a/astrbot/core/provider/sources/xinference_stt_provider.py b/astrbot/core/provider/sources/xinference_stt_provider.py index 4b947b3f0c..a3e5be352d 100644 --- a/astrbot/core/provider/sources/xinference_stt_provider.py +++ b/astrbot/core/provider/sources/xinference_stt_provider.py @@ -40,7 +40,7 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None: self.client = None self.model_uid = None - async def initialize(self): + async def initialize(self) -> None: if self.api_key: logger.info("Xinference STT: Using API key for authentication.") self.client = Client(self.base_url, api_key=self.api_key) diff --git a/astrbot/core/star/__init__.py b/astrbot/core/star/__init__.py index c474962c58..2bf86872e3 100644 --- a/astrbot/core/star/__init__.py +++ b/astrbot/core/star/__init__.py @@ -15,7 +15,7 @@ class Star(CommandParserMixin, PluginKVStoreMixin): author: str name: str - def __init__(self, context: Context, config: dict | None = None): + def __init__(self, context: Context, config: dict | None = None) -> None: StarTools.initialize(context) self.context = context @@ -55,13 +55,13 @@ async def html_render( options=options, ) - async def initialize(self): + async def initialize(self) -> None: """当插件被激活时会调用这个方法""" - async def terminate(self): + async def terminate(self) -> None: """当插件被禁用、重载插件时会调用这个方法""" - def __del__(self): + def __del__(self) -> None: """[Deprecated] 当插件被禁用、重载插件时会调用这个方法""" diff --git a/astrbot/core/star/config.py b/astrbot/core/star/config.py index 2b590921d3..429a05d5ee 100644 --- a/astrbot/core/star/config.py +++ b/astrbot/core/star/config.py @@ -22,7 +22,7 @@ def load_config(namespace: str) -> dict | bool: return ret -def put_config(namespace: str, name: str, key: str, value, description: str): +def put_config(namespace: str, name: str, key: str, value, description: str) -> None: """将配置项写入以namespace为名字的配置文件,如果key不存在于目标配置文件中。当前 value 仅支持 str, int, float, bool, list 类型(暂不支持 dict)。 namespace: str, 配置的唯一识别符,也就是配置文件的名字。 name: str, 配置项的显示名字。 @@ -64,7 +64,7 @@ def put_config(namespace: str, name: str, key: str, value, description: str): f.flush() -def update_config(namespace: str, key: str, value): +def update_config(namespace: str, key: str, value) -> None: """更新配置文件中的配置项。 namespace: str, 配置的唯一识别符,也就是配置文件的名字。 key: str, 配置项的键。 diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index c7438baf22..6a74580f6e 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -69,7 +69,7 @@ def __init__( knowledge_base_manager: KnowledgeBaseManager, cron_manager: CronJobManager, subagent_orchestrator: SubAgentOrchestrator | None = None, - ): + ) -> None: self._event_queue = event_queue """事件队列。消息平台通过事件队列传递消息事件。""" self._config = config @@ -491,7 +491,7 @@ def register_web_api( view_handler: Awaitable, methods: list, desc: str, - ): + ) -> None: """注册 Web API。 Args: @@ -565,7 +565,7 @@ def get_db(self) -> BaseDatabase: """ return self._db - def register_provider(self, provider: Provider): + def register_provider(self, provider: Provider) -> None: """注册一个 LLM Provider(Chat_Completion 类型)。 Args: @@ -626,7 +626,7 @@ def register_commands( awaitable: Callable[..., Awaitable[Any]], use_regex=False, ignore_prefix=False, - ): + ) -> None: """[DEPRECATED]注册一个命令。 Args: @@ -658,7 +658,7 @@ def register_commands( ) star_handlers_registry.append(md) - def register_task(self, task: Awaitable, desc: str): + def register_task(self, task: Awaitable, desc: str) -> None: """[DEPRECATED]注册一个异步任务。 Args: diff --git a/astrbot/core/star/filter/command.py b/astrbot/core/star/filter/command.py index e86ee85af8..31949b674c 100755 --- a/astrbot/core/star/filter/command.py +++ b/astrbot/core/star/filter/command.py @@ -37,7 +37,7 @@ def __init__( alias: set | None = None, handler_md: StarHandlerMetadata | None = None, parent_command_names: list[str] | None = None, - ): + ) -> None: self.command_name = command_name self.alias = alias if alias else set() self._original_command_name = command_name @@ -63,7 +63,7 @@ def print_types(self): result = "".join(parts).rstrip(",") return result - def init_handler_md(self, handle_md: StarHandlerMetadata): + def init_handler_md(self, handle_md: StarHandlerMetadata) -> None: self.handler_md = handle_md signature = inspect.signature(self.handler_md.handler) self.handler_params = {} # 参数名 -> 参数类型,如果有默认值则为默认值 @@ -81,7 +81,7 @@ def init_handler_md(self, handle_md: StarHandlerMetadata): def get_handler_md(self) -> StarHandlerMetadata: return self.handler_md - def add_custom_filter(self, custom_filter: CustomFilter): + def add_custom_filter(self, custom_filter: CustomFilter) -> None: self.custom_filter_list.append(custom_filter) def custom_filter_ok(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: diff --git a/astrbot/core/star/filter/command_group.py b/astrbot/core/star/filter/command_group.py index 4cbd2c0079..52fb6a4521 100755 --- a/astrbot/core/star/filter/command_group.py +++ b/astrbot/core/star/filter/command_group.py @@ -15,7 +15,7 @@ def __init__( group_name: str, alias: set | None = None, parent_group: CommandGroupFilter | None = None, - ): + ) -> None: self.group_name = group_name self.alias = alias if alias else set() self._original_group_name = group_name @@ -29,10 +29,10 @@ def __init__( def add_sub_command_filter( self, sub_command_filter: CommandFilter | CommandGroupFilter, - ): + ) -> None: self.sub_command_filters.append(sub_command_filter) - def add_custom_filter(self, custom_filter: CustomFilter): + def add_custom_filter(self, custom_filter: CustomFilter) -> None: self.custom_filter_list.append(custom_filter) def get_complete_command_names(self) -> list[str]: diff --git a/astrbot/core/star/filter/custom_filter.py b/astrbot/core/star/filter/custom_filter.py index 54d03632d5..d99cb5f13e 100644 --- a/astrbot/core/star/filter/custom_filter.py +++ b/astrbot/core/star/filter/custom_filter.py @@ -19,7 +19,7 @@ def __or__(cls, other): class CustomFilter(HandlerFilter, metaclass=CustomFilterMeta): - def __init__(self, raise_error: bool = True, **kwargs): + def __init__(self, raise_error: bool = True, **kwargs) -> None: self.raise_error = raise_error @abstractmethod @@ -35,7 +35,7 @@ def __and__(self, other): class CustomFilterOr(CustomFilter): - def __init__(self, filter1: CustomFilter, filter2: CustomFilter): + def __init__(self, filter1: CustomFilter, filter2: CustomFilter) -> None: super().__init__() if not isinstance(filter1, (CustomFilter, CustomFilterAnd, CustomFilterOr)): raise ValueError( @@ -49,7 +49,7 @@ def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: class CustomFilterAnd(CustomFilter): - def __init__(self, filter1: CustomFilter, filter2: CustomFilter): + def __init__(self, filter1: CustomFilter, filter2: CustomFilter) -> None: super().__init__() if not isinstance(filter1, (CustomFilter, CustomFilterAnd, CustomFilterOr)): raise ValueError( diff --git a/astrbot/core/star/filter/event_message_type.py b/astrbot/core/star/filter/event_message_type.py index 7f350bd38d..604fc3ed3b 100644 --- a/astrbot/core/star/filter/event_message_type.py +++ b/astrbot/core/star/filter/event_message_type.py @@ -22,7 +22,7 @@ class EventMessageType(enum.Flag): class EventMessageTypeFilter(HandlerFilter): - def __init__(self, event_message_type: EventMessageType): + def __init__(self, event_message_type: EventMessageType) -> None: self.event_message_type = event_message_type def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: diff --git a/astrbot/core/star/filter/permission.py b/astrbot/core/star/filter/permission.py index 3374544c26..a70299fa95 100644 --- a/astrbot/core/star/filter/permission.py +++ b/astrbot/core/star/filter/permission.py @@ -14,7 +14,9 @@ class PermissionType(enum.Flag): class PermissionTypeFilter(HandlerFilter): - def __init__(self, permission_type: PermissionType, raise_error: bool = True): + def __init__( + self, permission_type: PermissionType, raise_error: bool = True + ) -> None: self.permission_type = permission_type self.raise_error = raise_error diff --git a/astrbot/core/star/filter/platform_adapter_type.py b/astrbot/core/star/filter/platform_adapter_type.py index 241662bca4..ff1affa245 100644 --- a/astrbot/core/star/filter/platform_adapter_type.py +++ b/astrbot/core/star/filter/platform_adapter_type.py @@ -55,7 +55,7 @@ class PlatformAdapterType(enum.Flag): class PlatformAdapterTypeFilter(HandlerFilter): - def __init__(self, platform_adapter_type_or_str: PlatformAdapterType | str): + def __init__(self, platform_adapter_type_or_str: PlatformAdapterType | str) -> None: if isinstance(platform_adapter_type_or_str, str): self.platform_type = ADAPTER_NAME_2_TYPE.get(platform_adapter_type_or_str) else: diff --git a/astrbot/core/star/filter/regex.py b/astrbot/core/star/filter/regex.py index cd5bebdb42..abec5a4882 100644 --- a/astrbot/core/star/filter/regex.py +++ b/astrbot/core/star/filter/regex.py @@ -10,7 +10,7 @@ class RegexFilter(HandlerFilter): """正则表达式过滤器""" - def __init__(self, regex: str): + def __init__(self, regex: str) -> None: self.regex_str = regex self.regex = re.compile(regex) diff --git a/astrbot/core/star/register/star_handler.py b/astrbot/core/star/register/star_handler.py index eefbcedb78..dfca5a25c0 100644 --- a/astrbot/core/star/register/star_handler.py +++ b/astrbot/core/star/register/star_handler.py @@ -250,7 +250,7 @@ class RegisteringCommandable: command: Callable[..., Callable[..., None]] = register_command custom_filter: Callable[..., Callable[..., Any]] = register_custom_filter - def __init__(self, parent_group: CommandGroupFilter): + def __init__(self, parent_group: CommandGroupFilter) -> None: self.parent_group = parent_group @@ -565,7 +565,7 @@ def llm_tool(self, *args, **kwargs): kwargs["registering_agent"] = self return register_llm_tool(*args, **kwargs) - def __init__(self, agent: Agent[AstrAgentContext]): + def __init__(self, agent: Agent[AstrAgentContext]) -> None: self._agent = agent diff --git a/astrbot/core/star/star_handler.py b/astrbot/core/star/star_handler.py index 6f5ce60905..ced4d7739f 100644 --- a/astrbot/core/star/star_handler.py +++ b/astrbot/core/star/star_handler.py @@ -12,11 +12,11 @@ class StarHandlerRegistry(Generic[T]): - def __init__(self): + def __init__(self) -> None: self.star_handlers_map: dict[str, StarHandlerMetadata] = {} self._handlers: list[StarHandlerMetadata] = [] - def append(self, handler: StarHandlerMetadata): + def append(self, handler: StarHandlerMetadata) -> None: """添加一个 Handler,并保持按优先级有序""" if "priority" not in handler.extras_configs: handler.extras_configs["priority"] = 0 @@ -25,7 +25,7 @@ def append(self, handler: StarHandlerMetadata): self._handlers.append(handler) self._handlers.sort(key=lambda h: -h.extras_configs["priority"]) - def _print_handlers(self): + def _print_handlers(self) -> None: for handler in self._handlers: print(handler.handler_full_name) @@ -156,18 +156,18 @@ def get_handlers_by_module_name( if handler.handler_module_path == module_name ] - def clear(self): + def clear(self) -> None: self.star_handlers_map.clear() self._handlers.clear() - def remove(self, handler: StarHandlerMetadata): + def remove(self, handler: StarHandlerMetadata) -> None: self.star_handlers_map.pop(handler.handler_full_name, None) self._handlers = [h for h in self._handlers if h != handler] def __iter__(self): return iter(self._handlers) - def __len__(self): + def __len__(self) -> int: return len(self._handlers) diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index 567397107b..2c8c940f24 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -41,7 +41,7 @@ class PluginManager: - def __init__(self, context: Context, config: AstrBotConfig): + def __init__(self, context: Context, config: AstrBotConfig) -> None: self.updator = PluginUpdator() self.context = context @@ -66,7 +66,7 @@ def __init__(self, context: Context, config: AstrBotConfig): if os.getenv("ASTRBOT_RELOAD", "0") == "1": asyncio.create_task(self._watch_plugins_changes()) - async def _watch_plugins_changes(self): + async def _watch_plugins_changes(self) -> None: """监视插件文件变化""" try: async for changes in awatch( @@ -83,7 +83,7 @@ async def _watch_plugins_changes(self): logger.error(f"插件热重载监视任务异常: {e!s}") logger.error(traceback.format_exc()) - async def _handle_file_changes(self, changes): + async def _handle_file_changes(self, changes) -> None: """处理文件变化""" logger.info(f"检测到文件变化: {changes}") plugins_to_check = [] @@ -167,7 +167,9 @@ def _get_plugin_modules(self) -> list[dict]: plugins.extend(_p) return plugins - async def _check_plugin_dept_update(self, target_plugin: str | None = None): + async def _check_plugin_dept_update( + self, target_plugin: str | None = None + ) -> bool | None: """检查插件的依赖 如果 target_plugin 为 None,则检查所有插件的依赖 """ @@ -264,7 +266,7 @@ def _purge_modules( module_patterns: list[str] | None = None, root_dir_name: str | None = None, is_reserved: bool = False, - ): + ) -> None: """从 sys.modules 中移除指定的模块 可以基于模块名模式或插件目录名移除模块,用于清理插件相关的模块缓存 @@ -711,7 +713,7 @@ async def uninstall_plugin( plugin_name: str, delete_config: bool = False, delete_data: bool = False, - ): + ) -> None: """卸载指定的插件。 Args: @@ -800,7 +802,7 @@ async def uninstall_plugin( except Exception as e: logger.warning(f"删除插件持久化数据失败 (plugins_data): {e!s}") - async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str): + async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str) -> None: """解绑并移除一个插件。 Args: @@ -863,7 +865,7 @@ async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str): is_reserved=plugin.reserved, ) - async def update_plugin(self, plugin_name: str, proxy=""): + async def update_plugin(self, plugin_name: str, proxy="") -> None: """升级一个插件""" plugin = self.context.get_registered_star(plugin_name) if not plugin: @@ -874,7 +876,7 @@ async def update_plugin(self, plugin_name: str, proxy=""): await self.updator.update(plugin, proxy=proxy) await self.reload(plugin_name) - async def turn_off_plugin(self, plugin_name: str): + async def turn_off_plugin(self, plugin_name: str) -> None: """禁用一个插件。 调用插件的 terminate() 方法, 将插件的 module_path 加入到 data/shared_preferences.json 的 inactivated_plugins 列表中。 @@ -916,7 +918,7 @@ async def turn_off_plugin(self, plugin_name: str): plugin.activated = False @staticmethod - async def _terminate_plugin(star_metadata: StarMetadata): + async def _terminate_plugin(star_metadata: StarMetadata) -> None: """终止插件,调用插件的 terminate() 和 __del__() 方法""" logger.info(f"正在终止插件 {star_metadata.name} ...") @@ -936,7 +938,7 @@ async def _terminate_plugin(star_metadata: StarMetadata): elif "terminate" in star_metadata.star_cls_type.__dict__: await star_metadata.star_cls.terminate() - async def turn_on_plugin(self, plugin_name: str): + async def turn_on_plugin(self, plugin_name: str) -> None: plugin = self.context.get_registered_star(plugin_name) if plugin is None: raise Exception(f"插件 {plugin_name} 不存在。") diff --git a/astrbot/core/star/star_tools.py b/astrbot/core/star/star_tools.py index 7a66449b4f..4d85131fc6 100644 --- a/astrbot/core/star/star_tools.py +++ b/astrbot/core/star/star_tools.py @@ -89,7 +89,7 @@ async def send_message_by_id( id: str, message_chain: MessageChain, platform: str = "aiocqhttp", - ): + ) -> None: """根据 id(例如qq号, 群号等) 直接, 主动地发送消息 Args: diff --git a/astrbot/core/star/updator.py b/astrbot/core/star/updator.py index 8793ad5055..1a0c5fc260 100644 --- a/astrbot/core/star/updator.py +++ b/astrbot/core/star/updator.py @@ -52,7 +52,7 @@ async def update(self, plugin: StarMetadata, proxy="") -> str: return plugin_path - def unzip_file(self, zip_path: str, target_dir: str): + def unzip_file(self, zip_path: str, target_dir: str) -> None: os.makedirs(target_dir, exist_ok=True) update_dir = "" logger.info(f"正在解压压缩包: {zip_path}") diff --git a/astrbot/core/subagent_orchestrator.py b/astrbot/core/subagent_orchestrator.py index 62ddc0fd3a..205c554cb8 100644 --- a/astrbot/core/subagent_orchestrator.py +++ b/astrbot/core/subagent_orchestrator.py @@ -16,7 +16,9 @@ class SubAgentOrchestrator: Execution happens via HandoffTool in FunctionToolExecutor. """ - def __init__(self, tool_mgr: FunctionToolManager, persona_mgr: PersonaManager): + def __init__( + self, tool_mgr: FunctionToolManager, persona_mgr: PersonaManager + ) -> None: self._tool_mgr = tool_mgr self._persona_mgr = persona_mgr self.handoffs: list[HandoffTool] = [] diff --git a/astrbot/core/umop_config_router.py b/astrbot/core/umop_config_router.py index 1f2289f4df..d8b010d507 100644 --- a/astrbot/core/umop_config_router.py +++ b/astrbot/core/umop_config_router.py @@ -6,15 +6,15 @@ class UmopConfigRouter: """UMOP 配置路由器""" - def __init__(self, sp: SharedPreferences): + def __init__(self, sp: SharedPreferences) -> None: self.umop_to_conf_id: dict[str, str] = {} """UMOP 到配置文件 ID 的映射""" self.sp = sp - async def initialize(self): + async def initialize(self) -> None: await self._load_routing_table() - async def _load_routing_table(self): + async def _load_routing_table(self) -> None: """加载路由表""" # 从 SharedPreferences 中加载 umop_to_conf_id 映射 sp_data = await self.sp.get_async( @@ -50,7 +50,7 @@ def get_conf_id_for_umop(self, umo: str) -> str | None: return conf_id return None - async def update_routing_data(self, new_routing: dict[str, str]): + async def update_routing_data(self, new_routing: dict[str, str]) -> None: """更新路由表 Args: @@ -70,7 +70,7 @@ async def update_routing_data(self, new_routing: dict[str, str]): self.umop_to_conf_id = new_routing await self.sp.global_put("umop_config_routing", self.umop_to_conf_id) - async def update_route(self, umo: str, conf_id: str): + async def update_route(self, umo: str, conf_id: str) -> None: """更新一条路由 Args: @@ -89,7 +89,7 @@ async def update_route(self, umo: str, conf_id: str): self.umop_to_conf_id[umo] = conf_id await self.sp.global_put("umop_config_routing", self.umop_to_conf_id) - async def delete_route(self, umo: str): + async def delete_route(self, umo: str) -> None: """删除一条路由 Args: diff --git a/astrbot/core/updator.py b/astrbot/core/updator.py index e993cdc8b6..87268c6099 100644 --- a/astrbot/core/updator.py +++ b/astrbot/core/updator.py @@ -23,7 +23,7 @@ def __init__(self, repo_mirror: str = "") -> None: self.MAIN_PATH = get_astrbot_path() self.ASTRBOT_RELEASE_API = "https://api.soulter.top/releases" - def terminate_child_processes(self): + def terminate_child_processes(self) -> None: """终止当前进程的所有子进程 使用 psutil 库获取当前进程的所有子进程,并尝试终止它们 """ @@ -44,7 +44,7 @@ def terminate_child_processes(self): except psutil.NoSuchProcess: pass - def _reboot(self, delay: int = 3): + def _reboot(self, delay: int = 3) -> None: """重启当前程序 在指定的延迟后,终止所有子进程并重新启动程序 这里只能使用 os.exec* 来重启程序 @@ -91,7 +91,7 @@ async def check_update( async def get_releases(self) -> list: return await self.fetch_release_info(self.ASTRBOT_RELEASE_API) - async def update(self, reboot=False, latest=True, version=None, proxy=""): + async def update(self, reboot=False, latest=True, version=None, proxy="") -> None: update_data = await self.fetch_release_info(self.ASTRBOT_RELEASE_API, latest) file_url = None diff --git a/astrbot/core/utils/io.py b/astrbot/core/utils/io.py index fcf5bb3c7b..24b919bdd9 100644 --- a/astrbot/core/utils/io.py +++ b/astrbot/core/utils/io.py @@ -19,7 +19,7 @@ logger = logging.getLogger("astrbot") -def on_error(func, path, exc_info): +def on_error(func, path, exc_info) -> None: """A callback of the rmtree function.""" import stat @@ -37,7 +37,7 @@ def remove_dir(file_path: str) -> bool: return True -def port_checker(port: int, host: str = "localhost"): +def port_checker(port: int, host: str = "localhost") -> bool: sk = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sk.settimeout(1) try: @@ -134,7 +134,7 @@ async def download_image_by_url( raise e -async def download_file(url: str, path: str, show_progress: bool = False): +async def download_file(url: str, path: str, show_progress: bool = False) -> None: """从指定 url 下载文件到指定路径 path""" try: ssl_context = ssl.create_default_context( diff --git a/astrbot/core/utils/llm_metadata.py b/astrbot/core/utils/llm_metadata.py index 540c1efd98..915d8d8f9a 100644 --- a/astrbot/core/utils/llm_metadata.py +++ b/astrbot/core/utils/llm_metadata.py @@ -29,7 +29,7 @@ class LLMMetadata(TypedDict): LLM_METADATAS: dict[str, LLMMetadata] = {} -async def update_llm_metadata(): +async def update_llm_metadata() -> None: url = "https://models.dev/api.json" try: async with aiohttp.ClientSession() as session: diff --git a/astrbot/core/utils/log_pipe.py b/astrbot/core/utils/log_pipe.py index 2e931dd816..6f40f09420 100644 --- a/astrbot/core/utils/log_pipe.py +++ b/astrbot/core/utils/log_pipe.py @@ -10,7 +10,7 @@ def __init__( logger: Logger, identifier=None, callback=None, - ): + ) -> None: threading.Thread.__init__(self) self.daemon = True self.level = level @@ -24,7 +24,7 @@ def __init__( def fileno(self): return self.fd_write - def run(self): + def run(self) -> None: for line in iter(self.reader.readline, ""): if self.callback: self.callback(line.strip()) @@ -32,5 +32,5 @@ def run(self): self.reader.close() - def close(self): + def close(self) -> None: os.close(self.fd_write) diff --git a/astrbot/core/utils/metrics.py b/astrbot/core/utils/metrics.py index d3dc732d2a..8fb1464284 100644 --- a/astrbot/core/utils/metrics.py +++ b/astrbot/core/utils/metrics.py @@ -40,7 +40,7 @@ def get_installation_id(): return "null" @staticmethod - async def upload(**kwargs): + async def upload(**kwargs) -> None: """上传相关非敏感的指标以更好地了解 AstrBot 的使用情况。上传的指标不会包含任何有关消息文本、用户信息等敏感信息。 Powered by TickStats. diff --git a/astrbot/core/utils/pip_installer.py b/astrbot/core/utils/pip_installer.py index e9e9556d99..afde805ffa 100644 --- a/astrbot/core/utils/pip_installer.py +++ b/astrbot/core/utils/pip_installer.py @@ -61,7 +61,7 @@ def _cleanup_added_root_handlers(original_handlers: list[logging.Handler]) -> No class PipInstaller: - def __init__(self, pip_install_arg: str, pypi_index_url: str | None = None): + def __init__(self, pip_install_arg: str, pypi_index_url: str | None = None) -> None: self.pip_install_arg = pip_install_arg self.pypi_index_url = pypi_index_url @@ -70,7 +70,7 @@ async def install( package_name: str | None = None, requirements_path: str | None = None, mirror: str | None = None, - ): + ) -> None: args = ["install"] if package_name: args.append(package_name) diff --git a/astrbot/core/utils/session_lock.py b/astrbot/core/utils/session_lock.py index 912d91e532..7810d6ce45 100644 --- a/astrbot/core/utils/session_lock.py +++ b/astrbot/core/utils/session_lock.py @@ -4,7 +4,7 @@ class SessionLockManager: - def __init__(self): + def __init__(self) -> None: self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) self._lock_count: dict[str, int] = defaultdict(int) self._access_lock = asyncio.Lock() diff --git a/astrbot/core/utils/session_waiter.py b/astrbot/core/utils/session_waiter.py index e1f2fbef74..b327a61843 100644 --- a/astrbot/core/utils/session_waiter.py +++ b/astrbot/core/utils/session_waiter.py @@ -18,7 +18,7 @@ class SessionController: """控制一个 Session 是否已经结束""" - def __init__(self): + def __init__(self) -> None: self.future = asyncio.Future() self.current_event: asyncio.Event | None = None """当前正在等待的所用的异步事件""" @@ -29,7 +29,7 @@ def __init__(self): self.history_chains: list[list[Comp.BaseMessageComponent]] = [] - def stop(self, error: Exception | None = None): + def stop(self, error: Exception | None = None) -> None: """立即结束这个会话""" if not self.future.done(): if error: @@ -37,7 +37,7 @@ def stop(self, error: Exception | None = None): else: self.future.set_result(None) - def keep(self, timeout: float = 0, reset_timeout=False): + def keep(self, timeout: float = 0, reset_timeout=False) -> None: """保持这个会话 Args: @@ -71,7 +71,7 @@ def keep(self, timeout: float = 0, reset_timeout=False): asyncio.create_task(self._holding(new_event, timeout)) # 开始新的 keep - async def _holding(self, event: asyncio.Event, timeout: float): + async def _holding(self, event: asyncio.Event, timeout: float) -> None: """等待事件结束或超时""" try: await asyncio.wait_for(event.wait(), timeout) @@ -107,7 +107,7 @@ def __init__( session_filter: SessionFilter, session_id: str, record_history_chains: bool, - ): + ) -> None: self.session_id = session_id self.session_filter = session_filter self.handler: ( @@ -141,7 +141,7 @@ async def register_wait( finally: self._cleanup() - def _cleanup(self, error: Exception | None = None): + def _cleanup(self, error: Exception | None = None) -> None: """清理会话""" USER_SESSIONS.pop(self.session_id, None) try: @@ -151,7 +151,7 @@ def _cleanup(self, error: Exception | None = None): self.session_controller.stop(error) @classmethod - async def trigger(cls, session_id: str, event: AstrMessageEvent): + async def trigger(cls, session_id: str, event: AstrMessageEvent) -> None: """外部输入触发会话处理""" session = USER_SESSIONS.get(session_id) if not session or session.session_controller.future.done(): diff --git a/astrbot/core/utils/shared_preferences.py b/astrbot/core/utils/shared_preferences.py index 788d54a474..344808cbd3 100644 --- a/astrbot/core/utils/shared_preferences.py +++ b/astrbot/core/utils/shared_preferences.py @@ -15,7 +15,7 @@ class SharedPreferences: - def __init__(self, db_helper: BaseDatabase, json_storage_path=None): + def __init__(self, db_helper: BaseDatabase, json_storage_path=None) -> None: if json_storage_path is None: json_storage_path = os.path.join( get_astrbot_data_path(), @@ -36,7 +36,7 @@ def __init__(self, db_helper: BaseDatabase, json_storage_path=None): ) self._scheduler.start() - def _clear_temporary_cache(self): + def _clear_temporary_cache(self) -> None: self.temporary_cache.clear() async def get_async( @@ -132,7 +132,7 @@ async def global_get( return await self.range_get_async("global", "global", key) return await self.get_async("global", "global", key, default) - async def put_async(self, scope: str, scope_id: str, key: str, value: Any): + async def put_async(self, scope: str, scope_id: str, key: str, value: Any) -> None: """设置指定范围和键的偏好设置""" await self.db_helper.insert_preference_or_update( scope, @@ -141,24 +141,24 @@ async def put_async(self, scope: str, scope_id: str, key: str, value: Any): {"val": value}, ) - async def session_put(self, umo: str, key: str, value: Any): + async def session_put(self, umo: str, key: str, value: Any) -> None: await self.put_async("umo", umo, key, value) - async def global_put(self, key: str, value: Any): + async def global_put(self, key: str, value: Any) -> None: await self.put_async("global", "global", key, value) - async def remove_async(self, scope: str, scope_id: str, key: str): + async def remove_async(self, scope: str, scope_id: str, key: str) -> None: """删除指定范围和键的偏好设置""" await self.db_helper.remove_preference(scope, scope_id, key) - async def session_remove(self, umo: str, key: str): + async def session_remove(self, umo: str, key: str) -> None: await self.remove_async("umo", umo, key) - async def global_remove(self, key: str): + async def global_remove(self, key: str) -> None: """删除全局偏好设置""" await self.remove_async("global", "global", key) - async def clear_async(self, scope: str, scope_id: str): + async def clear_async(self, scope: str, scope_id: str) -> None: """清空指定范围的所有偏好设置""" await self.db_helper.clear_preferences(scope, scope_id) @@ -202,21 +202,25 @@ def range_get( return result - def put(self, key, value, scope: str | None = None, scope_id: str | None = None): + def put( + self, key, value, scope: str | None = None, scope_id: str | None = None + ) -> None: """设置偏好设置(已弃用)""" asyncio.run_coroutine_threadsafe( self.put_async(scope or "unknown", scope_id or "unknown", key, value), self._sync_loop, ).result() - def remove(self, key, scope: str | None = None, scope_id: str | None = None): + def remove( + self, key, scope: str | None = None, scope_id: str | None = None + ) -> None: """删除偏好设置(已弃用)""" asyncio.run_coroutine_threadsafe( self.remove_async(scope or "unknown", scope_id or "unknown", key), self._sync_loop, ).result() - def clear(self, scope: str | None = None, scope_id: str | None = None): + def clear(self, scope: str | None = None, scope_id: str | None = None) -> None: """清空偏好设置(已弃用)""" asyncio.run_coroutine_threadsafe( self.clear_async(scope or "unknown", scope_id or "unknown"), diff --git a/astrbot/core/utils/t2i/network_strategy.py b/astrbot/core/utils/t2i/network_strategy.py index 7ebba56692..2abb229171 100644 --- a/astrbot/core/utils/t2i/network_strategy.py +++ b/astrbot/core/utils/t2i/network_strategy.py @@ -28,7 +28,7 @@ def __init__(self, base_url: str | None = None) -> None: self.endpoints = [self.BASE_RENDER_URL] self.template_manager = TemplateManager() - async def initialize(self): + async def initialize(self) -> None: if self.BASE_RENDER_URL == ASTRBOT_T2I_DEFAULT_ENDPOINT: asyncio.create_task(self.get_official_endpoints()) @@ -36,7 +36,7 @@ async def get_template(self, name: str = "base") -> str: """通过名称获取文转图 HTML 模板""" return self.template_manager.get_template(name) - async def get_official_endpoints(self): + async def get_official_endpoints(self) -> None: """获取官方的 t2i 端点列表。""" try: async with aiohttp.ClientSession() as session: diff --git a/astrbot/core/utils/t2i/renderer.py b/astrbot/core/utils/t2i/renderer.py index 2ce7a5ebf6..e3118d7e86 100644 --- a/astrbot/core/utils/t2i/renderer.py +++ b/astrbot/core/utils/t2i/renderer.py @@ -7,11 +7,11 @@ class HtmlRenderer: - def __init__(self, endpoint_url: str | None = None): + def __init__(self, endpoint_url: str | None = None) -> None: self.network_strategy = NetworkRenderStrategy(endpoint_url) self.local_strategy = LocalRenderStrategy() - async def initialize(self): + async def initialize(self) -> None: await self.network_strategy.initialize() async def render_custom_template( diff --git a/astrbot/core/utils/t2i/template_manager.py b/astrbot/core/utils/t2i/template_manager.py index 6d44f735be..b3eb0c9ffb 100644 --- a/astrbot/core/utils/t2i/template_manager.py +++ b/astrbot/core/utils/t2i/template_manager.py @@ -14,7 +14,7 @@ class TemplateManager: CORE_TEMPLATES = ["base.html", "astrbot_powershell.html"] - def __init__(self): + def __init__(self) -> None: self.builtin_template_dir = os.path.join( get_astrbot_path(), "astrbot", @@ -28,7 +28,7 @@ def __init__(self): os.makedirs(self.user_template_dir, exist_ok=True) self._initialize_user_templates() - def _copy_core_templates(self, overwrite: bool = False): + def _copy_core_templates(self, overwrite: bool = False) -> None: """从内置目录复制核心模板到用户目录。""" for filename in self.CORE_TEMPLATES: src = os.path.join(self.builtin_template_dir, filename) @@ -36,7 +36,7 @@ def _copy_core_templates(self, overwrite: bool = False): if os.path.exists(src) and (overwrite or not os.path.exists(dst)): shutil.copyfile(src, dst) - def _initialize_user_templates(self): + def _initialize_user_templates(self) -> None: """如果用户目录下缺少核心模板,则进行复制。""" self._copy_core_templates(overwrite=False) @@ -80,7 +80,7 @@ def get_template(self, name: str) -> str: raise FileNotFoundError("模板不存在。") - def create_template(self, name: str, content: str): + def create_template(self, name: str, content: str) -> None: """在用户目录中创建一个新的模板文件。""" path = self._get_user_template_path(name) if os.path.exists(path): @@ -88,7 +88,7 @@ def create_template(self, name: str, content: str): with open(path, "w", encoding="utf-8") as f: f.write(content) - def update_template(self, name: str, content: str): + def update_template(self, name: str, content: str) -> None: """更新一个模板。此操作始终写入用户目录。 如果更新的是一个内置模板,此操作实际上会在用户目录中创建一个修改后的副本, 从而实现对内置模板的“覆盖”。 @@ -97,7 +97,7 @@ def update_template(self, name: str, content: str): with open(path, "w", encoding="utf-8") as f: f.write(content) - def delete_template(self, name: str): + def delete_template(self, name: str) -> None: """仅删除用户目录中的模板文件。 如果删除的是一个覆盖了内置模板的用户模板,这将有效地“恢复”到内置版本。 """ @@ -106,6 +106,6 @@ def delete_template(self, name: str): raise FileNotFoundError("用户模板不存在,无法删除。") os.remove(path) - def reset_default_template(self): + def reset_default_template(self) -> None: """将核心模板从内置目录强制重置到用户目录。""" self._copy_core_templates(overwrite=True) diff --git a/astrbot/core/utils/webhook_utils.py b/astrbot/core/utils/webhook_utils.py index 0e1c3f9cdd..07abc115a0 100644 --- a/astrbot/core/utils/webhook_utils.py +++ b/astrbot/core/utils/webhook_utils.py @@ -20,7 +20,7 @@ def _get_dashboard_port() -> int: return 6185 -def log_webhook_info(platform_name: str, webhook_uuid: str): +def log_webhook_info(platform_name: str, webhook_uuid: str) -> None: """打印美观的 webhook 信息日志 Args: diff --git a/astrbot/core/zip_updator.py b/astrbot/core/zip_updator.py index 728dfdabb3..6cea6b38d5 100644 --- a/astrbot/core/zip_updator.py +++ b/astrbot/core/zip_updator.py @@ -3,6 +3,7 @@ import shutil import ssl import zipfile +from typing import NoReturn import aiohttp import certifi @@ -101,10 +102,10 @@ def github_api_release_parser(self, releases: list) -> list: ) return ret - def unzip(self): + def unzip(self) -> NoReturn: raise NotImplementedError - async def update(self): + async def update(self) -> NoReturn: raise NotImplementedError def compare_version(self, v1: str, v2: str) -> int: @@ -148,7 +149,9 @@ async def check_update( body=f"{tag_name}\n\n{sel_release_data['body']}", ) - async def download_from_repo_url(self, target_path: str, repo_url: str, proxy=""): + async def download_from_repo_url( + self, target_path: str, repo_url: str, proxy="" + ) -> None: author, repo, branch = self.parse_github_url(repo_url) logger.info(f"正在下载更新 {repo} ...") @@ -203,7 +206,7 @@ def parse_github_url(self, url: str): return author, repo, branch raise ValueError("无效的 GitHub URL") - def unzip_file(self, zip_path: str, target_dir: str): + def unzip_file(self, zip_path: str, target_dir: str) -> None: """解压缩文件, 并将压缩包内**第一个**文件夹内的文件移动到 target_dir""" os.makedirs(target_dir, exist_ok=True) update_dir = "" diff --git a/astrbot/dashboard/routes/backup.py b/astrbot/dashboard/routes/backup.py index ee39399dc1..952806beb7 100644 --- a/astrbot/dashboard/routes/backup.py +++ b/astrbot/dashboard/routes/backup.py @@ -183,7 +183,9 @@ def _update_progress( def _make_progress_callback(self, task_id: str): """创建进度回调函数""" - async def _callback(stage: str, current: int, total: int, message: str = ""): + async def _callback( + stage: str, current: int, total: int, message: str = "" + ) -> None: self._update_progress( task_id, status="processing", @@ -195,7 +197,7 @@ async def _callback(stage: str, current: int, total: int, message: str = ""): return _callback - def _ensure_cleanup_task_started(self): + def _ensure_cleanup_task_started(self) -> None: """确保后台清理任务已启动(在异步上下文中延迟启动)""" if self._cleanup_task is None or self._cleanup_task.done(): try: @@ -206,7 +208,7 @@ def _ensure_cleanup_task_started(self): # 如果没有运行中的事件循环,跳过(等待下次异步调用时启动) pass - async def _cleanup_expired_uploads(self): + async def _cleanup_expired_uploads(self) -> None: """定期清理过期的上传会话 基于 last_activity 字段判断过期,避免清理活跃的上传会话。 @@ -233,7 +235,7 @@ async def _cleanup_expired_uploads(self): except Exception as e: logger.error(f"清理过期上传会话失败: {e}") - async def _cleanup_upload_session(self, upload_id: str): + async def _cleanup_upload_session(self, upload_id: str) -> None: """清理上传会话""" if upload_id in self.upload_sessions: session = self.upload_sessions[upload_id] @@ -371,7 +373,7 @@ async def export_backup(self): logger.error(traceback.format_exc()) return Response().error(f"创建备份失败: {e!s}").__dict__ - async def _background_export_task(self, task_id: str): + async def _background_export_task(self, task_id: str) -> None: """后台导出任务""" try: self._update_progress(task_id, status="processing", message="正在初始化...") @@ -866,7 +868,7 @@ async def import_backup(self): logger.error(traceback.format_exc()) return Response().error(f"导入备份失败: {e!s}").__dict__ - async def _background_import_task(self, task_id: str, zip_path: str): + async def _background_import_task(self, task_id: str, zip_path: str) -> None: """后台导入任务""" try: self._update_progress(task_id, status="processing", message="正在初始化...") diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index 55b279fe11..a7c0e3a573 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -651,7 +651,7 @@ def _extract_attachment_ids(self, history_list) -> list[str]: attachment_ids.append(part["attachment_id"]) return attachment_ids - async def _delete_attachments(self, attachment_ids: list[str]): + async def _delete_attachments(self, attachment_ids: list[str]) -> None: """删除附件(包括数据库记录和磁盘文件)""" try: attachments = await self.db.get_attachments(attachment_ids) diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index e018d64755..efea4c7cf8 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -59,7 +59,7 @@ def try_cast(value: Any, type_: str): return None -def _expect_type(value, expected_type, path_key, errors, expected_name=None): +def _expect_type(value, expected_type, path_key, errors, expected_name=None) -> bool: if not isinstance(value, expected_type): errors.append( f"错误的类型 {path_key}: 期望是 {expected_name or expected_type.__name__}, " @@ -69,7 +69,7 @@ def _expect_type(value, expected_type, path_key, errors, expected_name=None): return True -def _validate_template_list(value, meta, path_key, errors, validate_fn): +def _validate_template_list(value, meta, path_key, errors, validate_fn) -> None: if not _expect_type(value, list, path_key, errors, "list"): return @@ -102,7 +102,7 @@ def _validate_template_list(value, meta, path_key, errors, validate_fn): def validate_config(data, schema: dict, is_core: bool) -> tuple[list[str], dict]: errors = [] - def validate(data: dict, metadata: dict = schema, path=""): + def validate(data: dict, metadata: dict = schema, path="") -> None: for key, value in data.items(): if key not in metadata: continue @@ -206,7 +206,9 @@ def validate(data: dict, metadata: dict = schema, path=""): return errors, data -def save_config(post_config: dict, config: AstrBotConfig, is_core: bool = False): +def save_config( + post_config: dict, config: AstrBotConfig, is_core: bool = False +) -> None: """验证并保存配置""" errors = None logger.info(f"Saving config, is_core={is_core}") @@ -1221,7 +1223,7 @@ async def get_llm_tools(self): tools = tool_mgr.get_func_desc_openai_style() return Response().ok(tools).__dict__ - async def _register_platform_logo(self, platform, platform_default_tmpl): + async def _register_platform_logo(self, platform, platform_default_tmpl) -> None: """注册平台logo文件并生成访问令牌""" if not platform.logo_path: return @@ -1357,7 +1359,7 @@ async def _get_plugin_config(self, plugin_name: str): async def _save_astrbot_configs( self, post_configs: dict, conf_id: str | None = None - ): + ) -> None: try: if conf_id not in self.acm.confs: raise ValueError(f"配置文件 {conf_id} 不存在") @@ -1373,7 +1375,7 @@ async def _save_astrbot_configs( except Exception as e: raise e - async def _save_plugin_configs(self, post_configs: dict, plugin_name: str): + async def _save_plugin_configs(self, post_configs: dict, plugin_name: str) -> None: md = None for plugin_md in star_registry: if plugin_md.name == plugin_name: diff --git a/astrbot/dashboard/routes/knowledge_base.py b/astrbot/dashboard/routes/knowledge_base.py index 25bc2cf347..efdc3bc6a9 100644 --- a/astrbot/dashboard/routes/knowledge_base.py +++ b/astrbot/dashboard/routes/knowledge_base.py @@ -114,7 +114,7 @@ def _update_progress( p["total"] = total def _make_progress_callback(self, task_id: str, file_idx: int, file_name: str): - async def _callback(stage: str, current: int, total: int): + async def _callback(stage: str, current: int, total: int) -> None: self._update_progress( task_id, status="processing", @@ -137,7 +137,7 @@ async def _background_upload_task( batch_size: int, tasks_limit: int, max_retries: int, - ): + ) -> None: """后台上传任务""" try: # 初始化任务状态 @@ -216,7 +216,7 @@ async def _background_import_task( batch_size: int, tasks_limit: int, max_retries: int, - ): + ) -> None: """后台导入预切片文档任务""" try: # 初始化任务状态 @@ -1215,7 +1215,7 @@ async def _background_upload_from_url_task( max_retries: int, enable_cleaning: bool, cleaning_provider_id: str | None, - ): + ) -> None: """后台上传URL任务""" try: # 初始化任务状态 diff --git a/astrbot/dashboard/routes/live_chat.py b/astrbot/dashboard/routes/live_chat.py index b6336a7974..85af3cef82 100644 --- a/astrbot/dashboard/routes/live_chat.py +++ b/astrbot/dashboard/routes/live_chat.py @@ -20,7 +20,7 @@ class LiveChatSession: """Live Chat 会话管理器""" - def __init__(self, session_id: str, username: str): + def __init__(self, session_id: str, username: str) -> None: self.session_id = session_id self.username = username self.conversation_id = str(uuid.uuid4()) @@ -31,14 +31,14 @@ def __init__(self, session_id: str, username: str): self.current_stamp: str | None = None self.temp_audio_path: str | None = None - def start_speaking(self, stamp: str): + def start_speaking(self, stamp: str) -> None: """开始说话""" self.is_speaking = True self.current_stamp = stamp self.audio_frames = [] logger.debug(f"[Live Chat] {self.username} 开始说话 stamp={stamp}") - def add_audio_frame(self, data: bytes): + def add_audio_frame(self, data: bytes) -> None: """添加音频帧""" if self.is_speaking: self.audio_frames.append(data) @@ -82,7 +82,7 @@ async def end_speaking(self, stamp: str) -> tuple[str | None, float]: logger.error(f"[Live Chat] 组装 WAV 文件失败: {e}", exc_info=True) return None, 0.0 - def cleanup(self): + def cleanup(self) -> None: """清理临时文件""" if self.temp_audio_path and os.path.exists(self.temp_audio_path): try: @@ -111,7 +111,7 @@ def __init__( # 注册 WebSocket 路由 self.app.websocket("/api/live_chat/ws")(self.live_chat_ws) - async def live_chat_ws(self): + async def live_chat_ws(self) -> None: """Live Chat WebSocket 处理器""" # WebSocket 不能通过 header 传递 token,需要从 query 参数获取 # 注意:WebSocket 上下文使用 websocket.args 而不是 request.args @@ -152,7 +152,7 @@ async def live_chat_ws(self): del self.sessions[session_id] logger.info(f"[Live Chat] WebSocket 连接关闭: {username}") - async def _handle_message(self, session: LiveChatSession, message: dict): + async def _handle_message(self, session: LiveChatSession, message: dict) -> None: """处理 WebSocket 消息""" msg_type = message.get("t") # 使用 t 代替 type @@ -201,7 +201,7 @@ async def _handle_message(self, session: LiveChatSession, message: dict): async def _process_audio( self, session: LiveChatSession, audio_path: str, assemble_duration: float - ): + ) -> None: """处理音频:STT -> LLM -> 流式 TTS""" try: # 发送 WAV 组装耗时 @@ -409,7 +409,7 @@ async def _process_audio( async def _save_interrupted_message( self, session: LiveChatSession, user_text: str, bot_text: str - ): + ) -> None: """保存被打断的消息""" interrupted_text = bot_text + " [用户打断]" logger.info(f"[Live Chat] 保存打断消息: {interrupted_text}") diff --git a/astrbot/dashboard/routes/platform.py b/astrbot/dashboard/routes/platform.py index 4d8fdddfe8..874bc19db7 100644 --- a/astrbot/dashboard/routes/platform.py +++ b/astrbot/dashboard/routes/platform.py @@ -26,7 +26,7 @@ def __init__( self._register_webhook_routes() - def _register_webhook_routes(self): + def _register_webhook_routes(self) -> None: """注册 webhook 路由""" # 统一 webhook 入口,支持 GET 和 POST self.app.add_url_rule( diff --git a/astrbot/dashboard/routes/plugin.py b/astrbot/dashboard/routes/plugin.py index b6160ff1e4..f9f8961b4a 100644 --- a/astrbot/dashboard/routes/plugin.py +++ b/astrbot/dashboard/routes/plugin.py @@ -261,7 +261,7 @@ def _load_plugin_cache(self, cache_file: str): logger.warning(f"加载插件市场缓存失败: {e}") return None - def _save_plugin_cache(self, cache_file: str, data, md5: str | None = None): + def _save_plugin_cache(self, cache_file: str, data, md5: str | None = None) -> None: """保存插件市场数据到本地缓存""" try: # 确保目录存在 diff --git a/astrbot/dashboard/routes/route.py b/astrbot/dashboard/routes/route.py index 01ab292d48..53c6234439 100644 --- a/astrbot/dashboard/routes/route.py +++ b/astrbot/dashboard/routes/route.py @@ -14,12 +14,12 @@ class RouteContext: class Route: routes: list | dict - def __init__(self, context: RouteContext): + def __init__(self, context: RouteContext) -> None: self.app = context.app self.config = context.config - def register_routes(self): - def _add_rule(path, method, func): + def register_routes(self) -> None: + def _add_rule(path, method, func) -> None: # 统一添加 /api 前缀 full_path = f"/api{path}" self.app.add_url_rule(full_path, view_func=func, methods=[method]) diff --git a/astrbot/dashboard/routes/static_file.py b/astrbot/dashboard/routes/static_file.py index 3d3d0ca512..e056b6c5ac 100644 --- a/astrbot/dashboard/routes/static_file.py +++ b/astrbot/dashboard/routes/static_file.py @@ -30,7 +30,7 @@ def __init__(self, context: RouteContext) -> None: self.app.add_url_rule(i, view_func=self.index) @self.app.errorhandler(404) - async def page_not_found(e): + async def page_not_found(e) -> str: return "404 Not found。如果你初次使用打开面板发现 404, 请参考文档: https://astrbot.app/faq.html。如果你正在测试回调地址可达性,显示这段文字说明测试成功了。" async def index(self): diff --git a/astrbot/dashboard/routes/t2i.py b/astrbot/dashboard/routes/t2i.py index db70a88205..8d06826be3 100644 --- a/astrbot/dashboard/routes/t2i.py +++ b/astrbot/dashboard/routes/t2i.py @@ -12,7 +12,9 @@ class T2iRoute(Route): - def __init__(self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle): + def __init__( + self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle + ) -> None: super().__init__(context) self.core_lifecycle = core_lifecycle self.config = core_lifecycle.astrbot_config diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index 5378f1a856..604866a873 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -190,7 +190,7 @@ def get_process_using_port(self, port: int) -> str: except Exception as e: return f"获取进程信息失败: {e!s}" - def _init_jwt_secret(self): + def _init_jwt_secret(self) -> None: if not self.config.get("dashboard", {}).get("jwt_secret", None): # 如果没有设置 JWT 密钥,则生成一个新的密钥 jwt_secret = os.urandom(32).hex() @@ -271,6 +271,6 @@ def run(self): return serve(self.app, config, shutdown_trigger=self.shutdown_trigger) - async def shutdown_trigger(self): + async def shutdown_trigger(self) -> None: await self.shutdown_event.wait() logger.info("AstrBot WebUI 已经被优雅地关闭") diff --git a/main.py b/main.py index 0e01d0e2d7..68e03dc9a1 100644 --- a/main.py +++ b/main.py @@ -32,7 +32,7 @@ """ -def check_env(): +def check_env() -> None: if not (sys.version_info.major == 3 and sys.version_info.minor >= 10): logger.error("请使用 Python3.10+ 运行本项目。") exit() diff --git a/scripts/generate_changelog.py b/scripts/generate_changelog.py index 446e0bc564..75b6ca88c9 100755 --- a/scripts/generate_changelog.py +++ b/scripts/generate_changelog.py @@ -185,7 +185,7 @@ def generate_simple_changelog(commits): return changelog_zh + changelog_en -def main(): +def main() -> None: parser = argparse.ArgumentParser(description="Generate changelog from git commits") parser.add_argument( "--version", help="Version number for the changelog (e.g., v4.13.3)"