From d18468a16ca8e43d5f2e9790453b710a96f23998 Mon Sep 17 00:00:00 2001 From: jonasHanhan Date: Mon, 26 Jan 2026 16:10:57 +0800 Subject: [PATCH] feat: add keyword and hybrid retrieval for milvus --- src/knowledge/implementations/milvus.py | 197 ++++++++++++++++++------ 1 file changed, 152 insertions(+), 45 deletions(-) diff --git a/src/knowledge/implementations/milvus.py b/src/knowledge/implementations/milvus.py index 0a05ebd33..e50bf3b56 100644 --- a/src/knowledge/implementations/milvus.py +++ b/src/knowledge/implementations/milvus.py @@ -1,5 +1,6 @@ import asyncio import os +import re import time import traceback from functools import partial @@ -477,6 +478,9 @@ async def aquery(self, query_text: str, db_id: str, agent_call: bool = False, ** similarity_threshold = float(merged_kwargs.get("similarity_threshold", 0.2)) metric_type = merged_kwargs.get("metric_type", "COSINE") include_distances = bool(merged_kwargs.get("include_distances", True)) + search_mode = str(merged_kwargs.get("search_mode", "vector")).lower() + if search_mode not in {"vector", "keyword", "hybrid"}: + search_mode = "vector" use_reranker = bool(merged_kwargs.get("use_reranker", False)) if use_reranker: @@ -485,58 +489,140 @@ async def aquery(self, query_text: str, db_id: str, agent_call: bool = False, ** else: recall_top_k = final_top_k - embed_info = self.databases_meta[db_id].get("embed_info", {}) - embedding_function = self._get_embedding_function(embed_info) - query_embedding = embedding_function([query_text]) - - search_params = {"metric_type": metric_type, "params": {"nprobe": 10}} - - # 构建过滤表达式 - expr = None + # 构建过滤表达式(文件名) + file_expr = None if file_name := merged_kwargs.get("file_name"): - # 使用 like 支持模糊匹配 - # 注意:需要转义双引号以防止注入 safe_file_name = file_name.replace('"', '\\"') - # 如果没有提供通配符,默认前后添加 % if "%" not in safe_file_name: - expr = f'source like "%{safe_file_name}%"' + file_expr = f'source like "%{safe_file_name}%"' else: - expr = f'source like "{safe_file_name}"' - logger.debug(f"Using filter expression: {expr}") - - results = collection.search( - data=query_embedding, - anns_field="embedding", - param=search_params, - limit=recall_top_k, - expr=expr, - output_fields=["content", "source", "chunk_id", "file_id", "chunk_index"], - ) - - if not results or len(results) == 0 or len(results[0]) == 0: - return [] - - retrieved_chunks = [] - for hit in results[0]: - similarity = hit.distance if metric_type == "COSINE" else 1 / (1 + hit.distance) - - if similarity < similarity_threshold: - continue + file_expr = f'source like "{safe_file_name}"' + logger.debug(f"Using filter expression: {file_expr}") + + vector_results: list[dict] = [] + if search_mode in {"vector", "hybrid"}: + embed_info = self.databases_meta[db_id].get("embed_info", {}) + embedding_function = self._get_embedding_function(embed_info) + query_embedding = embedding_function([query_text]) + + search_params = {"metric_type": metric_type, "params": {"nprobe": 10}} + + results = collection.search( + data=query_embedding, + anns_field="embedding", + param=search_params, + limit=recall_top_k, + expr=file_expr, + output_fields=["content", "source", "chunk_id", "file_id", "chunk_index"], + ) - entity = hit.entity - metadata = { - "source": entity.get("source", "未知来源"), - "chunk_id": entity.get("chunk_id"), - "file_id": entity.get("file_id"), - "chunk_index": entity.get("chunk_index"), - } + if results and len(results) > 0 and len(results[0]) > 0: + for hit in results[0]: + similarity = hit.distance if metric_type == "COSINE" else 1 / (1 + hit.distance) + if similarity < similarity_threshold: + continue + + entity = hit.entity + metadata = { + "source": entity.get("source", "未知来源"), + "chunk_id": entity.get("chunk_id"), + "file_id": entity.get("file_id"), + "chunk_index": entity.get("chunk_index"), + } + + chunk = {"content": entity.get("content", ""), "metadata": metadata, "score": similarity} + if include_distances: + chunk["distance"] = hit.distance + vector_results.append(chunk) + + logger.debug( + f"Milvus vector query response: {len(vector_results)} chunks found (after similarity filtering)" + ) - chunk = {"content": entity.get("content", ""), "metadata": metadata, "score": similarity} - if include_distances: - chunk["distance"] = hit.distance - retrieved_chunks.append(chunk) + keyword_results: list[dict] = [] + if search_mode in {"keyword", "hybrid"}: + keyword_top_k = int(merged_kwargs.get("keyword_top_k", final_top_k)) + keyword_top_k = max(keyword_top_k, 1) + raw_keywords = re.split(r"[\s,,;;]+", str(query_text)) + keywords = [kw.strip() for kw in raw_keywords if kw and kw.strip()] + + if keywords: + keyword_clauses = [] + for kw in keywords: + safe_kw = kw.replace('"', '\\"') + keyword_clauses.append(f'content like "%{safe_kw}%"') + + keyword_expr = " or ".join(keyword_clauses) + if file_expr: + keyword_expr = f"({keyword_expr}) and ({file_expr})" + + results = collection.query( + expr=keyword_expr, + output_fields=["content", "source", "chunk_id", "file_id", "chunk_index"], + limit=keyword_top_k, + ) - logger.debug(f"Milvus query response: {len(retrieved_chunks)} chunks found (after similarity filtering)") + keyword_scores = [] + for result in results or []: + content = result.get("content", "") + text_lower = content.lower() + match_count = sum(text_lower.count(kw.lower()) for kw in keywords if kw) + if match_count <= 0: + continue + + metadata = { + "source": result.get("source", "未知来源"), + "chunk_id": result.get("chunk_id"), + "file_id": result.get("file_id"), + "chunk_index": result.get("chunk_index"), + } + keyword_scores.append((result, metadata, match_count)) + + if keyword_scores: + max_count = max(item[2] for item in keyword_scores) + for result, metadata, match_count in keyword_scores: + score = match_count / max_count if max_count > 0 else 0.0 + keyword_results.append( + { + "content": result.get("content", ""), + "metadata": metadata, + "score": score, + "keyword_score": score, + } + ) + keyword_results.sort(key=lambda item: item.get("score", 0.0), reverse=True) + + logger.debug(f"Milvus keyword query response: {len(keyword_results)} chunks found") + + if search_mode == "vector": + retrieved_chunks = vector_results + elif search_mode == "keyword": + retrieved_chunks = keyword_results + else: + merged: dict[str, dict] = {} + for item in vector_results: + chunk_id = item.get("metadata", {}).get("chunk_id") + if chunk_id: + merged[chunk_id] = item + else: + merged[id(item)] = item + + for item in keyword_results: + chunk_id = item.get("metadata", {}).get("chunk_id") + if chunk_id in merged: + existing = merged[chunk_id] + keyword_score = item.get("keyword_score", item.get("score", 0.0)) + existing_score = existing.get("score", 0.0) + existing["score"] = max(existing_score, keyword_score) + existing["keyword_score"] = keyword_score + else: + merged[chunk_id or id(item)] = item + + retrieved_chunks = list(merged.values()) + retrieved_chunks.sort(key=lambda item: item.get("score", 0.0), reverse=True) + + if not retrieved_chunks: + return [] if not use_reranker: return retrieved_chunks[:final_top_k] @@ -703,6 +789,18 @@ def get_query_params_config(self, db_id: str, **kwargs) -> dict: """获取 Milvus 知识库的查询参数配置""" # 构建 Milvus 特定参数(不再从 reranker_config 读取) options = [ + { + "key": "search_mode", + "label": "检索模式", + "type": "select", + "default": "vector", + "options": [ + {"value": "vector", "label": "向量检索", "description": "仅使用向量相似度检索"}, + {"value": "keyword", "label": "关键词检索", "description": "仅使用关键词匹配检索"}, + {"value": "hybrid", "label": "混合检索", "description": "向量检索与关键词检索融合"}, + ], + "description": "选择检索模式", + }, { "key": "final_top_k", "label": "最终返回数", @@ -722,6 +820,15 @@ def get_query_params_config(self, db_id: str, **kwargs) -> dict: "step": 0.1, "description": "过滤相似度低于此值的结果", }, + { + "key": "keyword_top_k", + "label": "关键词召回数量", + "type": "number", + "default": 50, + "min": 1, + "max": 200, + "description": "关键词/混合检索时的候选数量", + }, { "key": "include_distances", "label": "显示相似度",