diff --git a/astrbot/dashboard/routes/t2i.py b/astrbot/dashboard/routes/t2i.py index 8d06826be3..634828e955 100644 --- a/astrbot/dashboard/routes/t2i.py +++ b/astrbot/dashboard/routes/t2i.py @@ -38,6 +38,18 @@ def __init__( ] self.register_routes() + async def _reload_all_pipeline_schedulers(self) -> None: + """热重载所有配置对应的 pipeline scheduler。""" + for conf_id in self.core_lifecycle.astrbot_config_mgr.confs: + await self.core_lifecycle.reload_pipeline_scheduler(conf_id) + + async def _sync_active_template_to_all_configs(self, name: str) -> None: + """同步当前激活模板到所有配置文件,并热重载对应流水线。""" + for config in self.core_lifecycle.astrbot_config_mgr.confs.values(): + config["t2i_active_template"] = name + config.save_config() + await self._reload_all_pipeline_schedulers() + async def list_templates(self): """获取所有T2I模板列表""" try: @@ -133,7 +145,7 @@ async def update_template(self, name: str): # 检查更新的是否为当前激活的模板,如果是,则热重载 active_template = self.config.get("t2i_active_template", "base") if name == active_template: - await self.core_lifecycle.reload_pipeline_scheduler("default") + await self._reload_all_pipeline_schedulers() message = f"模板 '{name}' 已更新并重新加载。" else: message = f"模板 '{name}' 已更新。" @@ -182,13 +194,8 @@ async def set_active_template(self): # 验证模板文件是否存在 self.manager.get_template(name) - # 更新配置 - config = self.config - config["t2i_active_template"] = name - config.save_config(config) - - # 热重载以应用更改 - await self.core_lifecycle.reload_pipeline_scheduler("default") + # 更新所有配置并热重载以应用更改 + await self._sync_active_template_to_all_configs(name) return jsonify(asdict(Response().ok(message=f"模板 '{name}' 已成功应用。"))) @@ -209,13 +216,8 @@ async def reset_default_template(self): try: self.manager.reset_default_template() - # 更新配置,将激活模板也重置为'base' - config = self.config - config["t2i_active_template"] = "base" - config.save_config(config) - - # 热重载以应用更改 - await self.core_lifecycle.reload_pipeline_scheduler("default") + # 更新所有配置,将激活模板也重置为'base' + await self._sync_active_template_to_all_configs("base") return jsonify( asdict( diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py index 8c505ad2c3..7fba6893e5 100644 --- a/tests/test_dashboard.py +++ b/tests/test_dashboard.py @@ -3,6 +3,7 @@ import io import os import sys +import uuid import zipfile from datetime import datetime from types import SimpleNamespace @@ -155,6 +156,7 @@ async def test_subagent_config_accepts_default_persona( headers=authenticated_header, ) + @pytest.mark.asyncio @pytest.mark.parametrize("payload", [[], "x"]) async def test_batch_delete_sessions_rejects_non_object_payload( @@ -427,6 +429,224 @@ async def test_commands_api(app: Quart, authenticated_header: dict): assert isinstance(data["data"], list) +@pytest.mark.asyncio +async def test_t2i_set_active_template_syncs_all_configs( + app: Quart, + authenticated_header: dict, + core_lifecycle_td: AstrBotCoreLifecycle, +): + test_client = app.test_client() + template_name = f"sync_tpl_{uuid.uuid4().hex[:8]}" + created_conf_ids: list[str] = [] + + try: + for name in ("sync-a", "sync-b"): + response = await test_client.post( + "/api/config/abconf/new", + json={"name": name}, + headers=authenticated_header, + ) + assert response.status_code == 200 + data = await response.get_json() + assert data["status"] == "ok" + created_conf_ids.append(data["data"]["conf_id"]) + + response = await test_client.post( + "/api/t2i/templates/create", + json={ + "name": template_name, + "content": "{{ content }}", + }, + headers=authenticated_header, + ) + assert response.status_code == 201 + data = await response.get_json() + assert data["status"] == "ok" + + response = await test_client.post( + "/api/t2i/templates/set_active", + json={"name": template_name}, + headers=authenticated_header, + ) + assert response.status_code == 200 + data = await response.get_json() + assert data["status"] == "ok" + + conf_ids = set(core_lifecycle_td.astrbot_config_mgr.confs.keys()) + assert "default" in conf_ids + for conf_id in conf_ids: + conf = core_lifecycle_td.astrbot_config_mgr.confs[conf_id] + assert conf.get("t2i_active_template") == template_name + assert conf_id in core_lifecycle_td.pipeline_scheduler_mapping + finally: + await test_client.post( + "/api/t2i/templates/set_active", + json={"name": "base"}, + headers=authenticated_header, + ) + await test_client.delete( + f"/api/t2i/templates/{template_name}", + headers=authenticated_header, + ) + for conf_id in created_conf_ids: + await test_client.post( + "/api/config/abconf/delete", + json={"id": conf_id}, + headers=authenticated_header, + ) + + +@pytest.mark.asyncio +async def test_t2i_reset_default_template_syncs_all_configs( + app: Quart, + authenticated_header: dict, + core_lifecycle_td: AstrBotCoreLifecycle, +): + test_client = app.test_client() + template_name = f"reset_tpl_{uuid.uuid4().hex[:8]}" + created_conf_ids: list[str] = [] + + try: + for name in ("reset-a", "reset-b"): + response = await test_client.post( + "/api/config/abconf/new", + json={"name": name}, + headers=authenticated_header, + ) + assert response.status_code == 200 + data = await response.get_json() + assert data["status"] == "ok" + created_conf_ids.append(data["data"]["conf_id"]) + + response = await test_client.post( + "/api/t2i/templates/create", + json={ + "name": template_name, + "content": "{{ content }} reset", + }, + headers=authenticated_header, + ) + assert response.status_code == 201 + data = await response.get_json() + assert data["status"] == "ok" + + response = await test_client.post( + "/api/t2i/templates/set_active", + json={"name": template_name}, + headers=authenticated_header, + ) + assert response.status_code == 200 + + response = await test_client.post( + "/api/t2i/templates/reset_default", + headers=authenticated_header, + ) + assert response.status_code == 200 + data = await response.get_json() + assert data["status"] == "ok" + + conf_ids = set(core_lifecycle_td.astrbot_config_mgr.confs.keys()) + assert "default" in conf_ids + for conf_id in conf_ids: + conf = core_lifecycle_td.astrbot_config_mgr.confs[conf_id] + assert conf.get("t2i_active_template") == "base" + assert conf_id in core_lifecycle_td.pipeline_scheduler_mapping + finally: + await test_client.post( + "/api/t2i/templates/set_active", + json={"name": "base"}, + headers=authenticated_header, + ) + await test_client.delete( + f"/api/t2i/templates/{template_name}", + headers=authenticated_header, + ) + for conf_id in created_conf_ids: + await test_client.post( + "/api/config/abconf/delete", + json={"id": conf_id}, + headers=authenticated_header, + ) + + +@pytest.mark.asyncio +async def test_t2i_update_active_template_reloads_all_schedulers( + app: Quart, + authenticated_header: dict, + core_lifecycle_td: AstrBotCoreLifecycle, +): + test_client = app.test_client() + template_name = f"update_tpl_{uuid.uuid4().hex[:8]}" + created_conf_ids: list[str] = [] + + try: + for name in ("update-a", "update-b"): + response = await test_client.post( + "/api/config/abconf/new", + json={"name": name}, + headers=authenticated_header, + ) + assert response.status_code == 200 + data = await response.get_json() + assert data["status"] == "ok" + created_conf_ids.append(data["data"]["conf_id"]) + + response = await test_client.post( + "/api/t2i/templates/create", + json={ + "name": template_name, + "content": "{{ content }} v1", + }, + headers=authenticated_header, + ) + assert response.status_code == 201 + + response = await test_client.post( + "/api/t2i/templates/set_active", + json={"name": template_name}, + headers=authenticated_header, + ) + assert response.status_code == 200 + + conf_ids = list(core_lifecycle_td.astrbot_config_mgr.confs.keys()) + old_schedulers = { + conf_id: core_lifecycle_td.pipeline_scheduler_mapping[conf_id] + for conf_id in conf_ids + } + + response = await test_client.put( + f"/api/t2i/templates/{template_name}", + json={"content": "{{ content }} v2"}, + headers=authenticated_header, + ) + assert response.status_code == 200 + data = await response.get_json() + assert data["status"] == "ok" + + for conf_id in conf_ids: + assert conf_id in core_lifecycle_td.pipeline_scheduler_mapping + assert ( + core_lifecycle_td.pipeline_scheduler_mapping[conf_id] + is not old_schedulers[conf_id] + ) + finally: + await test_client.post( + "/api/t2i/templates/set_active", + json={"name": "base"}, + headers=authenticated_header, + ) + await test_client.delete( + f"/api/t2i/templates/{template_name}", + headers=authenticated_header, + ) + for conf_id in created_conf_ids: + await test_client.post( + "/api/config/abconf/delete", + json={"id": conf_id}, + headers=authenticated_header, + ) + + @pytest.mark.asyncio async def test_check_update( app: Quart,