Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 115 additions & 10 deletions src/api/routes/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

import asyncio
import logging
import threading
import time
from collections import defaultdict, deque
from typing import Any, Dict, List

from fastapi import APIRouter, Depends, Request, UploadFile, File
Expand Down Expand Up @@ -50,6 +52,7 @@
logger = logging.getLogger("xmem.api.routes.memory")

_ingest_semaphore = asyncio.Semaphore(5)
_latency_samples: dict[str, deque[float]] = defaultdict(lambda: deque(maxlen=200))

router = APIRouter(
prefix="/v1/memory",
Expand Down Expand Up @@ -108,6 +111,40 @@ def _error(request: Request, detail: str, code: int, elapsed_ms: float = 0) -> J
return JSONResponse(content=body.model_dump(), status_code=code)


def _record_latency(mode: str, elapsed_ms: float) -> None:
_latency_samples[mode].append(elapsed_ms)


def _percentile(sorted_values: List[float], percentile: float) -> float:
if not sorted_values:
return 0.0
index = min(len(sorted_values) - 1, int(round((len(sorted_values) - 1) * percentile)))
return round(sorted_values[index], 2)


def _latency_stats() -> Dict[str, Dict[str, float]]:
stats: Dict[str, Dict[str, float]] = {}
for mode, samples in _latency_samples.items():
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Iterating over _latency_samples.items() directly is not thread-safe because _latency_samples is a defaultdict. If a concurrent request records latency for a new mode (causing a new key to be added) while this loop is running, it will raise a RuntimeError: dictionary changed size during iteration.

Consider wrapping the items in a list() to create a snapshot of the keys before iteration.

Suggested change
for mode, samples in _latency_samples.items():
for mode, samples in list(_latency_samples.items()):

values = sorted(samples)
stats[mode] = {
"count": len(values),
"p50_ms": _percentile(values, 0.50),
"p95_ms": _percentile(values, 0.95),
"p99_ms": _percentile(values, 0.99),
}
return stats


async def _timed(mode: str, func, *args, **kwargs):
start = time.perf_counter()
result = func(*args, **kwargs)
if hasattr(result, "__await__"):
result = await result
elapsed_ms = round((time.perf_counter() - start) * 1000, 2)
_record_latency(mode, elapsed_ms)
return result, elapsed_ms


def _detect_chat_provider(*urls: str) -> str:
for url in urls:
lowered = (url or "").lower()
Expand Down Expand Up @@ -145,8 +182,6 @@ async def _render_chat_share(url: str) -> tuple[str, str]:
# reuse it across scrape requests. The browser is thread-safe when each
# request uses its own BrowserContext.

import threading

_browser_lock = threading.Lock()
_pw_instance = None
_browser_instance = None
Expand Down Expand Up @@ -690,15 +725,57 @@ async def search_memory(req: SearchRequest, request: Request, user: dict = Depen

try:
all_results: List[SourceRecord] = []
latency_ms: Dict[str, float] = {}
plan = pipeline.raw_retrieval_plan(req.domains, answer=req.answer)

if "profile" in plan:
results, elapsed = await _timed("profile", _search_profile, pipeline, user_id)
latency_ms["profile"] = elapsed
all_results.extend(results)
if "temporal" in plan:
results, elapsed = await _timed("temporal", _search_temporal, pipeline, req.query, user_id, req.top_k)
latency_ms["temporal"] = elapsed
all_results.extend(results)
if "summary" in plan:
results, elapsed = await _timed("summary", _search_summary, pipeline, req.query, user_id, req.top_k)
latency_ms["summary"] = elapsed
all_results.extend(results)
if "snippet" in plan:
results, elapsed = await _timed("snippet", _search_snippet, pipeline, req.query, user_id, req.top_k)
latency_ms["snippet"] = elapsed
all_results.extend(results)
if "code" in plan:
results, elapsed = await _timed("code", _search_code, pipeline, req.query, user_id, req.top_k)
latency_ms["code"] = elapsed
all_results.extend(results)
Comment on lines +731 to +750
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The search domains are currently queried sequentially, which negates the performance benefits of having a "fast" raw search path. Since these operations are I/O bound, they should be executed in parallel using asyncio.gather to minimize total latency.

Suggested change
if "profile" in plan:
results, elapsed = await _timed("profile", _search_profile, pipeline, user_id)
latency_ms["profile"] = elapsed
all_results.extend(results)
if "temporal" in plan:
results, elapsed = await _timed("temporal", _search_temporal, pipeline, req.query, user_id, req.top_k)
latency_ms["temporal"] = elapsed
all_results.extend(results)
if "summary" in plan:
results, elapsed = await _timed("summary", _search_summary, pipeline, req.query, user_id, req.top_k)
latency_ms["summary"] = elapsed
all_results.extend(results)
if "snippet" in plan:
results, elapsed = await _timed("snippet", _search_snippet, pipeline, req.query, user_id, req.top_k)
latency_ms["snippet"] = elapsed
all_results.extend(results)
if "code" in plan:
results, elapsed = await _timed("code", _search_code, pipeline, req.query, user_id, req.top_k)
latency_ms["code"] = elapsed
all_results.extend(results)
search_tasks = []
if "profile" in plan:
search_tasks.append(_timed("profile", _search_profile, pipeline, user_id))
if "temporal" in plan:
search_tasks.append(_timed("temporal", _search_temporal, pipeline, req.query, user_id, req.top_k))
if "summary" in plan:
search_tasks.append(_timed("summary", _search_summary, pipeline, req.query, user_id, req.top_k))
if "snippet" in plan:
search_tasks.append(_timed("snippet", _search_snippet, pipeline, req.query, user_id, req.top_k))
if "code" in plan:
search_tasks.append(_timed("code", _search_code, pipeline, req.query, user_id, req.top_k))
if search_tasks:
task_results = await asyncio.gather(*search_tasks)
for (results, elapsed), mode in zip(task_results, plan):
latency_ms[mode] = elapsed
all_results.extend(results)


all_results.sort(key=lambda record: record.score, reverse=True)

answer = None
answer_sources: List[SourceRecord] = []
confidence = 0.0
if req.answer:
answer_result, elapsed = await _timed("answer", pipeline.run, req.query, user_id, req.top_k)
latency_ms["answer"] = elapsed
answer = answer_result.answer
confidence = answer_result.confidence
answer_sources = [
SourceRecord(
domain=s.domain, content=s.content,
score=round(s.score, 3), metadata=s.metadata,
)
for s in answer_result.sources
]

if "profile" in req.domains:
all_results.extend(_search_profile(pipeline, user_id))
if "temporal" in req.domains:
all_results.extend(_search_temporal(pipeline, req.query, user_id, req.top_k))
if "summary" in req.domains:
all_results.extend(await _search_summary(pipeline, req.query, user_id, req.top_k))

data = SearchResponse(results=all_results, total=len(all_results))
data = SearchResponse(
results=all_results,
total=len(all_results),
answer=answer,
answer_sources=answer_sources,
confidence=confidence,
latency_ms=latency_ms,
latency_stats=_latency_stats(),
)
elapsed = round((time.perf_counter() - start) * 1000, 2)
return _wrap(request, data, elapsed)

Expand Down Expand Up @@ -763,6 +840,34 @@ async def _search_summary(pipeline: RetrievalPipeline, query: str, user_id: str,
return []


async def _search_snippet(pipeline: RetrievalPipeline, query: str, user_id: str, top_k: int) -> List[SourceRecord]:
try:
raw = await pipeline._search_snippet(query=query, user_id=user_id, top_k=top_k)
return [
SourceRecord(domain=r.domain, content=r.content, score=round(r.score, 3), metadata=r.metadata)
for r in raw
]
except Exception as exc:
logger.warning("Snippet search error: %s", exc)
return []


async def _search_code(pipeline: RetrievalPipeline, query: str, user_id: str, top_k: int) -> List[SourceRecord]:
try:
raw = await pipeline.vector_store.search_by_text(
query_text=query,
top_k=top_k,
filters={"user_id": user_id, "domain": "code"},
)
return [
SourceRecord(domain="code", content=r.content, score=round(r.score, 3), metadata={"id": r.id, **r.metadata})
for r in raw
]
except Exception as exc:
logger.warning("Code search error: %s", exc)
return []


# POST /v1/memory/scrape
@scrape_router.post(
"/scrape",
Expand Down
23 changes: 19 additions & 4 deletions src/api/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from __future__ import annotations

from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional

Expand Down Expand Up @@ -159,24 +158,40 @@ class SearchRequest(BaseModel):
..., min_length=1, max_length=256, pattern=r"^[\w.\-@]+$",
)
domains: List[str] = Field(
default=["profile", "temporal", "summary"],
default=["profile", "temporal", "summary", "snippet", "code"],
description="Which memory domains to search",
)
top_k: int = Field(default=10, ge=1, le=100)
answer: bool = Field(
default=False,
description="When true, also generate a synthesized answer after returning raw ranked hits.",
)

@field_validator("domains")
@classmethod
def validate_domains(cls, v: List[str]) -> List[str]:
allowed = {"profile", "temporal", "summary"}
allowed = {"profile", "temporal", "summary", "snippet", "code"}
for d in v:
if d not in allowed:
raise ValueError(f"Invalid domain '{d}'. Allowed: {allowed}")
return v
return list(dict.fromkeys(v))


class SearchLatencySummary(BaseModel):
count: int = 0
p50_ms: float = 0.0
p95_ms: float = 0.0
p99_ms: float = 0.0


class SearchResponse(BaseModel):
results: List[SourceRecord] = Field(default_factory=list)
total: int = 0
answer: Optional[str] = None
answer_sources: List[SourceRecord] = Field(default_factory=list)
confidence: float = 0.0
latency_ms: Dict[str, float] = Field(default_factory=dict)
latency_stats: Dict[str, SearchLatencySummary] = Field(default_factory=dict)


# ── Scrape (extract from shared chat links) ────────────────────────────────
Expand Down
19 changes: 19 additions & 0 deletions src/pipelines/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import asyncio
import logging
import time
from typing import Any, Callable, Dict, List, Optional

from dotenv import load_dotenv
Expand Down Expand Up @@ -133,6 +134,9 @@ def __init__(

self.embed_fn = embed_fn
self._snippet_stores: Dict[str, BaseVectorStore] = {}
self._profile_catalog_cache: Dict[str, tuple[float, List[Dict[str, str]], list]] = {}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _profile_catalog_cache uses user_id as a key but lacks an eviction policy. In a production environment with a large number of unique users, this dictionary will grow indefinitely, leading to a memory leak.

Consider using an LRU cache or implementing a simple size-based eviction mechanism to bound memory usage.

self._raw_retrieval_plan_cache: Dict[tuple[tuple[str, ...], bool], tuple[str, ...]] = {}
self._cache_ttl_seconds = 60.0

logger.info("RetrievalPipeline initialized")

Expand Down Expand Up @@ -494,6 +498,11 @@ def _fetch_profile_catalog(self, user_id: str):
catalog — list of {topic, sub_topic} for the prompt
raw_results — the full SearchResult list, cached for _search_profile
"""
now = time.monotonic()
cached = self._profile_catalog_cache.get(user_id)
if cached and now - cached[0] < self._cache_ttl_seconds:
return cached[1], cached[2]

try:
results = self.vector_store.search_by_metadata(
filters={"user_id": user_id, "domain": "profile"},
Expand Down Expand Up @@ -524,8 +533,18 @@ def _fetch_profile_catalog(self, user_id: str):
"sub_topic": "",
})

self._profile_catalog_cache[user_id] = (now, catalog, results)
return catalog, results

def raw_retrieval_plan(self, domains: List[str], answer: bool = False) -> tuple[str, ...]:
"""Return a cached deterministic raw-search plan for the requested domains."""
ordered_allowed = ("profile", "temporal", "summary", "snippet", "code")
normalized = tuple(d for d in ordered_allowed if d in set(domains))
key = (normalized, answer)
if key not in self._raw_retrieval_plan_cache:
self._raw_retrieval_plan_cache[key] = normalized
return self._raw_retrieval_plan_cache[key]

def _format_catalog(self, catalog: List[Dict[str, str]]) -> str:
"""Format profile catalog for the system prompt."""
if not catalog:
Expand Down
88 changes: 84 additions & 4 deletions tests/api/test_dependencies_and_routes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from __future__ import annotations

import asyncio
from types import SimpleNamespace

import pytest
Expand All @@ -10,7 +8,9 @@
from src.api import dependencies as deps
from src.api.middleware import RequestContextMiddleware, SecurityHeadersMiddleware
from src.api.routes.health import router as health_router
from src.schemas.retrieval import RetrievalResult
from src.api.routes.memory import router as memory_router
from src.schemas.retrieval import RetrievalResult, SourceRecord
from src.storage.base import SearchResult


class FakeIngestPipeline:
Expand All @@ -26,8 +26,58 @@ def close(self):
class FakeRetrievalPipeline:
model = SimpleNamespace(model="fake-retrieval")

def __init__(self):
self.vector_store = SimpleNamespace(
search_by_metadata=self._search_by_metadata,
search_by_text=self._search_by_text,
)
self.neo4j = SimpleNamespace(search_events_by_embedding=self._search_events_by_embedding)

def raw_retrieval_plan(self, domains, answer=False):
ordered = ("profile", "temporal", "summary", "snippet", "code")
return tuple(domain for domain in ordered if domain in set(domains))

async def run(self, query: str, user_id: str, top_k: int = 5):
return RetrievalResult(query=query, answer=f"answer for {user_id}", sources=[], confidence=0.1)
return RetrievalResult(
query=query,
answer=f"answer for {user_id}",
sources=[SourceRecord(domain="summary", content="answer source", score=0.7)],
confidence=0.7,
)

def _search_by_metadata(self, filters, top_k=10):
return [
SearchResult(
id="profile-1",
content="Profile fact",
score=0.4,
metadata={"domain": "profile", "user_id": filters.get("user_id")},
)
][:top_k]

async def _search_by_text(self, query_text, top_k=10, filters=None):
domain = (filters or {}).get("domain", "summary")
return [
SearchResult(
id=f"{domain}-1",
content=f"{domain} hit for {query_text}",
score=0.8,
metadata={"domain": domain},
)
][:top_k]

def _search_events_by_embedding(self, user_id, query_text, top_k=3, similarity_threshold=0.0):
return [
{
"event_name": "Demo event",
"desc": query_text,
"date": "2026-05-21",
"similarity_score": 0.9,
}
][:top_k]

async def _search_snippet(self, query: str, user_id: str, top_k: int = 5):
return [SourceRecord(domain="snippet", content=f"snippet hit for {query}", score=0.6)]

def close(self):
pass
Expand All @@ -44,6 +94,7 @@ def dependency_app(monkeypatch):
app.add_middleware(SecurityHeadersMiddleware)
app.add_middleware(RequestContextMiddleware)
app.include_router(health_router)
app.include_router(memory_router)

@app.get("/protected")
async def protected(user: dict = Depends(deps.require_api_key)):
Expand Down Expand Up @@ -85,6 +136,35 @@ def test_dependency_injection_returns_configured_pipeline(dependency_app):
assert response.json() == {"ingest": "fake-ingest"}


def test_memory_search_returns_raw_hits_latency_and_optional_answer(dependency_app):
client = TestClient(dependency_app)

response = client.post(
"/v1/memory/search",
headers={"Authorization": "Bearer test-static-key"},
json={
"query": "fast lookup",
"user_id": "ignored-for-auth-user",
"domains": ["profile", "temporal", "summary", "snippet", "code"],
"answer": True,
},
)

assert response.status_code == 200
data = response.json()["data"]
assert data["answer"].startswith("answer for Static Key User")
assert data["total"] == 5
assert {item["domain"] for item in data["results"]} == {
"profile",
"temporal",
"summary",
"snippet",
"code",
}
assert {"profile", "temporal", "summary", "snippet", "code", "answer"} <= set(data["latency_ms"])
assert data["latency_stats"]["summary"]["count"] >= 1


@pytest.mark.asyncio
async def test_rate_limiter_blocks_after_limit(monkeypatch):
limiter = deps._SlidingWindowRateLimiter(max_requests=1, window_seconds=60)
Expand Down
Loading