Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions crates/forge_api/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,15 @@ pub trait API: Sync + Send {
/// Retrieves the provider configuration for the specified agent
async fn get_agent_provider(&self, agent_id: AgentId) -> anyhow::Result<Provider<Url>>;

/// Retrieves the provider configuration for the default agent
/// Gets the current session configuration (provider and model pair).
///
/// Returns `None` when no session has been configured yet, allowing callers
/// to distinguish between "not configured" and an actual error.
async fn get_session_config(&self) -> Option<forge_domain::ModelConfig>;

/// Retrieves the provider configuration for the default agent.
///
/// Delegates to [`Self::get_session_config`] and resolves the provider.
async fn get_default_provider(&self) -> anyhow::Result<Provider<Url>>;

/// Applies one or more configuration mutations atomically.
Expand All @@ -153,9 +161,6 @@ pub trait API: Sync + Send {
/// Gets the model for the specified agent
async fn get_agent_model(&self, agent_id: AgentId) -> Option<ModelId>;

/// Gets the default model
async fn get_default_model(&self) -> Option<ModelId>;

/// Gets the commit configuration (provider and model for commit message
/// generation).
async fn get_commit_config(&self) -> anyhow::Result<Option<forge_domain::ModelConfig>>;
Expand Down
16 changes: 10 additions & 6 deletions crates/forge_api/src/forge_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,10 +297,6 @@ impl<
agent_provider_resolver.get_model(Some(agent_id)).await.ok()
}

async fn get_default_model(&self) -> Option<ModelId> {
self.services.get_provider_model(None).await.ok()
}

async fn reload_mcp(&self) -> Result<()> {
self.services.mcp_service().reload_mcp().await
}
Expand Down Expand Up @@ -402,9 +398,17 @@ impl<
app.execute(data_parameters).await
}

async fn get_session_config(&self) -> Option<forge_domain::ModelConfig> {
self.services.get_session_config().await
}

async fn get_default_provider(&self) -> Result<Provider<Url>> {
let provider_id = self.services.get_default_provider().await?;
self.services.get_provider(provider_id).await
let model_config = self
.services
.get_session_config()
.await
.ok_or_else(|| forge_domain::Error::NoDefaultSession)?;
self.services.get_provider(model_config.provider).await
}

async fn mcp_auth(&self, server_url: &str) -> Result<()> {
Expand Down
5 changes: 4 additions & 1 deletion crates/forge_app/src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ impl<T: Services + EnvironmentInfra<Config = forge_config::ForgeConfig>> AgentSe
let provider_id = if let Some(provider_id) = provider_id {
provider_id
} else {
self.get_default_provider().await?
self.get_session_config()
.await
.map(|c| c.provider)
.ok_or_else(|| forge_domain::Error::NoDefaultSession)?
};
let provider = self.get_provider(provider_id).await?;

Expand Down
26 changes: 20 additions & 6 deletions crates/forge_app/src/agent_provider_resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,18 @@ where
} else {
// TODO: Needs review, should we throw an err here?
// we can throw crate::Error::AgentNotFound
self.0.get_default_provider().await?
self.0
.get_session_config()
.await
.map(|c| c.provider)
.ok_or_else(|| forge_domain::Error::NoDefaultSession)?
}
} else {
self.0.get_default_provider().await?
self.0
.get_session_config()
.await
.map(|c| c.provider)
.ok_or_else(|| forge_domain::Error::NoDefaultSession)?
};

let provider = self.0.get_provider(provider_id).await?;
Expand All @@ -52,12 +60,18 @@ where
} else {
// TODO: Needs review, should we throw an err here?
// we can throw crate::Error::AgentNotFound
let provider_id = self.get_provider(Some(agent_id)).await?.id;
Ok(self.0.get_provider_model(Some(&provider_id)).await?)
self.0
.get_session_config()
.await
.map(|c| c.model)
.ok_or_else(|| forge_domain::Error::NoDefaultSession.into())
}
} else {
let provider_id = self.get_provider(None).await?.id;
Ok(self.0.get_provider_model(Some(&provider_id)).await?)
self.0
.get_session_config()
.await
.map(|c| c.model)
.ok_or_else(|| forge_domain::Error::NoDefaultSession.into())
}
}
}
25 changes: 12 additions & 13 deletions crates/forge_app/src/command_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,13 @@ where
(provider, config.model)
}
None => {
let provider_id = self.services.get_default_provider().await?;
let provider = self.services.get_provider(provider_id).await?;
let model = self.services.get_provider_model(Some(&provider.id)).await?;
(provider, model)
let model_config = self
.services
.get_session_config()
.await
.ok_or_else(|| forge_domain::Error::NoDefaultSession)?;
let provider = self.services.get_provider(model_config.provider).await?;
(provider, model_config.model)
}
};

Expand Down Expand Up @@ -296,15 +299,11 @@ mod tests {

#[async_trait::async_trait]
impl AppConfigService for MockServices {
async fn get_default_provider(&self) -> Result<ProviderId> {
Ok(ProviderId::OPENAI)
}

async fn get_provider_model(
&self,
_provider_id: Option<&ProviderId>,
) -> anyhow::Result<ModelId> {
Ok(ModelId::new("test-model"))
async fn get_session_config(&self) -> Option<forge_domain::ModelConfig> {
Some(forge_domain::ModelConfig::new(
ProviderId::OPENAI,
ModelId::new("test-model"),
))
}

async fn get_commit_config(&self) -> Result<Option<forge_domain::ModelConfig>> {
Expand Down
10 changes: 7 additions & 3 deletions crates/forge_app/src/data_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,13 @@ impl<A: Services> DataGenerationApp<A> {
concurrency
);

let provider_id = self.services.get_default_provider().await?;
let provider = self.services.get_provider(provider_id).await?;
let model_id = self.services.get_provider_model(Some(&provider.id)).await?;
let model_config = self
.services
.get_session_config()
.await
.ok_or_else(|| forge_domain::Error::NoDefaultSession)?;
let provider = self.services.get_provider(model_config.provider).await?;
let model_id = model_config.model;
debug!("Using provider: {}, model: {}", provider.id, model_id);
let schema: Schema =
serde_json::from_str(&schema).with_context(|| "Could not parse the JSON schema")?;
Expand Down
27 changes: 5 additions & 22 deletions crates/forge_app/src/services.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,20 +181,10 @@ pub trait ProviderService: Send + Sync {
/// Manages user preferences for default providers and models.
#[async_trait::async_trait]
pub trait AppConfigService: Send + Sync {
/// Gets the user's default provider ID.
async fn get_default_provider(&self) -> anyhow::Result<ProviderId>;

/// Gets the user's default model for a specific provider or the currently
/// active provider. When provider_id is None, uses the currently active
/// provider.
/// Gets the current session configuration (provider and model pair).
///
/// # Errors
/// - Returns `Error::NoDefaultSession` when no provider and model are
/// configured.
async fn get_provider_model(
&self,
provider_id: Option<&forge_domain::ProviderId>,
) -> anyhow::Result<ModelId>;
/// Returns `None` when no session has been configured yet.
async fn get_session_config(&self) -> Option<forge_domain::ModelConfig>;

/// Gets the commit configuration (provider and model for commit message
/// generation).
Expand Down Expand Up @@ -956,15 +946,8 @@ impl<I: Services> PolicyService for I {

#[async_trait::async_trait]
impl<I: Services> AppConfigService for I {
async fn get_default_provider(&self) -> anyhow::Result<ProviderId> {
self.config_service().get_default_provider().await
}

async fn get_provider_model(
&self,
provider_id: Option<&forge_domain::ProviderId>,
) -> anyhow::Result<ModelId> {
self.config_service().get_provider_model(provider_id).await
async fn get_session_config(&self) -> Option<forge_domain::ModelConfig> {
self.config_service().get_session_config().await
}

async fn get_commit_config(&self) -> anyhow::Result<Option<forge_domain::ModelConfig>> {
Expand Down
3 changes: 1 addition & 2 deletions crates/forge_main/src/prompt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,7 @@ mod tests {

#[test]
fn test_render_prompt_left_with_branch() {
let mut prompt = ForgePrompt::default();
prompt.git_branch = Some("main".to_string());
let prompt = ForgePrompt { git_branch: Some("main".to_string()), ..Default::default() };
let actual = prompt.render_prompt_left();

// Agent name is on the right prompt, not the left
Expand Down
36 changes: 19 additions & 17 deletions crates/forge_main/src/ui.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ impl<A: API + ConsoleWriter + 'static, F: Fn(ForgeConfig) -> A + Send + Sync> UI
async fn get_agent_model(&self, agent_id: Option<AgentId>) -> Option<ModelId> {
match agent_id {
Some(agent_id) => self.api.get_agent_model(agent_id).await,
None => self.api.get_default_model().await,
None => self.api.get_session_config().await.map(|c| c.model),
}
}

Expand Down Expand Up @@ -2711,15 +2711,15 @@ impl<A: API + ConsoleWriter + 'static, F: Fn(ForgeConfig) -> A + Send + Sync> UI
provider_filter: Option<ProviderId>,
) -> Result<Option<(ModelId, ProviderId)>> {
// Check if provider is set otherwise first ask to select a provider
if provider_filter.is_none() && self.api.get_default_provider().await.is_err() {
if provider_filter.is_none() && self.api.get_session_config().await.is_none() {
if !self.on_provider_selection().await? {
return Ok(None);
}

// Provider activation may have already completed model selection.
// If it did not, continue below and show the full cross-provider
// model list.
if self.api.get_default_model().await.is_some() {
if self.api.get_session_config().await.is_some() {
return Ok(None);
}
}
Expand Down Expand Up @@ -3424,7 +3424,7 @@ impl<A: API + ConsoleWriter + 'static, F: Fn(ForgeConfig) -> A + Send + Sync> UI
self.activate_provider(any_provider).await?;
// Check if provider was actually saved — if user cancelled model selection
// inside activate_provider, nothing was written
Ok(self.api.get_default_provider().await.is_ok())
Ok(self.api.get_session_config().await.is_some())
}

/// Activates a provider by configuring it if needed, setting it as default,
Expand Down Expand Up @@ -3493,7 +3493,7 @@ impl<A: API + ConsoleWriter + 'static, F: Fn(ForgeConfig) -> A + Send + Sync> UI
}

// Check if the current model is available for the new provider
let current_model = self.api.get_default_model().await;
let current_model = self.api.get_session_config().await.map(|c| c.model);
let (needs_model_selection, compatible_model) = match current_model {
None => (true, None),
Some(current_model) => {
Expand Down Expand Up @@ -3642,7 +3642,7 @@ impl<A: API + ConsoleWriter + 'static, F: Fn(ForgeConfig) -> A + Send + Sync> UI
// Validate provider is configured before loading agents
// If provider is set in config but not configured (no credentials), prompt user
// to login
if self.api.get_default_provider().await.is_err() && !self.on_provider_selection().await? {
if self.api.get_session_config().await.is_none() && !self.on_provider_selection().await? {
return Ok(());
}

Expand Down Expand Up @@ -4274,9 +4274,9 @@ impl<A: API + ConsoleWriter + 'static, F: Fn(ForgeConfig) -> A + Send + Sync> UI
ConfigGetField::Model => {
let model = self
.api
.get_default_model()
.get_session_config()
.await
.map(|m| m.as_str().to_string());
.map(|c| c.model.as_str().to_string());
match model {
Some(v) => self.writeln(v.to_string())?,
None => self.writeln("Model: Not set")?,
Expand All @@ -4285,10 +4285,9 @@ impl<A: API + ConsoleWriter + 'static, F: Fn(ForgeConfig) -> A + Send + Sync> UI
ConfigGetField::Provider => {
let provider = self
.api
.get_default_provider()
.get_session_config()
.await
.ok()
.map(|p| p.id.to_string());
.map(|c| c.provider.to_string());
match provider {
Some(v) => self.writeln(v.to_string())?,
None => self.writeln("Provider: Not set")?,
Expand Down Expand Up @@ -4335,13 +4334,16 @@ impl<A: API + ConsoleWriter + 'static, F: Fn(ForgeConfig) -> A + Send + Sync> UI
.and_then(|str| ConversationId::from_str(str.as_str()).ok());

// Make IO calls in parallel
let (model_id, conversation) = tokio::join!(self.api.get_default_model(), async {
if let Some(cid) = cid {
self.api.conversation(&cid).await.ok().flatten()
} else {
None
let (model_id, conversation) = tokio::join!(
async { self.api.get_session_config().await.map(|c| c.model) },
async {
if let Some(cid) = cid {
self.api.conversation(&cid).await.ok().flatten()
} else {
None
}
}
});
);

// Calculate total cost including related conversations
let cost = if let Some(ref conv) = conversation {
Expand Down
Loading
Loading