Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/knowledge/implementations/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,13 +191,13 @@ def _get_async_embedding(self, embed_info: dict):
def _get_async_embedding_function(self, embed_info: dict):
"""获取 embedding 函数"""
embedding_model = self._get_async_embedding(embed_info)
return partial(embedding_model.abatch_encode, batch_size=40)
return partial(embedding_model.abatch_encode, batch_size=10)

def _get_embedding_function(self, embed_info: dict):
"""获取 embedding 函数"""
embedding_model = self._get_async_embedding(embed_info)

return partial(embedding_model.batch_encode, batch_size=40)
return partial(embedding_model.batch_encode, batch_size=10)

async def _get_milvus_collection(self, db_id: str):
"""获取或创建 Milvus 集合"""
Expand Down
2 changes: 1 addition & 1 deletion src/knowledge/services/upload_graph_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ def load_graph_info(self):
logger.error(f"加载图数据库信息失败:{e}")
return False

async def aget_embedding(self, text, batch_size=40):
async def aget_embedding(self, text, batch_size=10):
if isinstance(text, list):
outputs = await self.embed_model.abatch_encode(text, batch_size=batch_size)
return outputs
Expand Down
31 changes: 23 additions & 8 deletions src/models/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ async def aencode_queries(self, queries: list[str] | str) -> list[list[float]]:
"""等同于aencode"""
return await self.aencode(queries)

def batch_encode(self, messages: list[str], batch_size: int = 40) -> list[list[float]]:
def batch_encode(self, messages: list[str], batch_size: int = 10) -> list[list[float]]:
# logger.info(f"Batch encoding {len(messages)} messages")
data = []
task_id = None
Expand All @@ -67,24 +67,39 @@ def batch_encode(self, messages: list[str], batch_size: int = 40) -> list[list[f

return data

async def abatch_encode(self, messages: list[str], batch_size: int = 40) -> list[list[float]]:
async def abatch_encode(self, messages: list[str], batch_size: int = 10) -> list[list[float]]:
data = []
task_id = None
if len(messages) > batch_size:
task_id = hashstr(messages)
self.embed_state[task_id] = {"status": "in-progress", "total": len(messages), "progress": 0}

tasks = []
#保留原有逻辑:
#使用 asyncio.gather 并发执行所有 embedding 批次请求:
# tasks = []
# for i in range(0, len(messages), batch_size):
# group_msg = messages[i : i + batch_size]
# tasks.append(self.aencode(group_msg))

# results = await asyncio.gather(*tasks)
# for res in results:
# data.extend(res)

# if task_id:
# self.embed_state[task_id]["progress"] = len(messages)
# self.embed_state[task_id]["status"] = "completed"

# return data

for i in range(0, len(messages), batch_size):
group_msg = messages[i : i + batch_size]
tasks.append(self.aencode(group_msg))

results = await asyncio.gather(*tasks)
for res in results:
logger.info(f"Async encoding [{i}/{len(messages)}] messages (bsz={batch_size})")
res = await self.aencode(group_msg)
data.extend(res)
if task_id:
self.embed_state[task_id]["progress"] = i + len(group_msg)

if task_id:
self.embed_state[task_id]["progress"] = len(messages)
self.embed_state[task_id]["status"] = "completed"

return data
Expand Down
2 changes: 1 addition & 1 deletion src/services/evaluation_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ async def _generate_benchmark_task(self, context: TaskContext):
# Currently, we re-calculate embeddings for ALL chunks in the KB for every benchmark generation.
# This is inefficient for large KBs (O(N) embedding calls).
# Optimization: Reuse existing embeddings from Vector DB if embedding_model_id matches the KB's embedding model.
embeddings = await embed_model.abatch_encode(contents, batch_size=40)
embeddings = await embed_model.abatch_encode(contents, batch_size=10)
norms = [math.sqrt(sum(x * x for x in vec)) or 1.0 for vec in embeddings]

def cosine(a, b, na, nb):
Expand Down