Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions codex-rs/core/src/tools/runtimes/shell/zsh_fork_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ mod imp {
use super::*;
use crate::tools::runtimes::shell::unix_escalation;
use crate::unified_exec::SpawnLifecycle;
use codex_shell_escalation::ESCALATE_SOCKET_ENV_VAR;
use codex_shell_escalation::EscalationSession;

#[derive(Debug)]
Expand All @@ -54,6 +55,15 @@ mod imp {
}

impl SpawnLifecycle for ZshForkSpawnLifecycle {
fn inherited_fds(&self) -> Vec<i32> {
self.escalation_session
.env()
.get(ESCALATE_SOCKET_ENV_VAR)
.and_then(|fd| fd.parse().ok())
.into_iter()
.collect()
}

fn after_spawn(&mut self) {
self.escalation_session.close_client_socket();
}
Expand Down
9 changes: 9 additions & 0 deletions codex-rs/core/src/unified_exec/process.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@ use super::UnifiedExecError;
use super::head_tail_buffer::HeadTailBuffer;

pub(crate) trait SpawnLifecycle: std::fmt::Debug + Send + Sync {
/// Returns file descriptors that must stay open across the child `exec()`.
///
/// The returned descriptors must already be valid in the parent process and
/// stay valid until `after_spawn()` runs, which is the first point where
/// the parent may release its copies.
fn inherited_fds(&self) -> Vec<i32> {
Vec::new()
}

fn after_spawn(&mut self) {}
}

Expand Down
7 changes: 5 additions & 2 deletions codex-rs/core/src/unified_exec/process_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -537,24 +537,27 @@ impl UnifiedExecProcessManager {
.command
.split_first()
.ok_or(UnifiedExecError::MissingCommandLine)?;
let inherited_fds = spawn_lifecycle.inherited_fds();

let spawn_result = if tty {
codex_utils_pty::pty::spawn_process(
codex_utils_pty::pty::spawn_process_with_inherited_fds(
program,
args,
env.cwd.as_path(),
&env.env,
&env.arg0,
codex_utils_pty::TerminalSize::default(),
&inherited_fds,
)
.await
} else {
codex_utils_pty::pipe::spawn_process_no_stdin(
codex_utils_pty::pipe::spawn_process_no_stdin_with_inherited_fds(
program,
args,
env.cwd.as_path(),
&env.env,
&env.arg0,
&inherited_fds,
)
.await
};
Expand Down
2 changes: 2 additions & 0 deletions codex-rs/shell-escalation/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ pub use unix::ShellCommandExecutor;
#[cfg(unix)]
pub use unix::Stopwatch;
#[cfg(unix)]
pub use unix::escalate_protocol::ESCALATE_SOCKET_ENV_VAR;
#[cfg(unix)]
pub use unix::main_execve_wrapper;
#[cfg(unix)]
pub use unix::run_shell_escalation_execve_wrapper;
45 changes: 39 additions & 6 deletions codex-rs/shell-escalation/src/unix/escalate_client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::io;
use std::os::fd::AsFd;
use std::os::fd::AsRawFd;
use std::os::fd::FromRawFd as _;
use std::os::fd::OwnedFd;

use anyhow::Context as _;
Expand Down Expand Up @@ -28,6 +28,12 @@ fn get_escalate_client() -> anyhow::Result<AsyncDatagramSocket> {
Ok(unsafe { AsyncDatagramSocket::from_raw_fd(client_fd) }?)
}

fn duplicate_fd_for_transfer(fd: impl AsFd, name: &str) -> anyhow::Result<OwnedFd> {
fd.as_fd()
.try_clone_to_owned()
.with_context(|| format!("failed to duplicate {name} for escalation transfer"))
}

pub async fn run_shell_escalation_execve_wrapper(
file: String,
argv: Vec<String>,
Expand Down Expand Up @@ -62,19 +68,26 @@ pub async fn run_shell_escalation_execve_wrapper(
.context("failed to receive EscalateResponse")?;
match message.action {
EscalateAction::Escalate => {
// TODO: maybe we should send ALL open FDs (except the escalate client)?
// Duplicate stdio before transferring ownership to the server. The
// wrapper must keep using its own stdin/stdout/stderr until the
// escalated child takes over.
let destination_fds = [
io::stdin().as_raw_fd(),
io::stdout().as_raw_fd(),
io::stderr().as_raw_fd(),
];
let fds_to_send = [
unsafe { OwnedFd::from_raw_fd(io::stdin().as_raw_fd()) },
unsafe { OwnedFd::from_raw_fd(io::stdout().as_raw_fd()) },
unsafe { OwnedFd::from_raw_fd(io::stderr().as_raw_fd()) },
duplicate_fd_for_transfer(io::stdin(), "stdin")?,
duplicate_fd_for_transfer(io::stdout(), "stdout")?,
duplicate_fd_for_transfer(io::stderr(), "stderr")?,
];

// TODO: also forward signals over the super-exec socket

client
.send_with_fds(
SuperExecMessage {
fds: fds_to_send.iter().map(AsRawFd::as_raw_fd).collect(),
fds: destination_fds.into_iter().collect(),
},
&fds_to_send,
)
Expand Down Expand Up @@ -115,3 +128,23 @@ pub async fn run_shell_escalation_execve_wrapper(
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use std::os::fd::AsRawFd;
use std::os::unix::net::UnixStream;

#[test]
fn duplicate_fd_for_transfer_does_not_close_original() {
let (left, _right) = UnixStream::pair().expect("socket pair");
let original_fd = left.as_raw_fd();

let duplicate = duplicate_fd_for_transfer(&left, "test fd").expect("duplicate fd");
assert_ne!(duplicate.as_raw_fd(), original_fd);

drop(duplicate);

assert_ne!(unsafe { libc::fcntl(original_fd, libc::F_GETFD) }, -1);
}
}
131 changes: 121 additions & 10 deletions codex-rs/shell-escalation/src/unix/escalate_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -319,16 +319,6 @@ async fn handle_escalate_session_with_policy(
));
}

if msg
.fds
.iter()
.any(|src_fd| fds.iter().any(|dst_fd| dst_fd.as_raw_fd() == *src_fd))
{
return Err(anyhow::anyhow!(
"overlapping fds not yet supported in SuperExecMessage"
));
}

let PreparedExec {
command,
cwd,
Expand Down Expand Up @@ -398,6 +388,7 @@ mod tests {
use codex_utils_absolute_path::AbsolutePathBuf;
use pretty_assertions::assert_eq;
use std::collections::HashMap;
use std::io::Write;
use std::os::fd::AsRawFd;
use std::os::fd::FromRawFd;
use std::path::PathBuf;
Expand Down Expand Up @@ -812,6 +803,126 @@ mod tests {
server_task.await?
}

/// Saves a target descriptor, closes it, and restores it when dropped.
///
/// The overlap regression test needs the next received `SCM_RIGHTS` handle
/// to land on a specific descriptor number such as stdin. Temporarily
/// closing the descriptor makes that allocation possible while still
/// letting the test put the process back the way it found it.
struct RestoredFd {
target_fd: i32,
original_fd: std::os::fd::OwnedFd,
}

impl RestoredFd {
/// Duplicates `target_fd`, then closes the original descriptor number.
///
/// The duplicate is kept alive so `Drop` can restore the original
/// process state after the test finishes.
fn close_temporarily(target_fd: i32) -> anyhow::Result<Self> {
let original_fd = unsafe { libc::dup(target_fd) };
if original_fd == -1 {
return Err(std::io::Error::last_os_error().into());
}
if unsafe { libc::close(target_fd) } == -1 {
let err = std::io::Error::last_os_error();
unsafe {
libc::close(original_fd);
}
return Err(err.into());
}
Ok(Self {
target_fd,
original_fd: unsafe { std::os::fd::OwnedFd::from_raw_fd(original_fd) },
})
}
}

/// Restores the original descriptor back onto its original fd number.
///
/// This keeps the overlap test self-contained even though it mutates the
/// current process's stdio table.
impl Drop for RestoredFd {
fn drop(&mut self) {
unsafe {
libc::dup2(self.original_fd.as_raw_fd(), self.target_fd);
}
}
}

#[tokio::test]
async fn handle_escalate_session_accepts_received_fds_that_overlap_destinations()
-> anyhow::Result<()> {
let _guard = ESCALATE_SERVER_TEST_LOCK.lock().await;
let mut pipe_fds = [0; 2];
if unsafe { libc::pipe(pipe_fds.as_mut_ptr()) } == -1 {
return Err(std::io::Error::last_os_error().into());
}
let read_end = unsafe { std::os::fd::OwnedFd::from_raw_fd(pipe_fds[0]) };
let mut write_end = unsafe { std::fs::File::from_raw_fd(pipe_fds[1]) };

// Force the receive-side overlap case for stdin.
//
// SCM_RIGHTS installs received descriptors into the lowest available fd
// numbers in the receiving process. The pipe is opened first so its
// read end does not consume fd 0. After stdin is temporarily closed,
// receiving `read_end` should reuse descriptor 0. The message below
// also asks the server to map that received fd to destination fd 0, so
// the pre-exec dup2 loop exercises the src_fd == dst_fd case.
let stdin_restore = RestoredFd::close_temporarily(libc::STDIN_FILENO)?;
let (server, client) = AsyncSocket::pair()?;
let server_task = tokio::spawn(handle_escalate_session_with_policy(
server,
Arc::new(DeterministicEscalationPolicy {
decision: EscalationDecision::escalate(EscalationExecution::Unsandboxed),
}),
Arc::new(ForwardingShellCommandExecutor),
CancellationToken::new(),
CancellationToken::new(),
));

client
.send(EscalateRequest {
file: PathBuf::from("/bin/sh"),
argv: vec![
"sh".to_string(),
"-c".to_string(),
"IFS= read -r line && [ \"$line\" = overlap-ok ]".to_string(),
],
workdir: AbsolutePathBuf::current_dir()?,
env: HashMap::new(),
})
.await?;

let response = client.receive::<EscalateResponse>().await?;
assert_eq!(
EscalateResponse {
action: EscalateAction::Escalate,
},
response
);

client
.send_with_fds(
SuperExecMessage {
fds: vec![libc::STDIN_FILENO],
},
&[read_end],
)
.await?;
write_end.write_all(b"overlap-ok\n")?;
drop(write_end);

let result = client.receive::<SuperExecResult>().await?;
assert_eq!(
0, result.exit_code,
"expected the escalated child to read the sent stdin payload even when the received fd reuses fd 0"
);
drop(stdin_restore);

server_task.await?
}

#[tokio::test]
async fn handle_escalate_session_passes_permissions_to_executor() -> anyhow::Result<()> {
let _guard = ESCALATE_SERVER_TEST_LOCK.lock().await;
Expand Down
33 changes: 31 additions & 2 deletions codex-rs/utils/pty/src/pipe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,15 @@ async fn spawn_process_with_stdin_mode(
env: &HashMap<String, String>,
arg0: &Option<String>,
stdin_mode: PipeStdinMode,
inherited_fds: &[i32],
) -> Result<SpawnedProcess> {
if program.is_empty() {
anyhow::bail!("missing program for pipe spawn");
}

#[cfg(not(unix))]
let _ = inherited_fds;

let mut command = Command::new(program);
#[cfg(unix)]
if let Some(arg0) = arg0 {
Expand All @@ -115,11 +119,14 @@ async fn spawn_process_with_stdin_mode(
#[cfg(target_os = "linux")]
let parent_pid = unsafe { libc::getpid() };
#[cfg(unix)]
let inherited_fds = inherited_fds.to_vec();
#[cfg(unix)]
unsafe {
command.pre_exec(move || {
crate::process_group::detach_from_tty()?;
#[cfg(target_os = "linux")]
crate::process_group::set_parent_death_signal(parent_pid)?;
crate::pty::close_inherited_fds_except(&inherited_fds);
Ok(())
});
}
Expand Down Expand Up @@ -250,7 +257,7 @@ pub async fn spawn_process(
env: &HashMap<String, String>,
arg0: &Option<String>,
) -> Result<SpawnedProcess> {
spawn_process_with_stdin_mode(program, args, cwd, env, arg0, PipeStdinMode::Piped).await
spawn_process_with_stdin_mode(program, args, cwd, env, arg0, PipeStdinMode::Piped, &[]).await
}

/// Spawn a process using regular pipes, but close stdin immediately.
Expand All @@ -261,5 +268,27 @@ pub async fn spawn_process_no_stdin(
env: &HashMap<String, String>,
arg0: &Option<String>,
) -> Result<SpawnedProcess> {
spawn_process_with_stdin_mode(program, args, cwd, env, arg0, PipeStdinMode::Null).await
spawn_process_no_stdin_with_inherited_fds(program, args, cwd, env, arg0, &[]).await
}

/// Spawn a process using regular pipes, close stdin immediately, and preserve
/// selected inherited file descriptors across exec on Unix.
pub async fn spawn_process_no_stdin_with_inherited_fds(
program: &str,
args: &[String],
cwd: &Path,
env: &HashMap<String, String>,
arg0: &Option<String>,
inherited_fds: &[i32],
) -> Result<SpawnedProcess> {
spawn_process_with_stdin_mode(
program,
args,
cwd,
env,
arg0,
PipeStdinMode::Null,
inherited_fds,
)
.await
}
Loading
Loading