Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 152 additions & 45 deletions src/knowledge/implementations/milvus.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import os
import re
import time
import traceback
from functools import partial
Expand Down Expand Up @@ -477,6 +478,9 @@ async def aquery(self, query_text: str, db_id: str, agent_call: bool = False, **
similarity_threshold = float(merged_kwargs.get("similarity_threshold", 0.2))
metric_type = merged_kwargs.get("metric_type", "COSINE")
include_distances = bool(merged_kwargs.get("include_distances", True))
search_mode = str(merged_kwargs.get("search_mode", "vector")).lower()
if search_mode not in {"vector", "keyword", "hybrid"}:
search_mode = "vector"

use_reranker = bool(merged_kwargs.get("use_reranker", False))
if use_reranker:
Expand All @@ -485,58 +489,140 @@ async def aquery(self, query_text: str, db_id: str, agent_call: bool = False, **
else:
recall_top_k = final_top_k

embed_info = self.databases_meta[db_id].get("embed_info", {})
embedding_function = self._get_embedding_function(embed_info)
query_embedding = embedding_function([query_text])

search_params = {"metric_type": metric_type, "params": {"nprobe": 10}}

# 构建过滤表达式
expr = None
# 构建过滤表达式(文件名)
file_expr = None
if file_name := merged_kwargs.get("file_name"):
# 使用 like 支持模糊匹配
# 注意:需要转义双引号以防止注入
safe_file_name = file_name.replace('"', '\\"')
# 如果没有提供通配符,默认前后添加 %
if "%" not in safe_file_name:
expr = f'source like "%{safe_file_name}%"'
file_expr = f'source like "%{safe_file_name}%"'
else:
expr = f'source like "{safe_file_name}"'
logger.debug(f"Using filter expression: {expr}")

results = collection.search(
data=query_embedding,
anns_field="embedding",
param=search_params,
limit=recall_top_k,
expr=expr,
output_fields=["content", "source", "chunk_id", "file_id", "chunk_index"],
)

if not results or len(results) == 0 or len(results[0]) == 0:
return []

retrieved_chunks = []
for hit in results[0]:
similarity = hit.distance if metric_type == "COSINE" else 1 / (1 + hit.distance)

if similarity < similarity_threshold:
continue
file_expr = f'source like "{safe_file_name}"'
logger.debug(f"Using filter expression: {file_expr}")

vector_results: list[dict] = []
if search_mode in {"vector", "hybrid"}:
embed_info = self.databases_meta[db_id].get("embed_info", {})
embedding_function = self._get_embedding_function(embed_info)
query_embedding = embedding_function([query_text])

search_params = {"metric_type": metric_type, "params": {"nprobe": 10}}

results = collection.search(
data=query_embedding,
anns_field="embedding",
param=search_params,
limit=recall_top_k,
expr=file_expr,
output_fields=["content", "source", "chunk_id", "file_id", "chunk_index"],
)

entity = hit.entity
metadata = {
"source": entity.get("source", "未知来源"),
"chunk_id": entity.get("chunk_id"),
"file_id": entity.get("file_id"),
"chunk_index": entity.get("chunk_index"),
}
if results and len(results) > 0 and len(results[0]) > 0:
for hit in results[0]:
similarity = hit.distance if metric_type == "COSINE" else 1 / (1 + hit.distance)
if similarity < similarity_threshold:
continue

entity = hit.entity
metadata = {
"source": entity.get("source", "未知来源"),
"chunk_id": entity.get("chunk_id"),
"file_id": entity.get("file_id"),
"chunk_index": entity.get("chunk_index"),
}

chunk = {"content": entity.get("content", ""), "metadata": metadata, "score": similarity}
if include_distances:
chunk["distance"] = hit.distance
vector_results.append(chunk)

logger.debug(
f"Milvus vector query response: {len(vector_results)} chunks found (after similarity filtering)"
)

chunk = {"content": entity.get("content", ""), "metadata": metadata, "score": similarity}
if include_distances:
chunk["distance"] = hit.distance
retrieved_chunks.append(chunk)
keyword_results: list[dict] = []
if search_mode in {"keyword", "hybrid"}:
keyword_top_k = int(merged_kwargs.get("keyword_top_k", final_top_k))
keyword_top_k = max(keyword_top_k, 1)
raw_keywords = re.split(r"[\s,,;;]+", str(query_text))
keywords = [kw.strip() for kw in raw_keywords if kw and kw.strip()]

if keywords:
keyword_clauses = []
for kw in keywords:
safe_kw = kw.replace('"', '\\"')
keyword_clauses.append(f'content like "%{safe_kw}%"')

keyword_expr = " or ".join(keyword_clauses)
if file_expr:
keyword_expr = f"({keyword_expr}) and ({file_expr})"

results = collection.query(
expr=keyword_expr,
output_fields=["content", "source", "chunk_id", "file_id", "chunk_index"],
limit=keyword_top_k,
)

logger.debug(f"Milvus query response: {len(retrieved_chunks)} chunks found (after similarity filtering)")
keyword_scores = []
for result in results or []:
content = result.get("content", "")
text_lower = content.lower()
match_count = sum(text_lower.count(kw.lower()) for kw in keywords if kw)
if match_count <= 0:
continue

metadata = {
"source": result.get("source", "未知来源"),
"chunk_id": result.get("chunk_id"),
"file_id": result.get("file_id"),
"chunk_index": result.get("chunk_index"),
}
keyword_scores.append((result, metadata, match_count))

if keyword_scores:
max_count = max(item[2] for item in keyword_scores)
for result, metadata, match_count in keyword_scores:
score = match_count / max_count if max_count > 0 else 0.0
keyword_results.append(
{
"content": result.get("content", ""),
"metadata": metadata,
"score": score,
"keyword_score": score,
}
)
keyword_results.sort(key=lambda item: item.get("score", 0.0), reverse=True)

logger.debug(f"Milvus keyword query response: {len(keyword_results)} chunks found")

if search_mode == "vector":
retrieved_chunks = vector_results
elif search_mode == "keyword":
retrieved_chunks = keyword_results
else:
merged: dict[str, dict] = {}
for item in vector_results:
chunk_id = item.get("metadata", {}).get("chunk_id")
if chunk_id:
merged[chunk_id] = item
else:
merged[id(item)] = item

for item in keyword_results:
chunk_id = item.get("metadata", {}).get("chunk_id")
if chunk_id in merged:
existing = merged[chunk_id]
keyword_score = item.get("keyword_score", item.get("score", 0.0))
existing_score = existing.get("score", 0.0)
existing["score"] = max(existing_score, keyword_score)
existing["keyword_score"] = keyword_score
else:
merged[chunk_id or id(item)] = item

retrieved_chunks = list(merged.values())
retrieved_chunks.sort(key=lambda item: item.get("score", 0.0), reverse=True)

if not retrieved_chunks:
return []

if not use_reranker:
return retrieved_chunks[:final_top_k]
Expand Down Expand Up @@ -703,6 +789,18 @@ def get_query_params_config(self, db_id: str, **kwargs) -> dict:
"""获取 Milvus 知识库的查询参数配置"""
# 构建 Milvus 特定参数(不再从 reranker_config 读取)
options = [
{
"key": "search_mode",
"label": "检索模式",
"type": "select",
"default": "vector",
"options": [
{"value": "vector", "label": "向量检索", "description": "仅使用向量相似度检索"},
{"value": "keyword", "label": "关键词检索", "description": "仅使用关键词匹配检索"},
{"value": "hybrid", "label": "混合检索", "description": "向量检索与关键词检索融合"},
],
"description": "选择检索模式",
},
{
"key": "final_top_k",
"label": "最终返回数",
Expand All @@ -722,6 +820,15 @@ def get_query_params_config(self, db_id: str, **kwargs) -> dict:
"step": 0.1,
"description": "过滤相似度低于此值的结果",
},
{
"key": "keyword_top_k",
"label": "关键词召回数量",
"type": "number",
"default": 50,
"min": 1,
"max": 200,
"description": "关键词/混合检索时的候选数量",
},
{
"key": "include_distances",
"label": "显示相似度",
Expand Down