From 9cd64708499b7cce29c016fb8450a04fdd3a179f Mon Sep 17 00:00:00 2001 From: holenzzz Date: Thu, 26 Feb 2026 12:43:38 +0800 Subject: [PATCH] Fix embedding 400 errors --- src/knowledge/implementations/milvus.py | 4 +-- .../services/upload_graph_service.py | 2 +- src/models/embed.py | 31 ++++++++++++++----- src/services/evaluation_service.py | 2 +- 4 files changed, 27 insertions(+), 12 deletions(-) diff --git a/src/knowledge/implementations/milvus.py b/src/knowledge/implementations/milvus.py index 9896e4c29..f72ac80ba 100644 --- a/src/knowledge/implementations/milvus.py +++ b/src/knowledge/implementations/milvus.py @@ -191,13 +191,13 @@ def _get_async_embedding(self, embed_info: dict): def _get_async_embedding_function(self, embed_info: dict): """获取 embedding 函数""" embedding_model = self._get_async_embedding(embed_info) - return partial(embedding_model.abatch_encode, batch_size=40) + return partial(embedding_model.abatch_encode, batch_size=10) def _get_embedding_function(self, embed_info: dict): """获取 embedding 函数""" embedding_model = self._get_async_embedding(embed_info) - return partial(embedding_model.batch_encode, batch_size=40) + return partial(embedding_model.batch_encode, batch_size=10) async def _get_milvus_collection(self, db_id: str): """获取或创建 Milvus 集合""" diff --git a/src/knowledge/services/upload_graph_service.py b/src/knowledge/services/upload_graph_service.py index da3556ae2..e1480cd66 100644 --- a/src/knowledge/services/upload_graph_service.py +++ b/src/knowledge/services/upload_graph_service.py @@ -488,7 +488,7 @@ def load_graph_info(self): logger.error(f"加载图数据库信息失败:{e}") return False - async def aget_embedding(self, text, batch_size=40): + async def aget_embedding(self, text, batch_size=10): if isinstance(text, list): outputs = await self.embed_model.abatch_encode(text, batch_size=batch_size) return outputs diff --git a/src/models/embed.py b/src/models/embed.py index a63bbc587..08f42e4ab 100644 --- a/src/models/embed.py +++ b/src/models/embed.py @@ -46,7 +46,7 @@ async def aencode_queries(self, queries: list[str] | str) -> list[list[float]]: """等同于aencode""" return await self.aencode(queries) - def batch_encode(self, messages: list[str], batch_size: int = 40) -> list[list[float]]: + def batch_encode(self, messages: list[str], batch_size: int = 10) -> list[list[float]]: # logger.info(f"Batch encoding {len(messages)} messages") data = [] task_id = None @@ -67,24 +67,39 @@ def batch_encode(self, messages: list[str], batch_size: int = 40) -> list[list[f return data - async def abatch_encode(self, messages: list[str], batch_size: int = 40) -> list[list[float]]: + async def abatch_encode(self, messages: list[str], batch_size: int = 10) -> list[list[float]]: data = [] task_id = None if len(messages) > batch_size: task_id = hashstr(messages) self.embed_state[task_id] = {"status": "in-progress", "total": len(messages), "progress": 0} - tasks = [] + #保留原有逻辑: + #使用 asyncio.gather 并发执行所有 embedding 批次请求: + # tasks = [] + # for i in range(0, len(messages), batch_size): + # group_msg = messages[i : i + batch_size] + # tasks.append(self.aencode(group_msg)) + + # results = await asyncio.gather(*tasks) + # for res in results: + # data.extend(res) + + # if task_id: + # self.embed_state[task_id]["progress"] = len(messages) + # self.embed_state[task_id]["status"] = "completed" + + # return data + for i in range(0, len(messages), batch_size): group_msg = messages[i : i + batch_size] - tasks.append(self.aencode(group_msg)) - - results = await asyncio.gather(*tasks) - for res in results: + logger.info(f"Async encoding [{i}/{len(messages)}] messages (bsz={batch_size})") + res = await self.aencode(group_msg) data.extend(res) + if task_id: + self.embed_state[task_id]["progress"] = i + len(group_msg) if task_id: - self.embed_state[task_id]["progress"] = len(messages) self.embed_state[task_id]["status"] = "completed" return data diff --git a/src/services/evaluation_service.py b/src/services/evaluation_service.py index e1b6968bc..836229b6c 100644 --- a/src/services/evaluation_service.py +++ b/src/services/evaluation_service.py @@ -336,7 +336,7 @@ async def _generate_benchmark_task(self, context: TaskContext): # Currently, we re-calculate embeddings for ALL chunks in the KB for every benchmark generation. # This is inefficient for large KBs (O(N) embedding calls). # Optimization: Reuse existing embeddings from Vector DB if embedding_model_id matches the KB's embedding model. - embeddings = await embed_model.abatch_encode(contents, batch_size=40) + embeddings = await embed_model.abatch_encode(contents, batch_size=10) norms = [math.sqrt(sum(x * x for x in vec)) or 1.0 for vec in embeddings] def cosine(a, b, na, nb):