Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions crates/forge_api/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,15 @@ pub trait API: Sync + Send {
/// Sets the default provider for all the agents
async fn set_default_provider(&self, provider_id: ProviderId) -> anyhow::Result<()>;

/// Updates the caller's default provider and model together, ensuring all
/// commands resolve a consistent pair without requiring a follow-up model
/// selection call.
async fn set_default_provider_and_model(
&self,
provider_id: ProviderId,
model: ModelId,
) -> anyhow::Result<()>;

/// Retrieves information about the currently authenticated user
async fn user_info(&self) -> anyhow::Result<Option<User>>;

Expand Down
13 changes: 13 additions & 0 deletions crates/forge_api/src/forge_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,19 @@ impl<A: Services, F: CommandInfra + EnvironmentInfra + SkillRepository + GrpcInf
result
}

async fn set_default_provider_and_model(
&self,
provider_id: ProviderId,
model: ModelId,
) -> anyhow::Result<()> {
let result = self
.services
.set_default_provider_and_model(provider_id, model)
.await;
let _ = self.services.reload_agents().await;
result
}

async fn get_commit_config(&self) -> anyhow::Result<Option<CommitConfig>> {
self.services.get_commit_config().await
}
Expand Down
8 changes: 8 additions & 0 deletions crates/forge_app/src/command_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,14 @@ mod tests {
Ok(())
}

async fn set_default_provider_and_model(
&self,
_provider_id: ProviderId,
_model: ModelId,
) -> anyhow::Result<()> {
Ok(())
}

async fn get_commit_config(&self) -> Result<Option<forge_domain::CommitConfig>> {
Ok(None)
}
Expand Down
18 changes: 18 additions & 0 deletions crates/forge_app/src/services.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,14 @@ pub trait AppConfigService: Send + Sync {
/// Returns an error if no default provider is configured.
async fn set_default_model(&self, model: ModelId) -> anyhow::Result<()>;

/// Sets the user's default provider and default model in a single atomic
/// update so the persisted configuration never stores a mismatched pair.
async fn set_default_provider_and_model(
&self,
provider_id: ProviderId,
model: ModelId,
) -> anyhow::Result<()>;

/// Gets the commit configuration (provider and model for commit message
/// generation).
async fn get_commit_config(&self) -> anyhow::Result<Option<forge_domain::CommitConfig>>;
Expand Down Expand Up @@ -971,6 +979,16 @@ impl<I: Services> AppConfigService for I {
self.config_service().get_provider_model(provider_id).await
}

async fn set_default_provider_and_model(
&self,
provider_id: forge_domain::ProviderId,
model: ModelId,
) -> anyhow::Result<()> {
self.config_service()
.set_default_provider_and_model(provider_id, model)
.await
}

async fn set_default_model(&self, model: ModelId) -> anyhow::Result<()> {
self.config_service().set_default_model(model).await
}
Expand Down
4 changes: 0 additions & 4 deletions crates/forge_main/src/built_in_commands.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@
"command": "env",
"description": "Display environment information [alias: e]"
},
{
"command": "config-provider",
"description": "Switch the providers [alias: p, provider]"
},
{
"command": "config-model",
"description": "Switch the models [alias: cm]"
Expand Down
132 changes: 67 additions & 65 deletions crates/forge_main/src/ui.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use convert_case::{Case, Casing};
use forge_api::{
API, AgentId, AnyProvider, ApiKeyRequest, AuthContextRequest, AuthContextResponse, ChatRequest,
ChatResponse, CodeRequest, Conversation, ConversationId, DeviceCodeRequest, Event,
InterruptionReason, Model, ModelId, Provider, ProviderId, TextMessage, UserPrompt,
InterruptionReason, ModelId, Provider, ProviderId, TextMessage, UserPrompt,
};
use forge_app::utils::{format_display_path, truncate_key};
use forge_app::{CommitResult, ToolResolver};
Expand Down Expand Up @@ -127,14 +127,6 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
self.spinner.ewrite_ln(title)
}

/// Retrieve available models
async fn get_models(&mut self) -> Result<Vec<Model>> {
self.spinner.start(Some("Loading"))?;
let models = self.api.get_models().await?;
self.spinner.stop(None)?;
Ok(models)
}

/// Helper to get provider for an optional agent, defaulting to the current
/// active agent's provider
async fn get_provider(&self, agent_id: Option<AgentId>) -> Result<Provider<Url>> {
Expand Down Expand Up @@ -649,6 +641,7 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
return Ok(());
}
TopLevelCommand::Commit(commit_group) => {
self.init_state(false).await?;
let preview = commit_group.preview;
let result = self.handle_commit_command(commit_group).await?;
if preview {
Expand Down Expand Up @@ -1899,7 +1892,7 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
self.on_custom_event(event.into()).await?;
}
SlashCommand::Model => {
self.on_model_selection(None).await?;
self.on_model_selection(None, None).await?;
}
SlashCommand::Provider => {
self.on_provider_selection().await?;
Expand Down Expand Up @@ -2074,15 +2067,11 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
provider_filter: Option<ProviderId>,
) -> Result<Option<ModelId>> {
// Check if provider is set otherwise first ask to select a provider
if self.api.get_default_provider().await.is_err() {
self.on_provider_selection().await?;

// Check if a model was already selected during provider activation
// Return None to signal the model selection is complete and message was already
// printed
if self.api.get_default_model().await.is_some() {
if provider_filter.is_none() && self.api.get_default_provider().await.is_err() {
if !self.on_provider_selection().await? {
return Ok(None);
}
return Ok(None);
}

// Fetch models from ALL configured providers (matches shell plugin's
Expand Down Expand Up @@ -2401,22 +2390,12 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
Ok(())
}

async fn display_credential_success(
&mut self,
provider_id: ProviderId,
) -> anyhow::Result<bool> {
async fn display_credential_success(&mut self, provider_id: ProviderId) -> anyhow::Result<()> {
self.writeln_title(TitleFormat::info(format!(
"{provider_id} configured successfully!"
)))?;

// Prompt user to set as active provider
let should_set_active = ForgeWidget::confirm(format!(
"Would you like to set {provider_id} as the active provider?"
))
.with_default(true)
.prompt()?;

Ok(should_set_active.unwrap_or(false))
Ok(())
}

async fn handle_code_flow(
Expand Down Expand Up @@ -2602,11 +2581,7 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
}
}

let should_set_active = self.display_credential_success(provider_id.clone()).await?;

if !should_set_active {
return Ok(None);
}
self.display_credential_success(provider_id.clone()).await?;

// Fetch and return the configured provider
let provider = self.api.get_provider(&provider_id).await?;
Expand Down Expand Up @@ -2736,6 +2711,7 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
async fn on_model_selection(
&mut self,
provider_filter: Option<ProviderId>,
provider_to_activate: Option<ProviderId>,
) -> Result<Option<ModelId>> {
// Select a model
let model_option = self.select_model(provider_filter).await?;
Expand All @@ -2746,8 +2722,14 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
None => return Ok(None),
};

// Update the operating model via API
self.api.set_default_model(model.clone()).await?;
// If we have a provider to activate, write both atomically
if let Some(provider_id) = provider_to_activate {
self.api
.set_default_provider_and_model(provider_id, model.clone())
.await?;
} else {
self.api.set_default_model(model.clone()).await?;
}

// Update the UI state with the new model
self.update_model(Some(model.clone()));
Expand All @@ -2757,15 +2739,18 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
Ok(Some(model))
}

async fn on_provider_selection(&mut self) -> Result<()> {
async fn on_provider_selection(&mut self) -> Result<bool> {
// Select a provider
// If no provider was selected (user canceled), return early
let any_provider = match self.select_provider().await? {
Some(provider) => provider,
None => return Ok(()),
None => return Ok(false),
};

self.activate_provider(any_provider).await
self.activate_provider(any_provider).await?;
// Check if provider was actually saved — if user cancelled model selection
// inside activate_provider, nothing was written
Ok(self.api.get_default_provider().await.is_ok())
}

/// Activates a provider by configuring it if needed, setting it as default,
Expand Down Expand Up @@ -2812,21 +2797,19 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
provider: Provider<Url>,
model: Option<ModelId>,
) -> Result<()> {
// Set the provider via API
self.api.set_default_provider(provider.id.clone()).await?;

self.writeln_title(
TitleFormat::action(format!("{}", provider.id))
.sub_title("is now the default provider"),
)?;

// If a model was pre-selected (e.g. from :model), validate and set it
// directly without prompting
if let Some(model) = model {
let model_id = self
.validate_model(model.as_str(), Some(&provider.id))
.await?;
self.api.set_default_model(model_id.clone()).await?;
self.api
.set_default_provider_and_model(provider.id.clone(), model_id.clone())
.await?;
self.writeln_title(
TitleFormat::action(format!("{}", provider.id))
.sub_title("is now the default provider"),
)?;
self.writeln_title(
TitleFormat::action(model_id.as_str()).sub_title("is now the default model"),
)?;
Expand All @@ -2835,18 +2818,37 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {

// Check if the current model is available for the new provider
let current_model = self.api.get_default_model().await;
if let Some(current_model) = current_model {
let models = self.get_models().await?;
let model_available = models.iter().any(|m| m.id == current_model);
let needs_model_selection = match current_model {
None => true,
Some(current_model) => {
let provider_models = self.api.get_all_provider_models().await?;
let model_available = provider_models
.iter()
.find(|pm| pm.provider_id == provider.id)
.map(|pm| pm.models.iter().any(|m| m.id == current_model))
.unwrap_or(false);
!model_available
}
};

if !model_available {
// Prompt user to select a new model, scoped to the activated provider
self.writeln_title(TitleFormat::info("Please select a new model"))?;
self.on_model_selection(Some(provider.id.clone())).await?;
if needs_model_selection {
self.writeln_title(TitleFormat::info("Please select a new model"))?;
let selected = self
.on_model_selection(Some(provider.id.clone()), Some(provider.id.clone()))
.await?;
if selected.is_none() {
// User cancelled — preserve existing config untouched
return Ok(());
}
} else {
// No model set, select one now scoped to the activated provider
self.on_model_selection(Some(provider.id.clone())).await?;
// Set the provider via API
// Only reaches here if model is confirmed — safe to write provider now
self.api.set_default_provider(provider.id.clone()).await?;

self.writeln_title(
TitleFormat::action(format!("{}", provider.id))
.sub_title("is now the default provider"),
)?;
}

Ok(())
Expand Down Expand Up @@ -2954,17 +2956,17 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
// Ensure we have a model selected before proceeding with initialization
let active_agent = self.api.get_active_agent().await;

let mut operating_model = self.get_agent_model(active_agent.clone()).await;
if operating_model.is_none() {
// Use the model returned from selection instead of re-fetching
operating_model = self.on_model_selection(None).await?;
}

// Validate provider is configured before loading agents
// If provider is set in config but not configured (no credentials), prompt user
// to login
if self.api.get_default_provider().await.is_err() {
self.on_provider_selection().await?;
if self.api.get_default_provider().await.is_err() && !self.on_provider_selection().await? {
return Ok(());
}

let mut operating_model = self.get_agent_model(active_agent.clone()).await;
if operating_model.is_none() {
// Use the model returned from selection instead of re-fetching
operating_model = self.on_model_selection(None, None).await?;
}

if first {
Expand Down
9 changes: 9 additions & 0 deletions crates/forge_services/src/app_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ impl<F: ProviderRepository + EnvironmentInfra + Send + Sync> AppConfigService
.await
}

async fn set_default_provider_and_model(
&self,
provider_id: ProviderId,
model: ModelId,
) -> anyhow::Result<()> {
self.update(ConfigOperation::SetModel(provider_id, model))
.await
}

async fn get_commit_config(&self) -> anyhow::Result<Option<forge_domain::CommitConfig>> {
let config = self.infra.get_config();
Ok(config.commit.map(|mc| CommitConfig {
Expand Down
21 changes: 0 additions & 21 deletions shell-plugin/lib/actions/config.zsh
Original file line number Diff line number Diff line change
Expand Up @@ -72,27 +72,6 @@ function _forge_action_agent() {
fi
}

# Action handler: Select provider
function _forge_action_provider() {
local input_text="$1"
echo
local selected
# Only show LLM providers (exclude context_engine and other non-LLM types)
# Pass input_text as query parameter for fuzzy search
selected=$(_forge_select_provider "" "" "llm" "$input_text")

if [[ -n "$selected" ]]; then
# Extract the second field (provider ID) from the selected line
# Format: "DisplayName provider_id host type status"
local provider_id=$(echo "$selected" | awk '{print $2}')
# Use _forge_exec_interactive because config-set may trigger
# interactive authentication prompts (rustyline) when the provider
# is not yet configured. Without /dev/tty redirection, ZLE's pipes
# cause rustyline to see EOF and fail with "API key input cancelled".
_forge_exec_interactive config set provider "$provider_id"
fi
}

# Helper: Open an fzf model picker and print the raw selected line.
#
# Model list columns (from `forge list models --porcelain`):
Expand Down
3 changes: 0 additions & 3 deletions shell-plugin/lib/dispatcher.zsh
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,6 @@ function forge-accept-line() {
conversation|c)
_forge_action_conversation "$input_text"
;;
config-provider|provider|p)
_forge_action_provider "$input_text"
;;
config-model|cm)
_forge_action_model "$input_text"
;;
Expand Down
Loading