-
Notifications
You must be signed in to change notification settings - Fork 40
Add fast raw memory search path #189
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -8,7 +8,9 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import asyncio | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import logging | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import threading | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import time | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from collections import defaultdict, deque | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import Any, Dict, List | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from fastapi import APIRouter, Depends, Request, UploadFile, File | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -50,6 +52,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger = logging.getLogger("xmem.api.routes.memory") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _ingest_semaphore = asyncio.Semaphore(5) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _latency_samples: dict[str, deque[float]] = defaultdict(lambda: deque(maxlen=200)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| router = APIRouter( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| prefix="/v1/memory", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -108,6 +111,40 @@ def _error(request: Request, detail: str, code: int, elapsed_ms: float = 0) -> J | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return JSONResponse(content=body.model_dump(), status_code=code) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _record_latency(mode: str, elapsed_ms: float) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _latency_samples[mode].append(elapsed_ms) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _percentile(sorted_values: List[float], percentile: float) -> float: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not sorted_values: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return 0.0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| index = min(len(sorted_values) - 1, int(round((len(sorted_values) - 1) * percentile))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return round(sorted_values[index], 2) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _latency_stats() -> Dict[str, Dict[str, float]]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| stats: Dict[str, Dict[str, float]] = {} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for mode, samples in _latency_samples.items(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| values = sorted(samples) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| stats[mode] = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "count": len(values), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "p50_ms": _percentile(values, 0.50), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "p95_ms": _percentile(values, 0.95), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "p99_ms": _percentile(values, 0.99), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return stats | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def _timed(mode: str, func, *args, **kwargs): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| start = time.perf_counter() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| result = func(*args, **kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if hasattr(result, "__await__"): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| result = await result | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elapsed_ms = round((time.perf_counter() - start) * 1000, 2) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _record_latency(mode, elapsed_ms) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return result, elapsed_ms | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _detect_chat_provider(*urls: str) -> str: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for url in urls: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lowered = (url or "").lower() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -145,8 +182,6 @@ async def _render_chat_share(url: str) -> tuple[str, str]: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # reuse it across scrape requests. The browser is thread-safe when each | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # request uses its own BrowserContext. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import threading | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _browser_lock = threading.Lock() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _pw_instance = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _browser_instance = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -690,15 +725,57 @@ async def search_memory(req: SearchRequest, request: Request, user: dict = Depen | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| all_results: List[SourceRecord] = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| latency_ms: Dict[str, float] = {} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| plan = pipeline.raw_retrieval_plan(req.domains, answer=req.answer) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if "profile" in plan: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| results, elapsed = await _timed("profile", _search_profile, pipeline, user_id) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| latency_ms["profile"] = elapsed | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| all_results.extend(results) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if "temporal" in plan: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| results, elapsed = await _timed("temporal", _search_temporal, pipeline, req.query, user_id, req.top_k) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| latency_ms["temporal"] = elapsed | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| all_results.extend(results) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if "summary" in plan: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| results, elapsed = await _timed("summary", _search_summary, pipeline, req.query, user_id, req.top_k) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| latency_ms["summary"] = elapsed | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| all_results.extend(results) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if "snippet" in plan: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| results, elapsed = await _timed("snippet", _search_snippet, pipeline, req.query, user_id, req.top_k) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| latency_ms["snippet"] = elapsed | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| all_results.extend(results) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if "code" in plan: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| results, elapsed = await _timed("code", _search_code, pipeline, req.query, user_id, req.top_k) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| latency_ms["code"] = elapsed | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| all_results.extend(results) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+731
to
+750
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The search domains are currently queried sequentially, which negates the performance benefits of having a "fast" raw search path. Since these operations are I/O bound, they should be executed in parallel using
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| all_results.sort(key=lambda record: record.score, reverse=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| answer = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| answer_sources: List[SourceRecord] = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| confidence = 0.0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if req.answer: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| answer_result, elapsed = await _timed("answer", pipeline.run, req.query, user_id, req.top_k) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| latency_ms["answer"] = elapsed | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| answer = answer_result.answer | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| confidence = answer_result.confidence | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| answer_sources = [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| SourceRecord( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| domain=s.domain, content=s.content, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| score=round(s.score, 3), metadata=s.metadata, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for s in answer_result.sources | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if "profile" in req.domains: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| all_results.extend(_search_profile(pipeline, user_id)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if "temporal" in req.domains: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| all_results.extend(_search_temporal(pipeline, req.query, user_id, req.top_k)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if "summary" in req.domains: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| all_results.extend(await _search_summary(pipeline, req.query, user_id, req.top_k)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| data = SearchResponse(results=all_results, total=len(all_results)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| data = SearchResponse( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| results=all_results, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| total=len(all_results), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| answer=answer, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| answer_sources=answer_sources, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| confidence=confidence, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| latency_ms=latency_ms, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| latency_stats=_latency_stats(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elapsed = round((time.perf_counter() - start) * 1000, 2) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return _wrap(request, data, elapsed) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -763,6 +840,34 @@ async def _search_summary(pipeline: RetrievalPipeline, query: str, user_id: str, | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def _search_snippet(pipeline: RetrievalPipeline, query: str, user_id: str, top_k: int) -> List[SourceRecord]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raw = await pipeline._search_snippet(query=query, user_id=user_id, top_k=top_k) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| SourceRecord(domain=r.domain, content=r.content, score=round(r.score, 3), metadata=r.metadata) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for r in raw | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| except Exception as exc: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.warning("Snippet search error: %s", exc) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def _search_code(pipeline: RetrievalPipeline, query: str, user_id: str, top_k: int) -> List[SourceRecord]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raw = await pipeline.vector_store.search_by_text( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| query_text=query, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| top_k=top_k, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| filters={"user_id": user_id, "domain": "code"}, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| SourceRecord(domain="code", content=r.content, score=round(r.score, 3), metadata={"id": r.id, **r.metadata}) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for r in raw | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| except Exception as exc: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.warning("Code search error: %s", exc) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # POST /v1/memory/scrape | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @scrape_router.post( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "/scrape", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,6 +22,7 @@ | |
|
|
||
| import asyncio | ||
| import logging | ||
| import time | ||
| from typing import Any, Callable, Dict, List, Optional | ||
|
|
||
| from dotenv import load_dotenv | ||
|
|
@@ -133,6 +134,9 @@ def __init__( | |
|
|
||
| self.embed_fn = embed_fn | ||
| self._snippet_stores: Dict[str, BaseVectorStore] = {} | ||
| self._profile_catalog_cache: Dict[str, tuple[float, List[Dict[str, str]], list]] = {} | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The Consider using an LRU cache or implementing a simple size-based eviction mechanism to bound memory usage. |
||
| self._raw_retrieval_plan_cache: Dict[tuple[tuple[str, ...], bool], tuple[str, ...]] = {} | ||
| self._cache_ttl_seconds = 60.0 | ||
|
|
||
| logger.info("RetrievalPipeline initialized") | ||
|
|
||
|
|
@@ -494,6 +498,11 @@ def _fetch_profile_catalog(self, user_id: str): | |
| catalog — list of {topic, sub_topic} for the prompt | ||
| raw_results — the full SearchResult list, cached for _search_profile | ||
| """ | ||
| now = time.monotonic() | ||
| cached = self._profile_catalog_cache.get(user_id) | ||
| if cached and now - cached[0] < self._cache_ttl_seconds: | ||
| return cached[1], cached[2] | ||
|
|
||
| try: | ||
| results = self.vector_store.search_by_metadata( | ||
| filters={"user_id": user_id, "domain": "profile"}, | ||
|
|
@@ -524,8 +533,18 @@ def _fetch_profile_catalog(self, user_id: str): | |
| "sub_topic": "", | ||
| }) | ||
|
|
||
| self._profile_catalog_cache[user_id] = (now, catalog, results) | ||
| return catalog, results | ||
|
|
||
| def raw_retrieval_plan(self, domains: List[str], answer: bool = False) -> tuple[str, ...]: | ||
| """Return a cached deterministic raw-search plan for the requested domains.""" | ||
| ordered_allowed = ("profile", "temporal", "summary", "snippet", "code") | ||
| normalized = tuple(d for d in ordered_allowed if d in set(domains)) | ||
| key = (normalized, answer) | ||
| if key not in self._raw_retrieval_plan_cache: | ||
| self._raw_retrieval_plan_cache[key] = normalized | ||
| return self._raw_retrieval_plan_cache[key] | ||
|
|
||
| def _format_catalog(self, catalog: List[Dict[str, str]]) -> str: | ||
| """Format profile catalog for the system prompt.""" | ||
| if not catalog: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Iterating over
_latency_samples.items()directly is not thread-safe because_latency_samplesis adefaultdict. If a concurrent request records latency for a new mode (causing a new key to be added) while this loop is running, it will raise aRuntimeError: dictionary changed size during iteration.Consider wrapping the items in a
list()to create a snapshot of the keys before iteration.