Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 103 additions & 1 deletion server/routers/knowledge_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from server.services.tasker import TaskContext, tasker
from src import config, knowledge_base
from src.knowledge.indexing import SUPPORTED_FILE_EXTENSIONS, is_supported_file_extension, process_file_to_markdown
from src.knowledge.utils import calculate_content_hash
from src.knowledge.utils import calculate_content_hash, merge_processing_params
from src.models.embed import test_embedding_model_status, test_all_embedding_models_status
from src.utils import hashstr, logger

Expand Down Expand Up @@ -361,6 +361,108 @@ async def delete_document(db_id: str, doc_id: str, current_user: User = Depends(
raise HTTPException(status_code=400, detail=f"删除文档失败: {e}")


@knowledge.post("/databases/{db_id}/documents/rechunks")
async def rechunks_documents(
db_id: str, file_ids: list[str] = Body(...), params: dict = Body(...), current_user: User = Depends(get_admin_user)
):
"""重新分块文档"""
logger.debug(f"Rechunks documents for db_id {db_id}: {file_ids} {params=}")

async def run_rechunks(context: TaskContext):
await context.set_message("任务初始化")
await context.set_progress(5.0, "准备重新分块文档")

total = len(file_ids)
processed_items = []

try:
# 逐个处理文档并更新进度
for idx, file_id in enumerate(file_ids, 1):
await context.raise_if_cancelled()

# 更新进度
progress = 5.0 + (idx / total) * 90.0 # 5% ~ 95%
await context.set_progress(progress, f"正在重新分块第 {idx}/{total} 个文档")

# 获取文档元数据中的处理参数
metadata_params = None
try:
file_info = await knowledge_base.get_file_basic_info(db_id, file_id)
metadata_params = file_info.get("meta", {}).get("processing_params")
except Exception as meta_error:
logger.warning(f"Failed to get metadata for file {file_id}: {meta_error}")

# 合并参数:优先使用请求参数,缺失时使用元数据参数
merged_params = merge_processing_params(metadata_params, params)

# 处理单个文档
try:
result = await knowledge_base.update_content(db_id, [file_id], params=merged_params)
processed_items.extend(result)
except Exception as doc_error:
# 处理单个文档处理的所有异常(包括超时)
logger.error(f"Document rechunking failed for {file_id}: {doc_error}")

# 判断是否是超时异常
error_type = "timeout" if isinstance(doc_error, TimeoutError) else "processing_error"
error_msg = "处理超时" if isinstance(doc_error, TimeoutError) else "处理失败"

processed_items.append({
"file_id": file_id,
"status": "failed",
"error": f"{error_msg}: {str(doc_error)}",
"error_type": error_type
})

except asyncio.CancelledError:
await context.set_progress(100.0, "任务已取消")
raise
except Exception as task_error:
# 处理整体任务的其他异常(如内存不足、网络错误等)
logger.exception(f"Task rechunking failed: {task_error}")
await context.set_progress(100.0, f"任务处理失败: {str(task_error)}")
# 将所有未处理的文档标记为失败
for file_id in file_ids[len(processed_items):]:
processed_items.append({
"file_id": file_id,
"status": "failed",
"error": f"任务失败: {str(task_error)}",
"error_type": "task_failed"
})
raise

failed_count = len([_p for _p in processed_items if _p.get("status") == "failed"])
summary = {
"db_id": db_id,
"submitted": len(processed_items),
"failed": failed_count,
}
message = f"文档重新分块完成,失败 {failed_count} 个" if failed_count else "文档重新分块完成"
await context.set_result(summary | {"items": processed_items})
await context.set_progress(100.0, message)
return summary | {"items": processed_items}

try:
task = await tasker.enqueue(
name=f"文档重新分块({db_id})",
task_type="knowledge_rechunks",
payload={
"db_id": db_id,
"file_ids": file_ids,
"params": params,
},
coroutine=run_rechunks,
)
return {
"message": "任务已提交,请在任务中心查看进度",
"status": "queued",
"task_id": task.id,
}
except Exception as e: # noqa: BLE001
logger.error(f"Failed to enqueue rechunks task: {e}, {traceback.format_exc()}")
return {"message": f"Failed to enqueue task: {e}", "status": "failed"}


@knowledge.get("/databases/{db_id}/documents/{doc_id}/download")
async def download_document(db_id: str, doc_id: str, request: Request, current_user: User = Depends(get_admin_user)):
"""下载原始文件"""
Expand Down
16 changes: 16 additions & 0 deletions src/knowledge/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,21 @@ async def add_content(self, db_id: str, items: list[str], params: dict | None =
"""
pass

@abstractmethod
async def update_content(self, db_id: str, file_ids: list[str], params: dict | None = None) -> list[dict]:
"""
更新内容 - 根据file_ids重新解析文件并更新向量库

Args:
db_id: 数据库ID
file_ids: 文件ID列表
params: 处理参数

Returns:
更新结果列表
"""
pass

@abstractmethod
async def aquery(self, query_text: str, db_id: str, **kwargs) -> list[dict]:
"""
Expand Down Expand Up @@ -278,6 +293,7 @@ def get_database_info(self, db_id: str) -> dict | None:
"type": file_info.get("file_type", ""),
"status": file_info.get("status", "done"),
"created_at": created_at,
"processing_params": file_info.get("processing_params", None),
}

# 按创建时间倒序排序文件列表
Expand Down
118 changes: 115 additions & 3 deletions src/knowledge/implementations/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ async def add_content(self, db_id: str, items: list[str], params: dict | None) -

for item in items:
# 准备文件元数据
metadata = prepare_item_metadata(item, content_type, db_id)
metadata = prepare_item_metadata(item, content_type, db_id, params=params)
file_id = metadata["file_id"]
filename = metadata["filename"]

Expand Down Expand Up @@ -252,6 +252,112 @@ async def add_content(self, db_id: str, items: list[str], params: dict | None) -

return processed_items_info

async def update_content(self, db_id: str, file_ids: list[str], params: dict | None = None) -> list[dict]:
"""更新内容 - 根据file_ids重新解析文件并更新向量库"""
if db_id not in self.databases_meta:
raise ValueError(f"Database {db_id} not found")

collection = await self._get_chroma_collection(db_id)
if not collection:
raise ValueError(f"Failed to get ChromaDB collection for {db_id}")

# 处理默认参数
if params is None:
params = {}
content_type = params.get("content_type", "file")
processed_items_info = []

for file_id in file_ids:
# 从元数据中获取文件信息
if file_id not in self.files_meta:
logger.warning(f"File {file_id} not found in metadata, skipping")
continue

file_meta = self.files_meta[file_id]
file_path = file_meta.get("path")
filename = file_meta.get("filename")

if not file_path:
logger.warning(f"File path not found for {file_id}, skipping")
continue

# 添加到处理队列
self._add_to_processing_queue(file_id)

try:
# 更新状态为处理中
self.files_meta[file_id]["processing_params"] = params.copy()
self.files_meta[file_id]["status"] = "processing"
self._save_metadata()

# 重新解析文件为 markdown
if content_type == "file":
markdown_content = await process_file_to_markdown(file_path, params=params)
else:
markdown_content = await process_url_to_markdown(file_path, params=params)

# 先删除现有的 ChromaDB 数据(仅删除chunks,保留元数据)
await self.delete_file_chunks_only(db_id, file_id)

# 重新生成 chunks
chunks = self._split_text_into_chunks(markdown_content, file_id, filename, params)
logger.info(f"Split {filename} into {len(chunks)} chunks")

if chunks:
documents = [chunk["content"] for chunk in chunks]
metadatas = [chunk["metadata"] for chunk in chunks]
ids = [chunk["id"] for chunk in chunks]

# 插入到 ChromaDB - 分批处理以避免超出 OpenAI 批次大小限制
batch_size = 64 # OpenAI 的最大批次大小限制
total_batches = (len(chunks) + batch_size - 1) // batch_size

for i in range(0, len(chunks), batch_size):
batch_documents = documents[i : i + batch_size]
batch_metadatas = metadatas[i : i + batch_size]
batch_ids = ids[i : i + batch_size]

await asyncio.to_thread(
collection.add,
documents=batch_documents,
metadatas=batch_metadatas,
ids=batch_ids,
)

batch_num = i // batch_size + 1
logger.info(f"Processed batch {batch_num}/{total_batches} for {filename}")

logger.info(f"Updated {content_type} {file_path} in ChromaDB. Done.")

# 更新元数据状态
self.files_meta[file_id]["status"] = "done"
self._save_metadata()

# 从处理队列中移除
self._remove_from_processing_queue(file_id)

# 返回更新后的文件信息
updated_file_meta = file_meta.copy()
updated_file_meta["status"] = "done"
updated_file_meta["file_id"] = file_id
processed_items_info.append(updated_file_meta)

except Exception as e:
logger.error(f"更新{content_type} {file_path} 失败: {e}, {traceback.format_exc()}")
self.files_meta[file_id]["status"] = "failed"
self._save_metadata()

# 从处理队列中移除
self._remove_from_processing_queue(file_id)

# 返回失败的文件信息
failed_file_meta = file_meta.copy()
failed_file_meta["status"] = "failed"
failed_file_meta["file_id"] = file_id
processed_items_info.append(failed_file_meta)

return processed_items_info

async def aquery(self, query_text: str, db_id: str, **kwargs) -> list[dict]:
"""异步查询知识库"""
collection = await self._get_chroma_collection(db_id)
Expand Down Expand Up @@ -346,8 +452,8 @@ async def aquery(self, query_text: str, db_id: str, **kwargs) -> list[dict]:
logger.error(f"ChromaDB query error: {e}, {traceback.format_exc()}")
return []

async def delete_file(self, db_id: str, file_id: str) -> None:
"""删除文件"""
async def delete_file_chunks_only(self, db_id: str, file_id: str) -> None:
"""仅删除文件的chunks数据,保留元数据(用于更新操作)"""
collection = await self._get_chroma_collection(db_id)
if collection:
try:
Expand All @@ -361,6 +467,12 @@ async def delete_file(self, db_id: str, file_id: str) -> None:

except Exception as e:
logger.error(f"Error deleting file {file_id} from ChromaDB: {e}")
# 注意:这里不删除 files_meta[file_id],保留元数据用于后续操作

async def delete_file(self, db_id: str, file_id: str) -> None:
"""删除文件(包括元数据)"""
# 先删除 ChromaDB 中的 chunks 数据
await self.delete_file_chunks_only(db_id, file_id)

# 删除文件记录
if file_id in self.files_meta:
Expand Down
Loading