Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions astrbot/dashboard/routes/t2i.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,18 @@ def __init__(
]
self.register_routes()

async def _reload_all_pipeline_schedulers(self) -> None:
"""热重载所有配置对应的 pipeline scheduler。"""
for conf_id in self.core_lifecycle.astrbot_config_mgr.confs:
await self.core_lifecycle.reload_pipeline_scheduler(conf_id)

async def _sync_active_template_to_all_configs(self, name: str) -> None:
"""同步当前激活模板到所有配置文件,并热重载对应流水线。"""
for config in self.core_lifecycle.astrbot_config_mgr.confs.values():
config["t2i_active_template"] = name
config.save_config()
await self._reload_all_pipeline_schedulers()

async def list_templates(self):
"""获取所有T2I模板列表"""
try:
Expand Down Expand Up @@ -133,7 +145,7 @@ async def update_template(self, name: str):
# 检查更新的是否为当前激活的模板,如果是,则热重载
active_template = self.config.get("t2i_active_template", "base")
if name == active_template:
await self.core_lifecycle.reload_pipeline_scheduler("default")
await self._reload_all_pipeline_schedulers()
message = f"模板 '{name}' 已更新并重新加载。"
else:
message = f"模板 '{name}' 已更新。"
Expand Down Expand Up @@ -182,13 +194,8 @@ async def set_active_template(self):
# 验证模板文件是否存在
self.manager.get_template(name)

# 更新配置
config = self.config
config["t2i_active_template"] = name
config.save_config(config)

# 热重载以应用更改
await self.core_lifecycle.reload_pipeline_scheduler("default")
# 更新所有配置并热重载以应用更改
await self._sync_active_template_to_all_configs(name)

return jsonify(asdict(Response().ok(message=f"模板 '{name}' 已成功应用。")))

Expand All @@ -209,13 +216,8 @@ async def reset_default_template(self):
try:
self.manager.reset_default_template()

# 更新配置,将激活模板也重置为'base'
config = self.config
config["t2i_active_template"] = "base"
config.save_config(config)

# 热重载以应用更改
await self.core_lifecycle.reload_pipeline_scheduler("default")
# 更新所有配置,将激活模板也重置为'base'
await self._sync_active_template_to_all_configs("base")

return jsonify(
asdict(
Expand Down
220 changes: 220 additions & 0 deletions tests/test_dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import io
import os
import sys
import uuid
import zipfile
from datetime import datetime
from types import SimpleNamespace
Expand Down Expand Up @@ -155,6 +156,7 @@ async def test_subagent_config_accepts_default_persona(
headers=authenticated_header,
)


@pytest.mark.asyncio
@pytest.mark.parametrize("payload", [[], "x"])
async def test_batch_delete_sessions_rejects_non_object_payload(
Comment thread
RC-CHN marked this conversation as resolved.
Expand Down Expand Up @@ -427,6 +429,224 @@ async def test_commands_api(app: Quart, authenticated_header: dict):
assert isinstance(data["data"], list)


@pytest.mark.asyncio
async def test_t2i_set_active_template_syncs_all_configs(
app: Quart,
authenticated_header: dict,
core_lifecycle_td: AstrBotCoreLifecycle,
):
test_client = app.test_client()
template_name = f"sync_tpl_{uuid.uuid4().hex[:8]}"
created_conf_ids: list[str] = []

try:
for name in ("sync-a", "sync-b"):
response = await test_client.post(
"/api/config/abconf/new",
json={"name": name},
headers=authenticated_header,
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
created_conf_ids.append(data["data"]["conf_id"])

response = await test_client.post(
"/api/t2i/templates/create",
json={
"name": template_name,
"content": "<html><body>{{ content }}</body></html>",
},
headers=authenticated_header,
)
assert response.status_code == 201
data = await response.get_json()
assert data["status"] == "ok"

response = await test_client.post(
"/api/t2i/templates/set_active",
json={"name": template_name},
headers=authenticated_header,
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"

conf_ids = set(core_lifecycle_td.astrbot_config_mgr.confs.keys())
assert "default" in conf_ids
for conf_id in conf_ids:
conf = core_lifecycle_td.astrbot_config_mgr.confs[conf_id]
assert conf.get("t2i_active_template") == template_name
assert conf_id in core_lifecycle_td.pipeline_scheduler_mapping
finally:
await test_client.post(
"/api/t2i/templates/set_active",
json={"name": "base"},
headers=authenticated_header,
)
await test_client.delete(
f"/api/t2i/templates/{template_name}",
headers=authenticated_header,
)
for conf_id in created_conf_ids:
await test_client.post(
"/api/config/abconf/delete",
json={"id": conf_id},
headers=authenticated_header,
)


@pytest.mark.asyncio
async def test_t2i_reset_default_template_syncs_all_configs(
app: Quart,
authenticated_header: dict,
core_lifecycle_td: AstrBotCoreLifecycle,
):
test_client = app.test_client()
template_name = f"reset_tpl_{uuid.uuid4().hex[:8]}"
created_conf_ids: list[str] = []

try:
for name in ("reset-a", "reset-b"):
response = await test_client.post(
"/api/config/abconf/new",
json={"name": name},
headers=authenticated_header,
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
created_conf_ids.append(data["data"]["conf_id"])

response = await test_client.post(
"/api/t2i/templates/create",
json={
"name": template_name,
"content": "<html><body>{{ content }} reset</body></html>",
},
headers=authenticated_header,
)
assert response.status_code == 201
data = await response.get_json()
assert data["status"] == "ok"

response = await test_client.post(
"/api/t2i/templates/set_active",
json={"name": template_name},
headers=authenticated_header,
)
assert response.status_code == 200

response = await test_client.post(
"/api/t2i/templates/reset_default",
headers=authenticated_header,
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"

conf_ids = set(core_lifecycle_td.astrbot_config_mgr.confs.keys())
assert "default" in conf_ids
for conf_id in conf_ids:
conf = core_lifecycle_td.astrbot_config_mgr.confs[conf_id]
assert conf.get("t2i_active_template") == "base"
assert conf_id in core_lifecycle_td.pipeline_scheduler_mapping
finally:
await test_client.post(
"/api/t2i/templates/set_active",
json={"name": "base"},
headers=authenticated_header,
)
await test_client.delete(
f"/api/t2i/templates/{template_name}",
headers=authenticated_header,
)
for conf_id in created_conf_ids:
await test_client.post(
"/api/config/abconf/delete",
json={"id": conf_id},
headers=authenticated_header,
)


@pytest.mark.asyncio
async def test_t2i_update_active_template_reloads_all_schedulers(
app: Quart,
authenticated_header: dict,
core_lifecycle_td: AstrBotCoreLifecycle,
):
test_client = app.test_client()
template_name = f"update_tpl_{uuid.uuid4().hex[:8]}"
created_conf_ids: list[str] = []

try:
for name in ("update-a", "update-b"):
response = await test_client.post(
"/api/config/abconf/new",
json={"name": name},
headers=authenticated_header,
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
created_conf_ids.append(data["data"]["conf_id"])

response = await test_client.post(
"/api/t2i/templates/create",
json={
"name": template_name,
"content": "<html><body>{{ content }} v1</body></html>",
},
headers=authenticated_header,
)
assert response.status_code == 201

response = await test_client.post(
"/api/t2i/templates/set_active",
json={"name": template_name},
headers=authenticated_header,
)
assert response.status_code == 200

conf_ids = list(core_lifecycle_td.astrbot_config_mgr.confs.keys())
old_schedulers = {
conf_id: core_lifecycle_td.pipeline_scheduler_mapping[conf_id]
for conf_id in conf_ids
}

response = await test_client.put(
f"/api/t2i/templates/{template_name}",
json={"content": "<html><body>{{ content }} v2</body></html>"},
headers=authenticated_header,
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"

for conf_id in conf_ids:
assert conf_id in core_lifecycle_td.pipeline_scheduler_mapping
assert (
core_lifecycle_td.pipeline_scheduler_mapping[conf_id]
is not old_schedulers[conf_id]
)
finally:
await test_client.post(
"/api/t2i/templates/set_active",
json={"name": "base"},
headers=authenticated_header,
)
await test_client.delete(
f"/api/t2i/templates/{template_name}",
headers=authenticated_header,
)
for conf_id in created_conf_ids:
await test_client.post(
"/api/config/abconf/delete",
json={"id": conf_id},
headers=authenticated_header,
)


@pytest.mark.asyncio
async def test_check_update(
app: Quart,
Expand Down
Loading