diff --git a/src/services/conversation_service.py b/src/services/conversation_service.py index 642bce473..3f44d1010 100644 --- a/src/services/conversation_service.py +++ b/src/services/conversation_service.py @@ -1,8 +1,10 @@ import uuid +from datetime import UTC, datetime from fastapi import HTTPException, UploadFile from sqlalchemy.ext.asyncio import AsyncSession +from src.agents import agent_manager from src.repositories.conversation_repository import ConversationRepository from src.services.doc_converter import ( ATTACHMENT_ALLOWED_EXTENSIONS, @@ -41,6 +43,66 @@ def _make_attachment_path(file_name: str) -> str: return f"/attachments/{safe_name}.md" +def _build_state_files(attachments: list[dict]) -> dict: + files = {} + for attachment in attachments: + if attachment.get("status") != "parsed": + continue + + file_path = attachment.get("file_path") + markdown = attachment.get("markdown") + if not file_path or not markdown: + continue + + now = datetime.now(UTC).isoformat() + files[file_path] = { + "content": markdown.split("\n"), + "created_at": attachment.get("uploaded_at", now), + "modified_at": attachment.get("uploaded_at", now), + } + return files + + +async def _sync_thread_attachment_state( + *, + thread_id: str, + user_id: str, + agent_id: str, + attachments: list[dict], +) -> None: + try: + agent = agent_manager.get_agent(agent_id) + if not agent: + logger.warning(f"Skip attachment state sync: agent not found ({agent_id})") + return + + graph = await agent.get_graph() + config = {"configurable": {"thread_id": thread_id, "user_id": str(user_id)}} + state = await graph.aget_state(config) + state_values = getattr(state, "values", {}) if state else {} + existing_files = state_values.get("files", {}) if isinstance(state_values, dict) else {} + if not isinstance(existing_files, dict): + existing_files = {} + + attachment_files = _build_state_files(attachments) + merged_files = { + path: file_data + for path, file_data in existing_files.items() + if isinstance(path, str) and not path.startswith("/attachments/") + } + merged_files.update(attachment_files) + + await graph.aupdate_state( + config=config, + values={ + "attachments": attachments, + "files": merged_files, + }, + ) + except Exception as e: + logger.warning(f"Failed to sync attachment state for thread {thread_id}: {e}") + + def serialize_attachment(record: dict) -> dict: """序列化附件记录,返回给前端""" return { @@ -201,6 +263,13 @@ async def upload_thread_attachment_view( "minio_url": minio_url, } await conv_repo.add_attachment(conversation.id, attachment_record) + all_attachments = await conv_repo.get_attachments(conversation.id) + await _sync_thread_attachment_state( + thread_id=thread_id, + user_id=str(current_user_id), + agent_id=conversation.agent_id, + attachments=all_attachments, + ) return serialize_attachment(attachment_record) @@ -235,4 +304,11 @@ async def delete_thread_attachment_view( removed = await conv_repo.remove_attachment(conversation.id, file_id) if not removed: raise HTTPException(status_code=404, detail="附件不存在或已被删除") + all_attachments = await conv_repo.get_attachments(conversation.id) + await _sync_thread_attachment_state( + thread_id=thread_id, + user_id=str(current_user_id), + agent_id=conversation.agent_id, + attachments=all_attachments, + ) return {"message": "附件已删除"} diff --git a/test/test_conversation_service_attachment_state.py b/test/test_conversation_service_attachment_state.py new file mode 100644 index 000000000..fb04048aa --- /dev/null +++ b/test/test_conversation_service_attachment_state.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +from src.services import conversation_service as svc + + +def test_build_state_files_only_parsed_and_with_content(): + attachments = [ + { + "status": "parsed", + "file_path": "/attachments/a.md", + "markdown": "line1\nline2", + "uploaded_at": "2026-02-20T00:00:00+00:00", + }, + { + "status": "pending", + "file_path": "/attachments/b.md", + "markdown": "ignored", + }, + { + "status": "parsed", + "file_path": "/attachments/c.md", + "markdown": "", + }, + ] + + files = svc._build_state_files(attachments) + + assert list(files.keys()) == ["/attachments/a.md"] + assert files["/attachments/a.md"]["content"] == ["line1", "line2"] + assert files["/attachments/a.md"]["created_at"] == "2026-02-20T00:00:00+00:00" + + +@pytest.mark.asyncio +async def test_sync_thread_attachment_state_updates_graph(monkeypatch: pytest.MonkeyPatch): + captured: dict = {} + fake_state = SimpleNamespace( + values={ + "files": { + "/attachments/old.md": {"content": ["old"]}, + "/work/result.md": {"content": ["keep"]}, + } + } + ) + + class FakeGraph: + async def aget_state(self, config): + captured["read_config"] = config + return fake_state + + async def aupdate_state(self, *, config, values): + captured["write_config"] = config + captured["write_values"] = values + + class FakeAgent: + async def get_graph(self): + return FakeGraph() + + monkeypatch.setattr(svc.agent_manager, "get_agent", lambda _agent_id: FakeAgent()) + + attachments = [ + { + "status": "parsed", + "file_path": "/attachments/resume.md", + "markdown": "hello\nworld", + "uploaded_at": "2026-02-20T00:00:00+00:00", + } + ] + await svc._sync_thread_attachment_state( + thread_id="thread-1", + user_id="u1", + agent_id="ChatbotAgent", + attachments=attachments, + ) + + assert captured["read_config"] == {"configurable": {"thread_id": "thread-1", "user_id": "u1"}} + assert captured["write_config"] == {"configurable": {"thread_id": "thread-1", "user_id": "u1"}} + assert captured["write_values"]["attachments"] == attachments + assert "/attachments/resume.md" in captured["write_values"]["files"] + assert "/attachments/old.md" not in captured["write_values"]["files"] + assert "/work/result.md" in captured["write_values"]["files"] + + +@pytest.mark.asyncio +async def test_sync_thread_attachment_state_skips_when_agent_missing(monkeypatch: pytest.MonkeyPatch): + warnings: list[str] = [] + fake_logger = SimpleNamespace( + warning=lambda message: warnings.append(message), + ) + + monkeypatch.setattr(svc, "logger", fake_logger) + monkeypatch.setattr(svc.agent_manager, "get_agent", lambda _agent_id: None) + + await svc._sync_thread_attachment_state( + thread_id="thread-1", + user_id="u1", + agent_id="MissingAgent", + attachments=[], + ) + + assert any("agent not found" in msg for msg in warnings)