From 6702dd8ec18d202363ff84095e49ea1a485ccc9f Mon Sep 17 00:00:00 2001 From: qin-ctx Date: Wed, 8 Apr 2026 15:16:53 +0800 Subject: [PATCH] fix(embedder): reduce async contention in session flows Introduce native async embedding paths across providers, switch async retrieval/session hotspots to use them, and add a standalone mixed-load benchmark plus before/after benchmark evidence for the regression. --- benchmark/.gitignore | 1 + .../custom/session_contention_benchmark.py | 1672 +++++++++++++++++ openviking/models/embedder/base.py | 120 +- .../models/embedder/cohere_embedders.py | 103 +- .../models/embedder/gemini_embedders.py | 163 +- openviking/models/embedder/jina_embedders.py | 100 +- .../models/embedder/litellm_embedders.py | 40 + .../models/embedder/minimax_embedders.py | 118 +- .../models/embedder/openai_embedders.py | 105 +- .../models/embedder/vikingdb_embedders.py | 248 ++- .../models/embedder/volcengine_embedders.py | 200 ++ .../models/embedder/voyage_embedders.py | 84 +- openviking/retrieve/hierarchical_retriever.py | 4 +- openviking/session/compressor.py | 6 +- openviking/session/memory_deduplicator.py | 6 +- openviking/storage/collection_schemas.py | 8 +- .../utils/config/embedding_config.py | 34 +- 17 files changed, 2867 insertions(+), 145 deletions(-) create mode 100644 benchmark/.gitignore create mode 100644 benchmark/custom/session_contention_benchmark.py diff --git a/benchmark/.gitignore b/benchmark/.gitignore new file mode 100644 index 000000000..68bcbc960 --- /dev/null +++ b/benchmark/.gitignore @@ -0,0 +1 @@ +results/ \ No newline at end of file diff --git a/benchmark/custom/session_contention_benchmark.py b/benchmark/custom/session_contention_benchmark.py new file mode 100644 index 000000000..c351952ae --- /dev/null +++ b/benchmark/custom/session_contention_benchmark.py @@ -0,0 +1,1672 @@ +#!/usr/bin/env python3 +"""Daily session mixed-load contention benchmark for OpenViking.""" + +from __future__ import annotations + +import argparse +import asyncio +import csv +import json +import math +import os +import random +import sys +import time +from dataclasses import asdict, dataclass, field +from datetime import UTC, datetime +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional + +import httpx + +DEFAULT_FIND_QUERIES = [ + "how to authenticate users", + "what is OpenViking", + "session commit memory extraction", +] +DEFAULT_SLOW_THRESHOLDS_MS = (1000, 3000, 5000) +MAX_ERROR_MESSAGE_LEN = 500 + + +@dataclass +class BenchmarkConfig: + server_url: str + api_key: str + account: str + user: str + request_timeout: float + session_count: int + writer_concurrency: int + reader_concurrency: int + extract_concurrency: int + messages_per_commit: int + extract_ratio: float + message_size: int + baseline_seconds: float + mixed_seconds: float + recovery_seconds: float + window_seconds: float + observer_interval: float + task_poll_interval: float + task_drain_timeout: float + output_dir: str + cleanup: bool + require_extract_load: bool + find_queries: List[str] + find_limit: int + find_target_uri: str + find_score_threshold: Optional[float] + seed: int + + +@dataclass +class PhaseMetadata: + phase: str + started_at: str + ended_at: str + duration_seconds: float + + +@dataclass +class RequestEvent: + api: str + method: str + path: str + phase: str + started_at: str + ended_at: str + elapsed_ms_since_run_start: float + latency_ms: float + success: bool + status_code: Optional[int] + timeout: bool + exception_type: Optional[str] + error_code: Optional[str] + error_message: Optional[str] + session_id: Optional[str] = None + cycle_index: Optional[int] = None + worker_id: Optional[int] = None + task_id: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + +@dataclass +class CommitTaskEvent: + task_id: str + session_id: str + origin_phase: str + completion_phase: str + status: str + created_at: Optional[float] + updated_at: Optional[float] + server_duration_ms: Optional[float] + local_duration_ms: float + active_count_updated: Optional[int] + memories_extracted: Optional[Dict[str, int]] + error: Optional[str] + cycle_index: Optional[int] + polled_at: str + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + +@dataclass +class ObserverSample: + api: str + phase: str + sampled_at: str + elapsed_ms_since_run_start: float + latency_ms: float + success: bool + is_healthy: Optional[bool] + has_errors: Optional[bool] + payload: Optional[Dict[str, Any]] + error_message: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + +@dataclass +class PendingCommitTask: + task_id: str + session_id: str + origin_phase: str + cycle_index: int + local_started_monotonic: float + + +@dataclass +class Recorder: + request_events: List[RequestEvent] = field(default_factory=list) + task_events: List[CommitTaskEvent] = field(default_factory=list) + observer_samples: List[ObserverSample] = field(default_factory=list) + notes: List[str] = field(default_factory=list) + + def add_request(self, event: RequestEvent) -> None: + self.request_events.append(event) + + def add_task(self, event: CommitTaskEvent) -> None: + self.task_events.append(event) + + def add_sample(self, sample: ObserverSample) -> None: + self.observer_samples.append(sample) + + def add_note(self, note: str) -> None: + self.notes.append(note) + + +class PhaseState: + def __init__(self, initial: str = "setup") -> None: + self.current = initial + + +class BenchmarkHTTPClient: + def __init__(self, config: BenchmarkConfig, recorder: Recorder) -> None: + self._config = config + self._recorder = recorder + self._run_start_monotonic = time.perf_counter() + self._client = httpx.AsyncClient( + base_url=config.server_url.rstrip("/"), + headers=self._default_headers(), + timeout=httpx.Timeout(config.request_timeout), + follow_redirects=True, + limits=httpx.Limits( + max_connections=max( + 32, + config.writer_concurrency + + config.reader_concurrency + + config.extract_concurrency + + 8, + ), + max_keepalive_connections=max( + 16, + config.writer_concurrency + config.reader_concurrency + 4, + ), + ), + ) + + @property + def run_start_monotonic(self) -> float: + return self._run_start_monotonic + + async def aclose(self) -> None: + await self._client.aclose() + + def _default_headers(self) -> Dict[str, str]: + headers = { + "Accept": "*/*", + "Content-Type": "application/json", + "User-Agent": "OpenViking-Session-Contention-Benchmark/1.0", + "X-OpenViking-Account": self._config.account, + "X-OpenViking-User": self._config.user, + } + if self._config.api_key: + headers["Authorization"] = f"Bearer {self._config.api_key}" + return headers + + async def request_json( + self, + *, + api: str, + method: str, + path: str, + phase: str, + session_id: Optional[str] = None, + cycle_index: Optional[int] = None, + worker_id: Optional[int] = None, + task_id: Optional[str] = None, + json_payload: Optional[Dict[str, Any]] = None, + params: Optional[Dict[str, Any]] = None, + ) -> tuple[Optional[httpx.Response], Optional[Dict[str, Any]]]: + started_monotonic = time.perf_counter() + started_wall = utc_now() + response: Optional[httpx.Response] = None + body: Optional[Dict[str, Any]] = None + status_code: Optional[int] = None + success = False + timeout = False + exception_type: Optional[str] = None + error_code: Optional[str] = None + error_message: Optional[str] = None + + try: + response = await self._client.request( + method=method, + url=path, + json=json_payload, + params=params, + ) + status_code = response.status_code + body = maybe_json(response) + success = self._is_success(status_code, body) + if not success: + error_code, error_message = extract_error(body, status_code) + except httpx.TimeoutException as exc: + timeout = True + exception_type = type(exc).__name__ + error_message = truncate_error_message(str(exc)) + except Exception as exc: # pragma: no cover - exercised in real runs + exception_type = type(exc).__name__ + error_message = truncate_error_message(str(exc)) + + ended_wall = utc_now() + ended_monotonic = time.perf_counter() + latency_ms = (ended_monotonic - started_monotonic) * 1000.0 + elapsed_ms = (started_monotonic - self._run_start_monotonic) * 1000.0 + self._recorder.add_request( + RequestEvent( + api=api, + method=method.upper(), + path=path, + phase=phase, + started_at=started_wall, + ended_at=ended_wall, + elapsed_ms_since_run_start=elapsed_ms, + latency_ms=latency_ms, + success=success, + status_code=status_code, + timeout=timeout, + exception_type=exception_type, + error_code=error_code, + error_message=error_message, + session_id=session_id, + cycle_index=cycle_index, + worker_id=worker_id, + task_id=task_id, + ) + ) + return response, body + + @staticmethod + def _is_success(status_code: Optional[int], body: Optional[Dict[str, Any]]) -> bool: + if status_code is None or status_code >= 400: + return False + if not isinstance(body, dict): + return status_code < 400 + if "status" in body: + return body.get("status") == "ok" + return True + + +class CommitTaskPoller: + def __init__( + self, + client: BenchmarkHTTPClient, + recorder: Recorder, + phase_state: PhaseState, + poll_interval: float, + ) -> None: + self._client = client + self._recorder = recorder + self._phase_state = phase_state + self._poll_interval = poll_interval + self._pending: Dict[str, PendingCommitTask] = {} + self._closed = False + self._wake_event = asyncio.Event() + self._lock = asyncio.Lock() + + async def register(self, task: PendingCommitTask) -> None: + async with self._lock: + self._pending[task.task_id] = task + self._wake_event.set() + + async def close(self) -> None: + self._closed = True + self._wake_event.set() + + async def drain(self, timeout: float) -> None: + deadline = time.perf_counter() + timeout + while True: + async with self._lock: + remaining = len(self._pending) + if remaining == 0: + return + if time.perf_counter() >= deadline: + return + await asyncio.sleep(min(self._poll_interval, 0.5)) + + async def finalize_incomplete(self) -> None: + async with self._lock: + leftovers = list(self._pending.values()) + self._pending.clear() + for item in leftovers: + local_duration_ms = (time.perf_counter() - item.local_started_monotonic) * 1000.0 + self._recorder.add_task( + CommitTaskEvent( + task_id=item.task_id, + session_id=item.session_id, + origin_phase=item.origin_phase, + completion_phase=self._phase_state.current, + status="incomplete", + created_at=None, + updated_at=None, + server_duration_ms=None, + local_duration_ms=local_duration_ms, + active_count_updated=None, + memories_extracted=None, + error="task not completed before benchmark end", + cycle_index=item.cycle_index, + polled_at=utc_now(), + ) + ) + + async def run(self) -> None: + while True: + await self._wake_event.wait() + self._wake_event.clear() + + while True: + async with self._lock: + pending = list(self._pending.values()) + if not pending: + break + await self._poll_pending(pending) + if self._closed: + return + await asyncio.sleep(self._poll_interval) + + if self._closed: + return + + async def _poll_pending(self, pending: List[PendingCommitTask]) -> None: + coroutines = [self._poll_one(item) for item in pending] + results = await asyncio.gather(*coroutines, return_exceptions=True) + completed_ids = [task_id for task_id in results if isinstance(task_id, str)] + if not completed_ids: + return + async with self._lock: + for task_id in completed_ids: + self._pending.pop(task_id, None) + + async def _poll_one(self, item: PendingCommitTask) -> Optional[str]: + _, body = await self._client.request_json( + api="get_task", + method="GET", + path=f"/api/v1/tasks/{item.task_id}", + phase=self._phase_state.current, + session_id=item.session_id, + cycle_index=item.cycle_index, + task_id=item.task_id, + ) + if not isinstance(body, dict) or body.get("status") != "ok": + return None + result = body.get("result") or {} + task_status = result.get("status") + if task_status not in {"completed", "failed"}: + return None + + created_at = to_float(result.get("created_at")) + updated_at = to_float(result.get("updated_at")) + server_duration_ms = None + if created_at is not None and updated_at is not None: + server_duration_ms = max(updated_at - created_at, 0.0) * 1000.0 + local_duration_ms = (time.perf_counter() - item.local_started_monotonic) * 1000.0 + task_result = result.get("result") or {} + self._recorder.add_task( + CommitTaskEvent( + task_id=item.task_id, + session_id=item.session_id, + origin_phase=item.origin_phase, + completion_phase=self._phase_state.current, + status=task_status, + created_at=created_at, + updated_at=updated_at, + server_duration_ms=server_duration_ms, + local_duration_ms=local_duration_ms, + active_count_updated=task_result.get("active_count_updated"), + memories_extracted=task_result.get("memories_extracted"), + error=result.get("error"), + cycle_index=item.cycle_index, + polled_at=utc_now(), + ) + ) + return item.task_id + + +class BenchmarkRunner: + def __init__(self, config: BenchmarkConfig) -> None: + self.config = config + self.random = random.Random(config.seed) + self.recorder = Recorder() + self.phase_state = PhaseState() + self.phase_metadata: List[PhaseMetadata] = [] + self.phase_durations: Dict[str, float] = {} + self.session_ids: List[str] = [] + self.session_queue: asyncio.Queue[str] = asyncio.Queue() + self.session_cycle_counts: Dict[str, int] = {} + self.extract_semaphore = asyncio.Semaphore(max(1, config.extract_concurrency)) + self.client = BenchmarkHTTPClient(config, self.recorder) + self.task_poller = CommitTaskPoller( + client=self.client, + recorder=self.recorder, + phase_state=self.phase_state, + poll_interval=config.task_poll_interval, + ) + + async def run(self) -> int: + poller_task = asyncio.create_task(self.task_poller.run()) + exit_code = 0 + try: + await self._preflight() + await self._create_sessions() + await self._run_phase( + phase="baseline", + duration_seconds=self.config.baseline_seconds, + enable_readers=self.config.reader_concurrency > 0, + enable_writers=False, + enable_sampler=self.config.observer_interval > 0, + ) + await self._run_phase( + phase="mixed_load", + duration_seconds=self.config.mixed_seconds, + enable_readers=self.config.reader_concurrency > 0, + enable_writers=self.config.writer_concurrency > 0 and bool(self.session_ids), + enable_sampler=self.config.observer_interval > 0, + ) + await self._run_phase( + phase="recovery", + duration_seconds=self.config.recovery_seconds, + enable_readers=self.config.reader_concurrency > 0, + enable_writers=False, + enable_sampler=self.config.observer_interval > 0, + ) + if self.config.task_drain_timeout > 0: + self.phase_state.current = "drain" + await self.task_poller.drain(self.config.task_drain_timeout) + except RuntimeError as exc: + self.recorder.add_note(f"fatal: {exc}") + print(f"[fatal] {exc}", file=sys.stderr) + exit_code = 1 + finally: + await self.task_poller.close() + await poller_task + await self.task_poller.finalize_incomplete() + if self.config.cleanup and self.session_ids: + await self._cleanup_sessions() + await self.client.aclose() + + self._write_outputs() + self._print_summary() + return exit_code + + async def _preflight(self) -> None: + self.phase_state.current = "setup" + _, health_body = await self.client.request_json( + api="health", + method="GET", + path="/health", + phase="setup", + ) + if not isinstance(health_body, dict) or health_body.get("status") != "ok": + raise RuntimeError("server health check failed") + + _, status_body = await self.client.request_json( + api="system_status", + method="GET", + path="/api/v1/system/status", + phase="setup", + ) + if not isinstance(status_body, dict) or status_body.get("status") != "ok": + raise RuntimeError("authenticated system status request failed") + + _, models_body = await self.client.request_json( + api="observer_models", + method="GET", + path="/api/v1/observer/models", + phase="setup", + ) + model_result = (models_body or {}).get("result") if isinstance(models_body, dict) else None + model_note = self._extract_model_note(model_result) + if model_note: + self.recorder.add_note(model_note) + + if self.config.extract_ratio > 0: + preflight_result = await self._run_extract_preflight() + if preflight_result: + self.recorder.add_note(preflight_result) + if self.config.require_extract_load: + raise RuntimeError(preflight_result) + + async def _run_extract_preflight(self) -> Optional[str]: + _, create_body = await self.client.request_json( + api="create_session", + method="POST", + path="/api/v1/sessions", + phase="setup", + ) + session_id = extract_session_id(create_body) + if not session_id: + return "extract preflight could not create session" + + try: + payload = { + "role": "user", + "content": build_message_content( + session_id=session_id, + cycle_index=0, + message_index=0, + size=self.config.message_size, + ), + } + await self.client.request_json( + api="add_message", + method="POST", + path=f"/api/v1/sessions/{session_id}/messages", + phase="setup", + session_id=session_id, + cycle_index=0, + json_payload=payload, + ) + _, extract_body = await self.client.request_json( + api="extract", + method="POST", + path=f"/api/v1/sessions/{session_id}/extract", + phase="setup", + session_id=session_id, + cycle_index=0, + ) + if not isinstance(extract_body, dict) or extract_body.get("status") != "ok": + return "extract preflight request failed" + result = extract_body.get("result") + if isinstance(result, list) and not result: + return ( + "extract preflight returned empty result; long-tail load may be weak if models are " + "not configured" + ) + return None + finally: + await self.client.request_json( + api="delete_session", + method="DELETE", + path=f"/api/v1/sessions/{session_id}", + phase="setup", + session_id=session_id, + ) + + def _extract_model_note(self, model_result: Any) -> Optional[str]: + if not isinstance(model_result, dict): + return None + is_healthy = model_result.get("is_healthy") + status = model_result.get("status") + if is_healthy is False: + return f"observer/models reports unhealthy state; extract load may not be representative: {status}" + return None + + async def _create_sessions(self) -> None: + if self.config.session_count <= 0: + return + for _ in range(self.config.session_count): + _, body = await self.client.request_json( + api="create_session", + method="POST", + path="/api/v1/sessions", + phase="setup", + ) + session_id = extract_session_id(body) + if not session_id: + raise RuntimeError("failed to create benchmark sessions") + self.session_ids.append(session_id) + self.session_cycle_counts[session_id] = 0 + await self.session_queue.put(session_id) + + async def _cleanup_sessions(self) -> None: + self.phase_state.current = "cleanup" + for session_id in self.session_ids: + await self.client.request_json( + api="delete_session", + method="DELETE", + path=f"/api/v1/sessions/{session_id}", + phase="cleanup", + session_id=session_id, + ) + + async def _run_phase( + self, + *, + phase: str, + duration_seconds: float, + enable_readers: bool, + enable_writers: bool, + enable_sampler: bool, + ) -> None: + if duration_seconds <= 0: + return + + self.phase_state.current = phase + stop_event = asyncio.Event() + tasks: List[asyncio.Task[Any]] = [] + + if enable_readers: + for worker_id in range(self.config.reader_concurrency): + tasks.append(asyncio.create_task(self._reader_worker(phase, worker_id, stop_event))) + if enable_writers: + for worker_id in range(self.config.writer_concurrency): + tasks.append(asyncio.create_task(self._writer_worker(phase, worker_id, stop_event))) + if enable_sampler: + tasks.append(asyncio.create_task(self._sampler_worker(phase, stop_event))) + + phase_started = time.perf_counter() + started_wall = utc_now() + await asyncio.sleep(duration_seconds) + stop_event.set() + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + phase_duration = time.perf_counter() - phase_started + ended_wall = utc_now() + self.phase_metadata.append( + PhaseMetadata( + phase=phase, + started_at=started_wall, + ended_at=ended_wall, + duration_seconds=phase_duration, + ) + ) + self.phase_durations[phase] = phase_duration + + async def _writer_worker(self, phase: str, worker_id: int, stop_event: asyncio.Event) -> None: + while not stop_event.is_set(): + session_id = await self._borrow_session(stop_event) + if not session_id: + return + try: + cycle_index = self.session_cycle_counts[session_id] + self.session_cycle_counts[session_id] += 1 + await self._run_session_cycle( + phase=phase, + worker_id=worker_id, + session_id=session_id, + cycle_index=cycle_index, + ) + finally: + await self.session_queue.put(session_id) + + async def _run_session_cycle( + self, + *, + phase: str, + worker_id: int, + session_id: str, + cycle_index: int, + ) -> None: + successful_messages = 0 + for message_index in range(self.config.messages_per_commit): + payload = { + "role": "user", + "content": build_message_content( + session_id=session_id, + cycle_index=cycle_index, + message_index=message_index, + size=self.config.message_size, + ), + } + _, body = await self.client.request_json( + api="add_message", + method="POST", + path=f"/api/v1/sessions/{session_id}/messages", + phase=phase, + session_id=session_id, + cycle_index=cycle_index, + worker_id=worker_id, + json_payload=payload, + ) + if isinstance(body, dict) and body.get("status") == "ok": + successful_messages += 1 + + if successful_messages <= 0: + return + + if self.config.extract_ratio > 0 and self.random.random() < self.config.extract_ratio: + async with self.extract_semaphore: + await self.client.request_json( + api="extract", + method="POST", + path=f"/api/v1/sessions/{session_id}/extract", + phase=phase, + session_id=session_id, + cycle_index=cycle_index, + worker_id=worker_id, + ) + + _, body = await self.client.request_json( + api="commit", + method="POST", + path=f"/api/v1/sessions/{session_id}/commit", + phase=phase, + session_id=session_id, + cycle_index=cycle_index, + worker_id=worker_id, + ) + task_id = extract_task_id(body) + if task_id: + await self.task_poller.register( + PendingCommitTask( + task_id=task_id, + session_id=session_id, + origin_phase=phase, + cycle_index=cycle_index, + local_started_monotonic=time.perf_counter(), + ) + ) + + async def _reader_worker(self, phase: str, worker_id: int, stop_event: asyncio.Event) -> None: + while not stop_event.is_set(): + payload = { + "query": self.random.choice(self.config.find_queries), + "limit": self.config.find_limit, + } + if self.config.find_target_uri: + payload["target_uri"] = self.config.find_target_uri + if self.config.find_score_threshold is not None: + payload["score_threshold"] = self.config.find_score_threshold + await self.client.request_json( + api="find", + method="POST", + path="/api/v1/search/find", + phase=phase, + worker_id=worker_id, + json_payload=payload, + ) + + async def _sampler_worker(self, phase: str, stop_event: asyncio.Event) -> None: + sample_specs = [ + ("system_status", "GET", "/api/v1/system/status"), + ("observer_queue", "GET", "/api/v1/observer/queue"), + ("observer_system", "GET", "/api/v1/observer/system"), + ] + while not stop_event.is_set(): + for api, method, path in sample_specs: + started = time.perf_counter() + response, body = await self.client.request_json( + api=api, + method=method, + path=path, + phase=phase, + ) + latency_ms = (time.perf_counter() - started) * 1000.0 + success = response is not None and self.client._is_success( + response.status_code if response else None, + body, + ) + self.recorder.add_sample( + ObserverSample( + api=api, + phase=phase, + sampled_at=utc_now(), + elapsed_ms_since_run_start=( + time.perf_counter() - self.client.run_start_monotonic + ) + * 1000.0, + latency_ms=latency_ms, + success=success, + is_healthy=extract_boolean(body, "result", "is_healthy"), + has_errors=extract_boolean(body, "result", "has_errors"), + payload=body if isinstance(body, dict) else None, + error_message=extract_error( + body, response.status_code if response else None + )[1] + if response is not None and not success + else None, + ) + ) + if stop_event.is_set(): + break + if stop_event.is_set(): + return + try: + await asyncio.wait_for(stop_event.wait(), timeout=self.config.observer_interval) + except asyncio.TimeoutError: + continue + + async def _borrow_session(self, stop_event: asyncio.Event) -> Optional[str]: + while not stop_event.is_set(): + try: + return await asyncio.wait_for(self.session_queue.get(), timeout=0.2) + except asyncio.TimeoutError: + continue + return None + + def _write_outputs(self) -> None: + output_dir = Path(self.config.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + request_summary_rows = build_request_summary_rows( + events=self.recorder.request_events, + phase_durations=self.phase_durations, + total_run_duration=total_duration_seconds(self.phase_metadata), + ) + task_summary_rows = build_task_summary_rows(self.recorder.task_events) + human_summary_zh = render_human_summary_zh( + config=self.config, + output_dir=self.config.output_dir, + notes=self.recorder.notes, + phase_metadata=self.phase_metadata, + request_summary_rows=request_summary_rows, + request_events=self.recorder.request_events, + task_summary_rows=task_summary_rows, + task_events=self.recorder.task_events, + ) + + write_json(output_dir / "run_config.json", asdict(self.config)) + write_json( + output_dir / "phases.json", + [asdict(item) for item in self.phase_metadata], + ) + write_json( + output_dir / "run_summary.json", + self._build_run_summary( + request_summary_rows=request_summary_rows, + task_summary_rows=task_summary_rows, + human_summary_zh=human_summary_zh, + ), + ) + write_text(output_dir / "summary_zh.txt", human_summary_zh) + write_jsonl(output_dir / "request_events.jsonl", self.recorder.request_events) + write_jsonl(output_dir / "task_events.jsonl", self.recorder.task_events) + write_jsonl(output_dir / "observer_samples.jsonl", self.recorder.observer_samples) + + write_csv( + output_dir / "request_summary.csv", + request_summary_rows, + ) + write_csv( + output_dir / "request_windows.csv", + build_request_window_rows( + events=self.recorder.request_events, + window_seconds=self.config.window_seconds, + ), + ) + write_csv( + output_dir / "task_summary.csv", + task_summary_rows, + ) + + def _build_run_summary( + self, + *, + request_summary_rows: List[Dict[str, Any]], + task_summary_rows: List[Dict[str, Any]], + human_summary_zh: str, + ) -> Dict[str, Any]: + find_delta = build_find_phase_delta(request_summary_rows) + return { + "notes": self.recorder.notes, + "phase_metadata": [asdict(item) for item in self.phase_metadata], + "request_summary": request_summary_rows, + "task_summary": task_summary_rows, + "find_phase_delta": find_delta, + "human_summary_zh": human_summary_zh, + "created_at": utc_now(), + } + + def _print_summary(self) -> None: + request_summary_rows = build_request_summary_rows( + events=self.recorder.request_events, + phase_durations=self.phase_durations, + total_run_duration=total_duration_seconds(self.phase_metadata), + ) + task_summary_rows = build_task_summary_rows(self.recorder.task_events) + print( + "\n" + + render_human_summary_zh( + config=self.config, + output_dir=self.config.output_dir, + notes=self.recorder.notes, + phase_metadata=self.phase_metadata, + request_summary_rows=request_summary_rows, + request_events=self.recorder.request_events, + task_summary_rows=task_summary_rows, + task_events=self.recorder.task_events, + ) + ) + + +def parse_args(argv: Optional[List[str]] = None) -> BenchmarkConfig: + server_host = os.getenv("SERVER_HOST", "127.0.0.1") + server_port = int(os.getenv("SERVER_PORT", "1933")) + default_server_url = f"http://{server_host}:{server_port}" + default_output_dir = ( + Path(__file__).resolve().parents[1] + / "results" + / "session_contention" + / datetime.now(UTC).strftime("%Y%m%dT%H%M%SZ") + ) + + parser = argparse.ArgumentParser( + description="Reproduce session addMessage/extract/commit contention against concurrent find traffic.", + ) + parser.add_argument("--server-url", default=default_server_url) + parser.add_argument("--api-key", default=os.getenv("OPENVIKING_API_KEY", "test-root-api-key")) + parser.add_argument("--account", default=os.getenv("OPENVIKING_ACCOUNT", "default")) + parser.add_argument("--user", default=os.getenv("OPENVIKING_USER", "default")) + parser.add_argument("--request-timeout", type=float, default=30.0) + parser.add_argument("--sessions", type=int, default=8) + parser.add_argument("--writer-concurrency", type=int, default=8) + parser.add_argument("--reader-concurrency", type=int, default=4) + parser.add_argument("--extract-concurrency", type=int, default=4) + parser.add_argument("--messages-per-commit", type=int, default=5) + parser.add_argument("--extract-ratio", type=float, default=0.5) + parser.add_argument("--message-size", type=int, default=768) + parser.add_argument("--baseline-seconds", type=float, default=30.0) + parser.add_argument("--mixed-seconds", type=float, default=120.0) + parser.add_argument("--recovery-seconds", type=float, default=30.0) + parser.add_argument("--window-seconds", type=float, default=5.0) + parser.add_argument("--observer-interval", type=float, default=5.0) + parser.add_argument("--task-poll-interval", type=float, default=1.0) + parser.add_argument("--task-drain-timeout", type=float, default=30.0) + parser.add_argument("--output-dir", default=str(default_output_dir)) + parser.add_argument("--cleanup", action="store_true") + parser.add_argument("--require-extract-load", action="store_true") + parser.add_argument( + "--find-query", + action="append", + dest="find_queries", + default=[], + help="Repeat to add multiple find queries.", + ) + parser.add_argument("--find-limit", type=int, default=10) + parser.add_argument("--find-target-uri", default="") + parser.add_argument("--find-score-threshold", type=float, default=None) + parser.add_argument("--seed", type=int, default=42) + + args = parser.parse_args(argv) + find_queries = args.find_queries or list(DEFAULT_FIND_QUERIES) + + config = BenchmarkConfig( + server_url=args.server_url, + api_key=args.api_key, + account=args.account, + user=args.user, + request_timeout=args.request_timeout, + session_count=max(0, args.sessions), + writer_concurrency=max(0, args.writer_concurrency), + reader_concurrency=max(0, args.reader_concurrency), + extract_concurrency=max(1, args.extract_concurrency), + messages_per_commit=max(1, args.messages_per_commit), + extract_ratio=min(max(args.extract_ratio, 0.0), 1.0), + message_size=max(128, args.message_size), + baseline_seconds=max(0.0, args.baseline_seconds), + mixed_seconds=max(0.0, args.mixed_seconds), + recovery_seconds=max(0.0, args.recovery_seconds), + window_seconds=max(args.window_seconds, 1.0), + observer_interval=0.0 if args.observer_interval <= 0 else max(args.observer_interval, 0.1), + task_poll_interval=max(args.task_poll_interval, 0.1), + task_drain_timeout=max(0.0, args.task_drain_timeout), + output_dir=args.output_dir, + cleanup=args.cleanup, + require_extract_load=args.require_extract_load, + find_queries=find_queries, + find_limit=max(1, args.find_limit), + find_target_uri=args.find_target_uri, + find_score_threshold=args.find_score_threshold, + seed=args.seed, + ) + if config.writer_concurrency > 0 and config.session_count <= 0: + parser.error("--sessions must be > 0 when --writer-concurrency is enabled") + return config + + +def maybe_json(response: httpx.Response) -> Optional[Dict[str, Any]]: + try: + body = response.json() + except ValueError: + return None + return body if isinstance(body, dict) else {"value": body} + + +def extract_error( + body: Optional[Dict[str, Any]], status_code: Optional[int] +) -> tuple[Optional[str], Optional[str]]: + if not isinstance(body, dict): + if status_code is None: + return None, None + return None, f"http status {status_code}" + error = body.get("error") + if isinstance(error, dict): + return error.get("code"), truncate_error_message(error.get("message")) + if body.get("status") not in {None, "ok"}: + return body.get("status"), truncate_error_message(json.dumps(body, ensure_ascii=False)) + if status_code is not None and status_code >= 400: + return None, f"http status {status_code}" + return None, None + + +def extract_session_id(body: Optional[Dict[str, Any]]) -> Optional[str]: + if not isinstance(body, dict): + return None + result = body.get("result") + if not isinstance(result, dict): + return None + session_id = result.get("session_id") + return session_id if isinstance(session_id, str) else None + + +def extract_task_id(body: Optional[Dict[str, Any]]) -> Optional[str]: + if not isinstance(body, dict): + return None + result = body.get("result") + if not isinstance(result, dict): + return None + task_id = result.get("task_id") + return task_id if isinstance(task_id, str) and task_id else None + + +def build_message_content( + *, session_id: str, cycle_index: int, message_index: int, size: int +) -> str: + prefix = ( + f"session={session_id} cycle={cycle_index} message={message_index}. " + "We discussed project goals, deployment constraints, user preferences, debugging notes, " + "timelines, risks, and follow-up actions. " + ) + detail = ( + "The user prefers production-safe changes, wants clear rollback steps, and asked for " + "memory extraction to keep decisions, entities, and events. " + "We also covered resource bottlenecks, queue backlog, response latency, and how read " + "traffic regressed during heavy write pressure. " + ) + content = prefix + while len(content) < size: + content += detail + return content[:size] + + +def truncate_error_message(message: Optional[str]) -> Optional[str]: + if message is None: + return None + if len(message) <= MAX_ERROR_MESSAGE_LEN: + return message + return message[:MAX_ERROR_MESSAGE_LEN] + "...[truncated]" + + +def utc_now() -> str: + return datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z") + + +def total_duration_seconds(phases: List[PhaseMetadata]) -> float: + return sum(item.duration_seconds for item in phases) + + +def percentile(values: Iterable[float], pct: float) -> Optional[float]: + ordered = sorted(float(value) for value in values) + if not ordered: + return None + if len(ordered) == 1: + return ordered[0] + rank = (pct / 100.0) * (len(ordered) - 1) + lower = math.floor(rank) + upper = math.ceil(rank) + if lower == upper: + return ordered[int(rank)] + weight = rank - lower + return ordered[lower] + (ordered[upper] - ordered[lower]) * weight + + +def build_request_summary_rows( + *, + events: List[RequestEvent], + phase_durations: Dict[str, float], + total_run_duration: float, +) -> List[Dict[str, Any]]: + rows: List[Dict[str, Any]] = [] + rows.extend( + _build_request_summary_for_groups( + events=events, + grouping=lambda event: (event.phase, event.api), + duration_lookup=phase_durations, + ) + ) + overall_groups = _build_request_summary_for_groups( + events=events, + grouping=lambda event: ("ALL", event.api), + duration_lookup={"ALL": total_run_duration}, + ) + rows.extend(overall_groups) + return sorted(rows, key=lambda row: (row["phase"], row["api"])) + + +def _build_request_summary_for_groups( + *, + events: List[RequestEvent], + grouping, + duration_lookup: Dict[str, float], +) -> List[Dict[str, Any]]: + groups: Dict[tuple[str, str], List[RequestEvent]] = {} + for event in events: + key = grouping(event) + groups.setdefault(key, []).append(event) + + rows: List[Dict[str, Any]] = [] + for (phase, api), api_events in groups.items(): + latencies = [event.latency_ms for event in api_events] + successes = sum(1 for event in api_events if event.success) + failures = len(api_events) - successes + timeouts = sum(1 for event in api_events if event.timeout) + exceptions = sum(1 for event in api_events if event.exception_type) + status_counts: Dict[str, int] = {} + for event in api_events: + key = str(event.status_code) if event.status_code is not None else "exception" + status_counts[key] = status_counts.get(key, 0) + 1 + duration = max(duration_lookup.get(phase, 0.0), 1e-9) + row = { + "phase": phase, + "api": api, + "requests": len(api_events), + "successes": successes, + "failures": failures, + "timeouts": timeouts, + "exceptions": exceptions, + "success_rate": round((successes / len(api_events)) * 100.0, 4), + "qps": round(len(api_events) / duration, 4), + "avg_ms": round(sum(latencies) / len(latencies), 4), + "p50_ms": round_optional(percentile(latencies, 50)), + "p90_ms": round_optional(percentile(latencies, 90)), + "p95_ms": round_optional(percentile(latencies, 95)), + "p99_ms": round_optional(percentile(latencies, 99)), + "max_ms": round_optional(max(latencies) if latencies else None), + "slow_gt_1s": sum( + 1 for latency in latencies if latency > DEFAULT_SLOW_THRESHOLDS_MS[0] + ), + "slow_gt_3s": sum( + 1 for latency in latencies if latency > DEFAULT_SLOW_THRESHOLDS_MS[1] + ), + "slow_gt_5s": sum( + 1 for latency in latencies if latency > DEFAULT_SLOW_THRESHOLDS_MS[2] + ), + "status_codes": json.dumps(status_counts, sort_keys=True), + } + rows.append(row) + return rows + + +def build_request_window_rows( + *, + events: List[RequestEvent], + window_seconds: float, +) -> List[Dict[str, Any]]: + groups: Dict[tuple[int, str, str], List[RequestEvent]] = {} + for event in events: + window_index = int((event.elapsed_ms_since_run_start / 1000.0) // window_seconds) + key = (window_index, event.phase, event.api) + groups.setdefault(key, []).append(event) + + rows: List[Dict[str, Any]] = [] + for (window_index, phase, api), window_events in sorted(groups.items()): + latencies = [event.latency_ms for event in window_events] + successes = sum(1 for event in window_events if event.success) + rows.append( + { + "window_index": window_index, + "window_start_sec": round(window_index * window_seconds, 4), + "window_end_sec": round((window_index + 1) * window_seconds, 4), + "phase": phase, + "api": api, + "requests": len(window_events), + "successes": successes, + "failures": len(window_events) - successes, + "success_rate": round((successes / len(window_events)) * 100.0, 4), + "qps": round(len(window_events) / window_seconds, 4), + "p95_ms": round_optional(percentile(latencies, 95)), + "p99_ms": round_optional(percentile(latencies, 99)), + "max_ms": round_optional(max(latencies) if latencies else None), + } + ) + return rows + + +def build_task_summary_rows(events: List[CommitTaskEvent]) -> List[Dict[str, Any]]: + groups: Dict[str, List[CommitTaskEvent]] = {} + for event in events: + groups.setdefault(event.status, []).append(event) + + rows: List[Dict[str, Any]] = [] + for status, status_events in sorted(groups.items()): + server_latencies = [ + event.server_duration_ms + for event in status_events + if event.server_duration_ms is not None + ] + local_latencies = [event.local_duration_ms for event in status_events] + successes = sum(1 for event in status_events if event.status == "completed") + rows.append( + { + "status": status, + "tasks": len(status_events), + "successes": successes, + "success_rate": round((successes / len(status_events)) * 100.0, 4), + "p50_server_duration_ms": round_optional(percentile(server_latencies, 50)), + "p95_server_duration_ms": round_optional(percentile(server_latencies, 95)), + "p99_server_duration_ms": round_optional(percentile(server_latencies, 99)), + "max_server_duration_ms": round_optional( + max(server_latencies) if server_latencies else None + ), + "p50_local_duration_ms": round_optional(percentile(local_latencies, 50)), + "p95_local_duration_ms": round_optional(percentile(local_latencies, 95)), + "p99_local_duration_ms": round_optional(percentile(local_latencies, 99)), + "max_local_duration_ms": round_optional( + max(local_latencies) if local_latencies else None + ), + } + ) + return rows + + +def build_find_phase_delta(summary_rows: List[Dict[str, Any]]) -> Optional[Dict[str, float]]: + baseline = next( + (row for row in summary_rows if row["phase"] == "baseline" and row["api"] == "find"), + None, + ) + mixed = next( + (row for row in summary_rows if row["phase"] == "mixed_load" and row["api"] == "find"), + None, + ) + if not baseline or not mixed: + return None + baseline_p95 = baseline.get("p95_ms") + baseline_p99 = baseline.get("p99_ms") + mixed_p95 = mixed.get("p95_ms") + mixed_p99 = mixed.get("p99_ms") + if not all(metric is not None for metric in [baseline_p95, baseline_p99, mixed_p95, mixed_p99]): + return None + return { + "baseline_p95_ms": baseline_p95, + "mixed_p95_ms": mixed_p95, + "p95_delta_percent": percent_change(baseline_p95, mixed_p95), + "baseline_p99_ms": baseline_p99, + "mixed_p99_ms": mixed_p99, + "p99_delta_percent": percent_change(baseline_p99, mixed_p99), + "baseline_success_rate": baseline["success_rate"], + "mixed_success_rate": mixed["success_rate"], + "success_rate_delta_percent": mixed["success_rate"] - baseline["success_rate"], + } + + +def find_request_summary_row( + summary_rows: List[Dict[str, Any]], + *, + api: str, + phase: str, +) -> Optional[Dict[str, Any]]: + return next((row for row in summary_rows if row["api"] == api and row["phase"] == phase), None) + + +def phase_target_seconds(config: BenchmarkConfig, phase: str) -> Optional[float]: + mapping = { + "baseline": config.baseline_seconds, + "mixed_load": config.mixed_seconds, + "recovery": config.recovery_seconds, + } + return mapping.get(phase) + + +def build_phase_overview_rows( + config: BenchmarkConfig, + phase_metadata: List[PhaseMetadata], +) -> List[Dict[str, Optional[float]]]: + rows: List[Dict[str, Optional[float]]] = [] + for item in phase_metadata: + target = phase_target_seconds(config, item.phase) + delta = None if target is None else item.duration_seconds - target + rows.append( + { + "phase": item.phase, + "target_seconds": round_optional(target), + "actual_seconds": round_optional(item.duration_seconds), + "delta_seconds": round_optional(delta), + } + ) + return rows + + +def build_api_error_breakdown( + events: List[RequestEvent], + *, + api: str, + phase: Optional[str] = None, +) -> Dict[str, Any]: + filtered = [ + event for event in events if event.api == api and (phase is None or event.phase == phase) + ] + exception_counts: Dict[str, int] = {} + error_counts: Dict[str, int] = {} + for event in filtered: + if event.exception_type: + exception_counts[event.exception_type] = ( + exception_counts.get(event.exception_type, 0) + 1 + ) + key = event.error_code or event.exception_type + if key: + error_counts[key] = error_counts.get(key, 0) + 1 + return { + "requests": len(filtered), + "successes": sum(1 for event in filtered if event.success), + "failures": sum(1 for event in filtered if not event.success), + "timeouts": sum(1 for event in filtered if event.timeout), + "exception_counts": exception_counts, + "error_counts": error_counts, + } + + +def format_phase_name_cn(phase: str) -> str: + mapping = { + "setup": "预热", + "baseline": "基线阶段", + "mixed_load": "混合压测阶段", + "recovery": "恢复阶段", + "drain": "收尾等待阶段", + "cleanup": "清理阶段", + "ALL": "全程", + } + return mapping.get(phase, phase) + + +def format_seconds(value: Optional[float]) -> str: + if value is None: + return "n/a" + return f"{value:.1f}s" + + +def format_percent(value: Optional[float]) -> str: + if value is None: + return "n/a" + return f"{value:.2f}%" + + +def format_delta_percent(value: Optional[float]) -> str: + if value is None: + return "n/a" + sign = "+" if value >= 0 else "" + return f"{sign}{value:.2f}%" + + +def format_delta_seconds(value: Optional[float]) -> str: + if value is None: + return "n/a" + sign = "+" if value >= 0 else "" + return f"{sign}{value:.1f}s" + + +def format_change(old: Optional[float], new: Optional[float], *, unit: str = "ms") -> str: + if old is None or new is None: + return "n/a" + if unit == "ms": + return ( + f"{old:.2f}{unit} -> {new:.2f}{unit} ({format_delta_percent(percent_change(old, new))})" + ) + return f"{old:.2f} -> {new:.2f} ({format_delta_percent(percent_change(old, new))})" + + +def format_qps_change(old: Optional[float], new: Optional[float]) -> str: + if old is None or new is None: + return "n/a" + return f"{old:.2f} -> {new:.2f} ({format_delta_percent(percent_change(old, new))})" + + +def render_human_summary_zh( + *, + config: BenchmarkConfig, + output_dir: str, + notes: List[str], + phase_metadata: List[PhaseMetadata], + request_summary_rows: List[Dict[str, Any]], + request_events: List[RequestEvent], + task_summary_rows: List[Dict[str, Any]], + task_events: List[CommitTaskEvent], +) -> str: + lines: List[str] = [] + lines.append("=== OpenViking Session 竞争压测摘要 ===") + lines.append(f"结果目录: {output_dir}") + + if notes: + lines.append("") + lines.append("说明:") + for note in notes: + lines.append(f"- {note}") + + phase_rows = build_phase_overview_rows(config, phase_metadata) + baseline_find = find_request_summary_row(request_summary_rows, api="find", phase="baseline") + mixed_find = find_request_summary_row(request_summary_rows, api="find", phase="mixed_load") + recovery_find = find_request_summary_row(request_summary_rows, api="find", phase="recovery") + mixed_add = find_request_summary_row( + request_summary_rows, api="add_message", phase="mixed_load" + ) + mixed_commit = find_request_summary_row(request_summary_rows, api="commit", phase="mixed_load") + mixed_extract = find_request_summary_row( + request_summary_rows, api="extract", phase="mixed_load" + ) + baseline_status = find_request_summary_row( + request_summary_rows, api="system_status", phase="baseline" + ) + mixed_status = find_request_summary_row( + request_summary_rows, api="system_status", phase="mixed_load" + ) + baseline_queue = find_request_summary_row( + request_summary_rows, api="observer_queue", phase="baseline" + ) + mixed_queue = find_request_summary_row( + request_summary_rows, api="observer_queue", phase="mixed_load" + ) + find_delta = build_find_phase_delta(request_summary_rows) + extract_breakdown = build_api_error_breakdown(request_events, api="extract", phase="mixed_load") + completed_tasks = next( + (row for row in task_summary_rows if row["status"] == "completed"), + None, + ) + incomplete_tasks = next( + (row for row in task_summary_rows if row["status"] == "incomplete"), + None, + ) + total_task_count = len(task_events) + + lines.append("") + lines.append("一、核心结论") + if baseline_find and mixed_find and find_delta: + lines.append( + "- 已明确复现读接口退化:`find` 在混合压测阶段的 p95 从 " + f"{baseline_find['p95_ms']:.2f}ms 升到 {mixed_find['p95_ms']:.2f}ms," + f"增幅 {find_delta['p95_delta_percent']:.2f}%;p99 从 " + f"{baseline_find['p99_ms']:.2f}ms 升到 {mixed_find['p99_ms']:.2f}ms," + f"增幅 {find_delta['p99_delta_percent']:.2f}%。" + ) + lines.append( + "- `find` 吞吐也下降了:QPS 从 " + f"{baseline_find['qps']:.2f} 降到 {mixed_find['qps']:.2f}," + f"变化 {format_delta_percent(percent_change(baseline_find['qps'], mixed_find['qps']))}。" + ) + if recovery_find and baseline_find and mixed_find: + lines.append( + "- 恢复阶段没有完全回到基线:`find` p95 为 " + f"{recovery_find['p95_ms']:.2f}ms,仍高于基线 " + f"{format_delta_percent(percent_change(baseline_find['p95_ms'], recovery_find['p95_ms']))};" + "但相比混合压测阶段已经有明显回落。" + ) + if mixed_extract: + lines.append( + "- 长尾压力主要来自 `extract`:混合压测阶段共 " + f"{mixed_extract['requests']} 次调用,成功率 {mixed_extract['success_rate']:.2f}%," + f"p95 {mixed_extract['p95_ms']:.2f}ms。" + ) + if extract_breakdown["timeouts"] > 0: + lines.append( + "- `extract` 失败几乎全是客户端超时:" + f"{extract_breakdown['timeouts']}/{extract_breakdown['requests']} 次超时," + f"主异常是 {format_top_counts(extract_breakdown['exception_counts'])}。" + ) + if mixed_commit and completed_tasks: + lines.append( + "- `commit` 接口本身不是最重的部分:前台 `commit` p95 只有 " + f"{mixed_commit['p95_ms']:.2f}ms;真正重的是后台任务,已完成任务的后台 p95 达 " + f"{completed_tasks['p95_server_duration_ms']:.2f}ms。" + ) + if incomplete_tasks: + lines.append( + "- 后台积压明显:本次共跟踪到 " + f"{total_task_count} 个 `commit` 背景任务,其中 {incomplete_tasks['tasks']} 个在压测结束" + "并等待 drain 后仍未完成。" + ) + + lines.append("") + lines.append("二、阶段时长") + for row in phase_rows: + extra = "" + if row["delta_seconds"] is not None and row["delta_seconds"] > 1: + extra = ",实际时长明显长于目标值,通常说明脚本在等待 in-flight 会话周期收尾" + lines.append( + f"- {format_phase_name_cn(row['phase'])}: 目标 {format_seconds(row['target_seconds'])}," + f"实际 {format_seconds(row['actual_seconds'])},偏差 {format_delta_seconds(row['delta_seconds'])}{extra}" + ) + + lines.append("") + lines.append("三、关键指标对比") + if baseline_find and mixed_find and recovery_find: + lines.append( + "- `find`:" + f" 基线 p95={baseline_find['p95_ms']:.2f}ms / p99={baseline_find['p99_ms']:.2f}ms / qps={baseline_find['qps']:.2f};" + f" 压测中 p95={mixed_find['p95_ms']:.2f}ms / p99={mixed_find['p99_ms']:.2f}ms / qps={mixed_find['qps']:.2f};" + f" 恢复期 p95={recovery_find['p95_ms']:.2f}ms / p99={recovery_find['p99_ms']:.2f}ms / qps={recovery_find['qps']:.2f}。" + ) + if mixed_add: + lines.append( + "- `add_message`: 混合压测阶段 " + f"requests={mixed_add['requests']},p50={mixed_add['p50_ms']:.2f}ms," + f"p95={mixed_add['p95_ms']:.2f}ms,p99={mixed_add['p99_ms']:.2f}ms。" + ) + if mixed_commit: + lines.append( + "- `commit`: 混合压测阶段 " + f"requests={mixed_commit['requests']},p50={mixed_commit['p50_ms']:.2f}ms," + f"p95={mixed_commit['p95_ms']:.2f}ms,p99={mixed_commit['p99_ms']:.2f}ms。" + ) + if mixed_extract: + lines.append( + "- `extract`: 混合压测阶段 " + f"requests={mixed_extract['requests']},success_rate={mixed_extract['success_rate']:.2f}%," + f"timeouts={extract_breakdown['timeouts']},p95={mixed_extract['p95_ms']:.2f}ms。" + ) + if completed_tasks: + lines.append( + "- `commit` 背景任务(completed):" + f" tasks={completed_tasks['tasks']},p50={format_metric(completed_tasks['p50_server_duration_ms'])}," + f" p95={format_metric(completed_tasks['p95_server_duration_ms'])}," + f" p99={format_metric(completed_tasks['p99_server_duration_ms'])}。" + ) + if incomplete_tasks: + lines.append( + "- `commit` 背景任务(incomplete):" + f" tasks={incomplete_tasks['tasks']},本地等待 p95={format_metric(incomplete_tasks['p95_local_duration_ms'])}。" + ) + if baseline_status and mixed_status: + lines.append( + "- `system_status`: p95 " + f"{format_change(baseline_status['p95_ms'], mixed_status['p95_ms'])}。" + ) + if baseline_queue and mixed_queue: + lines.append( + "- `observer_queue`: p95 " + f"{format_change(baseline_queue['p95_ms'], mixed_queue['p95_ms'])}。" + ) + + lines.append("") + lines.append("四、怎么理解这次结果") + lines.append( + "- `find` 没有报错,但延迟和吞吐同时变差,这比“报错”更说明问题:读请求被明显挤压了。" + ) + lines.append("- `extract` 的大量 30 秒超时说明长尾请求已经被稳定制造出来了,压测目标基本达成。") + lines.append( + "- `commit` 前台接口看起来还好,但后台任务非常慢,说明资源竞争更可能发生在后续提取/索引阶段,而不是 HTTP 返回这一步。" + ) + lines.append( + "- 如果你要拿这次结果给别人看,最应该盯的是三组数字:" + "`find` 基线 vs 压测 p95/p99、`extract` 超时比例、`commit` 背景任务完成时长。" + ) + + return "\n".join(lines) + + +def format_top_counts(counts: Dict[str, int], limit: int = 3) -> str: + if not counts: + return "无" + ordered = sorted(counts.items(), key=lambda item: (-item[1], item[0])) + return ", ".join(f"{key}={value}" for key, value in ordered[:limit]) + + +def percent_change(old: float, new: float) -> float: + if old == 0: + return 0.0 if new == 0 else 100.0 + return ((new - old) / old) * 100.0 + + +def round_optional(value: Optional[float], ndigits: int = 4) -> Optional[float]: + if value is None: + return None + return round(value, ndigits) + + +def write_json(path: Path, data: Any) -> None: + with path.open("w", encoding="utf-8") as handle: + json.dump(data, handle, indent=2, ensure_ascii=False) + + +def write_text(path: Path, content: str) -> None: + path.write_text(content, encoding="utf-8") + + +def write_jsonl(path: Path, rows: Iterable[Any]) -> None: + with path.open("w", encoding="utf-8") as handle: + for row in rows: + if hasattr(row, "to_dict"): + row = row.to_dict() + handle.write(json.dumps(row, ensure_ascii=False) + "\n") + + +def write_csv(path: Path, rows: List[Dict[str, Any]]) -> None: + if not rows: + path.write_text("", encoding="utf-8") + return + fieldnames = list(rows[0].keys()) + with path.open("w", encoding="utf-8", newline="") as handle: + writer = csv.DictWriter(handle, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(rows) + + +def extract_boolean(body: Optional[Dict[str, Any]], *keys: str) -> Optional[bool]: + current: Any = body + for key in keys: + if not isinstance(current, dict): + return None + current = current.get(key) + return current if isinstance(current, bool) else None + + +def to_float(value: Any) -> Optional[float]: + if isinstance(value, (float, int)): + return float(value) + return None + + +def format_metric(value: Optional[float]) -> str: + if value is None: + return "n/a" + return f"{value:.2f}ms" + + +async def async_main(argv: Optional[List[str]] = None) -> int: + config = parse_args(argv) + runner = BenchmarkRunner(config) + return await runner.run() + + +def main(argv: Optional[List[str]] = None) -> int: + try: + return asyncio.run(async_main(argv)) + except KeyboardInterrupt: + print("\n[stopped] benchmark interrupted by user", file=sys.stderr) + return 130 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/openviking/models/embedder/base.py b/openviking/models/embedder/base.py index 46e3b5b9b..f50597344 100644 --- a/openviking/models/embedder/base.py +++ b/openviking/models/embedder/base.py @@ -1,17 +1,35 @@ # Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. # SPDX-License-Identifier: AGPL-3.0 +import asyncio import random import time +import weakref from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, TypeVar +from threading import Lock +from typing import Any, Awaitable, Callable, Dict, List, Optional, TypeVar -from openviking.utils.model_retry import retry_sync +from openviking.telemetry import get_current_telemetry +from openviking.utils.model_retry import retry_async, retry_sync T = TypeVar("T") _token_tracker_instance = None +_ASYNC_EMBED_SEMAPHORES: "weakref.WeakKeyDictionary[asyncio.AbstractEventLoop, Dict[int, asyncio.Semaphore]]" = weakref.WeakKeyDictionary() +_ASYNC_EMBED_LOCK = Lock() + + +def _get_async_embed_semaphore(limit: int) -> asyncio.Semaphore: + loop = asyncio.get_running_loop() + normalized_limit = max(1, limit) + with _ASYNC_EMBED_LOCK: + semaphores_by_limit = _ASYNC_EMBED_SEMAPHORES.setdefault(loop, {}) + semaphore = semaphores_by_limit.get(normalized_limit) + if semaphore is None: + semaphore = asyncio.Semaphore(normalized_limit) + semaphores_by_limit[normalized_limit] = semaphore + return semaphore def _get_token_tracker(): @@ -24,6 +42,24 @@ def _get_token_tracker(): return _token_tracker_instance +async def embed_compat(embedder: Any, text: str, *, is_query: bool = False) -> "EmbedResult": + """Call async embedding when available, otherwise fall back to sync embed().""" + embed_async = getattr(embedder, "embed_async", None) + if callable(embed_async): + return await embed_async(text, is_query=is_query) + return embedder.embed(text, is_query=is_query) + + +async def embed_batch_compat( + embedder: Any, texts: List[str], *, is_query: bool = False +) -> List["EmbedResult"]: + """Call async batch embedding when available, otherwise fall back to sync embed_batch().""" + embed_batch_async = getattr(embedder, "embed_batch_async", None) + if callable(embed_batch_async): + return await embed_batch_async(texts, is_query=is_query) + return embedder.embed_batch(texts, is_query=is_query) + + def truncate_and_normalize(embedding: List[float], dimension: Optional[int]) -> List[float]: """Truncate and L2 normalize embedding vector @@ -90,6 +126,7 @@ def __init__(self, model_name: str, config: Optional[Dict[str, Any]] = None): self.model_name = model_name self.config = config or {} self.max_retries = int(self.config.get("max_retries", 3)) + self.max_concurrent = int(self.config.get("max_concurrent", 10)) self.provider = self.config.get("provider", "unknown") # Token usage tracking @@ -120,6 +157,24 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes """ return [self.embed(text, is_query=is_query) for text in texts] + async def embed_async(self, text: str, is_query: bool = False) -> EmbedResult: + """Async embed single text. + + Subclasses should override this with a non-blocking implementation. + The default implementation preserves compatibility for test doubles and + third-party embedders that only implement the sync interface. + """ + return self.embed(text, is_query=is_query) + + async def embed_batch_async( + self, texts: List[str], is_query: bool = False + ) -> List[EmbedResult]: + """Async batch embedding.""" + results: List[EmbedResult] = [] + for text in texts: + results.append(await self.embed_async(text, is_query=is_query)) + return results + def close(self): """Release resources, subclasses can override as needed""" pass @@ -132,6 +187,46 @@ def _run_with_retry(self, func: Callable[[], T], *, logger=None, operation_name: operation_name=operation_name, ) + async def _run_with_async_retry( + self, + func: Callable[[], Awaitable[T]], + *, + logger=None, + operation_name: str, + ) -> T: + async def _wrapped() -> T: + semaphore = _get_async_embed_semaphore(self.max_concurrent) + wait_started = time.monotonic() + await semaphore.acquire() + wait_elapsed = time.monotonic() - wait_started + telemetry = get_current_telemetry() + telemetry.set("embedding.async.max_concurrent", self.max_concurrent) + telemetry.set("embedding.async.wait_ms", round(wait_elapsed * 1000, 3)) + + started = time.monotonic() + try: + return await func() + finally: + elapsed = time.monotonic() - started + telemetry.set("embedding.async.duration_ms", round(elapsed * 1000, 3)) + if logger and elapsed >= 1.0: + logger.warning( + "%s slow call provider=%s model=%s wait_ms=%.2f duration_ms=%.2f", + operation_name, + self.provider, + self.model_name, + wait_elapsed * 1000, + elapsed * 1000, + ) + semaphore.release() + + return await retry_async( + _wrapped, + max_retries=self.max_retries, + logger=logger, + operation_name=operation_name, + ) + @property def is_dense(self) -> bool: """Check if result contains dense vector""" @@ -337,6 +432,27 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes for d, s in zip(dense_results, sparse_results, strict=True) ] + async def embed_async(self, text: str, is_query: bool = False) -> EmbedResult: + dense_res, sparse_res = await asyncio.gather( + self.dense_embedder.embed_async(text, is_query=is_query), + self.sparse_embedder.embed_async(text, is_query=is_query), + ) + return EmbedResult( + dense_vector=dense_res.dense_vector, sparse_vector=sparse_res.sparse_vector + ) + + async def embed_batch_async( + self, texts: List[str], is_query: bool = False + ) -> List[EmbedResult]: + dense_results, sparse_results = await asyncio.gather( + self.dense_embedder.embed_batch_async(texts, is_query=is_query), + self.sparse_embedder.embed_batch_async(texts, is_query=is_query), + ) + return [ + EmbedResult(dense_vector=d.dense_vector, sparse_vector=s.sparse_vector) + for d, s in zip(dense_results, sparse_results, strict=True) + ] + def get_dimension(self) -> int: return self.dense_embedder.get_dimension() diff --git a/openviking/models/embedder/cohere_embedders.py b/openviking/models/embedder/cohere_embedders.py index d80226d3b..5188ab9d5 100644 --- a/openviking/models/embedder/cohere_embedders.py +++ b/openviking/models/embedder/cohere_embedders.py @@ -7,12 +7,16 @@ for asymmetric retrieval. """ +import asyncio +import logging from typing import Any, Dict, List, Optional import httpx from openviking.models.embedder.base import DenseEmbedderBase, EmbedResult, truncate_and_normalize +logger = logging.getLogger(__name__) + COHERE_MODEL_DIMENSIONS = { "embed-v4.0": 1536, "embed-multilingual-v3.0": 1024, @@ -86,8 +90,9 @@ def __init__( }, timeout=60.0, ) + self._async_client: Optional[httpx.AsyncClient] = None - def _call_api(self, texts: List[str], input_type: str) -> List[List[float]]: + def _build_payload(self, texts: List[str], input_type: str) -> Dict[str, Any]: payload: Dict[str, Any] = { "model": self.model_name, "texts": texts, @@ -96,7 +101,27 @@ def _call_api(self, texts: List[str], input_type: str) -> List[List[float]]: } if self._use_server_dim: payload["output_dimension"] = self._dimension - resp = self._client.post("/v2/embed", json=payload) + return payload + + def _call_api(self, texts: List[str], input_type: str) -> List[List[float]]: + resp = self._client.post("/v2/embed", json=self._build_payload(texts, input_type)) + resp.raise_for_status() + data = resp.json() + return data["embeddings"]["float"] + + async def _call_api_async(self, texts: List[str], input_type: str) -> List[List[float]]: + if self._async_client is None: + self._async_client = httpx.AsyncClient( + base_url=self.api_base, + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + }, + timeout=60.0, + ) + resp = await self._async_client.post( + "/v2/embed", json=self._build_payload(texts, input_type) + ) resp.raise_for_status() data = resp.json() return data["embeddings"]["float"] @@ -128,6 +153,34 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult: except Exception as e: raise RuntimeError(f"Cohere embedding failed: {e}") from e + async def embed_async(self, text: str, is_query: bool = False) -> EmbedResult: + input_type = "search_query" if is_query else "search_document" + + async def _call() -> EmbedResult: + vectors = await self._call_api_async([text], input_type) + return EmbedResult(dense_vector=self._normalize_vector(vectors[0])) + + try: + result = await self._run_with_async_retry( + _call, + logger=logger, + operation_name="Cohere async embedding", + ) + estimated_tokens = self._estimate_tokens(text) + self.update_token_usage( + model_name=self.model_name, + provider="cohere", + prompt_tokens=estimated_tokens, + completion_tokens=0, + ) + return result + except httpx.HTTPStatusError as e: + raise RuntimeError( + f"Cohere API error: {e.response.status_code} {e.response.text}" + ) from e + except Exception as e: + raise RuntimeError(f"Cohere embedding failed: {e}") from e + def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]: if not texts: return [] @@ -154,9 +207,55 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes except Exception as e: raise RuntimeError(f"Cohere batch embedding failed: {e}") from e + async def embed_batch_async( + self, texts: List[str], is_query: bool = False + ) -> List[EmbedResult]: + if not texts: + return [] + + input_type = "search_query" if is_query else "search_document" + + async def _call() -> List[EmbedResult]: + results: List[EmbedResult] = [] + for i in range(0, len(texts), 96): + batch = texts[i : i + 96] + vectors = await self._call_api_async(batch, input_type) + results.extend(EmbedResult(dense_vector=self._normalize_vector(v)) for v in vectors) + return results + + try: + results = await self._run_with_async_retry( + _call, + logger=logger, + operation_name="Cohere async batch embedding", + ) + total_tokens = sum(self._estimate_tokens(text) for text in texts) + self.update_token_usage( + model_name=self.model_name, + provider="cohere", + prompt_tokens=total_tokens, + completion_tokens=0, + ) + return results + except httpx.HTTPStatusError as e: + raise RuntimeError( + f"Cohere API error: {e.response.status_code} {e.response.text}" + ) from e + except Exception as e: + raise RuntimeError(f"Cohere batch embedding failed: {e}") from e + def close(self): """Close the httpx client connection pool.""" self._client.close() + if self._async_client is not None: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + if loop and loop.is_running(): + loop.create_task(self._async_client.aclose()) + else: + asyncio.run(self._async_client.aclose()) def get_dimension(self) -> int: return self._dimension diff --git a/openviking/models/embedder/gemini_embedders.py b/openviking/models/embedder/gemini_embedders.py index 0d8ae74af..65a0e4a51 100644 --- a/openviking/models/embedder/gemini_embedders.py +++ b/openviking/models/embedder/gemini_embedders.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: AGPL-3.0 """Gemini Embedding 2 provider using the official google-genai SDK.""" +import asyncio from typing import Any, Dict, List, Optional from google import genai @@ -17,13 +18,6 @@ import logging -try: - import anyio - - _ANYIO_AVAILABLE = True -except ImportError: - _ANYIO_AVAILABLE = False - from openviking.models.embedder.base import ( DenseEmbedderBase, EmbedResult, @@ -182,6 +176,19 @@ def _build_config( kwargs["title"] = title return types.EmbedContentConfig(**kwargs) + def _resolve_task_type( + self, + *, + is_query: bool = False, + task_type: Optional[str] = None, + ) -> Optional[str]: + if task_type is None: + if is_query and self.query_param: + task_type = self.query_param + elif not is_query and self.document_param: + task_type = self.document_param + return task_type + def __repr__(self) -> str: return ( f"GeminiDenseEmbedder(" @@ -201,12 +208,7 @@ def embed( if not text or not text.strip(): logger.warning("Empty text passed to embed(), returning zero vector") return EmbedResult(dense_vector=[0.0] * self._dimension) - # Resolve effective task_type from is_query when no explicit override - if task_type is None: - if is_query and self.query_param: - task_type = self.query_param - elif not is_query and self.document_param: - task_type = self.document_param + task_type = self._resolve_task_type(is_query=is_query, task_type=task_type) # SDK accepts plain str; converts to REST Parts format internally. def _call() -> EmbedResult: @@ -240,6 +242,46 @@ def _call() -> EmbedResult: except (APIError, ClientError) as e: _raise_api_error(e, self.model_name) + async def embed_async( + self, + text: str, + is_query: bool = False, + *, + task_type: Optional[str] = None, + title: Optional[str] = None, + ) -> EmbedResult: + if not text or not text.strip(): + logger.warning("Empty text passed to embed_async(), returning zero vector") + return EmbedResult(dense_vector=[0.0] * self._dimension) + + task_type = self._resolve_task_type(is_query=is_query, task_type=task_type) + + async def _call() -> EmbedResult: + result = await self.client.aio.models.embed_content( + model=self.model_name, + contents=text, + config=self._build_config(task_type=task_type, title=title), + ) + vector = truncate_and_normalize(list(result.embeddings[0].values), self._dimension) + return EmbedResult(dense_vector=vector) + + try: + result = await self._run_with_async_retry( + _call, + logger=logger, + operation_name="Gemini async embedding", + ) + estimated_tokens = self._estimate_tokens(text) + self.update_token_usage( + model_name=self.model_name, + provider="gemini", + prompt_tokens=estimated_tokens, + completion_tokens=0, + ) + return result + except (APIError, ClientError) as e: + _raise_api_error(e, self.model_name) + def embed_batch( self, texts: List[str], @@ -256,12 +298,7 @@ def embed_batch( self.embed(text, is_query=is_query, task_type=task_type, title=title) for text, title in zip(texts, titles, strict=True) ] - # Resolve effective task_type from is_query when no explicit override - if task_type is None: - if is_query and self.query_param: - task_type = self.query_param - elif not is_query and self.document_param: - task_type = self.document_param + task_type = self._resolve_task_type(is_query=is_query, task_type=task_type) results: List[EmbedResult] = [] config = self._build_config(task_type=task_type) for i in range(0, len(texts), _TEXT_BATCH_SIZE): @@ -315,37 +352,64 @@ def _call_batch( # No need to track here to avoid double counting return results - async def async_embed_batch(self, texts: List[str]) -> List[EmbedResult]: - """Concurrent batch embedding via client.aio — requires anyio to be installed. - - Dispatches all 100-text chunks in parallel, bounded by max_concurrent_batches. - Per-batch APIError falls back to individual embed() calls via thread pool. - Raises ImportError if anyio is not installed. - """ - if not _ANYIO_AVAILABLE: - raise ImportError( - "anyio is required for async_embed_batch: pip install 'openviking[gemini-async]'" - ) + async def embed_batch_async( + self, + texts: List[str], + is_query: bool = False, + *, + task_type: Optional[str] = None, + titles: Optional[List[str]] = None, + ) -> List[EmbedResult]: if not texts: return [] + if titles is not None: + return [ + await self.embed_async( + text, + is_query=is_query, + task_type=task_type, + title=title, + ) + for text, title in zip(texts, titles, strict=True) + ] + + task_type = self._resolve_task_type(is_query=is_query, task_type=task_type) batches = [texts[i : i + _TEXT_BATCH_SIZE] for i in range(0, len(texts), _TEXT_BATCH_SIZE)] results: List[Optional[List[EmbedResult]]] = [None] * len(batches) - sem = anyio.Semaphore(self._max_concurrent_batches) + sem = asyncio.Semaphore(self._max_concurrent_batches) async def _embed_one(idx: int, batch: List[str]) -> None: async with sem: + non_empty_indices = [j for j, t in enumerate(batch) if t and t.strip()] + empty_indices = [j for j, t in enumerate(batch) if not (t and t.strip())] + batch_results: List[Optional[EmbedResult]] = [None] * len(batch) + for j in empty_indices: + batch_results[j] = EmbedResult(dense_vector=[0.0] * self._dimension) + + if not non_empty_indices: + results[idx] = [r for r in batch_results if r is not None] + return + + non_empty_texts = [batch[j] for j in non_empty_indices] + + async def _call_batch() -> Any: + return await self.client.aio.models.embed_content( + model=self.model_name, + contents=non_empty_texts, + config=self._build_config(task_type=task_type), + ) + try: - response = await self.client.aio.models.embed_content( - model=self.model_name, contents=batch, config=self._build_config() + response = await self._run_with_async_retry( + _call_batch, + logger=logger, + operation_name="Gemini async batch embedding", ) - results[idx] = [ - EmbedResult( + for j, emb in zip(non_empty_indices, response.embeddings, strict=True): + batch_results[j] = EmbedResult( dense_vector=truncate_and_normalize(list(emb.values), self._dimension) ) - for emb in response.embeddings - ] - # Track token usage for successful API call - total_tokens = sum(self._estimate_tokens(text) for text in batch) + total_tokens = sum(self._estimate_tokens(text) for text in non_empty_texts) self.update_token_usage( model_name=self.model_name, provider="gemini", @@ -354,21 +418,26 @@ async def _embed_one(idx: int, batch: List[str]) -> None: ) except (APIError, ClientError) as e: logger.warning( - "Gemini async batch embed failed (HTTP %d) for batch of %d, falling back", + "Gemini async batch embed failed (HTTP %d) for batch of %d, falling back to per-item async calls", e.code, len(batch), ) - # Token usage will be tracked via self.embed() calls - results[idx] = [ - await anyio.to_thread.run_sync(self.embed, text) for text in batch - ] + for j in non_empty_indices: + batch_results[j] = await self.embed_async( + batch[j], + is_query=is_query, + task_type=task_type, + ) - async with anyio.create_task_group() as tg: - for idx, batch in enumerate(batches): - tg.start_soon(_embed_one, idx, batch) + results[idx] = [r for r in batch_results if r is not None] + await asyncio.gather(*(_embed_one(idx, batch) for idx, batch in enumerate(batches))) return [r for batch_results in results for r in (batch_results or [])] + async def async_embed_batch(self, texts: List[str]) -> List[EmbedResult]: + """Backward-compatible alias for the standardized async batch API.""" + return await self.embed_batch_async(texts) + def get_dimension(self) -> int: return self._dimension diff --git a/openviking/models/embedder/jina_embedders.py b/openviking/models/embedder/jina_embedders.py index e13765421..5394c0c85 100644 --- a/openviking/models/embedder/jina_embedders.py +++ b/openviking/models/embedder/jina_embedders.py @@ -120,6 +120,7 @@ def __init__( api_key=self.api_key, base_url=self.api_base, ) + self._async_client = None # Determine dimension max_dim = JINA_MODEL_DIMENSIONS.get(model_name, 1024) @@ -145,6 +146,24 @@ def _build_extra_body(self, is_query: bool = False) -> Optional[Dict[str, Any]]: extra_body["late_chunking"] = self.late_chunking return extra_body if extra_body else None + def _build_kwargs(self, text_input: str | List[str], is_query: bool = False) -> Dict[str, Any]: + kwargs: Dict[str, Any] = {"input": text_input, "model": self.model_name} + if self.dimension: + kwargs["dimensions"] = self.dimension + + extra_body = self._build_extra_body(is_query=is_query) + if extra_body: + kwargs["extra_body"] = extra_body + return kwargs + + def _get_async_client(self): + if self._async_client is None: + self._async_client = openai.AsyncOpenAI( + api_key=self.api_key, + base_url=self.api_base, + ) + return self._async_client + def _raise_task_error(self, error: openai.APIError) -> None: """Raise an actionable error if a 422 indicates an invalid task type.""" if getattr(error, "status_code", None) == 422 and "task" in str(error.body): @@ -170,15 +189,7 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult: """ def _call() -> EmbedResult: - kwargs: Dict[str, Any] = {"input": text, "model": self.model_name} - if self.dimension: - kwargs["dimensions"] = self.dimension - - extra_body = self._build_extra_body(is_query=is_query) - if extra_body: - kwargs["extra_body"] = extra_body - - response = self.client.embeddings.create(**kwargs) + response = self.client.embeddings.create(**self._build_kwargs(text, is_query=is_query)) vector = response.data[0].embedding return EmbedResult(dense_vector=vector) @@ -204,6 +215,33 @@ def _call() -> EmbedResult: except Exception as e: raise RuntimeError(f"Embedding failed: {str(e)}") from e + async def embed_async(self, text: str, is_query: bool = False) -> EmbedResult: + client = self._get_async_client() + + async def _call() -> EmbedResult: + response = await client.embeddings.create(**self._build_kwargs(text, is_query=is_query)) + return EmbedResult(dense_vector=response.data[0].embedding) + + try: + result = await self._run_with_async_retry( + _call, + logger=logger, + operation_name="Jina async embedding", + ) + estimated_tokens = self._estimate_tokens(text) + self.update_token_usage( + model_name=self.model_name, + provider="jina", + prompt_tokens=estimated_tokens, + completion_tokens=0, + ) + return result + except openai.APIError as e: + self._raise_task_error(e) + raise RuntimeError(f"Jina API error: {e.message}") from e + except Exception as e: + raise RuntimeError(f"Embedding failed: {str(e)}") from e + def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]: """Batch embedding (Jina native support) @@ -221,15 +259,7 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes return [] def _call() -> List[EmbedResult]: - kwargs: Dict[str, Any] = {"input": texts, "model": self.model_name} - if self.dimension: - kwargs["dimensions"] = self.dimension - - extra_body = self._build_extra_body(is_query=is_query) - if extra_body: - kwargs["extra_body"] = extra_body - - response = self.client.embeddings.create(**kwargs) + response = self.client.embeddings.create(**self._build_kwargs(texts, is_query=is_query)) return [EmbedResult(dense_vector=item.embedding) for item in response.data] @@ -254,6 +284,40 @@ def _call() -> List[EmbedResult]: except Exception as e: raise RuntimeError(f"Batch embedding failed: {str(e)}") from e + async def embed_batch_async( + self, texts: List[str], is_query: bool = False + ) -> List[EmbedResult]: + if not texts: + return [] + + client = self._get_async_client() + + async def _call() -> List[EmbedResult]: + response = await client.embeddings.create( + **self._build_kwargs(texts, is_query=is_query) + ) + return [EmbedResult(dense_vector=item.embedding) for item in response.data] + + try: + results = await self._run_with_async_retry( + _call, + logger=logger, + operation_name="Jina async batch embedding", + ) + total_tokens = sum(self._estimate_tokens(text) for text in texts) + self.update_token_usage( + model_name=self.model_name, + provider="jina", + prompt_tokens=total_tokens, + completion_tokens=0, + ) + return results + except openai.APIError as e: + self._raise_task_error(e) + raise RuntimeError(f"Jina API error: {e.message}") from e + except Exception as e: + raise RuntimeError(f"Batch embedding failed: {str(e)}") from e + def get_dimension(self) -> int: """Get embedding dimension diff --git a/openviking/models/embedder/litellm_embedders.py b/openviking/models/embedder/litellm_embedders.py index 441b85fa0..4f7419619 100644 --- a/openviking/models/embedder/litellm_embedders.py +++ b/openviking/models/embedder/litellm_embedders.py @@ -182,6 +182,24 @@ def _call() -> EmbedResult: except Exception as e: raise RuntimeError(f"LiteLLM embedding failed: {e}") from e + async def embed_async(self, text: str, is_query: bool = False) -> EmbedResult: + async def _call() -> EmbedResult: + kwargs = self._build_kwargs(is_query=is_query) + kwargs["input"] = [text] + response = await litellm.aembedding(**kwargs) + self._update_telemetry_token_usage(response) + vector = response.data[0]["embedding"] + return EmbedResult(dense_vector=vector) + + try: + return await self._run_with_async_retry( + _call, + logger=logger, + operation_name="LiteLLM async embedding", + ) + except Exception as e: + raise RuntimeError(f"LiteLLM embedding failed: {e}") from e + def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]: """Batch embedding via litellm. @@ -214,6 +232,28 @@ def _call() -> List[EmbedResult]: except Exception as e: raise RuntimeError(f"LiteLLM batch embedding failed: {e}") from e + async def embed_batch_async( + self, texts: List[str], is_query: bool = False + ) -> List[EmbedResult]: + if not texts: + return [] + + async def _call() -> List[EmbedResult]: + kwargs = self._build_kwargs(is_query=is_query) + kwargs["input"] = texts + response = await litellm.aembedding(**kwargs) + self._update_telemetry_token_usage(response) + return [EmbedResult(dense_vector=item["embedding"]) for item in response.data] + + try: + return await self._run_with_async_retry( + _call, + logger=logger, + operation_name="LiteLLM async batch embedding", + ) + except Exception as e: + raise RuntimeError(f"LiteLLM batch embedding failed: {e}") from e + def get_dimension(self) -> int: """Get embedding dimension. diff --git a/openviking/models/embedder/minimax_embedders.py b/openviking/models/embedder/minimax_embedders.py index 7a5d79f3c..84542685c 100644 --- a/openviking/models/embedder/minimax_embedders.py +++ b/openviking/models/embedder/minimax_embedders.py @@ -2,8 +2,10 @@ # SPDX-License-Identifier: AGPL-3.0 """MiniMax Embedder Implementation via HTTP API""" +import asyncio from typing import Any, Dict, List, Optional +import httpx import requests from requests.adapters import HTTPAdapter from urllib3.util.retry import Retry @@ -77,6 +79,7 @@ def __init__( # Initialize session with retry logic self.session = self._create_session() + self._async_client: Optional[httpx.AsyncClient] = None # Auto-detect dimension if not provided if self._dimension is None: @@ -107,45 +110,80 @@ def _detect_dimension(self) -> int: def _call_api(self, texts: List[str], is_query: bool = False) -> List[List[float]]: """Call MiniMax API""" + headers = self._build_headers() + params = self._build_params() + payload = self._build_payload(texts, is_query=is_query) + + try: + response = self.session.post( + self.api_base, + headers=headers, + params=params, + json=payload, + timeout=60, # 60s timeout + ) + response.raise_for_status() + data = response.json() + + # Check for business error code + base_resp = data.get("base_resp", {}) + if base_resp.get("status_code") != 0: + raise RuntimeError(f"MiniMax API error: {base_resp.get('status_msg')}") + + vectors = data.get("vectors", []) + if not vectors: + raise RuntimeError("MiniMax API returned empty vectors") + + return vectors + + except requests.exceptions.RequestException as e: + raise RuntimeError(f"MiniMax network error: {str(e)}") from e + except Exception as e: + raise RuntimeError(f"MiniMax embedding failed: {str(e)}") from e + + def _build_headers(self) -> Dict[str, str]: headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } - - # Merge extra headers if self.extra_headers: for k, v in self.extra_headers.items(): if k.lower() not in ["authorization", "content-type", "groupid", "group_id"]: headers[k] = v + return headers - params = {} + def _build_params(self) -> Dict[str, str]: + params: Dict[str, str] = {} if self.group_id: params["GroupId"] = self.group_id + return params + def _build_payload(self, texts: List[str], is_query: bool = False) -> Dict[str, Any]: embed_type = "db" if is_query: embed_type = self.query_param if self.query_param is not None else "query" else: embed_type = self.document_param if self.document_param is not None else "db" - - payload = { + return { "model": self.model_name, "type": embed_type, "texts": texts, } + async def _call_api_async(self, texts: List[str], is_query: bool = False) -> List[List[float]]: + if self._async_client is None: + self._async_client = httpx.AsyncClient(timeout=60.0) + try: - response = self.session.post( + response = await self._async_client.post( self.api_base, - headers=headers, - params=params, - json=payload, - timeout=60, # 60s timeout + headers=self._build_headers(), + params=self._build_params(), + json=self._build_payload(texts, is_query=is_query), ) response.raise_for_status() data = response.json() - # Check for business error code base_resp = data.get("base_resp", {}) if base_resp.get("status_code") != 0: raise RuntimeError(f"MiniMax API error: {base_resp.get('status_msg')}") @@ -155,8 +193,7 @@ def _call_api(self, texts: List[str], is_query: bool = False) -> List[List[float raise RuntimeError("MiniMax API returned empty vectors") return vectors - - except requests.exceptions.RequestException as e: + except httpx.HTTPError as e: raise RuntimeError(f"MiniMax network error: {str(e)}") from e except Exception as e: raise RuntimeError(f"MiniMax embedding failed: {str(e)}") from e @@ -175,6 +212,25 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult: ) return result + async def embed_async(self, text: str, is_query: bool = False) -> EmbedResult: + async def _call() -> EmbedResult: + vectors = await self._call_api_async([text], is_query=is_query) + return EmbedResult(dense_vector=vectors[0]) + + result = await self._run_with_async_retry( + _call, + logger=logger, + operation_name="MiniMax async embedding", + ) + estimated_tokens = self._estimate_tokens(text) + self.update_token_usage( + model_name=self.model_name, + provider="minimax", + prompt_tokens=estimated_tokens, + completion_tokens=0, + ) + return result + def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]: """Batch embedding""" if not texts: @@ -194,6 +250,42 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes ) return results + async def embed_batch_async( + self, texts: List[str], is_query: bool = False + ) -> List[EmbedResult]: + if not texts: + return [] + + async def _call() -> List[EmbedResult]: + vectors = await self._call_api_async(texts, is_query=is_query) + return [EmbedResult(dense_vector=v) for v in vectors] + + results = await self._run_with_async_retry( + _call, + logger=logger, + operation_name="MiniMax async batch embedding", + ) + total_tokens = sum(self._estimate_tokens(text) for text in texts) + self.update_token_usage( + model_name=self.model_name, + provider="minimax", + prompt_tokens=total_tokens, + completion_tokens=0, + ) + return results + def get_dimension(self) -> int: """Get embedding dimension""" return self._dimension + + def close(self): + self.session.close() + if self._async_client is not None: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + if loop and loop.is_running(): + loop.create_task(self._async_client.aclose()) + else: + asyncio.run(self._async_client.aclose()) diff --git a/openviking/models/embedder/openai_embedders.py b/openviking/models/embedder/openai_embedders.py index 0e67069de..7a716f737 100644 --- a/openviking/models/embedder/openai_embedders.py +++ b/openviking/models/embedder/openai_embedders.py @@ -116,26 +116,27 @@ def __init__( self.query_param = query_param self.document_param = document_param self._provider = provider.lower() + self._client_kwargs: Dict[str, Any] = {"api_key": self.api_key or "no-key"} # Allow missing api_key when api_base is set (e.g. local OpenAI-compatible servers) if not self.api_key and not self.api_base: raise ValueError("api_key is required") - client_kwargs: Dict[str, Any] = {"api_key": self.api_key or "no-key"} if self._provider == "azure": if not self.api_base: raise ValueError("api_base (Azure endpoint) is required for Azure provider") - client_kwargs["azure_endpoint"] = self.api_base - client_kwargs["api_version"] = self.api_version or DEFAULT_AZURE_API_VERSION + self._client_kwargs["azure_endpoint"] = self.api_base + self._client_kwargs["api_version"] = self.api_version or DEFAULT_AZURE_API_VERSION if extra_headers: - client_kwargs["default_headers"] = extra_headers - self.client = openai.AzureOpenAI(**client_kwargs) + self._client_kwargs["default_headers"] = extra_headers + self.client = openai.AzureOpenAI(**self._client_kwargs) else: if self.api_base: - client_kwargs["base_url"] = self.api_base + self._client_kwargs["base_url"] = self.api_base if extra_headers: - client_kwargs["default_headers"] = extra_headers - self.client = openai.OpenAI(**client_kwargs) + self._client_kwargs["default_headers"] = extra_headers + self.client = openai.OpenAI(**self._client_kwargs) + self._async_client = None # Auto-detect dimension self._dimension = dimension @@ -235,6 +236,29 @@ def _build_extra_body(self, is_query: bool = False) -> Optional[Dict[str, Any]]: return extra_body if extra_body else None + def _build_kwargs(self, text_input: str | List[str], is_query: bool = False) -> Dict[str, Any]: + kwargs: Dict[str, Any] = {"input": text_input, "model": self.model_name} + if self.dimension and self._should_send_dimensions(): + kwargs["dimensions"] = self.dimension + + extra_body = self._build_extra_body(is_query=is_query) + if extra_body: + kwargs["extra_body"] = extra_body + return kwargs + + def _should_send_dimensions(self) -> bool: + # Preserve existing behavior for official OpenAI embeddings: only custom + # OpenAI-compatible backends and Azure send explicit dimensions. + return self._provider != "openai" or bool(self.api_base) + + def _get_async_client(self): + if self._async_client is None: + if self._provider == "azure": + self._async_client = openai.AsyncAzureOpenAI(**self._client_kwargs) + else: + self._async_client = openai.AsyncOpenAI(**self._client_kwargs) + return self._async_client + def embed(self, text: str, is_query: bool = False) -> EmbedResult: """Perform dense embedding on text @@ -250,15 +274,7 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult: """ def _call() -> EmbedResult: - kwargs: Dict[str, Any] = {"input": text, "model": self.model_name} - if self.dimension: - kwargs["dimensions"] = self.dimension - - extra_body = self._build_extra_body(is_query=is_query) - if extra_body: - kwargs["extra_body"] = extra_body - - response = self.client.embeddings.create(**kwargs) + response = self.client.embeddings.create(**self._build_kwargs(text, is_query=is_query)) self._update_telemetry_token_usage(response) vector = response.data[0].embedding @@ -275,6 +291,25 @@ def _call() -> EmbedResult: except Exception as e: raise RuntimeError(f"Embedding failed: {str(e)}") from e + async def embed_async(self, text: str, is_query: bool = False) -> EmbedResult: + client = self._get_async_client() + + async def _call() -> EmbedResult: + response = await client.embeddings.create(**self._build_kwargs(text, is_query=is_query)) + self._update_telemetry_token_usage(response) + return EmbedResult(dense_vector=response.data[0].embedding) + + try: + return await self._run_with_async_retry( + _call, + logger=logger, + operation_name="OpenAI async embedding", + ) + except openai.APIError as e: + raise RuntimeError(f"OpenAI API error: {e.message}") from e + except Exception as e: + raise RuntimeError(f"Embedding failed: {str(e)}") from e + def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]: """Batch embedding (OpenAI native support) @@ -292,15 +327,7 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes return [] def _call() -> List[EmbedResult]: - kwargs: Dict[str, Any] = {"input": texts, "model": self.model_name} - if self.dimension: - kwargs["dimensions"] = self.dimension - - extra_body = self._build_extra_body(is_query=is_query) - if extra_body: - kwargs["extra_body"] = extra_body - - response = self.client.embeddings.create(**kwargs) + response = self.client.embeddings.create(**self._build_kwargs(texts, is_query=is_query)) self._update_telemetry_token_usage(response) return [EmbedResult(dense_vector=item.embedding) for item in response.data] @@ -316,6 +343,32 @@ def _call() -> List[EmbedResult]: except Exception as e: raise RuntimeError(f"Batch embedding failed: {str(e)}") from e + async def embed_batch_async( + self, texts: List[str], is_query: bool = False + ) -> List[EmbedResult]: + if not texts: + return [] + + client = self._get_async_client() + + async def _call() -> List[EmbedResult]: + response = await client.embeddings.create( + **self._build_kwargs(texts, is_query=is_query) + ) + self._update_telemetry_token_usage(response) + return [EmbedResult(dense_vector=item.embedding) for item in response.data] + + try: + return await self._run_with_async_retry( + _call, + logger=logger, + operation_name="OpenAI async batch embedding", + ) + except openai.APIError as e: + raise RuntimeError(f"OpenAI API error: {e.message}") from e + except Exception as e: + raise RuntimeError(f"Batch embedding failed: {str(e)}") from e + def get_dimension(self) -> int: """Get embedding dimension diff --git a/openviking/models/embedder/vikingdb_embedders.py b/openviking/models/embedder/vikingdb_embedders.py index a6042316d..6fd07fdd0 100644 --- a/openviking/models/embedder/vikingdb_embedders.py +++ b/openviking/models/embedder/vikingdb_embedders.py @@ -2,15 +2,21 @@ # SPDX-License-Identifier: AGPL-3.0 """VikingDB Embedder Implementation via HTTP API""" +import asyncio from typing import Any, Dict, List, Optional +import httpx + from openviking.models.embedder.base import ( DenseEmbedderBase, EmbedResult, HybridEmbedderBase, SparseEmbedderBase, ) -from openviking.storage.vectordb.collection.volcengine_clients import ClientForDataApi +from openviking.storage.vectordb.collection.volcengine_clients import ( + DEFAULT_TIMEOUT, + ClientForDataApi, +) from openviking_cli.utils.logger import default_logger as logger @@ -33,6 +39,7 @@ def _init_vikingdb_client( raise ValueError("AK and SK are required for VikingDB Embedder") self.client = ClientForDataApi(self.ak, self.sk, self.region, self.host) + self._async_client: Optional[httpx.AsyncClient] = None def _call_api( self, @@ -66,6 +73,42 @@ def _call_api( logger.error(f"Failed to get embeddings: {e}") raise e + async def _call_api_async( + self, + texts: List[str], + dense_model: Dict[str, Any] = None, + sparse_model: Optional[Dict[str, Any]] = None, + ) -> List[Dict[str, Any]]: + path = "/api/vikingdb/embedding" + data_items = [{"text": text} for text in texts] + + req_body = {"data": data_items} + if dense_model: + req_body["dense_model"] = dense_model + if sparse_model: + req_body["sparse_model"] = sparse_model + + req = self.client.prepare_request(method="POST", path=path, data=req_body) + if self._async_client is None: + self._async_client = httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) + + response = await self._async_client.request( + method=req.method, + url=f"https://{self.host}{req.path}", + headers=req.headers, + content=req.body, + ) + if response.status_code != 200: + logger.warning( + "VikingDB API returned bad code: %s, message: %s", + response.status_code, + response.text, + ) + return [] + + result = response.json() + return result.get("result", {}).get("data", []) + def _truncate_and_normalize( self, embedding: List[float], dimension: Optional[int] ) -> List[float]: @@ -181,6 +224,63 @@ def _call() -> List[EmbedResult]: ) return results + async def embed_async(self, text: str, is_query: bool = False) -> EmbedResult: + async def _call() -> EmbedResult: + results = await self._call_api_async([text], dense_model=self.dense_model) + if not results: + return EmbedResult(dense_vector=[]) + + item = results[0] + dense_vector = [] + if "dense_embedding" in item: + dense_vector = self._truncate_and_normalize(item["dense_embedding"], self.dimension) + return EmbedResult(dense_vector=dense_vector) + + result = await self._run_with_async_retry( + _call, + logger=logger, + operation_name="VikingDB async embedding", + ) + estimated_tokens = self._estimate_tokens(text) + self.update_token_usage( + model_name=self.model_name, + provider="volcengine", + prompt_tokens=estimated_tokens, + completion_tokens=0, + ) + return result + + async def embed_batch_async( + self, texts: List[str], is_query: bool = False + ) -> List[EmbedResult]: + if not texts: + return [] + + async def _call() -> List[EmbedResult]: + raw_results = await self._call_api_async(texts, dense_model=self.dense_model) + return [ + EmbedResult( + dense_vector=self._truncate_and_normalize( + item.get("dense_embedding", []), self.dimension + ) + ) + for item in raw_results + ] + + results = await self._run_with_async_retry( + _call, + logger=logger, + operation_name="VikingDB async batch embedding", + ) + total_tokens = sum(self._estimate_tokens(text) for text in texts) + self.update_token_usage( + model_name=self.model_name, + provider="volcengine", + prompt_tokens=total_tokens, + completion_tokens=0, + ) + return results + def get_dimension(self) -> int: return self.dimension if self.dimension else 2048 @@ -262,6 +362,65 @@ def _call() -> List[EmbedResult]: ) return results + async def embed_async(self, text: str, is_query: bool = False) -> EmbedResult: + async def _call() -> EmbedResult: + results = await self._call_api_async([text], sparse_model=self.sparse_model) + if not results: + return EmbedResult(sparse_vector={}) + + item = results[0] + sparse_vector = {} + if "sparse" in item: + sparse_vector = item["sparse"] + elif "sparse_embedding" in item: + sparse_vector = self._process_sparse_embedding(item["sparse_embedding"]) + return EmbedResult(sparse_vector=sparse_vector) + + result = await self._run_with_async_retry( + _call, + logger=logger, + operation_name="VikingDB async sparse embedding", + ) + estimated_tokens = self._estimate_tokens(text) + self.update_token_usage( + model_name=self.model_name, + provider="volcengine", + prompt_tokens=estimated_tokens, + completion_tokens=0, + ) + return result + + async def embed_batch_async( + self, texts: List[str], is_query: bool = False + ) -> List[EmbedResult]: + if not texts: + return [] + + async def _call() -> List[EmbedResult]: + raw_results = await self._call_api_async(texts, sparse_model=self.sparse_model) + return [ + EmbedResult( + sparse_vector=self._process_sparse_embedding( + item.get("sparse_embedding", item.get("sparse", {})) + ) + ) + for item in raw_results + ] + + results = await self._run_with_async_retry( + _call, + logger=logger, + operation_name="VikingDB async sparse batch embedding", + ) + total_tokens = sum(self._estimate_tokens(text) for text in texts) + self.update_token_usage( + model_name=self.model_name, + provider="volcengine", + prompt_tokens=total_tokens, + completion_tokens=0, + ) + return results + class VikingDBHybridEmbedder(HybridEmbedderBase, VikingDBClientMixin): """VikingDB Hybrid Embedder""" @@ -357,5 +516,92 @@ def _call() -> List[EmbedResult]: ) return results + async def embed_async(self, text: str, is_query: bool = False) -> EmbedResult: + async def _call() -> EmbedResult: + results = await self._call_api_async( + [text], dense_model=self.dense_model, sparse_model=self.sparse_model + ) + if not results: + return EmbedResult(dense_vector=[], sparse_vector={}) + + item = results[0] + dense_vector = [] + sparse_vector = {} + if "dense" in item: + dense_vector = self._truncate_and_normalize(item["dense"], self.dimension) + elif "dense_embedding" in item: + dense_vector = self._truncate_and_normalize(item["dense_embedding"], self.dimension) + if "sparse" in item: + sparse_vector = item["sparse"] + elif "sparse_embedding" in item: + sparse_vector = self._process_sparse_embedding(item["sparse_embedding"]) + return EmbedResult(dense_vector=dense_vector, sparse_vector=sparse_vector) + + result = await self._run_with_async_retry( + _call, + logger=logger, + operation_name="VikingDB async hybrid embedding", + ) + estimated_tokens = self._estimate_tokens(text) + self.update_token_usage( + model_name=self.model_name, + provider="volcengine", + prompt_tokens=estimated_tokens, + completion_tokens=0, + ) + return result + + async def embed_batch_async( + self, texts: List[str], is_query: bool = False + ) -> List[EmbedResult]: + if not texts: + return [] + + async def _call() -> List[EmbedResult]: + raw_results = await self._call_api_async( + texts, dense_model=self.dense_model, sparse_model=self.sparse_model + ) + results = [] + for item in raw_results: + dense_vector = [] + sparse_vector = {} + if "dense" in item: + dense_vector = self._truncate_and_normalize(item["dense"], self.dimension) + elif "dense_embedding" in item: + dense_vector = self._truncate_and_normalize( + item["dense_embedding"], self.dimension + ) + if "sparse" in item: + sparse_vector = item["sparse"] + elif "sparse_embedding" in item: + sparse_vector = self._process_sparse_embedding(item["sparse_embedding"]) + results.append(EmbedResult(dense_vector=dense_vector, sparse_vector=sparse_vector)) + return results + + results = await self._run_with_async_retry( + _call, + logger=logger, + operation_name="VikingDB async hybrid batch embedding", + ) + total_tokens = sum(self._estimate_tokens(text) for text in texts) + self.update_token_usage( + model_name=self.model_name, + provider="volcengine", + prompt_tokens=total_tokens, + completion_tokens=0, + ) + return results + def get_dimension(self) -> int: return self.dimension if self.dimension else 2048 + + def close(self): + if getattr(self, "_async_client", None) is not None: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + if loop and loop.is_running(): + loop.create_task(self._async_client.aclose()) + else: + asyncio.run(self._async_client.aclose()) diff --git a/openviking/models/embedder/volcengine_embedders.py b/openviking/models/embedder/volcengine_embedders.py index 03f9915f5..8d1be0199 100644 --- a/openviking/models/embedder/volcengine_embedders.py +++ b/openviking/models/embedder/volcengine_embedders.py @@ -95,6 +95,10 @@ def __init__( if self.api_base: ark_kwargs["base_url"] = self.api_base self.client = volcenginesdkarkruntime.Ark(**ark_kwargs) + self._ark_kwargs = ark_kwargs + self._async_client = None + self._ark_kwargs = ark_kwargs + self._async_client = None # Auto-detect dimension self._dimension = dimension @@ -178,6 +182,37 @@ def _embed_call(): except Exception as e: raise RuntimeError(f"Volcengine embedding failed: {str(e)}") from e + def _get_async_client(self): + if self._async_client is None: + self._async_client = volcenginesdkarkruntime.AsyncArk(**self._ark_kwargs) + return self._async_client + + async def embed_async(self, text: str, is_query: bool = False) -> EmbedResult: + client = self._get_async_client() + + async def _embed_call() -> EmbedResult: + if self.input_type == "multimodal": + response = await client.multimodal_embeddings.create( + input=[{"type": "text", "text": text}], model=self.model_name + ) + self._update_telemetry_token_usage(response) + vector = response.data.embedding + else: + response = await client.embeddings.create(input=text, model=self.model_name) + self._update_telemetry_token_usage(response) + vector = response.data[0].embedding + + return EmbedResult(dense_vector=truncate_and_normalize(vector, self.dimension)) + + try: + return await self._run_with_async_retry( + _embed_call, + logger=logger, + operation_name="Volcengine async embedding", + ) + except Exception as e: + raise RuntimeError(f"Volcengine embedding failed: {str(e)}") from e + def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]: """Batch embedding @@ -224,6 +259,44 @@ def _call() -> List[EmbedResult]: ) raise RuntimeError(f"Volcengine batch embedding failed: {str(e)}") from e + async def embed_batch_async( + self, texts: List[str], is_query: bool = False + ) -> List[EmbedResult]: + if not texts: + return [] + + client = self._get_async_client() + + async def _call() -> List[EmbedResult]: + if self.input_type == "multimodal": + multimodal_inputs = [{"type": "text", "text": text} for text in texts] + response = await client.multimodal_embeddings.create( + input=multimodal_inputs, model=self.model_name + ) + self._update_telemetry_token_usage(response) + data = response.data + else: + response = await client.embeddings.create(input=texts, model=self.model_name) + self._update_telemetry_token_usage(response) + data = response.data + + return [ + EmbedResult(dense_vector=truncate_and_normalize(item.embedding, self.dimension)) + for item in data + ] + + try: + return await self._run_with_async_retry( + _call, + logger=logger, + operation_name="Volcengine async batch embedding", + ) + except Exception as e: + logger.error( + f"Volcengine async batch embedding failed, texts length: {len(texts)}, input_type: {self.input_type}, model_name: {self.model_name}" + ) + raise RuntimeError(f"Volcengine batch embedding failed: {str(e)}") from e + def get_dimension(self) -> int: return self._dimension @@ -329,6 +402,34 @@ def _embed_call(): except Exception as e: raise RuntimeError(f"Volcengine sparse embedding failed: {str(e)}") from e + def _get_async_client(self): + if self._async_client is None: + self._async_client = volcenginesdkarkruntime.AsyncArk(**self._ark_kwargs) + return self._async_client + + async def embed_async(self, text: str, is_query: bool = False) -> EmbedResult: + client = self._get_async_client() + + async def _embed_call() -> EmbedResult: + response = await client.multimodal_embeddings.create( + input=[{"type": "text", "text": text}], + model=self.model_name, + sparse_embedding={"type": "enabled"}, + ) + self._update_telemetry_token_usage(response) + item = response.data + sparse_vector = getattr(item, "sparse_embedding", None) + return EmbedResult(sparse_vector=process_sparse_embedding(sparse_vector)) + + try: + return await self._run_with_async_retry( + _embed_call, + logger=logger, + operation_name="Volcengine async sparse embedding", + ) + except Exception as e: + raise RuntimeError(f"Volcengine sparse embedding failed: {str(e)}") from e + def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]: """Batch sparse embedding @@ -346,6 +447,38 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes return [] return [self.embed(text) for text in texts] + async def embed_batch_async( + self, texts: List[str], is_query: bool = False + ) -> List[EmbedResult]: + if not texts: + return [] + + client = self._get_async_client() + + async def _call() -> List[EmbedResult]: + response = await client.multimodal_embeddings.create( + input=[{"type": "text", "text": text} for text in texts], + model=self.model_name, + sparse_embedding={"type": "enabled"}, + ) + self._update_telemetry_token_usage(response) + data = response.data + return [ + EmbedResult( + sparse_vector=process_sparse_embedding(getattr(item, "sparse_embedding", None)) + ) + for item in data + ] + + try: + return await self._run_with_async_retry( + _call, + logger=logger, + operation_name="Volcengine async sparse batch embedding", + ) + except Exception as e: + raise RuntimeError(f"Volcengine sparse embedding failed: {str(e)}") from e + class VolcengineHybridEmbedder(HybridEmbedderBase): """Volcengine Hybrid Embedder Implementation @@ -389,6 +522,8 @@ def __init__( if self.api_base: ark_kwargs["base_url"] = self.api_base self.client = volcenginesdkarkruntime.Ark(**ark_kwargs) + self._ark_kwargs = ark_kwargs + self._async_client = None self._dimension = dimension or 2048 def _update_telemetry_token_usage(self, response) -> None: @@ -460,6 +595,38 @@ def _embed_call(): except Exception as e: raise RuntimeError(f"Volcengine hybrid embedding failed: {str(e)}") from e + def _get_async_client(self): + if self._async_client is None: + self._async_client = volcenginesdkarkruntime.AsyncArk(**self._ark_kwargs) + return self._async_client + + async def embed_async(self, text: str, is_query: bool = False) -> EmbedResult: + client = self._get_async_client() + + async def _embed_call() -> EmbedResult: + response = await client.multimodal_embeddings.create( + input=[{"type": "text", "text": text}], + model=self.model_name, + sparse_embedding={"type": "enabled"}, + ) + self._update_telemetry_token_usage(response) + item = response.data + dense_vector = truncate_and_normalize(item.embedding, self.dimension) + sparse_vector = getattr(item, "sparse_embedding", None) + return EmbedResult( + dense_vector=dense_vector, + sparse_vector=process_sparse_embedding(sparse_vector), + ) + + try: + return await self._run_with_async_retry( + _embed_call, + logger=logger, + operation_name="Volcengine async hybrid embedding", + ) + except Exception as e: + raise RuntimeError(f"Volcengine hybrid embedding failed: {str(e)}") from e + def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]: """Batch hybrid embedding @@ -477,5 +644,38 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes return [] return [self.embed(text, is_query=is_query) for text in texts] + async def embed_batch_async( + self, texts: List[str], is_query: bool = False + ) -> List[EmbedResult]: + if not texts: + return [] + + client = self._get_async_client() + + async def _call() -> List[EmbedResult]: + response = await client.multimodal_embeddings.create( + input=[{"type": "text", "text": text} for text in texts], + model=self.model_name, + sparse_embedding={"type": "enabled"}, + ) + self._update_telemetry_token_usage(response) + data = response.data + return [ + EmbedResult( + dense_vector=truncate_and_normalize(item.embedding, self.dimension), + sparse_vector=process_sparse_embedding(getattr(item, "sparse_embedding", None)), + ) + for item in data + ] + + try: + return await self._run_with_async_retry( + _call, + logger=logger, + operation_name="Volcengine async hybrid batch embedding", + ) + except Exception as e: + raise RuntimeError(f"Volcengine hybrid embedding failed: {str(e)}") from e + def get_dimension(self) -> int: return self._dimension diff --git a/openviking/models/embedder/voyage_embedders.py b/openviking/models/embedder/voyage_embedders.py index c4c366942..55415b140 100644 --- a/openviking/models/embedder/voyage_embedders.py +++ b/openviking/models/embedder/voyage_embedders.py @@ -81,18 +81,29 @@ def __init__( api_key=self.api_key, base_url=self.api_base, ) + self._async_client = None self._dimension = dimension or get_voyage_model_default_dimension(normalized_model_name) + def _build_kwargs(self, text_input: str | List[str]) -> Dict[str, Any]: + kwargs: Dict[str, Any] = {"input": text_input, "model": self.model_name} + if self.dimension is not None: + kwargs["extra_body"] = {"output_dimension": self.dimension} + return kwargs + + def _get_async_client(self): + if self._async_client is None: + self._async_client = openai.AsyncOpenAI( + api_key=self.api_key, + base_url=self.api_base, + ) + return self._async_client + def embed(self, text: str, is_query: bool = False) -> EmbedResult: """Perform dense embedding on text.""" def _call() -> EmbedResult: - kwargs: Dict[str, Any] = {"input": text, "model": self.model_name} - if self.dimension is not None: - kwargs["extra_body"] = {"output_dimension": self.dimension} - - response = self.client.embeddings.create(**kwargs) + response = self.client.embeddings.create(**self._build_kwargs(text)) vector = response.data[0].embedding return EmbedResult(dense_vector=vector) @@ -116,17 +127,39 @@ def _call() -> EmbedResult: except Exception as e: raise RuntimeError(f"Embedding failed: {str(e)}") from e + async def embed_async(self, text: str, is_query: bool = False) -> EmbedResult: + client = self._get_async_client() + + async def _call() -> EmbedResult: + response = await client.embeddings.create(**self._build_kwargs(text)) + return EmbedResult(dense_vector=response.data[0].embedding) + + try: + result = await self._run_with_async_retry( + _call, + logger=logger, + operation_name="Voyage async embedding", + ) + estimated_tokens = self._estimate_tokens(text) + self.update_token_usage( + model_name=self.model_name, + provider="voyage", + prompt_tokens=estimated_tokens, + completion_tokens=0, + ) + return result + except openai.APIError as e: + raise RuntimeError(f"Voyage API error: {e.message}") from e + except Exception as e: + raise RuntimeError(f"Embedding failed: {str(e)}") from e + def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]: """Batch embedding.""" if not texts: return [] def _call() -> List[EmbedResult]: - kwargs: Dict[str, Any] = {"input": texts, "model": self.model_name} - if self.dimension is not None: - kwargs["extra_body"] = {"output_dimension": self.dimension} - - response = self.client.embeddings.create(**kwargs) + response = self.client.embeddings.create(**self._build_kwargs(texts)) return [EmbedResult(dense_vector=item.embedding) for item in response.data] try: @@ -149,6 +182,37 @@ def _call() -> List[EmbedResult]: except Exception as e: raise RuntimeError(f"Batch embedding failed: {str(e)}") from e + async def embed_batch_async( + self, texts: List[str], is_query: bool = False + ) -> List[EmbedResult]: + if not texts: + return [] + + client = self._get_async_client() + + async def _call() -> List[EmbedResult]: + response = await client.embeddings.create(**self._build_kwargs(texts)) + return [EmbedResult(dense_vector=item.embedding) for item in response.data] + + try: + results = await self._run_with_async_retry( + _call, + logger=logger, + operation_name="Voyage async batch embedding", + ) + total_tokens = sum(self._estimate_tokens(text) for text in texts) + self.update_token_usage( + model_name=self.model_name, + provider="voyage", + prompt_tokens=total_tokens, + completion_tokens=0, + ) + return results + except openai.APIError as e: + raise RuntimeError(f"Voyage API error: {e.message}") from e + except Exception as e: + raise RuntimeError(f"Batch embedding failed: {str(e)}") from e + def get_dimension(self) -> int: """Get embedding dimension.""" return self._dimension diff --git a/openviking/retrieve/hierarchical_retriever.py b/openviking/retrieve/hierarchical_retriever.py index 3192ccff8..5c7419200 100644 --- a/openviking/retrieve/hierarchical_retriever.py +++ b/openviking/retrieve/hierarchical_retriever.py @@ -14,7 +14,7 @@ from datetime import datetime from typing import Any, Dict, List, Optional, Tuple -from openviking.models.embedder.base import EmbedResult +from openviking.models.embedder.base import EmbedResult, embed_compat from openviking.models.rerank import RerankClient from openviking.retrieve.memory_lifecycle import hotness_score from openviking.retrieve.retrieval_stats import get_stats_collector @@ -129,7 +129,7 @@ async def retrieve( query_vector = None sparse_query_vector = None if self.embedder: - result: EmbedResult = self.embedder.embed(query.query, is_query=True) + result: EmbedResult = await embed_compat(self.embedder, query.query, is_query=True) query_vector = result.dense_vector sparse_query_vector = result.sparse_vector diff --git a/openviking/session/compressor.py b/openviking/session/compressor.py index 46b56ed41..40583ae65 100644 --- a/openviking/session/compressor.py +++ b/openviking/session/compressor.py @@ -12,6 +12,7 @@ from openviking.core.context import Context, Vectorize from openviking.message import Message +from openviking.models.embedder.base import embed_compat from openviking.server.identity import RequestContext from openviking.storage import VikingDBManager from openviking.storage.viking_fs import get_viking_fs @@ -472,8 +473,9 @@ async def extract_long_term_memories( merged_text = ( f"{action.memory.abstract} {candidate.content}" ) - merged_embed = self.deduplicator.embedder.embed( - merged_text + merged_embed = await embed_compat( + self.deduplicator.embedder, + merged_text, ) batch_memories.append( (merged_embed.dense_vector, action.memory) diff --git a/openviking/session/memory_deduplicator.py b/openviking/session/memory_deduplicator.py index 1459292ee..393873149 100644 --- a/openviking/session/memory_deduplicator.py +++ b/openviking/session/memory_deduplicator.py @@ -15,7 +15,7 @@ from typing import Dict, List, Optional from openviking.core.context import Context -from openviking.models.embedder.base import EmbedResult +from openviking.models.embedder.base import EmbedResult, embed_compat from openviking.prompts import render_prompt from openviking.server.identity import RequestContext from openviking.storage import VikingDBManager @@ -151,7 +151,7 @@ async def _find_similar_memories( # Generate embedding for candidate query_text = f"{candidate.abstract} {candidate.content}" - embed_result: EmbedResult = self.embedder.embed(query_text, is_query=True) + embed_result: EmbedResult = await embed_compat(self.embedder, query_text, is_query=True) query_vector = embed_result.dense_vector category_uri_prefix = self._category_uri_prefix(candidate.category.value, candidate.user) @@ -439,7 +439,7 @@ def _cosine_similarity(vec_a: List[float], vec_b: List[float]) -> float: if len(vec_a) != len(vec_b): return 0.0 - dot = sum(a * b for a, b in zip(vec_a, vec_b)) + dot = sum(a * b for a, b in zip(vec_a, vec_b, strict=False)) mag_a = sum(a * a for a in vec_a) ** 0.5 mag_b = sum(b * b for b in vec_b) ** 0.5 diff --git a/openviking/storage/collection_schemas.py b/openviking/storage/collection_schemas.py index ae09e8219..b2c41db53 100644 --- a/openviking/storage/collection_schemas.py +++ b/openviking/storage/collection_schemas.py @@ -15,7 +15,7 @@ from dataclasses import dataclass from typing import Any, Dict, List, Optional -from openviking.models.embedder.base import EmbedResult +from openviking.models.embedder.base import EmbedResult, embed_compat from openviking.server.identity import RequestContext, Role from openviking.storage.errors import CollectionNotFoundError from openviking.storage.queuefs.embedding_msg import EmbeddingMsg @@ -279,13 +279,11 @@ async def on_dequeue(self, data: Optional[Dict[str, Any]]) -> Optional[Dict[str, # Generate embedding vector(s) if self._embedder: try: - # embed() is a blocking HTTP call; offload to thread pool to avoid - # blocking the event loop and allow real concurrency. import time as _time _embed_t0 = _time.monotonic() - result: EmbedResult = await asyncio.to_thread( - self._embedder.embed, embedding_msg.message + result: EmbedResult = await embed_compat( + self._embedder, embedding_msg.message ) _embed_elapsed = _time.monotonic() - _embed_t0 try: diff --git a/openviking_cli/utils/config/embedding_config.py b/openviking_cli/utils/config/embedding_config.py index 2392198c9..9b7e40a96 100644 --- a/openviking_cli/utils/config/embedding_config.py +++ b/openviking_cli/utils/config/embedding_config.py @@ -333,6 +333,11 @@ def _create_embedder( raise ValueError("LiteLLM is not installed. Install it with: pip install litellm") # Factory registry: (provider, type) -> (embedder_class, param_builder) + runtime_config = { + "max_retries": self.max_retries, + "max_concurrent": self.max_concurrent, + } + factory_registry = { ("openai", "dense"): ( OpenAIDenseEmbedder, @@ -344,7 +349,7 @@ def _create_embedder( "api_version": cfg.api_version, "dimension": cfg.dimension, "provider": "openai", - "config": {"max_retries": self.max_retries}, + "config": dict(runtime_config), **({"query_param": cfg.query_param} if cfg.query_param else {}), **({"document_param": cfg.document_param} if cfg.document_param else {}), **({"extra_headers": cfg.extra_headers} if cfg.extra_headers else {}), @@ -359,7 +364,7 @@ def _create_embedder( "api_version": cfg.api_version, "dimension": cfg.dimension, "provider": "azure", - "config": {"max_retries": self.max_retries}, + "config": dict(runtime_config), **({"query_param": cfg.query_param} if cfg.query_param else {}), **({"document_param": cfg.document_param} if cfg.document_param else {}), **({"extra_headers": cfg.extra_headers} if cfg.extra_headers else {}), @@ -373,7 +378,7 @@ def _create_embedder( "api_base": cfg.api_base, "dimension": cfg.dimension, "input_type": cfg.input, - "config": {"max_retries": self.max_retries}, + "config": dict(runtime_config), }, ), ("volcengine", "sparse"): ( @@ -382,7 +387,7 @@ def _create_embedder( "model_name": cfg.model, "api_key": cfg.api_key, "api_base": cfg.api_base, - "config": {"max_retries": self.max_retries}, + "config": dict(runtime_config), }, ), ("volcengine", "hybrid"): ( @@ -393,7 +398,7 @@ def _create_embedder( "api_base": cfg.api_base, "dimension": cfg.dimension, "input_type": cfg.input, - "config": {"max_retries": self.max_retries}, + "config": dict(runtime_config), }, ), ("vikingdb", "dense"): ( @@ -407,7 +412,7 @@ def _create_embedder( "host": cfg.host, "dimension": cfg.dimension, "input_type": cfg.input, - "config": {"max_retries": self.max_retries}, + "config": dict(runtime_config), }, ), ("vikingdb", "sparse"): ( @@ -419,7 +424,7 @@ def _create_embedder( "sk": cfg.sk, "region": cfg.region, "host": cfg.host, - "config": {"max_retries": self.max_retries}, + "config": dict(runtime_config), }, ), ("vikingdb", "hybrid"): ( @@ -433,7 +438,7 @@ def _create_embedder( "host": cfg.host, "dimension": cfg.dimension, "input_type": cfg.input, - "config": {"max_retries": self.max_retries}, + "config": dict(runtime_config), }, ), ("jina", "dense"): ( @@ -443,7 +448,7 @@ def _create_embedder( "api_key": cfg.api_key, "api_base": cfg.api_base, "dimension": cfg.dimension, - "config": {"max_retries": self.max_retries}, + "config": dict(runtime_config), **({"query_param": cfg.query_param} if cfg.query_param else {}), **({"document_param": cfg.document_param} if cfg.document_param else {}), }, @@ -454,7 +459,7 @@ def _create_embedder( "model_name": cfg.model, "api_key": cfg.api_key, "dimension": cfg.dimension, - "config": {"max_retries": self.max_retries}, + "config": dict(runtime_config), **({"query_param": cfg.query_param} if cfg.query_param else {}), **({"document_param": cfg.document_param} if cfg.document_param else {}), }, @@ -468,7 +473,7 @@ def _create_embedder( or "no-key", # Ollama ignores the key, but client requires non-empty "api_base": cfg.api_base or "http://localhost:11434/v1", "dimension": cfg.dimension, - "config": {"max_retries": self.max_retries}, + "config": dict(runtime_config), }, ), ("voyage", "dense"): ( @@ -478,7 +483,7 @@ def _create_embedder( "api_key": cfg.api_key, "api_base": cfg.api_base, "dimension": cfg.dimension, - "config": {"max_retries": self.max_retries}, + "config": dict(runtime_config), }, ), ("minimax", "dense"): ( @@ -488,7 +493,7 @@ def _create_embedder( "api_key": cfg.api_key, "api_base": cfg.api_base, "dimension": cfg.dimension, - "config": {"max_retries": self.max_retries}, + "config": dict(runtime_config), **({"query_param": cfg.query_param} if cfg.query_param else {}), **({"document_param": cfg.document_param} if cfg.document_param else {}), **({"extra_headers": cfg.extra_headers} if cfg.extra_headers else {}), @@ -501,6 +506,7 @@ def _create_embedder( "api_key": cfg.api_key, "api_base": cfg.api_base, "dimension": cfg.dimension, + "config": dict(runtime_config), }, ), ("litellm", "dense"): ( @@ -510,7 +516,7 @@ def _create_embedder( "api_key": cfg.api_key, "api_base": cfg.api_base, "dimension": cfg.dimension, - "config": {"max_retries": self.max_retries}, + "config": dict(runtime_config), **({"query_param": cfg.query_param} if cfg.query_param else {}), **({"document_param": cfg.document_param} if cfg.document_param else {}), **({"extra_headers": cfg.extra_headers} if cfg.extra_headers else {}),