diff --git a/src/functions/search.ts b/src/functions/search.ts index 74af9ff1..b4444b48 100644 --- a/src/functions/search.ts +++ b/src/functions/search.ts @@ -86,6 +86,99 @@ export async function vectorIndexAddGuarded( } } +// Batched variant: calls EmbeddingProvider.embedBatch ONCE for the whole +// batch, then writes each resulting vector. Use this for bulk paths +// (rebuildIndex, future bulk-add APIs) where per-item serial awaits +// dominate wallclock. A batch of N has roughly the latency of a single +// embed (network + GPU setup amortized), so backfilling a 500k-obs +// corpus drops from days to hours on a per-batch endpoint like vLLM. +// +// Per-item failure shape: +// - whole-batch network/provider error → all skipped, single warn line +// - per-item dimension mismatch → that item skipped, others continue +export async function vectorIndexAddBatchGuarded( + items: Array<{ + id: string + sessionId: string + text: string + context: { kind: "memory" | "observation" | "synthetic"; logId: string } + }>, +): Promise<{ ok: number; fail: number }> { + const vi = vectorIndex + const ep = currentEmbeddingProvider + if (!vi || !ep || items.length === 0) return { ok: 0, fail: 0 } + + let embeddings: Float32Array[] + try { + embeddings = await ep.embedBatch(items.map((i) => clipEmbedInput(i.text))) + } catch (err) { + logger.warn("vector-index add batch: embed failed — skipping batch", { + batchSize: items.length, + provider: ep.name, + error: err instanceof Error ? err.message : String(err), + }) + return { ok: 0, fail: items.length } + } + + if (embeddings.length !== items.length) { + logger.warn( + "vector-index add batch: provider returned wrong length — skipping batch", + { + batchSize: items.length, + returned: embeddings.length, + provider: ep.name, + }, + ) + return { ok: 0, fail: items.length } + } + + let ok = 0 + let fail = 0 + for (let i = 0; i < items.length; i++) { + const item = items[i] + const embedding = embeddings[i] + if (embedding.length !== ep.dimensions) { + logger.warn("vector-index add batch: dimension mismatch — skipping item", { + kind: item.context.kind, + id: item.context.logId, + provider: ep.name, + expected: ep.dimensions, + received: embedding.length, + }) + fail++ + continue + } + try { + vi.add(item.id, item.sessionId, embedding) + ok++ + } catch (err) { + logger.warn("vector-index add batch: index write failed — skipping item", { + kind: item.context.kind, + id: item.context.logId, + error: err instanceof Error ? err.message : String(err), + }) + fail++ + } + } + return { ok, fail } +} + +// Embed-batch size for rebuild. Each item is one /v1/embeddings call's +// `input` array element; the provider sees the whole batch as one HTTP +// round-trip. 32 fits comfortably under typical per-request token budgets +// (32 × ~110 tok/item ≈ 3.5k tokens) and gets close to per-call +// throughput for GPU-backed endpoints (vLLM, Triton, etc.). Override via +// REBUILD_EMBED_BATCH_SIZE for endpoints that prefer smaller/larger +// batches. Set to 1 to fall back to the legacy per-item path. +const DEFAULT_REBUILD_EMBED_BATCH = 32 + +function getRebuildEmbedBatchSize(): number { + const raw = process.env.REBUILD_EMBED_BATCH_SIZE + if (!raw) return DEFAULT_REBUILD_EMBED_BATCH + const n = parseInt(raw, 10) + return Number.isFinite(n) && n > 0 ? n : DEFAULT_REBUILD_EMBED_BATCH +} + export async function rebuildIndex(kv: StateKV): Promise { const idx = getSearchIndex() idx.clear() @@ -96,8 +189,28 @@ export async function rebuildIndex(kv: StateKV): Promise { // repopulation loops run, so BM25 and vector stay in sync. vectorIndex?.clear() + const batchSize = getRebuildEmbedBatchSize() + // Accumulator for the batched embed flush. BM25 add is synchronous and + // doesn't need batching — only the vector path benefits. + type EmbedJob = { + id: string + sessionId: string + text: string + context: { kind: "memory" | "observation" | "synthetic"; logId: string } + } + const pending: EmbedJob[] = [] let count = 0 + const flush = async (): Promise => { + if (pending.length === 0) return + await vectorIndexAddBatchGuarded(pending) + pending.length = 0 + } + const enqueue = async (job: EmbedJob): Promise => { + pending.push(job) + if (pending.length >= batchSize) await flush() + } + // Memories live in their own KV scope outside per-session observation // scopes, so they need a separate walk. Without this, mem::remember // entries vanish from BM25 on every restart even after the live-write @@ -108,12 +221,12 @@ export async function rebuildIndex(kv: StateKV): Promise { if (memory.isLatest === false) continue if (!memory.title || !memory.content) continue idx.add(memoryToObservation(memory)) - await vectorIndexAddGuarded( - memory.id, - memory.sessionIds[0] ?? 'memory', - memory.title + ' ' + memory.content, - { kind: "memory", logId: memory.id }, - ) + await enqueue({ + id: memory.id, + sessionId: memory.sessionIds[0] ?? 'memory', + text: memory.title + ' ' + memory.content, + context: { kind: "memory", logId: memory.id }, + }) count++ } } catch (err) { @@ -123,7 +236,10 @@ export async function rebuildIndex(kv: StateKV): Promise { } const sessions = await kv.list(KV.sessions) - if (!sessions.length) return count + if (!sessions.length) { + await flush() + return count + } const obsPerSession: CompressedObservation[][] = [] const failedSessions: string[] = [] @@ -148,16 +264,19 @@ export async function rebuildIndex(kv: StateKV): Promise { for (const obs of observations) { if (obs.title && obs.narrative) { idx.add(obs) - await vectorIndexAddGuarded( - obs.id, - obs.sessionId, - obs.title + ' ' + obs.narrative, - { kind: "observation", logId: obs.id }, - ) + await enqueue({ + id: obs.id, + sessionId: obs.sessionId, + text: obs.title + ' ' + obs.narrative, + context: { kind: "observation", logId: obs.id }, + }) count++ } } } + + // Drain the last partial batch. + await flush() return count }