Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion codex-rs/core/src/codex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ pub(crate) struct CodexSpawnArgs {
pub(crate) persist_extended_history: bool,
pub(crate) metrics_service_name: Option<String>,
pub(crate) inherited_shell_snapshot: Option<Arc<ShellSnapshot>>,
pub(crate) user_shell_override: Option<shell::Shell>,
pub(crate) parent_trace: Option<W3cTraceContext>,
}

Expand Down Expand Up @@ -421,6 +422,7 @@ impl Codex {
persist_extended_history,
metrics_service_name,
inherited_shell_snapshot,
user_shell_override,
parent_trace: _,
} = args;
let (tx_sub, rx_sub) = async_channel::bounded(SUBMISSION_CHANNEL_CAPACITY);
Expand Down Expand Up @@ -575,6 +577,7 @@ impl Codex {
dynamic_tools,
persist_extended_history,
inherited_shell_snapshot,
user_shell_override,
};

// Generate a unique ID for the lifetime of this Codex session.
Expand Down Expand Up @@ -1037,6 +1040,7 @@ pub(crate) struct SessionConfiguration {
dynamic_tools: Vec<DynamicToolSpec>,
persist_extended_history: bool,
inherited_shell_snapshot: Option<Arc<ShellSnapshot>>,
user_shell_override: Option<shell::Shell>,
}

impl SessionConfiguration {
Expand Down Expand Up @@ -1617,7 +1621,11 @@ impl Session {
);

let use_zsh_fork_shell = config.features.enabled(Feature::ShellZshFork);
let mut default_shell = if use_zsh_fork_shell {
let mut default_shell = if let Some(user_shell_override) =
session_configuration.user_shell_override.clone()
{
user_shell_override
} else if use_zsh_fork_shell {
let zsh_path = config.zsh_path.as_ref().ok_or_else(|| {
anyhow::anyhow!(
"zsh fork feature enabled, but `zsh_path` is not configured; set `zsh_path` in config.toml"
Expand Down
1 change: 1 addition & 0 deletions codex-rs/core/src/codex_delegate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ pub(crate) async fn run_codex_thread_interactive(
persist_extended_history: false,
metrics_service_name: None,
inherited_shell_snapshot: None,
user_shell_override: None,
parent_trace: None,
})
.await?;
Expand Down
6 changes: 6 additions & 0 deletions codex-rs/core/src/codex_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1663,6 +1663,7 @@ async fn set_rate_limits_retains_previous_credits() {
dynamic_tools: Vec::new(),
persist_extended_history: false,
inherited_shell_snapshot: None,
user_shell_override: None,
};

let mut state = SessionState::new(session_configuration);
Expand Down Expand Up @@ -1760,6 +1761,7 @@ async fn set_rate_limits_updates_plan_type_when_present() {
dynamic_tools: Vec::new(),
persist_extended_history: false,
inherited_shell_snapshot: None,
user_shell_override: None,
};

let mut state = SessionState::new(session_configuration);
Expand Down Expand Up @@ -2115,6 +2117,7 @@ pub(crate) async fn make_session_configuration_for_tests() -> SessionConfigurati
dynamic_tools: Vec::new(),
persist_extended_history: false,
inherited_shell_snapshot: None,
user_shell_override: None,
}
}

Expand Down Expand Up @@ -2345,6 +2348,7 @@ async fn session_new_fails_when_zsh_fork_enabled_without_zsh_path() {
dynamic_tools: Vec::new(),
persist_extended_history: false,
inherited_shell_snapshot: None,
user_shell_override: None,
};

let (tx_event, _rx_event) = async_channel::unbounded();
Expand Down Expand Up @@ -2439,6 +2443,7 @@ pub(crate) async fn make_session_and_context() -> (Session, TurnContext) {
dynamic_tools: Vec::new(),
persist_extended_history: false,
inherited_shell_snapshot: None,
user_shell_override: None,
};
let per_turn_config = Session::build_per_turn_config(&session_configuration);
let model_info = ModelsManager::construct_model_info_offline_for_tests(
Expand Down Expand Up @@ -3230,6 +3235,7 @@ pub(crate) async fn make_session_and_context_with_dynamic_tools_and_rx(
dynamic_tools,
persist_extended_history: false,
inherited_shell_snapshot: None,
user_shell_override: None,
};
let per_turn_config = Session::build_per_turn_config(&session_configuration);
let model_info = ModelsManager::construct_model_info_offline_for_tests(
Expand Down
1 change: 1 addition & 0 deletions codex-rs/core/src/codex_tests_guardian.rs
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ async fn guardian_subagent_does_not_inherit_parent_exec_policy_rules() {
persist_extended_history: false,
metrics_service_name: None,
inherited_shell_snapshot: None,
user_shell_override: None,
parent_trace: None,
})
.await
Expand Down
27 changes: 27 additions & 0 deletions codex-rs/core/src/test_support.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,33 @@ pub fn thread_manager_with_models_provider_and_home(
ThreadManager::with_models_provider_and_home_for_tests(auth, provider, codex_home)
}

pub async fn start_thread_with_user_shell_override(
thread_manager: &ThreadManager,
config: Config,
user_shell_override: crate::shell::Shell,
) -> crate::error::Result<crate::NewThread> {
thread_manager
.start_thread_with_user_shell_override_for_tests(config, user_shell_override)
.await
}

pub async fn resume_thread_from_rollout_with_user_shell_override(
thread_manager: &ThreadManager,
config: Config,
rollout_path: PathBuf,
auth_manager: Arc<AuthManager>,
user_shell_override: crate::shell::Shell,
) -> crate::error::Result<crate::NewThread> {
thread_manager
.resume_thread_from_rollout_with_user_shell_override_for_tests(
config,
rollout_path,
auth_manager,
user_shell_override,
)
.await
}

pub fn models_manager_with_provider(
codex_home: PathBuf,
auth_manager: Arc<AuthManager>,
Expand Down
51 changes: 51 additions & 0 deletions codex-rs/core/src/thread_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ impl ThreadManager {
persist_extended_history,
metrics_service_name,
parent_trace,
/*user_shell_override*/ None,
))
.await
}
Expand Down Expand Up @@ -420,6 +421,48 @@ impl ThreadManager {
persist_extended_history,
/*metrics_service_name*/ None,
parent_trace,
/*user_shell_override*/ None,
))
.await
}

pub(crate) async fn start_thread_with_user_shell_override_for_tests(
&self,
config: Config,
user_shell_override: crate::shell::Shell,
) -> CodexResult<NewThread> {
Box::pin(self.state.spawn_thread(
config,
InitialHistory::New,
Arc::clone(&self.state.auth_manager),
self.agent_control(),
Vec::new(),
/*persist_extended_history*/ false,
/*metrics_service_name*/ None,
/*parent_trace*/ None,
/*user_shell_override*/ Some(user_shell_override),
))
.await
}

pub(crate) async fn resume_thread_from_rollout_with_user_shell_override_for_tests(
&self,
config: Config,
rollout_path: PathBuf,
auth_manager: Arc<AuthManager>,
user_shell_override: crate::shell::Shell,
) -> CodexResult<NewThread> {
let initial_history = RolloutRecorder::get_rollout_history(&rollout_path).await?;
Box::pin(self.state.spawn_thread(
config,
initial_history,
auth_manager,
self.agent_control(),
Vec::new(),
/*persist_extended_history*/ false,
/*metrics_service_name*/ None,
/*parent_trace*/ None,
/*user_shell_override*/ Some(user_shell_override),
))
.await
}
Expand Down Expand Up @@ -505,6 +548,7 @@ impl ThreadManager {
persist_extended_history,
/*metrics_service_name*/ None,
parent_trace,
/*user_shell_override*/ None,
))
.await
}
Expand Down Expand Up @@ -590,6 +634,7 @@ impl ThreadManagerState {
metrics_service_name,
inherited_shell_snapshot,
/*parent_trace*/ None,
/*user_shell_override*/ None,
))
.await
}
Expand All @@ -614,6 +659,7 @@ impl ThreadManagerState {
/*metrics_service_name*/ None,
inherited_shell_snapshot,
/*parent_trace*/ None,
/*user_shell_override*/ None,
))
.await
}
Expand All @@ -638,6 +684,7 @@ impl ThreadManagerState {
/*metrics_service_name*/ None,
inherited_shell_snapshot,
/*parent_trace*/ None,
/*user_shell_override*/ None,
))
.await
}
Expand All @@ -654,6 +701,7 @@ impl ThreadManagerState {
persist_extended_history: bool,
metrics_service_name: Option<String>,
parent_trace: Option<W3cTraceContext>,
user_shell_override: Option<crate::shell::Shell>,
) -> CodexResult<NewThread> {
Box::pin(self.spawn_thread_with_source(
config,
Expand All @@ -666,6 +714,7 @@ impl ThreadManagerState {
metrics_service_name,
/*inherited_shell_snapshot*/ None,
parent_trace,
user_shell_override,
))
.await
}
Expand All @@ -683,6 +732,7 @@ impl ThreadManagerState {
metrics_service_name: Option<String>,
inherited_shell_snapshot: Option<Arc<ShellSnapshot>>,
parent_trace: Option<W3cTraceContext>,
user_shell_override: Option<crate::shell::Shell>,
) -> CodexResult<NewThread> {
let watch_registration = self
.file_watcher
Expand All @@ -704,6 +754,7 @@ impl ThreadManagerState {
persist_extended_history,
metrics_service_name,
inherited_shell_snapshot,
user_shell_override,
parent_trace,
})
.await?;
Expand Down
47 changes: 44 additions & 3 deletions codex-rs/core/tests/common/test_codex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ use codex_core::built_in_model_providers;
use codex_core::config::Config;
use codex_core::features::Feature;
use codex_core::models_manager::collaboration_mode_presets::CollaborationModesConfig;
use codex_core::shell::Shell;
use codex_core::shell::get_shell_by_model_provided_path;
use codex_protocol::config_types::ServiceTier;
use codex_protocol::openai_models::ModelsResponse;
use codex_protocol::protocol::AskForApproval;
Expand Down Expand Up @@ -64,6 +66,7 @@ pub struct TestCodexBuilder {
auth: CodexAuth,
pre_build_hooks: Vec<Box<PreBuildHook>>,
home: Option<Arc<TempDir>>,
user_shell_override: Option<Shell>,
}

impl TestCodexBuilder {
Expand Down Expand Up @@ -100,6 +103,19 @@ impl TestCodexBuilder {
self
}

pub fn with_user_shell(mut self, user_shell: Shell) -> Self {
self.user_shell_override = Some(user_shell);
self
}

pub fn with_windows_cmd_shell(self) -> Self {
if cfg!(windows) {
self.with_user_shell(get_shell_by_model_provided_path(&PathBuf::from("cmd.exe")))
} else {
self
}
}

pub async fn build(&mut self, server: &wiremock::MockServer) -> anyhow::Result<TestCodex> {
let home = match self.home.clone() {
Some(home) => home,
Expand Down Expand Up @@ -199,9 +215,23 @@ impl TestCodexBuilder {
)
};
let thread_manager = Arc::new(thread_manager);
let user_shell_override = self.user_shell_override.clone();

let new_conversation = match resume_from {
Some(path) => {
let new_conversation = match (resume_from, user_shell_override) {
(Some(path), Some(user_shell_override)) => {
let auth_manager = codex_core::test_support::auth_manager_from_auth(auth);
Box::pin(
codex_core::test_support::resume_thread_from_rollout_with_user_shell_override(
thread_manager.as_ref(),
config.clone(),
path,
auth_manager,
user_shell_override,
),
)
.await?
}
(Some(path), None) => {
let auth_manager = codex_core::test_support::auth_manager_from_auth(auth);
Box::pin(thread_manager.resume_thread_from_rollout(
config.clone(),
Expand All @@ -211,7 +241,17 @@ impl TestCodexBuilder {
))
.await?
}
None => Box::pin(thread_manager.start_thread(config.clone())).await?,
(None, Some(user_shell_override)) => {
Box::pin(
codex_core::test_support::start_thread_with_user_shell_override(
thread_manager.as_ref(),
config.clone(),
user_shell_override,
),
)
.await?
}
(None, None) => Box::pin(thread_manager.start_thread(config.clone())).await?,
};

Ok(TestCodex {
Expand Down Expand Up @@ -562,6 +602,7 @@ pub fn test_codex() -> TestCodexBuilder {
auth: CodexAuth::from_api_key("dummy"),
pre_build_hooks: vec![],
home: None,
user_shell_override: None,
}
}

Expand Down
4 changes: 2 additions & 2 deletions codex-rs/core/tests/suite/agent_websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ async fn websocket_test_codex_shell_chain() -> Result<()> {
]])
.await;

let mut builder = test_codex();
let mut builder = test_codex().with_windows_cmd_shell();

let test = builder.build_with_websocket_server(&server).await?;
test.submit_turn_with_policy(
Expand Down Expand Up @@ -183,7 +183,7 @@ async fn websocket_v2_test_codex_shell_chain() -> Result<()> {
]])
.await;

let mut builder = test_codex().with_config(|config| {
let mut builder = test_codex().with_windows_cmd_shell().with_config(|config| {
config
.features
.enable(Feature::ResponsesWebsocketsV2)
Expand Down
Loading
Loading