diff --git a/crates/forge_api/src/api.rs b/crates/forge_api/src/api.rs index ceef7975d9..aafb112d49 100644 --- a/crates/forge_api/src/api.rs +++ b/crates/forge_api/src/api.rs @@ -127,7 +127,15 @@ pub trait API: Sync + Send { /// Retrieves the provider configuration for the specified agent async fn get_agent_provider(&self, agent_id: AgentId) -> anyhow::Result>; - /// Retrieves the provider configuration for the default agent + /// Gets the current session configuration (provider and model pair). + /// + /// Returns `None` when no session has been configured yet, allowing callers + /// to distinguish between "not configured" and an actual error. + async fn get_session_config(&self) -> Option; + + /// Retrieves the provider configuration for the default agent. + /// + /// Delegates to [`Self::get_session_config`] and resolves the provider. async fn get_default_provider(&self) -> anyhow::Result>; /// Applies one or more configuration mutations atomically. @@ -153,9 +161,6 @@ pub trait API: Sync + Send { /// Gets the model for the specified agent async fn get_agent_model(&self, agent_id: AgentId) -> Option; - /// Gets the default model - async fn get_default_model(&self) -> Option; - /// Gets the commit configuration (provider and model for commit message /// generation). async fn get_commit_config(&self) -> anyhow::Result>; diff --git a/crates/forge_api/src/forge_api.rs b/crates/forge_api/src/forge_api.rs index 1372a314ec..aca7637afc 100644 --- a/crates/forge_api/src/forge_api.rs +++ b/crates/forge_api/src/forge_api.rs @@ -297,10 +297,6 @@ impl< agent_provider_resolver.get_model(Some(agent_id)).await.ok() } - async fn get_default_model(&self) -> Option { - self.services.get_provider_model(None).await.ok() - } - async fn reload_mcp(&self) -> Result<()> { self.services.mcp_service().reload_mcp().await } @@ -402,9 +398,17 @@ impl< app.execute(data_parameters).await } + async fn get_session_config(&self) -> Option { + self.services.get_session_config().await + } + async fn get_default_provider(&self) -> Result> { - let provider_id = self.services.get_default_provider().await?; - self.services.get_provider(provider_id).await + let model_config = self + .services + .get_session_config() + .await + .ok_or_else(|| forge_domain::Error::NoDefaultSession)?; + self.services.get_provider(model_config.provider).await } async fn mcp_auth(&self, server_url: &str) -> Result<()> { diff --git a/crates/forge_app/src/agent.rs b/crates/forge_app/src/agent.rs index b189b0a547..8a72fecc07 100644 --- a/crates/forge_app/src/agent.rs +++ b/crates/forge_app/src/agent.rs @@ -48,7 +48,10 @@ impl> AgentSe let provider_id = if let Some(provider_id) = provider_id { provider_id } else { - self.get_default_provider().await? + self.get_session_config() + .await + .map(|c| c.provider) + .ok_or_else(|| forge_domain::Error::NoDefaultSession)? }; let provider = self.get_provider(provider_id).await?; diff --git a/crates/forge_app/src/agent_provider_resolver.rs b/crates/forge_app/src/agent_provider_resolver.rs index 14eba1d0c0..82e5a48197 100644 --- a/crates/forge_app/src/agent_provider_resolver.rs +++ b/crates/forge_app/src/agent_provider_resolver.rs @@ -33,10 +33,18 @@ where } else { // TODO: Needs review, should we throw an err here? // we can throw crate::Error::AgentNotFound - self.0.get_default_provider().await? + self.0 + .get_session_config() + .await + .map(|c| c.provider) + .ok_or_else(|| forge_domain::Error::NoDefaultSession)? } } else { - self.0.get_default_provider().await? + self.0 + .get_session_config() + .await + .map(|c| c.provider) + .ok_or_else(|| forge_domain::Error::NoDefaultSession)? }; let provider = self.0.get_provider(provider_id).await?; @@ -52,12 +60,18 @@ where } else { // TODO: Needs review, should we throw an err here? // we can throw crate::Error::AgentNotFound - let provider_id = self.get_provider(Some(agent_id)).await?.id; - Ok(self.0.get_provider_model(Some(&provider_id)).await?) + self.0 + .get_session_config() + .await + .map(|c| c.model) + .ok_or_else(|| forge_domain::Error::NoDefaultSession.into()) } } else { - let provider_id = self.get_provider(None).await?.id; - Ok(self.0.get_provider_model(Some(&provider_id)).await?) + self.0 + .get_session_config() + .await + .map(|c| c.model) + .ok_or_else(|| forge_domain::Error::NoDefaultSession.into()) } } } diff --git a/crates/forge_app/src/command_generator.rs b/crates/forge_app/src/command_generator.rs index 78ac004569..122fbc2ec8 100644 --- a/crates/forge_app/src/command_generator.rs +++ b/crates/forge_app/src/command_generator.rs @@ -62,10 +62,13 @@ where (provider, config.model) } None => { - let provider_id = self.services.get_default_provider().await?; - let provider = self.services.get_provider(provider_id).await?; - let model = self.services.get_provider_model(Some(&provider.id)).await?; - (provider, model) + let model_config = self + .services + .get_session_config() + .await + .ok_or_else(|| forge_domain::Error::NoDefaultSession)?; + let provider = self.services.get_provider(model_config.provider).await?; + (provider, model_config.model) } }; @@ -296,15 +299,11 @@ mod tests { #[async_trait::async_trait] impl AppConfigService for MockServices { - async fn get_default_provider(&self) -> Result { - Ok(ProviderId::OPENAI) - } - - async fn get_provider_model( - &self, - _provider_id: Option<&ProviderId>, - ) -> anyhow::Result { - Ok(ModelId::new("test-model")) + async fn get_session_config(&self) -> Option { + Some(forge_domain::ModelConfig::new( + ProviderId::OPENAI, + ModelId::new("test-model"), + )) } async fn get_commit_config(&self) -> Result> { diff --git a/crates/forge_app/src/data_gen.rs b/crates/forge_app/src/data_gen.rs index da80ec7061..e93313bace 100644 --- a/crates/forge_app/src/data_gen.rs +++ b/crates/forge_app/src/data_gen.rs @@ -93,9 +93,13 @@ impl DataGenerationApp { concurrency ); - let provider_id = self.services.get_default_provider().await?; - let provider = self.services.get_provider(provider_id).await?; - let model_id = self.services.get_provider_model(Some(&provider.id)).await?; + let model_config = self + .services + .get_session_config() + .await + .ok_or_else(|| forge_domain::Error::NoDefaultSession)?; + let provider = self.services.get_provider(model_config.provider).await?; + let model_id = model_config.model; debug!("Using provider: {}, model: {}", provider.id, model_id); let schema: Schema = serde_json::from_str(&schema).with_context(|| "Could not parse the JSON schema")?; diff --git a/crates/forge_app/src/services.rs b/crates/forge_app/src/services.rs index 4ec2c49809..59f88f3be7 100644 --- a/crates/forge_app/src/services.rs +++ b/crates/forge_app/src/services.rs @@ -181,20 +181,10 @@ pub trait ProviderService: Send + Sync { /// Manages user preferences for default providers and models. #[async_trait::async_trait] pub trait AppConfigService: Send + Sync { - /// Gets the user's default provider ID. - async fn get_default_provider(&self) -> anyhow::Result; - - /// Gets the user's default model for a specific provider or the currently - /// active provider. When provider_id is None, uses the currently active - /// provider. + /// Gets the current session configuration (provider and model pair). /// - /// # Errors - /// - Returns `Error::NoDefaultSession` when no provider and model are - /// configured. - async fn get_provider_model( - &self, - provider_id: Option<&forge_domain::ProviderId>, - ) -> anyhow::Result; + /// Returns `None` when no session has been configured yet. + async fn get_session_config(&self) -> Option; /// Gets the commit configuration (provider and model for commit message /// generation). @@ -956,15 +946,8 @@ impl PolicyService for I { #[async_trait::async_trait] impl AppConfigService for I { - async fn get_default_provider(&self) -> anyhow::Result { - self.config_service().get_default_provider().await - } - - async fn get_provider_model( - &self, - provider_id: Option<&forge_domain::ProviderId>, - ) -> anyhow::Result { - self.config_service().get_provider_model(provider_id).await + async fn get_session_config(&self) -> Option { + self.config_service().get_session_config().await } async fn get_commit_config(&self) -> anyhow::Result> { diff --git a/crates/forge_main/src/prompt.rs b/crates/forge_main/src/prompt.rs index b95ce2608e..4001fce80b 100644 --- a/crates/forge_main/src/prompt.rs +++ b/crates/forge_main/src/prompt.rs @@ -249,8 +249,7 @@ mod tests { #[test] fn test_render_prompt_left_with_branch() { - let mut prompt = ForgePrompt::default(); - prompt.git_branch = Some("main".to_string()); + let prompt = ForgePrompt { git_branch: Some("main".to_string()), ..Default::default() }; let actual = prompt.render_prompt_left(); // Agent name is on the right prompt, not the left diff --git a/crates/forge_main/src/ui.rs b/crates/forge_main/src/ui.rs index 76f7739213..e4910997d1 100644 --- a/crates/forge_main/src/ui.rs +++ b/crates/forge_main/src/ui.rs @@ -145,7 +145,7 @@ impl A + Send + Sync> UI async fn get_agent_model(&self, agent_id: Option) -> Option { match agent_id { Some(agent_id) => self.api.get_agent_model(agent_id).await, - None => self.api.get_default_model().await, + None => self.api.get_session_config().await.map(|c| c.model), } } @@ -2711,7 +2711,7 @@ impl A + Send + Sync> UI provider_filter: Option, ) -> Result> { // Check if provider is set otherwise first ask to select a provider - if provider_filter.is_none() && self.api.get_default_provider().await.is_err() { + if provider_filter.is_none() && self.api.get_session_config().await.is_none() { if !self.on_provider_selection().await? { return Ok(None); } @@ -2719,7 +2719,7 @@ impl A + Send + Sync> UI // Provider activation may have already completed model selection. // If it did not, continue below and show the full cross-provider // model list. - if self.api.get_default_model().await.is_some() { + if self.api.get_session_config().await.is_some() { return Ok(None); } } @@ -3424,7 +3424,7 @@ impl A + Send + Sync> UI self.activate_provider(any_provider).await?; // Check if provider was actually saved — if user cancelled model selection // inside activate_provider, nothing was written - Ok(self.api.get_default_provider().await.is_ok()) + Ok(self.api.get_session_config().await.is_some()) } /// Activates a provider by configuring it if needed, setting it as default, @@ -3493,7 +3493,7 @@ impl A + Send + Sync> UI } // Check if the current model is available for the new provider - let current_model = self.api.get_default_model().await; + let current_model = self.api.get_session_config().await.map(|c| c.model); let (needs_model_selection, compatible_model) = match current_model { None => (true, None), Some(current_model) => { @@ -3642,7 +3642,7 @@ impl A + Send + Sync> UI // Validate provider is configured before loading agents // If provider is set in config but not configured (no credentials), prompt user // to login - if self.api.get_default_provider().await.is_err() && !self.on_provider_selection().await? { + if self.api.get_session_config().await.is_none() && !self.on_provider_selection().await? { return Ok(()); } @@ -4274,9 +4274,9 @@ impl A + Send + Sync> UI ConfigGetField::Model => { let model = self .api - .get_default_model() + .get_session_config() .await - .map(|m| m.as_str().to_string()); + .map(|c| c.model.as_str().to_string()); match model { Some(v) => self.writeln(v.to_string())?, None => self.writeln("Model: Not set")?, @@ -4285,10 +4285,9 @@ impl A + Send + Sync> UI ConfigGetField::Provider => { let provider = self .api - .get_default_provider() + .get_session_config() .await - .ok() - .map(|p| p.id.to_string()); + .map(|c| c.provider.to_string()); match provider { Some(v) => self.writeln(v.to_string())?, None => self.writeln("Provider: Not set")?, @@ -4335,13 +4334,16 @@ impl A + Send + Sync> UI .and_then(|str| ConversationId::from_str(str.as_str()).ok()); // Make IO calls in parallel - let (model_id, conversation) = tokio::join!(self.api.get_default_model(), async { - if let Some(cid) = cid { - self.api.conversation(&cid).await.ok().flatten() - } else { - None + let (model_id, conversation) = tokio::join!( + async { self.api.get_session_config().await.map(|c| c.model) }, + async { + if let Some(cid) = cid { + self.api.conversation(&cid).await.ok().flatten() + } else { + None + } } - }); + ); // Calculate total cost including related conversations let cost = if let Some(ref conv) = conversation { diff --git a/crates/forge_services/src/app_config.rs b/crates/forge_services/src/app_config.rs index 45551d63f7..3e279aae9b 100644 --- a/crates/forge_services/src/app_config.rs +++ b/crates/forge_services/src/app_config.rs @@ -23,38 +23,13 @@ impl ForgeAppConfigService { impl + Send + Sync> AppConfigService for ForgeAppConfigService { - async fn get_default_provider(&self) -> anyhow::Result { - let config = self.infra.get_config()?; - config - .session - .as_ref() - .map(|s| ProviderId::from(s.provider_id.clone())) - .ok_or_else(|| forge_domain::Error::NoDefaultSession.into()) - } - - async fn get_provider_model( - &self, - provider_id: Option<&ProviderId>, - ) -> anyhow::Result { - let config = self.infra.get_config()?; - - let session = config - .session - .as_ref() - .ok_or(forge_domain::Error::NoDefaultSession)?; - - // Use the requested provider or the session's active provider - let requested_provider = match provider_id { - Some(id) => id.as_ref(), - None => session.provider_id.as_str(), - }; - - // Return the session's model if the provider matches - if session.provider_id == requested_provider { - Ok(ModelId::new(session.model_id.clone())) - } else { - Err(forge_domain::Error::NoDefaultSession.into()) - } + async fn get_session_config(&self) -> Option { + let config = self.infra.get_config().ok()?; + let session = config.session.as_ref()?; + Some(ModelConfig { + provider: ProviderId::from(session.provider_id.clone()), + model: ModelId::new(session.model_id.clone()), + }) } async fn get_commit_config(&self) -> anyhow::Result> { @@ -325,18 +300,18 @@ mod tests { } #[tokio::test] - async fn test_get_default_provider_when_none_set() -> anyhow::Result<()> { + async fn test_get_session_config_when_none_set() -> anyhow::Result<()> { let fixture = MockInfra::new(); let service = ForgeAppConfigService::new(Arc::new(fixture)); - let result = service.get_default_provider().await; + let result = service.get_session_config().await; - assert!(result.is_err()); + assert!(result.is_none()); Ok(()) } #[tokio::test] - async fn test_get_default_provider_when_set() -> anyhow::Result<()> { + async fn test_get_session_config_when_set() -> anyhow::Result<()> { let fixture = MockInfra::new(); let service = ForgeAppConfigService::new(Arc::new(fixture.clone())); @@ -345,16 +320,18 @@ mod tests { DomainModelConfig::new(ProviderId::ANTHROPIC, ModelId::new("claude-3")), )]) .await?; - let actual = service.get_default_provider().await?; - let expected = ProviderId::ANTHROPIC; + let actual = service.get_session_config().await; + let expected = Some(DomainModelConfig::new( + ProviderId::ANTHROPIC, + ModelId::new("claude-3"), + )); assert_eq!(actual, expected); Ok(()) } #[tokio::test] - async fn test_get_default_provider_when_configured_provider_not_available() -> anyhow::Result<()> - { + async fn test_get_session_config_when_provider_not_available() -> anyhow::Result<()> { let mut fixture = MockInfra::new(); // Remove OpenAI from available providers but keep it in config fixture.providers.retain(|p| p.id != ProviderId::OPENAI); @@ -367,16 +344,22 @@ mod tests { )]) .await?; - // Should return the provider ID even if provider is not available + // Should return the config even if provider is not available // Validation happens when getting the actual provider via ProviderService - let result = service.get_default_provider().await?; - - assert_eq!(result, ProviderId::OPENAI); + let result = service.get_session_config().await; + + assert_eq!( + result, + Some(DomainModelConfig::new( + ProviderId::OPENAI, + ModelId::new("gpt-4") + )) + ); Ok(()) } #[tokio::test] - async fn test_set_default_provider() -> anyhow::Result<()> { + async fn test_set_session_config() -> anyhow::Result<()> { let fixture = MockInfra::new(); let service = ForgeAppConfigService::new(Arc::new(fixture.clone())); @@ -386,60 +369,57 @@ mod tests { )]) .await?; - let actual = service.get_default_provider().await?; - let expected = ProviderId::ANTHROPIC; + let actual = service.get_session_config().await; + let expected = Some(DomainModelConfig::new( + ProviderId::ANTHROPIC, + ModelId::new("claude-3"), + )); assert_eq!(actual, expected); Ok(()) } #[tokio::test] - async fn test_get_default_model_when_none_set() -> anyhow::Result<()> { + async fn test_get_session_config_model_when_none_set() -> anyhow::Result<()> { let fixture = MockInfra::new(); let service = ForgeAppConfigService::new(Arc::new(fixture)); - let result = service.get_provider_model(Some(&ProviderId::OPENAI)).await; + let result = service.get_session_config().await; - assert!(result.is_err()); + assert!(result.is_none()); Ok(()) } #[tokio::test] - async fn test_get_default_model_when_set() -> anyhow::Result<()> { + async fn test_get_session_config_model_when_set() -> anyhow::Result<()> { let fixture = MockInfra::new(); let service = ForgeAppConfigService::new(Arc::new(fixture.clone())); - // Set OpenAI as the default provider first, then set model atomically service .update_config(vec![ConfigOperation::SetSessionConfig( DomainModelConfig::new(ProviderId::OPENAI, ModelId::new("gpt-4")), )]) .await?; - let actual = service - .get_provider_model(Some(&ProviderId::OPENAI)) - .await?; - let expected = "gpt-4".to_string().into(); + let actual = service.get_session_config().await.map(|c| c.model); + let expected = Some(ModelId::new("gpt-4")); assert_eq!(actual, expected); Ok(()) } #[tokio::test] - async fn test_set_default_model() -> anyhow::Result<()> { + async fn test_set_session_config_model() -> anyhow::Result<()> { let fixture = MockInfra::new(); let service = ForgeAppConfigService::new(Arc::new(fixture.clone())); - // Set provider and model atomically service .update_config(vec![ConfigOperation::SetSessionConfig( DomainModelConfig::new(ProviderId::OPENAI, ModelId::from("gpt-4".to_string())), )]) .await?; - let actual = service - .get_provider_model(Some(&ProviderId::OPENAI)) - .await?; - let expected = "gpt-4".to_string().into(); + let actual = service.get_session_config().await.map(|c| c.model); + let expected = Some(ModelId::from("gpt-4".to_string())); assert_eq!(actual, expected); Ok(()) @@ -469,13 +449,13 @@ mod tests { // ForgeConfig only tracks a single active session, so the last // provider/model pair wins - let actual_provider = service.get_default_provider().await?; - let actual_model = service - .get_provider_model(Some(&ProviderId::ANTHROPIC)) - .await?; + let actual = service.get_session_config().await; + let expected = Some(DomainModelConfig::new( + ProviderId::ANTHROPIC, + ModelId::new("claude-3"), + )); - assert_eq!(actual_provider, ProviderId::ANTHROPIC); - assert_eq!(actual_model, ModelId::new("claude-3")); + assert_eq!(actual, expected); Ok(()) } }