Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 123 additions & 100 deletions src/agentic_layer/memory_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,13 @@
MemoryType.EPISODIC_MEMORY: EpisodicMemoryEsRepository,
}

# MemoryType -> Milvus Repository mapping
MILVUS_REPO_MAP = {
MemoryType.FORESIGHT: ForesightMilvusRepository,
MemoryType.EVENT_LOG: EventLogMilvusRepository,
MemoryType.EPISODIC_MEMORY: EpisodicMemoryMilvusRepository,
}


@dataclass
class EventLogCandidate:
Expand Down Expand Up @@ -299,7 +306,7 @@ async def retrieve_mem_keyword(
"""Keyword-based memory retrieval"""
start_time = time.perf_counter()
memory_type = (
retrieve_mem_request.memory_types[0].value
','.join(mt.value for mt in retrieve_mem_request.memory_types)
if retrieve_mem_request.memory_types
else 'unknown'
)
Expand Down Expand Up @@ -340,7 +347,7 @@ async def get_keyword_search_results(
"""Keyword search with stage-level metrics"""
stage_start = time.perf_counter()
memory_type = (
retrieve_mem_request.memory_types[0].value
','.join(mt.value for mt in retrieve_mem_request.memory_types)
if retrieve_mem_request.memory_types
else 'unknown'
)
Expand Down Expand Up @@ -375,32 +382,35 @@ async def get_keyword_search_results(
if end_time is not None:
date_range["lte"] = end_time

mem_type = memory_types[0]

repo_class = ES_REPO_MAP.get(mem_type)
if not repo_class:
logger.warning(f"Unsupported memory_type: {mem_type}")
return []
all_results = []
for mem_type in memory_types:
repo_class = ES_REPO_MAP.get(mem_type)
if not repo_class:
logger.info(
f"Skipping unsupported memory_type for keyword search: {mem_type}"
)
continue

es_repo = get_bean_by_type(repo_class)
logger.debug(f"Using {repo_class.__name__} for {mem_type}")
es_repo = get_bean_by_type(repo_class)
logger.debug(f"Using {repo_class.__name__} for {mem_type}")

results = await es_repo.multi_search(
query=query_words,
user_id=user_id,
group_id=group_id,
size=top_k,
from_=0,
date_range=date_range,
)
results = await es_repo.multi_search(
query=query_words,
user_id=user_id,
group_id=group_id,
size=top_k,
from_=0,
date_range=date_range,
)

# Mark memory_type, search_source, and unified score
if results:
for r in results:
r['memory_type'] = mem_type.value
r['_search_source'] = RetrieveMethod.KEYWORD.value
r['id'] = r.get('_id', '') # Unify ES '_id' to 'id'
r['score'] = r.get('_score', 0.0) # Unified score field
# Mark memory_type, search_source, and unified score
if results:
for r in results:
r['memory_type'] = mem_type.value
r['_search_source'] = RetrieveMethod.KEYWORD.value
r['id'] = r.get('_id', '') # Unify ES '_id' to 'id'
r['score'] = r.get('_score', 0.0) # Unified score field
all_results.extend(results)

# Record stage metrics
record_retrieve_stage(
Expand All @@ -410,7 +420,7 @@ async def get_keyword_search_results(
duration_seconds=time.perf_counter() - stage_start,
)

return results or []
return all_results
except Exception as e:
record_retrieve_stage(
retrieve_method=retrieve_method,
Expand All @@ -434,7 +444,7 @@ async def retrieve_mem_vector(
"""Vector-based memory retrieval"""
start_time = time.perf_counter()
memory_type = (
retrieve_mem_request.memory_types[0].value
','.join(mt.value for mt in retrieve_mem_request.memory_types)
if retrieve_mem_request.memory_types
else 'unknown'
)
Expand Down Expand Up @@ -474,7 +484,7 @@ async def get_vector_search_results(
) -> List[Dict[str, Any]]:
"""Vector search with stage-level metrics (embedding + milvus_search)"""
memory_type = (
retrieve_mem_request.memory_types[0].value
','.join(mt.value for mt in retrieve_mem_request.memory_types)
if retrieve_mem_request.memory_types
else 'unknown'
)
Expand All @@ -497,7 +507,7 @@ async def get_vector_search_results(
top_k = retrieve_mem_request.top_k
start_time = retrieve_mem_request.start_time
end_time = retrieve_mem_request.end_time
mem_type = retrieve_mem_request.memory_types[0]
memory_types = retrieve_mem_request.memory_types

logger.debug(
f"retrieve_mem_vector called with query: {query}, user_id: {user_id}, group_id: {group_id}, top_k: {top_k}"
Expand All @@ -506,7 +516,7 @@ async def get_vector_search_results(
# Get vectorization service
vectorize_service = get_vectorize_service()

# Convert query text to vector (embedding stage)
# Convert query text to vector (embedding stage) - only once for all types
logger.debug(f"Starting to vectorize query text: {query}")
embedding_start = time.perf_counter()
query_vector = await vectorize_service.get_embedding(query)
Expand All @@ -521,21 +531,9 @@ async def get_vector_search_results(
f"Query text vectorization completed, vector dimension: {len(query_vector_list)}"
)

# Select Milvus repository based on memory type
match mem_type:
case MemoryType.FORESIGHT:
milvus_repo = get_bean_by_type(ForesightMilvusRepository)
case MemoryType.EVENT_LOG:
milvus_repo = get_bean_by_type(EventLogMilvusRepository)
case MemoryType.EPISODIC_MEMORY:
milvus_repo = get_bean_by_type(EpisodicMemoryMilvusRepository)
case _:
raise ValueError(f"Unsupported memory type: {mem_type}")

# Handle time range filter conditions
start_time_dt = None
end_time_dt = None
current_time_dt = None

if start_time is not None:
start_time_dt = (
Expand All @@ -553,62 +551,75 @@ async def get_vector_search_results(
else:
end_time_dt = end_time

# Handle foresight time range (only valid for foresight)
if mem_type == MemoryType.FORESIGHT:
if retrieve_mem_request.start_time:
start_time_dt = from_iso_format(retrieve_mem_request.start_time)
if retrieve_mem_request.end_time:
end_time_dt = from_iso_format(retrieve_mem_request.end_time)
if retrieve_mem_request.current_time:
current_time_dt = from_iso_format(retrieve_mem_request.current_time)

# Call Milvus vector search (pass different parameters based on memory type)
milvus_start = time.perf_counter()
if mem_type == MemoryType.FORESIGHT:
# Foresight: supports time range and validity filtering, supports radius parameter
search_results = await milvus_repo.vector_search(
query_vector=query_vector_list,
user_id=user_id,
group_id=group_id,
start_time=start_time_dt,
end_time=end_time_dt,
current_time=current_time_dt,
limit=top_k,
score_threshold=0.0,
radius=retrieve_mem_request.radius,
)
else:
# Episodic memory and event log: use timestamp filtering, supports radius parameter
search_results = await milvus_repo.vector_search(
query_vector=query_vector_list,
user_id=user_id,
group_id=group_id,
start_time=start_time_dt,
end_time=end_time_dt,
limit=top_k,
score_threshold=0.0,
radius=retrieve_mem_request.radius,
# Search all supported memory types
all_results = []
for mem_type in memory_types:
repo_class = MILVUS_REPO_MAP.get(mem_type)
if not repo_class:
logger.info(
f"Skipping unsupported memory_type for vector search: {mem_type}"
)
continue

milvus_repo = get_bean_by_type(repo_class)

# Handle foresight-specific time range
type_start_time = start_time_dt
type_end_time = end_time_dt
current_time_dt = None
if mem_type == MemoryType.FORESIGHT:
if retrieve_mem_request.start_time:
type_start_time = from_iso_format(
retrieve_mem_request.start_time
)
if retrieve_mem_request.end_time:
type_end_time = from_iso_format(
retrieve_mem_request.end_time
)
if retrieve_mem_request.current_time:
current_time_dt = from_iso_format(
retrieve_mem_request.current_time
)

# Call Milvus vector search
milvus_start = time.perf_counter()
if mem_type == MemoryType.FORESIGHT:
search_results = await milvus_repo.vector_search(
query_vector=query_vector_list,
user_id=user_id,
group_id=group_id,
start_time=type_start_time,
end_time=type_end_time,
current_time=current_time_dt,
limit=top_k,
score_threshold=0.0,
radius=retrieve_mem_request.radius,
)
else:
search_results = await milvus_repo.vector_search(
query_vector=query_vector_list,
user_id=user_id,
group_id=group_id,
start_time=start_time_dt,
end_time=end_time_dt,
limit=top_k,
score_threshold=0.0,
radius=retrieve_mem_request.radius,
)
record_retrieve_stage(
retrieve_method=retrieve_method,
stage='milvus_search',
memory_type=mem_type.value,
duration_seconds=time.perf_counter() - milvus_start,
)
record_retrieve_stage(
retrieve_method=retrieve_method,
stage='milvus_search',
memory_type=memory_type,
duration_seconds=time.perf_counter() - milvus_start,
)

for r in search_results:
r['memory_type'] = mem_type.value
r['_search_source'] = RetrieveMethod.VECTOR.value
# Milvus already uses 'score', no need to rename
for r in search_results:
r['memory_type'] = mem_type.value
r['_search_source'] = RetrieveMethod.VECTOR.value
all_results.extend(search_results)

return search_results
return all_results
except Exception as e:
record_retrieve_stage(
retrieve_method=retrieve_method,
stage=RetrieveMethod.VECTOR.value,
memory_type=memory_type,
duration_seconds=time.perf_counter() - milvus_start,
)
record_retrieve_error(
retrieve_method=retrieve_method,
stage=RetrieveMethod.VECTOR.value,
Expand All @@ -625,7 +636,7 @@ async def retrieve_mem_hybrid(
"""Hybrid memory retrieval: keyword + vector + rerank"""
start_time = time.perf_counter()
memory_type = (
retrieve_mem_request.memory_types[0].value
','.join(mt.value for mt in retrieve_mem_request.memory_types)
if retrieve_mem_request.memory_types
else 'unknown'
)
Expand Down Expand Up @@ -700,7 +711,9 @@ async def _search_hybrid(
) -> List[Dict]:
"""Core hybrid search: keyword + vector + rerank, returns flat list"""
memory_type = (
request.memory_types[0].value if request.memory_types else 'unknown'
','.join(mt.value for mt in request.memory_types)
if request.memory_types
else 'unknown'
)
# Run keyword and vector search concurrently
kw_results, vec_results = await asyncio.gather(
Expand All @@ -723,7 +736,9 @@ async def _search_rrf(
) -> List[Dict]:
"""Core RRF search: keyword + vector + RRF fusion, returns flat list"""
memory_type = (
request.memory_types[0].value if request.memory_types else 'unknown'
','.join(mt.value for mt in request.memory_types)
if request.memory_types
else 'unknown'
)

# Run keyword and vector search concurrently
Expand Down Expand Up @@ -766,7 +781,11 @@ async def _to_response(
"""Convert flat hits list to grouped RetrieveMemResponse"""
user_id = req.user_id if req else ""
source_type = req.retrieve_method.value
memory_type = req.memory_types[0].value
memory_type = (
','.join(mt.value for mt in req.memory_types)
if req.memory_types
else 'unknown'
)

if not hits:
return RetrieveMemResponse(
Expand Down Expand Up @@ -809,7 +828,7 @@ async def retrieve_mem_rrf(
"""RRF-based memory retrieval: keyword + vector + RRF fusion"""
start_time = time.perf_counter()
memory_type = (
retrieve_mem_request.memory_types[0].value
','.join(mt.value for mt in retrieve_mem_request.memory_types)
if retrieve_mem_request.memory_types
else 'unknown'
)
Expand Down Expand Up @@ -855,7 +874,11 @@ async def retrieve_mem_agentic(
req = retrieve_mem_request # alias
top_k = req.top_k
config = AgenticConfig()
memory_type = req.memory_types[0].value if req.memory_types else 'unknown'
memory_type = (
','.join(mt.value for mt in req.memory_types)
if req.memory_types
else 'unknown'
)

try:
llm_provider = LLMProvider(
Expand Down
Loading