From 34889d6021e790de6ab4b2b2c562027628553138 Mon Sep 17 00:00:00 2001 From: Rasmus Rygaard Date: Thu, 16 Apr 2026 16:16:12 -0700 Subject: [PATCH 1/6] Add protos for session config --- codex-rs/Cargo.lock | 5 + codex-rs/config/Cargo.toml | 10 + codex-rs/config/examples/generate-proto.rs | 19 + codex-rs/config/scripts/generate-proto.sh | 38 ++ codex-rs/config/src/lib.rs | 1 + codex-rs/config/src/thread_config.rs | 4 + .../proto/codex.thread_config.v1.proto | 67 +++ .../proto/codex.thread_config.v1.rs | 397 +++++++++++++++ codex-rs/config/src/thread_config/remote.rs | 481 ++++++++++++++++++ 9 files changed, 1022 insertions(+) create mode 100644 codex-rs/config/examples/generate-proto.rs create mode 100755 codex-rs/config/scripts/generate-proto.sh create mode 100644 codex-rs/config/src/thread_config/proto/codex.thread_config.v1.proto create mode 100644 codex-rs/config/src/thread_config/proto/codex.thread_config.v1.rs create mode 100644 codex-rs/config/src/thread_config/remote.rs diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 6bc53a49fbc8..a5281527d761 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -2293,6 +2293,7 @@ dependencies = [ "libc", "multimap", "pretty_assertions", + "prost 0.14.3", "schemars 0.8.22", "serde", "serde_json", @@ -2301,8 +2302,12 @@ dependencies = [ "tempfile", "thiserror 2.0.18", "tokio", + "tokio-stream", "toml 0.9.11+spec-1.1.0", "toml_edit 0.24.0+spec-1.1.0", + "tonic", + "tonic-prost", + "tonic-prost-build", "tracing", "wildmatch", "winapi-util", diff --git a/codex-rs/config/Cargo.toml b/codex-rs/config/Cargo.toml index a7a50b07205d..9df08b115de0 100644 --- a/codex-rs/config/Cargo.toml +++ b/codex-rs/config/Cargo.toml @@ -4,6 +4,10 @@ version.workspace = true edition.workspace = true license.workspace = true +[[example]] +name = "generate-proto" +path = "examples/generate-proto.rs" + [lints] workspace = true @@ -21,6 +25,7 @@ codex-utils-path = { workspace = true } futures = { workspace = true, features = ["alloc", "std"] } gethostname = { workspace = true } multimap = { workspace = true } +prost = "0.14.3" schemars = { workspace = true } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } @@ -30,6 +35,8 @@ thiserror = { workspace = true } tokio = { workspace = true, features = ["fs"] } toml = { workspace = true } toml_edit = { workspace = true } +tonic = { workspace = true } +tonic-prost = { workspace = true } tracing = { workspace = true } wildmatch = { workspace = true } @@ -44,3 +51,6 @@ winapi-util = { workspace = true } pretty_assertions = { workspace = true } tempfile = { workspace = true } tokio = { workspace = true, features = ["full"] } +tokio-stream = { workspace = true, features = ["net"] } +tonic = { workspace = true, features = ["router", "transport"] } +tonic-prost-build = { version = "=0.14.3", default-features = false, features = ["transport"] } diff --git a/codex-rs/config/examples/generate-proto.rs b/codex-rs/config/examples/generate-proto.rs new file mode 100644 index 000000000000..03f0f796da44 --- /dev/null +++ b/codex-rs/config/examples/generate-proto.rs @@ -0,0 +1,19 @@ +use std::path::PathBuf; + +fn main() -> Result<(), Box> { + let Some(proto_dir_arg) = std::env::args().nth(1) else { + eprintln!("Usage: generate-proto "); + std::process::exit(1); + }; + + let proto_dir = PathBuf::from(proto_dir_arg); + let proto_file = proto_dir.join("codex.thread_config.v1.proto"); + + tonic_prost_build::configure() + .build_client(true) + .build_server(true) + .out_dir(&proto_dir) + .compile_protos(&[proto_file], &[proto_dir])?; + + Ok(()) +} diff --git a/codex-rs/config/scripts/generate-proto.sh b/codex-rs/config/scripts/generate-proto.sh new file mode 100755 index 000000000000..86af22b8957c --- /dev/null +++ b/codex-rs/config/scripts/generate-proto.sh @@ -0,0 +1,38 @@ +#!/usr/bin/env bash +set -euo pipefail + +script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +repo_root="$(cd "$script_dir/../../.." && pwd)" +proto_dir="$repo_root/codex-rs/config/src/thread_config/proto" +generated="$proto_dir/codex.thread_config.v1.rs" +tmpdir="$(mktemp -d)" + +cleanup() { + rm -rf "$tmpdir" +} +trap cleanup EXIT + +( + cd "$repo_root/codex-rs" + CARGO_TARGET_DIR="$tmpdir/target" cargo run \ + -p codex-config \ + --example generate-proto \ + -- "$proto_dir" +) + +if ! sed -n '2p' "$generated" | grep -q 'clippy::trivially_copy_pass_by_ref'; then + { + sed -n '1p' "$generated" + printf '#![allow(clippy::trivially_copy_pass_by_ref)]\n' + sed '1d' "$generated" + } > "$tmpdir/generated.rs" + mv "$tmpdir/generated.rs" "$generated" +fi + +rustfmt --edition 2024 "$generated" + +awk ' + NR == 3 && previous ~ /clippy::trivially_copy_pass_by_ref/ && $0 != "" { print "" } + { print; previous = $0 } +' "$generated" > "$tmpdir/formatted.rs" +mv "$tmpdir/formatted.rs" "$generated" diff --git a/codex-rs/config/src/lib.rs b/codex-rs/config/src/lib.rs index e70550f00960..a8f09b08f635 100644 --- a/codex-rs/config/src/lib.rs +++ b/codex-rs/config/src/lib.rs @@ -98,6 +98,7 @@ pub use state::ConfigLayerStack; pub use state::ConfigLayerStackOrdering; pub use state::LoaderOverrides; pub use thread_config::NoopThreadConfigLoader; +pub use thread_config::RemoteThreadConfigLoader; pub use thread_config::SessionThreadConfig; pub use thread_config::StaticThreadConfigLoader; pub use thread_config::ThreadConfigContext; diff --git a/codex-rs/config/src/thread_config.rs b/codex-rs/config/src/thread_config.rs index e1e598708ef6..1b3ea8fe8712 100644 --- a/codex-rs/config/src/thread_config.rs +++ b/codex-rs/config/src/thread_config.rs @@ -10,6 +10,10 @@ use toml::Value as TomlValue; use crate::ConfigLayerEntry; +mod remote; + +pub use remote::RemoteThreadConfigLoader; + /// Context available to implementations when loading thread-scoped config. #[derive(Clone, Debug, Default, Eq, PartialEq)] pub struct ThreadConfigContext { diff --git a/codex-rs/config/src/thread_config/proto/codex.thread_config.v1.proto b/codex-rs/config/src/thread_config/proto/codex.thread_config.v1.proto new file mode 100644 index 000000000000..acc8d46e1890 --- /dev/null +++ b/codex-rs/config/src/thread_config/proto/codex.thread_config.v1.proto @@ -0,0 +1,67 @@ +syntax = "proto3"; + +package codex.thread_config.v1; + +service ThreadConfigLoader { + rpc Load(LoadThreadConfigRequest) returns (LoadThreadConfigResponse); +} + +message LoadThreadConfigRequest { + optional string thread_id = 1; + optional string cwd = 2; +} + +message LoadThreadConfigResponse { + repeated ThreadConfigSource sources = 1; +} + +message ThreadConfigSource { + oneof source { + SessionThreadConfig session = 1; + UserThreadConfig user = 2; + } +} + +message SessionThreadConfig { + optional string model_provider = 1; + repeated ModelProvider model_providers = 2; + map features = 3; +} + +message UserThreadConfig {} + +message ModelProvider { + string id = 1; + string name = 2; + optional string base_url = 3; + optional string env_key = 4; + optional string env_key_instructions = 5; + optional string experimental_bearer_token = 6; + optional ModelProviderAuthInfo auth = 7; + WireApi wire_api = 8; + optional StringMap query_params = 9; + optional StringMap http_headers = 10; + optional StringMap env_http_headers = 11; + optional uint64 request_max_retries = 12; + optional uint64 stream_max_retries = 13; + optional uint64 stream_idle_timeout_ms = 14; + optional uint64 websocket_connect_timeout_ms = 15; + bool requires_openai_auth = 16; + bool supports_websockets = 17; +} + +message StringMap { + map values = 1; +} + +message ModelProviderAuthInfo { + string command = 1; + repeated string args = 2; + uint64 timeout_ms = 3; + uint64 refresh_interval_ms = 4; + string cwd = 5; +} + +enum WireApi { + WIRE_API_RESPONSES = 0; +} diff --git a/codex-rs/config/src/thread_config/proto/codex.thread_config.v1.rs b/codex-rs/config/src/thread_config/proto/codex.thread_config.v1.rs new file mode 100644 index 000000000000..4f607488bd2f --- /dev/null +++ b/codex-rs/config/src/thread_config/proto/codex.thread_config.v1.rs @@ -0,0 +1,397 @@ +// This file is @generated by prost-build. +#![allow(clippy::trivially_copy_pass_by_ref)] + +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] +pub struct LoadThreadConfigRequest { + #[prost(string, optional, tag = "1")] + pub thread_id: ::core::option::Option<::prost::alloc::string::String>, + #[prost(string, optional, tag = "2")] + pub cwd: ::core::option::Option<::prost::alloc::string::String>, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct LoadThreadConfigResponse { + #[prost(message, repeated, tag = "1")] + pub sources: ::prost::alloc::vec::Vec, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ThreadConfigSource { + #[prost(oneof = "thread_config_source::Source", tags = "1, 2")] + pub source: ::core::option::Option, +} +/// Nested message and enum types in `ThreadConfigSource`. +pub mod thread_config_source { + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum Source { + #[prost(message, tag = "1")] + Session(super::SessionThreadConfig), + #[prost(message, tag = "2")] + User(super::UserThreadConfig), + } +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct SessionThreadConfig { + #[prost(string, optional, tag = "1")] + pub model_provider: ::core::option::Option<::prost::alloc::string::String>, + #[prost(message, repeated, tag = "2")] + pub model_providers: ::prost::alloc::vec::Vec, + #[prost(map = "string, bool", tag = "3")] + pub features: ::std::collections::HashMap<::prost::alloc::string::String, bool>, +} +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] +pub struct UserThreadConfig {} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ModelProvider { + #[prost(string, tag = "1")] + pub id: ::prost::alloc::string::String, + #[prost(string, tag = "2")] + pub name: ::prost::alloc::string::String, + #[prost(string, optional, tag = "3")] + pub base_url: ::core::option::Option<::prost::alloc::string::String>, + #[prost(string, optional, tag = "4")] + pub env_key: ::core::option::Option<::prost::alloc::string::String>, + #[prost(string, optional, tag = "5")] + pub env_key_instructions: ::core::option::Option<::prost::alloc::string::String>, + #[prost(string, optional, tag = "6")] + pub experimental_bearer_token: ::core::option::Option<::prost::alloc::string::String>, + #[prost(message, optional, tag = "7")] + pub auth: ::core::option::Option, + #[prost(enumeration = "WireApi", tag = "8")] + pub wire_api: i32, + #[prost(message, optional, tag = "9")] + pub query_params: ::core::option::Option, + #[prost(message, optional, tag = "10")] + pub http_headers: ::core::option::Option, + #[prost(message, optional, tag = "11")] + pub env_http_headers: ::core::option::Option, + #[prost(uint64, optional, tag = "12")] + pub request_max_retries: ::core::option::Option, + #[prost(uint64, optional, tag = "13")] + pub stream_max_retries: ::core::option::Option, + #[prost(uint64, optional, tag = "14")] + pub stream_idle_timeout_ms: ::core::option::Option, + #[prost(uint64, optional, tag = "15")] + pub websocket_connect_timeout_ms: ::core::option::Option, + #[prost(bool, tag = "16")] + pub requires_openai_auth: bool, + #[prost(bool, tag = "17")] + pub supports_websockets: bool, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct StringMap { + #[prost(map = "string, string", tag = "1")] + pub values: + ::std::collections::HashMap<::prost::alloc::string::String, ::prost::alloc::string::String>, +} +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] +pub struct ModelProviderAuthInfo { + #[prost(string, tag = "1")] + pub command: ::prost::alloc::string::String, + #[prost(string, repeated, tag = "2")] + pub args: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + #[prost(uint64, tag = "3")] + pub timeout_ms: u64, + #[prost(uint64, tag = "4")] + pub refresh_interval_ms: u64, + #[prost(string, tag = "5")] + pub cwd: ::prost::alloc::string::String, +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum WireApi { + Responses = 0, +} +impl WireApi { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::Responses => "WIRE_API_RESPONSES", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "WIRE_API_RESPONSES" => Some(Self::Responses), + _ => None, + } + } +} +/// Generated client implementations. +pub mod thread_config_loader_client { + #![allow( + unused_variables, + dead_code, + missing_docs, + clippy::wildcard_imports, + clippy::let_unit_value + )] + use tonic::codegen::http::Uri; + use tonic::codegen::*; + #[derive(Debug, Clone)] + pub struct ThreadConfigLoaderClient { + inner: tonic::client::Grpc, + } + impl ThreadConfigLoaderClient { + /// Attempt to create a new client by connecting to a given endpoint. + pub async fn connect(dst: D) -> Result + where + D: TryInto, + D::Error: Into, + { + let conn = tonic::transport::Endpoint::new(dst)?.connect().await?; + Ok(Self::new(conn)) + } + } + impl ThreadConfigLoaderClient + where + T: tonic::client::GrpcService, + T::Error: Into, + T::ResponseBody: Body + std::marker::Send + 'static, + ::Error: Into + std::marker::Send, + { + pub fn new(inner: T) -> Self { + let inner = tonic::client::Grpc::new(inner); + Self { inner } + } + pub fn with_origin(inner: T, origin: Uri) -> Self { + let inner = tonic::client::Grpc::with_origin(inner, origin); + Self { inner } + } + pub fn with_interceptor( + inner: T, + interceptor: F, + ) -> ThreadConfigLoaderClient> + where + F: tonic::service::Interceptor, + T::ResponseBody: Default, + T: tonic::codegen::Service< + http::Request, + Response = http::Response< + >::ResponseBody, + >, + >, + >>::Error: + Into + std::marker::Send + std::marker::Sync, + { + ThreadConfigLoaderClient::new(InterceptedService::new(inner, interceptor)) + } + /// Compress requests with the given encoding. + /// + /// This requires the server to support it otherwise it might respond with an + /// error. + #[must_use] + pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.send_compressed(encoding); + self + } + /// Enable decompressing responses. + #[must_use] + pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.accept_compressed(encoding); + self + } + /// Limits the maximum size of a decoded message. + /// + /// Default: `4MB` + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_decoding_message_size(limit); + self + } + /// Limits the maximum size of an encoded message. + /// + /// Default: `usize::MAX` + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_encoding_message_size(limit); + self + } + pub async fn load( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result, tonic::Status> + { + self.inner.ready().await.map_err(|e| { + tonic::Status::unknown(format!("Service was not ready: {}", e.into())) + })?; + let codec = tonic_prost::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/codex.thread_config.v1.ThreadConfigLoader/Load", + ); + let mut req = request.into_request(); + req.extensions_mut().insert(GrpcMethod::new( + "codex.thread_config.v1.ThreadConfigLoader", + "Load", + )); + self.inner.unary(req, path, codec).await + } + } +} +/// Generated server implementations. +pub mod thread_config_loader_server { + #![allow( + unused_variables, + dead_code, + missing_docs, + clippy::wildcard_imports, + clippy::let_unit_value + )] + use tonic::codegen::*; + /// Generated trait containing gRPC methods that should be implemented for use with ThreadConfigLoaderServer. + #[async_trait] + pub trait ThreadConfigLoader: std::marker::Send + std::marker::Sync + 'static { + async fn load( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status>; + } + #[derive(Debug)] + pub struct ThreadConfigLoaderServer { + inner: Arc, + accept_compression_encodings: EnabledCompressionEncodings, + send_compression_encodings: EnabledCompressionEncodings, + max_decoding_message_size: Option, + max_encoding_message_size: Option, + } + impl ThreadConfigLoaderServer { + pub fn new(inner: T) -> Self { + Self::from_arc(Arc::new(inner)) + } + pub fn from_arc(inner: Arc) -> Self { + Self { + inner, + accept_compression_encodings: Default::default(), + send_compression_encodings: Default::default(), + max_decoding_message_size: None, + max_encoding_message_size: None, + } + } + pub fn with_interceptor(inner: T, interceptor: F) -> InterceptedService + where + F: tonic::service::Interceptor, + { + InterceptedService::new(Self::new(inner), interceptor) + } + /// Enable decompressing requests with the given encoding. + #[must_use] + pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.accept_compression_encodings.enable(encoding); + self + } + /// Compress responses with the given encoding, if the client supports it. + #[must_use] + pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.send_compression_encodings.enable(encoding); + self + } + /// Limits the maximum size of a decoded message. + /// + /// Default: `4MB` + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.max_decoding_message_size = Some(limit); + self + } + /// Limits the maximum size of an encoded message. + /// + /// Default: `usize::MAX` + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.max_encoding_message_size = Some(limit); + self + } + } + impl tonic::codegen::Service> for ThreadConfigLoaderServer + where + T: ThreadConfigLoader, + B: Body + std::marker::Send + 'static, + B::Error: Into + std::marker::Send + 'static, + { + type Response = http::Response; + type Error = std::convert::Infallible; + type Future = BoxFuture; + fn poll_ready( + &mut self, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + fn call(&mut self, req: http::Request) -> Self::Future { + match req.uri().path() { + "/codex.thread_config.v1.ThreadConfigLoader/Load" => { + #[allow(non_camel_case_types)] + struct LoadSvc(pub Arc); + impl + tonic::server::UnaryService for LoadSvc + { + type Response = super::LoadThreadConfigResponse; + type Future = BoxFuture, tonic::Status>; + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::load(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let method = LoadSvc(inner); + let codec = tonic_prost::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.unary(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + _ => Box::pin(async move { + let mut response = http::Response::new(tonic::body::Body::default()); + let headers = response.headers_mut(); + headers.insert( + tonic::Status::GRPC_STATUS, + (tonic::Code::Unimplemented as i32).into(), + ); + headers.insert( + http::header::CONTENT_TYPE, + tonic::metadata::GRPC_CONTENT_TYPE, + ); + Ok(response) + }), + } + } + } + impl Clone for ThreadConfigLoaderServer { + fn clone(&self) -> Self { + let inner = self.inner.clone(); + Self { + inner, + accept_compression_encodings: self.accept_compression_encodings, + send_compression_encodings: self.send_compression_encodings, + max_decoding_message_size: self.max_decoding_message_size, + max_encoding_message_size: self.max_encoding_message_size, + } + } + } + /// Generated gRPC service name + pub const SERVICE_NAME: &str = "codex.thread_config.v1.ThreadConfigLoader"; + impl tonic::server::NamedService for ThreadConfigLoaderServer { + const NAME: &'static str = SERVICE_NAME; + } +} diff --git a/codex-rs/config/src/thread_config/remote.rs b/codex-rs/config/src/thread_config/remote.rs new file mode 100644 index 000000000000..1a94cd813379 --- /dev/null +++ b/codex-rs/config/src/thread_config/remote.rs @@ -0,0 +1,481 @@ +use std::collections::BTreeMap; +use std::collections::HashMap; +use std::num::NonZeroU64; + +use async_trait::async_trait; +use codex_model_provider_info::ModelProviderInfo; +use codex_model_provider_info::WireApi; +use codex_protocol::config_types::ModelProviderAuthInfo; +use codex_utils_absolute_path::AbsolutePathBuf; + +use super::SessionThreadConfig; +use super::ThreadConfigContext; +use super::ThreadConfigLoadError; +use super::ThreadConfigLoadErrorCode; +use super::ThreadConfigLoader; +use super::ThreadConfigSource; +use super::UserThreadConfig; +use proto::thread_config_loader_client::ThreadConfigLoaderClient; + +#[path = "proto/codex.thread_config.v1.rs"] +mod proto; + +/// gRPC-backed [`ThreadConfigLoader`] implementation. +#[derive(Clone, Debug)] +pub struct RemoteThreadConfigLoader { + endpoint: String, +} + +impl RemoteThreadConfigLoader { + pub fn new(endpoint: impl Into) -> Self { + Self { + endpoint: endpoint.into(), + } + } + + async fn client( + &self, + ) -> Result, ThreadConfigLoadError> { + ThreadConfigLoaderClient::connect(self.endpoint.clone()) + .await + .map_err(|err| { + ThreadConfigLoadError::new( + ThreadConfigLoadErrorCode::RequestFailed, + None, + format!("failed to connect to remote thread config loader: {err}"), + ) + }) + } +} + +#[async_trait] +impl ThreadConfigLoader for RemoteThreadConfigLoader { + async fn load( + &self, + context: ThreadConfigContext, + ) -> Result, ThreadConfigLoadError> { + let request = proto::LoadThreadConfigRequest { + thread_id: context.thread_id, + cwd: context.cwd.map(|cwd| cwd.to_string_lossy().into_owned()), + }; + + let response = self + .client() + .await? + .load(request) + .await + .map_err(remote_status_to_error)? + .into_inner(); + + response + .sources + .into_iter() + .map(thread_config_source_from_proto) + .collect() + } +} + +fn remote_status_to_error(status: tonic::Status) -> ThreadConfigLoadError { + let code = match status.code() { + tonic::Code::Unauthenticated | tonic::Code::PermissionDenied => { + ThreadConfigLoadErrorCode::Auth + } + tonic::Code::DeadlineExceeded => ThreadConfigLoadErrorCode::Timeout, + tonic::Code::Ok + | tonic::Code::Cancelled + | tonic::Code::Unknown + | tonic::Code::InvalidArgument + | tonic::Code::NotFound + | tonic::Code::AlreadyExists + | tonic::Code::ResourceExhausted + | tonic::Code::FailedPrecondition + | tonic::Code::Aborted + | tonic::Code::OutOfRange + | tonic::Code::Unimplemented + | tonic::Code::Internal + | tonic::Code::Unavailable + | tonic::Code::DataLoss => ThreadConfigLoadErrorCode::RequestFailed, + }; + ThreadConfigLoadError::new( + code, + None, + format!("remote thread config request failed: {status}"), + ) +} + +fn thread_config_source_from_proto( + source: proto::ThreadConfigSource, +) -> Result { + match source.source { + Some(proto::thread_config_source::Source::Session(config)) => { + session_thread_config_from_proto(config).map(ThreadConfigSource::Session) + } + Some(proto::thread_config_source::Source::User(_)) => { + Ok(ThreadConfigSource::User(UserThreadConfig::default())) + } + None => Err(parse_error("remote thread config omitted source payload")), + } +} + +fn session_thread_config_from_proto( + config: proto::SessionThreadConfig, +) -> Result { + let model_providers = config + .model_providers + .into_iter() + .map(model_provider_from_proto) + .collect::, _>>()?; + + Ok(SessionThreadConfig { + model_provider: config.model_provider, + model_providers, + features: config.features.into_iter().collect::>(), + }) +} + +fn model_provider_from_proto( + provider: proto::ModelProvider, +) -> Result<(String, ModelProviderInfo), ThreadConfigLoadError> { + if provider.id.is_empty() { + return Err(parse_error( + "remote thread config returned model provider without an id", + )); + } + let id = provider.id; + let wire_api = match proto::WireApi::try_from(provider.wire_api) { + Ok(proto::WireApi::Responses) => WireApi::Responses, + Err(_) => { + return Err(parse_error(format!( + "remote thread config returned unknown wire_api: {}", + provider.wire_api + ))); + } + }; + let info = ModelProviderInfo { + name: provider.name, + base_url: provider.base_url, + env_key: provider.env_key, + env_key_instructions: provider.env_key_instructions, + experimental_bearer_token: provider.experimental_bearer_token, + auth: provider + .auth + .map(model_provider_auth_from_proto) + .transpose()?, + wire_api, + query_params: provider.query_params.map(|map| map.values), + http_headers: provider.http_headers.map(|map| map.values), + env_http_headers: provider.env_http_headers.map(|map| map.values), + request_max_retries: provider.request_max_retries, + stream_max_retries: provider.stream_max_retries, + stream_idle_timeout_ms: provider.stream_idle_timeout_ms, + websocket_connect_timeout_ms: provider.websocket_connect_timeout_ms, + requires_openai_auth: provider.requires_openai_auth, + supports_websockets: provider.supports_websockets, + }; + Ok((id, info)) +} + +#[cfg(test)] +fn model_provider_to_proto( + id: impl Into, + provider: ModelProviderInfo, +) -> proto::ModelProvider { + let ModelProviderInfo { + name, + base_url, + env_key, + env_key_instructions, + experimental_bearer_token, + auth, + wire_api, + query_params, + http_headers, + env_http_headers, + request_max_retries, + stream_max_retries, + stream_idle_timeout_ms, + websocket_connect_timeout_ms, + requires_openai_auth, + supports_websockets, + } = provider; + + proto::ModelProvider { + id: id.into(), + name, + base_url, + env_key, + env_key_instructions, + experimental_bearer_token, + auth: auth.map(model_provider_auth_to_proto), + wire_api: proto_wire_api(wire_api).into(), + query_params: query_params.map(proto_string_map), + http_headers: http_headers.map(proto_string_map), + env_http_headers: env_http_headers.map(proto_string_map), + request_max_retries, + stream_max_retries, + stream_idle_timeout_ms, + websocket_connect_timeout_ms, + requires_openai_auth, + supports_websockets, + } +} + +fn model_provider_auth_from_proto( + auth: proto::ModelProviderAuthInfo, +) -> Result { + let timeout_ms = NonZeroU64::new(auth.timeout_ms) + .ok_or_else(|| parse_error("remote thread config returned zero auth timeout_ms"))?; + let cwd = AbsolutePathBuf::from_absolute_path_checked(&auth.cwd).map_err(|err| { + parse_error(format!( + "remote thread config returned invalid auth cwd {:?}: {err}", + auth.cwd + )) + })?; + + Ok(ModelProviderAuthInfo { + command: auth.command, + args: auth.args, + timeout_ms, + refresh_interval_ms: auth.refresh_interval_ms, + cwd, + }) +} + +#[cfg(test)] +fn model_provider_auth_to_proto(auth: ModelProviderAuthInfo) -> proto::ModelProviderAuthInfo { + let ModelProviderAuthInfo { + command, + args, + timeout_ms, + refresh_interval_ms, + cwd, + } = auth; + + proto::ModelProviderAuthInfo { + command, + args, + timeout_ms: timeout_ms.get(), + refresh_interval_ms, + cwd: cwd.to_string_lossy().into_owned(), + } +} + +#[cfg(test)] +fn proto_string_map(values: HashMap) -> proto::StringMap { + proto::StringMap { values } +} + +#[cfg(test)] +fn proto_wire_api(wire_api: WireApi) -> proto::WireApi { + match wire_api { + WireApi::Responses => proto::WireApi::Responses, + } +} + +fn parse_error(message: impl Into) -> ThreadConfigLoadError { + ThreadConfigLoadError::new(ThreadConfigLoadErrorCode::Parse, None, message.into()) +} + +#[cfg(test)] +mod tests { + use std::collections::BTreeMap; + use std::collections::HashMap; + use std::num::NonZeroU64; + + use codex_model_provider_info::ModelProviderInfo; + use codex_model_provider_info::WireApi; + use codex_protocol::config_types::ModelProviderAuthInfo; + use codex_utils_absolute_path::AbsolutePathBuf; + use pretty_assertions::assert_eq; + use tonic::Request; + use tonic::Response; + use tonic::Status; + use tonic::transport::Server; + + use super::proto::thread_config_loader_server; + use super::proto::thread_config_loader_server::ThreadConfigLoaderServer; + use super::*; + use crate::SessionThreadConfig; + use crate::UserThreadConfig; + + struct TestServer { + sources: Vec, + } + + #[tonic::async_trait] + impl thread_config_loader_server::ThreadConfigLoader for TestServer { + async fn load( + &self, + request: Request, + ) -> Result, Status> { + assert_eq!( + request.into_inner(), + proto::LoadThreadConfigRequest { + thread_id: Some("thread-1".to_string()), + cwd: Some("/workspace/project".to_string()), + } + ); + + Ok(Response::new(proto::LoadThreadConfigResponse { + sources: self.sources.clone(), + })) + } + } + + #[tokio::test] + async fn load_thread_config_calls_remote_service() { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .expect("bind test server"); + let addr = listener.local_addr().expect("test server addr"); + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel(); + let server = tokio::spawn(async move { + Server::builder() + .add_service(ThreadConfigLoaderServer::new(TestServer { + sources: proto_sources(), + })) + .serve_with_incoming_shutdown( + tokio_stream::wrappers::TcpListenerStream::new(listener), + async { + let _ = shutdown_rx.await; + }, + ) + .await + }); + + let loader = RemoteThreadConfigLoader::new(format!("http://{addr}")); + let loaded = loader + .load(ThreadConfigContext { + thread_id: Some("thread-1".to_string()), + cwd: Some("/workspace/project".into()), + }) + .await; + + let _ = shutdown_tx.send(()); + server.await.expect("join server").expect("server"); + + assert_eq!(loaded.expect("load thread config"), expected_sources()); + } + + #[test] + fn model_provider_proto_roundtrips_through_domain_type() { + let expected = expected_provider(); + let proto = model_provider_to_proto("local", expected.clone()); + let (id, actual) = model_provider_from_proto(proto).expect("model provider from proto"); + + assert_eq!(id, "local"); + assert_eq!(actual, expected); + } + + fn proto_sources() -> Vec { + vec![ + proto::ThreadConfigSource { + source: Some(proto::thread_config_source::Source::Session( + proto::SessionThreadConfig { + model_provider: Some("local".to_string()), + model_providers: vec![proto::ModelProvider { + id: "local".to_string(), + name: "Local".to_string(), + base_url: Some("http://127.0.0.1:8061/api/codex".to_string()), + env_key: None, + env_key_instructions: None, + experimental_bearer_token: None, + auth: Some(proto::ModelProviderAuthInfo { + command: "token-helper".to_string(), + args: vec!["--json".to_string()], + timeout_ms: 5_000, + refresh_interval_ms: 300_000, + cwd: "/workspace".to_string(), + }), + wire_api: proto::WireApi::Responses.into(), + query_params: Some(proto::StringMap { + values: HashMap::from([( + "api-version".to_string(), + "2026-04-16".to_string(), + )]), + }), + http_headers: Some(proto::StringMap { + values: HashMap::from([( + "X-Test".to_string(), + "enabled".to_string(), + )]), + }), + env_http_headers: Some(proto::StringMap { + values: HashMap::from([( + "X-Env".to_string(), + "LOCAL_HEADER".to_string(), + )]), + }), + request_max_retries: Some(7), + stream_max_retries: Some(8), + stream_idle_timeout_ms: Some(9_000), + websocket_connect_timeout_ms: Some(10_000), + requires_openai_auth: false, + supports_websockets: true, + }], + features: HashMap::from([ + ("plugins".to_string(), false), + ("tools".to_string(), true), + ]), + }, + )), + }, + proto::ThreadConfigSource { + source: Some(proto::thread_config_source::Source::User( + proto::UserThreadConfig {}, + )), + }, + ] + } + + fn expected_sources() -> Vec { + vec![ + ThreadConfigSource::Session(SessionThreadConfig { + model_provider: Some("local".to_string()), + model_providers: HashMap::from([("local".to_string(), expected_provider())]), + features: BTreeMap::from([ + ("plugins".to_string(), false), + ("tools".to_string(), true), + ]), + }), + ThreadConfigSource::User(UserThreadConfig::default()), + ] + } + + fn expected_provider() -> ModelProviderInfo { + ModelProviderInfo { + name: "Local".to_string(), + base_url: Some("http://127.0.0.1:8061/api/codex".to_string()), + env_key: None, + env_key_instructions: None, + experimental_bearer_token: None, + auth: Some(ModelProviderAuthInfo { + command: "token-helper".to_string(), + args: vec!["--json".to_string()], + timeout_ms: NonZeroU64::new(5_000).expect("non-zero timeout"), + refresh_interval_ms: 300_000, + cwd: AbsolutePathBuf::from_absolute_path_checked("/workspace") + .expect("absolute cwd"), + }), + wire_api: WireApi::Responses, + query_params: Some(HashMap::from([( + "api-version".to_string(), + "2026-04-16".to_string(), + )])), + http_headers: Some(HashMap::from([( + "X-Test".to_string(), + "enabled".to_string(), + )])), + env_http_headers: Some(HashMap::from([( + "X-Env".to_string(), + "LOCAL_HEADER".to_string(), + )])), + request_max_retries: Some(7), + stream_max_retries: Some(8), + stream_idle_timeout_ms: Some(9_000), + websocket_connect_timeout_ms: Some(10_000), + requires_openai_auth: false, + supports_websockets: true, + } + } +} From 02ad45bf677ea9cf0e4a24dede34262f39042cb4 Mon Sep 17 00:00:00 2001 From: Rasmus Rygaard Date: Tue, 21 Apr 2026 14:29:26 -0700 Subject: [PATCH 2/6] rebase --- codex-rs/config/src/thread_config/remote.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/codex-rs/config/src/thread_config/remote.rs b/codex-rs/config/src/thread_config/remote.rs index 1a94cd813379..c65664a81a10 100644 --- a/codex-rs/config/src/thread_config/remote.rs +++ b/codex-rs/config/src/thread_config/remote.rs @@ -161,6 +161,7 @@ fn model_provider_from_proto( .auth .map(model_provider_auth_from_proto) .transpose()?, + aws: None, wire_api, query_params: provider.query_params.map(|map| map.values), http_headers: provider.http_headers.map(|map| map.values), @@ -187,6 +188,7 @@ fn model_provider_to_proto( env_key_instructions, experimental_bearer_token, auth, + aws: _, wire_api, query_params, http_headers, @@ -347,7 +349,10 @@ mod tests { let loaded = loader .load(ThreadConfigContext { thread_id: Some("thread-1".to_string()), - cwd: Some("/workspace/project".into()), + cwd: Some( + AbsolutePathBuf::from_absolute_path_checked("/workspace/project") + .expect("absolute cwd"), + ), }) .await; @@ -476,6 +481,7 @@ mod tests { websocket_connect_timeout_ms: Some(10_000), requires_openai_auth: false, supports_websockets: true, + aws: None, } } } From c739ccc7e300e7a1be677c8e0f3d451a8a575dab Mon Sep 17 00:00:00 2001 From: Rasmus Rygaard Date: Tue, 21 Apr 2026 14:37:16 -0700 Subject: [PATCH 3/6] fix enum default --- .../src/thread_config/proto/codex.thread_config.v1.proto | 3 ++- .../config/src/thread_config/proto/codex.thread_config.v1.rs | 5 ++++- codex-rs/config/src/thread_config/remote.rs | 3 +++ 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/codex-rs/config/src/thread_config/proto/codex.thread_config.v1.proto b/codex-rs/config/src/thread_config/proto/codex.thread_config.v1.proto index acc8d46e1890..1efccfd1bfb5 100644 --- a/codex-rs/config/src/thread_config/proto/codex.thread_config.v1.proto +++ b/codex-rs/config/src/thread_config/proto/codex.thread_config.v1.proto @@ -63,5 +63,6 @@ message ModelProviderAuthInfo { } enum WireApi { - WIRE_API_RESPONSES = 0; + WIRE_API_UNSPECIFIED = 0; + WIRE_API_RESPONSES = 1; } diff --git a/codex-rs/config/src/thread_config/proto/codex.thread_config.v1.rs b/codex-rs/config/src/thread_config/proto/codex.thread_config.v1.rs index 4f607488bd2f..30a76bc6b2f4 100644 --- a/codex-rs/config/src/thread_config/proto/codex.thread_config.v1.rs +++ b/codex-rs/config/src/thread_config/proto/codex.thread_config.v1.rs @@ -98,7 +98,8 @@ pub struct ModelProviderAuthInfo { #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum WireApi { - Responses = 0, + Unspecified = 0, + Responses = 1, } impl WireApi { /// String value of the enum field names used in the ProtoBuf definition. @@ -107,12 +108,14 @@ impl WireApi { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { + Self::Unspecified => "WIRE_API_UNSPECIFIED", Self::Responses => "WIRE_API_RESPONSES", } } /// Creates an enum from field names used in the ProtoBuf definition. pub fn from_str_name(value: &str) -> ::core::option::Option { match value { + "WIRE_API_UNSPECIFIED" => Some(Self::Unspecified), "WIRE_API_RESPONSES" => Some(Self::Responses), _ => None, } diff --git a/codex-rs/config/src/thread_config/remote.rs b/codex-rs/config/src/thread_config/remote.rs index c65664a81a10..fbf84fc79b41 100644 --- a/codex-rs/config/src/thread_config/remote.rs +++ b/codex-rs/config/src/thread_config/remote.rs @@ -144,6 +144,9 @@ fn model_provider_from_proto( let id = provider.id; let wire_api = match proto::WireApi::try_from(provider.wire_api) { Ok(proto::WireApi::Responses) => WireApi::Responses, + Ok(proto::WireApi::Unspecified) => { + return Err(parse_error("remote thread config omitted wire_api")); + } Err(_) => { return Err(parse_error(format!( "remote thread config returned unknown wire_api: {}", From 5a05aff8bf059441105c518eb0725984c7f9dc8b Mon Sep 17 00:00:00 2001 From: Rasmus Rygaard Date: Tue, 21 Apr 2026 14:50:48 -0700 Subject: [PATCH 4/6] Fix thread config argument comments --- codex-rs/config/src/thread_config/remote.rs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/codex-rs/config/src/thread_config/remote.rs b/codex-rs/config/src/thread_config/remote.rs index fbf84fc79b41..71d678e063a0 100644 --- a/codex-rs/config/src/thread_config/remote.rs +++ b/codex-rs/config/src/thread_config/remote.rs @@ -41,7 +41,7 @@ impl RemoteThreadConfigLoader { .map_err(|err| { ThreadConfigLoadError::new( ThreadConfigLoadErrorCode::RequestFailed, - None, + /*status_code*/ None, format!("failed to connect to remote thread config loader: {err}"), ) }) @@ -98,7 +98,7 @@ fn remote_status_to_error(status: tonic::Status) -> ThreadConfigLoadError { }; ThreadConfigLoadError::new( code, - None, + /*status_code*/ None, format!("remote thread config request failed: {status}"), ) } @@ -278,7 +278,11 @@ fn proto_wire_api(wire_api: WireApi) -> proto::WireApi { } fn parse_error(message: impl Into) -> ThreadConfigLoadError { - ThreadConfigLoadError::new(ThreadConfigLoadErrorCode::Parse, None, message.into()) + ThreadConfigLoadError::new( + ThreadConfigLoadErrorCode::Parse, + /*status_code*/ None, + message.into(), + ) } #[cfg(test)] From 26894c0783d269fd847c98fcc4620631701daf5a Mon Sep 17 00:00:00 2001 From: Rasmus Rygaard Date: Tue, 21 Apr 2026 15:26:01 -0700 Subject: [PATCH 5/6] Use platform absolute paths in thread config tests --- codex-rs/config/src/thread_config/remote.rs | 23 ++++++++++++++------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/codex-rs/config/src/thread_config/remote.rs b/codex-rs/config/src/thread_config/remote.rs index 71d678e063a0..7c584b33bf27 100644 --- a/codex-rs/config/src/thread_config/remote.rs +++ b/codex-rs/config/src/thread_config/remote.rs @@ -309,6 +309,7 @@ mod tests { struct TestServer { sources: Vec, + expected_cwd: String, } #[tonic::async_trait] @@ -321,7 +322,7 @@ mod tests { request.into_inner(), proto::LoadThreadConfigRequest { thread_id: Some("thread-1".to_string()), - cwd: Some("/workspace/project".to_string()), + cwd: Some(self.expected_cwd.clone()), } ); @@ -333,6 +334,8 @@ mod tests { #[tokio::test] async fn load_thread_config_calls_remote_service() { + let cwd = workspace_dir().join("project"); + let expected_cwd = cwd.to_string_lossy().into_owned(); let listener = tokio::net::TcpListener::bind("127.0.0.1:0") .await .expect("bind test server"); @@ -342,6 +345,7 @@ mod tests { Server::builder() .add_service(ThreadConfigLoaderServer::new(TestServer { sources: proto_sources(), + expected_cwd, })) .serve_with_incoming_shutdown( tokio_stream::wrappers::TcpListenerStream::new(listener), @@ -356,10 +360,7 @@ mod tests { let loaded = loader .load(ThreadConfigContext { thread_id: Some("thread-1".to_string()), - cwd: Some( - AbsolutePathBuf::from_absolute_path_checked("/workspace/project") - .expect("absolute cwd"), - ), + cwd: Some(cwd), }) .await; @@ -380,6 +381,7 @@ mod tests { } fn proto_sources() -> Vec { + let workspace_cwd = workspace_dir().to_string_lossy().into_owned(); vec![ proto::ThreadConfigSource { source: Some(proto::thread_config_source::Source::Session( @@ -397,7 +399,7 @@ mod tests { args: vec!["--json".to_string()], timeout_ms: 5_000, refresh_interval_ms: 300_000, - cwd: "/workspace".to_string(), + cwd: workspace_cwd, }), wire_api: proto::WireApi::Responses.into(), query_params: Some(proto::StringMap { @@ -466,8 +468,7 @@ mod tests { args: vec!["--json".to_string()], timeout_ms: NonZeroU64::new(5_000).expect("non-zero timeout"), refresh_interval_ms: 300_000, - cwd: AbsolutePathBuf::from_absolute_path_checked("/workspace") - .expect("absolute cwd"), + cwd: workspace_dir(), }), wire_api: WireApi::Responses, query_params: Some(HashMap::from([( @@ -491,4 +492,10 @@ mod tests { aws: None, } } + + fn workspace_dir() -> AbsolutePathBuf { + AbsolutePathBuf::current_dir() + .expect("current dir") + .join("workspace") + } } From 77a8e290ba066893fbcae1ef024767e47e91052a Mon Sep 17 00:00:00 2001 From: Rasmus Rygaard Date: Tue, 21 Apr 2026 16:57:35 -0700 Subject: [PATCH 6/6] set timeout --- codex-rs/config/src/thread_config/remote.rs | 34 +++++++++++++++++---- 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/codex-rs/config/src/thread_config/remote.rs b/codex-rs/config/src/thread_config/remote.rs index 7c584b33bf27..7b7feacec5ec 100644 --- a/codex-rs/config/src/thread_config/remote.rs +++ b/codex-rs/config/src/thread_config/remote.rs @@ -1,6 +1,7 @@ use std::collections::BTreeMap; use std::collections::HashMap; use std::num::NonZeroU64; +use std::time::Duration; use async_trait::async_trait; use codex_model_provider_info::ModelProviderInfo; @@ -20,6 +21,8 @@ use proto::thread_config_loader_client::ThreadConfigLoaderClient; #[path = "proto/codex.thread_config.v1.rs"] mod proto; +const REMOTE_THREAD_CONFIG_LOAD_TIMEOUT: Duration = Duration::from_secs(5); + /// gRPC-backed [`ThreadConfigLoader`] implementation. #[derive(Clone, Debug)] pub struct RemoteThreadConfigLoader { @@ -54,15 +57,10 @@ impl ThreadConfigLoader for RemoteThreadConfigLoader { &self, context: ThreadConfigContext, ) -> Result, ThreadConfigLoadError> { - let request = proto::LoadThreadConfigRequest { - thread_id: context.thread_id, - cwd: context.cwd.map(|cwd| cwd.to_string_lossy().into_owned()), - }; - let response = self .client() .await? - .load(request) + .load(load_thread_config_request(context)) .await .map_err(remote_status_to_error)? .into_inner(); @@ -75,6 +73,17 @@ impl ThreadConfigLoader for RemoteThreadConfigLoader { } } +fn load_thread_config_request( + context: ThreadConfigContext, +) -> tonic::Request { + let mut request = tonic::Request::new(proto::LoadThreadConfigRequest { + thread_id: context.thread_id, + cwd: context.cwd.map(|cwd| cwd.to_string_lossy().into_owned()), + }); + request.set_timeout(REMOTE_THREAD_CONFIG_LOAD_TIMEOUT); + request +} + fn remote_status_to_error(status: tonic::Status) -> ThreadConfigLoadError { let code = match status.code() { tonic::Code::Unauthenticated | tonic::Code::PermissionDenied => { @@ -370,6 +379,19 @@ mod tests { assert_eq!(loaded.expect("load thread config"), expected_sources()); } + #[test] + fn load_thread_config_request_sets_timeout() { + let request = load_thread_config_request(ThreadConfigContext::default()); + + assert_eq!( + request + .metadata() + .get("grpc-timeout") + .and_then(|value| value.to_str().ok()), + Some("5000000u") + ); + } + #[test] fn model_provider_proto_roundtrips_through_domain_type() { let expected = expected_provider();