diff --git a/docs/content/docs/(core)/cortex.mdx b/docs/content/docs/(core)/cortex.mdx index 41aef7c5f..7f47c213f 100644 --- a/docs/content/docs/(core)/cortex.mdx +++ b/docs/content/docs/(core)/cortex.mdx @@ -36,7 +36,7 @@ Each bulletin generation pass does: This design avoids the problem of an LLM formulating search queries without conversation context. The retrieval phase uses `SearchMode::Typed`, `SearchMode::Recent`, and `SearchMode::Important` — metadata-based modes that query SQLite directly without needing vector embeddings or search terms. The LLM only gets involved for the part it's good at: turning structured data into readable prose. -On startup, Spacebot runs a best-effort warmup pass before adapters accept traffic (bounded wait), so the first bulletin is usually already present when the first user message arrives. If generation fails, the previous bulletin is preserved. If the memory graph is empty, an empty bulletin is stored without invoking the LLM. +On startup, Spacebot runs a best-effort warmup pass before adapters accept traffic (bounded wait), so the first bulletin is usually already present when the first user message arrives. The same pass also refreshes a warm recall cache of high-importance memories used only as degraded fallback context when hybrid branch recall fails. If generation fails, the previous bulletin is preserved. If the memory graph is empty, an empty bulletin is stored without invoking the LLM. ### What Channels See diff --git a/docs/content/docs/(core)/memory.mdx b/docs/content/docs/(core)/memory.mdx index 6577bba84..ed5781c6c 100644 --- a/docs/content/docs/(core)/memory.mdx +++ b/docs/content/docs/(core)/memory.mdx @@ -99,7 +99,7 @@ Memory recall is always delegated to a worker. No LLM process ever queries the d The `memory_recall` tool supports four search modes, each suited to different retrieval needs: -**Hybrid** (default) -- Full pipeline: vector similarity (LanceDB HNSW) + full-text search (Tantivy) + graph traversal, merged via Reciprocal Rank Fusion (RRF). Requires a query string. Best when you have a specific topic to search for and conversation context to inform the query. +**Hybrid** (default) -- Full pipeline: vector similarity (LanceDB HNSW) + full-text search (Tantivy) + graph traversal, merged via Reciprocal Rank Fusion (RRF). Requires a query string. Best when you have a specific topic to search for and conversation context to inform the query. If the hybrid search path errors, Spacebot can return a degraded fallback from a warm importance-sorted cache populated during warmup. This is availability hardening, not a replacement for normal hybrid retrieval. **Recent** -- Returns the most recent memories ordered by `created_at`. No query needed, no vector/FTS overhead. Pure SQLite. Best for temporal awareness -- "what just happened?" diff --git a/docs/content/docs/(deployment)/roadmap.mdx b/docs/content/docs/(deployment)/roadmap.mdx index 61afccb9e..cc110b929 100644 --- a/docs/content/docs/(deployment)/roadmap.mdx +++ b/docs/content/docs/(deployment)/roadmap.mdx @@ -21,6 +21,7 @@ The full message-in → LLM → response-out pipeline is wired end-to-end across - **LLM** — `SpacebotModel` implements Rig's `CompletionModel`, routes through `LlmManager` via HTTP with retries and fallback chains across 13 providers (Anthropic, OpenAI, OpenRouter, Kilo Gateway, Z.ai, Groq, Together, Fireworks, DeepSeek, xAI, Mistral, OpenCode Zen, OpenCode Go) - **Model routing** — `RoutingConfig` with process-type defaults, task overrides, fallback chains - **Memory** — full stack: types, SQLite store (CRUD + graph), LanceDB (embeddings + vector + FTS), fastembed, hybrid search (RRF fusion). `memory_type` filter wired end-to-end through SearchConfig. `total_cmp` for safe sorting. +- **Warm recall degraded fallback** — warmup refreshes an importance-sorted cache that branch recall can use when the hybrid search path fails, with epoch/lock coordination to avoid reintroducing forgotten memories during concurrent cache mutation. - **Memory maintenance** — decay + prune implemented - **Identity** — `Identity` struct loads SOUL.md/IDENTITY.md/ROLE.md from agent root, `Prompts` with fallback chain - **Agent loops** — all three process types run real Rig loops: diff --git a/src/agent/cortex.rs b/src/agent/cortex.rs index 8e350e6e1..e7e3a9efa 100644 --- a/src/agent/cortex.rs +++ b/src/agent/cortex.rs @@ -36,6 +36,8 @@ use std::sync::atomic::{AtomicU8, Ordering}; use std::time::{Duration, Instant}; use tokio::sync::{RwLock, broadcast}; +const WARM_RECALL_MEMORY_LIMIT: usize = 50; + fn update_warmup_status(deps: &AgentDeps, update: F) where F: FnOnce(&mut crate::config::WarmupStatus), @@ -1528,6 +1530,19 @@ pub async fn run_warmup_once(deps: &AgentDeps, logger: &CortexLogger, reason: &s errors.push("bulletin generation failed".to_string()); } + let warm_recall_count = match refresh_warm_recall_memories(deps).await { + Ok(count) => Some(count), + Err(error) => { + let refresh_error = format!("warm recall refresh failed: {error}"); + tracing::warn!( + error = %refresh_error, + "warm recall refresh failed during warmup, keeping previous cache" + ); + errors.push(refresh_error); + None + } + }; + let now_ms = chrono::Utc::now().timestamp_millis(); if errors.is_empty() { update_warmup_status(deps, |status| { @@ -1544,6 +1559,7 @@ pub async fn run_warmup_once(deps: &AgentDeps, logger: &CortexLogger, reason: &s Some(serde_json::json!({ "reason": reason, "embedding_ready": embedding_ready, + "warm_recall_count": warm_recall_count, "forced": force, })), ); @@ -1562,12 +1578,76 @@ pub async fn run_warmup_once(deps: &AgentDeps, logger: &CortexLogger, reason: &s Some(serde_json::json!({ "reason": reason, "errors": errors, + "warm_recall_count": warm_recall_count, "forced": force, })), ); } } +async fn refresh_warm_recall_memories(deps: &AgentDeps) -> Result { + let start_epoch = deps + .runtime_config + .warm_recall_cache_epoch + .load(std::sync::atomic::Ordering::SeqCst); + let config = SearchConfig { + mode: SearchMode::Important, + sort_by: SearchSort::Importance, + max_results: WARM_RECALL_MEMORY_LIMIT, + max_results_per_source: WARM_RECALL_MEMORY_LIMIT, + ..Default::default() + }; + + let results = deps.memory_search.search("", &config).await?; + let memories = results + .into_iter() + .map(|result| result.memory) + .collect::>(); + + Ok(apply_warm_recall_refresh( + deps.runtime_config.as_ref(), + start_epoch, + memories, + chrono::Utc::now().timestamp_millis(), + ) + .await) +} + +async fn apply_warm_recall_refresh( + runtime_config: &crate::config::RuntimeConfig, + start_epoch: u64, + memories: Vec, + refreshed_at_unix_ms: i64, +) -> usize { + let count = memories.len(); + let _warm_recall_cache_guard = runtime_config.warm_recall_cache_lock.lock().await; + let current_epoch = runtime_config + .warm_recall_cache_epoch + .load(std::sync::atomic::Ordering::SeqCst); + if current_epoch != start_epoch { + let retained_count = runtime_config.warm_recall_memories.load().len(); + tracing::debug!( + start_epoch, + current_epoch, + retained_count, + "skipping warm recall refresh apply due to concurrent cache mutation" + ); + return retained_count; + } + + runtime_config + .warm_recall_memories + .store(Arc::new(memories)); + runtime_config + .warm_recall_refreshed_at_unix_ms + .store(Arc::new(Some(refreshed_at_unix_ms))); + runtime_config + .warm_recall_cache_epoch + .fetch_add(1, std::sync::atomic::Ordering::SeqCst); + + count +} + /// Trigger a forced warmup pass in the background from a dispatch path. /// /// This helper never blocks the caller. It is intended for readiness guards on @@ -3253,9 +3333,10 @@ mod tests { use super::{ BULLETIN_REFRESH_CIRCUIT_OPEN_SECS, BULLETIN_REFRESH_CIRCUIT_OPEN_THRESHOLD, BranchTracker, BulletinRefreshOutcome, CortexReceiverOutcome, HealthRuntimeState, ReceiverClosedBehavior, - Signal, WorkerTracker, apply_cancelled_warmup_status, build_kill_targets, - claim_detached_completion, detached_timeout_transition, handle_cortex_receiver_result, - has_completed_initial_warmup, is_cancelled_control_result, is_terminal_control_result, + Signal, WARM_RECALL_MEMORY_LIMIT, WorkerTracker, apply_cancelled_warmup_status, + apply_warm_recall_refresh, build_kill_targets, claim_detached_completion, + detached_timeout_transition, handle_cortex_receiver_result, has_completed_initial_warmup, + is_cancelled_control_result, is_terminal_control_result, maybe_close_bulletin_refresh_circuit, maybe_generate_bulletin_under_lock, parse_structured_success_flag, push_signal_into_buffer, record_bulletin_refresh_failure, should_execute_warmup, should_generate_bulletin_from_bulletin_loop, signal_from_event, @@ -3263,17 +3344,66 @@ mod tests { }; use crate::ProcessEvent; use crate::agent::process_control::ControlActionResult; + use crate::config::{Config, RuntimeConfig}; + use crate::identity::Identity; use crate::memory::MemoryType; + use crate::memory::search::{SearchConfig, SearchMode, SearchSort}; + use crate::memory::{EmbeddingModel, EmbeddingTable, Memory, MemorySearch, MemoryStore}; + use crate::prompts::PromptEngine; + use crate::skills::SkillSet; use crate::tasks::TaskStatus; use crate::tasks::TaskStore; + use crate::tools::{MemoryDeleteArgs, MemoryDeleteTool}; use futures::FutureExt; use futures::future; + use rig::tool::Tool as _; use sqlx::sqlite::SqlitePoolOptions; use std::collections::VecDeque; use std::sync::Arc; use std::sync::atomic::{AtomicU8, AtomicUsize, Ordering}; use std::time::{Duration, Instant}; + fn test_runtime_config(instance_dir: &std::path::Path) -> Arc { + let config = Config::load_from_env(instance_dir).expect("failed to build config"); + let resolved = config + .resolve_agents() + .into_iter() + .next() + .expect("missing resolved agent config"); + let prompts = PromptEngine::new("en").expect("failed to build prompt engine"); + + Arc::new(RuntimeConfig::new( + instance_dir, + &resolved, + &config.defaults, + prompts, + Identity::default(), + SkillSet::default(), + )) + } + + async fn test_memory_search() -> Arc { + let store = MemoryStore::connect_in_memory().await; + let lance_dir = tempfile::tempdir().expect("failed to create lance temp dir"); + let lance_conn = lancedb::connect( + lance_dir + .path() + .to_str() + .expect("temp path should be valid UTF-8"), + ) + .execute() + .await + .expect("failed to connect lancedb"); + let embedding_table = EmbeddingTable::open_or_create(&lance_conn) + .await + .expect("failed to open embedding table"); + let embedding_model = Arc::new( + EmbeddingModel::new(lance_dir.path()).expect("failed to init embedding model"), + ); + + Arc::new(MemorySearch::new(store, embedding_table, embedding_model)) + } + #[test] fn run_warmup_once_semantics_skip_when_disabled_without_force() { let warmup_config = crate::config::WarmupConfig { @@ -3467,6 +3597,187 @@ mod tests { assert_eq!(calls.load(Ordering::SeqCst), 0); } + #[tokio::test] + async fn apply_warm_recall_refresh_skips_stale_results_after_concurrent_epoch_change() { + let temp_dir = tempfile::tempdir().expect("failed to create temp dir"); + let runtime_config = test_runtime_config(temp_dir.path()); + let retained = Memory::new("retained cache entry", MemoryType::Fact).with_importance(0.9); + + runtime_config + .warm_recall_memories + .store(Arc::new(vec![retained.clone()])); + runtime_config + .warm_recall_refreshed_at_unix_ms + .store(Arc::new(Some(111))); + runtime_config + .warm_recall_cache_epoch + .fetch_add(1, Ordering::SeqCst); + + let applied_count = apply_warm_recall_refresh( + runtime_config.as_ref(), + 0, + vec![Memory::new("stale refresh result", MemoryType::Fact)], + 222, + ) + .await; + + assert_eq!(applied_count, 1); + assert_eq!( + runtime_config.warm_recall_memories.load().as_ref(), + &vec![retained] + ); + assert_eq!( + *runtime_config + .warm_recall_refreshed_at_unix_ms + .load() + .as_ref(), + Some(111) + ); + assert_eq!( + runtime_config + .warm_recall_cache_epoch + .load(Ordering::SeqCst), + 1 + ); + } + + #[tokio::test] + async fn apply_warm_recall_refresh_overwrites_cache_when_epoch_is_stable() { + let temp_dir = tempfile::tempdir().expect("failed to create temp dir"); + let runtime_config = test_runtime_config(temp_dir.path()); + let fresh = Memory::new("fresh refresh result", MemoryType::Decision).with_importance(0.8); + + let applied_count = + apply_warm_recall_refresh(runtime_config.as_ref(), 0, vec![fresh.clone()], 333).await; + + assert_eq!(applied_count, 1); + assert_eq!( + runtime_config.warm_recall_memories.load().as_ref(), + &vec![fresh] + ); + assert_eq!( + *runtime_config + .warm_recall_refreshed_at_unix_ms + .load() + .as_ref(), + Some(333) + ); + assert_eq!( + runtime_config + .warm_recall_cache_epoch + .load(Ordering::SeqCst), + 1 + ); + } + + #[tokio::test] + async fn stale_warm_refresh_does_not_reintroduce_memory_forgotten_via_tool() { + let temp_dir = tempfile::tempdir().expect("failed to create temp dir"); + let runtime_config = test_runtime_config(temp_dir.path()); + let memory_search = test_memory_search().await; + + let forgotten_candidate = + Memory::new("highest importance memory to forget", MemoryType::Decision) + .with_importance(0.95); + let retained_memory = Memory::new("second memory retained in warm cache", MemoryType::Fact) + .with_importance(0.85); + + memory_search + .store() + .save(&forgotten_candidate) + .await + .expect("failed to save forget candidate"); + memory_search + .store() + .save(&retained_memory) + .await + .expect("failed to save retained memory"); + + let stale_snapshot = memory_search + .search( + "", + &SearchConfig { + mode: SearchMode::Important, + sort_by: SearchSort::Importance, + max_results: WARM_RECALL_MEMORY_LIMIT, + max_results_per_source: WARM_RECALL_MEMORY_LIMIT, + ..Default::default() + }, + ) + .await + .expect("failed to build warm refresh snapshot") + .into_iter() + .map(|result| result.memory) + .collect::>(); + assert!( + stale_snapshot + .iter() + .any(|memory| memory.id == forgotten_candidate.id), + "sanity check: stale snapshot should include forget candidate" + ); + + runtime_config + .warm_recall_memories + .store(Arc::new(stale_snapshot.clone())); + runtime_config + .warm_recall_refreshed_at_unix_ms + .store(Arc::new(Some(111))); + + let start_epoch = runtime_config + .warm_recall_cache_epoch + .load(Ordering::SeqCst); + let delete_tool = + MemoryDeleteTool::with_runtime(Arc::clone(&memory_search), Arc::clone(&runtime_config)); + let delete_result = delete_tool + .call(MemoryDeleteArgs { + memory_id: forgotten_candidate.id.clone(), + reason: Some("integration test".to_string()), + }) + .await + .expect("memory delete should succeed"); + assert!( + delete_result.forgotten, + "tool should forget the target memory" + ); + + let applied_count = + apply_warm_recall_refresh(runtime_config.as_ref(), start_epoch, stale_snapshot, 222) + .await; + assert_eq!( + applied_count, 1, + "stale apply should report retained cache size after delete/evict" + ); + + let warm_cache = runtime_config.warm_recall_memories.load(); + assert_eq!(warm_cache.len(), 1); + assert_eq!(warm_cache[0].id, retained_memory.id); + assert_eq!( + *runtime_config + .warm_recall_refreshed_at_unix_ms + .load() + .as_ref(), + Some(111), + "stale refresh must not overwrite the retained-cache timestamp" + ); + assert_eq!( + runtime_config + .warm_recall_cache_epoch + .load(Ordering::SeqCst), + 1, + "delete/evict should be the only cache mutation that advanced the epoch" + ); + assert!( + memory_search + .store() + .load(&forgotten_candidate.id) + .await + .expect("memory load should succeed") + .expect("forgotten candidate should still exist in store") + .forgotten, + "store record should remain forgotten after stale refresh apply" + ); + } + #[test] fn summarize_signal_text_uses_first_non_empty_line() { let text = "\n\nfirst line\nsecond line"; diff --git a/src/config/runtime.rs b/src/config/runtime.rs index 58c36ce47..d8aecd06f 100644 --- a/src/config/runtime.rs +++ b/src/config/runtime.rs @@ -51,6 +51,18 @@ pub struct RuntimeConfig { /// Cached memory bulletin generated by the cortex. Injected into every /// channel's system prompt. Empty string until the first cortex run. pub memory_bulletin: ArcSwap, + /// Warm memory set used as degraded fallback context for branch recall. + pub warm_recall_memories: ArcSwap>, + /// Last refresh timestamp for the warm recall memory set. + pub warm_recall_refreshed_at_unix_ms: ArcSwap>, + /// Synchronizes warm recall cache writers (warmup refresh and memory eviction). + pub warm_recall_cache_lock: Arc>, + /// Memory IDs currently being forgotten with per-ID in-flight reference + /// counts; warm fallback excludes any ID with a count > 0. + pub warm_recall_inflight_forget_counts: + Arc>>, + /// Monotonic epoch for warm recall cache mutations. + pub warm_recall_cache_epoch: std::sync::atomic::AtomicU64, pub prompts: ArcSwap, pub identity: ArcSwap, pub skills: ArcSwap, @@ -128,6 +140,13 @@ impl RuntimeConfig { warmup_status: ArcSwap::from_pointee(WarmupStatus::default()), warmup_lock: Arc::new(tokio::sync::Mutex::new(())), memory_bulletin: ArcSwap::from_pointee(String::new()), + warm_recall_memories: ArcSwap::from_pointee(Vec::new()), + warm_recall_refreshed_at_unix_ms: ArcSwap::from_pointee(None), + warm_recall_cache_lock: Arc::new(tokio::sync::Mutex::new(())), + warm_recall_inflight_forget_counts: Arc::new(std::sync::Mutex::new( + std::collections::BTreeMap::new(), + )), + warm_recall_cache_epoch: std::sync::atomic::AtomicU64::new(0), prompts: ArcSwap::from_pointee(prompts), identity: ArcSwap::from_pointee(identity), skills: ArcSwap::from_pointee(skills), diff --git a/src/tools.rs b/src/tools.rs index 4ed3bd7f5..cf28ceccc 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -499,8 +499,14 @@ pub fn create_branch_tool_server( agent_id.clone(), memory_event_tx.clone(), )) - .tool(MemoryRecallTool::new(memory_search.clone())) - .tool(MemoryDeleteTool::new(memory_search)) + .tool(MemoryRecallTool::with_runtime( + memory_search.clone(), + runtime_config.clone(), + )) + .tool(MemoryDeleteTool::with_runtime( + memory_search, + runtime_config.clone(), + )) .tool(ChannelRecallTool::new(conversation_logger, channel_store)) .tool(SpacebotDocsTool::new()) .tool(EmailSearchTool::new(runtime_config)) @@ -640,8 +646,14 @@ pub fn create_cortex_chat_tool_server( agent_id.clone(), memory_event_tx, )) - .tool(MemoryRecallTool::new(memory_search.clone())) - .tool(MemoryDeleteTool::new(memory_search)) + .tool(MemoryRecallTool::with_runtime( + memory_search.clone(), + runtime_config.clone(), + )) + .tool(MemoryDeleteTool::with_runtime( + memory_search, + runtime_config.clone(), + )) .tool(ChannelRecallTool::new(conversation_logger, channel_store)) .tool(SpacebotDocsTool::new()) .tool(ConfigInspectTool::new( diff --git a/src/tools/memory_delete.rs b/src/tools/memory_delete.rs index af0383f9b..e72c46b2b 100644 --- a/src/tools/memory_delete.rs +++ b/src/tools/memory_delete.rs @@ -3,26 +3,141 @@ //! Soft-deletes a memory by setting its `forgotten` flag. The memory stays in //! the database but is excluded from all search and recall operations. +use crate::config::RuntimeConfig; +use crate::memory::Memory; use crate::memory::MemorySearch; use rig::completion::ToolDefinition; use rig::tool::Tool; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; +use std::collections::BTreeMap; use std::sync::Arc; +use std::sync::atomic::Ordering; /// Tool for soft-deleting memories. #[derive(Debug, Clone)] pub struct MemoryDeleteTool { memory_search: Arc, + runtime_config: Option>, } impl MemoryDeleteTool { /// Create a new memory delete tool. pub fn new(memory_search: Arc) -> Self { - Self { memory_search } + Self { + memory_search, + runtime_config: None, + } + } + + /// Create a memory delete tool with runtime warm-recall cache support. + pub fn with_runtime( + memory_search: Arc, + runtime_config: Arc, + ) -> Self { + Self { + memory_search, + runtime_config: Some(runtime_config), + } + } + + fn evict_from_warm_cache(runtime_config: &RuntimeConfig, memory_id: &str) -> bool { + let current_memories = runtime_config.warm_recall_memories.load(); + let refreshed_at_unix_ms = *runtime_config + .warm_recall_refreshed_at_unix_ms + .load() + .as_ref(); + let (updated_memories, refreshed_at_unix_ms, removed) = remove_warm_cache_memory_by_id( + current_memories.as_ref(), + memory_id, + refreshed_at_unix_ms, + ); + if !removed { + return false; + } + + runtime_config + .warm_recall_memories + .store(Arc::new(updated_memories)); + // Eviction is a partial mutation, not a full refresh. Keep the existing timestamp. + runtime_config + .warm_recall_refreshed_at_unix_ms + .store(Arc::new(refreshed_at_unix_ms)); + true } } +struct InflightForgetGuard { + counts: Arc>>, + memory_id: String, +} + +impl InflightForgetGuard { + fn new(runtime_config: &RuntimeConfig, memory_id: &str) -> Self { + Self::from_counts( + Arc::clone(&runtime_config.warm_recall_inflight_forget_counts), + memory_id, + ) + } + + fn from_counts( + counts: Arc>>, + memory_id: &str, + ) -> Self { + { + let mut counts = lock_inflight_counts(&counts); + *counts.entry(memory_id.to_string()).or_insert(0) += 1; + } + + Self { + counts, + memory_id: memory_id.to_string(), + } + } +} + +impl Drop for InflightForgetGuard { + fn drop(&mut self) { + let mut counts = lock_inflight_counts(&self.counts); + if let Some(count) = counts.get_mut(&self.memory_id) { + if *count <= 1 { + counts.remove(&self.memory_id); + } else { + *count -= 1; + } + } + } +} + +fn lock_inflight_counts( + counts: &std::sync::Mutex>, +) -> std::sync::MutexGuard<'_, BTreeMap> { + match counts.lock() { + Ok(guard) => guard, + Err(poisoned) => { + tracing::warn!("warm recall inflight forget counts lock poisoned; recovering state"); + poisoned.into_inner() + } + } +} + +fn remove_warm_cache_memory_by_id( + memories: &[Memory], + memory_id: &str, + refreshed_at_unix_ms: Option, +) -> (Vec, Option, bool) { + let mut removed = false; + let mut updated = Vec::with_capacity(memories.len()); + for memory in memories { + if memory.id == memory_id { + removed = true; + continue; + } + updated.push(memory.clone()); + } + (updated, refreshed_at_unix_ms, removed) +} + /// Error type for memory delete tool. #[derive(Debug, thiserror::Error)] #[error("Memory delete failed: {0}")] @@ -91,23 +206,74 @@ impl Tool for MemoryDeleteTool { }; if memory.forgotten { + let removed_from_warm_cache = if let Some(runtime_config) = self.runtime_config.as_ref() + { + let _warm_recall_cache_guard = runtime_config.warm_recall_cache_lock.lock().await; + let removed = + MemoryDeleteTool::evict_from_warm_cache(runtime_config, &args.memory_id); + if removed { + runtime_config + .warm_recall_cache_epoch + .fetch_add(1, Ordering::SeqCst); + } + removed + } else { + false + }; return Ok(MemoryDeleteOutput { forgotten: false, - message: format!("Memory {} is already forgotten.", args.memory_id), + message: format!( + "Memory {} is already forgotten. Warm cache evicted: {}.", + args.memory_id, removed_from_warm_cache + ), }); } - let was_forgotten = store - .forget(&args.memory_id) - .await - .map_err(|e| MemoryDeleteError(format!("Failed to forget memory: {e}")))?; - let reason_suffix = args .reason .as_deref() .map(|r| format!(" Reason: {r}")) .unwrap_or_default(); + let (was_forgotten, removed_from_warm_cache) = + if let Some(runtime_config) = self.runtime_config.as_ref() { + let runtime_config = Arc::clone(runtime_config); + let memory_search = Arc::clone(&self.memory_search); + let memory_id = args.memory_id.clone(); + let inflight_forget_guard = InflightForgetGuard::new(&runtime_config, &memory_id); + let forget_and_evict_task = tokio::spawn(async move { + let store = memory_search.store(); + let was_forgotten = store.forget(&memory_id).await.map_err(|error| { + MemoryDeleteError(format!("Failed to forget memory: {error}")) + })?; + + let mut removed_from_warm_cache = false; + if was_forgotten { + let _warm_recall_cache_guard = + runtime_config.warm_recall_cache_lock.lock().await; + removed_from_warm_cache = + MemoryDeleteTool::evict_from_warm_cache(&runtime_config, &memory_id); + runtime_config + .warm_recall_cache_epoch + .fetch_add(1, Ordering::SeqCst); + } else { + drop(inflight_forget_guard); + } + + Ok::<(bool, bool), MemoryDeleteError>((was_forgotten, removed_from_warm_cache)) + }); + + forget_and_evict_task.await.map_err(|error| { + MemoryDeleteError(format!("Forget and eviction task failed: {error}")) + })?? + } else { + let was_forgotten = store + .forget(&args.memory_id) + .await + .map_err(|e| MemoryDeleteError(format!("Failed to forget memory: {e}")))?; + (was_forgotten, false) + }; + if was_forgotten { #[cfg(feature = "metrics")] crate::telemetry::Metrics::global() @@ -119,6 +285,7 @@ impl Tool for MemoryDeleteTool { memory_id = %args.memory_id, memory_type = %memory.memory_type, reason = ?args.reason, + removed_from_warm_cache, "memory forgotten" ); @@ -147,3 +314,115 @@ fn truncate(s: &str, max: usize) -> &str { &s[..s.floor_char_boundary(max)] } } + +#[cfg(test)] +mod tests { + use super::{InflightForgetGuard, lock_inflight_counts, remove_warm_cache_memory_by_id}; + use crate::memory::{Memory, MemoryType}; + use std::collections::BTreeMap; + use std::sync::{Arc, Mutex}; + + #[test] + fn remove_warm_cache_memory_by_id_removes_matching_memory() { + let keep = Memory::new("keep", MemoryType::Fact); + let remove = Memory::new("remove", MemoryType::Fact); + let refreshed_at_unix_ms = Some(11_000); + + let (updated, updated_refreshed_at_unix_ms, removed_flag) = remove_warm_cache_memory_by_id( + &[keep.clone(), remove.clone()], + &remove.id, + refreshed_at_unix_ms, + ); + + assert!(removed_flag); + assert_eq!(updated.len(), 1); + assert_eq!(updated[0].id, keep.id); + assert_eq!(updated_refreshed_at_unix_ms, refreshed_at_unix_ms); + } + + #[test] + fn remove_warm_cache_memory_by_id_keeps_cache_when_id_missing() { + let memory = Memory::new("keep", MemoryType::Fact); + let refreshed_at_unix_ms = Some(22_000); + + let (updated, updated_refreshed_at_unix_ms, removed_flag) = remove_warm_cache_memory_by_id( + std::slice::from_ref(&memory), + "missing-id", + refreshed_at_unix_ms, + ); + + assert!(!removed_flag); + assert_eq!(updated.len(), 1); + assert_eq!(updated[0].id, memory.id); + assert_eq!(updated_refreshed_at_unix_ms, refreshed_at_unix_ms); + } + + #[test] + fn remove_warm_cache_memory_by_id_compose_evictions_without_reintroducing() { + let first = Memory::new("first", MemoryType::Fact); + let second = Memory::new("second", MemoryType::Fact); + let third = Memory::new("third", MemoryType::Fact); + let refreshed_at_unix_ms = Some(33_000); + + let (after_first, after_first_refreshed_at_unix_ms, removed_first) = + remove_warm_cache_memory_by_id( + &[first.clone(), second.clone(), third.clone()], + &first.id, + refreshed_at_unix_ms, + ); + let (after_second, after_second_refreshed_at_unix_ms, removed_second) = + remove_warm_cache_memory_by_id( + &after_first, + &second.id, + after_first_refreshed_at_unix_ms, + ); + + assert!(removed_first); + assert!(removed_second); + assert_eq!(after_second.len(), 1); + assert_eq!(after_second[0].id, third.id); + assert_eq!(after_second_refreshed_at_unix_ms, refreshed_at_unix_ms); + } + + #[test] + fn inflight_forget_guard_refcounts_same_memory_id() { + let counts = Arc::new(Mutex::new(BTreeMap::new())); + let guard_one = InflightForgetGuard::from_counts(Arc::clone(&counts), "memory-1"); + let guard_two = InflightForgetGuard::from_counts(Arc::clone(&counts), "memory-1"); + + assert_eq!( + lock_inflight_counts(&counts).get("memory-1").copied(), + Some(2) + ); + + drop(guard_one); + assert_eq!( + lock_inflight_counts(&counts).get("memory-1").copied(), + Some(1) + ); + + drop(guard_two); + assert!(lock_inflight_counts(&counts).get("memory-1").is_none()); + } + + #[test] + fn inflight_forget_guard_tracks_multiple_ids_independently() { + let counts = Arc::new(Mutex::new(BTreeMap::new())); + let guard_one = InflightForgetGuard::from_counts(Arc::clone(&counts), "memory-1"); + let guard_two = InflightForgetGuard::from_counts(Arc::clone(&counts), "memory-2"); + + let snapshot = lock_inflight_counts(&counts).clone(); + assert_eq!(snapshot.get("memory-1"), Some(&1)); + assert_eq!(snapshot.get("memory-2"), Some(&1)); + + drop(guard_one); + assert!(lock_inflight_counts(&counts).get("memory-1").is_none()); + assert_eq!( + lock_inflight_counts(&counts).get("memory-2").copied(), + Some(1) + ); + + drop(guard_two); + assert!(lock_inflight_counts(&counts).is_empty()); + } +} diff --git a/src/tools/memory_recall.rs b/src/tools/memory_recall.rs index 42bce96e7..a982e4f91 100644 --- a/src/tools/memory_recall.rs +++ b/src/tools/memory_recall.rs @@ -1,27 +1,205 @@ //! Memory recall tool for branches. +use crate::config::RuntimeConfig; use crate::error::Result; use crate::memory::MemorySearch; use crate::memory::search::{SearchConfig, SearchMode, SearchSort, curate_results}; -use crate::memory::types::Memory; +use crate::memory::types::{Memory, MemorySearchResult, MemoryType}; use rig::completion::ToolDefinition; use rig::tool::Tool; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; +use std::collections::BTreeSet; use std::sync::Arc; /// Tool for recalling memories using hybrid search. #[derive(Debug, Clone)] pub struct MemoryRecallTool { memory_search: Arc, + runtime_config: Option>, } impl MemoryRecallTool { /// Create a new memory recall tool. pub fn new(memory_search: Arc) -> Self { - Self { memory_search } + Self { + memory_search, + runtime_config: None, + } + } + + /// Create a memory recall tool with runtime warm-recall cache support. + pub fn with_runtime( + memory_search: Arc, + runtime_config: Arc, + ) -> Self { + Self { + memory_search, + runtime_config: Some(runtime_config), + } + } + + async fn warm_cache_results( + &self, + query: &str, + memory_type: Option, + max_results: usize, + ) -> Vec { + let Some(runtime_config) = self.runtime_config.as_ref() else { + return Vec::new(); + }; + let _warm_recall_cache_guard = runtime_config.warm_recall_cache_lock.lock().await; + + let now_unix_ms = chrono::Utc::now().timestamp_millis(); + let refreshed_at_unix_ms = *runtime_config + .warm_recall_refreshed_at_unix_ms + .load() + .as_ref(); + let Some(age_secs) = warm_cache_age_secs(refreshed_at_unix_ms, now_unix_ms) else { + return Vec::new(); + }; + let warmup_refresh_secs = runtime_config.warmup.load().as_ref().refresh_secs.max(1); + if age_secs > warmup_refresh_secs { + return Vec::new(); + } + + let warm_memories = runtime_config.warm_recall_memories.load(); + let inflight_forget_ids = snapshot_inflight_forget_ids(runtime_config); + score_warm_memories( + query, + warm_memories.as_ref(), + memory_type, + max_results, + &inflight_forget_ids, + ) + } +} + +fn snapshot_inflight_forget_ids(runtime_config: &RuntimeConfig) -> BTreeSet { + let counts = match runtime_config.warm_recall_inflight_forget_counts.lock() { + Ok(guard) => guard, + Err(poisoned) => { + tracing::warn!("warm recall inflight forget counts lock poisoned; recovering state"); + poisoned.into_inner() + } + }; + counts.keys().cloned().collect() +} + +fn warm_cache_age_secs(refreshed_at_unix_ms: Option, now_unix_ms: i64) -> Option { + refreshed_at_unix_ms.map(|refresh_ms| { + if now_unix_ms > refresh_ms { + ((now_unix_ms - refresh_ms) / 1000) as u64 + } else { + 0 + } + }) +} + +fn tokenize_query_terms(query: &str) -> Vec { + query + .split(|c: char| !c.is_alphanumeric()) + .filter(|term| term.len() >= 2) + .map(|term| term.to_lowercase()) + .collect() +} + +fn score_warm_memories( + query: &str, + warm_memories: &[Memory], + memory_type: Option, + max_results: usize, + inflight_forget_ids: &BTreeSet, +) -> Vec { + if max_results == 0 { + return Vec::new(); + } + + let query_terms = tokenize_query_terms(query); + if query_terms.is_empty() { + return Vec::new(); + } + + let mut results = Vec::new(); + + for memory in warm_memories { + if let Some(score) = + score_memory_with_terms(&query_terms, memory, memory_type, inflight_forget_ids) + { + results.push(MemorySearchResult { + memory: memory.clone(), + score, + rank: 0, + }); + } + } + + results.sort_by(|left, right| { + right + .score + .total_cmp(&left.score) + .then_with(|| right.memory.importance.total_cmp(&left.memory.importance)) + }); + results.truncate(max_results); + + for (index, result) in results.iter_mut().enumerate() { + result.rank = index + 1; + } + + results +} + +fn score_memory_with_terms( + query_terms: &[String], + memory: &Memory, + memory_type: Option, + inflight_forget_ids: &BTreeSet, +) -> Option { + if memory.forgotten { + return None; + } + + if inflight_forget_ids.contains(&memory.id) { + return None; + } + + if memory_type.is_some_and(|kind| memory.memory_type != kind) { + return None; + } + + let content = memory.content.to_lowercase(); + let matched_terms = query_terms + .iter() + .filter(|term| content.contains(term.as_str())) + .count(); + if matched_terms == 0 { + return None; + } + + let coverage = matched_terms as f32 / query_terms.len() as f32; + Some(coverage * 0.85 + memory.importance.clamp(0.0, 1.0) * 0.15) +} + +fn select_hybrid_results( + warm_results: Vec, + search_result: std::result::Result, MemoryRecallError>, +) -> std::result::Result<(Vec, Option), MemoryRecallError> { + match search_result { + Ok(search_results) => Ok((search_results, None)), + Err(error) => { + if warm_results.is_empty() { + return Err(error); + } + let search_error = error.0; + tracing::warn!( + error = %search_error, + warm_matches = warm_results.len(), + "hybrid search failed, returning partial warm recall results" + ); + Ok((warm_results, Some(search_error))) + } } } @@ -111,6 +289,10 @@ pub struct MemoryRecallOutput { pub memories: Vec, /// Total number of results found before curation. pub total_found: usize, + /// True when hybrid search failed and warm cache fallback was returned. + pub degraded_fallback_used: bool, + /// Root-cause error from hybrid search when degraded fallback was used. + pub degraded_fallback_error: Option, /// Formatted summary of the memories. pub summary: String, } @@ -221,11 +403,36 @@ impl Tool for MemoryRecallTool { }; let query = args.query.as_deref().unwrap_or(""); - let search_results = self - .memory_search - .search(query, &config) - .await - .map_err(|e| MemoryRecallError(format!("Search failed: {e}")))?; + let (search_results, degraded_fallback_error) = if mode == SearchMode::Hybrid { + let search_result = self + .memory_search + .search(query, &config) + .await + .map_err(|e| MemoryRecallError(format!("Search failed: {e}"))); + let warm_results = if search_result.is_err() { + self.warm_cache_results(query, memory_type, config.max_results_per_source) + .await + } else { + Vec::new() + }; + let (selected, degraded_fallback_error) = + select_hybrid_results(warm_results, search_result)?; + if selected.len() < args.max_results { + tracing::debug!( + matched = selected.len(), + requested = args.max_results, + "memory recall returned partial warm or hybrid results" + ); + } + (selected, degraded_fallback_error) + } else { + let search_results = self + .memory_search + .search(query, &config) + .await + .map_err(|e| MemoryRecallError(format!("Search failed: {e}")))?; + (search_results, None) + }; let curated = curate_results(&search_results, args.max_results); @@ -271,6 +478,8 @@ impl Tool for MemoryRecallTool { Ok(MemoryRecallOutput { memories, total_found, + degraded_fallback_used: degraded_fallback_error.is_some(), + degraded_fallback_error, summary, }) } @@ -335,6 +544,7 @@ pub async fn memory_recall( #[cfg(test)] mod tests { use super::*; + use crate::memory::MemoryType; #[test] fn test_parse_search_mode_valid() { @@ -385,4 +595,165 @@ mod tests { fn test_parse_memory_type_invalid() { assert!(parse_memory_type("invalid").is_err()); } + + #[test] + fn test_warm_cache_age_secs_none_stays_none() { + assert_eq!(warm_cache_age_secs(None, 10_000), None); + } + + #[test] + fn test_warm_cache_age_secs_clamps_future_to_zero() { + assert_eq!(warm_cache_age_secs(Some(5_000), 4_000), Some(0)); + } + + #[test] + fn test_score_warm_memories_prefers_stronger_term_match() { + let mut auth_memory = Memory::new("auth token rotation policy", MemoryType::Fact); + auth_memory.importance = 0.4; + let mut cache_memory = Memory::new("cache eviction details", MemoryType::Fact); + cache_memory.importance = 0.9; + + let results = score_warm_memories( + "auth token", + &[cache_memory, auth_memory.clone()], + None, + 10, + &BTreeSet::new(), + ); + + assert!(!results.is_empty()); + assert_eq!(results[0].memory.id, auth_memory.id); + } + + #[test] + fn test_score_warm_memories_applies_memory_type_filter() { + let fact_memory = Memory::new("auth strategy", MemoryType::Fact); + let decision_memory = Memory::new("auth decision", MemoryType::Decision); + + let results = score_warm_memories( + "auth decision", + &[fact_memory, decision_memory.clone()], + Some(MemoryType::Decision), + 10, + &BTreeSet::new(), + ); + + assert_eq!(results.len(), 1); + assert_eq!(results[0].memory.memory_type, MemoryType::Decision); + assert_eq!(results[0].memory.id, decision_memory.id); + } + + #[test] + fn test_score_warm_memories_returns_empty_for_non_word_query() { + let memory = Memory::new("auth strategy", MemoryType::Fact); + let results = score_warm_memories("...", &[memory], None, 10, &BTreeSet::new()); + assert!(results.is_empty()); + } + + #[test] + fn test_score_warm_memories_excludes_inflight_forget_ids() { + let memory = Memory::new("auth strategy", MemoryType::Fact); + let mut inflight_forget_ids = BTreeSet::new(); + inflight_forget_ids.insert(memory.id.clone()); + + let results = score_warm_memories("auth", &[memory], None, 10, &inflight_forget_ids); + assert!(results.is_empty()); + } + + #[test] + fn test_select_hybrid_results_uses_partial_warm_results_on_search_error() { + let warm_memory = Memory::new("auth strategy", MemoryType::Fact); + let warm_results = vec![MemorySearchResult { + memory: warm_memory, + score: 0.9, + rank: 1, + }]; + + let selected = select_hybrid_results( + warm_results.clone(), + Err(MemoryRecallError("Search failed: boom".to_string())), + ) + .expect("expected warm fallback"); + + assert_eq!(selected.0.len(), warm_results.len()); + assert_eq!(selected.0[0].rank, warm_results[0].rank); + assert_eq!(selected.1, Some("Search failed: boom".to_string())); + } + + #[test] + fn test_select_hybrid_results_prefers_search_results_when_search_succeeds() { + let warm_memory = Memory::new("warm cache result", MemoryType::Fact); + let search_memory = Memory::new("hybrid result", MemoryType::Fact); + let warm_results = vec![MemorySearchResult { + memory: warm_memory, + score: 0.95, + rank: 1, + }]; + let search_results = vec![MemorySearchResult { + memory: search_memory.clone(), + score: 0.5, + rank: 1, + }]; + + let selected = + select_hybrid_results(warm_results, Ok(search_results.clone())).expect("select"); + + assert_eq!(selected.0.len(), 1); + assert_eq!(selected.0[0].memory.id, search_memory.id); + assert!(selected.1.is_none()); + } + + #[test] + fn test_select_hybrid_results_returns_error_when_search_fails_and_warm_is_empty() { + let result = select_hybrid_results( + Vec::new(), + Err(MemoryRecallError("Search failed: boom".to_string())), + ); + assert!(result.is_err()); + } + + #[test] + fn test_score_memory_with_terms_filters_type_and_content() { + let query_terms = vec!["auth".to_string()]; + let memory = Memory::new("auth strategy", MemoryType::Fact); + let wrong_type_score = score_memory_with_terms( + &query_terms, + &memory, + Some(MemoryType::Decision), + &BTreeSet::new(), + ); + let no_match_score = score_memory_with_terms( + &["billing".to_string()], + &memory, + Some(MemoryType::Fact), + &BTreeSet::new(), + ); + let match_score = score_memory_with_terms( + &query_terms, + &memory, + Some(MemoryType::Fact), + &BTreeSet::new(), + ); + + assert!(wrong_type_score.is_none()); + assert!(no_match_score.is_none()); + assert!(match_score.is_some()); + } + + #[test] + fn test_score_memory_with_terms_filters_inflight_forget() { + let query_terms = vec!["auth".to_string()]; + let memory = Memory::new("auth strategy", MemoryType::Fact); + let mut inflight_forget_ids = BTreeSet::new(); + inflight_forget_ids.insert(memory.id.clone()); + + let score = score_memory_with_terms( + &query_terms, + &memory, + Some(MemoryType::Fact), + &inflight_forget_ids, + ); + + assert!(score.is_none()); + } }