diff --git a/src/config.ts b/src/config.ts index b9545a5..bf80a0a 100644 --- a/src/config.ts +++ b/src/config.ts @@ -8,6 +8,12 @@ import type { FallbackConfig, } from "./types.js"; +function safeParseInt(value: string | undefined, fallback: number): number { + if (!value) return fallback; + const parsed = parseInt(value, 10); + return Number.isNaN(parsed) ? fallback : parsed; +} + const DATA_DIR = join(homedir(), ".agentmemory"); const ENV_FILE = join(DATA_DIR, ".env"); @@ -71,14 +77,11 @@ export function loadConfig(): AgentMemoryConfig { return { engineUrl: env["III_ENGINE_URL"] || "ws://localhost:49134", - restPort: parseInt(env["III_REST_PORT"] || "3111", 10), - streamsPort: parseInt(env["III_STREAMS_PORT"] || "3112", 10), + restPort: parseInt(env["III_REST_PORT"] || "3111", 10) || 3111, + streamsPort: parseInt(env["III_STREAMS_PORT"] || "3112", 10) || 3112, provider, - tokenBudget: parseInt(env["TOKEN_BUDGET"] || "2000", 10), - maxObservationsPerSession: parseInt( - env["MAX_OBS_PER_SESSION"] || "500", - 10, - ), + tokenBudget: safeParseInt(env["TOKEN_BUDGET"], 2000), + maxObservationsPerSession: safeParseInt(env["MAX_OBS_PER_SESSION"], 500), compressionModel: provider.model, dataDir: DATA_DIR, }; diff --git a/src/functions/auto-forget.ts b/src/functions/auto-forget.ts index 510c51a..ce90d83 100644 --- a/src/functions/auto-forget.ts +++ b/src/functions/auto-forget.ts @@ -52,9 +52,13 @@ export function registerAutoForgetFunction(sdk: ISdk, kv: StateKV): void { } } - const latestMemories = memories.filter( - (m) => m.isLatest !== false && !deletedIds.has(m.id), - ); + const latestMemories = memories + .filter((m) => m.isLatest !== false && !deletedIds.has(m.id)) + .sort( + (a, b) => + new Date(b.createdAt).getTime() - new Date(a.createdAt).getTime(), + ) + .slice(0, 1000); for (let i = 0; i < latestMemories.length; i++) { for (let j = i + 1; j < latestMemories.length; j++) { const sim = jaccardSimilarity( diff --git a/src/functions/compress.ts b/src/functions/compress.ts index ead2775..8e8bf10 100644 --- a/src/functions/compress.ts +++ b/src/functions/compress.ts @@ -175,7 +175,7 @@ export function registerCompressFunction( obsId: data.observationId, error: msg, }); - return { success: false, error: msg }; + return { success: false, error: "compression_failed" }; } }, ); diff --git a/src/functions/export-import.ts b/src/functions/export-import.ts index 62b5a07..ac9dc41 100644 --- a/src/functions/export-import.ts +++ b/src/functions/export-import.ts @@ -85,6 +85,76 @@ export function registerExportImportFunction(sdk: ISdk, kv: StateKV): void { }; } + const MAX_SESSIONS = 10_000; + const MAX_MEMORIES = 50_000; + const MAX_SUMMARIES = 10_000; + const MAX_OBS_PER_SESSION = 5_000; + const MAX_TOTAL_OBSERVATIONS = 500_000; + + if (!Array.isArray(importData.sessions)) { + return { success: false, error: "sessions must be an array" }; + } + if (!Array.isArray(importData.memories)) { + return { success: false, error: "memories must be an array" }; + } + if (!Array.isArray(importData.summaries)) { + return { success: false, error: "summaries must be an array" }; + } + if ( + typeof importData.observations !== "object" || + importData.observations === null || + Array.isArray(importData.observations) + ) { + return { success: false, error: "observations must be an object" }; + } + + if (importData.sessions.length > MAX_SESSIONS) { + return { + success: false, + error: `Too many sessions (max ${MAX_SESSIONS})`, + }; + } + if (importData.memories.length > MAX_MEMORIES) { + return { + success: false, + error: `Too many memories (max ${MAX_MEMORIES})`, + }; + } + if (importData.summaries.length > MAX_SUMMARIES) { + return { + success: false, + error: `Too many summaries (max ${MAX_SUMMARIES})`, + }; + } + const MAX_OBS_BUCKETS = 10_000; + const obsBuckets = Object.keys(importData.observations); + if (obsBuckets.length > MAX_OBS_BUCKETS) { + return { + success: false, + error: `Too many observation buckets (max ${MAX_OBS_BUCKETS})`, + }; + } + + let totalObservations = 0; + for (const [, obs] of Object.entries(importData.observations)) { + if (!Array.isArray(obs)) { + return { success: false, error: "observation values must be arrays" }; + } + if (obs.length > MAX_OBS_PER_SESSION) { + return { + success: false, + error: `Too many observations per session (max ${MAX_OBS_PER_SESSION})`, + }; + } + totalObservations += obs.length; + } + if (totalObservations > MAX_TOTAL_OBSERVATIONS) { + return { + success: false, + error: `Too many total observations (max ${MAX_TOTAL_OBSERVATIONS})`, + }; + } + const stats = { sessions: 0, observations: 0, diff --git a/src/functions/migrate.ts b/src/functions/migrate.ts index f20c6bd..cb7f28a 100644 --- a/src/functions/migrate.ts +++ b/src/functions/migrate.ts @@ -10,10 +10,7 @@ import type { SessionSummary, } from "../types.js"; -const ALLOWED_DIRS = [ - resolve(homedir(), ".agentmemory"), - resolve(homedir(), ".claude"), -]; +const ALLOWED_DIRS = [resolve(homedir(), ".agentmemory")]; function isAllowedPath(dbPath: string): boolean { const resolved = resolve(dbPath); @@ -149,7 +146,7 @@ export function registerMigrateFunction(sdk: ISdk, kv: StateKV): void { } catch (err) { const msg = err instanceof Error ? err.message : String(err); ctx.logger.error("Migration failed", { error: msg }); - return { success: false, error: msg }; + return { success: false, error: "Migration failed" }; } }, ); diff --git a/src/functions/observe.ts b/src/functions/observe.ts index 389c0bf..1fa7e18 100644 --- a/src/functions/observe.ts +++ b/src/functions/observe.ts @@ -18,6 +18,22 @@ export function registerObserveFunction( }, async (payload: HookPayload) => { const ctx = getContext(); + + if ( + !payload?.sessionId || + typeof payload.sessionId !== "string" || + !payload.hookType || + typeof payload.hookType !== "string" || + !payload.timestamp || + typeof payload.timestamp !== "string" + ) { + return { + success: false, + error: + "Invalid payload: sessionId, hookType, and timestamp are required", + }; + } + const obsId = generateId("obs"); if (dedupMap) { diff --git a/src/functions/privacy.ts b/src/functions/privacy.ts index e93b506..3237356 100644 --- a/src/functions/privacy.ts +++ b/src/functions/privacy.ts @@ -1,30 +1,36 @@ -import type { ISdk } from 'iii-sdk' +import type { ISdk } from "iii-sdk"; -const PRIVATE_TAG_RE = /[\s\S]*?<\/private>/gi +const PRIVATE_TAG_RE = /[\s\S]*?<\/private>/gi; const SECRET_PATTERN_SOURCES = [ /(?:api[_-]?key|secret|token|password|credential|auth)[\s]*[=:]\s*["']?[A-Za-z0-9_\-/.+]{20,}["']?/gi, /(?:sk|pk|rk|ak)-[A-Za-z0-9]{20,}/g, + /sk-ant-[A-Za-z0-9\-_]{20,}/g, /ghp_[A-Za-z0-9]{36}/g, + /github_pat_[A-Za-z0-9_]{22,}/g, /xoxb-[A-Za-z0-9\-]+/g, /AKIA[0-9A-Z]{16}/g, + /AIza[A-Za-z0-9\-_]{35}/g, /eyJ[A-Za-z0-9_-]{10,}\.[A-Za-z0-9_-]{10,}\.[A-Za-z0-9_-]{10,}/g, -] +]; export function stripPrivateData(input: string): string { - let result = input.replace(PRIVATE_TAG_RE, '[REDACTED]') + let result = input.replace(PRIVATE_TAG_RE, "[REDACTED]"); for (const source of SECRET_PATTERN_SOURCES) { - const pattern = new RegExp(source.source, source.flags) - result = result.replace(pattern, '[REDACTED_SECRET]') + const pattern = new RegExp(source.source, source.flags); + result = result.replace(pattern, "[REDACTED_SECRET]"); } - return result + return result; } export function registerPrivacyFunction(sdk: ISdk): void { sdk.registerFunction( - { id: 'mem::privacy', description: 'Strip private tags and secrets from input' }, + { + id: "mem::privacy", + description: "Strip private tags and secrets from input", + }, async (data: { input: string }) => { - return { output: stripPrivateData(data.input) } - } - ) + return { output: stripPrivateData(data.input) }; + }, + ); } diff --git a/src/functions/relations.ts b/src/functions/relations.ts index 910c9ea..dcb6af7 100644 --- a/src/functions/relations.ts +++ b/src/functions/relations.ts @@ -112,7 +112,12 @@ export function registerRelationsFunction(sdk: ISdk, kv: StateKV): void { }, async (data: { memoryId: string; maxHops?: number }) => { const ctx = getContext(); - const maxHops = data.maxHops ?? 2; + const maxHops = Math.min(data.maxHops ?? 2, 5); + const MAX_VISITED = 500; + + const allRelations = await kv + .list(KV.relations) + .catch(() => []); const visited = new Set(); const result: Array<{ memory: Memory; hop: number }> = []; @@ -120,7 +125,7 @@ export function registerRelationsFunction(sdk: ISdk, kv: StateKV): void { { id: data.memoryId, hop: 0 }, ]; - while (queue.length > 0) { + while (queue.length > 0 && visited.size < MAX_VISITED) { const current = queue.shift()!; if (visited.has(current.id) || current.hop > maxHops) continue; visited.add(current.id); @@ -136,10 +141,7 @@ export function registerRelationsFunction(sdk: ISdk, kv: StateKV): void { const supersedes = memory.supersedes || []; const parentId = memory.parentId ? [memory.parentId] : []; - const kvRelations = await kv - .list(KV.relations) - .catch(() => []); - const kvLinked = kvRelations + const kvLinked = allRelations .filter((r) => r.sourceId === current.id || r.targetId === current.id) .map((r) => (r.sourceId === current.id ? r.targetId : r.sourceId)); diff --git a/src/functions/remember.ts b/src/functions/remember.ts index 739f633..1f42da7 100644 --- a/src/functions/remember.ts +++ b/src/functions/remember.ts @@ -14,9 +14,19 @@ export function registerRememberFunction(sdk: ISdk, kv: StateKV): void { files?: string[]; }) => { const ctx = getContext(); - if (!data.content || !data.content.trim()) { + if ( + !data.content || + typeof data.content !== "string" || + !data.content.trim() + ) { return { success: false, error: "content is required" }; } + if (data.files && !Array.isArray(data.files)) { + return { success: false, error: "files must be an array" }; + } + if (data.concepts && !Array.isArray(data.concepts)) { + return { success: false, error: "concepts must be an array" }; + } const validTypes = new Set([ "pattern", "preference", diff --git a/src/functions/smart-search.ts b/src/functions/smart-search.ts index 7c59c6e..d525d6a 100644 --- a/src/functions/smart-search.ts +++ b/src/functions/smart-search.ts @@ -23,13 +23,14 @@ export function registerSmartSearchFunction( const ctx = getContext(); if (data.expandIds && data.expandIds.length > 0) { + const ids = data.expandIds.slice(0, 20); const expanded: Array<{ obsId: string; sessionId: string; observation: CompressedObservation; }> = []; - for (const obsId of data.expandIds) { + for (const obsId of ids) { const obs = await findObservation(kv, obsId); if (obs) { expanded.push({ @@ -40,11 +41,14 @@ export function registerSmartSearchFunction( } } + const truncated = data.expandIds.length > ids.length; ctx.logger.info("Smart search expanded", { requested: data.expandIds.length, - found: expanded.length, + attempted: ids.length, + returned: expanded.length, + truncated, }); - return { mode: "expanded", results: expanded }; + return { mode: "expanded", results: expanded, truncated }; } if (!data.query || typeof data.query !== "string" || !data.query.trim()) { diff --git a/src/health/recovery.ts b/src/health/recovery.ts deleted file mode 100644 index 34be8bb..0000000 --- a/src/health/recovery.ts +++ /dev/null @@ -1,22 +0,0 @@ -import type { HookPayload } from "../types.js"; - -const MAX_QUEUE_SIZE = 500; - -export class ObservationQueue { - private queue: HookPayload[] = []; - - enqueue(payload: HookPayload): boolean { - if (this.queue.length >= MAX_QUEUE_SIZE) return false; - this.queue.push(payload); - return true; - } - - drain(): HookPayload[] { - const items = this.queue.splice(0); - return items; - } - - get size(): number { - return this.queue.length; - } -} diff --git a/src/index.ts b/src/index.ts index 90c120f..ac3f6b7 100644 --- a/src/index.ts +++ b/src/index.ts @@ -140,19 +140,20 @@ async function main() { console.warn(`[agentmemory] Failed to load persisted index:`, err); return null; }); - if (loaded?.bm25) { - const restoredCount = loaded.bm25.size; - if (restoredCount > 0) { - console.log( - `[agentmemory] Loaded persisted BM25 index (${restoredCount} docs)`, - ); - } + if (loaded?.bm25 && loaded.bm25.size > 0) { + bm25Index.restoreFrom(loaded.bm25); + console.log( + `[agentmemory] Loaded persisted BM25 index (${bm25Index.size} docs)`, + ); + } + if (loaded?.vector && vectorIndex && loaded.vector.size > 0) { + vectorIndex.restoreFrom(loaded.vector); + console.log( + `[agentmemory] Loaded persisted vector index (${vectorIndex.size} vectors)`, + ); } - const needsRebuild = - !loaded?.bm25 || - loaded.bm25.size === 0 || - (embeddingProvider && vectorIndex && vectorIndex.size === 0); + const needsRebuild = bm25Index.size === 0; if (needsRebuild) { const indexCount = await rebuildIndex(kv).catch((err) => { diff --git a/src/mcp/server.ts b/src/mcp/server.ts index 306633b..de00b0b 100644 --- a/src/mcp/server.ts +++ b/src/mcp/server.ts @@ -230,10 +230,11 @@ export function registerMcpEndpoints( sdk.registerFunction( { id: "mcp::tools::list" }, - async (): Promise => ({ - status_code: 200, - body: { tools: MCP_TOOLS }, - }), + async (req: ApiRequest): Promise => { + const authErr = checkAuth(req, secret); + if (authErr) return authErr; + return { status_code: 200, body: { tools: MCP_TOOLS } }; + }, ); sdk.registerTrigger({ type: "http", @@ -375,6 +376,7 @@ export function registerMcpEndpoints( ? (args.expandIds as string) .split(",") .map((id: string) => id.trim()) + .slice(0, 20) : []; const result = await sdk.trigger("mem::smart-search", { query: args.query, @@ -478,7 +480,7 @@ export function registerMcpEndpoints( return { status_code: 500, body: { - error: err instanceof Error ? err.message : "Internal error", + error: "Internal error", }, }; } diff --git a/src/state/search-index.ts b/src/state/search-index.ts index f9f08f8..2427b58 100644 --- a/src/state/search-index.ts +++ b/src/state/search-index.ts @@ -117,6 +117,25 @@ export class SearchIndex { this.totalDocLength = 0; } + restoreFrom(other: SearchIndex): void { + this.entries = new Map( + Array.from(other.entries.entries()).map(([k, v]) => [k, { ...v }]), + ); + this.invertedIndex = new Map( + Array.from(other.invertedIndex.entries()).map(([k, v]) => [ + k, + new Set(v), + ]), + ); + this.docTermCounts = new Map( + Array.from(other.docTermCounts.entries()).map(([k, v]) => [ + k, + new Map(v), + ]), + ); + this.totalDocLength = other.totalDocLength; + } + serialize(): string { const entries = Array.from(this.entries.entries()); const inverted = Array.from(this.invertedIndex.entries()).map( @@ -135,19 +154,26 @@ export class SearchIndex { } static deserialize(json: string): SearchIndex { - const idx = new SearchIndex(); - const data = JSON.parse(json); - for (const [key, val] of data.entries) { - idx.entries.set(key, val); - } - for (const [term, ids] of data.inverted) { - idx.invertedIndex.set(term, new Set(ids)); - } - for (const [id, counts] of data.docTerms) { - idx.docTermCounts.set(id, new Map(counts)); + try { + const idx = new SearchIndex(); + const data = JSON.parse(json); + if (!data?.entries || !data?.inverted || !data?.docTerms) return idx; + for (const [key, val] of data.entries) { + idx.entries.set(key, val); + } + for (const [term, ids] of data.inverted) { + idx.invertedIndex.set(term, new Set(ids)); + } + for (const [id, counts] of data.docTerms) { + idx.docTermCounts.set(id, new Map(counts)); + } + const rawLen = Number(data.totalDocLength); + idx.totalDocLength = + Number.isFinite(rawLen) && rawLen >= 0 ? Math.floor(rawLen) : 0; + return idx; + } catch { + return new SearchIndex(); } - idx.totalDocLength = data.totalDocLength; - return idx; } private extractTerms(obs: CompressedObservation): string[] { diff --git a/src/state/vector-index.ts b/src/state/vector-index.ts index 1bd0787..7a8f442 100644 --- a/src/state/vector-index.ts +++ b/src/state/vector-index.ts @@ -21,10 +21,8 @@ function cosineSimilarity(a: Float32Array, b: Float32Array): number { } export class VectorIndex { - private vectors: Map< - string, - { embedding: Float32Array; sessionId: string } - > = new Map(); + private vectors: Map = + new Map(); add(obsId: string, sessionId: string, embedding: Float32Array): void { this.vectors.set(obsId, { embedding, sessionId }); @@ -61,10 +59,22 @@ export class VectorIndex { this.vectors.clear(); } + restoreFrom(other: VectorIndex): void { + const src = (other as any).vectors as Map< + string, + { embedding: Float32Array; sessionId: string } + >; + this.vectors = new Map(); + for (const [obsId, entry] of src) { + this.vectors.set(obsId, { + embedding: new Float32Array(entry.embedding), + sessionId: entry.sessionId, + }); + } + } + serialize(): string { - const data: Array< - [string, { embedding: string; sessionId: string }] - > = []; + const data: Array<[string, { embedding: string; sessionId: string }]> = []; for (const [obsId, entry] of this.vectors) { data.push([ obsId, @@ -79,14 +89,30 @@ export class VectorIndex { static deserialize(json: string): VectorIndex { const idx = new VectorIndex(); - const data: Array< - [string, { embedding: string; sessionId: string }] - > = JSON.parse(json); - for (const [obsId, entry] of data) { - idx.vectors.set(obsId, { - embedding: base64ToFloat32(entry.embedding), - sessionId: entry.sessionId, - }); + let data: unknown; + try { + data = JSON.parse(json); + } catch { + return idx; + } + if (!Array.isArray(data)) return idx; + for (const row of data) { + try { + if (!Array.isArray(row) || row.length < 2) continue; + const [obsId, entry] = row; + if ( + typeof obsId !== "string" || + typeof entry?.embedding !== "string" || + typeof entry?.sessionId !== "string" + ) + continue; + idx.vectors.set(obsId, { + embedding: base64ToFloat32(entry.embedding), + sessionId: entry.sessionId, + }); + } catch { + continue; + } } return idx; } diff --git a/src/triggers/api.ts b/src/triggers/api.ts index 20078ee..3d555d8 100644 --- a/src/triggers/api.ts +++ b/src/triggers/api.ts @@ -16,7 +16,7 @@ type Response = { }; const VIEWER_CSP = - "default-src 'self'; script-src 'unsafe-inline'; style-src 'unsafe-inline'; connect-src 'self' ws://localhost:* wss://localhost:*"; + "default-src 'none'; script-src 'unsafe-inline'; style-src 'unsafe-inline'; connect-src 'self' ws://localhost:* wss://localhost:*; img-src 'self'; font-src 'self'"; function checkAuth( req: ApiRequest, @@ -37,27 +37,46 @@ export function registerApiTriggers( metricsStore?: MetricsStore, provider?: ResilientProvider | { circuitState?: unknown }, ): void { - sdk.registerFunction({ id: "api::health" }, async (): Promise => { - const health = await getLatestHealth(kv); - const functionMetrics = metricsStore ? await metricsStore.getAll() : []; - const circuitBreaker = - provider && "circuitState" in provider ? provider.circuitState : null; - - const status = health?.status || "healthy"; - const statusCode = status === "critical" ? 503 : 200; - - return { - status_code: statusCode, - body: { - status, - service: "agentmemory", - version: "0.3.0", - health: health || null, - functionMetrics, - circuitBreaker, - }, - }; + sdk.registerFunction( + { id: "api::liveness" }, + async (): Promise => ({ + status_code: 200, + body: { status: "ok", service: "agentmemory" }, + }), + ); + sdk.registerTrigger({ + type: "http", + function_id: "api::liveness", + config: { api_path: "/agentmemory/livez", http_method: "GET" }, }); + + sdk.registerFunction( + { id: "api::health" }, + async (req: ApiRequest): Promise => { + const authErr = checkAuth(req, secret); + if (authErr) return authErr; + + const health = await getLatestHealth(kv); + const functionMetrics = metricsStore ? await metricsStore.getAll() : []; + const circuitBreaker = + provider && "circuitState" in provider ? provider.circuitState : null; + + const status = health?.status || "healthy"; + const statusCode = status === "critical" ? 503 : 200; + + return { + status_code: statusCode, + body: { + status, + service: "agentmemory", + version: "0.3.0", + health: health || null, + functionMetrics, + circuitBreaker, + }, + }; + }, + ); sdk.registerTrigger({ type: "http", function_id: "api::health", @@ -183,10 +202,15 @@ export function registerApiTriggers( config: { api_path: "/agentmemory/summarize", http_method: "POST" }, }); - sdk.registerFunction({ id: "api::sessions" }, async (): Promise => { - const sessions = await kv.list(KV.sessions); - return { status_code: 200, body: { sessions } }; - }); + sdk.registerFunction( + { id: "api::sessions" }, + async (req: ApiRequest): Promise => { + const authErr = checkAuth(req, secret); + if (authErr) return authErr; + const sessions = await kv.list(KV.sessions); + return { status_code: 200, body: { sessions } }; + }, + ); sdk.registerTrigger({ type: "http", function_id: "api::sessions", @@ -196,6 +220,8 @@ export function registerApiTriggers( sdk.registerFunction( { id: "api::observations" }, async (req: ApiRequest): Promise => { + const authErr = checkAuth(req, secret); + if (authErr) return authErr; const sessionId = req.query_params["sessionId"] as string; if (!sessionId) return { status_code: 400, body: { error: "sessionId required" } };