Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions src/services/conversation_service.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import uuid
from datetime import UTC, datetime

from fastapi import HTTPException, UploadFile
from sqlalchemy.ext.asyncio import AsyncSession

from src.agents import agent_manager
from src.repositories.conversation_repository import ConversationRepository
from src.services.doc_converter import (
ATTACHMENT_ALLOWED_EXTENSIONS,
Expand Down Expand Up @@ -41,6 +43,66 @@ def _make_attachment_path(file_name: str) -> str:
return f"/attachments/{safe_name}.md"


def _build_state_files(attachments: list[dict]) -> dict:
files = {}
for attachment in attachments:
if attachment.get("status") != "parsed":
continue

file_path = attachment.get("file_path")
markdown = attachment.get("markdown")
if not file_path or not markdown:
continue

now = datetime.now(UTC).isoformat()
files[file_path] = {
"content": markdown.split("\n"),
"created_at": attachment.get("uploaded_at", now),
"modified_at": attachment.get("uploaded_at", now),
}
return files


async def _sync_thread_attachment_state(
*,
thread_id: str,
user_id: str,
agent_id: str,
attachments: list[dict],
) -> None:
try:
agent = agent_manager.get_agent(agent_id)
if not agent:
logger.warning(f"Skip attachment state sync: agent not found ({agent_id})")
return

graph = await agent.get_graph()
config = {"configurable": {"thread_id": thread_id, "user_id": str(user_id)}}
state = await graph.aget_state(config)
state_values = getattr(state, "values", {}) if state else {}
existing_files = state_values.get("files", {}) if isinstance(state_values, dict) else {}
if not isinstance(existing_files, dict):
existing_files = {}

attachment_files = _build_state_files(attachments)
merged_files = {
path: file_data
for path, file_data in existing_files.items()
if isinstance(path, str) and not path.startswith("/attachments/")
}
merged_files.update(attachment_files)

await graph.aupdate_state(
config=config,
values={
"attachments": attachments,
"files": merged_files,
},
)
except Exception as e:
logger.warning(f"Failed to sync attachment state for thread {thread_id}: {e}")


def serialize_attachment(record: dict) -> dict:
"""序列化附件记录,返回给前端"""
return {
Expand Down Expand Up @@ -201,6 +263,13 @@ async def upload_thread_attachment_view(
"minio_url": minio_url,
}
await conv_repo.add_attachment(conversation.id, attachment_record)
all_attachments = await conv_repo.get_attachments(conversation.id)
await _sync_thread_attachment_state(
thread_id=thread_id,
user_id=str(current_user_id),
agent_id=conversation.agent_id,
attachments=all_attachments,
)

return serialize_attachment(attachment_record)

Expand Down Expand Up @@ -235,4 +304,11 @@ async def delete_thread_attachment_view(
removed = await conv_repo.remove_attachment(conversation.id, file_id)
if not removed:
raise HTTPException(status_code=404, detail="附件不存在或已被删除")
all_attachments = await conv_repo.get_attachments(conversation.id)
await _sync_thread_attachment_state(
thread_id=thread_id,
user_id=str(current_user_id),
agent_id=conversation.agent_id,
attachments=all_attachments,
)
return {"message": "附件已删除"}
104 changes: 104 additions & 0 deletions test/test_conversation_service_attachment_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from __future__ import annotations

from types import SimpleNamespace

import pytest

from src.services import conversation_service as svc


def test_build_state_files_only_parsed_and_with_content():
attachments = [
{
"status": "parsed",
"file_path": "/attachments/a.md",
"markdown": "line1\nline2",
"uploaded_at": "2026-02-20T00:00:00+00:00",
},
{
"status": "pending",
"file_path": "/attachments/b.md",
"markdown": "ignored",
},
{
"status": "parsed",
"file_path": "/attachments/c.md",
"markdown": "",
},
]

files = svc._build_state_files(attachments)

assert list(files.keys()) == ["/attachments/a.md"]
assert files["/attachments/a.md"]["content"] == ["line1", "line2"]
assert files["/attachments/a.md"]["created_at"] == "2026-02-20T00:00:00+00:00"


@pytest.mark.asyncio
async def test_sync_thread_attachment_state_updates_graph(monkeypatch: pytest.MonkeyPatch):
captured: dict = {}
fake_state = SimpleNamespace(
values={
"files": {
"/attachments/old.md": {"content": ["old"]},
"/work/result.md": {"content": ["keep"]},
}
}
)

class FakeGraph:
async def aget_state(self, config):
captured["read_config"] = config
return fake_state

async def aupdate_state(self, *, config, values):
captured["write_config"] = config
captured["write_values"] = values

class FakeAgent:
async def get_graph(self):
return FakeGraph()

monkeypatch.setattr(svc.agent_manager, "get_agent", lambda _agent_id: FakeAgent())

attachments = [
{
"status": "parsed",
"file_path": "/attachments/resume.md",
"markdown": "hello\nworld",
"uploaded_at": "2026-02-20T00:00:00+00:00",
}
]
await svc._sync_thread_attachment_state(
thread_id="thread-1",
user_id="u1",
agent_id="ChatbotAgent",
attachments=attachments,
)

assert captured["read_config"] == {"configurable": {"thread_id": "thread-1", "user_id": "u1"}}
assert captured["write_config"] == {"configurable": {"thread_id": "thread-1", "user_id": "u1"}}
assert captured["write_values"]["attachments"] == attachments
assert "/attachments/resume.md" in captured["write_values"]["files"]
assert "/attachments/old.md" not in captured["write_values"]["files"]
assert "/work/result.md" in captured["write_values"]["files"]


@pytest.mark.asyncio
async def test_sync_thread_attachment_state_skips_when_agent_missing(monkeypatch: pytest.MonkeyPatch):
warnings: list[str] = []
fake_logger = SimpleNamespace(
warning=lambda message: warnings.append(message),
)

monkeypatch.setattr(svc, "logger", fake_logger)
monkeypatch.setattr(svc.agent_manager, "get_agent", lambda _agent_id: None)

await svc._sync_thread_attachment_state(
thread_id="thread-1",
user_id="u1",
agent_id="MissingAgent",
attachments=[],
)

assert any("agent not found" in msg for msg in warnings)