From 4ca12c755a7a72a6d342850ad3a8eb77029453d4 Mon Sep 17 00:00:00 2001 From: Will Pfleger Date: Mon, 25 Aug 2025 12:35:47 -0400 Subject: [PATCH 01/16] Persist extension config so we can resume recipe sessions w/ extensions --- crates/goose-cli/src/session/builder.rs | 40 ++++++++++++++++++- crates/goose-cli/src/session/mod.rs | 3 +- crates/goose/src/agents/agent.rs | 4 ++ crates/goose/src/agents/extension_manager.rs | 9 +++++ crates/goose/src/context_mgmt/auto_compact.rs | 1 + crates/goose/src/scheduler.rs | 1 + crates/goose/src/session/mod.rs | 4 +- crates/goose/src/session/storage.rs | 26 ++++++++++++ crates/goose/tests/test_support.rs | 1 + 9 files changed, 85 insertions(+), 4 deletions(-) diff --git a/crates/goose-cli/src/session/builder.rs b/crates/goose-cli/src/session/builder.rs index 92c316abd1d7..59a5f5501ace 100644 --- a/crates/goose-cli/src/session/builder.rs +++ b/crates/goose-cli/src/session/builder.rs @@ -348,6 +348,27 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session { let extensions_to_run: Vec<_> = if let Some(extensions) = session_config.extensions_override { agent.disable_router_for_recipe().await; extensions.into_iter().collect() + } else if session_config.resume { + if let Some(session_file) = session_file.as_ref() { + match session::read_metadata(session_file) { + Ok(metadata) if metadata.enabled_extensions.is_some() => { + metadata.enabled_extensions.unwrap().into_iter().collect() + } + _ => ExtensionConfigManager::get_all() + .expect("should load extensions") + .into_iter() + .filter(|ext| ext.enabled) + .map(|ext| ext.config) + .collect(), + } + } else { + ExtensionConfigManager::get_all() + .expect("should load extensions") + .into_iter() + .filter(|ext| ext.enabled) + .map(|ext| ext.config) + .collect() + } } else { ExtensionConfigManager::get_all() .expect("should load extensions") @@ -432,7 +453,7 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session { session_config.retry_config.clone(), ); - // Add extensions if provided + // Add CLI extensions if provided and track their source for extension_str in session_config.extensions { if let Err(e) = session.add_extension(extension_str.clone()).await { eprintln!( @@ -585,6 +606,23 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session { session.agent.override_system_prompt(override_prompt).await; } + // Save all extension configurations for session resume (after all extensions are added) + if let Some(session_file) = &session_file { + let all_extension_configs = session.agent.get_extension_configs().await; + if !all_extension_configs.is_empty() { + if let Err(e) = + goose::session::update_metadata_with_extensions(session_file, all_extension_configs) + .await + { + tracing::error!("Failed to persist extension configuration: {}", e); + if !session_config.quiet { + println!("Warning: Extension configuration could not be saved. Session resume may not work correctly."); + } + // Non-fatal: continue session even if we can't persist extensions + } + } + } + // Display session information unless in quiet mode if !session_config.quiet { output::display_session_info( diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 7e9ec62b5d9f..93535ae46d50 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -300,8 +300,9 @@ impl Session { /// * `builtin_name` - Name of the builtin extension(s), comma separated pub async fn add_builtin(&mut self, builtin_name: String) -> Result<()> { for name in builtin_name.split(',') { + let extension_name = name.trim().to_string(); let config = ExtensionConfig::Builtin { - name: name.trim().to_string(), + name: extension_name, display_name: None, // TODO: should set a timeout timeout: Some(goose::config::DEFAULT_EXTENSION_TIMEOUT), diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index cb676d025ca4..37c9595b561d 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -856,6 +856,10 @@ impl Agent { .expect("Failed to list extensions") } + pub async fn get_extension_configs(&self) -> Vec { + self.extension_manager.get_extension_configs().await + } + /// Handle a confirmation response for a tool request pub async fn handle_confirmation( &self, diff --git a/crates/goose/src/agents/extension_manager.rs b/crates/goose/src/agents/extension_manager.rs index 657c9f62b534..064fd1eee88f 100644 --- a/crates/goose/src/agents/extension_manager.rs +++ b/crates/goose/src/agents/extension_manager.rs @@ -500,6 +500,15 @@ impl ExtensionManager { Ok(self.extensions.lock().await.keys().cloned().collect()) } + pub async fn get_extension_configs(&self) -> Vec { + self.extensions + .lock() + .await + .values() + .map(|ext| ext.config.clone()) + .collect() + } + /// Get all tools from all clients with proper prefixing pub async fn get_prefixed_tools( &self, diff --git a/crates/goose/src/context_mgmt/auto_compact.rs b/crates/goose/src/context_mgmt/auto_compact.rs index 2c3c5a4c5a1e..42e442270427 100644 --- a/crates/goose/src/context_mgmt/auto_compact.rs +++ b/crates/goose/src/context_mgmt/auto_compact.rs @@ -270,6 +270,7 @@ mod tests { accumulated_input_tokens: Some(50), accumulated_output_tokens: Some(50), todo_content: None, + enabled_extensions: None, } } diff --git a/crates/goose/src/scheduler.rs b/crates/goose/src/scheduler.rs index bb0da404591a..25f11110e12f 100644 --- a/crates/goose/src/scheduler.rs +++ b/crates/goose/src/scheduler.rs @@ -1299,6 +1299,7 @@ async fn run_scheduled_job_internal( accumulated_input_tokens: None, accumulated_output_tokens: None, todo_content: None, + enabled_extensions: None, }; if let Err(e_fb) = crate::session::storage::save_messages_with_metadata( &session_file_path, diff --git a/crates/goose/src/session/mod.rs b/crates/goose/src/session/mod.rs index 5f4537fe7e6a..e7e9d1d8a392 100644 --- a/crates/goose/src/session/mod.rs +++ b/crates/goose/src/session/mod.rs @@ -5,8 +5,8 @@ pub mod storage; pub use storage::{ ensure_session_dir, generate_description, generate_description_with_schedule_id, generate_session_id, get_most_recent_session, get_path, list_sessions, persist_messages, - persist_messages_with_schedule_id, read_messages, read_metadata, update_metadata, Identifier, - SessionMetadata, + persist_messages_with_schedule_id, read_messages, read_metadata, update_metadata, + update_metadata_with_extensions, Identifier, SessionMetadata, }; pub use info::{get_valid_sorted_sessions, SessionInfo}; diff --git a/crates/goose/src/session/storage.rs b/crates/goose/src/session/storage.rs index 2da5b1112318..7d10b81cbf66 100644 --- a/crates/goose/src/session/storage.rs +++ b/crates/goose/src/session/storage.rs @@ -5,6 +5,7 @@ // - Backup creation // Additional debug logging can be added if needed for troubleshooting. +use crate::config::ExtensionConfig; use crate::conversation::message::Message; use crate::conversation::Conversation; use crate::providers::base::Provider; @@ -66,6 +67,8 @@ pub struct SessionMetadata { pub accumulated_output_tokens: Option, /// Session-scoped TODO list content pub todo_content: Option, + /// Extensions that were active in this session + pub enabled_extensions: Option>, } // Custom deserializer to handle old sessions without working_dir and todo_content @@ -87,6 +90,7 @@ impl<'de> Deserialize<'de> for SessionMetadata { accumulated_output_tokens: Option, working_dir: Option, todo_content: Option, // For backward compatibility + enabled_extensions: Option>, // For backward compatibility } let helper = Helper::deserialize(deserializer)?; @@ -109,6 +113,7 @@ impl<'de> Deserialize<'de> for SessionMetadata { accumulated_output_tokens: helper.accumulated_output_tokens, working_dir, todo_content: helper.todo_content, + enabled_extensions: helper.enabled_extensions, }) } } @@ -134,6 +139,7 @@ impl SessionMetadata { accumulated_input_tokens: None, accumulated_output_tokens: None, todo_content: None, + enabled_extensions: None, } } } @@ -1345,6 +1351,26 @@ pub async fn update_metadata(session_file: &Path, metadata: &SessionMetadata) -> save_messages_with_metadata(&secure_path, metadata, &messages) } +/// Update session metadata with current extension state +/// +/// This reads the current metadata, updates the extensions, and rewrites the session file. +pub async fn update_metadata_with_extensions( + session_file: &Path, + extension_records: Vec, +) -> Result<()> { + // Validate the path for security + let secure_path = get_path(Identifier::Path(session_file.to_path_buf()))?; + + // Read current metadata + let mut metadata = read_metadata(&secure_path)?; + + // Update the extensions + metadata.enabled_extensions = Some(extension_records); + + // Update the metadata in the file + update_metadata(&secure_path, &metadata).await +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/goose/tests/test_support.rs b/crates/goose/tests/test_support.rs index eeaca03253b4..1d47e20ed118 100644 --- a/crates/goose/tests/test_support.rs +++ b/crates/goose/tests/test_support.rs @@ -412,5 +412,6 @@ pub fn create_test_session_metadata(message_count: usize, working_dir: &str) -> accumulated_input_tokens: Some(50), accumulated_output_tokens: Some(50), todo_content: None, + enabled_extensions: None, } } From fb06fb8e6aea84e1f6f45627771dd4d7e435d600 Mon Sep 17 00:00:00 2001 From: Will Pfleger Date: Tue, 26 Aug 2025 11:58:10 -0400 Subject: [PATCH 02/16] Clarify comment and consolidate redundant print statements --- crates/goose-cli/src/session/builder.rs | 38 ++++--------------------- 1 file changed, 5 insertions(+), 33 deletions(-) diff --git a/crates/goose-cli/src/session/builder.rs b/crates/goose-cli/src/session/builder.rs index 59a5f5501ace..0e2a864c0c7a 100644 --- a/crates/goose-cli/src/session/builder.rs +++ b/crates/goose-cli/src/session/builder.rs @@ -453,21 +453,17 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session { session_config.retry_config.clone(), ); - // Add CLI extensions if provided and track their source + // Add stdio extensions if provided for extension_str in session_config.extensions { if let Err(e) = session.add_extension(extension_str.clone()).await { eprintln!( "{}", style(format!( - "Warning: Failed to start extension '{}': {}", + "Warning: Failed to start stdio extension '{}' ({}), continuing without it", extension_str, e )) .yellow() ); - eprintln!( - "{}", - style(format!("Continuing without extension '{}'", extension_str)).yellow() - ); // Offer debugging help if let Err(debug_err) = offer_extension_debugging_help( @@ -489,19 +485,11 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session { eprintln!( "{}", style(format!( - "Warning: Failed to start remote extension '{}': {}", + "Warning: Failed to start remote extension '{}' ({}), continuing without it", extension_str, e )) .yellow() ); - eprintln!( - "{}", - style(format!( - "Continuing without remote extension '{}'", - extension_str - )) - .yellow() - ); // Offer debugging help if let Err(debug_err) = offer_extension_debugging_help( @@ -526,19 +514,11 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session { eprintln!( "{}", style(format!( - "Warning: Failed to start streamable HTTP extension '{}': {}", + "Warning: Failed to start streamable HTTP extension '{}' ({}), continuing without it", extension_str, e )) .yellow() ); - eprintln!( - "{}", - style(format!( - "Continuing without streamable HTTP extension '{}'", - extension_str - )) - .yellow() - ); // Offer debugging help if let Err(debug_err) = offer_extension_debugging_help( @@ -560,19 +540,11 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session { eprintln!( "{}", style(format!( - "Warning: Failed to start builtin extension '{}': {}", + "Warning: Failed to start builtin extension '{}' ({}), continuing without it", builtin, e )) .yellow() ); - eprintln!( - "{}", - style(format!( - "Continuing without builtin extension '{}'", - builtin - )) - .yellow() - ); // Offer debugging help if let Err(debug_err) = offer_extension_debugging_help( From 8013c7f5bb13eb11ec72a82b6c354242e164ad41 Mon Sep 17 00:00:00 2001 From: Will Pfleger Date: Tue, 26 Aug 2025 12:25:25 -0400 Subject: [PATCH 03/16] Extract get_enabled() method for ExtensionConfigManager type to avoid duplicated code Also included a few changes picked up by my linter --- crates/goose-cli/src/commands/configure.rs | 5 ++--- crates/goose-cli/src/commands/web.rs | 14 ++++---------- crates/goose-cli/src/session/builder.rs | 21 +++------------------ crates/goose/src/agents/agent.rs | 7 +------ crates/goose/src/agents/subagent.rs | 8 +------- crates/goose/src/config/extensions.rs | 8 ++++++++ 6 files changed, 19 insertions(+), 44 deletions(-) diff --git a/crates/goose-cli/src/commands/configure.rs b/crates/goose-cli/src/commands/configure.rs index bcee3d2309e3..58f0302d0c10 100644 --- a/crates/goose-cli/src/commands/configure.rs +++ b/crates/goose-cli/src/commands/configure.rs @@ -1427,11 +1427,10 @@ pub fn toggle_experiments_dialog() -> Result<(), Box> { } pub async fn configure_tool_permissions_dialog() -> Result<(), Box> { - let mut extensions: Vec = ExtensionConfigManager::get_all() + let mut extensions: Vec = ExtensionConfigManager::get_enabled() .unwrap_or_default() .into_iter() - .filter(|ext| ext.enabled) - .map(|ext| ext.config.name().clone()) + .map(|ext| ext.name().clone()) .collect(); extensions.push("platform".to_string()); diff --git a/crates/goose-cli/src/commands/web.rs b/crates/goose-cli/src/commands/web.rs index 6f175d57e3ae..ab5ead780c5a 100644 --- a/crates/goose-cli/src/commands/web.rs +++ b/crates/goose-cli/src/commands/web.rs @@ -108,16 +108,10 @@ pub async fn handle_web(port: u16, host: String, open: bool) -> Result<()> { agent.update_provider(provider).await?; // Load and enable extensions from config - let extensions = goose::config::ExtensionConfigManager::get_all()?; - for ext_config in extensions { - if ext_config.enabled { - if let Err(e) = agent.add_extension(ext_config.config.clone()).await { - eprintln!( - "Warning: Failed to load extension {}: {}", - ext_config.config.name(), - e - ); - } + let enabled_configs = goose::config::ExtensionConfigManager::get_enabled()?; + for config in enabled_configs { + if let Err(e) = agent.add_extension(config.clone()).await { + eprintln!("Warning: Failed to load extension {}: {}", config.name(), e); } } diff --git a/crates/goose-cli/src/session/builder.rs b/crates/goose-cli/src/session/builder.rs index 0e2a864c0c7a..c6bce6622864 100644 --- a/crates/goose-cli/src/session/builder.rs +++ b/crates/goose-cli/src/session/builder.rs @@ -354,28 +354,13 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session { Ok(metadata) if metadata.enabled_extensions.is_some() => { metadata.enabled_extensions.unwrap().into_iter().collect() } - _ => ExtensionConfigManager::get_all() - .expect("should load extensions") - .into_iter() - .filter(|ext| ext.enabled) - .map(|ext| ext.config) - .collect(), + _ => ExtensionConfigManager::get_enabled().expect("should load extensions"), } } else { - ExtensionConfigManager::get_all() - .expect("should load extensions") - .into_iter() - .filter(|ext| ext.enabled) - .map(|ext| ext.config) - .collect() + ExtensionConfigManager::get_enabled().expect("should load extensions") } } else { - ExtensionConfigManager::get_all() - .expect("should load extensions") - .into_iter() - .filter(|ext| ext.enabled) - .map(|ext| ext.config) - .collect() + ExtensionConfigManager::get_enabled().expect("should load extensions") }; let mut set = JoinSet::new(); diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 37c9595b561d..59da205a1e75 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -1512,12 +1512,7 @@ impl Agent { (instructions, activities) }; - let extensions = ExtensionConfigManager::get_all().unwrap_or_default(); - let extension_configs: Vec<_> = extensions - .iter() - .filter(|e| e.enabled) - .map(|e| e.config.clone()) - .collect(); + let extension_configs = ExtensionConfigManager::get_enabled().unwrap_or_default(); let author = Author { contact: std::env::var("USER") diff --git a/crates/goose/src/agents/subagent.rs b/crates/goose/src/agents/subagent.rs index 50eded2b30fd..5be856b3d3ff 100644 --- a/crates/goose/src/agents/subagent.rs +++ b/crates/goose/src/agents/subagent.rs @@ -1,6 +1,5 @@ use crate::agents::subagent_task_config::DEFAULT_SUBAGENT_MAX_TURNS; use crate::{ - agents::extension::ExtensionConfig, agents::{extension_manager::ExtensionManager, Agent, TaskConfig}, config::ExtensionConfigManager, prompt_template::render_global_file, @@ -64,12 +63,7 @@ impl SubAgent { // 2. (TODO) If executing a sub-recipe task, only use recipe extensions // Get all enabled extensions from config - let enabled_extensions = ExtensionConfigManager::get_all() - .unwrap_or_default() - .into_iter() - .filter(|ext| ext.enabled) - .map(|ext| ext.config) - .collect::>(); + let enabled_extensions = ExtensionConfigManager::get_enabled().unwrap_or_default(); // Add enabled extensions to the subagent's extension manager for extension in enabled_extensions { diff --git a/crates/goose/src/config/extensions.rs b/crates/goose/src/config/extensions.rs index 3019f81ca537..e2b89e942314 100644 --- a/crates/goose/src/config/extensions.rs +++ b/crates/goose/src/config/extensions.rs @@ -85,4 +85,12 @@ impl ExtensionConfigManager { let extensions = Self::get_extensions_map()?; Ok(extensions.get(key).map(|e| e.enabled).unwrap_or(false)) } + + pub fn get_enabled() -> Result> { + Ok(Self::get_all()? + .into_iter() + .filter(|ext| ext.enabled) + .map(|ext| ext.config) + .collect()) + } } From 7d56e05152bea612bb7bac2085c230160469dcd3 Mon Sep 17 00:00:00 2001 From: Will Pfleger Date: Fri, 29 Aug 2025 13:50:32 -0400 Subject: [PATCH 04/16] Build extension metadata object and persist in the same write as the first session message --- .../src/scenario_tests/scenario_runner.rs | 2 +- crates/goose-cli/src/session/builder.rs | 36 +++++++++------ crates/goose-cli/src/session/mod.rs | 44 ++++++++++++++----- crates/goose/src/session/mod.rs | 4 +- crates/goose/src/session/storage.rs | 20 --------- 5 files changed, 60 insertions(+), 46 deletions(-) diff --git a/crates/goose-cli/src/scenario_tests/scenario_runner.rs b/crates/goose-cli/src/scenario_tests/scenario_runner.rs index 130077c9fdc7..f95831d7d2a5 100644 --- a/crates/goose-cli/src/scenario_tests/scenario_runner.rs +++ b/crates/goose-cli/src/scenario_tests/scenario_runner.rs @@ -218,7 +218,7 @@ where .update_provider(provider_arc as Arc) .await?; - let mut session = Session::new(agent, None, false, None, None, None, None); + let mut session = Session::new(agent, None, false, None, None, None, None, None); let mut error = None; for message in &messages { diff --git a/crates/goose-cli/src/session/builder.rs b/crates/goose-cli/src/session/builder.rs index c6bce6622864..e9d64dd84ced 100644 --- a/crates/goose-cli/src/session/builder.rs +++ b/crates/goose-cli/src/session/builder.rs @@ -4,10 +4,11 @@ use goose::agents::Agent; use goose::config::{Config, ExtensionConfig, ExtensionConfigManager}; use goose::providers::create; use goose::recipe::{Response, SubRecipe}; -use goose::session; use goose::session::Identifier; +use goose::session::{self, SessionMetadata}; use rustyline::EditMode; use std::collections::HashSet; +use std::path::PathBuf; use std::process; use std::sync::Arc; use tokio::task::JoinSet; @@ -142,6 +143,7 @@ async fn offer_extension_debugging_help( None, None, None, + None, // No startup metadata for debug sessions ); // Process the debugging request @@ -427,7 +429,7 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session { } }); - // Create new session + // Create new session (initially without metadata) let mut session = Session::new( Arc::try_unwrap(agent_ptr).unwrap_or_else(|_| panic!("There should be no more references")), session_file.clone(), @@ -436,6 +438,7 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session { session_config.max_turns, edit_mode, session_config.retry_config.clone(), + None, // Will be set after extensions are loaded ); // Add stdio extensions if provided @@ -563,21 +566,28 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session { session.agent.override_system_prompt(override_prompt).await; } - // Save all extension configurations for session resume (after all extensions are added) + // Prepare metadata with extension configurations (will be persisted with first message) if let Some(session_file) = &session_file { let all_extension_configs = session.agent.get_extension_configs().await; + + // Prepare metadata to be persisted with first message + let mut startup_metadata = if session_config.resume { + // For resumed sessions, load existing metadata or create new one + session::read_metadata(session_file).unwrap_or_else(|_| { + SessionMetadata::new(std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."))) + }) + } else { + // For new sessions, create fresh metadata + SessionMetadata::new(std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."))) + }; + + // Update metadata with extension configurations if !all_extension_configs.is_empty() { - if let Err(e) = - goose::session::update_metadata_with_extensions(session_file, all_extension_configs) - .await - { - tracing::error!("Failed to persist extension configuration: {}", e); - if !session_config.quiet { - println!("Warning: Extension configuration could not be saved. Session resume may not work correctly."); - } - // Non-fatal: continue session even if we can't persist extensions - } + startup_metadata.enabled_extensions = Some(all_extension_configs); } + + // Set the prepared metadata on the session + session.startup_metadata = Some(startup_metadata); } // Display session information unless in quiet mode diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 93535ae46d50..8a693fc3a71f 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -32,7 +32,7 @@ use goose::agents::types::RetryConfig; use goose::agents::{Agent, SessionConfig}; use goose::config::Config; use goose::providers::pricing::initialize_pricing_cache; -use goose::session; +use goose::session::{self, SessionMetadata}; use input::InputResult; use rmcp::model::PromptMessage; use rmcp::model::ServerNotification; @@ -55,7 +55,7 @@ pub enum RunMode { } pub struct Session { - agent: Agent, + pub agent: Agent, messages: Conversation, session_file: Option, // Cache for completion data - using std::sync for thread safety without async @@ -66,6 +66,7 @@ pub struct Session { max_turns: Option, edit_mode: Option, retry_config: Option, + startup_metadata: Option, // Metadata to write with first message } // Cache structure for completion data @@ -130,6 +131,7 @@ impl Session { max_turns: Option, edit_mode: Option, retry_config: Option, + startup_metadata: Option, ) -> Self { let messages = if let Some(session_file) = &session_file { session::read_messages(session_file).unwrap_or_else(|e| { @@ -152,6 +154,7 @@ impl Session { max_turns, edit_mode, retry_config, + startup_metadata, } } @@ -387,14 +390,35 @@ impl Session { std::env::current_dir().expect("failed to get current session working directory"), ); - session::persist_messages_with_schedule_id( - session_file, - &self.messages, - Some(provider), - self.scheduled_job_id.clone(), - working_dir, - ) - .await?; + // First-write optimization: Save complete metadata (including extensions) only once + // + // - startup_metadata contains our metadata with extensions (set during session init) + // - .take() removes the value and returns it, leaving None behind + // - After .take(), startup_metadata is None forever + // + // Result: First message write uses cached metadata, all others use normal flow + // This avoids multiple file writes during startup without needing a "dirty" flag to track the cached metadata + if let Some(startup_metadata) = self.startup_metadata.take() { + // First write: Use cached metadata with extensions + let secure_path = + session::get_path(session::Identifier::Path(session_file.to_path_buf()))?; + session::storage::save_messages_with_metadata( + &secure_path, + &startup_metadata, + &self.messages, + )?; + } else { + // All other writes: Use normal persistence + // (startup_metadata is None after .take() was called above) + session::persist_messages_with_schedule_id( + session_file, + &self.messages, + Some(provider), + self.scheduled_job_id.clone(), + working_dir, + ) + .await?; + } } // Track the current directory and last instruction in projects.json diff --git a/crates/goose/src/session/mod.rs b/crates/goose/src/session/mod.rs index e7e9d1d8a392..5f4537fe7e6a 100644 --- a/crates/goose/src/session/mod.rs +++ b/crates/goose/src/session/mod.rs @@ -5,8 +5,8 @@ pub mod storage; pub use storage::{ ensure_session_dir, generate_description, generate_description_with_schedule_id, generate_session_id, get_most_recent_session, get_path, list_sessions, persist_messages, - persist_messages_with_schedule_id, read_messages, read_metadata, update_metadata, - update_metadata_with_extensions, Identifier, SessionMetadata, + persist_messages_with_schedule_id, read_messages, read_metadata, update_metadata, Identifier, + SessionMetadata, }; pub use info::{get_valid_sorted_sessions, SessionInfo}; diff --git a/crates/goose/src/session/storage.rs b/crates/goose/src/session/storage.rs index 7d10b81cbf66..5422472ffa3a 100644 --- a/crates/goose/src/session/storage.rs +++ b/crates/goose/src/session/storage.rs @@ -1351,26 +1351,6 @@ pub async fn update_metadata(session_file: &Path, metadata: &SessionMetadata) -> save_messages_with_metadata(&secure_path, metadata, &messages) } -/// Update session metadata with current extension state -/// -/// This reads the current metadata, updates the extensions, and rewrites the session file. -pub async fn update_metadata_with_extensions( - session_file: &Path, - extension_records: Vec, -) -> Result<()> { - // Validate the path for security - let secure_path = get_path(Identifier::Path(session_file.to_path_buf()))?; - - // Read current metadata - let mut metadata = read_metadata(&secure_path)?; - - // Update the extensions - metadata.enabled_extensions = Some(extension_records); - - // Update the metadata in the file - update_metadata(&secure_path, &metadata).await -} - #[cfg(test)] mod tests { use super::*; From 3240ed99403f4c5b20396ea030f6a1f7acec7e72 Mon Sep 17 00:00:00 2001 From: Will Pfleger Date: Fri, 29 Aug 2025 17:00:46 -0400 Subject: [PATCH 05/16] Simplify ExtensionConfigManager struct into free functions --- crates/goose-cli/src/commands/configure.rs | 75 ++++---- crates/goose-cli/src/commands/web.rs | 2 +- crates/goose-cli/src/session/builder.rs | 29 ++- .../src/routes/config_management.rs | 43 ++--- crates/goose/src/agents/agent.rs | 20 +- crates/goose/src/agents/extension_manager.rs | 4 +- crates/goose/src/agents/subagent.rs | 4 +- crates/goose/src/config/extensions.rs | 172 ++++++++++++------ crates/goose/src/config/mod.rs | 5 +- 9 files changed, 195 insertions(+), 159 deletions(-) diff --git a/crates/goose-cli/src/commands/configure.rs b/crates/goose-cli/src/commands/configure.rs index 58f0302d0c10..8bda366e1758 100644 --- a/crates/goose-cli/src/commands/configure.rs +++ b/crates/goose-cli/src/commands/configure.rs @@ -9,12 +9,12 @@ use goose::agents::platform_tools::{ use goose::agents::Agent; use goose::agents::{extension::Envs, ExtensionConfig}; use goose::config::custom_providers::CustomProviderConfig; -use goose::config::extensions::name_to_key; -use goose::config::permission::PermissionLevel; -use goose::config::{ - Config, ConfigError, ExperimentManager, ExtensionConfigManager, ExtensionEntry, - PermissionManager, +use goose::config::extensions::{ + get_all_extension_names, get_all_extensions, get_enabled_extensions, get_extension_by_name, + name_to_key, remove_extension, set_extension, set_extension_enabled, }; +use goose::config::permission::PermissionLevel; +use goose::config::{Config, ConfigError, ExperimentManager, ExtensionEntry, PermissionManager}; use goose::conversation::message::Message; use goose::model::ModelConfig; use goose::providers::{create, providers}; @@ -126,7 +126,7 @@ pub async fn handle_configure() -> Result<(), Box> { ); // Since we are setting up for the first time, we'll also enable the developer system // This operation is best-effort and errors are ignored - ExtensionConfigManager::set(ExtensionEntry { + set_extension(ExtensionEntry { enabled: true, config: ExtensionConfig::Builtin { name: "developer".to_string(), @@ -136,7 +136,7 @@ pub async fn handle_configure() -> Result<(), Box> { description: None, available_tools: Vec::new(), }, - })?; + }); } Ok(false) => { let _ = config.clear(); @@ -669,7 +669,7 @@ pub async fn configure_provider_dialog() -> Result> { /// Configure extensions that can be used with goose /// Dialog for toggling which extensions are enabled/disabled pub fn toggle_extensions_dialog() -> Result<(), Box> { - let extensions = ExtensionConfigManager::get_all()?; + let extensions = get_all_extensions(); if extensions.is_empty() { cliclack::outro( @@ -710,10 +710,10 @@ pub fn toggle_extensions_dialog() -> Result<(), Box> { // Update enabled status for each extension for name in extension_status.iter().map(|(name, _)| name) { - ExtensionConfigManager::set_enabled( + set_extension_enabled( &name_to_key(name), selected.iter().any(|s| s.as_str() == name), - )?; + ); } cliclack::outro("Extension settings updated successfully")?; @@ -787,7 +787,7 @@ pub fn configure_extensions_dialog() -> Result<(), Box> { let display_name = get_display_name(&extension); - ExtensionConfigManager::set(ExtensionEntry { + set_extension(ExtensionEntry { enabled: true, config: ExtensionConfig::Builtin { name: extension.clone(), @@ -797,12 +797,12 @@ pub fn configure_extensions_dialog() -> Result<(), Box> { description: None, available_tools: Vec::new(), }, - })?; + }); cliclack::outro(format!("Enabled {} extension", style(extension).green()))?; } "stdio" => { - let extensions = ExtensionConfigManager::get_all_names()?; + let extensions = get_all_extension_names(); let name: String = cliclack::input("What would you like to call this extension?") .placeholder("my-extension") .validate(move |input: &String| { @@ -892,7 +892,7 @@ pub fn configure_extensions_dialog() -> Result<(), Box> { } } - ExtensionConfigManager::set(ExtensionEntry { + set_extension(ExtensionEntry { enabled: true, config: ExtensionConfig::Stdio { name: name.clone(), @@ -905,12 +905,12 @@ pub fn configure_extensions_dialog() -> Result<(), Box> { bundled: None, available_tools: Vec::new(), }, - })?; + }); cliclack::outro(format!("Added {} extension", style(name).green()))?; } "sse" => { - let extensions = ExtensionConfigManager::get_all_names()?; + let extensions = get_all_extension_names(); let name: String = cliclack::input("What would you like to call this extension?") .placeholder("my-remote-extension") .validate(move |input: &String| { @@ -996,7 +996,7 @@ pub fn configure_extensions_dialog() -> Result<(), Box> { } } - ExtensionConfigManager::set(ExtensionEntry { + set_extension(ExtensionEntry { enabled: true, config: ExtensionConfig::Sse { name: name.clone(), @@ -1008,12 +1008,12 @@ pub fn configure_extensions_dialog() -> Result<(), Box> { bundled: None, available_tools: Vec::new(), }, - })?; + }); cliclack::outro(format!("Added {} extension", style(name).green()))?; } "streamable_http" => { - let extensions = ExtensionConfigManager::get_all_names()?; + let extensions = get_all_extension_names(); let name: String = cliclack::input("What would you like to call this extension?") .placeholder("my-remote-extension") .validate(move |input: &String| { @@ -1123,7 +1123,7 @@ pub fn configure_extensions_dialog() -> Result<(), Box> { } } - ExtensionConfigManager::set(ExtensionEntry { + set_extension(ExtensionEntry { enabled: true, config: ExtensionConfig::StreamableHttp { name: name.clone(), @@ -1136,7 +1136,7 @@ pub fn configure_extensions_dialog() -> Result<(), Box> { bundled: None, available_tools: Vec::new(), }, - })?; + }); cliclack::outro(format!("Added {} extension", style(name).green()))?; } @@ -1147,7 +1147,7 @@ pub fn configure_extensions_dialog() -> Result<(), Box> { } pub fn remove_extension_dialog() -> Result<(), Box> { - let extensions = ExtensionConfigManager::get_all()?; + let extensions = get_all_extensions(); // Create a list of extension names and their enabled status let mut extension_status: Vec<(String, bool)> = extensions @@ -1192,7 +1192,7 @@ pub fn remove_extension_dialog() -> Result<(), Box> { .interact()?; for name in selected { - ExtensionConfigManager::remove(&name_to_key(name))?; + remove_extension(&name_to_key(name)); let mut permission_manager = PermissionManager::default(); permission_manager.remove_extension(&name_to_key(name)); cliclack::outro(format!("Removed {} extension", style(name).green()))?; @@ -1427,8 +1427,7 @@ pub fn toggle_experiments_dialog() -> Result<(), Box> { } pub async fn configure_tool_permissions_dialog() -> Result<(), Box> { - let mut extensions: Vec = ExtensionConfigManager::get_enabled() - .unwrap_or_default() + let mut extensions: Vec = get_enabled_extensions() .into_iter() .map(|ext| ext.name().clone()) .collect(); @@ -1463,7 +1462,7 @@ pub async fn configure_tool_permissions_dialog() -> Result<(), Box> { let agent = Agent::new(); let new_provider = create(&provider_name, model_config)?; agent.update_provider(new_provider).await?; - if let Ok(Some(config)) = ExtensionConfigManager::get_config_by_name(&selected_extension_name) { + if let Some(config) = get_extension_by_name(&selected_extension_name) { agent .add_extension(config.clone()) .await @@ -1746,13 +1745,13 @@ pub async fn handle_openrouter_auth() -> Result<(), Box> { println!("✓ Configuration test passed!"); // Enable the developer extension by default if not already enabled - let entries = ExtensionConfigManager::get_all()?; + let entries = get_all_extensions(); let has_developer = entries .iter() .any(|e| e.config.name() == "developer" && e.enabled); if !has_developer { - match ExtensionConfigManager::set(ExtensionEntry { + set_extension(ExtensionEntry { enabled: true, config: ExtensionConfig::Builtin { name: "developer".to_string(), @@ -1764,12 +1763,8 @@ pub async fn handle_openrouter_auth() -> Result<(), Box> { description: None, available_tools: Vec::new(), }, - }) { - Ok(_) => println!("✓ Developer extension enabled"), - Err(e) => { - eprintln!("⚠️ Failed to enable developer extension: {}", e) - } - } + }); + println!("✓ Developer extension enabled"); } cliclack::outro("OpenRouter setup complete! You can now use Goose.")?; @@ -1849,13 +1844,13 @@ pub async fn handle_tetrate_auth() -> Result<(), Box> { println!("✓ Configuration test passed!"); // Enable the developer extension by default if not already enabled - let entries = ExtensionConfigManager::get_all()?; + let entries = get_all_extensions(); let has_developer = entries .iter() .any(|e| e.config.name() == "developer" && e.enabled); if !has_developer { - match ExtensionConfigManager::set(ExtensionEntry { + set_extension(ExtensionEntry { enabled: true, config: ExtensionConfig::Builtin { name: "developer".to_string(), @@ -1867,12 +1862,8 @@ pub async fn handle_tetrate_auth() -> Result<(), Box> { description: None, available_tools: Vec::new(), }, - }) { - Ok(_) => println!("✓ Developer extension enabled"), - Err(e) => { - eprintln!("⚠️ Failed to enable developer extension: {}", e) - } - } + }); + println!("✓ Developer extension enabled"); } cliclack::outro("Tetrate Agent Router Service setup complete! You can now use Goose.")?; diff --git a/crates/goose-cli/src/commands/web.rs b/crates/goose-cli/src/commands/web.rs index ab5ead780c5a..ba01f12d34cb 100644 --- a/crates/goose-cli/src/commands/web.rs +++ b/crates/goose-cli/src/commands/web.rs @@ -108,7 +108,7 @@ pub async fn handle_web(port: u16, host: String, open: bool) -> Result<()> { agent.update_provider(provider).await?; // Load and enable extensions from config - let enabled_configs = goose::config::ExtensionConfigManager::get_enabled()?; + let enabled_configs = goose::config::get_enabled_extensions(); for config in enabled_configs { if let Err(e) = agent.add_extension(config.clone()).await { eprintln!("Warning: Failed to load extension {}: {}", config.name(), e); diff --git a/crates/goose-cli/src/session/builder.rs b/crates/goose-cli/src/session/builder.rs index e9d64dd84ced..e5e6befb61d5 100644 --- a/crates/goose-cli/src/session/builder.rs +++ b/crates/goose-cli/src/session/builder.rs @@ -1,7 +1,7 @@ use console::style; use goose::agents::types::RetryConfig; use goose::agents::Agent; -use goose::config::{Config, ExtensionConfig, ExtensionConfigManager}; +use goose::config::{get_all_extensions, get_enabled_extensions, Config, ExtensionConfig}; use goose::providers::create; use goose::recipe::{Response, SubRecipe}; use goose::session::Identifier; @@ -115,18 +115,17 @@ async fn offer_extension_debugging_help( debug_agent.update_provider(provider).await?; // Add the developer extension if available to help with debugging - if let Ok(extensions) = ExtensionConfigManager::get_all() { - for ext_wrapper in extensions { - if ext_wrapper.enabled && ext_wrapper.config.name() == "developer" { - if let Err(e) = debug_agent.add_extension(ext_wrapper.config).await { - // If we can't add developer extension, continue without it - eprintln!( - "Note: Could not load developer extension for debugging: {}", - e - ); - } - break; + let extensions = get_all_extensions(); + for ext_wrapper in extensions { + if ext_wrapper.enabled && ext_wrapper.config.name() == "developer" { + if let Err(e) = debug_agent.add_extension(ext_wrapper.config).await { + // If we can't add developer extension, continue without it + eprintln!( + "Note: Could not load developer extension for debugging: {}", + e + ); } + break; } } @@ -356,13 +355,13 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session { Ok(metadata) if metadata.enabled_extensions.is_some() => { metadata.enabled_extensions.unwrap().into_iter().collect() } - _ => ExtensionConfigManager::get_enabled().expect("should load extensions"), + _ => get_enabled_extensions(), } } else { - ExtensionConfigManager::get_enabled().expect("should load extensions") + get_enabled_extensions() } } else { - ExtensionConfigManager::get_enabled().expect("should load extensions") + get_enabled_extensions() }; let mut set = JoinSet::new(); diff --git a/crates/goose-server/src/routes/config_management.rs b/crates/goose-server/src/routes/config_management.rs index 79a177943d57..eab4c000d661 100644 --- a/crates/goose-server/src/routes/config_management.rs +++ b/crates/goose-server/src/routes/config_management.rs @@ -7,9 +7,9 @@ use axum::{ Json, Router, }; use etcetera::{choose_app_strategy, AppStrategy}; +use goose::config::ExtensionEntry; use goose::config::APP_STRATEGY; use goose::config::{Config, ConfigError}; -use goose::config::{ExtensionConfigManager, ExtensionEntry}; use goose::model::ModelConfig; use goose::providers::base::ProviderMetadata; use goose::providers::pricing::{ @@ -201,19 +201,8 @@ pub async fn get_extensions( ) -> Result, StatusCode> { verify_secret_key(&headers, &state)?; - match ExtensionConfigManager::get_all() { - Ok(extensions) => Ok(Json(ExtensionResponse { extensions })), - Err(err) => { - if err - .downcast_ref::() - .is_some_and(|e| matches!(e, goose::config::base::ConfigError::DeserializeError(_))) - { - Err(StatusCode::UNPROCESSABLE_ENTITY) - } else { - Err(StatusCode::INTERNAL_SERVER_ERROR) - } - } - } + let extensions = goose::config::get_all_extensions(); + Ok(Json(ExtensionResponse { extensions })) } #[utoipa::path( @@ -234,24 +223,20 @@ pub async fn add_extension( ) -> Result, StatusCode> { verify_secret_key(&headers, &state)?; - let extensions = - ExtensionConfigManager::get_all().map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + let extensions = goose::config::get_all_extensions(); let key = goose::config::extensions::name_to_key(&extension_query.name); let is_update = extensions.iter().any(|e| e.config.key() == key); - match ExtensionConfigManager::set(ExtensionEntry { + goose::config::set_extension(ExtensionEntry { enabled: extension_query.enabled, config: extension_query.config, - }) { - Ok(_) => { - if is_update { - Ok(Json(format!("Updated extension {}", extension_query.name))) - } else { - Ok(Json(format!("Added extension {}", extension_query.name))) - } - } - Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR), + }); + + if is_update { + Ok(Json(format!("Updated extension {}", extension_query.name))) + } else { + Ok(Json(format!("Added extension {}", extension_query.name))) } } @@ -272,10 +257,8 @@ pub async fn remove_extension( verify_secret_key(&headers, &state)?; let key = goose::config::extensions::name_to_key(&name); - match ExtensionConfigManager::remove(&key) { - Ok(_) => Ok(Json(format!("Removed extension {}", name))), - Err(_) => Err(StatusCode::NOT_FOUND), - } + goose::config::remove_extension(&key); + Ok(Json(format!("Removed extension {}", name))) } #[utoipa::path( diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 59da205a1e75..818176ad0dd7 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -31,7 +31,7 @@ use crate::agents::tool_route_manager::ToolRouteManager; use crate::agents::tool_router_index_manager::ToolRouterIndexManager; use crate::agents::types::SessionConfig; use crate::agents::types::{FrontendTool, ToolResultReceiver}; -use crate::config::{Config, ExtensionConfigManager, PermissionManager}; +use crate::config::{get_enabled_extensions, get_extension_by_name, Config, PermissionManager}; use crate::context_mgmt::auto_compact; use crate::conversation::{debug_conversation_fix, fix_conversation, Conversation}; use crate::permission::permission_judge::{check_tool_permissions, PermissionCheckResult}; @@ -653,9 +653,9 @@ impl Agent { return (request_id, result); } - let config = match ExtensionConfigManager::get_config_by_name(&extension_name) { - Ok(Some(config)) => config, - Ok(None) => { + let config = match get_extension_by_name(&extension_name) { + Some(config) => config, + None => { return ( request_id, Err(ErrorData::new( @@ -668,16 +668,6 @@ impl Agent { )), ) } - Err(e) => { - return ( - request_id, - Err(ErrorData::new( - ErrorCode::INTERNAL_ERROR, - format!("Failed to get extension config: {}", e), - None, - )), - ) - } }; let result = self .extension_manager @@ -1512,7 +1502,7 @@ impl Agent { (instructions, activities) }; - let extension_configs = ExtensionConfigManager::get_enabled().unwrap_or_default(); + let extension_configs = get_enabled_extensions(); let author = Author { contact: std::env::var("USER") diff --git a/crates/goose/src/agents/extension_manager.rs b/crates/goose/src/agents/extension_manager.rs index 064fd1eee88f..ad636d5b295a 100644 --- a/crates/goose/src/agents/extension_manager.rs +++ b/crates/goose/src/agents/extension_manager.rs @@ -26,7 +26,7 @@ use tracing::{error, warn}; use super::extension::{ExtensionConfig, ExtensionError, ExtensionInfo, ExtensionResult, ToolInfo}; use super::tool_execution::ToolCallResult; use crate::agents::extension::{Envs, ProcessExit}; -use crate::config::{Config, ExtensionConfigManager}; +use crate::config::{get_all_extensions, Config}; use crate::oauth::oauth_flow; use crate::prompt_template; use mcp_client::client::{McpClient, McpClientTrait}; @@ -965,7 +965,7 @@ impl ExtensionManager { // First get disabled extensions from current config let mut disabled_extensions: Vec = vec![]; - for extension in ExtensionConfigManager::get_all().expect("should load extensions") { + for extension in get_all_extensions() { if !extension.enabled { let config = extension.config.clone(); let description = match &config { diff --git a/crates/goose/src/agents/subagent.rs b/crates/goose/src/agents/subagent.rs index 5be856b3d3ff..ec557dc11dd9 100644 --- a/crates/goose/src/agents/subagent.rs +++ b/crates/goose/src/agents/subagent.rs @@ -1,7 +1,7 @@ use crate::agents::subagent_task_config::DEFAULT_SUBAGENT_MAX_TURNS; use crate::{ agents::{extension_manager::ExtensionManager, Agent, TaskConfig}, - config::ExtensionConfigManager, + config::get_enabled_extensions, prompt_template::render_global_file, providers::errors::ProviderError, }; @@ -63,7 +63,7 @@ impl SubAgent { // 2. (TODO) If executing a sub-recipe task, only use recipe extensions // Get all enabled extensions from config - let enabled_extensions = ExtensionConfigManager::get_enabled().unwrap_or_default(); + let enabled_extensions = get_enabled_extensions(); // Add enabled extensions to the subagent's extension manager for extension in enabled_extensions { diff --git a/crates/goose/src/config/extensions.rs b/crates/goose/src/config/extensions.rs index e2b89e942314..c30d7df7a523 100644 --- a/crates/goose/src/config/extensions.rs +++ b/crates/goose/src/config/extensions.rs @@ -1,8 +1,8 @@ use super::base::Config; use crate::agents::ExtensionConfig; -use anyhow::Result; use serde::{Deserialize, Serialize}; use std::collections::HashMap; +use tracing; use utoipa::ToSchema; pub const DEFAULT_EXTENSION: &str = "developer"; @@ -25,72 +25,142 @@ pub fn name_to_key(name: &str) -> String { .to_lowercase() } -pub struct ExtensionConfigManager; +fn get_extensions_map() -> HashMap { + let config = Config::global(); + config + .get_param(EXTENSIONS_CONFIG_KEY) + .unwrap_or_else(|_| HashMap::new()) +} -impl ExtensionConfigManager { - fn get_extensions_map() -> Result> { - let config = Config::global(); - Ok(config - .get_param(EXTENSIONS_CONFIG_KEY) - .unwrap_or_else(|_| HashMap::new())) +fn save_extensions_map(extensions: HashMap) { + let config = Config::global(); + match serde_json::to_value(extensions) { + Ok(value) => { + if let Err(e) = config.set_param(EXTENSIONS_CONFIG_KEY, value) { + tracing::debug!("Failed to save extensions config: {}", e); + } + } + Err(e) => { + tracing::debug!("Failed to serialize extensions: {}", e); + } } +} - fn save_extensions_map(extensions: HashMap) -> Result<()> { - let config = Config::global(); - config.set_param(EXTENSIONS_CONFIG_KEY, serde_json::to_value(extensions)?)?; - Ok(()) - } +pub fn get_extension_by_name(name: &str) -> Option { + let extensions = get_extensions_map(); + extensions + .values() + .find(|entry| entry.config.name() == name) + .map(|entry| entry.config.clone()) +} - pub fn get_config_by_name(name: &str) -> Result> { - let extensions = Self::get_extensions_map()?; - Ok(extensions - .values() - .find(|entry| entry.config.name() == name) - .map(|entry| entry.config.clone())) - } +pub fn set_extension(entry: ExtensionEntry) { + let mut extensions = get_extensions_map(); + let key = entry.config.key(); + extensions.insert(key, entry); + save_extensions_map(extensions); +} - pub fn set(entry: ExtensionEntry) -> Result<()> { - let mut extensions = Self::get_extensions_map()?; - let key = entry.config.key(); - extensions.insert(key, entry); - Self::save_extensions_map(extensions) - } +pub fn remove_extension(key: &str) { + let mut extensions = get_extensions_map(); + extensions.remove(key); + save_extensions_map(extensions); +} - pub fn remove(key: &str) -> Result<()> { - let mut extensions = Self::get_extensions_map()?; - extensions.remove(key); - Self::save_extensions_map(extensions) +pub fn set_extension_enabled(key: &str, enabled: bool) { + let mut extensions = get_extensions_map(); + if let Some(entry) = extensions.get_mut(key) { + entry.enabled = enabled; + save_extensions_map(extensions); } +} - pub fn set_enabled(key: &str, enabled: bool) -> Result<()> { - let mut extensions = Self::get_extensions_map()?; - if let Some(entry) = extensions.get_mut(key) { - entry.enabled = enabled; - Self::save_extensions_map(extensions)?; +pub fn get_all_extensions() -> Vec { + let extensions = get_extensions_map(); + extensions.into_values().collect() +} + +pub fn get_all_extension_names() -> Vec { + let extensions = get_extensions_map(); + extensions.keys().cloned().collect() +} + +pub fn is_extension_enabled(key: &str) -> bool { + let extensions = get_extensions_map(); + extensions.get(key).map(|e| e.enabled).unwrap_or(false) +} + +pub fn get_enabled_extensions() -> Vec { + get_all_extensions() + .into_iter() + .filter(|ext| ext.enabled) + .map(|ext| ext.config) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::agents::ExtensionConfig; + + fn create_test_extension_config() -> ExtensionConfig { + ExtensionConfig::Builtin { + name: "test_extension".to_string(), + display_name: Some("Test Extension".to_string()), + description: Some("A test extension".to_string()), + timeout: None, + bundled: None, + available_tools: vec![], } - Ok(()) } - pub fn get_all() -> Result> { - let extensions = Self::get_extensions_map()?; - Ok(extensions.into_values().collect()) + #[test] + fn test_name_to_key_function() { + assert_eq!(name_to_key("Test Extension"), "testextension"); + assert_eq!(name_to_key("Developer Tools"), "developertools"); + assert_eq!(name_to_key("simple"), "simple"); + assert_eq!(name_to_key("UPPER_case MiXeD"), "upper_casemixed"); } - pub fn get_all_names() -> Result> { - let extensions = Self::get_extensions_map()?; - Ok(extensions.keys().cloned().collect()) + #[test] + fn test_extension_config_key_generation() { + let config = create_test_extension_config(); + assert_eq!(config.key(), "test_extension"); + + let config_with_spaces = ExtensionConfig::Builtin { + name: "Test Extension Name".to_string(), + display_name: Some("Test Extension".to_string()), + description: Some("A test extension".to_string()), + timeout: None, + bundled: None, + available_tools: vec![], + }; + assert_eq!(config_with_spaces.key(), "testextensionname"); } - pub fn is_enabled(key: &str) -> Result { - let extensions = Self::get_extensions_map()?; - Ok(extensions.get(key).map(|e| e.enabled).unwrap_or(false)) + #[test] + fn test_extension_entry_serialization() { + let config = create_test_extension_config(); + let entry = ExtensionEntry { + enabled: true, + config, + }; + + // Test that ExtensionEntry can be serialized/deserialized + let json = serde_json::to_string(&entry).unwrap(); + assert!(json.contains("\"enabled\":true")); + assert!(json.contains("\"name\":\"test_extension\"")); + + let deserialized: ExtensionEntry = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.enabled, true); + assert_eq!(deserialized.config.name(), "test_extension"); } - pub fn get_enabled() -> Result> { - Ok(Self::get_all()? - .into_iter() - .filter(|ext| ext.enabled) - .map(|ext| ext.config) - .collect()) + #[test] + fn test_get_extensions_map_returns_hashmap() { + // Test that get_extensions_map returns a HashMap (may be empty or not depending on global config) + let extensions = get_extensions_map(); + // Just verify it returns a HashMap - don't assert on contents since global config may vary + assert!(extensions.is_empty() || !extensions.is_empty()); } } diff --git a/crates/goose/src/config/mod.rs b/crates/goose/src/config/mod.rs index f060299a6d88..e5e3dfcd0ffe 100644 --- a/crates/goose/src/config/mod.rs +++ b/crates/goose/src/config/mod.rs @@ -10,7 +10,10 @@ pub use crate::agents::ExtensionConfig; pub use base::{Config, ConfigError, APP_STRATEGY}; pub use custom_providers::CustomProviderConfig; pub use experiments::ExperimentManager; -pub use extensions::{ExtensionConfigManager, ExtensionEntry}; +pub use extensions::{ + get_all_extension_names, get_all_extensions, get_enabled_extensions, get_extension_by_name, + is_extension_enabled, remove_extension, set_extension, set_extension_enabled, ExtensionEntry, +}; pub use permission::PermissionManager; pub use signup_openrouter::configure_openrouter; pub use signup_tetrate::configure_tetrate; From 086cb88569112dbc5e902b99ddca448db58c69cb Mon Sep 17 00:00:00 2001 From: Will Pfleger Date: Fri, 29 Aug 2025 17:16:22 -0400 Subject: [PATCH 06/16] Simplify pattern matching per PR feedback --- crates/goose-cli/src/session/builder.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/crates/goose-cli/src/session/builder.rs b/crates/goose-cli/src/session/builder.rs index e5e6befb61d5..6b0c06d53105 100644 --- a/crates/goose-cli/src/session/builder.rs +++ b/crates/goose-cli/src/session/builder.rs @@ -352,9 +352,10 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session { } else if session_config.resume { if let Some(session_file) = session_file.as_ref() { match session::read_metadata(session_file) { - Ok(metadata) if metadata.enabled_extensions.is_some() => { - metadata.enabled_extensions.unwrap().into_iter().collect() - } + Ok(SessionMetadata { + enabled_extensions: Some(extensions), + .. + }) => extensions.into_iter().collect(), _ => get_enabled_extensions(), } } else { From e2ccdf570142f8317b1a6baadef9d878ce13997f Mon Sep 17 00:00:00 2001 From: Will Pfleger Date: Fri, 29 Aug 2025 18:25:25 -0400 Subject: [PATCH 07/16] Cruft cleanup --- crates/goose/src/session/storage.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/crates/goose/src/session/storage.rs b/crates/goose/src/session/storage.rs index 5ff1a8361dc9..be6cb6a955e7 100644 --- a/crates/goose/src/session/storage.rs +++ b/crates/goose/src/session/storage.rs @@ -88,7 +88,6 @@ impl<'de> Deserialize<'de> for SessionMetadata { accumulated_input_tokens: Option, accumulated_output_tokens: Option, working_dir: Option, - todo_content: Option, // For backward compatibility #[serde(default)] extension_data: ExtensionData, } From f24eb00f793b7993dc1dd4deceff0601e29253a7 Mon Sep 17 00:00:00 2001 From: Will Pfleger Date: Wed, 1 Oct 2025 14:02:12 -0400 Subject: [PATCH 08/16] Cleanup after addressing merge conflicts --- crates/goose-cli/src/commands/acp.rs | 5 +- crates/goose-cli/src/session/builder.rs | 67 +---- crates/goose-cli/src/session/mod.rs | 5 +- .../src/routes/config_management.rs | 18 +- crates/goose/src/agents/agent.rs | 34 ++- crates/goose/src/agents/extension_manager.rs | 2 +- .../agents/recipe_tools/dynamic_task_tools.rs | 24 +- crates/goose/src/agents/subagent.rs | 7 +- crates/goose/src/session/extension_data.rs | 131 ++++++++- crates/goose/src/session/mod.rs | 2 +- crates/goose/src/session/session_manager.rs | 259 ++++++++++++++++++ 11 files changed, 443 insertions(+), 111 deletions(-) diff --git a/crates/goose-cli/src/commands/acp.rs b/crates/goose-cli/src/commands/acp.rs index f6664a048c5c..7717c5d5c92b 100644 --- a/crates/goose-cli/src/commands/acp.rs +++ b/crates/goose-cli/src/commands/acp.rs @@ -4,7 +4,7 @@ use agent_client_protocol::{ }; use anyhow::Result; use goose::agents::Agent; -use goose::config::{Config, ExtensionConfigManager}; +use goose::config::{get_all_extensions, Config}; use goose::conversation::message::{Message, MessageContent}; use goose::conversation::Conversation; use goose::providers::create; @@ -124,8 +124,7 @@ impl GooseAcpAgent { agent.update_provider(provider.clone()).await?; // Load and add extensions just like the normal CLI - let extensions_to_run: Vec<_> = ExtensionConfigManager::get_all() - .map_err(|e| anyhow::anyhow!("Failed to load extensions: {}", e))? + let extensions_to_run: Vec<_> = get_all_extensions() .into_iter() .filter(|ext| ext.enabled) .map(|ext| ext.config) diff --git a/crates/goose-cli/src/session/builder.rs b/crates/goose-cli/src/session/builder.rs index ca7b1ecff2d9..c2394ae44f5c 100644 --- a/crates/goose-cli/src/session/builder.rs +++ b/crates/goose-cli/src/session/builder.rs @@ -1,7 +1,7 @@ use super::output; use super::CliSession; use console::style; -use goose::agents::types::RetryConfig; +use goose::agents::types::{RetryConfig, SessionConfig}; use goose::agents::Agent; use goose::config::{get_all_extensions, get_enabled_extensions, Config, ExtensionConfig}; use goose::providers::create; @@ -10,7 +10,6 @@ use goose::session::SessionManager; use goose::session::{EnabledExtensionsState, ExtensionState}; use rustyline::EditMode; use std::collections::HashSet; -use std::path::PathBuf; use std::process; use std::sync::Arc; use tokio::task::JoinSet; @@ -327,7 +326,9 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession { match SessionManager::get_session(session_id, false).await { Ok(session_data) => { // Try to load saved extension configs directly - if let Some(saved_state) = EnabledExtensionsState::from_extension_data(&session_data.extension_data) { + if let Some(saved_state) = + EnabledExtensionsState::from_extension_data(&session_data.extension_data) + { // Use the saved configs as-is (no lookup needed!) saved_state.extensions } else { @@ -397,28 +398,17 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession { // Save extension state after loading all extensions if let Some(session_id) = session_id.as_ref() { - let loaded_extension_configs = agent_ptr - .extension_manager - .get_extension_configs() - .await; - - if !loaded_extension_configs.is_empty() { - let extensions_state = EnabledExtensionsState::new(loaded_extension_configs); - - // Load current session - if let Ok(mut session_data) = SessionManager::get_session(session_id, false).await { - // Update extension data - if extensions_state.to_extension_data(&mut session_data.extension_data).is_ok() { - // Save back to database - if let Err(e) = SessionManager::update_session(session_id) - .extension_data(session_data.extension_data) - .apply() - .await - { - tracing::warn!("Failed to save initial extension state: {}", e); - } - } - } + let session_config = SessionConfig { + id: session_id.clone(), + working_dir: std::env::current_dir().unwrap_or_default(), + schedule_id: None, + execution_mode: None, + max_turns: None, + retry_config: None, + }; + + if let Err(e) = agent_ptr.save_extension_state(&Some(session_config)).await { + tracing::warn!("Failed to save initial extension state: {}", e); } } @@ -571,33 +561,6 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession { session.agent.override_system_prompt(override_prompt).await; } - // Prepare metadata with extension configurations (will be persisted with first message) - if let Some(session_file) = &session_file { - let all_extension_configs = session.agent.get_extension_configs().await; - - // Prepare metadata to be persisted with first message - let mut startup_metadata = if session_config.resume { - // For resumed sessions, load existing metadata or create new one - session::read_metadata(session_file).unwrap_or_else(|_| { - SessionMetadata::new(std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."))) - }) - } else { - // For new sessions, create fresh metadata - SessionMetadata::new(std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."))) - }; - - // Update metadata with extension configurations - if !all_extension_configs.is_empty() { - let enabled_extensions_state = EnabledExtensionsState::new(all_extension_configs); - if let Err(e) = enabled_extensions_state.to_extension_data(&mut startup_metadata.extension_data) { - tracing::warn!("Failed to save enabled extensions to metadata: {}", e); - } - } - - // Set the prepared metadata on the session - session.startup_metadata = Some(startup_metadata); - } - // Display session information unless in quiet mode if !session_config.quiet { output::display_session_info( diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 9be08caad606..8c02bb40ae75 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -31,14 +31,13 @@ use goose::agents::types::RetryConfig; use goose::agents::{Agent, SessionConfig}; use goose::config::Config; use goose::providers::pricing::initialize_pricing_cache; -use goose::session::{self, SessionMetadata}; +use goose::session::SessionManager; use input::InputResult; use rmcp::model::PromptMessage; use rmcp::model::ServerNotification; use rmcp::model::{ErrorCode, ErrorData}; use goose::conversation::message::{Message, MessageContent}; -use goose::session::SessionManager; use rand::{distributions::Alphanumeric, Rng}; use rustyline::EditMode; use serde_json::Value; @@ -1460,7 +1459,7 @@ impl CliSession { ); } - pub async fn get_metadata(&self) -> Result { + pub async fn get_metadata(&self) -> Result { match &self.session_id { Some(id) => SessionManager::get_session(id, false).await, None => Err(anyhow::anyhow!("No session available")), diff --git a/crates/goose-server/src/routes/config_management.rs b/crates/goose-server/src/routes/config_management.rs index 3c4483124d67..cd429b378f9e 100644 --- a/crates/goose-server/src/routes/config_management.rs +++ b/crates/goose-server/src/routes/config_management.rs @@ -182,19 +182,8 @@ pub async fn read_config(Json(query): Json) -> Result Result, StatusCode> { - match ExtensionConfigManager::get_all() { - Ok(extensions) => Ok(Json(ExtensionResponse { extensions })), - Err(err) => { - if err - .downcast_ref::() - .is_some_and(|e| matches!(e, goose::config::base::ConfigError::DeserializeError(_))) - { - Err(StatusCode::UNPROCESSABLE_ENTITY) - } else { - Err(StatusCode::INTERNAL_SERVER_ERROR) - } - } - } + let extensions = goose::config::get_all_extensions(); + Ok(Json(ExtensionResponse { extensions })) } #[utoipa::path( @@ -211,8 +200,7 @@ pub async fn get_extensions() -> Result, StatusCode> { pub async fn add_extension( Json(extension_query): Json, ) -> Result, StatusCode> { - let extensions = - ExtensionConfigManager::get_all().map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + let extensions = goose::config::get_all_extensions(); let key = goose::config::extensions::name_to_key(&extension_query.name); let is_update = extensions.iter().any(|e| e.config.key() == key); diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index d6310df70f38..84f574717992 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -31,7 +31,7 @@ use crate::agents::tool_route_manager::ToolRouteManager; use crate::agents::tool_router_index_manager::ToolRouterIndexManager; use crate::agents::types::SessionConfig; use crate::agents::types::{FrontendTool, ToolResultReceiver}; -use crate::config::{Config, ExtensionConfigManager}; +use crate::config::{get_enabled_extensions, get_extension_by_name, Config}; use crate::context_mgmt::auto_compact; use crate::conversation::{debug_conversation_fix, fix_conversation, Conversation}; use crate::mcp_utils::ToolResult; @@ -649,25 +649,26 @@ impl Agent { /// Save current extension state to session metadata /// Should be called after any extension add/remove operation - async fn save_extension_state(&self, session: &Option) -> Result<()> { + pub async fn save_extension_state(&self, session: &Option) -> Result<()> { if let Some(session_config) = session { - let extension_configs = self.extension_manager - .get_extension_configs() - .await; + let extension_configs = self.extension_manager.get_extension_configs().await; let extensions_state = EnabledExtensionsState::new(extension_configs); // Load current session - if let Ok(mut session_data) = SessionManager::get_session(&session_config.id, false).await { - // Update extension data - if extensions_state.to_extension_data(&mut session_data.extension_data).is_ok() { - // Save back to database - SessionManager::update_session(&session_config.id) - .extension_data(session_data.extension_data) - .apply() - .await?; - } + let mut session_data = SessionManager::get_session(&session_config.id, false).await?; + + // Update extension data + if let Err(e) = extensions_state.to_extension_data(&mut session_data.extension_data) { + warn!("Failed to serialize extension state: {}", e); + return Err(anyhow!("Extension state serialization failed: {}", e)); } + + // Save back to database + SessionManager::update_session(&session_config.id) + .extension_data(session_data.extension_data) + .apply() + .await?; } Ok(()) } @@ -776,7 +777,10 @@ impl Agent { // Save extension state after successful operation if result.is_ok() { if let Err(e) = self.save_extension_state(session).await { - warn!("Failed to save extension state after manage_extensions: {}", e); + warn!( + "Failed to save extension state after manage_extensions: {}", + e + ); } } diff --git a/crates/goose/src/agents/extension_manager.rs b/crates/goose/src/agents/extension_manager.rs index 23ec63290177..0fe27913d01d 100644 --- a/crates/goose/src/agents/extension_manager.rs +++ b/crates/goose/src/agents/extension_manager.rs @@ -26,7 +26,7 @@ use super::tool_execution::ToolCallResult; use crate::agents::extension::{Envs, ProcessExit}; use crate::agents::extension_malware_check; use crate::agents::mcp_client::{McpClient, McpClientTrait}; -use crate::config::{Config, ExtensionConfigManager}; +use crate::config::{get_all_extensions, Config}; use crate::oauth::oauth_flow; use crate::prompt_template; use rmcp::model::{ diff --git a/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs b/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs index bda327fdf7df..69fcb81975fa 100644 --- a/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs +++ b/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs @@ -109,23 +109,15 @@ fn process_extensions( for ext in arr { if let Some(name_str) = ext.as_str() { // Look up the full extension config by name - match crate::config::ExtensionConfigManager::get_config_by_name(name_str) { - Ok(Some(config)) => { - // Check if the extension is enabled - if crate::config::ExtensionConfigManager::is_enabled(&config.key()) - .unwrap_or(false) - { - converted_extensions.push(config); - } else { - tracing::warn!("Extension '{}' is disabled, skipping", name_str); - } - } - Ok(None) => { - tracing::warn!("Extension '{}' not found in configuration", name_str); - } - Err(e) => { - tracing::warn!("Error looking up extension '{}': {}", name_str, e); + if let Some(config) = crate::config::get_extension_by_name(name_str) { + // Check if the extension is enabled + if crate::config::is_extension_enabled(&config.key()) { + converted_extensions.push(config); + } else { + tracing::warn!("Extension '{}' is disabled, skipping", name_str); } + } else { + tracing::warn!("Extension '{}' not found in configuration", name_str); } } else if let Ok(ext_config) = serde_json::from_value::(ext.clone()) { converted_extensions.push(ext_config); diff --git a/crates/goose/src/agents/subagent.rs b/crates/goose/src/agents/subagent.rs index 5363e48f99e4..7daac47255d8 100644 --- a/crates/goose/src/agents/subagent.rs +++ b/crates/goose/src/agents/subagent.rs @@ -1,7 +1,7 @@ use crate::agents::subagent_task_config::DEFAULT_SUBAGENT_MAX_TURNS; use crate::{ agents::{extension_manager::ExtensionManager, Agent, TaskConfig}, - config::get_enabled_extensions, + config::get_all_extensions, prompt_template::render_global_file, providers::errors::ProviderError, }; @@ -67,12 +67,11 @@ impl SubAgent { extensions.clone() } else { // Default behavior: use all enabled extensions - ExtensionConfigManager::get_all() - .unwrap_or_default() + get_all_extensions() .into_iter() .filter(|ext| ext.enabled) .map(|ext| ext.config) - .collect::>() + .collect() }; // Add the determined extensions to the subagent's extension manager diff --git a/crates/goose/src/session/extension_data.rs b/crates/goose/src/session/extension_data.rs index 7eeb1d98afe1..b90776bd9ff9 100644 --- a/crates/goose/src/session/extension_data.rs +++ b/crates/goose/src/session/extension_data.rs @@ -1,12 +1,12 @@ // Extension data management for sessions // Provides a simple way to store extension-specific data with versioned keys +use crate::config::ExtensionConfig; use anyhow::Result; use serde::{Deserialize, Serialize}; use serde_json::Value; use std::collections::HashMap; use utoipa::ToSchema; -use crate::config::ExtensionConfig; /// Extension data containing all extension states /// Keys are in format "extension_name.version" (e.g., "todo.v0") @@ -189,4 +189,133 @@ mod tests { Some(&json!({"key": "value"})) ); } + + #[test] + fn test_enabled_extensions_state_with_full_configs() { + use crate::agents::extension::Envs; + use std::collections::HashMap; + + // Create multiple ExtensionConfig objects with different types + let configs = vec![ + ExtensionConfig::Builtin { + name: "developer".to_string(), + display_name: Some("Developer Tools".to_string()), + description: Some("Built-in developer extension".to_string()), + timeout: Some(30), + bundled: Some(true), + available_tools: vec!["read_file".to_string(), "write_file".to_string()], + }, + ExtensionConfig::Stdio { + name: "custom_mcp".to_string(), + cmd: "python".to_string(), + args: vec!["-m".to_string(), "mcp_server".to_string()], + envs: { + let mut map = HashMap::new(); + map.insert("API_KEY".to_string(), "test123".to_string()); + Envs::new(map) + }, + env_keys: vec!["API_KEY".to_string()], + timeout: Some(60), + description: Some("Custom MCP server".to_string()), + bundled: Some(false), + available_tools: vec!["custom_tool".to_string()], + }, + ]; + + // Create EnabledExtensionsState + let state = EnabledExtensionsState::new(configs.clone()); + + // Verify basic properties + assert_eq!(state.extensions.len(), 2); + assert_eq!(state.extensions[0].name(), "developer"); + assert_eq!(state.extensions[1].name(), "custom_mcp"); + + // Test round-trip serialization through ExtensionData + let mut data = ExtensionData::default(); + state.to_extension_data(&mut data).unwrap(); + + // Verify the state was saved + assert!(data + .get_extension_state("enabled_extensions", "v0") + .is_some()); + + // Restore from ExtensionData + let restored = EnabledExtensionsState::from_extension_data(&data).unwrap(); + + // Verify all extensions were restored + assert_eq!(restored.extensions.len(), 2); + + // Verify first extension (Builtin) details preserved + match &restored.extensions[0] { + ExtensionConfig::Builtin { + name, + display_name, + description, + timeout, + bundled, + available_tools, + } => { + assert_eq!(name, "developer"); + assert_eq!(display_name, &Some("Developer Tools".to_string())); + assert_eq!( + description, + &Some("Built-in developer extension".to_string()) + ); + assert_eq!(timeout, &Some(30)); + assert_eq!(bundled, &Some(true)); + assert_eq!(available_tools.len(), 2); + assert_eq!(available_tools[0], "read_file"); + } + _ => panic!("Expected Builtin variant"), + } + + // Verify second extension (Stdio) details preserved + match &restored.extensions[1] { + ExtensionConfig::Stdio { + name, + cmd, + args, + envs, + env_keys, + timeout, + description, + bundled, + available_tools, + } => { + assert_eq!(name, "custom_mcp"); + assert_eq!(cmd, "python"); + assert_eq!(args.len(), 2); + assert_eq!(args[0], "-m"); + assert_eq!(envs.get_env().get("API_KEY"), Some(&"test123".to_string())); + assert_eq!(env_keys[0], "API_KEY"); + assert_eq!(timeout, &Some(60)); + assert_eq!(description, &Some("Custom MCP server".to_string())); + assert_eq!(bundled, &Some(false)); + assert_eq!(available_tools[0], "custom_tool"); + } + _ => panic!("Expected Stdio variant"), + } + } + + #[test] + fn test_enabled_extensions_state_missing_data() { + // Test loading from ExtensionData without enabled_extensions + let data = ExtensionData::default(); + let result = EnabledExtensionsState::from_extension_data(&data); + + // Should return None when the key doesn't exist + assert!(result.is_none()); + } + + #[test] + fn test_enabled_extensions_state_corrupt_data() { + // Test loading from ExtensionData with corrupt data + let mut data = ExtensionData::default(); + data.set_extension_state("enabled_extensions", "v0", json!("invalid json string")); + + let result = EnabledExtensionsState::from_extension_data(&data); + + // Should return None when deserialization fails + assert!(result.is_none()); + } } diff --git a/crates/goose/src/session/mod.rs b/crates/goose/src/session/mod.rs index a8641bae066d..221f89d70b1b 100644 --- a/crates/goose/src/session/mod.rs +++ b/crates/goose/src/session/mod.rs @@ -2,5 +2,5 @@ pub mod extension_data; mod legacy; pub mod session_manager; -pub use session_manager::{Session, SessionInsights, SessionManager}; pub use extension_data::{EnabledExtensionsState, ExtensionData, ExtensionState, TodoState}; +pub use session_manager::{Session, SessionInsights, SessionManager}; diff --git a/crates/goose/src/session/session_manager.rs b/crates/goose/src/session/session_manager.rs index 9a31c4ea43e0..39b5180611f9 100644 --- a/crates/goose/src/session/session_manager.rs +++ b/crates/goose/src/session/session_manager.rs @@ -858,3 +858,262 @@ impl SessionStorage { }) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::agents::extension::Envs; + use crate::config::ExtensionConfig; + use crate::session::extension_data::{EnabledExtensionsState, ExtensionState}; + use std::collections::HashMap; + use tempfile::tempdir; + + /// Helper to create a SessionStorage instance with a temporary database + async fn create_test_storage() -> Result<(SessionStorage, tempfile::TempDir)> { + let temp_dir = tempdir()?; + let db_path = temp_dir.path().join("test_sessions.db"); + let storage = SessionStorage::create(&db_path).await?; + Ok((storage, temp_dir)) + } + + #[tokio::test] + async fn test_extension_data_persistence() -> Result<()> { + let (storage, _temp_dir) = create_test_storage().await?; + + // Create a session + let session_id = "test_session_001"; + let working_dir = PathBuf::from("/test/dir"); + + sqlx::query( + r#" + INSERT INTO sessions (id, description, working_dir, extension_data) + VALUES (?, ?, ?, '{}') + "#, + ) + .bind(session_id) + .bind("Test session") + .bind(working_dir.to_string_lossy().as_ref()) + .execute(&storage.pool) + .await?; + + // Create extension_data with some state + let mut extension_data = ExtensionData::default(); + extension_data.set_extension_state("test_ext", "v1", serde_json::json!({"key": "value"})); + + // Update the session with extension_data + let extension_data_json = serde_json::to_string(&extension_data)?; + sqlx::query("UPDATE sessions SET extension_data = ? WHERE id = ?") + .bind(&extension_data_json) + .bind(session_id) + .execute(&storage.pool) + .await?; + + // Read back the session + let session = storage.get_session(session_id, false).await?; + + // Verify extension_data was persisted correctly + assert_eq!( + session.extension_data.get_extension_state("test_ext", "v1"), + Some(&serde_json::json!({"key": "value"})) + ); + + Ok(()) + } + + #[tokio::test] + async fn test_enabled_extensions_state_database_roundtrip() -> Result<()> { + let (storage, _temp_dir) = create_test_storage().await?; + + // Create test extension configs + let configs = vec![ + ExtensionConfig::Builtin { + name: "developer".to_string(), + display_name: Some("Developer Tools".to_string()), + description: Some("Built-in developer extension".to_string()), + timeout: Some(30), + bundled: Some(true), + available_tools: vec!["read_file".to_string(), "write_file".to_string()], + }, + ExtensionConfig::Stdio { + name: "custom_mcp".to_string(), + cmd: "python".to_string(), + args: vec!["-m".to_string(), "mcp_server".to_string()], + envs: { + let mut map = HashMap::new(); + map.insert("API_KEY".to_string(), "test123".to_string()); + Envs::new(map) + }, + env_keys: vec!["API_KEY".to_string()], + timeout: Some(60), + description: Some("Custom MCP server".to_string()), + bundled: Some(false), + available_tools: vec!["custom_tool".to_string()], + }, + ]; + + let extensions_state = EnabledExtensionsState::new(configs.clone()); + + // Create a session + let session_id = "test_session_002"; + let working_dir = PathBuf::from("/test/dir"); + + sqlx::query( + r#" + INSERT INTO sessions (id, description, working_dir, extension_data) + VALUES (?, ?, ?, '{}') + "#, + ) + .bind(session_id) + .bind("Test session with extensions") + .bind(working_dir.to_string_lossy().as_ref()) + .execute(&storage.pool) + .await?; + + // Save extension state to database + let mut extension_data = ExtensionData::default(); + extensions_state.to_extension_data(&mut extension_data)?; + + let extension_data_json = serde_json::to_string(&extension_data)?; + sqlx::query("UPDATE sessions SET extension_data = ? WHERE id = ?") + .bind(&extension_data_json) + .bind(session_id) + .execute(&storage.pool) + .await?; + + // Read back the session + let session = storage.get_session(session_id, false).await?; + + // Restore EnabledExtensionsState from database + let restored_state = EnabledExtensionsState::from_extension_data(&session.extension_data) + .expect("Failed to restore extension state"); + + // Verify all extensions were restored correctly + assert_eq!(restored_state.extensions.len(), 2); + + // Verify first extension (Builtin) + match &restored_state.extensions[0] { + ExtensionConfig::Builtin { + name, + display_name, + timeout, + bundled, + available_tools, + .. + } => { + assert_eq!(name, "developer"); + assert_eq!(display_name, &Some("Developer Tools".to_string())); + assert_eq!(timeout, &Some(30)); + assert_eq!(bundled, &Some(true)); + assert_eq!(available_tools.len(), 2); + } + _ => panic!("Expected Builtin variant"), + } + + // Verify second extension (Stdio) + match &restored_state.extensions[1] { + ExtensionConfig::Stdio { + name, + cmd, + envs, + timeout, + .. + } => { + assert_eq!(name, "custom_mcp"); + assert_eq!(cmd, "python"); + assert_eq!(envs.get_env().get("API_KEY"), Some(&"test123".to_string())); + assert_eq!(timeout, &Some(60)); + } + _ => panic!("Expected Stdio variant"), + } + + Ok(()) + } + + #[tokio::test] + async fn test_multiple_extension_states_in_database() -> Result<()> { + let (storage, _temp_dir) = create_test_storage().await?; + + // Create a session + let session_id = "test_session_003"; + let working_dir = PathBuf::from("/test/dir"); + + sqlx::query( + r#" + INSERT INTO sessions (id, description, working_dir, extension_data) + VALUES (?, ?, ?, '{}') + "#, + ) + .bind(session_id) + .bind("Test session") + .bind(working_dir.to_string_lossy().as_ref()) + .execute(&storage.pool) + .await?; + + // Create extension_data with multiple states + let mut extension_data = ExtensionData::default(); + extension_data.set_extension_state( + "state_one", + "v0", + serde_json::json!({"data": "value1"}), + ); + extension_data.set_extension_state( + "state_two", + "v1", + serde_json::json!({"data": "value2"}), + ); + + // Update via direct SQL + let extension_data_json = serde_json::to_string(&extension_data)?; + sqlx::query("UPDATE sessions SET extension_data = ? WHERE id = ?") + .bind(&extension_data_json) + .bind(session_id) + .execute(&storage.pool) + .await?; + + // Verify the update + let session = storage.get_session(session_id, false).await?; + assert_eq!( + session + .extension_data + .get_extension_state("state_one", "v0"), + Some(&serde_json::json!({"data": "value1"})) + ); + assert_eq!( + session + .extension_data + .get_extension_state("state_two", "v1"), + Some(&serde_json::json!({"data": "value2"})) + ); + + Ok(()) + } + + #[tokio::test] + async fn test_extension_data_empty_by_default() -> Result<()> { + let (storage, _temp_dir) = create_test_storage().await?; + + // Create a session without explicitly setting extension_data + let session_id = "test_session_004"; + let working_dir = PathBuf::from("/test/dir"); + + sqlx::query( + r#" + INSERT INTO sessions (id, description, working_dir) + VALUES (?, ?, ?) + "#, + ) + .bind(session_id) + .bind("Test session") + .bind(working_dir.to_string_lossy().as_ref()) + .execute(&storage.pool) + .await?; + + // Read the session + let session = storage.get_session(session_id, false).await?; + + // Verify extension_data is empty by default + assert_eq!(session.extension_data.extension_states.len(), 0); + + Ok(()) + } +} From 973dd9897eee32b9672abe2b02d2b9156aafa6ef Mon Sep 17 00:00:00 2001 From: Will Pfleger Date: Wed, 1 Oct 2025 15:16:46 -0400 Subject: [PATCH 09/16] Botched merge conflict resolution, this shouldn't have been recreated --- crates/goose/src/session/storage.rs | 1978 --------------------------- 1 file changed, 1978 deletions(-) delete mode 100644 crates/goose/src/session/storage.rs diff --git a/crates/goose/src/session/storage.rs b/crates/goose/src/session/storage.rs deleted file mode 100644 index be6cb6a955e7..000000000000 --- a/crates/goose/src/session/storage.rs +++ /dev/null @@ -1,1978 +0,0 @@ -// IMPORTANT: This file includes session recovery functionality to handle corrupted session files. -// Only essential logging is included with the [SESSION] prefix to track: -// - Total message counts -// - Corruption detection and recovery -// - Backup creation -// Additional debug logging can be added if needed for troubleshooting. - -use crate::conversation::message::Message; -use crate::conversation::Conversation; -use crate::providers::base::Provider; -use crate::session::extension_data::ExtensionData; -use crate::utils::safe_truncate; -use anyhow::Result; -use chrono::Local; -use etcetera::{choose_app_strategy, AppStrategy, AppStrategyArgs}; -use regex::Regex; -use serde::{Deserialize, Serialize}; -use std::fs; -use std::io::{self, BufRead, Write}; -use std::ops::DerefMut; -use std::path::{Path, PathBuf}; -use std::sync::Arc; -use utoipa::ToSchema; - -// Security limits -const MAX_FILE_SIZE: u64 = 10 * 1024 * 1024; // 10MB -const MAX_MESSAGE_COUNT: usize = 5000; -const MAX_LINE_LENGTH: usize = 1024 * 1024; // 1MB per line - -fn get_home_dir() -> PathBuf { - choose_app_strategy(crate::config::APP_STRATEGY.clone()) - .expect("goose requires a home dir") - .home_dir() - .to_path_buf() -} - -fn get_current_working_dir() -> PathBuf { - std::env::current_dir() - .or_else(|_| Ok::(get_home_dir())) - .expect("could not determine the current working directory") -} - -/// Metadata for a session, stored as the first line in the session file -#[derive(Debug, Clone, Serialize, ToSchema)] -pub struct SessionMetadata { - /// Working directory for the session - #[schema(value_type = String, example = "/home/user/sessions/session1")] - pub working_dir: PathBuf, - /// A short description of the session, typically 3 words or less - pub description: String, - /// ID of the schedule that triggered this session, if any - pub schedule_id: Option, - - /// Number of messages in the session - pub message_count: usize, - /// The total number of tokens used in the session. Retrieved from the provider's last usage. - pub total_tokens: Option, - /// The number of input tokens used in the session. Retrieved from the provider's last usage. - pub input_tokens: Option, - /// The number of output tokens used in the session. Retrieved from the provider's last usage. - pub output_tokens: Option, - /// The total number of tokens used in the session. Accumulated across all messages (useful for tracking cost over an entire session). - pub accumulated_total_tokens: Option, - /// The number of input tokens used in the session. Accumulated across all messages. - pub accumulated_input_tokens: Option, - /// The number of output tokens used in the session. Accumulated across all messages. - pub accumulated_output_tokens: Option, - /// Extension data containing extension states - #[serde(default)] - pub extension_data: ExtensionData, -} - -// Custom deserializer to handle old sessions without working_dir -impl<'de> Deserialize<'de> for SessionMetadata { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - #[derive(Deserialize)] - struct Helper { - description: String, - message_count: usize, - schedule_id: Option, - total_tokens: Option, - input_tokens: Option, - output_tokens: Option, - accumulated_total_tokens: Option, - accumulated_input_tokens: Option, - accumulated_output_tokens: Option, - working_dir: Option, - #[serde(default)] - extension_data: ExtensionData, - } - - let helper = Helper::deserialize(deserializer)?; - - // Get working dir, falling back to home if not specified or if specified dir doesn't exist - let working_dir = helper - .working_dir - .filter(|path| path.exists()) - .unwrap_or_else(get_current_working_dir); - - Ok(SessionMetadata { - description: helper.description, - message_count: helper.message_count, - schedule_id: helper.schedule_id, - total_tokens: helper.total_tokens, - input_tokens: helper.input_tokens, - output_tokens: helper.output_tokens, - accumulated_total_tokens: helper.accumulated_total_tokens, - accumulated_input_tokens: helper.accumulated_input_tokens, - accumulated_output_tokens: helper.accumulated_output_tokens, - working_dir, - extension_data: helper.extension_data, - }) - } -} - -impl SessionMetadata { - pub fn new(working_dir: PathBuf) -> Self { - // If working_dir doesn't exist, fall back to home directory - let working_dir = if !working_dir.exists() { - get_home_dir() - } else { - working_dir - }; - - Self { - working_dir, - description: String::new(), - schedule_id: None, - message_count: 0, - total_tokens: None, - input_tokens: None, - output_tokens: None, - accumulated_total_tokens: None, - accumulated_input_tokens: None, - accumulated_output_tokens: None, - extension_data: ExtensionData::new(), - } - } -} - -impl Default for SessionMetadata { - fn default() -> Self { - Self::new(get_current_working_dir()) - } -} - -// The single app name used for all Goose applications -const APP_NAME: &str = "goose"; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum Identifier { - Name(String), - Path(PathBuf), -} - -pub fn get_path(id: Identifier) -> Result { - let path = match id { - Identifier::Name(name) => { - // Validate session name for security - if name.is_empty() || name.len() > 255 { - return Err(anyhow::anyhow!("Invalid session name length")); - } - - // Check for path traversal attempts - if name.contains("..") || name.contains('/') || name.contains('\\') { - return Err(anyhow::anyhow!("Invalid characters in session name")); - } - - let session_dir = ensure_session_dir().map_err(|e| { - tracing::error!("Failed to create session directory: {}", e); - anyhow::anyhow!("Failed to access session directory") - })?; - session_dir.join(format!("{}.jsonl", name)) - } - Identifier::Path(path) => { - // In test mode, allow temporary directory paths - #[cfg(test)] - { - if let Some(path_str) = path.to_str() { - if path_str.contains("/tmp") || path_str.contains("/.tmp") { - // Allow test temporary directories - return Ok(path); - } - } - } - - // Validate that the path is within allowed directories - let session_dir = ensure_session_dir().map_err(|e| { - tracing::error!("Failed to create session directory: {}", e); - anyhow::anyhow!("Failed to access session directory") - })?; - - // Handle path validation with Windows-compatible logic - let is_path_allowed = validate_path_within_session_dir(&path, &session_dir)?; - if !is_path_allowed { - tracing::warn!( - "Attempted access outside session directory: {:?} not within {:?}", - path, - session_dir - ); - return Err(anyhow::anyhow!("Path not allowed")); - } - - path - } - }; - - // Additional security check for file extension (skip for special no-session paths) - if let Some(ext) = path.extension() { - if ext != "jsonl" { - return Err(anyhow::anyhow!("Invalid file extension")); - } - } - - Ok(path) -} - -/// Validate that a path is within the session directory, with Windows-compatible logic -/// -/// This function handles Windows-specific path issues like: -/// - UNC path conversion during canonicalization -/// - Case sensitivity differences -/// - Path separator normalization -/// - Drive letter casing inconsistencies -fn validate_path_within_session_dir(path: &Path, session_dir: &Path) -> Result { - // First, try the simple case - if canonicalization works cleanly - if let (Ok(canonical_path), Ok(canonical_session_dir)) = - (path.canonicalize(), session_dir.canonicalize()) - { - if canonical_path.starts_with(&canonical_session_dir) { - return Ok(true); - } - } - - // Fallback approach for Windows: normalize paths manually - let normalized_path = normalize_path_for_comparison(path); - let normalized_session_dir = normalize_path_for_comparison(session_dir); - - // Check if the normalized path starts with the normalized session directory - if normalized_path.starts_with(&normalized_session_dir) { - return Ok(true); - } - - // Additional check: if the path doesn't exist yet, check its parent directory - if !path.exists() { - if let Some(parent) = path.parent() { - return validate_path_within_session_dir(parent, session_dir); - } - } - - Ok(false) -} - -/// Normalize a path for cross-platform comparison -/// -/// This handles Windows-specific issues like: -/// - Converting to absolute paths -/// - Normalizing path separators -/// - Handling case sensitivity -fn normalize_path_for_comparison(path: &Path) -> PathBuf { - // Try to canonicalize first, but fall back to absolute path if that fails - let absolute_path = if let Ok(canonical) = path.canonicalize() { - canonical - } else if let Ok(absolute) = path.to_path_buf().canonicalize() { - absolute - } else { - // Last resort: try to make it absolute manually - if path.is_absolute() { - path.to_path_buf() - } else { - // If we can't make it absolute, use the current directory - std::env::current_dir() - .unwrap_or_else(|_| PathBuf::from(".")) - .join(path) - } - }; - - // On Windows, normalize the path representation - #[cfg(windows)] - { - // Convert the path to components and rebuild it normalized - let components: Vec<_> = absolute_path.components().collect(); - let mut normalized = PathBuf::new(); - - for component in components { - match component { - std::path::Component::Prefix(prefix) => { - // Handle drive letters and UNC paths - let prefix_str = prefix.as_os_str().to_string_lossy(); - if prefix_str.starts_with("\\\\?\\") { - // Remove UNC prefix and add the drive letter normally - let clean_prefix = &prefix_str[4..]; - normalized.push(clean_prefix); - } else { - normalized.push(component); - } - } - std::path::Component::RootDir => { - normalized.push(component); - } - std::path::Component::CurDir | std::path::Component::ParentDir => { - // Skip these as they should be resolved by canonicalization - continue; - } - std::path::Component::Normal(name) => { - // Normalize case for Windows - let name_str = name.to_string_lossy().to_lowercase(); - normalized.push(name_str); - } - } - } - - normalized - } - - #[cfg(not(windows))] - { - absolute_path - } -} - -/// Ensure the session directory exists and return its path -pub fn ensure_session_dir() -> Result { - let app_strategy = AppStrategyArgs { - top_level_domain: "Block".to_string(), - author: "Block".to_string(), - app_name: APP_NAME.to_string(), - }; - - let data_dir = choose_app_strategy(app_strategy) - .expect("goose requires a home dir") - .data_dir() - .join("sessions"); - - if !data_dir.exists() { - fs::create_dir_all(&data_dir)?; - } - - Ok(data_dir) -} - -/// Get the path to the most recently modified session file -pub fn get_most_recent_session() -> Result { - let session_dir = ensure_session_dir()?; - let mut entries = fs::read_dir(&session_dir)? - .filter_map(|entry| entry.ok()) - .filter(|entry| entry.path().extension().is_some_and(|ext| ext == "jsonl")) - .collect::>(); - - if entries.is_empty() { - return Err(anyhow::anyhow!("No session files found")); - } - - // Sort by modification time, most recent first - entries.sort_by(|a, b| { - b.metadata() - .and_then(|m| m.modified()) - .unwrap_or(std::time::SystemTime::UNIX_EPOCH) - .cmp( - &a.metadata() - .and_then(|m| m.modified()) - .unwrap_or(std::time::SystemTime::UNIX_EPOCH), - ) - }); - - Ok(entries[0].path()) -} - -/// List all available session files -pub fn list_sessions() -> Result> { - let session_dir = ensure_session_dir()?; - let entries = fs::read_dir(&session_dir)? - .filter_map(|entry| { - let entry = entry.ok()?; - let path = entry.path(); - - if path.extension().is_some_and(|ext| ext == "jsonl") { - let name = path.file_stem()?.to_string_lossy().to_string(); - Some((name, path)) - } else { - None - } - }) - .collect::>(); - - Ok(entries) -} - -/// Generate a session ID using timestamp format (yyyymmdd_hhmmss) -pub fn generate_session_id() -> String { - Local::now().format("%Y%m%d_%H%M%S").to_string() -} - -/// Read messages from a session file with corruption recovery -/// -/// Creates the file if it doesn't exist, reads and deserializes all messages if it does. -/// The first line of the file is expected to be metadata, and the rest are messages. -/// Large messages are automatically truncated to prevent memory issues. -/// Includes recovery mechanisms for corrupted files. -/// -/// Security features: -/// - Validates file paths to prevent directory traversal -/// - Includes all security limits from read_messages_with_truncation -pub fn read_messages(session_file: &Path) -> Result { - // Validate the path for security - let secure_path = get_path(Identifier::Path(session_file.to_path_buf()))?; - - let result = read_messages_with_truncation(&secure_path, Some(50000)); // 50KB limit per message content - match &result { - Ok(_messages) => {} - Err(e) => println!( - "[SESSION] Failed to read messages from {:?}: {}", - secure_path, e - ), - } - result -} - -/// Read messages from a session file with optional content truncation and corruption recovery -/// -/// Creates the file if it doesn't exist, reads and deserializes all messages if it does. -/// The first line of the file is expected to be metadata, and the rest are messages. -/// If max_content_size is Some, large message content will be truncated during loading. -/// Includes robust error handling and corruption recovery mechanisms. -/// -/// Security features: -/// - File size limits to prevent resource exhaustion -/// - Message count limits to prevent DoS attacks -/// - Line length restrictions to prevent memory issues -pub fn read_messages_with_truncation( - session_file: &Path, - max_content_size: Option, -) -> Result { - // Security check: file size limit - if session_file.exists() { - let metadata = fs::metadata(session_file)?; - if metadata.len() > MAX_FILE_SIZE { - tracing::warn!("Session file exceeds size limit: {} bytes", metadata.len()); - return Err(anyhow::anyhow!("Session file too large")); - } - } - - // Check if there's a backup file we should restore from - let backup_file = session_file.with_extension("backup"); - if !session_file.exists() && backup_file.exists() { - println!( - "[SESSION] Session file missing but backup exists, restoring from backup: {:?}", - backup_file - ); - tracing::warn!( - "Session file missing but backup exists, restoring from backup: {:?}", - backup_file - ); - if let Err(e) = fs::copy(&backup_file, session_file) { - println!("[SESSION] Failed to restore from backup: {}", e); - tracing::error!("Failed to restore from backup: {}", e); - } - } - - // Open the file with appropriate options - let file = fs::OpenOptions::new() - .read(true) - .write(true) - .create(true) - .truncate(false) - .open(session_file)?; - - let reader = io::BufReader::new(file); - let mut lines = reader.lines(); - let mut messages = Vec::new(); - let mut corrupted_lines = Vec::new(); - let mut line_number = 1; - let mut message_count = 0; - - // Read the first line as metadata or create default if empty/missing - if let Some(line_result) = lines.next() { - match line_result { - Ok(line) => { - // Security check: line length - if line.len() > MAX_LINE_LENGTH { - tracing::warn!("Line {} exceeds length limit", line_number); - return Err(anyhow::anyhow!("Line too long")); - } - - // Try to parse as metadata, but if it fails, treat it as a message - if let Ok(_metadata) = serde_json::from_str::(&line) { - // Metadata successfully parsed, continue with the rest of the lines as messages - } else { - // This is not metadata, it's a message - match parse_message_with_truncation(&line, max_content_size) { - Ok(message) => { - messages.push(message); - message_count += 1; - } - Err(e) => { - println!("[SESSION] Failed to parse first line as message: {}", e); - println!("[SESSION] Attempting to recover corrupted first line..."); - tracing::warn!("Failed to parse first line as message: {}", e); - - // Try to recover the corrupted line - match attempt_corruption_recovery(&line, max_content_size) { - Ok(recovered) => { - println!( - "[SESSION] Successfully recovered corrupted first line!" - ); - messages.push(recovered); - message_count += 1; - } - Err(recovery_err) => { - println!( - "[SESSION] Failed to recover corrupted first line: {}", - recovery_err - ); - corrupted_lines.push((line_number, line)); - } - } - } - } - } - } - Err(e) => { - println!("[SESSION] Failed to read first line: {}", e); - tracing::error!("Failed to read first line: {}", e); - corrupted_lines.push((line_number, "[Unreadable line]".to_string())); - } - } - line_number += 1; - } - - // Read the rest of the lines as messages - for line_result in lines { - // Security check: message count limit - if message_count >= MAX_MESSAGE_COUNT { - tracing::warn!("Message count limit reached: {}", MAX_MESSAGE_COUNT); - println!( - "[SESSION] Message count limit reached, stopping at {}", - MAX_MESSAGE_COUNT - ); - break; - } - - match line_result { - Ok(line) => { - // Security check: line length - if line.len() > MAX_LINE_LENGTH { - tracing::warn!("Line {} exceeds length limit", line_number); - corrupted_lines.push(( - line_number, - "[Line too long - truncated for security]".to_string(), - )); - line_number += 1; - continue; - } - - match parse_message_with_truncation(&line, max_content_size) { - Ok(message) => { - messages.push(message); - message_count += 1; - } - Err(e) => { - println!("[SESSION] Failed to parse line {}: {}", line_number, e); - println!( - "[SESSION] Attempting to recover corrupted line {}...", - line_number - ); - tracing::warn!("Failed to parse line {}: {}", line_number, e); - - // Try to recover the corrupted line - match attempt_corruption_recovery(&line, max_content_size) { - Ok(recovered) => { - println!( - "[SESSION] Successfully recovered corrupted line {}!", - line_number - ); - messages.push(recovered); - message_count += 1; - } - Err(recovery_err) => { - println!( - "[SESSION] Failed to recover corrupted line {}: {}", - line_number, recovery_err - ); - corrupted_lines.push((line_number, line)); - } - } - } - } - } - Err(e) => { - println!("[SESSION] Failed to read line {}: {}", line_number, e); - tracing::error!("Failed to read line {}: {}", line_number, e); - corrupted_lines.push((line_number, "[Unreadable line]".to_string())); - } - } - line_number += 1; - } - - // If we found corrupted lines, create a backup and log the issues - if !corrupted_lines.is_empty() { - println!( - "[SESSION] Found {} corrupted lines, creating backup", - corrupted_lines.len() - ); - tracing::warn!( - "Found {} corrupted lines in session file, creating backup", - corrupted_lines.len() - ); - - // Create a backup of the original file - if !backup_file.exists() { - if let Err(e) = fs::copy(session_file, &backup_file) { - println!("[SESSION] Failed to create backup file: {}", e); - tracing::error!("Failed to create backup file: {}", e); - } else { - println!("[SESSION] Created backup file: {:?}", backup_file); - tracing::info!("Created backup file: {:?}", backup_file); - } - } - - // Log details about corrupted lines (with limited detail for security) - for (num, line) in &corrupted_lines { - let preview = if line.len() > 50 { - format!("{}... (truncated)", safe_truncate(line, 50)) - } else { - line.clone() - }; - tracing::debug!("Corrupted line {}: {}", num, preview); - } - } - - Ok(Conversation::new_unvalidated(messages)) -} - -/// Parse a message from JSON string with optional content truncation -fn parse_message_with_truncation( - json_str: &str, - max_content_size: Option, -) -> Result { - // First try to parse normally - match serde_json::from_str::(json_str) { - Ok(mut message) => { - // If we have a size limit, check and truncate if needed - if let Some(max_size) = max_content_size { - truncate_message_content_in_place(&mut message, max_size); - } - Ok(message) - } - Err(_e) => { - // If parsing fails and the string is very long, it might be due to size - if json_str.len() > 100000 { - println!( - "[SESSION] Very large message detected ({}KB), attempting truncation", - json_str.len() / 1024 - ); - tracing::warn!( - "Failed to parse very large message ({}KB), attempting truncation", - json_str.len() / 1024 - ); - - // Try to truncate the JSON string itself before parsing - let truncated_json = if let Some(max_size) = max_content_size { - truncate_json_string(json_str, max_size) - } else { - json_str.to_string() - }; - - match serde_json::from_str::(&truncated_json) { - Ok(message) => { - tracing::info!("Successfully parsed message after JSON truncation"); - Ok(message) - } - Err(_) => { - println!( - "[SESSION] Failed to parse even after truncation, attempting recovery" - ); - tracing::error!("Failed to parse message even after truncation"); - attempt_corruption_recovery(json_str, max_content_size) - } - } - } else { - // Try intelligent corruption recovery - attempt_corruption_recovery(json_str, max_content_size) - } - } - } -} - -/// Truncate content within a message in place -fn truncate_message_content_in_place(message: &mut Message, max_content_size: usize) { - use crate::conversation::message::MessageContent; - use rmcp::model::{RawContent, ResourceContents}; - - for content in &mut message.content { - match content { - MessageContent::Text(text_content) => { - if text_content.text.chars().count() > max_content_size { - let truncated = format!( - "{}\n\n[... content truncated during session loading from {} to {} characters ...]", - safe_truncate(&text_content.text, max_content_size), - text_content.text.chars().count(), - max_content_size - ); - text_content.text = truncated; - } - } - MessageContent::ToolResponse(tool_response) => { - if let Ok(ref mut result) = tool_response.tool_result { - for content_item in result { - match content_item.deref_mut() { - RawContent::Text(ref mut text_content) => { - if text_content.text.chars().count() > max_content_size { - let truncated = format!( - "{}\n\n[... tool response truncated during session loading from {} to {} characters ...]", - safe_truncate(&text_content.text, max_content_size), - text_content.text.chars().count(), - max_content_size - ); - text_content.text = truncated; - } - } - RawContent::Resource(ref mut resource_content) => { - if let ResourceContents::TextResourceContents { text, .. } = - &mut resource_content.resource - { - if text.chars().count() > max_content_size { - let truncated = format!( - "{}\n\n[... resource content truncated during session loading from {} to {} characters ...]", - safe_truncate(text, max_content_size), - text.chars().count(), - max_content_size - ); - *text = truncated; - } - } - } - _ => {} // Other content types are typically smaller - } - } - } - } - _ => {} // Other content types are typically smaller - } - } -} - -/// Attempt to recover corrupted JSON lines using various strategies -fn attempt_corruption_recovery(json_str: &str, max_content_size: Option) -> Result { - // Strategy 1: Try to fix common JSON corruption issues - if let Ok(message) = try_fix_json_corruption(json_str, max_content_size) { - println!("[SESSION] Recovered using JSON corruption fix"); - return Ok(message); - } - - // Strategy 2: Try to extract partial content if it looks like a message - if let Ok(message) = try_extract_partial_message(json_str) { - println!("[SESSION] Recovered using partial message extraction"); - return Ok(message); - } - - // Strategy 3: Try to fix truncated JSON - if let Ok(message) = try_fix_truncated_json(json_str, max_content_size) { - println!("[SESSION] Recovered using truncated JSON fix"); - return Ok(message); - } - - // Strategy 4: Create a placeholder message with the raw content - println!("[SESSION] All recovery strategies failed, creating placeholder message"); - let preview = if json_str.len() > 200 { - format!("{}...", safe_truncate(json_str, 200)) - } else { - json_str.to_string() - }; - - Ok(Message::user().with_text(format!( - "[RECOVERED FROM CORRUPTED LINE]\nOriginal content preview: {}\n\n[This message was recovered from a corrupted session file line. The original data may be incomplete.]", - preview - ))) -} - -/// Try to fix common JSON corruption patterns -fn try_fix_json_corruption(json_str: &str, max_content_size: Option) -> Result { - let mut fixed_json = json_str.to_string(); - let mut fixes_applied = Vec::new(); - - // Fix 1: Remove trailing commas before closing braces/brackets - if fixed_json.contains(",}") || fixed_json.contains(",]") { - fixed_json = fixed_json.replace(",}", "}").replace(",]", "]"); - fixes_applied.push("trailing commas"); - } - - // Fix 2: Try to close unclosed quotes in text fields - if let Some(text_start) = fixed_json.find("\"text\":\"") { - let content_start = text_start + 8; - if let Some(remaining) = fixed_json.get(content_start..) { - // Count quotes to see if we have an odd number (unclosed quote) - let quote_count = remaining.matches('"').count(); - if quote_count % 2 == 1 { - // Find the last quote and see if we need to close it - if let Some(last_quote_pos) = remaining.rfind('"') { - let after_last_quote = &remaining[last_quote_pos + 1..]; - if !after_last_quote.trim_start().starts_with(',') - && !after_last_quote.trim_start().starts_with('}') - { - // Insert a closing quote before the next field or end - if let Some(next_field) = after_last_quote.find(',') { - fixed_json.insert(content_start + last_quote_pos + 1 + next_field, '"'); - fixes_applied.push("unclosed quotes"); - } else if after_last_quote.contains('}') { - if let Some(brace_pos) = after_last_quote.find('}') { - fixed_json - .insert(content_start + last_quote_pos + 1 + brace_pos, '"'); - fixes_applied.push("unclosed quotes"); - } - } - } - } - } - } - } - - // Fix 3: Try to close unclosed JSON objects/arrays - let open_braces = fixed_json.matches('{').count(); - let close_braces = fixed_json.matches('}').count(); - let open_brackets = fixed_json.matches('[').count(); - let close_brackets = fixed_json.matches(']').count(); - - if open_braces > close_braces { - for _ in 0..(open_braces - close_braces) { - fixed_json.push('}'); - } - fixes_applied.push("unclosed braces"); - } - - if open_brackets > close_brackets { - for _ in 0..(open_brackets - close_brackets) { - fixed_json.push(']'); - } - fixes_applied.push("unclosed brackets"); - } - - // Fix 4: Remove control characters that might break JSON parsing - let original_len = fixed_json.len(); - fixed_json = fixed_json - .chars() - .filter(|c| !c.is_control() || *c == '\n' || *c == '\r' || *c == '\t') - .collect(); - if fixed_json.len() != original_len { - fixes_applied.push("control characters"); - } - - if !fixes_applied.is_empty() { - match serde_json::from_str::(&fixed_json) { - Ok(mut message) => { - if let Some(max_size) = max_content_size { - truncate_message_content_in_place(&mut message, max_size); - } - return Ok(message); - } - Err(e) => { - println!("[SESSION] JSON fixes didn't work: {}", e); - } - } - } - - Err(anyhow::anyhow!("JSON corruption fixes failed")) -} - -/// Try to extract a partial message from corrupted JSON -fn try_extract_partial_message(json_str: &str) -> Result { - // Look for recognizable patterns that indicate this was a message - - // Try to extract role - let role = if json_str.contains("\"role\":\"user\"") { - rmcp::model::Role::User - } else if json_str.contains("\"role\":\"assistant\"") { - rmcp::model::Role::Assistant - } else { - rmcp::model::Role::User // Default fallback - }; - - // Try to extract text content - let mut extracted_text = String::new(); - - // Look for text field content - if let Some(text_start) = json_str.find("\"text\":\"") { - let content_start = text_start + 8; - if let Some(content_end) = json_str[content_start..].find("\",") { - extracted_text = json_str[content_start..content_start + content_end].to_string(); - } else if let Some(content_end) = json_str[content_start..].find("\"") { - extracted_text = json_str[content_start..content_start + content_end].to_string(); - } else { - // Take everything after "text":" until we hit a likely end - let remaining = &json_str[content_start..]; - if let Some(end_pos) = remaining.find('}') { - extracted_text = remaining[..end_pos].trim_end_matches('"').to_string(); - } else { - extracted_text = remaining.to_string(); - } - } - } - - // If we couldn't extract text, try to find any readable content - if extracted_text.is_empty() { - // Look for any quoted strings that might be content - let quote_pattern = Regex::new(r#""([^"]{10,})""#).unwrap(); - if let Some(captures) = quote_pattern.find(json_str) { - extracted_text = captures.as_str().trim_matches('"').to_string(); - } - } - - if !extracted_text.is_empty() { - let message = match role { - rmcp::model::Role::User => Message::user(), - rmcp::model::Role::Assistant => Message::assistant(), - }; - - return Ok(message.with_text(format!("[PARTIALLY RECOVERED] {}", extracted_text))); - } - - Err(anyhow::anyhow!("Could not extract partial message")) -} - -/// Try to fix truncated JSON by completing it -fn try_fix_truncated_json(json_str: &str, max_content_size: Option) -> Result { - let mut completed_json = json_str.to_string(); - - // If the JSON appears to be cut off mid-field, try to complete it - if !completed_json.trim().ends_with('}') && !completed_json.trim().ends_with(']') { - // Try to find where it was likely cut off - if let Some(last_quote) = completed_json.rfind('"') { - let after_quote = &completed_json[last_quote + 1..]; - if !after_quote.contains('"') && !after_quote.contains('}') { - // Looks like it was cut off in the middle of a string value - completed_json.push('"'); - - // Try to close the JSON structure - let open_braces = completed_json.matches('{').count(); - let close_braces = completed_json.matches('}').count(); - - for _ in 0..(open_braces - close_braces) { - completed_json.push('}'); - } - - match serde_json::from_str::(&completed_json) { - Ok(mut message) => { - if let Some(max_size) = max_content_size { - truncate_message_content_in_place(&mut message, max_size); - } - return Ok(message); - } - Err(e) => { - println!("[SESSION] Truncation fix didn't work: {}", e); - } - } - } - } - } - - Err(anyhow::anyhow!("Truncation fix failed")) -} - -/// Attempt to truncate a JSON string by finding and truncating large text values -fn truncate_json_string(json_str: &str, max_content_size: usize) -> String { - // This is a heuristic approach - look for large text values in the JSON - // and truncate them. This is not perfect but should handle the common case - // of large tool responses. - - if json_str.len() <= max_content_size * 2 { - return json_str.to_string(); - } - - // Try to find patterns that look like large text content - // Look for "text":"..." patterns and truncate the content - let mut result = json_str.to_string(); - - // Simple regex-like approach to find and truncate large text values - if let Some(start) = result.find("\"text\":\"") { - let text_start = start + 8; // Length of "text":" - if let Some(end) = result[text_start..].find("\",") { - let text_end = text_start + end; - let text_content = &result[text_start..text_end]; - - if text_content.len() > max_content_size { - let truncated_text = format!( - "{}\n\n[... content truncated during JSON parsing from {} to {} characters ...]", - safe_truncate(text_content, max_content_size), - text_content.len(), - max_content_size - ); - result.replace_range(text_start..text_end, &truncated_text); - } - } - } - - result -} - -/// Read session metadata from a session file with security validation -/// -/// Returns default empty metadata if the file doesn't exist or has no metadata. -/// Includes security checks for file access and content validation. -pub fn read_metadata(session_file: &Path) -> Result { - // Validate the path for security - let secure_path = get_path(Identifier::Path(session_file.to_path_buf()))?; - - if !secure_path.exists() { - return Ok(SessionMetadata::default()); - } - - // Security check: file size - let file_metadata = fs::metadata(&secure_path)?; - if file_metadata.len() > MAX_FILE_SIZE { - tracing::warn!("Session file exceeds size limit during metadata read"); - return Err(anyhow::anyhow!("Session file too large")); - } - - let file = fs::File::open(&secure_path).map_err(|e| { - tracing::error!("Failed to open session file for metadata read: {}", e); - anyhow::anyhow!("Failed to access session file") - })?; - let mut reader = io::BufReader::new(file); - let mut first_line = String::new(); - - // Read just the first line - if reader.read_line(&mut first_line)? > 0 { - // Security check: line length - if first_line.len() > MAX_LINE_LENGTH { - tracing::warn!("Metadata line exceeds length limit"); - return Err(anyhow::anyhow!("Metadata line too long")); - } - - // Try to parse as metadata - match serde_json::from_str::(&first_line) { - Ok(metadata) => Ok(metadata), - Err(e) => { - // If the first line isn't metadata, return default - tracing::debug!("Metadata parse error: {}", e); - Ok(SessionMetadata::default()) - } - } - } else { - // Empty file, return default - Ok(SessionMetadata::default()) - } -} - -/// Write messages to a session file with metadata -/// -/// Overwrites the file with metadata as the first line, followed by all messages in JSONL format. -/// If a provider is supplied, it will automatically generate a description when appropriate. -/// -/// Security features: -/// - Validates file paths to prevent directory traversal -pub async fn persist_messages( - session_file: &Path, - messages: &Conversation, - provider: Option>, - working_dir: Option, -) -> Result<()> { - persist_messages_with_schedule_id(session_file, messages, provider, None, working_dir).await -} - -/// Write messages to a session file with metadata, including an optional scheduled job ID -/// -/// Overwrites the file with metadata as the first line, followed by all messages in JSONL format. -/// If a provider is supplied, it will automatically generate a description when appropriate. -/// -/// Security features: -/// - Validates file paths to prevent directory traversal -/// - Limits error message details in logs -/// - Uses atomic file operations via save_messages_with_metadata -pub async fn persist_messages_with_schedule_id( - session_file: &Path, - messages: &Conversation, - provider: Option>, - schedule_id: Option, - working_dir: Option, -) -> Result<()> { - // Validate the session file path for security - let secure_path = get_path(Identifier::Path(session_file.to_path_buf()))?; - - // Security check: message count limit - if messages.len() > MAX_MESSAGE_COUNT { - tracing::warn!("Message count exceeds limit: {}", messages.len()); - return Err(anyhow::anyhow!("Too many messages")); - } - - // Count user messages - let user_message_count = messages - .iter() - .filter(|m| m.role == rmcp::model::Role::User && !m.as_concat_text().trim().is_empty()) - .count(); - - // Check if we need to update the description (after 1st or 3rd user message) - match provider { - Some(provider) if user_message_count < 4 => { - //generate_description is responsible for writing the messages - generate_description_with_schedule_id( - &secure_path, - messages, - provider, - schedule_id, - working_dir, - ) - .await - } - _ => { - // Read existing metadata or create new with proper working_dir - let mut metadata = if secure_path.exists() { - read_metadata(&secure_path)? - } else { - // Create new metadata with the provided working_dir or fall back to home - let work_dir = working_dir.clone().unwrap_or_else(get_home_dir); - SessionMetadata::new(work_dir) - }; - - // Update the working_dir if provided (even for existing files) - if let Some(work_dir) = working_dir { - metadata.working_dir = work_dir; - } - - // Update the schedule_id if provided - if schedule_id.is_some() { - metadata.schedule_id = schedule_id; - } - - // Write the file with metadata and messages - save_messages_with_metadata(&secure_path, &metadata, messages) - } - } -} - -/// Write messages to a session file with the provided metadata using secure atomic operations -/// -/// This function uses atomic file operations to prevent corruption: -/// 1. Writes to a temporary file first with secure permissions -/// 2. Uses fs2 file locking to prevent concurrent writes -/// 3. Atomically moves the temp file to the final location -/// 4. Includes comprehensive error handling and recovery -/// -/// Security features: -/// - Secure temporary file creation with restricted permissions -/// - Path validation to prevent directory traversal -/// - File size and message count limits -/// - Sanitized error messages to prevent information leakage -pub fn save_messages_with_metadata( - session_file: &Path, - metadata: &SessionMetadata, - messages: &Conversation, -) -> Result<()> { - use fs2::FileExt; - - // Validate the path for security - let secure_path = get_path(Identifier::Path(session_file.to_path_buf()))?; - - // Security check: message count limit - if messages.len() > MAX_MESSAGE_COUNT { - tracing::warn!( - "Message count exceeds limit during save: {}", - messages.len() - ); - return Err(anyhow::anyhow!("Too many messages to save")); - } - - // Create a temporary file in the same directory to ensure atomic move - let temp_file = secure_path.with_extension("tmp"); - - // Ensure the parent directory exists - if let Some(parent) = secure_path.parent() { - fs::create_dir_all(parent).map_err(|e| { - tracing::error!("Failed to create parent directory: {}", e); - anyhow::anyhow!("Failed to create session directory") - })?; - } - - // Create and lock the temporary file with secure permissions - let file = fs::OpenOptions::new() - .write(true) - .create(true) - .truncate(true) - .open(&temp_file) - .map_err(|e| { - tracing::error!("Failed to create temporary file: {}", e); - anyhow::anyhow!("Failed to create temporary session file") - })?; - - // Set secure file permissions (Unix only - read/write for owner only) - #[cfg(unix)] - { - use std::os::unix::fs::PermissionsExt; - let mut perms = file.metadata()?.permissions(); - perms.set_mode(0o600); // rw------- - fs::set_permissions(&temp_file, perms).map_err(|e| { - tracing::error!("Failed to set secure file permissions: {}", e); - anyhow::anyhow!("Failed to secure temporary file") - })?; - } - - // Get an exclusive lock on the file - file.try_lock_exclusive().map_err(|e| { - tracing::error!("Failed to lock file: {}", e); - anyhow::anyhow!("Failed to lock session file") - })?; - - // Write to temporary file - { - let mut writer = io::BufWriter::new(&file); - - // Write metadata as the first line - serde_json::to_writer(&mut writer, &metadata).map_err(|e| { - tracing::error!("Failed to serialize metadata: {}", e); - anyhow::anyhow!("Failed to write session metadata") - })?; - writeln!(writer)?; - - // Write all messages with progress tracking - for (i, message) in messages.iter().enumerate() { - serde_json::to_writer(&mut writer, &message).map_err(|e| { - tracing::error!("Failed to serialize message {}: {}", i, e); - anyhow::anyhow!("Failed to write session message") - })?; - writeln!(writer)?; - } - - // Ensure all data is written to disk - writer.flush().map_err(|e| { - tracing::error!("Failed to flush writer: {}", e); - anyhow::anyhow!("Failed to flush session data") - })?; - } - - // Sync to ensure data is persisted - file.sync_all().map_err(|e| { - tracing::error!("Failed to sync data: {}", e); - anyhow::anyhow!("Failed to sync session data") - })?; - - // Release the lock - fs2::FileExt::unlock(&file).map_err(|e| { - tracing::error!("Failed to unlock file: {}", e); - anyhow::anyhow!("Failed to unlock session file") - })?; - - // Atomically move the temporary file to the final location - fs::rename(&temp_file, &secure_path).map_err(|e| { - // Clean up temp file on failure - tracing::error!("Failed to move temporary file: {}", e); - let _ = fs::remove_file(&temp_file); - anyhow::anyhow!("Failed to finalize session file") - })?; - - tracing::debug!("Successfully saved session file: {:?}", secure_path); - Ok(()) -} - -/// Generate a description for the session using the provider -/// -/// This function is called when appropriate to generate a short description -/// of the session based on the conversation history. -pub async fn generate_description( - session_file: &Path, - messages: &Conversation, - provider: Arc, - working_dir: Option, -) -> Result<()> { - generate_description_with_schedule_id(session_file, messages, provider, None, working_dir).await -} - -/// Generate a description for the session using the provider, including an optional scheduled job ID and working directory -/// -/// This function is called when appropriate to generate a short description -/// of the session based on the conversation history. -/// -/// Security features: -/// - Validates file paths to prevent directory traversal -/// - Limits context size to prevent resource exhaustion -/// - Uses secure file operations for saving -pub async fn generate_description_with_schedule_id( - session_file: &Path, - messages: &Conversation, - provider: Arc, - schedule_id: Option, - working_dir: Option, -) -> Result<()> { - // Validate the path for security - let secure_path = get_path(Identifier::Path(session_file.to_path_buf()))?; - - // Security check: message count limit - if messages.len() > MAX_MESSAGE_COUNT { - tracing::warn!( - "Message count exceeds limit during description generation: {}", - messages.len() - ); - return Err(anyhow::anyhow!( - "Too many messages for description generation" - )); - } - - // Use the provider's session naming capability - let sanitized_description = provider - .generate_session_name(messages) - .await - .map_err(|e| { - tracing::error!("Failed to generate session description: {}", e); - anyhow::anyhow!("Failed to generate session description") - })?; - - // Create metadata with proper working_dir or read existing and update - let mut metadata = if secure_path.exists() { - read_metadata(&secure_path)? - } else { - // Create new metadata with the provided working_dir or fall back to home - let work_dir = working_dir.clone().unwrap_or_else(get_home_dir); - SessionMetadata::new(work_dir) - }; - - // Update description and schedule_id - metadata.description = sanitized_description; - if schedule_id.is_some() { - metadata.schedule_id = schedule_id; - } - - // Update the working_dir if provided (even for existing files) - if let Some(work_dir) = working_dir { - metadata.working_dir = work_dir; - } - - // Update the file with the new metadata and existing messages - save_messages_with_metadata(&secure_path, &metadata, messages) -} - -/// Update only the metadata in a session file, preserving all messages -/// -/// Security features: -/// - Validates file paths to prevent directory traversal -/// - Uses secure file operations for reading and writing -pub async fn update_metadata(session_file: &Path, metadata: &SessionMetadata) -> Result<()> { - // Validate the path for security - let secure_path = get_path(Identifier::Path(session_file.to_path_buf()))?; - - // Read all messages from the file - let messages = read_messages(&secure_path)?; - - // Rewrite the file with the new metadata and existing messages - save_messages_with_metadata(&secure_path, metadata, &messages) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::conversation::message::{Message, MessageContent}; - use tempfile::tempdir; - - #[test] - fn test_corruption_recovery() -> Result<()> { - let test_cases = vec![ - // Case 1: Unclosed quotes - ( - r#"{"role":"user","content":[{"type":"text","text":"Hello there}]"#, - "Unclosed JSON with truncated content", - ), - // Case 2: Trailing comma - ( - r#"{"role":"user","content":[{"type":"text","text":"Test"},]}"#, - "JSON with trailing comma", - ), - // Case 3: Missing closing brace - ( - r#"{"role":"user","content":[{"type":"text","text":"Test""#, - "Incomplete JSON structure", - ), - // Case 4: Control characters in text - ( - r#"{"role":"user","content":[{"type":"text","text":"Test\u{0000}with\u{0001}control\u{0002}chars"}]}"#, - "JSON with control characters", - ), - // Case 5: Partial message with role and text - ( - r#"broken{"role": "assistant", "text": "This is recoverable content"more broken"#, - "Partial message with recoverable content", - ), - ]; - - println!("[TEST] Starting corruption recovery tests..."); - for (i, (corrupt_json, desc)) in test_cases.iter().enumerate() { - println!("\n[TEST] Case {}: {}", i + 1, desc); - println!( - "[TEST] Input: {}", - if corrupt_json.len() > 100 { - safe_truncate(corrupt_json, 100) - } else { - corrupt_json.to_string() - } - ); - - // Try to parse the corrupted JSON - match attempt_corruption_recovery(corrupt_json, Some(50000)) { - Ok(message) => { - println!("[TEST] Successfully recovered message"); - // Verify we got some content - if let Some(MessageContent::Text(text_content)) = message.content.first() { - assert!( - !text_content.text.is_empty(), - "Recovered message should have content" - ); - println!( - "[TEST] Recovered content: {}", - if text_content.text.len() > 50 { - format!("{}...", &text_content.text[..50]) - } else { - text_content.text.clone() - } - ); - } - } - Err(e) => { - println!("[TEST] Failed to recover: {}", e); - panic!("Failed to recover from case {}: {}", i + 1, desc); - } - } - } - - println!("\n[TEST] All corruption recovery tests passed!"); - Ok(()) - } - - #[tokio::test] - async fn test_read_write_messages() -> Result<()> { - let dir = tempdir()?; - let file_path = dir.path().join("test.jsonl"); - - // Create some test messages - let messages = Conversation::new_unvalidated(vec![ - Message::user().with_text("Hello"), - Message::assistant().with_text("Hi there"), - ]); - - // Write messages - persist_messages(&file_path, &messages, None, None).await?; - - // Read them back - let read_messages = read_messages(&file_path)?; - - // Compare - assert_eq!(messages.len(), read_messages.len()); - for (orig, read) in messages.iter().zip(read_messages.iter()) { - assert_eq!(orig.role, read.role); - assert_eq!(orig.content.len(), read.content.len()); - - // Compare first text content - if let (Some(MessageContent::Text(orig_text)), Some(MessageContent::Text(read_text))) = - (orig.content.first(), read.content.first()) - { - assert_eq!(orig_text.text, read_text.text); - } else { - panic!("Messages don't match expected structure"); - } - } - - Ok(()) - } - - #[test] - fn test_empty_file() -> Result<()> { - let dir = tempdir()?; - let file_path = dir.path().join("empty.jsonl"); - - // Reading an empty file should return empty vec - let messages = read_messages(&file_path)?; - assert!(messages.is_empty()); - - Ok(()) - } - - #[test] - fn test_generate_session_id() { - let id = generate_session_id(); - - // Check that it follows the timestamp format (yyyymmdd_hhmmss) - assert_eq!(id.len(), 15); // 8 chars for date + 1 for underscore + 6 for time - assert!(id.contains('_')); - - // Split by underscore and check parts - let parts: Vec<&str> = id.split('_').collect(); - assert_eq!(parts.len(), 2); - - // Date part should be 8 digits - assert_eq!(parts[0].len(), 8); - // Time part should be 6 digits - assert_eq!(parts[1].len(), 6); - } - - #[tokio::test] - async fn test_special_characters_and_long_text() -> Result<()> { - let dir = tempdir()?; - let file_path = dir.path().join("special.jsonl"); - - // Insert some problematic JSON-like content between moderately long text - // (keeping under truncation limit to test serialization/deserialization) - let long_text = format!( - "Start_of_message\n{}{}SOME_MIDDLE_TEXT{}End_of_message", - "A".repeat(10_000), // Reduced from 100_000 to stay under 50KB limit - "\"}]\n", - "A".repeat(10_000) // Reduced from 100_000 to stay under 50KB limit - ); - - let special_chars = vec![ - // Long text - long_text.as_str(), - // Newlines in different positions - "Line 1\nLine 2", - "Line 1\r\nLine 2", - "\nStart with newline", - "End with newline\n", - "\n\nMultiple\n\nNewlines\n\n", - // JSON special characters - "Quote\"in middle", - "\"Quote at start", - "Quote at end\"", - "Multiple\"\"Quotes", - "{\"json\": \"looking text\"}", - // Unicode and special characters - "Unicode: 🦆🤖👾", - "Special: \\n \\r \\t", - "Mixed: \n\"🦆\"\r\n\\n", - // Control characters - "Tab\there", - "Bell\u{0007}char", - "Null\u{0000}char", - // Long text with mixed content - "A very long message with multiple lines\nand \"quotes\"\nand emojis 🦆\nand \\escaped chars", - // Potentially problematic JSON content - "}{[]\",\\", - "]}}\"\\n\\\"{[", - "Edge case: } ] some text", - "{\"foo\": \"} ]\"}", - "}]", - ]; - - let mut messages = Conversation::empty(); - for text in special_chars { - messages.push(Message::user().with_text(text)); - messages.push(Message::assistant().with_text(text)); - } - - // Write messages with special characters - persist_messages(&file_path, &messages, None, None).await?; - - // Read them back - let read_messages = read_messages(&file_path)?; - - // Compare all messages - assert_eq!(messages.len(), read_messages.len()); - for (i, (orig, read)) in messages.iter().zip(read_messages.iter()).enumerate() { - assert_eq!(orig.role, read.role, "Role mismatch at message {}", i); - assert_eq!( - orig.content.len(), - read.content.len(), - "Content length mismatch at message {}", - i - ); - - if let (Some(MessageContent::Text(orig_text)), Some(MessageContent::Text(read_text))) = - (orig.content.first(), read.content.first()) - { - assert_eq!( - orig_text.text, read_text.text, - "Text mismatch at message {}\nExpected: {}\nGot: {}", - i, orig_text.text, read_text.text - ); - } else { - panic!("Messages don't match expected structure at index {}", i); - } - } - - // Verify file format - let contents = fs::read_to_string(&file_path)?; - let lines: Vec<&str> = contents.lines().collect(); - - // First line should be metadata - assert!( - lines[0].contains("\"description\""), - "First line should be metadata" - ); - - // Each subsequent line should be valid JSON - for (i, line) in lines.iter().enumerate().skip(1) { - assert!( - serde_json::from_str::(line).is_ok(), - "Invalid JSON at line {}: {}", - i + 1, - line - ); - } - - Ok(()) - } - - #[tokio::test] - async fn test_large_content_truncation() -> Result<()> { - let dir = tempdir()?; - let file_path = dir.path().join("large_content.jsonl"); - - // Create a message with content larger than the 50KB truncation limit - let very_large_text = "A".repeat(100_000); // 100KB of text - let messages = Conversation::new_unvalidated(vec![ - Message::user().with_text(&very_large_text), - Message::assistant().with_text("Small response"), - ]); - - // Write messages - persist_messages(&file_path, &messages, None, None).await?; - - // Read them back - should be truncated - let read_messages = read_messages(&file_path)?; - - assert_eq!(messages.len(), read_messages.len()); - - // First message should be truncated - if let Some(MessageContent::Text(read_text)) = - read_messages.first().unwrap().content.first() - { - assert!( - read_text.text.len() < very_large_text.len(), - "Content should be truncated" - ); - assert!( - read_text - .text - .contains("content truncated during session loading"), - "Should contain truncation notice" - ); - assert!( - read_text.text.starts_with("AAAA"), - "Should start with original content" - ); - } else { - panic!("Expected text content in first message"); - } - - // Second message should be unchanged - if let Some(MessageContent::Text(read_text)) = read_messages.messages()[1].content.first() { - assert_eq!(read_text.text, "Small response"); - } else { - panic!("Expected text content in second message"); - } - - Ok(()) - } - - #[tokio::test] - async fn test_metadata_special_chars() -> Result<()> { - let dir = tempdir()?; - let file_path = dir.path().join("metadata.jsonl"); - - let mut metadata = SessionMetadata::default(); - metadata.description = "Description with\nnewline and \"quotes\" and 🦆".to_string(); - - let messages = Conversation::new_unvalidated(vec![Message::user().with_text("test")]); - - // Write with special metadata - save_messages_with_metadata(&file_path, &metadata, &messages)?; - - // Read back metadata - let read_metadata = read_metadata(&file_path)?; - assert_eq!(metadata.description, read_metadata.description); - - Ok(()) - } - - #[test] - fn test_invalid_working_dir() -> Result<()> { - let dir = tempdir()?; - let file_path = dir.path().join("test.jsonl"); - - // Create metadata with non-existent directory - let invalid_dir = PathBuf::from("/path/that/does/not/exist"); - - let metadata = SessionMetadata::new(invalid_dir.clone()); - - // Should fall back to home directory - assert_ne!(metadata.working_dir, invalid_dir); - assert_eq!(metadata.working_dir, get_home_dir()); - - // Test deserialization of invalid directory - let messages = Conversation::new_unvalidated(vec![Message::user().with_text("test")]); - save_messages_with_metadata(&file_path, &metadata, &messages)?; - - // Modify the file to include invalid directory - let contents = fs::read_to_string(&file_path)?; - let mut lines: Vec = contents.lines().map(String::from).collect(); - lines[0] = lines[0].replace( - &get_home_dir().to_string_lossy().into_owned(), - &invalid_dir.to_string_lossy().into_owned(), - ); - fs::write(&file_path, lines.join("\n"))?; - - // Read back - should fall back to home dir - let read_metadata = read_metadata(&file_path)?; - assert_ne!(read_metadata.working_dir, invalid_dir); - assert_eq!(read_metadata.working_dir, get_current_working_dir()); - - Ok(()) - } - - #[tokio::test] - async fn test_working_dir_preservation() -> Result<()> { - let dir = tempdir()?; - let file_path = dir.path().join("test.jsonl"); - - // Create a temporary working directory - let working_dir = tempdir()?; - let working_dir_path = working_dir.path().to_path_buf(); - - // Create messages - let messages = - Conversation::new_unvalidated(vec![Message::user().with_text("test message")]); - - // Use persist_messages_with_schedule_id to set working dir - persist_messages_with_schedule_id( - &file_path, - &messages, - None, - None, - Some(working_dir_path.clone()), - ) - .await?; - - // Read back the metadata and verify working_dir is preserved - let metadata = read_metadata(&file_path)?; - assert_eq!(metadata.working_dir, working_dir_path); - - // Verify the messages are also preserved - let read_messages = read_messages(&file_path)?; - assert_eq!(read_messages.len(), 1); - assert_eq!( - read_messages.first().unwrap().role, - messages.messages()[0].role - ); - - Ok(()) - } - - #[tokio::test] - async fn test_working_dir_issue_fixed() -> Result<()> { - // This test demonstrates that the working_dir issue in jsonl files is fixed - let dir = tempdir()?; - let file_path = dir.path().join("test.jsonl"); - - // Create a temporary working directory (this simulates the actual working directory) - let working_dir = tempdir()?; - let working_dir_path = working_dir.path().to_path_buf(); - - // Create messages - let messages = - Conversation::new_unvalidated(vec![Message::user().with_text("test message")]); - - // Get the home directory for comparison - let home_dir = get_home_dir(); - - // Test 1: Using the old persist_messages function (without working_dir) - // This will fall back to home directory since no working_dir is provided - persist_messages(&file_path, &messages, None, None).await?; - - // Read back the metadata - this should now have the home directory as working_dir - let metadata_old = read_metadata(&file_path)?; - assert_eq!( - metadata_old.working_dir, home_dir, - "persist_messages should use home directory when no working_dir is provided" - ); - - // Test 2: Using persist_messages_with_schedule_id function - // This should properly set the working_dir (this is the main fix) - persist_messages_with_schedule_id( - &file_path, - &messages, - None, - None, - Some(working_dir_path.clone()), - ) - .await?; - - // Read back the metadata - this should now have the correct working_dir - let metadata_new = read_metadata(&file_path)?; - assert_eq!( - metadata_new.working_dir, working_dir_path, - "persist_messages_with_schedule_id should use provided working_dir" - ); - assert_ne!( - metadata_new.working_dir, home_dir, - "working_dir should be different from home directory" - ); - - // Test 3: Create a new session file without working_dir (should fall back to home) - let file_path_2 = dir.path().join("test2.jsonl"); - persist_messages_with_schedule_id( - &file_path_2, - &messages, - None, - None, - None, // No working_dir provided - ) - .await?; - - let metadata_fallback = read_metadata(&file_path_2)?; - assert_eq!(metadata_fallback.working_dir, home_dir, "persist_messages_with_schedule_id should fall back to home directory when no working_dir is provided"); - - // Test 4: Test that the fix works for existing files - // Create a session file and then add to it with different working_dir - let file_path_3 = dir.path().join("test3.jsonl"); - - // First, create with home directory - persist_messages(&file_path_3, &messages, None, None).await?; - let metadata_initial = read_metadata(&file_path_3)?; - assert_eq!( - metadata_initial.working_dir, home_dir, - "Initial session should use home directory" - ); - - // Then update with a specific working_dir - persist_messages_with_schedule_id( - &file_path_3, - &messages, - None, - None, - Some(working_dir_path.clone()), - ) - .await?; - - let metadata_updated = read_metadata(&file_path_3)?; - assert_eq!( - metadata_updated.working_dir, working_dir_path, - "Updated session should use new working_dir" - ); - - // Test 5: Most important test - simulate the real-world scenario where - // CLI and web interfaces pass the current directory instead of None - let file_path_4 = dir.path().join("test4.jsonl"); - let current_dir = std::env::current_dir()?; - - // This is what web.rs and session/mod.rs do now after the fix - persist_messages_with_schedule_id( - &file_path_4, - &messages, - None, - None, - Some(current_dir.clone()), - ) - .await?; - - let metadata_current = read_metadata(&file_path_4)?; - assert_eq!( - metadata_current.working_dir, current_dir, - "Session should use current directory when explicitly provided" - ); - // This should NOT be the home directory anymore (unless current_dir == home_dir) - if current_dir != home_dir { - assert_ne!( - metadata_current.working_dir, home_dir, - "working_dir should be different from home directory when current_dir is different" - ); - } - - Ok(()) - } - - #[test] - fn test_windows_path_validation() -> Result<()> { - // Test the Windows path validation logic - let temp_dir = tempfile::tempdir()?; - let session_dir = temp_dir.path().join("sessions"); - fs::create_dir_all(&session_dir)?; - - // Test case 1: Valid path within session directory - let valid_path = session_dir.join("test.jsonl"); - assert!(validate_path_within_session_dir(&valid_path, &session_dir)?); - - // Test case 2: Invalid path outside session directory - let invalid_path = temp_dir.path().join("outside.jsonl"); - assert!(!validate_path_within_session_dir( - &invalid_path, - &session_dir - )?); - - // Test case 3: Path with different separators (simulate Windows issue) - let mixed_sep_path = session_dir.join("subdir").join("test.jsonl"); - fs::create_dir_all(mixed_sep_path.parent().unwrap())?; - assert!(validate_path_within_session_dir( - &mixed_sep_path, - &session_dir - )?); - - // Test case 4: Non-existent path within session directory - let nonexistent_path = session_dir.join("nonexistent").join("test.jsonl"); - assert!(validate_path_within_session_dir( - &nonexistent_path, - &session_dir - )?); - - Ok(()) - } - - #[test] - fn test_path_normalization() { - let temp_dir = tempfile::tempdir().unwrap(); - let test_path = temp_dir.path().join("test"); - - // Test that normalization doesn't crash and returns a path - let normalized = normalize_path_for_comparison(&test_path); - assert!(!normalized.as_os_str().is_empty()); - - // Test with existing path - fs::create_dir_all(&test_path).unwrap(); - let normalized_existing = normalize_path_for_comparison(&test_path); - assert!(!normalized_existing.as_os_str().is_empty()); - } - - #[tokio::test] - async fn test_save_session_parameter() -> Result<()> { - let dir = tempdir()?; - let file_path = dir.path().join("test_save_session.jsonl"); - - let messages = Conversation::new_unvalidated(vec![ - Message::user().with_text("Hello"), - Message::assistant().with_text("Hi there"), - ]); - - let metadata = SessionMetadata::default(); - - // Test with save_session = true - should create file - save_messages_with_metadata(&file_path, &metadata, &messages)?; - assert!( - file_path.exists(), - "File should be created when save_session=true" - ); - - // Verify content is correct - let read_messages = read_messages(&file_path)?; - assert_eq!(messages.len(), read_messages.len()); - - Ok(()) - } - - #[tokio::test] - async fn test_persist_messages_with_save_session_false() -> Result<()> { - let dir = tempdir()?; - let file_path = dir.path().join("test_persist_no_save.jsonl"); - - let messages = Conversation::new_unvalidated(vec![ - Message::user().with_text("Test message"), - Message::assistant().with_text("Test response"), - ]); - - // Test persist_messages_with_schedule_id with working_dir parameter - persist_messages_with_schedule_id( - &file_path, - &messages, - None, - Some("test_schedule".to_string()), - None, - ) - .await?; - - assert!( - file_path.exists(), - "File should be created when save_session=true" - ); - - // Verify the schedule_id was set correctly - let metadata = read_metadata(&file_path)?; - assert_eq!(metadata.schedule_id, Some("test_schedule".to_string())); - - Ok(()) - } -} From 37bda81ff2c394439bdcda8eda6941cc7764f485 Mon Sep 17 00:00:00 2001 From: Will Pfleger Date: Wed, 1 Oct 2025 15:19:33 -0400 Subject: [PATCH 10/16] Unnecessary tests --- crates/goose/src/session/session_manager.rs | 259 -------------------- 1 file changed, 259 deletions(-) diff --git a/crates/goose/src/session/session_manager.rs b/crates/goose/src/session/session_manager.rs index 39b5180611f9..9a31c4ea43e0 100644 --- a/crates/goose/src/session/session_manager.rs +++ b/crates/goose/src/session/session_manager.rs @@ -858,262 +858,3 @@ impl SessionStorage { }) } } - -#[cfg(test)] -mod tests { - use super::*; - use crate::agents::extension::Envs; - use crate::config::ExtensionConfig; - use crate::session::extension_data::{EnabledExtensionsState, ExtensionState}; - use std::collections::HashMap; - use tempfile::tempdir; - - /// Helper to create a SessionStorage instance with a temporary database - async fn create_test_storage() -> Result<(SessionStorage, tempfile::TempDir)> { - let temp_dir = tempdir()?; - let db_path = temp_dir.path().join("test_sessions.db"); - let storage = SessionStorage::create(&db_path).await?; - Ok((storage, temp_dir)) - } - - #[tokio::test] - async fn test_extension_data_persistence() -> Result<()> { - let (storage, _temp_dir) = create_test_storage().await?; - - // Create a session - let session_id = "test_session_001"; - let working_dir = PathBuf::from("/test/dir"); - - sqlx::query( - r#" - INSERT INTO sessions (id, description, working_dir, extension_data) - VALUES (?, ?, ?, '{}') - "#, - ) - .bind(session_id) - .bind("Test session") - .bind(working_dir.to_string_lossy().as_ref()) - .execute(&storage.pool) - .await?; - - // Create extension_data with some state - let mut extension_data = ExtensionData::default(); - extension_data.set_extension_state("test_ext", "v1", serde_json::json!({"key": "value"})); - - // Update the session with extension_data - let extension_data_json = serde_json::to_string(&extension_data)?; - sqlx::query("UPDATE sessions SET extension_data = ? WHERE id = ?") - .bind(&extension_data_json) - .bind(session_id) - .execute(&storage.pool) - .await?; - - // Read back the session - let session = storage.get_session(session_id, false).await?; - - // Verify extension_data was persisted correctly - assert_eq!( - session.extension_data.get_extension_state("test_ext", "v1"), - Some(&serde_json::json!({"key": "value"})) - ); - - Ok(()) - } - - #[tokio::test] - async fn test_enabled_extensions_state_database_roundtrip() -> Result<()> { - let (storage, _temp_dir) = create_test_storage().await?; - - // Create test extension configs - let configs = vec![ - ExtensionConfig::Builtin { - name: "developer".to_string(), - display_name: Some("Developer Tools".to_string()), - description: Some("Built-in developer extension".to_string()), - timeout: Some(30), - bundled: Some(true), - available_tools: vec!["read_file".to_string(), "write_file".to_string()], - }, - ExtensionConfig::Stdio { - name: "custom_mcp".to_string(), - cmd: "python".to_string(), - args: vec!["-m".to_string(), "mcp_server".to_string()], - envs: { - let mut map = HashMap::new(); - map.insert("API_KEY".to_string(), "test123".to_string()); - Envs::new(map) - }, - env_keys: vec!["API_KEY".to_string()], - timeout: Some(60), - description: Some("Custom MCP server".to_string()), - bundled: Some(false), - available_tools: vec!["custom_tool".to_string()], - }, - ]; - - let extensions_state = EnabledExtensionsState::new(configs.clone()); - - // Create a session - let session_id = "test_session_002"; - let working_dir = PathBuf::from("/test/dir"); - - sqlx::query( - r#" - INSERT INTO sessions (id, description, working_dir, extension_data) - VALUES (?, ?, ?, '{}') - "#, - ) - .bind(session_id) - .bind("Test session with extensions") - .bind(working_dir.to_string_lossy().as_ref()) - .execute(&storage.pool) - .await?; - - // Save extension state to database - let mut extension_data = ExtensionData::default(); - extensions_state.to_extension_data(&mut extension_data)?; - - let extension_data_json = serde_json::to_string(&extension_data)?; - sqlx::query("UPDATE sessions SET extension_data = ? WHERE id = ?") - .bind(&extension_data_json) - .bind(session_id) - .execute(&storage.pool) - .await?; - - // Read back the session - let session = storage.get_session(session_id, false).await?; - - // Restore EnabledExtensionsState from database - let restored_state = EnabledExtensionsState::from_extension_data(&session.extension_data) - .expect("Failed to restore extension state"); - - // Verify all extensions were restored correctly - assert_eq!(restored_state.extensions.len(), 2); - - // Verify first extension (Builtin) - match &restored_state.extensions[0] { - ExtensionConfig::Builtin { - name, - display_name, - timeout, - bundled, - available_tools, - .. - } => { - assert_eq!(name, "developer"); - assert_eq!(display_name, &Some("Developer Tools".to_string())); - assert_eq!(timeout, &Some(30)); - assert_eq!(bundled, &Some(true)); - assert_eq!(available_tools.len(), 2); - } - _ => panic!("Expected Builtin variant"), - } - - // Verify second extension (Stdio) - match &restored_state.extensions[1] { - ExtensionConfig::Stdio { - name, - cmd, - envs, - timeout, - .. - } => { - assert_eq!(name, "custom_mcp"); - assert_eq!(cmd, "python"); - assert_eq!(envs.get_env().get("API_KEY"), Some(&"test123".to_string())); - assert_eq!(timeout, &Some(60)); - } - _ => panic!("Expected Stdio variant"), - } - - Ok(()) - } - - #[tokio::test] - async fn test_multiple_extension_states_in_database() -> Result<()> { - let (storage, _temp_dir) = create_test_storage().await?; - - // Create a session - let session_id = "test_session_003"; - let working_dir = PathBuf::from("/test/dir"); - - sqlx::query( - r#" - INSERT INTO sessions (id, description, working_dir, extension_data) - VALUES (?, ?, ?, '{}') - "#, - ) - .bind(session_id) - .bind("Test session") - .bind(working_dir.to_string_lossy().as_ref()) - .execute(&storage.pool) - .await?; - - // Create extension_data with multiple states - let mut extension_data = ExtensionData::default(); - extension_data.set_extension_state( - "state_one", - "v0", - serde_json::json!({"data": "value1"}), - ); - extension_data.set_extension_state( - "state_two", - "v1", - serde_json::json!({"data": "value2"}), - ); - - // Update via direct SQL - let extension_data_json = serde_json::to_string(&extension_data)?; - sqlx::query("UPDATE sessions SET extension_data = ? WHERE id = ?") - .bind(&extension_data_json) - .bind(session_id) - .execute(&storage.pool) - .await?; - - // Verify the update - let session = storage.get_session(session_id, false).await?; - assert_eq!( - session - .extension_data - .get_extension_state("state_one", "v0"), - Some(&serde_json::json!({"data": "value1"})) - ); - assert_eq!( - session - .extension_data - .get_extension_state("state_two", "v1"), - Some(&serde_json::json!({"data": "value2"})) - ); - - Ok(()) - } - - #[tokio::test] - async fn test_extension_data_empty_by_default() -> Result<()> { - let (storage, _temp_dir) = create_test_storage().await?; - - // Create a session without explicitly setting extension_data - let session_id = "test_session_004"; - let working_dir = PathBuf::from("/test/dir"); - - sqlx::query( - r#" - INSERT INTO sessions (id, description, working_dir) - VALUES (?, ?, ?) - "#, - ) - .bind(session_id) - .bind("Test session") - .bind(working_dir.to_string_lossy().as_ref()) - .execute(&storage.pool) - .await?; - - // Read the session - let session = storage.get_session(session_id, false).await?; - - // Verify extension_data is empty by default - assert_eq!(session.extension_data.extension_states.len(), 0); - - Ok(()) - } -} From 5a52420311cb521cddbad1a4d97960450fedb4b8 Mon Sep 17 00:00:00 2001 From: Will Pfleger Date: Wed, 1 Oct 2025 15:59:04 -0400 Subject: [PATCH 11/16] Extension persistence needs to happen after dynamic extensions are loaded --- crates/goose-cli/src/session/builder.rs | 36 ++++++++++++++----------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/crates/goose-cli/src/session/builder.rs b/crates/goose-cli/src/session/builder.rs index c2394ae44f5c..b9f33b7ac806 100644 --- a/crates/goose-cli/src/session/builder.rs +++ b/crates/goose-cli/src/session/builder.rs @@ -396,22 +396,6 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession { } } - // Save extension state after loading all extensions - if let Some(session_id) = session_id.as_ref() { - let session_config = SessionConfig { - id: session_id.clone(), - working_dir: std::env::current_dir().unwrap_or_default(), - schedule_id: None, - execution_mode: None, - max_turns: None, - retry_config: None, - }; - - if let Err(e) = agent_ptr.save_extension_state(&Some(session_config)).await { - tracing::warn!("Failed to save initial extension state: {}", e); - } - } - // Determine editor mode let edit_mode = config .get_param::("EDIT_MODE") @@ -543,6 +527,26 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession { } } + // Save extension state after loading all extensions + if let Some(session_id) = session_id.as_ref() { + let session_config_for_save = SessionConfig { + id: session_id.clone(), + working_dir: std::env::current_dir().unwrap_or_default(), + schedule_id: None, + execution_mode: None, + max_turns: None, + retry_config: None, + }; + + if let Err(e) = session + .agent + .save_extension_state(&Some(session_config_for_save)) + .await + { + tracing::warn!("Failed to save initial extension state: {}", e); + } + } + // Add CLI-specific system prompt extension session .agent From 9aa981fdb0c15567aa521193272854b004926854 Mon Sep 17 00:00:00 2001 From: Will Pfleger Date: Thu, 2 Oct 2025 10:59:52 -0400 Subject: [PATCH 12/16] Remove unnecessary comments --- crates/goose-cli/src/session/builder.rs | 1 - crates/goose/src/agents/agent.rs | 4 ---- crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs | 2 -- crates/goose/src/session/extension_data.rs | 1 - 4 files changed, 8 deletions(-) diff --git a/crates/goose-cli/src/session/builder.rs b/crates/goose-cli/src/session/builder.rs index b9f33b7ac806..0ca785e6861a 100644 --- a/crates/goose-cli/src/session/builder.rs +++ b/crates/goose-cli/src/session/builder.rs @@ -527,7 +527,6 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession { } } - // Save extension state after loading all extensions if let Some(session_id) = session_id.as_ref() { let session_config_for_save = SessionConfig { id: session_id.clone(), diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 84f574717992..3538cc0c92b0 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -655,16 +655,13 @@ impl Agent { let extensions_state = EnabledExtensionsState::new(extension_configs); - // Load current session let mut session_data = SessionManager::get_session(&session_config.id, false).await?; - // Update extension data if let Err(e) = extensions_state.to_extension_data(&mut session_data.extension_data) { warn!("Failed to serialize extension state: {}", e); return Err(anyhow!("Extension state serialization failed: {}", e)); } - // Save back to database SessionManager::update_session(&session_config.id) .extension_data(session_data.extension_data) .apply() @@ -774,7 +771,6 @@ impl Agent { } } - // Save extension state after successful operation if result.is_ok() { if let Err(e) = self.save_extension_state(session).await { warn!( diff --git a/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs b/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs index 69fcb81975fa..4174fc1055a4 100644 --- a/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs +++ b/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs @@ -108,9 +108,7 @@ fn process_extensions( for ext in arr { if let Some(name_str) = ext.as_str() { - // Look up the full extension config by name if let Some(config) = crate::config::get_extension_by_name(name_str) { - // Check if the extension is enabled if crate::config::is_extension_enabled(&config.key()) { converted_extensions.push(config); } else { diff --git a/crates/goose/src/session/extension_data.rs b/crates/goose/src/session/extension_data.rs index b90776bd9ff9..5ffcbabb3c7b 100644 --- a/crates/goose/src/session/extension_data.rs +++ b/crates/goose/src/session/extension_data.rs @@ -108,7 +108,6 @@ impl ExtensionState for EnabledExtensionsState { } impl EnabledExtensionsState { - /// Create a new enabled extensions state pub fn new(extensions: Vec) -> Self { Self { extensions } } From 5e19a07d4a7a49b82aa84815d0880fb1276690c8 Mon Sep 17 00:00:00 2001 From: Will Pfleger Date: Thu, 2 Oct 2025 11:25:42 -0400 Subject: [PATCH 13/16] Simplify save_extension_state function and remove unnecessary Option/Some wrappers --- crates/goose-cli/src/session/builder.rs | 2 +- crates/goose/src/agents/agent.rs | 39 +++++++++++++------------ 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/crates/goose-cli/src/session/builder.rs b/crates/goose-cli/src/session/builder.rs index 0ca785e6861a..ae4223e6401d 100644 --- a/crates/goose-cli/src/session/builder.rs +++ b/crates/goose-cli/src/session/builder.rs @@ -539,7 +539,7 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession { if let Err(e) = session .agent - .save_extension_state(&Some(session_config_for_save)) + .save_extension_state(&session_config_for_save) .await { tracing::warn!("Failed to save initial extension state: {}", e); diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 3538cc0c92b0..7799473e4fdf 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -649,24 +649,23 @@ impl Agent { /// Save current extension state to session metadata /// Should be called after any extension add/remove operation - pub async fn save_extension_state(&self, session: &Option) -> Result<()> { - if let Some(session_config) = session { - let extension_configs = self.extension_manager.get_extension_configs().await; + pub async fn save_extension_state(&self, session: &SessionConfig) -> Result<()> { + let extension_configs = self.extension_manager.get_extension_configs().await; - let extensions_state = EnabledExtensionsState::new(extension_configs); + let extensions_state = EnabledExtensionsState::new(extension_configs); - let mut session_data = SessionManager::get_session(&session_config.id, false).await?; + let mut session_data = SessionManager::get_session(&session.id, false).await?; - if let Err(e) = extensions_state.to_extension_data(&mut session_data.extension_data) { - warn!("Failed to serialize extension state: {}", e); - return Err(anyhow!("Extension state serialization failed: {}", e)); - } - - SessionManager::update_session(&session_config.id) - .extension_data(session_data.extension_data) - .apply() - .await?; + if let Err(e) = extensions_state.to_extension_data(&mut session_data.extension_data) { + warn!("Failed to serialize extension state: {}", e); + return Err(anyhow!("Extension state serialization failed: {}", e)); } + + SessionManager::update_session(&session.id) + .extension_data(session_data.extension_data) + .apply() + .await?; + Ok(()) } @@ -772,11 +771,13 @@ impl Agent { } if result.is_ok() { - if let Err(e) = self.save_extension_state(session).await { - warn!( - "Failed to save extension state after manage_extensions: {}", - e - ); + if let Some(session_config) = session { + if let Err(e) = self.save_extension_state(session_config).await { + warn!( + "Failed to save extension state after manage_extensions: {}", + e + ); + } } } From 84347c34732962e559570b26a8d6d3bc958db9b6 Mon Sep 17 00:00:00 2001 From: Will Pfleger Date: Thu, 2 Oct 2025 11:51:31 -0400 Subject: [PATCH 14/16] Remove trivial LLM tests --- crates/goose/src/config/extensions.rs | 67 ----------- crates/goose/src/session/extension_data.rs | 128 --------------------- 2 files changed, 195 deletions(-) diff --git a/crates/goose/src/config/extensions.rs b/crates/goose/src/config/extensions.rs index c30d7df7a523..8880396bd58f 100644 --- a/crates/goose/src/config/extensions.rs +++ b/crates/goose/src/config/extensions.rs @@ -97,70 +97,3 @@ pub fn get_enabled_extensions() -> Vec { .map(|ext| ext.config) .collect() } - -#[cfg(test)] -mod tests { - use super::*; - use crate::agents::ExtensionConfig; - - fn create_test_extension_config() -> ExtensionConfig { - ExtensionConfig::Builtin { - name: "test_extension".to_string(), - display_name: Some("Test Extension".to_string()), - description: Some("A test extension".to_string()), - timeout: None, - bundled: None, - available_tools: vec![], - } - } - - #[test] - fn test_name_to_key_function() { - assert_eq!(name_to_key("Test Extension"), "testextension"); - assert_eq!(name_to_key("Developer Tools"), "developertools"); - assert_eq!(name_to_key("simple"), "simple"); - assert_eq!(name_to_key("UPPER_case MiXeD"), "upper_casemixed"); - } - - #[test] - fn test_extension_config_key_generation() { - let config = create_test_extension_config(); - assert_eq!(config.key(), "test_extension"); - - let config_with_spaces = ExtensionConfig::Builtin { - name: "Test Extension Name".to_string(), - display_name: Some("Test Extension".to_string()), - description: Some("A test extension".to_string()), - timeout: None, - bundled: None, - available_tools: vec![], - }; - assert_eq!(config_with_spaces.key(), "testextensionname"); - } - - #[test] - fn test_extension_entry_serialization() { - let config = create_test_extension_config(); - let entry = ExtensionEntry { - enabled: true, - config, - }; - - // Test that ExtensionEntry can be serialized/deserialized - let json = serde_json::to_string(&entry).unwrap(); - assert!(json.contains("\"enabled\":true")); - assert!(json.contains("\"name\":\"test_extension\"")); - - let deserialized: ExtensionEntry = serde_json::from_str(&json).unwrap(); - assert_eq!(deserialized.enabled, true); - assert_eq!(deserialized.config.name(), "test_extension"); - } - - #[test] - fn test_get_extensions_map_returns_hashmap() { - // Test that get_extensions_map returns a HashMap (may be empty or not depending on global config) - let extensions = get_extensions_map(); - // Just verify it returns a HashMap - don't assert on contents since global config may vary - assert!(extensions.is_empty() || !extensions.is_empty()); - } -} diff --git a/crates/goose/src/session/extension_data.rs b/crates/goose/src/session/extension_data.rs index 5ffcbabb3c7b..a03d7db5aa75 100644 --- a/crates/goose/src/session/extension_data.rs +++ b/crates/goose/src/session/extension_data.rs @@ -189,132 +189,4 @@ mod tests { ); } - #[test] - fn test_enabled_extensions_state_with_full_configs() { - use crate::agents::extension::Envs; - use std::collections::HashMap; - - // Create multiple ExtensionConfig objects with different types - let configs = vec![ - ExtensionConfig::Builtin { - name: "developer".to_string(), - display_name: Some("Developer Tools".to_string()), - description: Some("Built-in developer extension".to_string()), - timeout: Some(30), - bundled: Some(true), - available_tools: vec!["read_file".to_string(), "write_file".to_string()], - }, - ExtensionConfig::Stdio { - name: "custom_mcp".to_string(), - cmd: "python".to_string(), - args: vec!["-m".to_string(), "mcp_server".to_string()], - envs: { - let mut map = HashMap::new(); - map.insert("API_KEY".to_string(), "test123".to_string()); - Envs::new(map) - }, - env_keys: vec!["API_KEY".to_string()], - timeout: Some(60), - description: Some("Custom MCP server".to_string()), - bundled: Some(false), - available_tools: vec!["custom_tool".to_string()], - }, - ]; - - // Create EnabledExtensionsState - let state = EnabledExtensionsState::new(configs.clone()); - - // Verify basic properties - assert_eq!(state.extensions.len(), 2); - assert_eq!(state.extensions[0].name(), "developer"); - assert_eq!(state.extensions[1].name(), "custom_mcp"); - - // Test round-trip serialization through ExtensionData - let mut data = ExtensionData::default(); - state.to_extension_data(&mut data).unwrap(); - - // Verify the state was saved - assert!(data - .get_extension_state("enabled_extensions", "v0") - .is_some()); - - // Restore from ExtensionData - let restored = EnabledExtensionsState::from_extension_data(&data).unwrap(); - - // Verify all extensions were restored - assert_eq!(restored.extensions.len(), 2); - - // Verify first extension (Builtin) details preserved - match &restored.extensions[0] { - ExtensionConfig::Builtin { - name, - display_name, - description, - timeout, - bundled, - available_tools, - } => { - assert_eq!(name, "developer"); - assert_eq!(display_name, &Some("Developer Tools".to_string())); - assert_eq!( - description, - &Some("Built-in developer extension".to_string()) - ); - assert_eq!(timeout, &Some(30)); - assert_eq!(bundled, &Some(true)); - assert_eq!(available_tools.len(), 2); - assert_eq!(available_tools[0], "read_file"); - } - _ => panic!("Expected Builtin variant"), - } - - // Verify second extension (Stdio) details preserved - match &restored.extensions[1] { - ExtensionConfig::Stdio { - name, - cmd, - args, - envs, - env_keys, - timeout, - description, - bundled, - available_tools, - } => { - assert_eq!(name, "custom_mcp"); - assert_eq!(cmd, "python"); - assert_eq!(args.len(), 2); - assert_eq!(args[0], "-m"); - assert_eq!(envs.get_env().get("API_KEY"), Some(&"test123".to_string())); - assert_eq!(env_keys[0], "API_KEY"); - assert_eq!(timeout, &Some(60)); - assert_eq!(description, &Some("Custom MCP server".to_string())); - assert_eq!(bundled, &Some(false)); - assert_eq!(available_tools[0], "custom_tool"); - } - _ => panic!("Expected Stdio variant"), - } - } - - #[test] - fn test_enabled_extensions_state_missing_data() { - // Test loading from ExtensionData without enabled_extensions - let data = ExtensionData::default(); - let result = EnabledExtensionsState::from_extension_data(&data); - - // Should return None when the key doesn't exist - assert!(result.is_none()); - } - - #[test] - fn test_enabled_extensions_state_corrupt_data() { - // Test loading from ExtensionData with corrupt data - let mut data = ExtensionData::default(); - data.set_extension_state("enabled_extensions", "v0", json!("invalid json string")); - - let result = EnabledExtensionsState::from_extension_data(&data); - - // Should return None when deserialization fails - assert!(result.is_none()); - } } From 6c40fac0c0e5c0d1d99434060f6580cfdb802df9 Mon Sep 17 00:00:00 2001 From: Will Pfleger Date: Thu, 2 Oct 2025 18:33:15 -0400 Subject: [PATCH 15/16] Check if extensions were uninstalled in between exiting and resuming session --- crates/goose-cli/src/session/builder.rs | 44 ++++++++++++++++++++++--- 1 file changed, 40 insertions(+), 4 deletions(-) diff --git a/crates/goose-cli/src/session/builder.rs b/crates/goose-cli/src/session/builder.rs index ae4223e6401d..a832e295ecb0 100644 --- a/crates/goose-cli/src/session/builder.rs +++ b/crates/goose-cli/src/session/builder.rs @@ -3,7 +3,10 @@ use super::CliSession; use console::style; use goose::agents::types::{RetryConfig, SessionConfig}; use goose::agents::Agent; -use goose::config::{get_all_extensions, get_enabled_extensions, Config, ExtensionConfig}; +use goose::config::{ + extensions::{get_extension_by_name, set_extension, ExtensionEntry}, + get_all_extensions, get_enabled_extensions, Config, ExtensionConfig, +}; use goose::providers::create; use goose::recipe::{Response, SubRecipe}; use goose::session::SessionManager; @@ -149,6 +152,41 @@ async fn offer_extension_debugging_help( Ok(()) } +fn check_missing_extensions_or_exit(saved_extensions: &[ExtensionConfig]) { + let missing: Vec<_> = saved_extensions + .iter() + .filter(|ext| get_extension_by_name(&ext.name()).is_none()) + .cloned() + .collect(); + + if !missing.is_empty() { + let names = missing + .iter() + .map(|e| e.name()) + .collect::>() + .join(", "); + + if !cliclack::confirm(format!( + "Extension(s) {} from previous session are no longer in config. Re-add them to config?", + names + )) + .initial_value(true) + .interact() + .unwrap_or(false) + { + println!("{}", style("Resume cancelled.").yellow()); + process::exit(0); + } + + missing.into_iter().for_each(|config| { + set_extension(ExtensionEntry { + enabled: true, + config, + }); + }); + } +} + #[derive(Clone, Debug, Default)] pub struct SessionSettings { pub goose_model: Option, @@ -325,14 +363,12 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession { if let Some(session_id) = session_id.as_ref() { match SessionManager::get_session(session_id, false).await { Ok(session_data) => { - // Try to load saved extension configs directly if let Some(saved_state) = EnabledExtensionsState::from_extension_data(&session_data.extension_data) { - // Use the saved configs as-is (no lookup needed!) + check_missing_extensions_or_exit(&saved_state.extensions); saved_state.extensions } else { - // Fallback to currently enabled extensions get_enabled_extensions() } } From 18b35c848abb44c8addaf8169044263601f30809 Mon Sep 17 00:00:00 2001 From: Will Pfleger Date: Thu, 2 Oct 2025 18:43:49 -0400 Subject: [PATCH 16/16] nit picky cargo fmt --- crates/goose/src/session/extension_data.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/crates/goose/src/session/extension_data.rs b/crates/goose/src/session/extension_data.rs index a03d7db5aa75..ff548ef6b196 100644 --- a/crates/goose/src/session/extension_data.rs +++ b/crates/goose/src/session/extension_data.rs @@ -188,5 +188,4 @@ mod tests { Some(&json!({"key": "value"})) ); } - }