diff --git a/crates/goose-cli/src/commands/acp.rs b/crates/goose-cli/src/commands/acp.rs index f6664a048c5c..7717c5d5c92b 100644 --- a/crates/goose-cli/src/commands/acp.rs +++ b/crates/goose-cli/src/commands/acp.rs @@ -4,7 +4,7 @@ use agent_client_protocol::{ }; use anyhow::Result; use goose::agents::Agent; -use goose::config::{Config, ExtensionConfigManager}; +use goose::config::{get_all_extensions, Config}; use goose::conversation::message::{Message, MessageContent}; use goose::conversation::Conversation; use goose::providers::create; @@ -124,8 +124,7 @@ impl GooseAcpAgent { agent.update_provider(provider.clone()).await?; // Load and add extensions just like the normal CLI - let extensions_to_run: Vec<_> = ExtensionConfigManager::get_all() - .map_err(|e| anyhow::anyhow!("Failed to load extensions: {}", e))? + let extensions_to_run: Vec<_> = get_all_extensions() .into_iter() .filter(|ext| ext.enabled) .map(|ext| ext.config) diff --git a/crates/goose-cli/src/commands/configure.rs b/crates/goose-cli/src/commands/configure.rs index 5aa45748c62f..c1d3b53041b3 100644 --- a/crates/goose-cli/src/commands/configure.rs +++ b/crates/goose-cli/src/commands/configure.rs @@ -9,12 +9,12 @@ use goose::agents::platform_tools::{ use goose::agents::Agent; use goose::agents::{extension::Envs, ExtensionConfig}; use goose::config::custom_providers::CustomProviderConfig; -use goose::config::extensions::name_to_key; -use goose::config::permission::PermissionLevel; -use goose::config::{ - Config, ConfigError, ExperimentManager, ExtensionConfigManager, ExtensionEntry, - PermissionManager, +use goose::config::extensions::{ + get_all_extension_names, get_all_extensions, get_enabled_extensions, get_extension_by_name, + name_to_key, remove_extension, set_extension, set_extension_enabled, }; +use goose::config::permission::PermissionLevel; +use goose::config::{Config, ConfigError, ExperimentManager, ExtensionEntry, PermissionManager}; use goose::conversation::message::Message; use goose::model::ModelConfig; use goose::providers::{create, providers}; @@ -105,10 +105,10 @@ pub async fn handle_configure() -> Result<(), Box> { ); // Since we are setting up for the first time, we'll also enable the developer system // This operation is best-effort and errors are ignored - ExtensionConfigManager::set(ExtensionEntry { + set_extension(ExtensionEntry { enabled: true, config: ExtensionConfig::default(), - })?; + }); } Ok(false) => { let _ = config.clear(); @@ -641,7 +641,7 @@ pub async fn configure_provider_dialog() -> Result> { /// Configure extensions that can be used with goose /// Dialog for toggling which extensions are enabled/disabled pub fn toggle_extensions_dialog() -> Result<(), Box> { - let extensions = ExtensionConfigManager::get_all()?; + let extensions = get_all_extensions(); if extensions.is_empty() { cliclack::outro( @@ -682,10 +682,10 @@ pub fn toggle_extensions_dialog() -> Result<(), Box> { // Update enabled status for each extension for name in extension_status.iter().map(|(name, _)| name) { - ExtensionConfigManager::set_enabled( + set_extension_enabled( &name_to_key(name), selected.iter().any(|s| s.as_str() == name), - )?; + ); } cliclack::outro("Extension settings updated successfully")?; @@ -768,7 +768,7 @@ pub fn configure_extensions_dialog() -> Result<(), Box> { .map(|(_, name, desc)| (name.to_string(), desc.to_string())) .unwrap_or_else(|| (extension.clone(), extension.clone())); - ExtensionConfigManager::set(ExtensionEntry { + set_extension(ExtensionEntry { enabled: true, config: ExtensionConfig::Builtin { name: extension.clone(), @@ -778,12 +778,12 @@ pub fn configure_extensions_dialog() -> Result<(), Box> { description, available_tools: Vec::new(), }, - })?; + }); cliclack::outro(format!("Enabled {} extension", style(extension).green()))?; } "stdio" => { - let extensions = ExtensionConfigManager::get_all_names()?; + let extensions = get_all_extension_names(); let name: String = cliclack::input("What would you like to call this extension?") .placeholder("my-extension") .validate(move |input: &String| { @@ -866,7 +866,7 @@ pub fn configure_extensions_dialog() -> Result<(), Box> { } } - ExtensionConfigManager::set(ExtensionEntry { + set_extension(ExtensionEntry { enabled: true, config: ExtensionConfig::Stdio { name: name.clone(), @@ -879,12 +879,12 @@ pub fn configure_extensions_dialog() -> Result<(), Box> { bundled: None, available_tools: Vec::new(), }, - })?; + }); cliclack::outro(format!("Added {} extension", style(name).green()))?; } "sse" => { - let extensions = ExtensionConfigManager::get_all_names()?; + let extensions = get_all_extension_names(); let name: String = cliclack::input("What would you like to call this extension?") .placeholder("my-remote-extension") .validate(move |input: &String| { @@ -962,7 +962,7 @@ pub fn configure_extensions_dialog() -> Result<(), Box> { } } - ExtensionConfigManager::set(ExtensionEntry { + set_extension(ExtensionEntry { enabled: true, config: ExtensionConfig::Sse { name: name.clone(), @@ -974,12 +974,12 @@ pub fn configure_extensions_dialog() -> Result<(), Box> { bundled: None, available_tools: Vec::new(), }, - })?; + }); cliclack::outro(format!("Added {} extension", style(name).green()))?; } "streamable_http" => { - let extensions = ExtensionConfigManager::get_all_names()?; + let extensions = get_all_extension_names(); let name: String = cliclack::input("What would you like to call this extension?") .placeholder("my-remote-extension") .validate(move |input: &String| { @@ -1082,7 +1082,7 @@ pub fn configure_extensions_dialog() -> Result<(), Box> { } } - ExtensionConfigManager::set(ExtensionEntry { + set_extension(ExtensionEntry { enabled: true, config: ExtensionConfig::StreamableHttp { name: name.clone(), @@ -1095,7 +1095,7 @@ pub fn configure_extensions_dialog() -> Result<(), Box> { bundled: None, available_tools: Vec::new(), }, - })?; + }); cliclack::outro(format!("Added {} extension", style(name).green()))?; } @@ -1106,7 +1106,7 @@ pub fn configure_extensions_dialog() -> Result<(), Box> { } pub fn remove_extension_dialog() -> Result<(), Box> { - let extensions = ExtensionConfigManager::get_all()?; + let extensions = get_all_extensions(); // Create a list of extension names and their enabled status let mut extension_status: Vec<(String, bool)> = extensions @@ -1151,7 +1151,7 @@ pub fn remove_extension_dialog() -> Result<(), Box> { .interact()?; for name in selected { - ExtensionConfigManager::remove(&name_to_key(name))?; + remove_extension(&name_to_key(name)); let mut permission_manager = PermissionManager::default(); permission_manager.remove_extension(&name_to_key(name)); cliclack::outro(format!("Removed {} extension", style(name).green()))?; @@ -1386,11 +1386,9 @@ pub fn toggle_experiments_dialog() -> Result<(), Box> { } pub async fn configure_tool_permissions_dialog() -> Result<(), Box> { - let mut extensions: Vec = ExtensionConfigManager::get_all() - .unwrap_or_default() + let mut extensions: Vec = get_enabled_extensions() .into_iter() - .filter(|ext| ext.enabled) - .map(|ext| ext.config.name().clone()) + .map(|ext| ext.name().clone()) .collect(); extensions.push("platform".to_string()); @@ -1423,7 +1421,7 @@ pub async fn configure_tool_permissions_dialog() -> Result<(), Box> { let agent = Agent::new(); let new_provider = create(&provider_name, model_config)?; agent.update_provider(new_provider).await?; - if let Ok(Some(config)) = ExtensionConfigManager::get_config_by_name(&selected_extension_name) { + if let Some(config) = get_extension_by_name(&selected_extension_name) { agent .add_extension(config.clone()) .await @@ -1706,13 +1704,13 @@ pub async fn handle_openrouter_auth() -> Result<(), Box> { println!("✓ Configuration test passed!"); // Enable the developer extension by default if not already enabled - let entries = ExtensionConfigManager::get_all()?; + let entries = get_all_extensions(); let has_developer = entries .iter() .any(|e| e.config.name() == "developer" && e.enabled); if !has_developer { - match ExtensionConfigManager::set(ExtensionEntry { + set_extension(ExtensionEntry { enabled: true, config: ExtensionConfig::Builtin { name: "developer".to_string(), @@ -1724,12 +1722,8 @@ pub async fn handle_openrouter_auth() -> Result<(), Box> { description: "Developer extension".to_string(), available_tools: Vec::new(), }, - }) { - Ok(_) => println!("✓ Developer extension enabled"), - Err(e) => { - eprintln!("⚠️ Failed to enable developer extension: {}", e) - } - } + }); + println!("✓ Developer extension enabled"); } cliclack::outro("OpenRouter setup complete! You can now use goose.")?; @@ -1809,13 +1803,13 @@ pub async fn handle_tetrate_auth() -> Result<(), Box> { println!("✓ Configuration test passed!"); // Enable the developer extension by default if not already enabled - let entries = ExtensionConfigManager::get_all()?; + let entries = get_all_extensions(); let has_developer = entries .iter() .any(|e| e.config.name() == "developer" && e.enabled); if !has_developer { - match ExtensionConfigManager::set(ExtensionEntry { + set_extension(ExtensionEntry { enabled: true, config: ExtensionConfig::Builtin { name: "developer".to_string(), @@ -1827,12 +1821,8 @@ pub async fn handle_tetrate_auth() -> Result<(), Box> { description: "Developer extension".to_string(), available_tools: Vec::new(), }, - }) { - Ok(_) => println!("✓ Developer extension enabled"), - Err(e) => { - eprintln!("⚠️ Failed to enable developer extension: {}", e) - } - } + }); + println!("✓ Developer extension enabled"); } cliclack::outro("Tetrate Agent Router Service setup complete! You can now use goose.")?; diff --git a/crates/goose-cli/src/commands/web.rs b/crates/goose-cli/src/commands/web.rs index 432360c15de0..af0f066cb2dd 100644 --- a/crates/goose-cli/src/commands/web.rs +++ b/crates/goose-cli/src/commands/web.rs @@ -164,16 +164,10 @@ pub async fn handle_web( agent.update_provider(provider).await?; // Load and enable extensions from config - let extensions = goose::config::ExtensionConfigManager::get_all()?; - for ext_config in extensions { - if ext_config.enabled { - if let Err(e) = agent.add_extension(ext_config.config.clone()).await { - eprintln!( - "Warning: Failed to load extension {}: {}", - ext_config.config.name(), - e - ); - } + let enabled_configs = goose::config::get_enabled_extensions(); + for config in enabled_configs { + if let Err(e) = agent.add_extension(config.clone()).await { + eprintln!("Warning: Failed to load extension {}: {}", config.name(), e); } } diff --git a/crates/goose-cli/src/session/builder.rs b/crates/goose-cli/src/session/builder.rs index ce6081c6826a..61ad44fded53 100644 --- a/crates/goose-cli/src/session/builder.rs +++ b/crates/goose-cli/src/session/builder.rs @@ -1,14 +1,18 @@ use super::output; use super::CliSession; use console::style; -use goose::agents::types::RetryConfig; +use goose::agents::types::{RetryConfig, SessionConfig}; use goose::agents::Agent; -use goose::config::{Config, ExtensionConfig, ExtensionConfigManager}; +use goose::config::{ + extensions::{get_extension_by_name, set_extension, ExtensionEntry}, + get_all_extensions, get_enabled_extensions, Config, ExtensionConfig, +}; use goose::providers::create; use goose::recipe::{Response, SubRecipe}; use goose::agents::extension::PlatformExtensionContext; use goose::session::SessionManager; +use goose::session::{EnabledExtensionsState, ExtensionState}; use rustyline::EditMode; use std::collections::HashSet; use std::process; @@ -114,18 +118,17 @@ async fn offer_extension_debugging_help( debug_agent.update_provider(provider).await?; // Add the developer extension if available to help with debugging - if let Ok(extensions) = ExtensionConfigManager::get_all() { - for ext_wrapper in extensions { - if ext_wrapper.enabled && ext_wrapper.config.name() == "developer" { - if let Err(e) = debug_agent.add_extension(ext_wrapper.config).await { - // If we can't add developer extension, continue without it - eprintln!( - "Note: Could not load developer extension for debugging: {}", - e - ); - } - break; + let extensions = get_all_extensions(); + for ext_wrapper in extensions { + if ext_wrapper.enabled && ext_wrapper.config.name() == "developer" { + if let Err(e) = debug_agent.add_extension(ext_wrapper.config).await { + // If we can't add developer extension, continue without it + eprintln!( + "Note: Could not load developer extension for debugging: {}", + e + ); } + break; } } @@ -151,6 +154,41 @@ async fn offer_extension_debugging_help( Ok(()) } +fn check_missing_extensions_or_exit(saved_extensions: &[ExtensionConfig]) { + let missing: Vec<_> = saved_extensions + .iter() + .filter(|ext| get_extension_by_name(&ext.name()).is_none()) + .cloned() + .collect(); + + if !missing.is_empty() { + let names = missing + .iter() + .map(|e| e.name()) + .collect::>() + .join(", "); + + if !cliclack::confirm(format!( + "Extension(s) {} from previous session are no longer in config. Re-add them to config?", + names + )) + .initial_value(true) + .interact() + .unwrap_or(false) + { + println!("{}", style("Resume cancelled.").yellow()); + process::exit(0); + } + + missing.into_iter().for_each(|config| { + set_extension(ExtensionEntry { + enabled: true, + config, + }); + }); + } +} + #[derive(Clone, Debug, Default)] pub struct SessionSettings { pub goose_model: Option, @@ -330,13 +368,26 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession { let extensions_to_run: Vec<_> = if let Some(extensions) = session_config.extensions_override { agent.disable_router_for_recipe().await; extensions.into_iter().collect() + } else if session_config.resume { + if let Some(session_id) = session_id.as_ref() { + match SessionManager::get_session(session_id, false).await { + Ok(session_data) => { + if let Some(saved_state) = + EnabledExtensionsState::from_extension_data(&session_data.extension_data) + { + check_missing_extensions_or_exit(&saved_state.extensions); + saved_state.extensions + } else { + get_enabled_extensions() + } + } + _ => get_enabled_extensions(), + } + } else { + get_enabled_extensions() + } } else { - ExtensionConfigManager::get_all() - .expect("should load extensions") - .into_iter() - .filter(|ext| ext.enabled) - .map(|ext| ext.config) - .collect() + get_enabled_extensions() }; let mut set = JoinSet::new(); @@ -416,21 +467,17 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession { session_config.retry_config.clone(), ); - // Add extensions if provided + // Add stdio extensions if provided for extension_str in session_config.extensions { if let Err(e) = session.add_extension(extension_str.clone()).await { eprintln!( "{}", style(format!( - "Warning: Failed to start extension '{}': {}", + "Warning: Failed to start stdio extension '{}' ({}), continuing without it", extension_str, e )) .yellow() ); - eprintln!( - "{}", - style(format!("Continuing without extension '{}'", extension_str)).yellow() - ); // Offer debugging help if let Err(debug_err) = offer_extension_debugging_help( @@ -452,19 +499,11 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession { eprintln!( "{}", style(format!( - "Warning: Failed to start remote extension '{}': {}", + "Warning: Failed to start remote extension '{}' ({}), continuing without it", extension_str, e )) .yellow() ); - eprintln!( - "{}", - style(format!( - "Continuing without remote extension '{}'", - extension_str - )) - .yellow() - ); // Offer debugging help if let Err(debug_err) = offer_extension_debugging_help( @@ -489,19 +528,11 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession { eprintln!( "{}", style(format!( - "Warning: Failed to start streamable HTTP extension '{}': {}", + "Warning: Failed to start streamable HTTP extension '{}' ({}), continuing without it", extension_str, e )) .yellow() ); - eprintln!( - "{}", - style(format!( - "Continuing without streamable HTTP extension '{}'", - extension_str - )) - .yellow() - ); // Offer debugging help if let Err(debug_err) = offer_extension_debugging_help( @@ -523,19 +554,11 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession { eprintln!( "{}", style(format!( - "Warning: Failed to start builtin extension '{}': {}", + "Warning: Failed to start builtin extension '{}' ({}), continuing without it", builtin, e )) .yellow() ); - eprintln!( - "{}", - style(format!( - "Continuing without builtin extension '{}'", - builtin - )) - .yellow() - ); // Offer debugging help if let Err(debug_err) = offer_extension_debugging_help( @@ -551,6 +574,25 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> CliSession { } } + if let Some(session_id) = session_id.as_ref() { + let session_config_for_save = SessionConfig { + id: session_id.clone(), + working_dir: std::env::current_dir().unwrap_or_default(), + schedule_id: None, + execution_mode: None, + max_turns: None, + retry_config: None, + }; + + if let Err(e) = session + .agent + .save_extension_state(&session_config_for_save) + .await + { + tracing::warn!("Failed to save initial extension state: {}", e); + } + } + // Add CLI-specific system prompt extension session .agent diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 1d236f431d75..cfb2083602fc 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -31,14 +31,13 @@ use goose::agents::types::RetryConfig; use goose::agents::{Agent, SessionConfig}; use goose::config::Config; use goose::providers::pricing::initialize_pricing_cache; -use goose::session; +use goose::session::SessionManager; use input::InputResult; use rmcp::model::PromptMessage; use rmcp::model::ServerNotification; use rmcp::model::{ErrorCode, ErrorData}; use goose::conversation::message::{Message, MessageContent}; -use goose::session::SessionManager; use rand::{distributions::Alphanumeric, Rng}; use rustyline::EditMode; use serde_json::Value; @@ -300,8 +299,9 @@ impl CliSession { /// * `builtin_name` - Name of the builtin extension(s), comma separated pub async fn add_builtin(&mut self, builtin_name: String) -> Result<()> { for name in builtin_name.split(',') { + let extension_name = name.trim().to_string(); let config = ExtensionConfig::Builtin { - name: name.trim().to_string(), + name: extension_name, display_name: None, // TODO: should set a timeout timeout: Some(goose::config::DEFAULT_EXTENSION_TIMEOUT), @@ -1464,7 +1464,7 @@ impl CliSession { ); } - pub async fn get_metadata(&self) -> Result { + pub async fn get_metadata(&self) -> Result { match &self.session_id { Some(id) => SessionManager::get_session(id, false).await, None => Err(anyhow::anyhow!("No session available")), diff --git a/crates/goose-server/src/routes/config_management.rs b/crates/goose-server/src/routes/config_management.rs index 070274181de0..cd429b378f9e 100644 --- a/crates/goose-server/src/routes/config_management.rs +++ b/crates/goose-server/src/routes/config_management.rs @@ -6,9 +6,9 @@ use axum::{ Json, Router, }; use etcetera::{choose_app_strategy, AppStrategy}; +use goose::config::ExtensionEntry; use goose::config::APP_STRATEGY; use goose::config::{Config, ConfigError}; -use goose::config::{ExtensionConfigManager, ExtensionEntry}; use goose::model::ModelConfig; use goose::providers::base::ProviderMetadata; use goose::providers::pricing::{ @@ -182,19 +182,8 @@ pub async fn read_config(Json(query): Json) -> Result Result, StatusCode> { - match ExtensionConfigManager::get_all() { - Ok(extensions) => Ok(Json(ExtensionResponse { extensions })), - Err(err) => { - if err - .downcast_ref::() - .is_some_and(|e| matches!(e, goose::config::base::ConfigError::DeserializeError(_))) - { - Err(StatusCode::UNPROCESSABLE_ENTITY) - } else { - Err(StatusCode::INTERNAL_SERVER_ERROR) - } - } - } + let extensions = goose::config::get_all_extensions(); + Ok(Json(ExtensionResponse { extensions })) } #[utoipa::path( @@ -211,24 +200,20 @@ pub async fn get_extensions() -> Result, StatusCode> { pub async fn add_extension( Json(extension_query): Json, ) -> Result, StatusCode> { - let extensions = - ExtensionConfigManager::get_all().map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + let extensions = goose::config::get_all_extensions(); let key = goose::config::extensions::name_to_key(&extension_query.name); let is_update = extensions.iter().any(|e| e.config.key() == key); - match ExtensionConfigManager::set(ExtensionEntry { + goose::config::set_extension(ExtensionEntry { enabled: extension_query.enabled, config: extension_query.config, - }) { - Ok(_) => { - if is_update { - Ok(Json(format!("Updated extension {}", extension_query.name))) - } else { - Ok(Json(format!("Added extension {}", extension_query.name))) - } - } - Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR), + }); + + if is_update { + Ok(Json(format!("Updated extension {}", extension_query.name))) + } else { + Ok(Json(format!("Added extension {}", extension_query.name))) } } @@ -245,10 +230,8 @@ pub async fn remove_extension( axum::extract::Path(name): axum::extract::Path, ) -> Result, StatusCode> { let key = goose::config::extensions::name_to_key(&name); - match ExtensionConfigManager::remove(&key) { - Ok(_) => Ok(Json(format!("Removed extension {}", name))), - Err(_) => Err(StatusCode::NOT_FOUND), - } + goose::config::remove_extension(&key); + Ok(Json(format!("Removed extension {}", name))) } #[utoipa::path( diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 177f1a4d1158..a5c35c7a6582 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -31,7 +31,7 @@ use crate::agents::tool_route_manager::ToolRouteManager; use crate::agents::tool_router_index_manager::ToolRouterIndexManager; use crate::agents::types::SessionConfig; use crate::agents::types::{FrontendTool, ToolResultReceiver}; -use crate::config::{Config, ExtensionConfigManager}; +use crate::config::{get_enabled_extensions, get_extension_by_name, Config}; use crate::context_mgmt::auto_compact; use crate::conversation::{debug_conversation_fix, fix_conversation, Conversation}; use crate::mcp_utils::ToolResult; @@ -62,6 +62,7 @@ use super::platform_tools; use super::tool_execution::{ToolCallResult, CHAT_MODE_TOOL_SKIPPED_RESPONSE, DECLINED_RESPONSE}; use crate::agents::subagent_task_config::TaskConfig; use crate::conversation::message::{Message, ToolRequest}; +use crate::session::extension_data::{EnabledExtensionsState, ExtensionState}; use crate::session::SessionManager; const DEFAULT_MAX_TURNS: u32 = 1000; @@ -549,6 +550,28 @@ impl Agent { ) } + /// Save current extension state to session metadata + /// Should be called after any extension add/remove operation + pub async fn save_extension_state(&self, session: &SessionConfig) -> Result<()> { + let extension_configs = self.extension_manager.get_extension_configs().await; + + let extensions_state = EnabledExtensionsState::new(extension_configs); + + let mut session_data = SessionManager::get_session(&session.id, false).await?; + + if let Err(e) = extensions_state.to_extension_data(&mut session_data.extension_data) { + warn!("Failed to serialize extension state: {}", e); + return Err(anyhow!("Extension state serialization failed: {}", e)); + } + + SessionManager::update_session(&session.id) + .extension_data(session_data.extension_data) + .apply() + .await?; + + Ok(()) + } + #[allow(clippy::too_many_lines)] pub(super) async fn manage_extensions( &self, @@ -595,9 +618,9 @@ impl Agent { return (request_id, result); } - let config = match ExtensionConfigManager::get_config_by_name(&extension_name) { - Ok(Some(config)) => config, - Ok(None) => { + let config = match get_extension_by_name(&extension_name) { + Some(config) => config, + None => { return ( request_id, Err(ErrorData::new( @@ -610,16 +633,6 @@ impl Agent { )), ) } - Err(e) => { - return ( - request_id, - Err(ErrorData::new( - ErrorCode::INTERNAL_ERROR, - format!("Failed to get extension config: {}", e), - None, - )), - ) - } }; let result = self .extension_manager @@ -658,6 +671,7 @@ impl Agent { } } } + (request_id, result) } @@ -792,6 +806,10 @@ impl Agent { .expect("Failed to list extensions") } + pub async fn get_extension_configs(&self) -> Vec { + self.extension_manager.get_extension_configs().await + } + /// Handle a confirmation response for a tool request pub async fn handle_confirmation( &self, @@ -1199,7 +1217,12 @@ impl Agent { } } - if all_install_successful { + if all_install_successful && !enable_extension_request_ids.is_empty() { + if let Some(ref session_config) = session { + if let Err(e) = self.save_extension_state(session_config).await { + warn!("Failed to save extension state after runtime changes: {}", e); + } + } tools_updated = true; } } @@ -1558,12 +1581,7 @@ impl Agent { (instructions, activities) }; - let extensions = ExtensionConfigManager::get_all().unwrap_or_default(); - let extension_configs: Vec<_> = extensions - .iter() - .filter(|e| e.enabled) - .map(|e| e.config.clone()) - .collect(); + let extension_configs = get_enabled_extensions(); let author = Author { contact: std::env::var("USER") diff --git a/crates/goose/src/agents/extension_manager.rs b/crates/goose/src/agents/extension_manager.rs index 9838214fc2d3..c5b29f124975 100644 --- a/crates/goose/src/agents/extension_manager.rs +++ b/crates/goose/src/agents/extension_manager.rs @@ -32,7 +32,7 @@ use super::tool_execution::ToolCallResult; use crate::agents::extension::{Envs, ProcessExit}; use crate::agents::extension_malware_check; use crate::agents::mcp_client::{McpClient, McpClientTrait}; -use crate::config::{Config, ExtensionConfigManager}; +use crate::config::{get_all_extensions, Config}; use crate::oauth::oauth_flow; use crate::prompt_template; use rmcp::model::{ @@ -576,6 +576,15 @@ impl ExtensionManager { Ok(self.extensions.lock().await.keys().cloned().collect()) } + pub async fn get_extension_configs(&self) -> Vec { + self.extensions + .lock() + .await + .values() + .map(|ext| ext.config.clone()) + .collect() + } + /// Get all tools from all clients with proper prefixing pub async fn get_prefixed_tools( &self, @@ -1035,7 +1044,7 @@ impl ExtensionManager { // First get disabled extensions from current config let mut disabled_extensions: Vec = vec![]; - for extension in ExtensionConfigManager::get_all().expect("should load extensions") { + for extension in get_all_extensions() { if !extension.enabled { let config = extension.config.clone(); let description = match &config { diff --git a/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs b/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs index bda327fdf7df..4174fc1055a4 100644 --- a/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs +++ b/crates/goose/src/agents/recipe_tools/dynamic_task_tools.rs @@ -108,24 +108,14 @@ fn process_extensions( for ext in arr { if let Some(name_str) = ext.as_str() { - // Look up the full extension config by name - match crate::config::ExtensionConfigManager::get_config_by_name(name_str) { - Ok(Some(config)) => { - // Check if the extension is enabled - if crate::config::ExtensionConfigManager::is_enabled(&config.key()) - .unwrap_or(false) - { - converted_extensions.push(config); - } else { - tracing::warn!("Extension '{}' is disabled, skipping", name_str); - } - } - Ok(None) => { - tracing::warn!("Extension '{}' not found in configuration", name_str); - } - Err(e) => { - tracing::warn!("Error looking up extension '{}': {}", name_str, e); + if let Some(config) = crate::config::get_extension_by_name(name_str) { + if crate::config::is_extension_enabled(&config.key()) { + converted_extensions.push(config); + } else { + tracing::warn!("Extension '{}' is disabled, skipping", name_str); } + } else { + tracing::warn!("Extension '{}' not found in configuration", name_str); } } else if let Ok(ext_config) = serde_json::from_value::(ext.clone()) { converted_extensions.push(ext_config); diff --git a/crates/goose/src/agents/subagent.rs b/crates/goose/src/agents/subagent.rs index 06039b7aedae..7daac47255d8 100644 --- a/crates/goose/src/agents/subagent.rs +++ b/crates/goose/src/agents/subagent.rs @@ -1,8 +1,7 @@ use crate::agents::subagent_task_config::DEFAULT_SUBAGENT_MAX_TURNS; use crate::{ - agents::extension::ExtensionConfig, agents::{extension_manager::ExtensionManager, Agent, TaskConfig}, - config::ExtensionConfigManager, + config::get_all_extensions, prompt_template::render_global_file, providers::errors::ProviderError, }; @@ -68,12 +67,11 @@ impl SubAgent { extensions.clone() } else { // Default behavior: use all enabled extensions - ExtensionConfigManager::get_all() - .unwrap_or_default() + get_all_extensions() .into_iter() .filter(|ext| ext.enabled) .map(|ext| ext.config) - .collect::>() + .collect() }; // Add the determined extensions to the subagent's extension manager diff --git a/crates/goose/src/config/extensions.rs b/crates/goose/src/config/extensions.rs index f4d6b15548f2..4241e1b42245 100644 --- a/crates/goose/src/config/extensions.rs +++ b/crates/goose/src/config/extensions.rs @@ -1,7 +1,6 @@ use super::base::Config; use crate::agents::extension::PLATFORM_EXTENSIONS; use crate::agents::ExtensionConfig; -use anyhow::Result; use serde::{Deserialize, Serialize}; use serde_json::Value; use std::collections::HashMap; @@ -28,132 +27,140 @@ pub fn name_to_key(name: &str) -> String { .to_lowercase() } -pub struct ExtensionConfigManager; +fn get_extensions_map() -> HashMap { + let raw: Value = Config::global() + .get_param::(EXTENSIONS_CONFIG_KEY) + .unwrap_or_else(|err| { + warn!( + "Failed to load {}: {err}. Falling back to empty object.", + EXTENSIONS_CONFIG_KEY + ); + Value::Object(serde_json::Map::new()) + }); -impl ExtensionConfigManager { - fn get_extensions_map() -> Result> { - let raw: Value = Config::global() - .get_param::(EXTENSIONS_CONFIG_KEY) - .unwrap_or_else(|err| { - warn!( - "Failed to load {}: {err}. Falling back to empty object.", - EXTENSIONS_CONFIG_KEY - ); - Value::Object(serde_json::Map::new()) - }); - - let mut extensions_map: HashMap = match raw { - Value::Object(obj) => { - let mut m = HashMap::with_capacity(obj.len()); - for (k, mut v) in obj { - if let Value::Object(ref mut inner) = v { - match inner.get("description") { - Some(Value::Null) | None => { - inner.insert( - "description".to_string(), - Value::String(String::new()), - ); - } - _ => {} + let mut extensions_map: HashMap = match raw { + Value::Object(obj) => { + let mut m = HashMap::with_capacity(obj.len()); + for (k, mut v) in obj { + if let Value::Object(ref mut inner) = v { + match inner.get("description") { + Some(Value::Null) | None => { + inner.insert("description".to_string(), Value::String(String::new())); } + _ => {} } - match serde_json::from_value::(v.clone()) { - Ok(entry) => { - m.insert(k, entry); - } - Err(err) => { - let bad_json = serde_json::to_string(&v).unwrap_or_else(|e| { - format!("") - }); - warn!( - extension = %k, - error = %err, - bad_json = %bad_json, - "Skipping malformed extension" - ); - } + } + match serde_json::from_value::(v.clone()) { + Ok(entry) => { + m.insert(k, entry); + } + Err(err) => { + let bad_json = serde_json::to_string(&v).unwrap_or_else(|e| { + format!("") + }); + warn!( + extension = %k, + error = %err, + bad_json = %bad_json, + "Skipping malformed extension" + ); } } - m - } - other => { - warn!( - "Expected object for {}, got {}. Using empty map.", - EXTENSIONS_CONFIG_KEY, other - ); - HashMap::new() } - }; + m + } + other => { + warn!( + "Expected object for {}, got {}. Using empty map.", + EXTENSIONS_CONFIG_KEY, other + ); + HashMap::new() + } + }; - if !extensions_map.is_empty() { - for (name, def) in PLATFORM_EXTENSIONS.iter() { - if !extensions_map.contains_key(*name) { - extensions_map.insert( - name.to_string(), - ExtensionEntry { - config: ExtensionConfig::Platform { - name: def.name.to_string(), - description: def.description.to_string(), - bundled: Some(true), - available_tools: Vec::new(), - }, - enabled: true, + if !extensions_map.is_empty() { + for (name, def) in PLATFORM_EXTENSIONS.iter() { + if !extensions_map.contains_key(*name) { + extensions_map.insert( + name.to_string(), + ExtensionEntry { + config: ExtensionConfig::Platform { + name: def.name.to_string(), + description: def.description.to_string(), + bundled: Some(true), + available_tools: Vec::new(), }, - ); - } + enabled: true, + }, + ); } } - Ok(extensions_map) } + extensions_map +} - fn save_extensions_map(extensions: HashMap) -> Result<()> { - let config = Config::global(); - config.set_param(EXTENSIONS_CONFIG_KEY, serde_json::to_value(extensions)?)?; - Ok(()) +fn save_extensions_map(extensions: HashMap) { + let config = Config::global(); + match serde_json::to_value(extensions) { + Ok(value) => { + if let Err(e) = config.set_param(EXTENSIONS_CONFIG_KEY, value) { + tracing::debug!("Failed to save extensions config: {}", e); + } + } + Err(e) => { + tracing::debug!("Failed to serialize extensions: {}", e); + } } +} - pub fn get_config_by_name(name: &str) -> Result> { - let extensions = Self::get_extensions_map()?; - Ok(extensions - .values() - .find(|entry| entry.config.name() == name) - .map(|entry| entry.config.clone())) - } +pub fn get_extension_by_name(name: &str) -> Option { + let extensions = get_extensions_map(); + extensions + .values() + .find(|entry| entry.config.name() == name) + .map(|entry| entry.config.clone()) +} - pub fn set(entry: ExtensionEntry) -> Result<()> { - let mut extensions = Self::get_extensions_map()?; - let key = entry.config.key(); - extensions.insert(key, entry); - Self::save_extensions_map(extensions) - } +pub fn set_extension(entry: ExtensionEntry) { + let mut extensions = get_extensions_map(); + let key = entry.config.key(); + extensions.insert(key, entry); + save_extensions_map(extensions); +} - pub fn remove(key: &str) -> Result<()> { - let mut extensions = Self::get_extensions_map()?; - extensions.remove(key); - Self::save_extensions_map(extensions) - } +pub fn remove_extension(key: &str) { + let mut extensions = get_extensions_map(); + extensions.remove(key); + save_extensions_map(extensions); +} - pub fn set_enabled(key: &str, enabled: bool) -> Result<()> { - let mut extensions = Self::get_extensions_map()?; - if let Some(entry) = extensions.get_mut(key) { - entry.enabled = enabled; - Self::save_extensions_map(extensions)?; - } - Ok(()) +pub fn set_extension_enabled(key: &str, enabled: bool) { + let mut extensions = get_extensions_map(); + if let Some(entry) = extensions.get_mut(key) { + entry.enabled = enabled; + save_extensions_map(extensions); } +} - pub fn get_all() -> Result> { - let extensions = Self::get_extensions_map()?; - Ok(extensions.into_values().collect()) - } +pub fn get_all_extensions() -> Vec { + let extensions = get_extensions_map(); + extensions.into_values().collect() +} - pub fn get_all_names() -> Result> { - let extensions = Self::get_extensions_map()?; - Ok(extensions.keys().cloned().collect()) - } +pub fn get_all_extension_names() -> Vec { + let extensions = get_extensions_map(); + extensions.keys().cloned().collect() +} - pub fn is_enabled(key: &str) -> Result { - let extensions = Self::get_extensions_map()?; - Ok(extensions.get(key).map(|e| e.enabled).unwrap_or(false)) - } +pub fn is_extension_enabled(key: &str) -> bool { + let extensions = get_extensions_map(); + extensions.get(key).map(|e| e.enabled).unwrap_or(false) +} + +pub fn get_enabled_extensions() -> Vec { + get_all_extensions() + .into_iter() + .filter(|ext| ext.enabled) + .map(|ext| ext.config) + .collect() } diff --git a/crates/goose/src/config/mod.rs b/crates/goose/src/config/mod.rs index 7c009b879de4..c80204205888 100644 --- a/crates/goose/src/config/mod.rs +++ b/crates/goose/src/config/mod.rs @@ -10,7 +10,10 @@ pub use crate::agents::ExtensionConfig; pub use base::{get_config_dir, Config, ConfigError, APP_STRATEGY}; pub use custom_providers::CustomProviderConfig; pub use experiments::ExperimentManager; -pub use extensions::{ExtensionConfigManager, ExtensionEntry}; +pub use extensions::{ + get_all_extension_names, get_all_extensions, get_enabled_extensions, get_extension_by_name, + is_extension_enabled, remove_extension, set_extension, set_extension_enabled, ExtensionEntry, +}; pub use permission::PermissionManager; pub use signup_openrouter::configure_openrouter; pub use signup_tetrate::configure_tetrate; diff --git a/crates/goose/src/session/extension_data.rs b/crates/goose/src/session/extension_data.rs index 292415f25d38..ff548ef6b196 100644 --- a/crates/goose/src/session/extension_data.rs +++ b/crates/goose/src/session/extension_data.rs @@ -1,6 +1,7 @@ // Extension data management for sessions // Provides a simple way to store extension-specific data with versioned keys +use crate::config::ExtensionConfig; use anyhow::Result; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -95,6 +96,23 @@ impl TodoState { } } +/// Enabled extensions state implementation for storing which extensions are active +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EnabledExtensionsState { + pub extensions: Vec, +} + +impl ExtensionState for EnabledExtensionsState { + const EXTENSION_NAME: &'static str = "enabled_extensions"; + const VERSION: &'static str = "v0"; +} + +impl EnabledExtensionsState { + pub fn new(extensions: Vec) -> Self { + Self { extensions } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/goose/src/session/mod.rs b/crates/goose/src/session/mod.rs index 5ed8312ecd54..221f89d70b1b 100644 --- a/crates/goose/src/session/mod.rs +++ b/crates/goose/src/session/mod.rs @@ -2,4 +2,5 @@ pub mod extension_data; mod legacy; pub mod session_manager; +pub use extension_data::{EnabledExtensionsState, ExtensionData, ExtensionState, TodoState}; pub use session_manager::{Session, SessionInsights, SessionManager};